lucasddmc commited on
Commit
d012f9c
·
1 Parent(s): 98cf39b

feat: make possible to use different ViT formats and architectures

Browse files
Files changed (3) hide show
  1. .github/copilot-instructions.md +32 -19
  2. app.py +9 -9
  3. utils/model_loader.py +218 -53
.github/copilot-instructions.md CHANGED
@@ -3,7 +3,7 @@
3
  ## Project Overview
4
 
5
  **ViTViz** is a Gradio-based web app for visualizing Vision Transformer (ViT) attention mechanisms and adversarial attacks on image classification. The app supports:
6
- - Custom ViT model upload (.pth files) or Hugging Face Hub models
7
  - Multiple adversarial attack methods (FGSM, PGD, MIM, TGR, SAGA)
8
  - Attention visualization via Attention Rollout and per-layer/per-head views
9
  - Interactive iteration-by-iteration comparison of adversarial examples
@@ -24,40 +24,53 @@
24
  ### Key Design Patterns
25
 
26
  #### Dynamic Architecture Support (ViTConfig)
27
- The codebase now supports multiple ViT architectures via automatic inference:
28
 
29
  ```python
30
- from utils.model_loader import ViTConfig, infer_config_from_model, infer_config_from_state_dict
31
 
32
  # ViTConfig contains all architecture parameters
33
  config = ViTConfig(
34
- embed_dim=768, # 384=small, 768=base, 1024=large
35
- num_heads=12, # 6=small, 12=base, 16=large
36
- num_layers=12, # varies by model
37
- patch_size=16, # 16 or 32
38
- img_size=224, # 224, 384, etc.
39
- num_classes=1000
 
 
40
  )
41
 
 
 
 
42
  # Properties computed automatically
43
  config.grid_size # img_size // patch_size (e.g., 14 for 224/16)
44
  config.num_patches # grid_size ** 2
45
- config.timm_model_name # e.g., "vit_base_patch16_224"
46
  ```
47
 
48
- Supported architectures (auto-detected):
49
- - `vit_tiny_patch16_224` (embed_dim=192, heads=3)
50
- - `vit_small_patch16_224` (embed_dim=384, heads=6)
51
- - `vit_base_patch16_224` (embed_dim=768, heads=12)
52
- - `vit_large_patch16_224` (embed_dim=1024, heads=16)
53
- - `vit_base_patch32_224` (embed_dim=768, patch_size=32, grid=7)
 
 
 
 
54
 
55
  #### Model Loading Strategy
56
  The codebase supports multiple model sources:
57
- 1. **Local .pth files**: Can contain full model, `state_dict`, `model_state_dict`, or checkpoint dicts with `class_names`
58
  2. **Hugging Face Hub**: Use `hf-model://username/repo-name` format; automatically converts HF ViT to timm-compatible format
59
  3. **Special `hf://` URIs**: For CNN backbones in SAGA attacks (e.g., `hf://lucasddmc/resnet101-stanford40-actions/resnet.pth`)
60
 
 
 
 
 
61
  The main loader returns 4 values:
62
  ```python
63
  model, class_names, label_source, vit_config = load_model_and_labels(model_path, None, device=DEVICE)
@@ -143,7 +156,7 @@ The app injects Bootstrap Icons via CDN and custom CSS for panels/tables. Icon c
143
 
144
  ## External Dependencies
145
 
146
- - **timm**: ViT model architecture (`vit_base_patch16_224` is the default)
147
  - **torchattacks**: Base classes for adversarial attacks
148
  - **transformers**: Optional, for loading HF Hub models
149
  - **gradio**: Version 5.49.1 (specified in requirements)
@@ -158,7 +171,7 @@ Currently no automated tests. Manual testing workflow:
158
 
159
  ## Known Limitations
160
 
161
- - Supports timm ViT architectures (tiny, small, base, large) with patch sizes 16 and 32
162
  - No support for non-standard ViT variants (DeiT distillation token, Swin hierarchical, BEiT) without additional conversion
163
  - Custom CSS may break with Gradio version updates
164
  - No batch processing support (processes one image at a time)
 
3
  ## Project Overview
4
 
5
  **ViTViz** is a Gradio-based web app for visualizing Vision Transformer (ViT) attention mechanisms and adversarial attacks on image classification. The app supports:
6
+ - Custom ViT model upload (.pth, .pt, .safetensors files) or Hugging Face Hub models
7
  - Multiple adversarial attack methods (FGSM, PGD, MIM, TGR, SAGA)
8
  - Attention visualization via Attention Rollout and per-layer/per-head views
9
  - Interactive iteration-by-iteration comparison of adversarial examples
 
24
  ### Key Design Patterns
25
 
26
  #### Dynamic Architecture Support (ViTConfig)
27
+ The codebase supports **any ViT architecture** with timm-compatible structure (`model.blocks[i].attn.qkv`), not limited to predefined model names. Architecture parameters are inferred automatically from state_dict:
28
 
29
  ```python
30
+ from utils.model_loader import ViTConfig, create_vit_from_config, infer_config_from_state_dict
31
 
32
  # ViTConfig contains all architecture parameters
33
  config = ViTConfig(
34
+ embed_dim=768, # Any value (192, 384, 512, 768, 1024, etc.)
35
+ num_heads=12, # Any valid divisor of embed_dim
36
+ num_layers=12, # Any depth
37
+ patch_size=16, # 8, 14, 16, 32, etc.
38
+ img_size=224, # 224, 384, 448, etc.
39
+ num_classes=1000,
40
+ mlp_ratio=4.0, # MLP hidden dim = embed_dim * mlp_ratio
41
+ qkv_bias=True # Whether QKV projection has bias
42
  )
43
 
44
+ # Create model directly from config (no predefined names needed)
45
+ model = create_vit_from_config(config, device=DEVICE)
46
+
47
  # Properties computed automatically
48
  config.grid_size # img_size // patch_size (e.g., 14 for 224/16)
49
  config.num_patches # grid_size ** 2
50
+ config.timm_model_name # Informational: "vit_base_patch16_224" or "vit_custom_patch16_224"
51
  ```
52
 
53
+ **Inference from state_dict**: When loading a checkpoint, all parameters are inferred automatically:
54
+ - `embed_dim`: from `blocks.0.attn.qkv.weight.shape[1]`
55
+ - `num_heads`: heuristic based on common head_dim values (64, 32, 96)
56
+ - `num_layers`: count of `blocks.X.attn.*` keys
57
+ - `patch_size`: from `patch_embed.proj.weight.shape[2]`
58
+ - `img_size`: from `pos_embed.shape[1]` (num_patches + 1)
59
+ - `mlp_ratio`: from `blocks.0.mlp.fc1.weight.shape[0] / embed_dim`
60
+ - `qkv_bias`: presence of `blocks.0.attn.qkv.bias` key
61
+
62
+ **Validation**: Use `validate_vit_structure(model)` to check if a model has the required structure before attempting attention extraction.
63
 
64
  #### Model Loading Strategy
65
  The codebase supports multiple model sources:
66
+ 1. **Local files**: `.pth`, `.pt`, `.safetensors` - Can contain full model, `state_dict`, `model_state_dict`, or checkpoint dicts with `class_names`
67
  2. **Hugging Face Hub**: Use `hf-model://username/repo-name` format; automatically converts HF ViT to timm-compatible format
68
  3. **Special `hf://` URIs**: For CNN backbones in SAGA attacks (e.g., `hf://lucasddmc/resnet101-stanford40-actions/resnet.pth`)
69
 
70
+ **Supported file formats**:
71
+ - `.pth` / `.pt`: Standard PyTorch checkpoint (torch.load)
72
+ - `.safetensors`: Modern HuggingFace format (faster, more secure)
73
+
74
  The main loader returns 4 values:
75
  ```python
76
  model, class_names, label_source, vit_config = load_model_and_labels(model_path, None, device=DEVICE)
 
156
 
157
  ## External Dependencies
158
 
159
+ - **timm**: ViT model architecture (VisionTransformer class for flexible model creation)
160
  - **torchattacks**: Base classes for adversarial attacks
161
  - **transformers**: Optional, for loading HF Hub models
162
  - **gradio**: Version 5.49.1 (specified in requirements)
 
171
 
172
  ## Known Limitations
173
 
174
+ - Supports any timm-compatible ViT (must have `model.blocks[i].attn.qkv` structure)
175
  - No support for non-standard ViT variants (DeiT distillation token, Swin hierarchical, BEiT) without additional conversion
176
  - Custom CSS may break with Gradio version updates
177
  - No batch processing support (processes one image at a time)
app.py CHANGED
@@ -83,7 +83,7 @@ def classify_image(model_file, use_hf_vit: bool, image):
83
  """
84
  try:
85
  if not use_hf_vit and model_file is None:
86
- return "Please upload a model file (.pth) or enable 'Use vit-b16-stanford40-actions'"
87
  # Extrair paths dos componentes de arquivo do Gradio
88
  model_path = HF_VIT_MODEL_SPEC if use_hf_vit else _to_path(model_file)
89
 
@@ -153,7 +153,7 @@ def visualize_attention(
153
  """
154
  try:
155
  if not use_hf_vit and model_file is None:
156
- return None, "Please upload a model file (.pth) or enable 'Use vit-b16-stanford40-actions'"
157
  if image is None:
158
  return None, "Please upload an image"
159
 
@@ -244,7 +244,7 @@ def run_attack(
244
  """
245
  try:
246
  if not use_hf_vit and model_file is None:
247
- return [], "Please upload a model file (.pth) or enable 'Use vit-b16-stanford40-actions'", []
248
  if image is None:
249
  return [], "Please upload an image", []
250
 
@@ -601,8 +601,8 @@ def create_app():
601
  label="Use vit-b16-stanford40-actions"
602
  )
603
  model_upload_classify = gr.File(
604
- label="Upload Model (.pth/.pt)",
605
- file_types=[".pth", ".pt"],
606
  interactive=False
607
  )
608
  with gr.Column(scale=2):
@@ -641,8 +641,8 @@ def create_app():
641
  label="Use vit-b16-stanford40-actions"
642
  )
643
  model_upload_attention = gr.File(
644
- label="Upload Model (.pth/.pt)",
645
- file_types=[".pth", ".pt"],
646
  interactive=False
647
  )
648
  with gr.Column(scale=2):
@@ -714,8 +714,8 @@ def create_app():
714
  label="Use vit-b16-stanford40-actions"
715
  )
716
  model_upload_attack = gr.File(
717
- label="Upload Model (.pth/.pt)",
718
- file_types=[".pth", ".pt"],
719
  interactive=False
720
  )
721
  with gr.Column(scale=3):
 
83
  """
84
  try:
85
  if not use_hf_vit and model_file is None:
86
+ return "Please upload a model file (.pth/.pt/.safetensors/.ckpt) or enable 'Use vit-b16-stanford40-actions'"
87
  # Extrair paths dos componentes de arquivo do Gradio
88
  model_path = HF_VIT_MODEL_SPEC if use_hf_vit else _to_path(model_file)
89
 
 
153
  """
154
  try:
155
  if not use_hf_vit and model_file is None:
156
+ return None, "Please upload a model file (.pth/.pt/.safetensors/.ckpt) or enable 'Use vit-b16-stanford40-actions'"
157
  if image is None:
158
  return None, "Please upload an image"
159
 
 
244
  """
245
  try:
246
  if not use_hf_vit and model_file is None:
247
+ return [], "Please upload a model file (.pth/.pt/.safetensors/.ckpt) or enable 'Use vit-b16-stanford40-actions'", []
248
  if image is None:
249
  return [], "Please upload an image", []
250
 
 
601
  label="Use vit-b16-stanford40-actions"
602
  )
603
  model_upload_classify = gr.File(
604
+ label="Upload Model (.pth/.pt/.safetensors/.ckpt)",
605
+ file_types=[".pth", ".pt", ".safetensors", ".ckpt"],
606
  interactive=False
607
  )
608
  with gr.Column(scale=2):
 
641
  label="Use vit-b16-stanford40-actions"
642
  )
643
  model_upload_attention = gr.File(
644
+ label="Upload Model (.pth/.pt/.safetensors/.ckpt)",
645
+ file_types=[".pth", ".pt", ".safetensors", ".ckpt"],
646
  interactive=False
647
  )
648
  with gr.Column(scale=2):
 
714
  label="Use vit-b16-stanford40-actions"
715
  )
716
  model_upload_attack = gr.File(
717
+ label="Upload Model (.pth/.pt/.safetensors/.ckpt)",
718
+ file_types=[".pth", ".pt", ".safetensors", ".ckpt"],
719
  interactive=False
720
  )
721
  with gr.Column(scale=3):
utils/model_loader.py CHANGED
@@ -4,11 +4,23 @@ import timm
4
  from dataclasses import dataclass
5
  from typing import Optional, Tuple, Dict, Any
6
 
 
 
 
 
 
 
7
  try:
8
  from transformers import AutoModelForImageClassification
9
  except Exception: # pragma: no cover
10
  AutoModelForImageClassification = None
11
 
 
 
 
 
 
 
12
  DEVICE_DEFAULT = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
 
@@ -21,6 +33,8 @@ class ViTConfig:
21
  patch_size: int = 16
22
  img_size: int = 224
23
  num_classes: int = 1000
 
 
24
 
25
  @property
26
  def grid_size(self) -> int:
@@ -34,7 +48,7 @@ class ViTConfig:
34
 
35
  @property
36
  def timm_model_name(self) -> str:
37
- """Retorna o nome do modelo timm correspondente à configuração."""
38
  # Mapeamento baseado em embed_dim e num_heads
39
  size_map = {
40
  (192, 3): 'tiny',
@@ -43,10 +57,107 @@ class ViTConfig:
43
  (1024, 16): 'large',
44
  (1280, 16): 'huge',
45
  }
46
- size = size_map.get((self.embed_dim, self.num_heads), 'base')
47
  return f"vit_{size}_patch{self.patch_size}_{self.img_size}"
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def infer_config_from_model(model: torch.nn.Module) -> ViTConfig:
51
  """Infere configuração ViT a partir de um modelo timm carregado."""
52
  config = ViTConfig()
@@ -95,20 +206,47 @@ def infer_config_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> ViTConf
95
  if layer_indices:
96
  config.num_layers = max(layer_indices) + 1
97
 
98
- # Inferir embed_dim e num_heads do primeiro bloco
99
  qkv_key = 'blocks.0.attn.qkv.weight'
100
  if qkv_key in state_dict:
101
  qkv_weight = state_dict[qkv_key]
102
  # qkv.weight shape: [3*embed_dim, embed_dim]
103
  config.embed_dim = qkv_weight.shape[1]
 
 
 
 
104
 
105
- # Inferir num_heads do proj bias ou de forma heurística
106
  proj_key = 'blocks.0.attn.proj.weight'
107
- if proj_key in state_dict:
108
- # proj.weight shape: [embed_dim, embed_dim]
109
  embed_dim = state_dict[proj_key].shape[0]
110
- # Heurística: head_dim típico é 64
111
- config.num_heads = embed_dim // 64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  # Inferir num_classes do head
114
  head_key = 'head.weight'
@@ -219,6 +357,8 @@ def load_vit_from_huggingface(model_id: str, device: Optional[torch.device] = No
219
  num_heads = int(getattr(cfg, "num_attention_heads", 12)) if cfg is not None else 12
220
  patch_size = int(getattr(cfg, "patch_size", 16)) if cfg is not None else 16
221
  img_size = int(getattr(cfg, "image_size", 224)) if cfg is not None else 224
 
 
222
  class_names = _hf_id2label_to_class_names(getattr(cfg, "id2label", None)) if cfg is not None else None
223
 
224
  # Criar config dinâmico
@@ -228,26 +368,23 @@ def load_vit_from_huggingface(model_id: str, device: Optional[torch.device] = No
228
  num_layers=num_layers,
229
  patch_size=patch_size,
230
  img_size=img_size,
231
- num_classes=num_labels
 
 
232
  )
233
 
234
- # Tentar encontrar o modelo timm correspondente
235
- timm_name = vit_config.timm_model_name
236
- try:
237
- timm_model = timm.create_model(timm_name, pretrained=False, num_classes=num_labels)
238
- except Exception:
239
- # Fallback para vit_base_patch16_224 se o modelo não existir
240
- print(f"[ViTViz] Modelo timm '{timm_name}' não encontrado, usando vit_base_patch16_224")
241
- timm_model = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=num_labels)
242
 
 
243
  timm_sd = _convert_hf_vit_to_timm_state_dict(hf_model.state_dict(), num_layers=num_layers)
244
  timm_model.load_state_dict(timm_sd, strict=False)
245
- timm_model = timm_model.to(device)
246
  timm_model.eval()
247
 
248
- # Atualizar config com valores reais do modelo carregado
249
- vit_config = infer_config_from_model(timm_model)
250
-
251
  return timm_model, class_names, vit_config
252
 
253
 
@@ -263,11 +400,27 @@ class CustomUnpickler(pickle.Unpickler):
263
 
264
 
265
  def load_checkpoint(model_path: str, device: Optional[torch.device] = None) -> Any:
266
- """Carrega um checkpoint/modelo do caminho informado, com fallback para unpickler customizado.
 
 
 
 
267
 
268
  Retorna o objeto carregado (modelo completo, state_dict ou dict de checkpoint).
269
  """
270
  device = device or DEVICE_DEFAULT
 
 
 
 
 
 
 
 
 
 
 
 
271
  try:
272
  return torch.load(model_path, map_location=device, weights_only=False)
273
  except (AttributeError, ModuleNotFoundError, RuntimeError):
@@ -349,59 +502,71 @@ def load_class_names_from_file(labels_file: Optional[str]) -> Optional[Dict[int,
349
  def build_model_from_checkpoint(checkpoint: Any, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, ViTConfig]:
350
  """Constroi um modelo a partir de um checkpoint que pode ser um dict, state_dict ou o próprio modelo.
351
 
 
 
352
  Returns:
353
  (model, config) - modelo carregado e configuração inferida
354
  """
355
  device = device or DEVICE_DEFAULT
356
  config: Optional[ViTConfig] = None
357
 
 
 
 
 
358
  if isinstance(checkpoint, dict):
359
  if 'model' in checkpoint:
 
360
  model = checkpoint['model']
361
  config = infer_config_from_model(model)
 
 
 
 
362
  elif 'state_dict' in checkpoint:
363
  state_dict = checkpoint['state_dict']
 
 
364
  config = infer_config_from_state_dict(state_dict)
365
- num_classes = infer_num_classes(state_dict)
366
- config.num_classes = num_classes
367
- # Usar arquitetura inferida
368
- timm_name = config.timm_model_name
369
- try:
370
- model = timm.create_model(timm_name, pretrained=False, num_classes=num_classes)
371
- except Exception:
372
- print(f"[ViTViz] Modelo timm '{timm_name}' não encontrado, usando vit_base_patch16_224")
373
- model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes)
374
- model.load_state_dict(state_dict)
375
  elif 'model_state_dict' in checkpoint:
376
  # Novo formato com class_names embutidas
377
  state_dict = checkpoint['model_state_dict']
 
 
378
  config = infer_config_from_state_dict(state_dict)
379
- num_classes = infer_num_classes(state_dict)
380
- config.num_classes = num_classes
381
- # Usar arquitetura inferida
382
- timm_name = config.timm_model_name
383
- try:
384
- model = timm.create_model(timm_name, pretrained=False, num_classes=num_classes)
385
- except Exception:
386
- print(f"[ViTViz] Modelo timm '{timm_name}' não encontrado, usando vit_base_patch16_224")
387
- model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes)
388
- model.load_state_dict(state_dict)
389
  else:
390
- # assume dict é um state_dict
 
 
391
  config = infer_config_from_state_dict(checkpoint)
392
- num_classes = infer_num_classes(checkpoint)
393
- config.num_classes = num_classes
394
- # Usar arquitetura inferida
395
- timm_name = config.timm_model_name
396
- try:
397
- model = timm.create_model(timm_name, pretrained=False, num_classes=num_classes)
398
- except Exception:
399
- print(f"[ViTViz] Modelo timm '{timm_name}' não encontrado, usando vit_base_patch16_224")
400
- model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes)
401
- model.load_state_dict(checkpoint)
402
  else:
403
  # modelo completo salvo via torch.save(model, ...)
404
  model = checkpoint
 
 
 
 
405
  config = infer_config_from_model(model)
406
 
407
  model = model.to(device)
 
4
  from dataclasses import dataclass
5
  from typing import Optional, Tuple, Dict, Any
6
 
7
+ # Importar VisionTransformer diretamente para criar modelos com arquiteturas customizadas
8
+ try:
9
+ from timm.models.vision_transformer import VisionTransformer
10
+ except ImportError:
11
+ VisionTransformer = None
12
+
13
  try:
14
  from transformers import AutoModelForImageClassification
15
  except Exception: # pragma: no cover
16
  AutoModelForImageClassification = None
17
 
18
+ # Suporte a safetensors (formato moderno do HuggingFace)
19
+ try:
20
+ from safetensors.torch import load_file as load_safetensors
21
+ except ImportError:
22
+ load_safetensors = None
23
+
24
  DEVICE_DEFAULT = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
 
 
33
  patch_size: int = 16
34
  img_size: int = 224
35
  num_classes: int = 1000
36
+ mlp_ratio: float = 4.0
37
+ qkv_bias: bool = True
38
 
39
  @property
40
  def grid_size(self) -> int:
 
48
 
49
  @property
50
  def timm_model_name(self) -> str:
51
+ """Retorna o nome do modelo timm correspondente (para fins informativos)."""
52
  # Mapeamento baseado em embed_dim e num_heads
53
  size_map = {
54
  (192, 3): 'tiny',
 
57
  (1024, 16): 'large',
58
  (1280, 16): 'huge',
59
  }
60
+ size = size_map.get((self.embed_dim, self.num_heads), 'custom')
61
  return f"vit_{size}_patch{self.patch_size}_{self.img_size}"
62
 
63
 
64
+ def create_vit_from_config(config: ViTConfig, device: Optional[torch.device] = None) -> torch.nn.Module:
65
+ """Cria um modelo ViT diretamente a partir da configuração inferida.
66
+
67
+ Isso permite criar modelos com arquiteturas arbitrárias, não limitadas
68
+ aos nomes predefinidos do timm (vit_base_patch16_224, etc.).
69
+ """
70
+ device = device or DEVICE_DEFAULT
71
+
72
+ if VisionTransformer is None:
73
+ raise RuntimeError("VisionTransformer não disponível. Verifique a instalação do timm.")
74
+
75
+ model = VisionTransformer(
76
+ img_size=config.img_size,
77
+ patch_size=config.patch_size,
78
+ in_chans=3,
79
+ num_classes=config.num_classes,
80
+ embed_dim=config.embed_dim,
81
+ depth=config.num_layers,
82
+ num_heads=config.num_heads,
83
+ mlp_ratio=config.mlp_ratio,
84
+ qkv_bias=config.qkv_bias,
85
+ class_token=True,
86
+ global_pool='token',
87
+ )
88
+
89
+ return model.to(device)
90
+
91
+
92
+ def _strip_state_dict_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
93
+ """Remove prefixos comuns de frameworks (Lightning, DDP, etc.) das keys do state_dict.
94
+
95
+ Prefixos tratados:
96
+ - 'model.' (PyTorch Lightning)
97
+ - 'module.' (DataParallel/DistributedDataParallel)
98
+ - 'encoder.' (alguns frameworks de self-supervised learning)
99
+ - 'backbone.' (alguns frameworks de detecção)
100
+
101
+ Returns:
102
+ state_dict com keys sem prefixo
103
+ """
104
+ prefixes = ['model.', 'module.', 'encoder.', 'backbone.']
105
+
106
+ # Verificar se alguma key tem prefixo
107
+ has_prefix = False
108
+ detected_prefix = None
109
+ for key in state_dict.keys():
110
+ for prefix in prefixes:
111
+ if key.startswith(prefix):
112
+ has_prefix = True
113
+ detected_prefix = prefix
114
+ break
115
+ if has_prefix:
116
+ break
117
+
118
+ if not has_prefix:
119
+ return state_dict
120
+
121
+ print(f"[ViTViz] Detectado prefixo '{detected_prefix}' nas keys do state_dict (Lightning/DDP). Removendo...")
122
+
123
+ new_sd: Dict[str, torch.Tensor] = {}
124
+ for key, value in state_dict.items():
125
+ new_key = key
126
+ for prefix in prefixes:
127
+ if key.startswith(prefix):
128
+ new_key = key[len(prefix):]
129
+ break
130
+ new_sd[new_key] = value
131
+
132
+ return new_sd
133
+
134
+
135
+ def validate_vit_structure(model: torch.nn.Module) -> Tuple[bool, str]:
136
+ """Valida se o modelo tem a estrutura esperada de um ViT timm-compatível.
137
+
138
+ Returns:
139
+ (is_valid, error_message) - se inválido, error_message descreve o problema
140
+ """
141
+ if not hasattr(model, 'blocks'):
142
+ return False, "Modelo não tem atributo 'blocks'. Não é um ViT compatível."
143
+
144
+ if len(model.blocks) == 0:
145
+ return False, "Modelo tem 'blocks' vazio."
146
+
147
+ block = model.blocks[0]
148
+ if not hasattr(block, 'attn'):
149
+ return False, "Bloco não tem atributo 'attn'. Estrutura incompatível."
150
+
151
+ attn = block.attn
152
+ if not hasattr(attn, 'qkv'):
153
+ return False, "Módulo de atenção não tem 'qkv'. Estrutura incompatível."
154
+
155
+ if not hasattr(attn, 'num_heads'):
156
+ return False, "Módulo de atenção não tem 'num_heads'. Estrutura incompatível."
157
+
158
+ return True, ""
159
+
160
+
161
  def infer_config_from_model(model: torch.nn.Module) -> ViTConfig:
162
  """Infere configuração ViT a partir de um modelo timm carregado."""
163
  config = ViTConfig()
 
206
  if layer_indices:
207
  config.num_layers = max(layer_indices) + 1
208
 
209
+ # Inferir embed_dim do primeiro bloco
210
  qkv_key = 'blocks.0.attn.qkv.weight'
211
  if qkv_key in state_dict:
212
  qkv_weight = state_dict[qkv_key]
213
  # qkv.weight shape: [3*embed_dim, embed_dim]
214
  config.embed_dim = qkv_weight.shape[1]
215
+ # Inferir num_heads diretamente: qkv tem shape [3*embed_dim, embed_dim]
216
+ # O output é 3*embed_dim = 3*num_heads*head_dim
217
+ # Podemos calcular num_heads = (qkv_out // 3) // head_dim
218
+ # Mas head_dim varia. Tentamos inferir de outra forma.
219
 
220
+ # Inferir num_heads: tentar múltiplos métodos
221
  proj_key = 'blocks.0.attn.proj.weight'
222
+ if proj_key in state_dict and qkv_key in state_dict:
 
223
  embed_dim = state_dict[proj_key].shape[0]
224
+ qkv_out = state_dict[qkv_key].shape[0] # 3*embed_dim
225
+
226
+ # Método 1: Se qkv_out == 3*embed_dim, tentar head_dim comum (64, 32, 96)
227
+ if qkv_out == 3 * embed_dim:
228
+ # Testar head_dims comuns em ordem de preferência
229
+ for head_dim in [64, 32, 96, 48, 128]:
230
+ if embed_dim % head_dim == 0:
231
+ config.num_heads = embed_dim // head_dim
232
+ break
233
+ else:
234
+ # Fallback: assumir que num_heads divide embed_dim uniformemente
235
+ # Tentar valores comuns de num_heads
236
+ for nh in [12, 16, 8, 6, 24, 4, 3]:
237
+ if embed_dim % nh == 0:
238
+ config.num_heads = nh
239
+ break
240
+
241
+ # Inferir qkv_bias
242
+ qkv_bias_key = 'blocks.0.attn.qkv.bias'
243
+ config.qkv_bias = qkv_bias_key in state_dict
244
+
245
+ # Inferir mlp_ratio do MLP
246
+ mlp_fc1_key = 'blocks.0.mlp.fc1.weight'
247
+ if mlp_fc1_key in state_dict and config.embed_dim > 0:
248
+ mlp_hidden = state_dict[mlp_fc1_key].shape[0]
249
+ config.mlp_ratio = mlp_hidden / config.embed_dim
250
 
251
  # Inferir num_classes do head
252
  head_key = 'head.weight'
 
357
  num_heads = int(getattr(cfg, "num_attention_heads", 12)) if cfg is not None else 12
358
  patch_size = int(getattr(cfg, "patch_size", 16)) if cfg is not None else 16
359
  img_size = int(getattr(cfg, "image_size", 224)) if cfg is not None else 224
360
+ intermediate_size = int(getattr(cfg, "intermediate_size", hidden_size * 4)) if cfg is not None else hidden_size * 4
361
+ qkv_bias = bool(getattr(cfg, "qkv_bias", True)) if cfg is not None else True
362
  class_names = _hf_id2label_to_class_names(getattr(cfg, "id2label", None)) if cfg is not None else None
363
 
364
  # Criar config dinâmico
 
368
  num_layers=num_layers,
369
  patch_size=patch_size,
370
  img_size=img_size,
371
+ num_classes=num_labels,
372
+ mlp_ratio=intermediate_size / hidden_size,
373
+ qkv_bias=qkv_bias
374
  )
375
 
376
+ print(f"[ViTViz] Carregando do HuggingFace: {vit_config.timm_model_name} "
377
+ f"(embed_dim={vit_config.embed_dim}, heads={vit_config.num_heads}, "
378
+ f"layers={vit_config.num_layers})")
379
+
380
+ # Criar modelo com arquitetura customizada diretamente
381
+ timm_model = create_vit_from_config(vit_config, device=device)
 
 
382
 
383
+ # Converter e carregar state_dict
384
  timm_sd = _convert_hf_vit_to_timm_state_dict(hf_model.state_dict(), num_layers=num_layers)
385
  timm_model.load_state_dict(timm_sd, strict=False)
 
386
  timm_model.eval()
387
 
 
 
 
388
  return timm_model, class_names, vit_config
389
 
390
 
 
400
 
401
 
402
  def load_checkpoint(model_path: str, device: Optional[torch.device] = None) -> Any:
403
+ """Carrega um checkpoint/modelo do caminho informado.
404
+
405
+ Suporta formatos:
406
+ - .pth / .pt: PyTorch checkpoint (torch.load)
407
+ - .safetensors: Formato moderno do HuggingFace (mais seguro e rápido)
408
 
409
  Retorna o objeto carregado (modelo completo, state_dict ou dict de checkpoint).
410
  """
411
  device = device or DEVICE_DEFAULT
412
+
413
+ # Detectar formato safetensors
414
+ if model_path.endswith('.safetensors'):
415
+ if load_safetensors is None:
416
+ raise ImportError(
417
+ "safetensors não está instalado. Instale com: pip install safetensors"
418
+ )
419
+ # safetensors sempre retorna um state_dict (não suporta modelo completo)
420
+ state_dict = load_safetensors(model_path, device=str(device))
421
+ return state_dict
422
+
423
+ # Formato PyTorch padrão (.pth, .pt, .ckpt, etc.)
424
  try:
425
  return torch.load(model_path, map_location=device, weights_only=False)
426
  except (AttributeError, ModuleNotFoundError, RuntimeError):
 
502
  def build_model_from_checkpoint(checkpoint: Any, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, ViTConfig]:
503
  """Constroi um modelo a partir de um checkpoint que pode ser um dict, state_dict ou o próprio modelo.
504
 
505
+ Suporta arquiteturas ViT arbitrárias, não limitadas aos nomes predefinidos do timm.
506
+
507
  Returns:
508
  (model, config) - modelo carregado e configuração inferida
509
  """
510
  device = device or DEVICE_DEFAULT
511
  config: Optional[ViTConfig] = None
512
 
513
+ # Detectar e logar se é checkpoint PyTorch Lightning
514
+ if isinstance(checkpoint, dict) and 'pytorch-lightning_version' in checkpoint:
515
+ print(f"[ViTViz] Detectado checkpoint PyTorch Lightning (v{checkpoint.get('pytorch-lightning_version', '?')})")
516
+
517
  if isinstance(checkpoint, dict):
518
  if 'model' in checkpoint:
519
+ # Modelo completo dentro do dict
520
  model = checkpoint['model']
521
  config = infer_config_from_model(model)
522
+ # Validar estrutura
523
+ is_valid, error_msg = validate_vit_structure(model)
524
+ if not is_valid:
525
+ raise ValueError(f"Modelo inválido: {error_msg}")
526
  elif 'state_dict' in checkpoint:
527
  state_dict = checkpoint['state_dict']
528
+ # Remover prefixos de frameworks (Lightning, DDP, etc.)
529
+ state_dict = _strip_state_dict_prefix(state_dict)
530
  config = infer_config_from_state_dict(state_dict)
531
+ print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} "
532
+ f"(embed_dim={config.embed_dim}, heads={config.num_heads}, "
533
+ f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})")
534
+ # Criar modelo com arquitetura customizada
535
+ model = create_vit_from_config(config, device=device)
536
+ # strict=False para suportar variações como CLIP (norm_pre, etc.)
537
+ model.load_state_dict(state_dict, strict=False)
 
 
 
538
  elif 'model_state_dict' in checkpoint:
539
  # Novo formato com class_names embutidas
540
  state_dict = checkpoint['model_state_dict']
541
+ # Remover prefixos de frameworks (Lightning, DDP, etc.)
542
+ state_dict = _strip_state_dict_prefix(state_dict)
543
  config = infer_config_from_state_dict(state_dict)
544
+ print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} "
545
+ f"(embed_dim={config.embed_dim}, heads={config.num_heads}, "
546
+ f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})")
547
+ # Criar modelo com arquitetura customizada
548
+ model = create_vit_from_config(config, device=device)
549
+ # strict=False para suportar variações como CLIP (norm_pre, etc.)
550
+ model.load_state_dict(state_dict, strict=False)
 
 
 
551
  else:
552
+ # assume dict é um state_dict puro
553
+ # Remover prefixos de frameworks (Lightning, DDP, etc.)
554
+ checkpoint = _strip_state_dict_prefix(checkpoint)
555
  config = infer_config_from_state_dict(checkpoint)
556
+ print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} "
557
+ f"(embed_dim={config.embed_dim}, heads={config.num_heads}, "
558
+ f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})")
559
+ # Criar modelo com arquitetura customizada
560
+ model = create_vit_from_config(config, device=device)
561
+ # strict=False para suportar variações como CLIP (norm_pre, etc.)
562
+ model.load_state_dict(checkpoint, strict=False)
 
 
 
563
  else:
564
  # modelo completo salvo via torch.save(model, ...)
565
  model = checkpoint
566
+ # Validar estrutura
567
+ is_valid, error_msg = validate_vit_structure(model)
568
+ if not is_valid:
569
+ raise ValueError(f"Modelo inválido: {error_msg}")
570
  config = infer_config_from_model(model)
571
 
572
  model = model.to(device)