lucasddmc commited on
Commit
98cf39b
·
1 Parent(s): 64e8f48

feat: adds different models capability

Browse files
.github/copilot-instructions.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ViTViz - AI Coding Agent Instructions
2
+
3
+ ## Project Overview
4
+
5
+ **ViTViz** is a Gradio-based web app for visualizing Vision Transformer (ViT) attention mechanisms and adversarial attacks on image classification. The app supports:
6
+ - Custom ViT model upload (.pth files) or Hugging Face Hub models
7
+ - Multiple adversarial attack methods (FGSM, PGD, MIM, TGR, SAGA)
8
+ - Attention visualization via Attention Rollout and per-layer/per-head views
9
+ - Interactive iteration-by-iteration comparison of adversarial examples
10
+
11
+ ## Architecture
12
+
13
+ ### Core Components
14
+ - **[app.py](app.py)**: Main Gradio interface with three tabs: Basic Classification, Attention Visualization, and Adversarial Attack Analysis
15
+ - **[utils/model_loader.py](utils/model_loader.py)**: Handles model loading from local .pth files, Hugging Face Hub, or special `hf://` URIs. Includes:
16
+ - `ViTConfig` dataclass for dynamic architecture configuration
17
+ - Automatic architecture inference from state_dict or loaded model
18
+ - Hugging Face → timm state_dict conversion
19
+ - **[utils/attacks.py](utils/attacks.py)**: Custom adversarial attack implementations that capture attention maps during attack iterations
20
+ - **[utils/visualization.py](utils/visualization.py)**: Attention extraction via forward hooks, attention rollout computation, and overlay creation with dynamic grid size inference
21
+ - **[utils/inference.py](utils/inference.py)**: Top-k prediction logic
22
+ - **[utils/preprocessing.py](utils/preprocessing.py)**: ImageNet-standard transforms with dynamic `img_size` support
23
+
24
+ ### Key Design Patterns
25
+
26
+ #### Dynamic Architecture Support (ViTConfig)
27
+ The codebase now supports multiple ViT architectures via automatic inference:
28
+
29
+ ```python
30
+ from utils.model_loader import ViTConfig, infer_config_from_model, infer_config_from_state_dict
31
+
32
+ # ViTConfig contains all architecture parameters
33
+ config = ViTConfig(
34
+ embed_dim=768, # 384=small, 768=base, 1024=large
35
+ num_heads=12, # 6=small, 12=base, 16=large
36
+ num_layers=12, # varies by model
37
+ patch_size=16, # 16 or 32
38
+ img_size=224, # 224, 384, etc.
39
+ num_classes=1000
40
+ )
41
+
42
+ # Properties computed automatically
43
+ config.grid_size # img_size // patch_size (e.g., 14 for 224/16)
44
+ config.num_patches # grid_size ** 2
45
+ config.timm_model_name # e.g., "vit_base_patch16_224"
46
+ ```
47
+
48
+ Supported architectures (auto-detected):
49
+ - `vit_tiny_patch16_224` (embed_dim=192, heads=3)
50
+ - `vit_small_patch16_224` (embed_dim=384, heads=6)
51
+ - `vit_base_patch16_224` (embed_dim=768, heads=12)
52
+ - `vit_large_patch16_224` (embed_dim=1024, heads=16)
53
+ - `vit_base_patch32_224` (embed_dim=768, patch_size=32, grid=7)
54
+
55
+ #### Model Loading Strategy
56
+ The codebase supports multiple model sources:
57
+ 1. **Local .pth files**: Can contain full model, `state_dict`, `model_state_dict`, or checkpoint dicts with `class_names`
58
+ 2. **Hugging Face Hub**: Use `hf-model://username/repo-name` format; automatically converts HF ViT to timm-compatible format
59
+ 3. **Special `hf://` URIs**: For CNN backbones in SAGA attacks (e.g., `hf://lucasddmc/resnet101-stanford40-actions/resnet.pth`)
60
+
61
+ The main loader returns 4 values:
62
+ ```python
63
+ model, class_names, label_source, vit_config = load_model_and_labels(model_path, None, device=DEVICE)
64
+ # vit_config.img_size, vit_config.grid_size, etc. are now available
65
+ ```
66
+
67
+ #### Attention Capture with Forward Hooks
68
+ All attention extraction uses PyTorch forward hooks on `model.blocks[i].attn` modules. The hook calculates Q, K, V manually and captures softmax attention weights before removal. See [visualization.py](utils/visualization.py#L12-L62).
69
+
70
+ **Critical**: Attention tensors are immediately moved to CPU to avoid GPU memory accumulation during iterative attacks.
71
+
72
+ **Dynamic grid size**: The `_infer_grid_size_from_attentions()` function automatically detects grid size from attention tensor shapes, eliminating hardcoded 14×14 assumptions.
73
+
74
+ #### Adversarial Attack Iteration Tracking
75
+ Custom attack classes (e.g., `PGDIterations`, `SAGA`) extend torchattacks and store:
76
+ - `attentions_per_iter`: List of attention maps per iteration (each iteration = list of layer tensors)
77
+ - Intermediate adversarial images via `tensor_to_pil()` with ImageNet denormalization
78
+
79
+ See [attacks.py](utils/attacks.py#L38-L107) for the denormalization pattern used consistently across attacks.
80
+
81
+ #### Gradio State Management
82
+ The attack tab uses multiple Gradio `State` components to cache expensive computations:
83
+ - `cached_attentions_state`: Raw attention maps from attack iterations
84
+ - `per_iter_rollout_masks_state`: Pre-computed rollout masks for all iterations
85
+ - `per_iter_layer_head_masks_state`: Pre-computed masks for all layers/heads (nested structure: `[iter][layer][head]`)
86
+
87
+ This avoids re-running attacks when users adjust visualization parameters (discard ratio, head fusion, alpha overlay).
88
+
89
+ ## Development Workflows
90
+
91
+ ### Running the App
92
+ ```bash
93
+ # Activate virtual environment (if exists)
94
+ source venv/bin/activate
95
+
96
+ # Install dependencies
97
+ pip install -r requirements.txt
98
+
99
+ # Run app locally (default port 7860)
100
+ python app.py
101
+
102
+ # Run on specific port
103
+ PORT=8080 python app.py
104
+ ```
105
+
106
+ ### Model Conversion
107
+ Use [convert_model_with_classes.py](convert_model_with_classes.py) to embed class names into checkpoint files:
108
+ ```bash
109
+ python convert_model_with_classes.py
110
+ ```
111
+ This extracts class names from Stanford40 dataset structure (`action_name_###.jpg`) and adds them to the checkpoint as `class_names` dict.
112
+
113
+ ### Adding New Attack Methods
114
+ 1. Subclass from `torchattacks.Attack` in [attacks.py](utils/attacks.py)
115
+ 2. Store `self.attentions_per_iter` as a list during `forward()` calls
116
+ 3. Call `capture_outputs_and_attentions(model, x_adv)` at each iteration to extract attention
117
+ 4. Return `(final_adv_tensor, iteration_images)` where `iteration_images` includes original + all intermediate steps
118
+ 5. Add attack to dropdown in [app.py](app.py#L270-L280) and handle parameters in `run_attack()` function
119
+
120
+ Example: [SAGA attack](utils/attacks.py#L774-L920) implements ViT+CNN gradient blending.
121
+
122
+ ## Critical Conventions
123
+
124
+ ### Device Management
125
+ All code uses `DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")` pattern. Models and tensors are explicitly moved to device at load/inference time.
126
+
127
+ ### Epsilon in [0,1] Space
128
+ **Important**: Adversarial perturbation epsilon is defined in the **denormalized [0,1]** image space, not the normalized space. The L∞ distance metric in attack results also uses denormalized space for user interpretability. See [app.py](app.py#L271-L280).
129
+
130
+ ### Class Name Handling
131
+ Supports three sources (priority order):
132
+ 1. External labels file (currently disabled in code)
133
+ 2. Embedded in checkpoint as `class_names` dict
134
+ 3. Hugging Face `id2label` from model config
135
+
136
+ Returns `None` if unavailable; UI falls back to showing class indices.
137
+
138
+ ### Gradio File Input Extraction
139
+ Use `_to_path()` helper ([app.py](app.py#L45-L57)) to handle different Gradio file input formats (string, dict with 'name', object with .name attribute).
140
+
141
+ ### Custom CSS and Icons
142
+ The app injects Bootstrap Icons via CDN and custom CSS for panels/tables. Icon constants (e.g., `ICON_SUCCESS`, `ICON_FAIL`) are defined at the top of [app.py](app.py#L26-L30).
143
+
144
+ ## External Dependencies
145
+
146
+ - **timm**: ViT model architecture (`vit_base_patch16_224` is the default)
147
+ - **torchattacks**: Base classes for adversarial attacks
148
+ - **transformers**: Optional, for loading HF Hub models
149
+ - **gradio**: Version 5.49.1 (specified in requirements)
150
+
151
+ ## Testing Strategy
152
+
153
+ Currently no automated tests. Manual testing workflow:
154
+ 1. Upload model → check classification works
155
+ 2. Run attention visualization → verify heatmaps align with predicted class
156
+ 3. Run attack → verify iteration slider shows progression
157
+ 4. Toggle layer/head sliders → verify attention updates without re-running attack
158
+
159
+ ## Known Limitations
160
+
161
+ - Supports timm ViT architectures (tiny, small, base, large) with patch sizes 16 and 32
162
+ - No support for non-standard ViT variants (DeiT distillation token, Swin hierarchical, BEiT) without additional conversion
163
+ - Custom CSS may break with Gradio version updates
164
+ - No batch processing support (processes one image at a time)
app.py CHANGED
@@ -6,7 +6,7 @@ from PIL import Image
6
  from typing import Optional, List, Tuple
7
  from pathlib import Path
8
 
9
- from utils.model_loader import load_model_and_labels
10
  from utils.preprocessing import get_default_transform, preprocess_image
11
  from utils.inference import predict_topk
12
  from utils.attacks import PGDIterations, FGSM, SAGA, MIFGSM, TGR
@@ -32,7 +32,6 @@ ICON_CHART = '<i class="bi bi-bar-chart-line-fill vitviz-bi" aria-hidden="true">
32
  ICON_RULER = '<i class="bi bi-speedometer2 vitviz-bi" aria-hidden="true"></i>'
33
  ICON_GEAR = '<i class="bi bi-gear-fill vitviz-bi" aria-hidden="true"></i>'
34
 
35
- transform = get_default_transform()
36
  # Backbone CNN opcional usado no modo "SAGA (with CNN gradient)".
37
  # Pode ser um caminho local (ex.: "models/resnet.pth") ou um checkpoint no Hugging Face Hub.
38
  RESNET_BACKBONE_SPEC = "hf://lucasddmc/resnet101-stanford40-actions/resnet.pth"
@@ -89,14 +88,15 @@ def classify_image(model_file, use_hf_vit: bool, image):
89
  model_path = HF_VIT_MODEL_SPEC if use_hf_vit else _to_path(model_file)
90
 
91
  # Carregar modelo e labels
92
- model, class_names, label_source = load_model_and_labels(model_path, None, device=DEVICE)
93
  # _print_model_heads(model)
94
 
95
- # Processar imagem
96
  if not (isinstance(image, str) or isinstance(image, Image.Image)):
97
  return "Please upload a valid image"
98
 
99
- img_tensor = preprocess_image(image, transform=transform).to(DEVICE)
 
100
 
101
  # Inferência
102
  top_prob, top_idx, num_classes, probabilities = predict_topk(model, img_tensor, top_k=5)
@@ -159,11 +159,12 @@ def visualize_attention(
159
 
160
  # Carregar modelo e labels
161
  model_path = HF_VIT_MODEL_SPEC if use_hf_vit else _to_path(model_file)
162
- model, class_names, label_source = load_model_and_labels(model_path, None, device=DEVICE)
163
  # _print_model_heads(model)
164
 
165
- # Processar imagem
166
- img_tensor = preprocess_image(image, transform=transform).to(DEVICE)
 
167
 
168
  # Predição
169
  top_prob, top_idx, num_classes, _ = predict_topk(model, img_tensor, top_k=1, device=DEVICE)
@@ -249,11 +250,12 @@ def run_attack(
249
 
250
  # Carregar modelo e labels
251
  model_path = HF_VIT_MODEL_SPEC if use_hf_vit else _to_path(model_file)
252
- model, class_names, label_source = load_model_and_labels(model_path, None, device=DEVICE)
253
  # _print_model_heads(model)
254
 
255
- # Processar imagem
256
- img_tensor = preprocess_image(image, transform=transform).to(DEVICE)
 
257
 
258
  # Predição original (top-5 para comparação)
259
  top_prob_orig, top_idx_orig, num_classes, _ = predict_topk(model, img_tensor, top_k=5, device=DEVICE)
 
6
  from typing import Optional, List, Tuple
7
  from pathlib import Path
8
 
9
+ from utils.model_loader import load_model_and_labels, ViTConfig
10
  from utils.preprocessing import get_default_transform, preprocess_image
11
  from utils.inference import predict_topk
12
  from utils.attacks import PGDIterations, FGSM, SAGA, MIFGSM, TGR
 
32
  ICON_RULER = '<i class="bi bi-speedometer2 vitviz-bi" aria-hidden="true"></i>'
33
  ICON_GEAR = '<i class="bi bi-gear-fill vitviz-bi" aria-hidden="true"></i>'
34
 
 
35
  # Backbone CNN opcional usado no modo "SAGA (with CNN gradient)".
36
  # Pode ser um caminho local (ex.: "models/resnet.pth") ou um checkpoint no Hugging Face Hub.
37
  RESNET_BACKBONE_SPEC = "hf://lucasddmc/resnet101-stanford40-actions/resnet.pth"
 
88
  model_path = HF_VIT_MODEL_SPEC if use_hf_vit else _to_path(model_file)
89
 
90
  # Carregar modelo e labels
91
+ model, class_names, label_source, vit_config = load_model_and_labels(model_path, None, device=DEVICE)
92
  # _print_model_heads(model)
93
 
94
+ # Processar imagem com transform dinâmico baseado no modelo
95
  if not (isinstance(image, str) or isinstance(image, Image.Image)):
96
  return "Please upload a valid image"
97
 
98
+ dynamic_transform = get_default_transform(img_size=vit_config.img_size)
99
+ img_tensor = preprocess_image(image, transform=dynamic_transform).to(DEVICE)
100
 
101
  # Inferência
102
  top_prob, top_idx, num_classes, probabilities = predict_topk(model, img_tensor, top_k=5)
 
159
 
160
  # Carregar modelo e labels
161
  model_path = HF_VIT_MODEL_SPEC if use_hf_vit else _to_path(model_file)
162
+ model, class_names, label_source, vit_config = load_model_and_labels(model_path, None, device=DEVICE)
163
  # _print_model_heads(model)
164
 
165
+ # Processar imagem com transform dinâmico baseado no modelo
166
+ dynamic_transform = get_default_transform(img_size=vit_config.img_size)
167
+ img_tensor = preprocess_image(image, transform=dynamic_transform).to(DEVICE)
168
 
169
  # Predição
170
  top_prob, top_idx, num_classes, _ = predict_topk(model, img_tensor, top_k=1, device=DEVICE)
 
250
 
251
  # Carregar modelo e labels
252
  model_path = HF_VIT_MODEL_SPEC if use_hf_vit else _to_path(model_file)
253
+ model, class_names, label_source, vit_config = load_model_and_labels(model_path, None, device=DEVICE)
254
  # _print_model_heads(model)
255
 
256
+ # Processar imagem com transform dinâmico baseado no modelo
257
+ dynamic_transform = get_default_transform(img_size=vit_config.img_size)
258
+ img_tensor = preprocess_image(image, transform=dynamic_transform).to(DEVICE)
259
 
260
  # Predição original (top-5 para comparação)
261
  top_prob_orig, top_idx_orig, num_classes, _ = predict_topk(model, img_tensor, top_k=5, device=DEVICE)
convert_model_with_classes.py → deprecated/convert_model_with_classes.py RENAMED
File without changes
utils/model_loader.py CHANGED
@@ -1,6 +1,7 @@
1
  import pickle
2
  import torch
3
  import timm
 
4
  from typing import Optional, Tuple, Dict, Any
5
 
6
  try:
@@ -11,6 +12,128 @@ except Exception: # pragma: no cover
11
  DEVICE_DEFAULT = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def _hf_id2label_to_class_names(id2label: Any) -> Optional[Dict[int, str]]:
15
  if not isinstance(id2label, dict):
16
  return None
@@ -77,8 +200,12 @@ def _convert_hf_vit_to_timm_state_dict(hf_sd: Dict[str, torch.Tensor], num_layer
77
  return out
78
 
79
 
80
- def load_vit_from_huggingface(model_id: str, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, Optional[Dict[int, str]]]:
81
- """Carrega ViT do Hugging Face Hub e retorna um modelo timm equivalente."""
 
 
 
 
82
  if AutoModelForImageClassification is None:
83
  raise RuntimeError("transformers não está instalado; instale 'transformers' para carregar do Hugging Face.")
84
 
@@ -88,14 +215,40 @@ def load_vit_from_huggingface(model_id: str, device: Optional[torch.device] = No
88
  cfg = getattr(hf_model, "config", None)
89
  num_labels = int(getattr(cfg, "num_labels", 1000)) if cfg is not None else 1000
90
  num_layers = int(getattr(cfg, "num_hidden_layers", 12)) if cfg is not None else 12
 
 
 
 
91
  class_names = _hf_id2label_to_class_names(getattr(cfg, "id2label", None)) if cfg is not None else None
92
 
93
- timm_model = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=num_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  timm_sd = _convert_hf_vit_to_timm_state_dict(hf_model.state_dict(), num_layers=num_layers)
95
  timm_model.load_state_dict(timm_sd, strict=False)
96
  timm_model = timm_model.to(device)
97
  timm_model.eval()
98
- return timm_model, class_names
 
 
 
 
99
 
100
 
101
  class CustomUnpickler(pickle.Unpickler):
@@ -193,59 +346,94 @@ def load_class_names_from_file(labels_file: Optional[str]) -> Optional[Dict[int,
193
  return None
194
 
195
 
196
- def build_model_from_checkpoint(checkpoint: Any, device: Optional[torch.device] = None) -> torch.nn.Module:
197
- """Constroi um modelo a partir de um checkpoint que pode ser um dict, state_dict ou o próprio modelo."""
 
 
 
 
198
  device = device or DEVICE_DEFAULT
 
 
199
  if isinstance(checkpoint, dict):
200
  if 'model' in checkpoint:
201
  model = checkpoint['model']
 
202
  elif 'state_dict' in checkpoint:
203
  state_dict = checkpoint['state_dict']
 
204
  num_classes = infer_num_classes(state_dict)
205
- # TODO: fazer tratamento dinâmico para timm, pytorch, etc.
206
- model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes)
 
 
 
 
 
 
207
  model.load_state_dict(state_dict)
208
  elif 'model_state_dict' in checkpoint:
209
  # Novo formato com class_names embutidas
210
  state_dict = checkpoint['model_state_dict']
 
211
  num_classes = infer_num_classes(state_dict)
212
- # TODO: fazer tratamento dinâmico para timm, pytorch, etc.
213
- model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes)
 
 
 
 
 
 
214
  model.load_state_dict(state_dict)
215
  else:
216
  # assume dict é um state_dict
 
217
  num_classes = infer_num_classes(checkpoint)
218
- # TODO: fazer tratamento dinâmico para timm, pytorch, etc.
219
- model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes)
 
 
 
 
 
 
220
  model.load_state_dict(checkpoint)
221
  else:
222
  # modelo completo salvo via torch.save(model, ...)
223
  model = checkpoint
 
224
 
225
  model = model.to(device)
226
  model.eval()
227
- return model
 
 
 
 
 
228
 
229
 
230
  def load_model_and_labels(
231
  model_path: str,
232
  labels_file: Optional[str] = None,
233
  device: Optional[torch.device] = None,
234
- ) -> Tuple[torch.nn.Module, Optional[Dict[int, str]], Optional[str]]:
235
  """
236
  ** Função Principal **
237
  Carrega modelo e, se disponível, nomes de classes.
238
 
239
- Retorna: (model, class_names, origem_labels) onde origem_labels ∈ {"file", "checkpoint", None}
240
  None se não houver nomes de classes disponíveis.
 
241
  """
242
  device = device or DEVICE_DEFAULT
243
 
244
  # Carregar diretamente do Hugging Face Hub (Transformers -> timm)
245
  if isinstance(model_path, str) and model_path.startswith("hf-model://"):
246
  model_id = model_path[len("hf-model://"):].strip("/")
247
- model, class_names = load_vit_from_huggingface(model_id, device=device)
248
- return model, class_names, 'hf'
249
 
250
  checkpoint = load_checkpoint(model_path, device=device)
251
  class_names_ckpt = extract_class_names(checkpoint)
@@ -260,5 +448,5 @@ def load_model_and_labels(
260
  class_names = class_names_ckpt
261
  source = 'checkpoint' if class_names_ckpt else None
262
 
263
- model = build_model_from_checkpoint(checkpoint, device=device)
264
- return model, class_names, source
 
1
  import pickle
2
  import torch
3
  import timm
4
+ from dataclasses import dataclass
5
  from typing import Optional, Tuple, Dict, Any
6
 
7
  try:
 
12
  DEVICE_DEFAULT = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
 
15
+ @dataclass
16
+ class ViTConfig:
17
+ """Configuração de arquitetura ViT extraída dinamicamente do modelo."""
18
+ embed_dim: int = 768
19
+ num_heads: int = 12
20
+ num_layers: int = 12
21
+ patch_size: int = 16
22
+ img_size: int = 224
23
+ num_classes: int = 1000
24
+
25
+ @property
26
+ def grid_size(self) -> int:
27
+ """Tamanho do grid de patches (ex: 224/16 = 14)."""
28
+ return self.img_size // self.patch_size
29
+
30
+ @property
31
+ def num_patches(self) -> int:
32
+ """Número total de patches (ex: 14*14 = 196)."""
33
+ return self.grid_size ** 2
34
+
35
+ @property
36
+ def timm_model_name(self) -> str:
37
+ """Retorna o nome do modelo timm correspondente à configuração."""
38
+ # Mapeamento baseado em embed_dim e num_heads
39
+ size_map = {
40
+ (192, 3): 'tiny',
41
+ (384, 6): 'small',
42
+ (768, 12): 'base',
43
+ (1024, 16): 'large',
44
+ (1280, 16): 'huge',
45
+ }
46
+ size = size_map.get((self.embed_dim, self.num_heads), 'base')
47
+ return f"vit_{size}_patch{self.patch_size}_{self.img_size}"
48
+
49
+
50
+ def infer_config_from_model(model: torch.nn.Module) -> ViTConfig:
51
+ """Infere configuração ViT a partir de um modelo timm carregado."""
52
+ config = ViTConfig()
53
+
54
+ # Extrair img_size e patch_size do patch_embed
55
+ if hasattr(model, 'patch_embed'):
56
+ pe = model.patch_embed
57
+ if hasattr(pe, 'img_size'):
58
+ img_size = pe.img_size
59
+ config.img_size = img_size[0] if isinstance(img_size, (tuple, list)) else img_size
60
+ if hasattr(pe, 'patch_size'):
61
+ patch_size = pe.patch_size
62
+ config.patch_size = patch_size[0] if isinstance(patch_size, (tuple, list)) else patch_size
63
+
64
+ # Extrair num_layers, embed_dim, num_heads dos blocks
65
+ if hasattr(model, 'blocks') and len(model.blocks) > 0:
66
+ config.num_layers = len(model.blocks)
67
+ block = model.blocks[0]
68
+ if hasattr(block, 'attn'):
69
+ attn = block.attn
70
+ if hasattr(attn, 'num_heads'):
71
+ config.num_heads = attn.num_heads
72
+ if hasattr(attn, 'qkv') and hasattr(attn.qkv, 'in_features'):
73
+ config.embed_dim = attn.qkv.in_features
74
+
75
+ # Extrair num_classes do head
76
+ if hasattr(model, 'head') and hasattr(model.head, 'out_features'):
77
+ config.num_classes = model.head.out_features
78
+ elif hasattr(model, 'head') and hasattr(model.head, 'weight'):
79
+ config.num_classes = model.head.weight.shape[0]
80
+
81
+ return config
82
+
83
+
84
+ def infer_config_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> ViTConfig:
85
+ """Infere configuração ViT a partir de um state_dict."""
86
+ config = ViTConfig()
87
+
88
+ # Inferir num_layers contando blocks
89
+ layer_indices = set()
90
+ for key in state_dict.keys():
91
+ if key.startswith('blocks.') and '.attn.' in key:
92
+ # blocks.0.attn.qkv.weight -> extrair 0
93
+ idx = int(key.split('.')[1])
94
+ layer_indices.add(idx)
95
+ if layer_indices:
96
+ config.num_layers = max(layer_indices) + 1
97
+
98
+ # Inferir embed_dim e num_heads do primeiro bloco
99
+ qkv_key = 'blocks.0.attn.qkv.weight'
100
+ if qkv_key in state_dict:
101
+ qkv_weight = state_dict[qkv_key]
102
+ # qkv.weight shape: [3*embed_dim, embed_dim]
103
+ config.embed_dim = qkv_weight.shape[1]
104
+
105
+ # Inferir num_heads do proj bias ou de forma heurística
106
+ proj_key = 'blocks.0.attn.proj.weight'
107
+ if proj_key in state_dict:
108
+ # proj.weight shape: [embed_dim, embed_dim]
109
+ embed_dim = state_dict[proj_key].shape[0]
110
+ # Heurística: head_dim típico é 64
111
+ config.num_heads = embed_dim // 64
112
+
113
+ # Inferir num_classes do head
114
+ head_key = 'head.weight'
115
+ if head_key in state_dict:
116
+ config.num_classes = state_dict[head_key].shape[0]
117
+
118
+ # Inferir patch_size e img_size do patch_embed
119
+ patch_proj_key = 'patch_embed.proj.weight'
120
+ if patch_proj_key in state_dict:
121
+ # shape: [embed_dim, 3, patch_size, patch_size]
122
+ patch_weight = state_dict[patch_proj_key]
123
+ config.patch_size = patch_weight.shape[2]
124
+
125
+ # Inferir img_size do pos_embed
126
+ pos_embed_key = 'pos_embed'
127
+ if pos_embed_key in state_dict:
128
+ # shape: [1, num_patches+1, embed_dim]
129
+ num_tokens = state_dict[pos_embed_key].shape[1]
130
+ num_patches = num_tokens - 1 # -1 para CLS token
131
+ grid_size = int(num_patches ** 0.5)
132
+ config.img_size = grid_size * config.patch_size
133
+
134
+ return config
135
+
136
+
137
  def _hf_id2label_to_class_names(id2label: Any) -> Optional[Dict[int, str]]:
138
  if not isinstance(id2label, dict):
139
  return None
 
200
  return out
201
 
202
 
203
+ def load_vit_from_huggingface(model_id: str, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, Optional[Dict[int, str]], ViTConfig]:
204
+ """Carrega ViT do Hugging Face Hub e retorna um modelo timm equivalente.
205
+
206
+ Returns:
207
+ (model, class_names, config)
208
+ """
209
  if AutoModelForImageClassification is None:
210
  raise RuntimeError("transformers não está instalado; instale 'transformers' para carregar do Hugging Face.")
211
 
 
215
  cfg = getattr(hf_model, "config", None)
216
  num_labels = int(getattr(cfg, "num_labels", 1000)) if cfg is not None else 1000
217
  num_layers = int(getattr(cfg, "num_hidden_layers", 12)) if cfg is not None else 12
218
+ hidden_size = int(getattr(cfg, "hidden_size", 768)) if cfg is not None else 768
219
+ num_heads = int(getattr(cfg, "num_attention_heads", 12)) if cfg is not None else 12
220
+ patch_size = int(getattr(cfg, "patch_size", 16)) if cfg is not None else 16
221
+ img_size = int(getattr(cfg, "image_size", 224)) if cfg is not None else 224
222
  class_names = _hf_id2label_to_class_names(getattr(cfg, "id2label", None)) if cfg is not None else None
223
 
224
+ # Criar config dinâmico
225
+ vit_config = ViTConfig(
226
+ embed_dim=hidden_size,
227
+ num_heads=num_heads,
228
+ num_layers=num_layers,
229
+ patch_size=patch_size,
230
+ img_size=img_size,
231
+ num_classes=num_labels
232
+ )
233
+
234
+ # Tentar encontrar o modelo timm correspondente
235
+ timm_name = vit_config.timm_model_name
236
+ try:
237
+ timm_model = timm.create_model(timm_name, pretrained=False, num_classes=num_labels)
238
+ except Exception:
239
+ # Fallback para vit_base_patch16_224 se o modelo não existir
240
+ print(f"[ViTViz] Modelo timm '{timm_name}' não encontrado, usando vit_base_patch16_224")
241
+ timm_model = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=num_labels)
242
+
243
  timm_sd = _convert_hf_vit_to_timm_state_dict(hf_model.state_dict(), num_layers=num_layers)
244
  timm_model.load_state_dict(timm_sd, strict=False)
245
  timm_model = timm_model.to(device)
246
  timm_model.eval()
247
+
248
+ # Atualizar config com valores reais do modelo carregado
249
+ vit_config = infer_config_from_model(timm_model)
250
+
251
+ return timm_model, class_names, vit_config
252
 
253
 
254
  class CustomUnpickler(pickle.Unpickler):
 
346
  return None
347
 
348
 
349
+ def build_model_from_checkpoint(checkpoint: Any, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, ViTConfig]:
350
+ """Constroi um modelo a partir de um checkpoint que pode ser um dict, state_dict ou o próprio modelo.
351
+
352
+ Returns:
353
+ (model, config) - modelo carregado e configuração inferida
354
+ """
355
  device = device or DEVICE_DEFAULT
356
+ config: Optional[ViTConfig] = None
357
+
358
  if isinstance(checkpoint, dict):
359
  if 'model' in checkpoint:
360
  model = checkpoint['model']
361
+ config = infer_config_from_model(model)
362
  elif 'state_dict' in checkpoint:
363
  state_dict = checkpoint['state_dict']
364
+ config = infer_config_from_state_dict(state_dict)
365
  num_classes = infer_num_classes(state_dict)
366
+ config.num_classes = num_classes
367
+ # Usar arquitetura inferida
368
+ timm_name = config.timm_model_name
369
+ try:
370
+ model = timm.create_model(timm_name, pretrained=False, num_classes=num_classes)
371
+ except Exception:
372
+ print(f"[ViTViz] Modelo timm '{timm_name}' não encontrado, usando vit_base_patch16_224")
373
+ model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes)
374
  model.load_state_dict(state_dict)
375
  elif 'model_state_dict' in checkpoint:
376
  # Novo formato com class_names embutidas
377
  state_dict = checkpoint['model_state_dict']
378
+ config = infer_config_from_state_dict(state_dict)
379
  num_classes = infer_num_classes(state_dict)
380
+ config.num_classes = num_classes
381
+ # Usar arquitetura inferida
382
+ timm_name = config.timm_model_name
383
+ try:
384
+ model = timm.create_model(timm_name, pretrained=False, num_classes=num_classes)
385
+ except Exception:
386
+ print(f"[ViTViz] Modelo timm '{timm_name}' não encontrado, usando vit_base_patch16_224")
387
+ model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes)
388
  model.load_state_dict(state_dict)
389
  else:
390
  # assume dict é um state_dict
391
+ config = infer_config_from_state_dict(checkpoint)
392
  num_classes = infer_num_classes(checkpoint)
393
+ config.num_classes = num_classes
394
+ # Usar arquitetura inferida
395
+ timm_name = config.timm_model_name
396
+ try:
397
+ model = timm.create_model(timm_name, pretrained=False, num_classes=num_classes)
398
+ except Exception:
399
+ print(f"[ViTViz] Modelo timm '{timm_name}' não encontrado, usando vit_base_patch16_224")
400
+ model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes)
401
  model.load_state_dict(checkpoint)
402
  else:
403
  # modelo completo salvo via torch.save(model, ...)
404
  model = checkpoint
405
+ config = infer_config_from_model(model)
406
 
407
  model = model.to(device)
408
  model.eval()
409
+
410
+ # Garantir que config está preenchido
411
+ if config is None:
412
+ config = infer_config_from_model(model)
413
+
414
+ return model, config
415
 
416
 
417
  def load_model_and_labels(
418
  model_path: str,
419
  labels_file: Optional[str] = None,
420
  device: Optional[torch.device] = None,
421
+ ) -> Tuple[torch.nn.Module, Optional[Dict[int, str]], Optional[str], ViTConfig]:
422
  """
423
  ** Função Principal **
424
  Carrega modelo e, se disponível, nomes de classes.
425
 
426
+ Retorna: (model, class_names, origem_labels, config) onde origem_labels ∈ {"file", "checkpoint", "hf", None}
427
  None se não houver nomes de classes disponíveis.
428
+ config contém a configuração da arquitetura ViT (embed_dim, num_heads, grid_size, etc.)
429
  """
430
  device = device or DEVICE_DEFAULT
431
 
432
  # Carregar diretamente do Hugging Face Hub (Transformers -> timm)
433
  if isinstance(model_path, str) and model_path.startswith("hf-model://"):
434
  model_id = model_path[len("hf-model://"):].strip("/")
435
+ model, class_names, config = load_vit_from_huggingface(model_id, device=device)
436
+ return model, class_names, 'hf', config
437
 
438
  checkpoint = load_checkpoint(model_path, device=device)
439
  class_names_ckpt = extract_class_names(checkpoint)
 
448
  class_names = class_names_ckpt
449
  source = 'checkpoint' if class_names_ckpt else None
450
 
451
+ model, config = build_model_from_checkpoint(checkpoint, device=device)
452
+ return model, class_names, source, config
utils/preprocessing.py CHANGED
@@ -4,13 +4,20 @@ from torchvision import transforms
4
  import torch
5
 
6
 
7
- # TODO: implementar adapters para diferentes modelos com outros tipos de classes
8
-
9
- def get_default_transform() -> transforms.Compose:
10
- """Transform padrão (Resize+CenterCrop+Normalize) compatível com modelos ImageNet."""
 
 
 
 
 
 
 
11
  return transforms.Compose([
12
- transforms.Resize(256),
13
- transforms.CenterCrop(224),
14
  transforms.ToTensor(),
15
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
16
  ])
 
4
  import torch
5
 
6
 
7
+ def get_default_transform(img_size: int = 224) -> transforms.Compose:
8
+ """Transform padrão (Resize+CenterCrop+Normalize) compatível com modelos ImageNet.
9
+
10
+ Args:
11
+ img_size: Tamanho da imagem de entrada do modelo (default: 224)
12
+
13
+ Returns:
14
+ Compose de transforms para preprocessamento
15
+ """
16
+ # Resize proporcional: 256 para 224, escala para outros tamanhos
17
+ resize_size = int(img_size * 256 / 224)
18
  return transforms.Compose([
19
+ transforms.Resize(resize_size),
20
+ transforms.CenterCrop(img_size),
21
  transforms.ToTensor(),
22
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
23
  ])
utils/visualization.py CHANGED
@@ -77,6 +77,24 @@ def extract_attention_maps(model, image: torch.Tensor) -> list:
77
  return attentions
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def extract_layer_head_masks(
81
  attentions_per_iter: list,
82
  layer_idx: int,
@@ -98,10 +116,14 @@ def extract_layer_head_masks(
98
  masks = []
99
  if attentions_per_iter is None or len(attentions_per_iter) == 0:
100
  return masks
 
 
 
101
  eps = 1e-8
 
102
  for iter_attns in attentions_per_iter:
103
  if not iter_attns or layer_idx < 0 or layer_idx >= len(iter_attns):
104
- masks.append(np.zeros((14, 14), dtype=np.float32))
105
  continue
106
  layer_tensor = iter_attns[layer_idx]
107
  if isinstance(layer_tensor, torch.Tensor):
@@ -109,7 +131,7 @@ def extract_layer_head_masks(
109
  else:
110
  att = torch.as_tensor(layer_tensor)
111
  if att.ndim != 4 or att.size(0) < 1 or head_idx < 0 or head_idx >= att.size(1):
112
- masks.append(np.zeros((14, 14), dtype=np.float32))
113
  continue
114
  att_head = att[0, head_idx] # [T,T]
115
  vec = att_head[0] if cls_only else att_head.mean(dim=0)
@@ -117,7 +139,7 @@ def extract_layer_head_masks(
117
  tokens = vec_patches.numel()
118
  side = int(tokens ** 0.5)
119
  if side * side != tokens:
120
- masks.append(np.zeros((14, 14), dtype=np.float32))
121
  continue
122
  mask = vec_patches.reshape(side, side)
123
  mask = mask / (mask.max() + eps)
@@ -164,6 +186,18 @@ def compute_layer_head_masks_from_cached_attns(iter_attns: List[torch.Tensor], c
164
  """
165
  per_layer_head_masks: List[List[np.ndarray]] = []
166
  eps = 1e-8
 
 
 
 
 
 
 
 
 
 
 
 
167
  for li, layer_tensor in enumerate(iter_attns):
168
  if isinstance(layer_tensor, torch.Tensor):
169
  att = layer_tensor.detach().cpu()
@@ -183,7 +217,7 @@ def compute_layer_head_masks_from_cached_attns(iter_attns: List[torch.Tensor], c
183
  side = int(tokens ** 0.5)
184
  if side * side != tokens:
185
  # print(f"[ViTViz][compute_layer_head_masks] Layer {li} head {h}: tokens {tokens} not square -> side={side}")
186
- heads_masks.append(np.zeros((14, 14), dtype=np.float32))
187
  continue
188
  mask = vec_patches.reshape(side, side)
189
  mmax = float(mask.max())
@@ -460,10 +494,13 @@ def extract_last_layer_head_masks(
460
  if attentions_per_iter is None or len(attentions_per_iter) == 0:
461
  return masks
462
 
 
 
463
  eps = 1e-8
 
464
  for iter_attns in attentions_per_iter:
465
  if not iter_attns:
466
- masks.append(np.zeros((14, 14), dtype=np.float32))
467
  print("Atenções vazias para esta iteração.")
468
  continue
469
  # Última camada
@@ -475,7 +512,7 @@ def extract_last_layer_head_masks(
475
 
476
  # Esperado: [B, H, T, T] com B=1
477
  if att.ndim != 4 or att.size(0) < 1 or head_idx < 0 or head_idx >= att.size(1):
478
- masks.append(np.zeros((14, 14), dtype=np.float32))
479
  print("Atenção inválida na última camada.")
480
  continue
481
 
@@ -495,7 +532,7 @@ def extract_last_layer_head_masks(
495
  side = int(tokens ** 0.5)
496
  if side * side != tokens:
497
  # fallback: normalizar e retornar zeros coerentes
498
- masks.append(np.zeros((14, 14), dtype=np.float32))
499
  print("Número de patches não forma uma grade quadrada.")
500
  continue
501
 
 
77
  return attentions
78
 
79
 
80
+ def _infer_grid_size_from_attentions(attentions_per_iter: list) -> int:
81
+ """Infere o tamanho do grid a partir dos tensores de atenção."""
82
+ if not attentions_per_iter:
83
+ return 14
84
+ for iter_attns in attentions_per_iter:
85
+ if not iter_attns:
86
+ continue
87
+ for layer_tensor in iter_attns:
88
+ if isinstance(layer_tensor, torch.Tensor) and layer_tensor.ndim == 4:
89
+ # shape: [B, H, T, T] onde T = num_patches + 1 (CLS)
90
+ num_tokens = layer_tensor.shape[-1]
91
+ num_patches = num_tokens - 1
92
+ side = int(num_patches ** 0.5)
93
+ if side * side == num_patches:
94
+ return side
95
+ return 14 # fallback
96
+
97
+
98
  def extract_layer_head_masks(
99
  attentions_per_iter: list,
100
  layer_idx: int,
 
116
  masks = []
117
  if attentions_per_iter is None or len(attentions_per_iter) == 0:
118
  return masks
119
+
120
+ # Inferir grid_size dinamicamente
121
+ default_grid = _infer_grid_size_from_attentions(attentions_per_iter)
122
  eps = 1e-8
123
+
124
  for iter_attns in attentions_per_iter:
125
  if not iter_attns or layer_idx < 0 or layer_idx >= len(iter_attns):
126
+ masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
127
  continue
128
  layer_tensor = iter_attns[layer_idx]
129
  if isinstance(layer_tensor, torch.Tensor):
 
131
  else:
132
  att = torch.as_tensor(layer_tensor)
133
  if att.ndim != 4 or att.size(0) < 1 or head_idx < 0 or head_idx >= att.size(1):
134
+ masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
135
  continue
136
  att_head = att[0, head_idx] # [T,T]
137
  vec = att_head[0] if cls_only else att_head.mean(dim=0)
 
139
  tokens = vec_patches.numel()
140
  side = int(tokens ** 0.5)
141
  if side * side != tokens:
142
+ masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
143
  continue
144
  mask = vec_patches.reshape(side, side)
145
  mask = mask / (mask.max() + eps)
 
186
  """
187
  per_layer_head_masks: List[List[np.ndarray]] = []
188
  eps = 1e-8
189
+
190
+ # Inferir grid_size do primeiro tensor válido
191
+ default_grid = 14
192
+ for layer_tensor in iter_attns:
193
+ if isinstance(layer_tensor, torch.Tensor) and layer_tensor.ndim == 4:
194
+ num_tokens = layer_tensor.shape[-1]
195
+ num_patches = num_tokens - 1
196
+ side = int(num_patches ** 0.5)
197
+ if side * side == num_patches:
198
+ default_grid = side
199
+ break
200
+
201
  for li, layer_tensor in enumerate(iter_attns):
202
  if isinstance(layer_tensor, torch.Tensor):
203
  att = layer_tensor.detach().cpu()
 
217
  side = int(tokens ** 0.5)
218
  if side * side != tokens:
219
  # print(f"[ViTViz][compute_layer_head_masks] Layer {li} head {h}: tokens {tokens} not square -> side={side}")
220
+ heads_masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
221
  continue
222
  mask = vec_patches.reshape(side, side)
223
  mmax = float(mask.max())
 
494
  if attentions_per_iter is None or len(attentions_per_iter) == 0:
495
  return masks
496
 
497
+ # Inferir grid_size dinamicamente
498
+ default_grid = _infer_grid_size_from_attentions(attentions_per_iter)
499
  eps = 1e-8
500
+
501
  for iter_attns in attentions_per_iter:
502
  if not iter_attns:
503
+ masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
504
  print("Atenções vazias para esta iteração.")
505
  continue
506
  # Última camada
 
512
 
513
  # Esperado: [B, H, T, T] com B=1
514
  if att.ndim != 4 or att.size(0) < 1 or head_idx < 0 or head_idx >= att.size(1):
515
+ masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
516
  print("Atenção inválida na última camada.")
517
  continue
518
 
 
532
  side = int(tokens ** 0.5)
533
  if side * side != tokens:
534
  # fallback: normalizar e retornar zeros coerentes
535
+ masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
536
  print("Número de patches não forma uma grade quadrada.")
537
  continue
538