stradavit-base / examples /stradavit_model.py
andrea86's picture
Upload 2 files
41a6ec2 verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import torch
import torch.nn as nn
@dataclass
class StradaViTOutput:
embedding: torch.Tensor
last_hidden_state: torch.Tensor | None = None
hidden_states: Any | None = None
attentions: Any | None = None
def _pool_patch_mean(last_hidden_state: torch.Tensor) -> torch.Tensor:
# Mirror `pretraining/ft_test_llrd.py`: mean over all non-CLS tokens.
if last_hidden_state.dim() != 3 or last_hidden_state.size(1) < 2:
raise ValueError(f"Expected (B, T, D) with CLS+patches, got {tuple(last_hidden_state.shape)}")
return last_hidden_state[:, 1:, :].mean(dim=1)
class StradaViTModel(nn.Module):
"""
Lightweight encoder-only wrapper that exposes a consistent embedding API for:
- vanilla ViTMAE checkpoints (any patch size)
- register-aware / Dinov2Encoder-backed MAE checkpoints
Embedding policy matches `pretraining/ft_test_llrd.py`:
embedding = mean over patch tokens (drop CLS).
"""
def __init__(self, backbone: nn.Module):
super().__init__()
self.backbone = backbone
self.config = getattr(backbone, "config", None)
@classmethod
def from_pretrained(cls, checkpoint_path: str, **kwargs):
"""
Loads a backbone in a way that is compatible with our checkpoints:
- If config indicates registers or Dinov2Encoder path, use `ViTMAEWithRegistersModel`.
- Else use `ViTModel` to avoid MAE random masking/shuffling in downstream usage.
"""
from transformers import ViTModel, ViTMAEConfig
config = ViTMAEConfig.from_pretrained(checkpoint_path)
use_dino_encoder = bool(getattr(config, "use_dino_encoder", False))
n_registers = int(getattr(config, "n_registers", 0) or 0)
if use_dino_encoder or n_registers > 0:
from pretraining.vit_mae_registers import ViTMAEWithRegistersModel
backbone = ViTMAEWithRegistersModel.from_pretrained(
checkpoint_path,
n_registers=n_registers,
ignore_mismatched_sizes=True,
**kwargs,
)
else:
# ViTModel loads MAE weights with an expected "vit_mae -> vit" type conversion warning.
backbone = ViTModel.from_pretrained(
checkpoint_path,
add_pooling_layer=False,
**kwargs,
)
return cls(backbone=backbone)
def _forward_backbone(self, pixel_values: torch.Tensor, **kwargs) -> Any:
"""
Runs the backbone and returns its native outputs.
For MAE-family backbones, we disable embeddings.random_masking to get a full-image encoding.
"""
bb = self.backbone
emb = getattr(bb, "embeddings", None)
if emb is None or not hasattr(emb, "random_masking"):
return bb(pixel_values=pixel_values, **kwargs)
orig_random_masking = emb.random_masking
def _random_masking_noop(self, x: torch.Tensor, noise: torch.Tensor | None = None):
if not isinstance(x, torch.Tensor):
x = torch.as_tensor(x)
if x.dim() != 3:
B = x.size(0) if x.dim() > 0 else 1
L = x.size(1) if x.dim() > 1 else 1
mask = x.new_zeros(B, L)
ids_restore = torch.arange(L, device=x.device).unsqueeze(0).expand(B, -1)
return x, mask, ids_restore
B, L, _ = x.shape
device = x.device
mask = x.new_zeros(B, L)
ids_restore = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
return x, mask, ids_restore
try:
import types
emb.random_masking = types.MethodType(_random_masking_noop, emb)
return bb(pixel_values=pixel_values, **kwargs)
finally:
emb.random_masking = orig_random_masking
def forward(
self,
pixel_values: torch.Tensor,
output_hidden_states: bool | None = None,
output_attentions: bool | None = None,
return_dict: bool | None = True,
**kwargs,
) -> StradaViTOutput:
outputs = self._forward_backbone(
pixel_values=pixel_values,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=True,
**kwargs,
)
last_hidden_state = getattr(outputs, "last_hidden_state", None)
if last_hidden_state is None:
# Some HF models may return a tuple.
if isinstance(outputs, (tuple, list)) and len(outputs) > 0:
last_hidden_state = outputs[0]
else:
raise ValueError("Backbone output does not include last_hidden_state")
emb = _pool_patch_mean(last_hidden_state)
out = StradaViTOutput(
embedding=emb,
last_hidden_state=last_hidden_state,
hidden_states=getattr(outputs, "hidden_states", None),
attentions=getattr(outputs, "attentions", None),
)
return out
class StradaViTForImageClassification(nn.Module):
"""
Simple classification head on top of `StradaViTModel` embeddings.
Head policy:
- LayerNorm (+ optional dropout) + Linear for all MAE-family variants.
Rationale: consistent ViT fine-tuning protocol and batch-size agnostic normalization.
"""
def __init__(
self,
checkpoint_path: str,
num_labels: int,
class_weights: list[float] | None = None,
head_norm: str = "ln", # kept for backward compatibility; must be "ln" or "auto"
n_registers: int | None = None, # accepted for call-site compatibility; config remains source of truth
):
super().__init__()
self.backbone = StradaViTModel.from_pretrained(checkpoint_path)
self.config = getattr(self.backbone, "config", None)
self.num_labels = int(num_labels)
hidden_size = None
if self.config is not None:
hidden_size = getattr(self.config, "hidden_size", None)
if hidden_size is None:
raise ValueError("Could not infer hidden_size from backbone config.")
if class_weights is not None:
self.register_buffer(
"class_weights",
torch.tensor(class_weights, dtype=torch.float32),
)
else:
self.class_weights = None
cfg_n_regs = int(getattr(self.config, "n_registers", 0) or 0) if self.config is not None else 0
cfg_use_dino = bool(getattr(self.config, "use_dino_encoder", False)) if self.config is not None else False
if n_registers is not None and int(n_registers) != cfg_n_regs:
raise ValueError(f"n_registers={int(n_registers)} does not match checkpoint config.n_registers={cfg_n_regs}.")
if head_norm not in ("auto", "ln"):
raise ValueError("head_norm must be one of {'ln','auto'} (BatchNorm is disabled).")
# "auto" is retained for older call sites; it maps to LN unconditionally now.
head_norm = "ln"
dropout_prob = float(getattr(self.config, "classifier_dropout_prob", 0.0) or 0.0) if self.config is not None else 0.0
ln_eps = float(getattr(self.config, "layer_norm_eps", 1e-6) or 1e-6) if self.config is not None else 1e-6
self.norm = nn.LayerNorm(int(hidden_size), eps=ln_eps)
self.dropout = nn.Dropout(dropout_prob)
self.classifier = nn.Linear(int(hidden_size), self.num_labels)
nn.init.trunc_normal_(self.classifier.weight, std=0.02)
if self.classifier.bias is not None:
nn.init.zeros_(self.classifier.bias)
def forward(self, pixel_values=None, labels=None, **kwargs):
out = self.backbone(pixel_values=pixel_values, **kwargs)
x = out.embedding
x = self.norm(x)
x = self.dropout(x)
logits = self.classifier(x)
loss = None
if labels is not None:
if getattr(self, "class_weights", None) is not None:
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
# Prefer HF's standard output container when available (Trainer-friendly),
# but keep a dict fallback so this module can be imported without transformers installed.
try:
from transformers.modeling_outputs import ImageClassifierOutput # type: ignore
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=out.hidden_states,
attentions=out.attentions,
)
except Exception:
return {
"loss": loss,
"logits": logits,
"hidden_states": out.hidden_states,
"attentions": out.attentions,
}