File size: 25,259 Bytes
e11edb1
 
 
98cf39b
e11edb1
 
d012f9c
 
 
 
 
 
f8ab302
 
 
 
 
d012f9c
 
 
 
 
 
e11edb1
 
 
98cf39b
 
 
 
 
 
 
 
 
d012f9c
 
98cf39b
 
 
 
 
 
 
 
 
 
 
 
 
d012f9c
98cf39b
 
 
 
 
 
 
 
d012f9c
98cf39b
 
 
d012f9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98cf39b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d012f9c
98cf39b
 
 
 
 
d012f9c
 
 
 
98cf39b
d012f9c
98cf39b
d012f9c
98cf39b
d012f9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98cf39b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8ab302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99483e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98cf39b
 
 
 
 
 
f8ab302
 
 
 
 
 
 
 
 
99483e9
 
 
 
 
 
 
 
 
 
 
 
 
 
d012f9c
 
99483e9
 
d012f9c
f8ab302
 
98cf39b
 
f8ab302
 
e11edb1
 
 
 
 
 
 
 
 
 
 
 
d012f9c
 
 
 
 
e11edb1
 
 
 
d012f9c
 
 
 
 
 
 
 
 
 
 
 
e11edb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98cf39b
 
 
d012f9c
 
98cf39b
 
 
e11edb1
98cf39b
 
d012f9c
 
 
 
e11edb1
 
d012f9c
e11edb1
98cf39b
d012f9c
 
 
 
e11edb1
 
d012f9c
 
98cf39b
d012f9c
 
 
 
 
 
 
98d71cd
 
 
d012f9c
 
98cf39b
d012f9c
 
 
 
 
 
 
e11edb1
d012f9c
 
 
98cf39b
d012f9c
 
 
 
 
 
 
e11edb1
 
 
d012f9c
 
 
 
98cf39b
e11edb1
 
 
98cf39b
 
 
 
 
 
e11edb1
 
 
 
 
 
98cf39b
e11edb1
 
 
 
98cf39b
e11edb1
98cf39b
e11edb1
 
f8ab302
 
 
 
98cf39b
 
f8ab302
e11edb1
 
ceb2d70
 
 
 
 
 
 
 
 
 
e11edb1
98cf39b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
import pickle
import torch
import timm
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any

# Importar VisionTransformer diretamente para criar modelos com arquiteturas customizadas
try:
    from timm.models.vision_transformer import VisionTransformer
except ImportError:
    VisionTransformer = None

try:
    from transformers import AutoModelForImageClassification
except Exception:  # pragma: no cover
    AutoModelForImageClassification = None

# Suporte a safetensors (formato moderno do HuggingFace)
try:
    from safetensors.torch import load_file as load_safetensors
except ImportError:
    load_safetensors = None

DEVICE_DEFAULT = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@dataclass
class ViTConfig:
    """Configuração de arquitetura ViT extraída dinamicamente do modelo."""
    embed_dim: int = 768
    num_heads: int = 12
    num_layers: int = 12
    patch_size: int = 16
    img_size: int = 224
    num_classes: int = 1000
    mlp_ratio: float = 4.0
    qkv_bias: bool = True
    
    @property
    def grid_size(self) -> int:
        """Tamanho do grid de patches (ex: 224/16 = 14)."""
        return self.img_size // self.patch_size
    
    @property
    def num_patches(self) -> int:
        """Número total de patches (ex: 14*14 = 196)."""
        return self.grid_size ** 2
    
    @property
    def timm_model_name(self) -> str:
        """Retorna o nome do modelo timm correspondente (para fins informativos)."""
        # Mapeamento baseado em embed_dim e num_heads
        size_map = {
            (192, 3): 'tiny',
            (384, 6): 'small', 
            (768, 12): 'base',
            (1024, 16): 'large',
            (1280, 16): 'huge',
        }
        size = size_map.get((self.embed_dim, self.num_heads), 'custom')
        return f"vit_{size}_patch{self.patch_size}_{self.img_size}"


def create_vit_from_config(config: ViTConfig, device: Optional[torch.device] = None) -> torch.nn.Module:
    """Cria um modelo ViT diretamente a partir da configuração inferida.
    
    Isso permite criar modelos com arquiteturas arbitrárias, não limitadas
    aos nomes predefinidos do timm (vit_base_patch16_224, etc.).
    """
    device = device or DEVICE_DEFAULT
    
    if VisionTransformer is None:
        raise RuntimeError("VisionTransformer não disponível. Verifique a instalação do timm.")
    
    model = VisionTransformer(
        img_size=config.img_size,
        patch_size=config.patch_size,
        in_chans=3,
        num_classes=config.num_classes,
        embed_dim=config.embed_dim,
        depth=config.num_layers,
        num_heads=config.num_heads,
        mlp_ratio=config.mlp_ratio,
        qkv_bias=config.qkv_bias,
        class_token=True,
        global_pool='token',
    )
    
    return model.to(device)


def _strip_state_dict_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """Remove prefixos comuns de frameworks (Lightning, DDP, etc.) das keys do state_dict.
    
    Prefixos tratados:
    - 'model.' (PyTorch Lightning)
    - 'module.' (DataParallel/DistributedDataParallel)
    - 'encoder.' (alguns frameworks de self-supervised learning)
    - 'backbone.' (alguns frameworks de detecção)
    
    Returns:
        state_dict com keys sem prefixo
    """
    prefixes = ['model.', 'module.', 'encoder.', 'backbone.']
    
    # Verificar se alguma key tem prefixo
    has_prefix = False
    detected_prefix = None
    for key in state_dict.keys():
        for prefix in prefixes:
            if key.startswith(prefix):
                has_prefix = True
                detected_prefix = prefix
                break
        if has_prefix:
            break
    
    if not has_prefix:
        return state_dict
    
    print(f"[ViTViz] Detectado prefixo '{detected_prefix}' nas keys do state_dict (Lightning/DDP). Removendo...")
    
    new_sd: Dict[str, torch.Tensor] = {}
    for key, value in state_dict.items():
        new_key = key
        for prefix in prefixes:
            if key.startswith(prefix):
                new_key = key[len(prefix):]
                break
        new_sd[new_key] = value
    
    return new_sd


def validate_vit_structure(model: torch.nn.Module) -> Tuple[bool, str]:
    """Valida se o modelo tem a estrutura esperada de um ViT timm-compatível.
    
    Returns:
        (is_valid, error_message) - se inválido, error_message descreve o problema
    """
    if not hasattr(model, 'blocks'):
        return False, "Modelo não tem atributo 'blocks'. Não é um ViT compatível."
    
    if len(model.blocks) == 0:
        return False, "Modelo tem 'blocks' vazio."
    
    block = model.blocks[0]
    if not hasattr(block, 'attn'):
        return False, "Bloco não tem atributo 'attn'. Estrutura incompatível."
    
    attn = block.attn
    if not hasattr(attn, 'qkv'):
        return False, "Módulo de atenção não tem 'qkv'. Estrutura incompatível."
    
    if not hasattr(attn, 'num_heads'):
        return False, "Módulo de atenção não tem 'num_heads'. Estrutura incompatível."
    
    return True, ""


def infer_config_from_model(model: torch.nn.Module) -> ViTConfig:
    """Infere configuração ViT a partir de um modelo timm carregado."""
    config = ViTConfig()
    
    # Extrair img_size e patch_size do patch_embed
    if hasattr(model, 'patch_embed'):
        pe = model.patch_embed
        if hasattr(pe, 'img_size'):
            img_size = pe.img_size
            config.img_size = img_size[0] if isinstance(img_size, (tuple, list)) else img_size
        if hasattr(pe, 'patch_size'):
            patch_size = pe.patch_size
            config.patch_size = patch_size[0] if isinstance(patch_size, (tuple, list)) else patch_size
    
    # Extrair num_layers, embed_dim, num_heads dos blocks
    if hasattr(model, 'blocks') and len(model.blocks) > 0:
        config.num_layers = len(model.blocks)
        block = model.blocks[0]
        if hasattr(block, 'attn'):
            attn = block.attn
            if hasattr(attn, 'num_heads'):
                config.num_heads = attn.num_heads
            if hasattr(attn, 'qkv') and hasattr(attn.qkv, 'in_features'):
                config.embed_dim = attn.qkv.in_features
    
    # Extrair num_classes do head
    if hasattr(model, 'head') and hasattr(model.head, 'out_features'):
        config.num_classes = model.head.out_features
    elif hasattr(model, 'head') and hasattr(model.head, 'weight'):
        config.num_classes = model.head.weight.shape[0]
    
    return config


def infer_config_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> ViTConfig:
    """Infere configuração ViT a partir de um state_dict."""
    config = ViTConfig()
    
    # Inferir num_layers contando blocks
    layer_indices = set()
    for key in state_dict.keys():
        if key.startswith('blocks.') and '.attn.' in key:
            # blocks.0.attn.qkv.weight -> extrair 0
            idx = int(key.split('.')[1])
            layer_indices.add(idx)
    if layer_indices:
        config.num_layers = max(layer_indices) + 1
    
    # Inferir embed_dim do primeiro bloco
    qkv_key = 'blocks.0.attn.qkv.weight'
    if qkv_key in state_dict:
        qkv_weight = state_dict[qkv_key]
        # qkv.weight shape: [3*embed_dim, embed_dim]
        config.embed_dim = qkv_weight.shape[1]
        # Inferir num_heads diretamente: qkv tem shape [3*embed_dim, embed_dim]
        # O output é 3*embed_dim = 3*num_heads*head_dim
        # Podemos calcular num_heads = (qkv_out // 3) // head_dim
        # Mas head_dim varia. Tentamos inferir de outra forma.
    
    # Inferir num_heads: tentar múltiplos métodos
    proj_key = 'blocks.0.attn.proj.weight'
    if proj_key in state_dict and qkv_key in state_dict:
        embed_dim = state_dict[proj_key].shape[0]
        qkv_out = state_dict[qkv_key].shape[0]  # 3*embed_dim
        
        # Método 1: Se qkv_out == 3*embed_dim, tentar head_dim comum (64, 32, 96)
        if qkv_out == 3 * embed_dim:
            # Testar head_dims comuns em ordem de preferência
            for head_dim in [64, 32, 96, 48, 128]:
                if embed_dim % head_dim == 0:
                    config.num_heads = embed_dim // head_dim
                    break
            else:
                # Fallback: assumir que num_heads divide embed_dim uniformemente
                # Tentar valores comuns de num_heads
                for nh in [12, 16, 8, 6, 24, 4, 3]:
                    if embed_dim % nh == 0:
                        config.num_heads = nh
                        break
    
    # Inferir qkv_bias
    qkv_bias_key = 'blocks.0.attn.qkv.bias'
    config.qkv_bias = qkv_bias_key in state_dict
    
    # Inferir mlp_ratio do MLP
    mlp_fc1_key = 'blocks.0.mlp.fc1.weight'
    if mlp_fc1_key in state_dict and config.embed_dim > 0:
        mlp_hidden = state_dict[mlp_fc1_key].shape[0]
        config.mlp_ratio = mlp_hidden / config.embed_dim
    
    # Inferir num_classes do head
    head_key = 'head.weight'
    if head_key in state_dict:
        config.num_classes = state_dict[head_key].shape[0]
    
    # Inferir patch_size e img_size do patch_embed
    patch_proj_key = 'patch_embed.proj.weight'
    if patch_proj_key in state_dict:
        # shape: [embed_dim, 3, patch_size, patch_size]
        patch_weight = state_dict[patch_proj_key]
        config.patch_size = patch_weight.shape[2]
    
    # Inferir img_size do pos_embed
    pos_embed_key = 'pos_embed'
    if pos_embed_key in state_dict:
        # shape: [1, num_patches+1, embed_dim]
        num_tokens = state_dict[pos_embed_key].shape[1]
        num_patches = num_tokens - 1  # -1 para CLS token
        grid_size = int(num_patches ** 0.5)
        config.img_size = grid_size * config.patch_size
    
    return config


def _hf_id2label_to_class_names(id2label: Any) -> Optional[Dict[int, str]]:
    if not isinstance(id2label, dict):
        return None
    out: Dict[int, str] = {}
    for k, v in id2label.items():
        try:
            out[int(k)] = str(v)
        except Exception:
            continue
    return out or None


def _convert_hf_vit_to_timm_state_dict(hf_sd: Dict[str, torch.Tensor], num_layers: int) -> Dict[str, torch.Tensor]:
    """Converte state_dict de ViT (Hugging Face Transformers) para chaves do timm ViT.

    Alvo: timm "vit_base_patch16_224".
    """
    out: Dict[str, torch.Tensor] = {}

    def get(key: str) -> torch.Tensor:
        if key not in hf_sd:
            raise KeyError(f"Missing key in HF state_dict: {key}")
        return hf_sd[key]

    # Embeddings
    out["cls_token"] = get("vit.embeddings.cls_token")
    out["pos_embed"] = get("vit.embeddings.position_embeddings")
    out["patch_embed.proj.weight"] = get("vit.embeddings.patch_embeddings.projection.weight")
    out["patch_embed.proj.bias"] = get("vit.embeddings.patch_embeddings.projection.bias")

    # Encoder blocks
    for i in range(num_layers):
        prefix = f"vit.encoder.layer.{i}"
        out[f"blocks.{i}.norm1.weight"] = get(f"{prefix}.layernorm_before.weight")
        out[f"blocks.{i}.norm1.bias"] = get(f"{prefix}.layernorm_before.bias")
        out[f"blocks.{i}.norm2.weight"] = get(f"{prefix}.layernorm_after.weight")
        out[f"blocks.{i}.norm2.bias"] = get(f"{prefix}.layernorm_after.bias")

        qw = get(f"{prefix}.attention.attention.query.weight")
        kw = get(f"{prefix}.attention.attention.key.weight")
        vw = get(f"{prefix}.attention.attention.value.weight")
        qb = get(f"{prefix}.attention.attention.query.bias")
        kb = get(f"{prefix}.attention.attention.key.bias")
        vb = get(f"{prefix}.attention.attention.value.bias")
        out[f"blocks.{i}.attn.qkv.weight"] = torch.cat([qw, kw, vw], dim=0)
        out[f"blocks.{i}.attn.qkv.bias"] = torch.cat([qb, kb, vb], dim=0)

        out[f"blocks.{i}.attn.proj.weight"] = get(f"{prefix}.attention.output.dense.weight")
        out[f"blocks.{i}.attn.proj.bias"] = get(f"{prefix}.attention.output.dense.bias")

        out[f"blocks.{i}.mlp.fc1.weight"] = get(f"{prefix}.intermediate.dense.weight")
        out[f"blocks.{i}.mlp.fc1.bias"] = get(f"{prefix}.intermediate.dense.bias")
        out[f"blocks.{i}.mlp.fc2.weight"] = get(f"{prefix}.output.dense.weight")
        out[f"blocks.{i}.mlp.fc2.bias"] = get(f"{prefix}.output.dense.bias")

    out["norm.weight"] = get("vit.layernorm.weight")
    out["norm.bias"] = get("vit.layernorm.bias")

    # Classifier
    if "classifier.weight" in hf_sd and "classifier.bias" in hf_sd:
        out["head.weight"] = get("classifier.weight")
        out["head.bias"] = get("classifier.bias")

    return out


def _convert_hf_timm_wrapper_to_timm_state_dict(hf_sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """Converte state_dict de TimmWrapper (Transformers) para formato timm ViT.

    Exemplo de origem: chaves com prefixo ``timm_model.``.
    """
    out: Dict[str, torch.Tensor] = {}

    for key, value in hf_sd.items():
        if key.startswith("timm_model."):
            out[key[len("timm_model."):]] = value
        elif key.startswith("classifier."):
            # Alguns wrappers usam head separado como classifier.
            out[f"head.{key[len('classifier.'):]}"] = value

    if not out:
        raise ValueError("State_dict de TimmWrapper sem chaves reconhecidas (timm_model.* / classifier.*).")

    return out


def load_vit_from_huggingface(model_id: str, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, Optional[Dict[int, str]], ViTConfig]:
    """Carrega ViT do Hugging Face Hub e retorna um modelo timm equivalente.
    
    Returns:
        (model, class_names, config)
    """
    if AutoModelForImageClassification is None:
        raise RuntimeError("transformers não está instalado; instale 'transformers' para carregar do Hugging Face.")

    device = device or DEVICE_DEFAULT
    hf_model = AutoModelForImageClassification.from_pretrained(model_id)
    hf_model.eval()
    cfg = getattr(hf_model, "config", None)
    class_names = _hf_id2label_to_class_names(getattr(cfg, "id2label", None)) if cfg is not None else None

    hf_sd = hf_model.state_dict()
    if any(key.startswith("timm_model.") for key in hf_sd.keys()):
        timm_sd = _convert_hf_timm_wrapper_to_timm_state_dict(hf_sd)
    else:
        num_layers = int(getattr(cfg, "num_hidden_layers", 12)) if cfg is not None else 12
        timm_sd = _convert_hf_vit_to_timm_state_dict(hf_sd, num_layers=num_layers)

    vit_config = infer_config_from_state_dict(timm_sd)
    if cfg is not None and hasattr(cfg, "num_labels"):
        try:
            vit_config.num_classes = int(getattr(cfg, "num_labels"))
        except Exception:
            pass

    print(f"[ViTViz] Carregando do HuggingFace: {vit_config.timm_model_name} "
          f"(embed_dim={vit_config.embed_dim}, heads={vit_config.num_heads}, "
          f"layers={vit_config.num_layers}, patch={vit_config.patch_size}, img={vit_config.img_size})")

    timm_model = create_vit_from_config(vit_config, device=device)
    timm_model.load_state_dict(timm_sd, strict=False)
    timm_model.eval()
    
    return timm_model, class_names, vit_config


class CustomUnpickler(pickle.Unpickler):
    """Unpickler que ignora classes customizadas ausentes criando dummies dinamicamente."""

    def find_class(self, module, name):
        try:
            return super().find_class(module, name)
        except Exception:
            # Cria uma classe dummy com o mesmo nome para permitir o unpickle
            return type(name, (), {})


def load_checkpoint(model_path: str, device: Optional[torch.device] = None) -> Any:
    """Carrega um checkpoint/modelo do caminho informado.
    
    Suporta formatos:
    - .pth / .pt: PyTorch checkpoint (torch.load)
    - .safetensors: Formato moderno do HuggingFace (mais seguro e rápido)

    Retorna o objeto carregado (modelo completo, state_dict ou dict de checkpoint).
    """
    device = device or DEVICE_DEFAULT
    
    # Detectar formato safetensors
    if model_path.endswith('.safetensors'):
        if load_safetensors is None:
            raise ImportError(
                "safetensors não está instalado. Instale com: pip install safetensors"
            )
        # safetensors sempre retorna um state_dict (não suporta modelo completo)
        state_dict = load_safetensors(model_path, device=str(device))
        return state_dict
    
    # Formato PyTorch padrão (.pth, .pt, .ckpt, etc.)
    try:
        return torch.load(model_path, map_location=device, weights_only=False)
    except (AttributeError, ModuleNotFoundError, RuntimeError):
        # Fallback quando há classes ausentes ou conflitos de versão
        with open(model_path, 'rb') as f:
            return CustomUnpickler(f).load()


def infer_num_classes(state_dict: Dict[str, torch.Tensor]) -> int:
    """Infere o número de classes a partir do state_dict (camada de head).

    Caso não encontre, retorna 1000 (padrão ImageNet).
    """
    for key, tensor in state_dict.items():
        if 'head' in key and 'weight' in key and hasattr(tensor, 'shape'):
            return tensor.shape[0]
    return 1000


def extract_class_names(checkpoint: Any) -> Optional[Dict[int, str]]:
    """Tenta extrair nomes de classes de um checkpoint (se presente)."""
    if not isinstance(checkpoint, dict):
        return None

    possible_keys = [
        'class_names', 'classes', 'class_to_idx', 'idx_to_class',
        'label_names', 'labels', 'class_labels'
    ]

    for key in possible_keys:
        if key in checkpoint:
            labels = checkpoint[key]
            if isinstance(labels, list):
                return {i: name for i, name in enumerate(labels)}
            if isinstance(labels, dict):
                # Se já for idx->nome
                if all(isinstance(k, int) for k in labels.keys()):
                    return labels  # type: ignore[return-value]
                # Se for nome->idx
                if all(isinstance(v, int) for v in labels.values()):
                    return {v: k for k, v in labels.items()}
            return labels  # type: ignore[return-value]
    return None


def load_class_names_from_file(labels_file: Optional[str]) -> Optional[Dict[int, str]]:
    """Carrega nomes de classes de um arquivo .txt (um por linha) ou .json (lista ou dict)."""
    if not labels_file:
        return None
    import json
    try:
        if labels_file.endswith('.json'):
            with open(labels_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
                if isinstance(data, list):
                    return {i: name for i, name in enumerate(data)}
                if isinstance(data, dict):
                    out: Dict[int, str] = {}
                    for k, v in data.items():
                        try:
                            out[int(k)] = v
                        except Exception:
                            # Ignora chaves não numéricas
                            pass
                    if out:
                        return out
                    # fallback se for nome->idx
                    if all(isinstance(v, int) for v in data.values()):
                        return {v: k for k, v in data.items()}
                    return None
        else:
            with open(labels_file, 'r', encoding='utf-8') as f:
                lines = [line.strip() for line in f if line.strip()]
                return {i: name for i, name in enumerate(lines)}
    except Exception:
        return None


def build_model_from_checkpoint(checkpoint: Any, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, ViTConfig]:
    """Constroi um modelo a partir de um checkpoint que pode ser um dict, state_dict ou o próprio modelo.
    
    Suporta arquiteturas ViT arbitrárias, não limitadas aos nomes predefinidos do timm.
    
    Returns:
        (model, config) - modelo carregado e configuração inferida
    """
    device = device or DEVICE_DEFAULT
    config: Optional[ViTConfig] = None
    
    # Detectar e logar se é checkpoint PyTorch Lightning
    if isinstance(checkpoint, dict) and 'pytorch-lightning_version' in checkpoint:
        print(f"[ViTViz] Detectado checkpoint PyTorch Lightning (v{checkpoint.get('pytorch-lightning_version', '?')})")
    
    if isinstance(checkpoint, dict):
        if 'model' in checkpoint:
            # Modelo completo dentro do dict
            model = checkpoint['model']
            config = infer_config_from_model(model)
            # Validar estrutura
            is_valid, error_msg = validate_vit_structure(model)
            if not is_valid:
                raise ValueError(f"Modelo inválido: {error_msg}")
        elif 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
            # Remover prefixos de frameworks (Lightning, DDP, etc.)
            state_dict = _strip_state_dict_prefix(state_dict)
            config = infer_config_from_state_dict(state_dict)
            print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} "
                  f"(embed_dim={config.embed_dim}, heads={config.num_heads}, "
                  f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})")
            # Criar modelo com arquitetura customizada
            model = create_vit_from_config(config, device=device)
            # strict=False para suportar variações como CLIP (norm_pre, etc.)
            model.load_state_dict(state_dict, strict=False)
        elif 'model_state_dict' in checkpoint:
            # Novo formato com class_names embutidas
            state_dict = checkpoint['model_state_dict']
            # Remover prefixos de frameworks (Lightning, DDP, etc.)
            state_dict = _strip_state_dict_prefix(state_dict)
            config = infer_config_from_state_dict(state_dict)
            print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} "
                  f"(embed_dim={config.embed_dim}, heads={config.num_heads}, "
                  f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})")
            # Criar modelo com arquitetura customizada
            model = create_vit_from_config(config, device=device)
            # strict=False para suportar variações como CLIP (norm_pre, etc.)
            model.load_state_dict(state_dict, strict=False)
        else:
            # assume dict é um state_dict puro
            # Remover prefixos de frameworks (Lightning, DDP, etc.)
            checkpoint = _strip_state_dict_prefix(checkpoint)
            config = infer_config_from_state_dict(checkpoint)
            print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} "
                  f"(embed_dim={config.embed_dim}, heads={config.num_heads}, "
                  f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})")
            # Criar modelo com arquitetura customizada
            model = create_vit_from_config(config, device=device)
            # strict=False para suportar variações como CLIP (norm_pre, etc.)
            model.load_state_dict(checkpoint, strict=False)
    else:
        # modelo completo salvo via torch.save(model, ...)
        model = checkpoint
        # Validar estrutura
        is_valid, error_msg = validate_vit_structure(model)
        if not is_valid:
            raise ValueError(f"Modelo inválido: {error_msg}")
        config = infer_config_from_model(model)

    model = model.to(device)
    model.eval()
    
    # Garantir que config está preenchido
    if config is None:
        config = infer_config_from_model(model)
    
    return model, config


def load_model_and_labels(
    model_path: str,
    labels_file: Optional[str] = None,
    device: Optional[torch.device] = None,
) -> Tuple[torch.nn.Module, Optional[Dict[int, str]], Optional[str], ViTConfig]:
    """
    ** Função Principal **
    Carrega modelo e, se disponível, nomes de classes.

    Retorna: (model, class_names, origem_labels, config) onde origem_labels ∈ {"file", "checkpoint", "hf", None}
        None se não houver nomes de classes disponíveis.
        config contém a configuração da arquitetura ViT (embed_dim, num_heads, grid_size, etc.)
    """
    device = device or DEVICE_DEFAULT

    # Carregar diretamente do Hugging Face Hub (Transformers -> timm)
    if isinstance(model_path, str) and model_path.startswith("hf-model://"):
        model_id = model_path[len("hf-model://"):].strip("/")
        model, class_names, config = load_vit_from_huggingface(model_id, device=device)
        return model, class_names, 'hf', config

    checkpoint = load_checkpoint(model_path, device=device)
    class_names_ckpt = extract_class_names(checkpoint)
    # class_names_file = load_class_names_from_file(labels_file)
    # class_names = class_names_file or class_names_ckpt
    # source: Optional[str] = None
    # if class_names_file:
    #     source = 'file'
    # elif class_names_ckpt:
    #     source = 'checkpoint'

    class_names = class_names_ckpt
    source = 'checkpoint' if class_names_ckpt else None

    model, config = build_model_from_checkpoint(checkpoint, device=device)
    return model, class_names, source, config