| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Any |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| @dataclass |
| class StradaViTOutput: |
| embedding: torch.Tensor |
| last_hidden_state: torch.Tensor | None = None |
| hidden_states: Any | None = None |
| attentions: Any | None = None |
|
|
|
|
| def _pool_patch_mean(last_hidden_state: torch.Tensor) -> torch.Tensor: |
| |
| if last_hidden_state.dim() != 3 or last_hidden_state.size(1) < 2: |
| raise ValueError(f"Expected (B, T, D) with CLS+patches, got {tuple(last_hidden_state.shape)}") |
| return last_hidden_state[:, 1:, :].mean(dim=1) |
|
|
|
|
| class StradaViTModel(nn.Module): |
| """ |
| Lightweight encoder-only wrapper that exposes a consistent embedding API for: |
| - vanilla ViTMAE checkpoints (any patch size) |
| - register-aware / Dinov2Encoder-backed MAE checkpoints |
| |
| Embedding policy matches `pretraining/ft_test_llrd.py`: |
| embedding = mean over patch tokens (drop CLS). |
| """ |
|
|
| def __init__(self, backbone: nn.Module): |
| super().__init__() |
| self.backbone = backbone |
| self.config = getattr(backbone, "config", None) |
|
|
| @classmethod |
| def from_pretrained(cls, checkpoint_path: str, **kwargs): |
| """ |
| Loads a backbone in a way that is compatible with our checkpoints: |
| - If config indicates registers or Dinov2Encoder path, use `ViTMAEWithRegistersModel`. |
| - Else use `ViTModel` to avoid MAE random masking/shuffling in downstream usage. |
| """ |
| from transformers import ViTModel, ViTMAEConfig |
|
|
| config = ViTMAEConfig.from_pretrained(checkpoint_path) |
| use_dino_encoder = bool(getattr(config, "use_dino_encoder", False)) |
| n_registers = int(getattr(config, "n_registers", 0) or 0) |
|
|
| if use_dino_encoder or n_registers > 0: |
| from pretraining.vit_mae_registers import ViTMAEWithRegistersModel |
|
|
| backbone = ViTMAEWithRegistersModel.from_pretrained( |
| checkpoint_path, |
| n_registers=n_registers, |
| ignore_mismatched_sizes=True, |
| **kwargs, |
| ) |
| else: |
| |
| backbone = ViTModel.from_pretrained( |
| checkpoint_path, |
| add_pooling_layer=False, |
| **kwargs, |
| ) |
| return cls(backbone=backbone) |
|
|
| def _forward_backbone(self, pixel_values: torch.Tensor, **kwargs) -> Any: |
| """ |
| Runs the backbone and returns its native outputs. |
| For MAE-family backbones, we disable embeddings.random_masking to get a full-image encoding. |
| """ |
| bb = self.backbone |
| emb = getattr(bb, "embeddings", None) |
| if emb is None or not hasattr(emb, "random_masking"): |
| return bb(pixel_values=pixel_values, **kwargs) |
|
|
| orig_random_masking = emb.random_masking |
|
|
| def _random_masking_noop(self, x: torch.Tensor, noise: torch.Tensor | None = None): |
| if not isinstance(x, torch.Tensor): |
| x = torch.as_tensor(x) |
| if x.dim() != 3: |
| B = x.size(0) if x.dim() > 0 else 1 |
| L = x.size(1) if x.dim() > 1 else 1 |
| mask = x.new_zeros(B, L) |
| ids_restore = torch.arange(L, device=x.device).unsqueeze(0).expand(B, -1) |
| return x, mask, ids_restore |
| B, L, _ = x.shape |
| device = x.device |
| mask = x.new_zeros(B, L) |
| ids_restore = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) |
| return x, mask, ids_restore |
|
|
| try: |
| import types |
|
|
| emb.random_masking = types.MethodType(_random_masking_noop, emb) |
| return bb(pixel_values=pixel_values, **kwargs) |
| finally: |
| emb.random_masking = orig_random_masking |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| output_hidden_states: bool | None = None, |
| output_attentions: bool | None = None, |
| return_dict: bool | None = True, |
| **kwargs, |
| ) -> StradaViTOutput: |
| outputs = self._forward_backbone( |
| pixel_values=pixel_values, |
| output_hidden_states=output_hidden_states, |
| output_attentions=output_attentions, |
| return_dict=True, |
| **kwargs, |
| ) |
| last_hidden_state = getattr(outputs, "last_hidden_state", None) |
| if last_hidden_state is None: |
| |
| if isinstance(outputs, (tuple, list)) and len(outputs) > 0: |
| last_hidden_state = outputs[0] |
| else: |
| raise ValueError("Backbone output does not include last_hidden_state") |
|
|
| emb = _pool_patch_mean(last_hidden_state) |
| out = StradaViTOutput( |
| embedding=emb, |
| last_hidden_state=last_hidden_state, |
| hidden_states=getattr(outputs, "hidden_states", None), |
| attentions=getattr(outputs, "attentions", None), |
| ) |
| return out |
|
|
|
|
| class StradaViTForImageClassification(nn.Module): |
| """ |
| Simple classification head on top of `StradaViTModel` embeddings. |
| |
| Head policy: |
| - LayerNorm (+ optional dropout) + Linear for all MAE-family variants. |
| |
| Rationale: consistent ViT fine-tuning protocol and batch-size agnostic normalization. |
| """ |
|
|
| def __init__( |
| self, |
| checkpoint_path: str, |
| num_labels: int, |
| class_weights: list[float] | None = None, |
| head_norm: str = "ln", |
| n_registers: int | None = None, |
| ): |
| super().__init__() |
| self.backbone = StradaViTModel.from_pretrained(checkpoint_path) |
| self.config = getattr(self.backbone, "config", None) |
| self.num_labels = int(num_labels) |
|
|
| hidden_size = None |
| if self.config is not None: |
| hidden_size = getattr(self.config, "hidden_size", None) |
| if hidden_size is None: |
| raise ValueError("Could not infer hidden_size from backbone config.") |
|
|
| if class_weights is not None: |
| self.register_buffer( |
| "class_weights", |
| torch.tensor(class_weights, dtype=torch.float32), |
| ) |
| else: |
| self.class_weights = None |
|
|
| cfg_n_regs = int(getattr(self.config, "n_registers", 0) or 0) if self.config is not None else 0 |
| cfg_use_dino = bool(getattr(self.config, "use_dino_encoder", False)) if self.config is not None else False |
| if n_registers is not None and int(n_registers) != cfg_n_regs: |
| raise ValueError(f"n_registers={int(n_registers)} does not match checkpoint config.n_registers={cfg_n_regs}.") |
|
|
| if head_norm not in ("auto", "ln"): |
| raise ValueError("head_norm must be one of {'ln','auto'} (BatchNorm is disabled).") |
| |
| head_norm = "ln" |
|
|
| dropout_prob = float(getattr(self.config, "classifier_dropout_prob", 0.0) or 0.0) if self.config is not None else 0.0 |
| ln_eps = float(getattr(self.config, "layer_norm_eps", 1e-6) or 1e-6) if self.config is not None else 1e-6 |
|
|
| self.norm = nn.LayerNorm(int(hidden_size), eps=ln_eps) |
| self.dropout = nn.Dropout(dropout_prob) |
|
|
| self.classifier = nn.Linear(int(hidden_size), self.num_labels) |
| nn.init.trunc_normal_(self.classifier.weight, std=0.02) |
| if self.classifier.bias is not None: |
| nn.init.zeros_(self.classifier.bias) |
|
|
| def forward(self, pixel_values=None, labels=None, **kwargs): |
| out = self.backbone(pixel_values=pixel_values, **kwargs) |
| x = out.embedding |
| x = self.norm(x) |
| x = self.dropout(x) |
| logits = self.classifier(x) |
|
|
| loss = None |
| if labels is not None: |
| if getattr(self, "class_weights", None) is not None: |
| loss_fct = nn.CrossEntropyLoss(weight=self.class_weights) |
| else: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
| |
| |
| try: |
| from transformers.modeling_outputs import ImageClassifierOutput |
|
|
| return ImageClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=out.hidden_states, |
| attentions=out.attentions, |
| ) |
| except Exception: |
| return { |
| "loss": loss, |
| "logits": logits, |
| "hidden_states": out.hidden_states, |
| "attentions": out.attentions, |
| } |
|
|