File size: 21,173 Bytes
e11edb1
 
 
 
 
 
4833c40
e11edb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4833c40
 
e11edb1
 
 
 
 
 
 
 
98cf39b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4833c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98cf39b
 
 
4833c40
98cf39b
4833c40
 
98cf39b
4833c40
 
 
 
 
 
 
98cf39b
4833c40
 
 
 
 
 
 
98cf39b
4833c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98cf39b
 
 
 
 
 
 
 
 
 
 
 
4833c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98cf39b
4833c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ceb2d70
 
 
e11edb1
 
 
 
 
 
 
 
 
 
 
 
 
ceb2d70
e11edb1
 
 
 
ceb2d70
e11edb1
 
 
 
 
 
 
 
ceb2d70
 
 
 
 
 
 
 
 
 
 
 
 
 
e11edb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ceb2d70
 
e11edb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ceb2d70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4833c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98cf39b
 
4833c40
98cf39b
4833c40
 
98cf39b
4833c40
 
 
 
 
 
 
 
 
 
 
98cf39b
4833c40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98cf39b
4833c40
 
 
 
 
 
 
 
 
 
e11edb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from torchvision.models.feature_extraction import create_feature_extractor
from typing import Dict, Tuple, List, Optional


def extract_attention_maps(model, image: torch.Tensor) -> list:
    """
    Extrai attention maps de todas as camadas do ViT usando hooks.
    
    Implementação simplificada e robusta que calcula attention manualmente.

    Args:
        model: Modelo ViT
        image: Tensor de imagem [1, 3, 224, 224]

    Returns:
        attentions: lista de tensores [batch, heads, patches, patches]
    """
    attentions = []
    
    # Função de hook simplificada que captura entrada e calcula attention
    def make_attention_hook():
        def hook(module, input, output):
            x = input[0]  # Input do módulo de atenção
            B, N, C = x.shape
            
            # Verificar se tem os componentes necessários
            if not (hasattr(module, 'qkv') and hasattr(module, 'num_heads')):
                return
            
            # Calcular Q, K, V
            qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
            q, k, v = qkv.unbind(0)
            
            # Calcular attention weights
            scale = (C // module.num_heads) ** -0.5
            attn = (q @ k.transpose(-2, -1)) * scale
            attn = attn.softmax(dim=-1)
            
            # Salvar (já no CPU para não acumular na GPU)
            attentions.append(attn.detach().cpu())
        
        return hook
    
    # Encontrar e registrar hooks nos módulos de atenção
    hooks = []
    if not hasattr(model, 'blocks'):
        raise ValueError("Modelo não tem atributo 'blocks'. Não é um ViT compatível.")
    
    for i, block in enumerate(model.blocks):
        if hasattr(block, 'attn'):
            hook = block.attn.register_forward_hook(make_attention_hook())
            hooks.append(hook)
    
    if len(hooks) == 0:
        raise ValueError("Não foi possível registrar hooks. Verifique a arquitetura do modelo.")
    
    # Executar forward pass
    model.eval()
    with torch.inference_mode():
        _ = model(image)
    
    # Remover hooks
    for hook in hooks:
        hook.remove()

    # Garantir que capturamos atenções e retornar
    if len(attentions) == 0:
        raise ValueError(
            f"Nenhuma atenção capturada após registrar {len(hooks)} hooks. "
            f"A arquitetura do modelo pode não ser compatível."
        )
    return attentions


def _infer_grid_size_from_attentions(attentions_per_iter: list) -> int:
    """Infere o tamanho do grid a partir dos tensores de atenção."""
    if not attentions_per_iter:
        return 14
    for iter_attns in attentions_per_iter:
        if not iter_attns:
            continue
        for layer_tensor in iter_attns:
            if isinstance(layer_tensor, torch.Tensor) and layer_tensor.ndim == 4:
                # shape: [B, H, T, T] onde T = num_patches + 1 (CLS)
                num_tokens = layer_tensor.shape[-1]
                num_patches = num_tokens - 1
                side = int(num_patches ** 0.5)
                if side * side == num_patches:
                    return side
    return 14  # fallback


def extract_layer_head_masks(
    attentions_per_iter: list,
    layer_idx: int,
    head_idx: int,
    cls_only: bool = True
) -> list:
    """
    Extrai máscaras por iteração para uma cabeça específica de uma camada arbitrária.

    Args:
        attentions_per_iter: Lista por iteração; cada item é lista de tensores [B, H, T, T] por camada
        layer_idx: Índice da camada (0-based)
        head_idx: Índice da cabeça (0-based)
        cls_only: Se True, usa apenas a atenção do token CLS para os patches

    Returns:
        Lista de máscaras [grid, grid] normalizadas [0,1]
    """
    masks = []
    if attentions_per_iter is None or len(attentions_per_iter) == 0:
        return masks
    
    # Inferir grid_size dinamicamente
    default_grid = _infer_grid_size_from_attentions(attentions_per_iter)
    eps = 1e-8
    
    for iter_attns in attentions_per_iter:
        if not iter_attns or layer_idx < 0 or layer_idx >= len(iter_attns):
            masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
            continue
        layer_tensor = iter_attns[layer_idx]
        if isinstance(layer_tensor, torch.Tensor):
            att = layer_tensor.detach().cpu()
        else:
            att = torch.as_tensor(layer_tensor)
        if att.ndim != 4 or att.size(0) < 1 or head_idx < 0 or head_idx >= att.size(1):
            masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
            continue
        att_head = att[0, head_idx]  # [T,T]
        vec = att_head[0] if cls_only else att_head.mean(dim=0)
        vec_patches = vec[1:]
        tokens = vec_patches.numel()
        side = int(tokens ** 0.5)
        if side * side != tokens:
            masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
            continue
        mask = vec_patches.reshape(side, side)
        mask = mask / (mask.max() + eps)
        masks.append(mask.numpy())
    return masks


def get_num_layers_heads_from_cached(attentions_per_iter: List[List[torch.Tensor]]) -> Tuple[int, int]:
    """
    Inspeciona o cache de atenções para obter número de camadas e cabeças.

    Args:
        attentions_per_iter: Lista por iteração; cada item é lista por camada com tensores [B, H, T, T].

    Returns:
        (num_layers, num_heads)
    """
    if not attentions_per_iter:
        return 0, 0
    first_iter = attentions_per_iter[0]
    if not first_iter:
        return 0, 0
    num_layers = len(first_iter)
    # assume cabeças constantes entre camadas
    h = first_iter[0]
    if isinstance(h, torch.Tensor):
        num_heads = int(h.shape[1]) if h.ndim == 4 else 0
    else:
        h_t = torch.as_tensor(h)
        num_heads = int(h_t.shape[1]) if h_t.ndim == 4 else 0
    return num_layers, num_heads


def compute_layer_head_masks_from_cached_attns(iter_attns: List[torch.Tensor], cls_only: bool = True) -> List[List[np.ndarray]]:
    """
    Para uma iteração, computa máscaras por camada e cabeça.

    Args:
        iter_attns: Lista por camada de tensores [B, H, T, T]
        cls_only: Se True, usa linha do CLS para patches

    Returns:
        Lista [layer] de listas [head] com máscaras [side, side] normalizadas.
    """
    per_layer_head_masks: List[List[np.ndarray]] = []
    eps = 1e-8
    
    # Inferir grid_size do primeiro tensor válido
    default_grid = 14
    for layer_tensor in iter_attns:
        if isinstance(layer_tensor, torch.Tensor) and layer_tensor.ndim == 4:
            num_tokens = layer_tensor.shape[-1]
            num_patches = num_tokens - 1
            side = int(num_patches ** 0.5)
            if side * side == num_patches:
                default_grid = side
                break
    
    for li, layer_tensor in enumerate(iter_attns):
        if isinstance(layer_tensor, torch.Tensor):
            att = layer_tensor.detach().cpu()
        else:
            att = torch.as_tensor(layer_tensor)
        if att.ndim != 4 or att.size(0) < 1:
            # print(f"[ViTViz][compute_layer_head_masks] Iter layer {li}: invalid attention shape {att.shape if hasattr(att,'shape') else type(att)}")
            per_layer_head_masks.append([])
            continue
        heads_masks: List[np.ndarray] = []
        # print(f"[ViTViz][compute_layer_head_masks] Layer {li}: B={att.size(0)}, H={att.size(1)}, T={att.size(2)}")
        for h in range(att.size(1)):
            att_head = att[0, h]  # [T, T]
            vec = att_head[0] if cls_only else att_head.mean(dim=0)
            vec_patches = vec[1:]
            tokens = vec_patches.numel()
            side = int(tokens ** 0.5)
            if side * side != tokens:
                # print(f"[ViTViz][compute_layer_head_masks] Layer {li} head {h}: tokens {tokens} not square -> side={side}")
                heads_masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
                continue
            mask = vec_patches.reshape(side, side)
            mmax = float(mask.max())
            mask = mask / (mmax + eps)
            if mmax == 0:
                # print(f"[ViTViz][compute_layer_head_masks] Layer {li} head {h}: max=0, produced zero mask")
                pass
            heads_masks.append(mask.numpy())
        per_layer_head_masks.append(heads_masks)
    return per_layer_head_masks


def batch_precompute_all_masks(
    attentions_per_iter: List[List[torch.Tensor]],
    discard_ratio: float = 0.9,
    head_fusion: str = 'max',
    precompute_heads: bool = True
) -> Tuple[List[np.ndarray], Optional[List[List[List[np.ndarray]]]]]:
    """
    Pré-computa todas as máscaras de atenção:
    - Rollout por iteração
    - Opcionalmente, por camada/cabeça por iteração

    Args:
        attentions_per_iter: Lista por iteração com listas por camada [B,H,T,T]
        discard_ratio: parâmetro do rollout
        head_fusion: fusão de cabeças no rollout
        precompute_heads: se True, computa todas heads por camada

    Returns:
        (rollout_masks_por_iter, per_iter_layer_head_masks ou None)
    """
    rollout_masks: List[np.ndarray] = []
    per_iter_layer_head_masks: Optional[List[List[List[np.ndarray]]]] = [] if precompute_heads else None

    if not attentions_per_iter:
        return rollout_masks, per_iter_layer_head_masks

    for it_idx, iter_attns in enumerate(attentions_per_iter):
        # Rollout desta iteração
        attentions_cpu = []
        for li, att in enumerate(iter_attns):
            if isinstance(att, torch.Tensor):
                attentions_cpu.append(att.detach().cpu())
            else:
                attentions_cpu.append(torch.as_tensor(att))
        if len(attentions_cpu) == 0:
            # print(f"[ViTViz][batch_precompute] Iter {it_idx}: empty attentions list")
            pass
        rollout_mask = attention_rollout(
            attentions_cpu,
            discard_ratio=discard_ratio,
            head_fusion=head_fusion
        )
        rollout_masks.append(rollout_mask)

        # Heads por camada desta iteração
        if precompute_heads:
            # print(f"[ViTViz][batch_precompute] Iter {it_idx}: computing per-layer/head masks; layers={len(iter_attns)}")
            per_layer = compute_layer_head_masks_from_cached_attns(iter_attns, cls_only=True)
            per_iter_layer_head_masks.append(per_layer)

    return rollout_masks, per_iter_layer_head_masks


def attention_rollout(attentions: list,
                      discard_ratio: float = 0.9,
                      head_fusion: str = 'max') -> np.ndarray:
    """
    Implementa Attention Rollout seguindo a implementação original.
    
    Referência: https://github.com/jacobgil/vit-explain

    Args:
        attentions: Lista de tensores [batch, heads, patches, patches]
        discard_ratio: Proporção de atenções mais fracas a descartar (default: 0.9)
        head_fusion: Como agregar múltiplas cabeças - 'mean', 'max' ou 'min'

    Returns:
        mask: Array numpy [grid_size, grid_size] com valores normalizados [0, 1]
    """
    # Inicializar com matriz identidade
    result = torch.eye(attentions[0].size(-1))
    
    with torch.no_grad():
        for attention in attentions:
            # Agregar heads
            if head_fusion == 'mean':
                attention_heads_fused = attention.mean(axis=1)
            elif head_fusion == 'max':
                attention_heads_fused = attention.max(axis=1)[0]
            elif head_fusion == 'min':
                attention_heads_fused = attention.min(axis=1)[0]
            else:
                raise ValueError(f"head_fusion deve ser 'mean', 'max' ou 'min'")
            # Aplicar descarte condicional das atenções fracas por amostra
            if discard_ratio > 0.0:
                bsz, tokens, _ = attention_heads_fused.shape
                flat = attention_heads_fused.view(bsz, -1)
                k = int(flat.size(-1) * discard_ratio)
                if k > 0:
                    # Menores valores (largest=False)
                    vals, idxs = torch.topk(flat, k, dim=-1, largest=False)
                    for b in range(bsz):
                        idxs_b = idxs[b]
                        # proteger CLS (posição 0 nas matrizes quadradas)
                        idxs_b = idxs_b[idxs_b != 0]
                        flat[b, idxs_b] = 0
                    attention_heads_fused = flat.view(bsz, tokens, tokens)

            # Adicionar identidade e normalizar
            I = torch.eye(attention_heads_fused.size(-1))
            a = (attention_heads_fused + 1.0 * I) / 2
            
            # CORREÇÃO 3: normalizar sem keepdim
            a = a / a.sum(dim=-1)

            # Rollout recursivo
            result = torch.matmul(a, result)
    
    # Look at the total attention between the class token and the image patches
    mask = result[0, 0, 1:]
    
    # Calcular tamanho do grid
    width = int(mask.size(-1) ** 0.5)
    mask = mask.reshape(width, width).numpy()
    
    # Normalizar
    mask = mask / np.max(mask)

    return mask


def create_attention_overlay(original_image: Image.Image, 
                            attention_mask: np.ndarray,
                            alpha: float = 0.5,
                            colormap: str = 'jet') -> Image.Image:
    """
    Cria visualização sobrepondo o mapa de atenção na imagem original.
    
    Segue implementação de referência usando OpenCV.

    Args:
        original_image: Imagem PIL original
        attention_mask: Máscara de atenção [H, W] normalizada [0, 1]
        alpha: Peso da imagem original (0.7 = 70% imagem, 30% heatmap)
        colormap: 'jet' (padrão OpenCV)

    Returns:
        Imagem PIL com overlay de atenção
    """
    import cv2
    
    # Converter PIL para numpy array RGB
    img_np = np.array(original_image).astype(np.float32) / 255.0
    
    # Redimensionar máscara para o tamanho da imagem (224x224 ou tamanho original)
    h, w = img_np.shape[:2]
    mask_resized = cv2.resize(attention_mask, (w, h))
    
    # Aplicar colormap do OpenCV (retorna BGR!)
    heatmap = cv2.applyColorMap(np.uint8(255 * mask_resized), cv2.COLORMAP_JET)
    
    # CRÍTICO: Converter BGR → RGB (OpenCV usa BGR!)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    heatmap = heatmap.astype(np.float32) / 255.0
    
    # Blend: alpha * img_original + (1-alpha) * heatmap
    overlay = alpha * img_np + (1 - alpha) * heatmap
    overlay = np.clip(overlay, 0, 1)
    
    # Converter de volta para PIL
    overlay_uint8 = (overlay * 255).astype(np.uint8)
    return Image.fromarray(overlay_uint8)


def extract_attention_for_iterations(
    model,
    iteration_tensors: list,
    discard_ratio: float = 0.9,
    head_fusion: str = 'max'
) -> list:
    """
    [Deprecated when cached attentions are present]
    Extrai mapas de atenção para cada iteração do ataque usando hooks.
    
    Args:
        model: Modelo ViT
        iteration_tensors: Lista de tensors normalizados [1, 3, 224, 224] de cada iteração
        discard_ratio: Proporção de atenções fracas a descartar
        head_fusion: Como agregar heads ('mean', 'max', 'min')
    
    Returns:
        Lista de máscaras de atenção [14, 14] normalizadas [0, 1]
    """
    attention_masks = []
    
    for tensor in iteration_tensors:
        # Extrair attention maps para esta iteração
        attentions = extract_attention_maps(model, tensor)
        
        # Aplicar Attention Rollout
        mask = attention_rollout(
            attentions,
            discard_ratio=discard_ratio,
            head_fusion=head_fusion
        )
        
        attention_masks.append(mask)
    
    return attention_masks


def rollout_from_cached_attentions(
    attentions_per_iter: list,
    discard_ratio: float = 0.9,
    head_fusion: str = 'max'
) -> list:
    """
    Gera máscaras de atenção por iteração a partir de atenções já capturadas no ataque.

    Args:
        attentions_per_iter: Lista por iteração; cada item é a lista de tensores [B, H, T, T] por camada
        discard_ratio: Proporção de atenções fracas a descartar
        head_fusion: Como agregar heads ('mean', 'max', 'min')

    Returns:
        Lista de máscaras de atenção [grid, grid] normalizadas [0, 1]
    """
    attention_masks = []

    if attentions_per_iter is None or len(attentions_per_iter) == 0:
        return attention_masks

    for layer_attns in attentions_per_iter:
        # layer_attns: lista de tensores por camada [B, H, T, T]
        # Garantir CPU e detach
        attentions_cpu = []
        for att in layer_attns:
            if isinstance(att, torch.Tensor):
                attentions_cpu.append(att.detach().cpu())
            else:
                # já é CPU numpy/tensor? tentar converter via torch.as_tensor
                attentions_cpu.append(torch.as_tensor(att))

        # Aplicar rollout padrão sobre a lista de camadas
        mask = attention_rollout(
            attentions_cpu,
            discard_ratio=discard_ratio,
            head_fusion=head_fusion
        )
        attention_masks.append(mask)

    return attention_masks


def extract_last_layer_head_masks(
    attentions_per_iter: list,
    head_idx: int,
    cls_only: bool = True
) -> list:
    """
    Extrai máscaras por iteração para uma única cabeça da última camada.

    Args:
        attentions_per_iter: Lista por iteração; cada item é a lista de tensores [B, H, T, T] por camada
        head_idx: Índice da cabeça na última camada (0-based)
        cls_only: Se True, usa a atenção do token CLS (linha 0) para os patches

    Returns:
        Lista de máscaras [grid, grid] normalizadas [0, 1]
    """
    masks = []
    if attentions_per_iter is None or len(attentions_per_iter) == 0:
        return masks

    # Inferir grid_size dinamicamente
    default_grid = _infer_grid_size_from_attentions(attentions_per_iter)
    eps = 1e-8
    
    for iter_attns in attentions_per_iter:
        if not iter_attns:
            masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
            print("Atenções vazias para esta iteração.")
            continue
        # Última camada
        last_layer = iter_attns[-1]
        if isinstance(last_layer, torch.Tensor):
            att = last_layer.detach().cpu()
        else:
            att = torch.as_tensor(last_layer)

        # Esperado: [B, H, T, T] com B=1
        if att.ndim != 4 or att.size(0) < 1 or head_idx < 0 or head_idx >= att.size(1):
            masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
            print("Atenção inválida na última camada.")
            continue

        # Selecionar cabeça
        att_head = att[0, head_idx]  # [T, T]

        # Vetor atenção CLS→tokens
        if cls_only:
            vec = att_head[0]  # linha do CLS
        else:
            # média das linhas como alternativa
            vec = att_head.mean(dim=0)

        # Remover CLS e projetar para grade
        vec_patches = vec[1:]
        tokens = vec_patches.numel()
        side = int(tokens ** 0.5)
        if side * side != tokens:
            # fallback: normalizar e retornar zeros coerentes
            masks.append(np.zeros((default_grid, default_grid), dtype=np.float32))
            print("Número de patches não forma uma grade quadrada.")
            continue

        mask = vec_patches.reshape(side, side)
        mask = mask / (mask.max() + eps)
        masks.append(mask.numpy())

    return masks


def create_iteration_attention_overlays(
    iteration_images: list,
    attention_masks: list,
    alpha: float = 0.7
) -> list:
    """
    Cria overlays de atenção para cada iteração do ataque.
    OTIMIZADO para velocidade de renderização.
    
    Args:
        iteration_images: Lista de PIL Images (uma por iteração)
        attention_masks: Lista de máscaras de atenção [14, 14]
        alpha: Transparência do overlay
    
    Returns:
        Lista de PIL Images com heatmaps sobrepostos (comprimidas)
    """
    overlays = []
    
    for img, mask in zip(iteration_images, attention_masks):
        overlay = create_attention_overlay(img, mask, alpha=alpha)
        
        # OTIMIZAÇÃO AGRESSIVA: reduzir para 224x224 JPEG qualidade 75
        overlay = overlay.resize((224, 224), Image.LANCZOS)
        
        # Converter para RGB se necessário (JPEG não suporta RGBA)
        if overlay.mode in ('RGBA', 'LA', 'P'):
            background = Image.new('RGB', overlay.size, (255, 255, 255))
            if overlay.mode == 'P':
                overlay = overlay.convert('RGBA')
            background.paste(overlay, mask=overlay.split()[-1] if overlay.mode == 'RGBA' else None)
            overlay = background
        
        overlays.append(overlay)
    
    return overlays