andrea86 commited on
Commit
41a6ec2
·
verified ·
1 Parent(s): 7959c5a

Upload 2 files

Browse files

Example for using the model, with required classes.

Files changed (2) hide show
  1. examples/demo.py +86 -0
  2. examples/stradavit_model.py +229 -0
examples/demo.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example: load a shipped StradaViT checkpoint and extract embeddings.
3
+
4
+ This mirrors the embedding policy:
5
+ - Use the ViT encoder's `last_hidden_state`
6
+ - Mean-pool patch tokens (drop CLS): `hs[:, 1:, :].mean(dim=1)`
7
+
8
+ Expected checkpoint layout (from our training scripts):
9
+ <RUN_ROOT>/checkpoints/
10
+ - config.json
11
+ - pytorch_model.bin (or model.safetensors)
12
+ - preprocessor_config.json
13
+ - (optional) tokenizer/feature extractor extras
14
+
15
+ Usage:
16
+ python3 examples/use_shipped_stradavit_model.py \\
17
+ --checkpoint /path/to/run/checkpoints \\
18
+ --image /path/to/image.png
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import argparse
24
+ import os
25
+ from typing import Any
26
+
27
+ import torch
28
+ import StradaViTModel
29
+
30
+
31
+ def load_model_and_processor(checkpoint_dir: str):
32
+ """
33
+ Loads a StradaViT checkpoint and the matching HF image processor.
34
+ """
35
+ from transformers import ViTImageProcessor, ViTMAEConfig
36
+
37
+ config = ViTMAEConfig.from_pretrained(checkpoint_dir)
38
+ processor = ViTImageProcessor.from_pretrained(checkpoint_dir)
39
+ model = StradaViTModel.from_pretrained(checkpoint_dir)
40
+ model.eval()
41
+ return model, processor, config
42
+
43
+
44
+ def load_image(path: str):
45
+ from PIL import Image
46
+
47
+ img = Image.open(path).convert("RGB")
48
+ return img
49
+
50
+
51
+ def main(argv: list[str] | None = None) -> int:
52
+ ap = argparse.ArgumentParser()
53
+ ap.add_argument("--checkpoint", required=True, help="Path to <run_root>/checkpoints (contains config + weights)")
54
+ ap.add_argument("--image", required=True, help="Path to an image file (png/jpg/...)")
55
+ ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
56
+ args = ap.parse_args(argv)
57
+
58
+ ckpt = os.path.abspath(args.checkpoint)
59
+ if not os.path.isdir(ckpt):
60
+ raise FileNotFoundError(f"--checkpoint must be a directory: {ckpt}")
61
+
62
+ device = torch.device(args.device)
63
+
64
+ model, processor, config = load_model_and_processor(ckpt)
65
+ model.to(device)
66
+
67
+ img = load_image(args.image)
68
+ inputs: dict[str, Any] = processor(images=img, return_tensors="pt")
69
+ pixel_values = inputs["pixel_values"].to(device)
70
+
71
+ with torch.inference_mode():
72
+ out = model(pixel_values=pixel_values)
73
+ emb = out.embedding
74
+
75
+ print(
76
+ f"Loaded checkpoint: {ckpt}\n"
77
+ f" model_type={getattr(config, 'model_type', None)} use_dino_encoder={bool(getattr(config, 'use_dino_encoder', False))} "
78
+ f"n_registers={int(getattr(config, 'n_registers', 0) or 0)}\n"
79
+ f" image_size={int(getattr(config, 'image_size', 0) or 0)} patch_size={int(getattr(config, 'patch_size', 0) or 0)}\n"
80
+ f"Embedding shape: {tuple(emb.shape)} dtype={emb.dtype} device={emb.device}"
81
+ )
82
+ return 0
83
+
84
+
85
+ if __name__ == "__main__":
86
+ raise SystemExit(main())
examples/stradavit_model.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ @dataclass
11
+ class StradaViTOutput:
12
+ embedding: torch.Tensor
13
+ last_hidden_state: torch.Tensor | None = None
14
+ hidden_states: Any | None = None
15
+ attentions: Any | None = None
16
+
17
+
18
+ def _pool_patch_mean(last_hidden_state: torch.Tensor) -> torch.Tensor:
19
+ # Mirror `pretraining/ft_test_llrd.py`: mean over all non-CLS tokens.
20
+ if last_hidden_state.dim() != 3 or last_hidden_state.size(1) < 2:
21
+ raise ValueError(f"Expected (B, T, D) with CLS+patches, got {tuple(last_hidden_state.shape)}")
22
+ return last_hidden_state[:, 1:, :].mean(dim=1)
23
+
24
+
25
+ class StradaViTModel(nn.Module):
26
+ """
27
+ Lightweight encoder-only wrapper that exposes a consistent embedding API for:
28
+ - vanilla ViTMAE checkpoints (any patch size)
29
+ - register-aware / Dinov2Encoder-backed MAE checkpoints
30
+
31
+ Embedding policy matches `pretraining/ft_test_llrd.py`:
32
+ embedding = mean over patch tokens (drop CLS).
33
+ """
34
+
35
+ def __init__(self, backbone: nn.Module):
36
+ super().__init__()
37
+ self.backbone = backbone
38
+ self.config = getattr(backbone, "config", None)
39
+
40
+ @classmethod
41
+ def from_pretrained(cls, checkpoint_path: str, **kwargs):
42
+ """
43
+ Loads a backbone in a way that is compatible with our checkpoints:
44
+ - If config indicates registers or Dinov2Encoder path, use `ViTMAEWithRegistersModel`.
45
+ - Else use `ViTModel` to avoid MAE random masking/shuffling in downstream usage.
46
+ """
47
+ from transformers import ViTModel, ViTMAEConfig
48
+
49
+ config = ViTMAEConfig.from_pretrained(checkpoint_path)
50
+ use_dino_encoder = bool(getattr(config, "use_dino_encoder", False))
51
+ n_registers = int(getattr(config, "n_registers", 0) or 0)
52
+
53
+ if use_dino_encoder or n_registers > 0:
54
+ from pretraining.vit_mae_registers import ViTMAEWithRegistersModel
55
+
56
+ backbone = ViTMAEWithRegistersModel.from_pretrained(
57
+ checkpoint_path,
58
+ n_registers=n_registers,
59
+ ignore_mismatched_sizes=True,
60
+ **kwargs,
61
+ )
62
+ else:
63
+ # ViTModel loads MAE weights with an expected "vit_mae -> vit" type conversion warning.
64
+ backbone = ViTModel.from_pretrained(
65
+ checkpoint_path,
66
+ add_pooling_layer=False,
67
+ **kwargs,
68
+ )
69
+ return cls(backbone=backbone)
70
+
71
+ def _forward_backbone(self, pixel_values: torch.Tensor, **kwargs) -> Any:
72
+ """
73
+ Runs the backbone and returns its native outputs.
74
+ For MAE-family backbones, we disable embeddings.random_masking to get a full-image encoding.
75
+ """
76
+ bb = self.backbone
77
+ emb = getattr(bb, "embeddings", None)
78
+ if emb is None or not hasattr(emb, "random_masking"):
79
+ return bb(pixel_values=pixel_values, **kwargs)
80
+
81
+ orig_random_masking = emb.random_masking
82
+
83
+ def _random_masking_noop(self, x: torch.Tensor, noise: torch.Tensor | None = None):
84
+ if not isinstance(x, torch.Tensor):
85
+ x = torch.as_tensor(x)
86
+ if x.dim() != 3:
87
+ B = x.size(0) if x.dim() > 0 else 1
88
+ L = x.size(1) if x.dim() > 1 else 1
89
+ mask = x.new_zeros(B, L)
90
+ ids_restore = torch.arange(L, device=x.device).unsqueeze(0).expand(B, -1)
91
+ return x, mask, ids_restore
92
+ B, L, _ = x.shape
93
+ device = x.device
94
+ mask = x.new_zeros(B, L)
95
+ ids_restore = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
96
+ return x, mask, ids_restore
97
+
98
+ try:
99
+ import types
100
+
101
+ emb.random_masking = types.MethodType(_random_masking_noop, emb)
102
+ return bb(pixel_values=pixel_values, **kwargs)
103
+ finally:
104
+ emb.random_masking = orig_random_masking
105
+
106
+ def forward(
107
+ self,
108
+ pixel_values: torch.Tensor,
109
+ output_hidden_states: bool | None = None,
110
+ output_attentions: bool | None = None,
111
+ return_dict: bool | None = True,
112
+ **kwargs,
113
+ ) -> StradaViTOutput:
114
+ outputs = self._forward_backbone(
115
+ pixel_values=pixel_values,
116
+ output_hidden_states=output_hidden_states,
117
+ output_attentions=output_attentions,
118
+ return_dict=True,
119
+ **kwargs,
120
+ )
121
+ last_hidden_state = getattr(outputs, "last_hidden_state", None)
122
+ if last_hidden_state is None:
123
+ # Some HF models may return a tuple.
124
+ if isinstance(outputs, (tuple, list)) and len(outputs) > 0:
125
+ last_hidden_state = outputs[0]
126
+ else:
127
+ raise ValueError("Backbone output does not include last_hidden_state")
128
+
129
+ emb = _pool_patch_mean(last_hidden_state)
130
+ out = StradaViTOutput(
131
+ embedding=emb,
132
+ last_hidden_state=last_hidden_state,
133
+ hidden_states=getattr(outputs, "hidden_states", None),
134
+ attentions=getattr(outputs, "attentions", None),
135
+ )
136
+ return out
137
+
138
+
139
+ class StradaViTForImageClassification(nn.Module):
140
+ """
141
+ Simple classification head on top of `StradaViTModel` embeddings.
142
+
143
+ Head policy:
144
+ - LayerNorm (+ optional dropout) + Linear for all MAE-family variants.
145
+
146
+ Rationale: consistent ViT fine-tuning protocol and batch-size agnostic normalization.
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ checkpoint_path: str,
152
+ num_labels: int,
153
+ class_weights: list[float] | None = None,
154
+ head_norm: str = "ln", # kept for backward compatibility; must be "ln" or "auto"
155
+ n_registers: int | None = None, # accepted for call-site compatibility; config remains source of truth
156
+ ):
157
+ super().__init__()
158
+ self.backbone = StradaViTModel.from_pretrained(checkpoint_path)
159
+ self.config = getattr(self.backbone, "config", None)
160
+ self.num_labels = int(num_labels)
161
+
162
+ hidden_size = None
163
+ if self.config is not None:
164
+ hidden_size = getattr(self.config, "hidden_size", None)
165
+ if hidden_size is None:
166
+ raise ValueError("Could not infer hidden_size from backbone config.")
167
+
168
+ if class_weights is not None:
169
+ self.register_buffer(
170
+ "class_weights",
171
+ torch.tensor(class_weights, dtype=torch.float32),
172
+ )
173
+ else:
174
+ self.class_weights = None
175
+
176
+ cfg_n_regs = int(getattr(self.config, "n_registers", 0) or 0) if self.config is not None else 0
177
+ cfg_use_dino = bool(getattr(self.config, "use_dino_encoder", False)) if self.config is not None else False
178
+ if n_registers is not None and int(n_registers) != cfg_n_regs:
179
+ raise ValueError(f"n_registers={int(n_registers)} does not match checkpoint config.n_registers={cfg_n_regs}.")
180
+
181
+ if head_norm not in ("auto", "ln"):
182
+ raise ValueError("head_norm must be one of {'ln','auto'} (BatchNorm is disabled).")
183
+ # "auto" is retained for older call sites; it maps to LN unconditionally now.
184
+ head_norm = "ln"
185
+
186
+ dropout_prob = float(getattr(self.config, "classifier_dropout_prob", 0.0) or 0.0) if self.config is not None else 0.0
187
+ ln_eps = float(getattr(self.config, "layer_norm_eps", 1e-6) or 1e-6) if self.config is not None else 1e-6
188
+
189
+ self.norm = nn.LayerNorm(int(hidden_size), eps=ln_eps)
190
+ self.dropout = nn.Dropout(dropout_prob)
191
+
192
+ self.classifier = nn.Linear(int(hidden_size), self.num_labels)
193
+ nn.init.trunc_normal_(self.classifier.weight, std=0.02)
194
+ if self.classifier.bias is not None:
195
+ nn.init.zeros_(self.classifier.bias)
196
+
197
+ def forward(self, pixel_values=None, labels=None, **kwargs):
198
+ out = self.backbone(pixel_values=pixel_values, **kwargs)
199
+ x = out.embedding
200
+ x = self.norm(x)
201
+ x = self.dropout(x)
202
+ logits = self.classifier(x)
203
+
204
+ loss = None
205
+ if labels is not None:
206
+ if getattr(self, "class_weights", None) is not None:
207
+ loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
208
+ else:
209
+ loss_fct = nn.CrossEntropyLoss()
210
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
211
+
212
+ # Prefer HF's standard output container when available (Trainer-friendly),
213
+ # but keep a dict fallback so this module can be imported without transformers installed.
214
+ try:
215
+ from transformers.modeling_outputs import ImageClassifierOutput # type: ignore
216
+
217
+ return ImageClassifierOutput(
218
+ loss=loss,
219
+ logits=logits,
220
+ hidden_states=out.hidden_states,
221
+ attentions=out.attentions,
222
+ )
223
+ except Exception:
224
+ return {
225
+ "loss": loss,
226
+ "logits": logits,
227
+ "hidden_states": out.hidden_states,
228
+ "attentions": out.attentions,
229
+ }