File size: 25,259 Bytes
e11edb1 98cf39b e11edb1 d012f9c f8ab302 d012f9c e11edb1 98cf39b d012f9c 98cf39b d012f9c 98cf39b d012f9c 98cf39b d012f9c 98cf39b d012f9c 98cf39b d012f9c 98cf39b d012f9c 98cf39b d012f9c 98cf39b d012f9c 98cf39b f8ab302 99483e9 98cf39b f8ab302 99483e9 d012f9c 99483e9 d012f9c f8ab302 98cf39b f8ab302 e11edb1 d012f9c e11edb1 d012f9c e11edb1 98cf39b d012f9c 98cf39b e11edb1 98cf39b d012f9c e11edb1 d012f9c e11edb1 98cf39b d012f9c e11edb1 d012f9c 98cf39b d012f9c 98d71cd d012f9c 98cf39b d012f9c e11edb1 d012f9c 98cf39b d012f9c e11edb1 d012f9c 98cf39b e11edb1 98cf39b e11edb1 98cf39b e11edb1 98cf39b e11edb1 98cf39b e11edb1 f8ab302 98cf39b f8ab302 e11edb1 ceb2d70 e11edb1 98cf39b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 | import pickle
import torch
import timm
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any
# Importar VisionTransformer diretamente para criar modelos com arquiteturas customizadas
try:
from timm.models.vision_transformer import VisionTransformer
except ImportError:
VisionTransformer = None
try:
from transformers import AutoModelForImageClassification
except Exception: # pragma: no cover
AutoModelForImageClassification = None
# Suporte a safetensors (formato moderno do HuggingFace)
try:
from safetensors.torch import load_file as load_safetensors
except ImportError:
load_safetensors = None
DEVICE_DEFAULT = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@dataclass
class ViTConfig:
"""Configuração de arquitetura ViT extraída dinamicamente do modelo."""
embed_dim: int = 768
num_heads: int = 12
num_layers: int = 12
patch_size: int = 16
img_size: int = 224
num_classes: int = 1000
mlp_ratio: float = 4.0
qkv_bias: bool = True
@property
def grid_size(self) -> int:
"""Tamanho do grid de patches (ex: 224/16 = 14)."""
return self.img_size // self.patch_size
@property
def num_patches(self) -> int:
"""Número total de patches (ex: 14*14 = 196)."""
return self.grid_size ** 2
@property
def timm_model_name(self) -> str:
"""Retorna o nome do modelo timm correspondente (para fins informativos)."""
# Mapeamento baseado em embed_dim e num_heads
size_map = {
(192, 3): 'tiny',
(384, 6): 'small',
(768, 12): 'base',
(1024, 16): 'large',
(1280, 16): 'huge',
}
size = size_map.get((self.embed_dim, self.num_heads), 'custom')
return f"vit_{size}_patch{self.patch_size}_{self.img_size}"
def create_vit_from_config(config: ViTConfig, device: Optional[torch.device] = None) -> torch.nn.Module:
"""Cria um modelo ViT diretamente a partir da configuração inferida.
Isso permite criar modelos com arquiteturas arbitrárias, não limitadas
aos nomes predefinidos do timm (vit_base_patch16_224, etc.).
"""
device = device or DEVICE_DEFAULT
if VisionTransformer is None:
raise RuntimeError("VisionTransformer não disponível. Verifique a instalação do timm.")
model = VisionTransformer(
img_size=config.img_size,
patch_size=config.patch_size,
in_chans=3,
num_classes=config.num_classes,
embed_dim=config.embed_dim,
depth=config.num_layers,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio,
qkv_bias=config.qkv_bias,
class_token=True,
global_pool='token',
)
return model.to(device)
def _strip_state_dict_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Remove prefixos comuns de frameworks (Lightning, DDP, etc.) das keys do state_dict.
Prefixos tratados:
- 'model.' (PyTorch Lightning)
- 'module.' (DataParallel/DistributedDataParallel)
- 'encoder.' (alguns frameworks de self-supervised learning)
- 'backbone.' (alguns frameworks de detecção)
Returns:
state_dict com keys sem prefixo
"""
prefixes = ['model.', 'module.', 'encoder.', 'backbone.']
# Verificar se alguma key tem prefixo
has_prefix = False
detected_prefix = None
for key in state_dict.keys():
for prefix in prefixes:
if key.startswith(prefix):
has_prefix = True
detected_prefix = prefix
break
if has_prefix:
break
if not has_prefix:
return state_dict
print(f"[ViTViz] Detectado prefixo '{detected_prefix}' nas keys do state_dict (Lightning/DDP). Removendo...")
new_sd: Dict[str, torch.Tensor] = {}
for key, value in state_dict.items():
new_key = key
for prefix in prefixes:
if key.startswith(prefix):
new_key = key[len(prefix):]
break
new_sd[new_key] = value
return new_sd
def validate_vit_structure(model: torch.nn.Module) -> Tuple[bool, str]:
"""Valida se o modelo tem a estrutura esperada de um ViT timm-compatível.
Returns:
(is_valid, error_message) - se inválido, error_message descreve o problema
"""
if not hasattr(model, 'blocks'):
return False, "Modelo não tem atributo 'blocks'. Não é um ViT compatível."
if len(model.blocks) == 0:
return False, "Modelo tem 'blocks' vazio."
block = model.blocks[0]
if not hasattr(block, 'attn'):
return False, "Bloco não tem atributo 'attn'. Estrutura incompatível."
attn = block.attn
if not hasattr(attn, 'qkv'):
return False, "Módulo de atenção não tem 'qkv'. Estrutura incompatível."
if not hasattr(attn, 'num_heads'):
return False, "Módulo de atenção não tem 'num_heads'. Estrutura incompatível."
return True, ""
def infer_config_from_model(model: torch.nn.Module) -> ViTConfig:
"""Infere configuração ViT a partir de um modelo timm carregado."""
config = ViTConfig()
# Extrair img_size e patch_size do patch_embed
if hasattr(model, 'patch_embed'):
pe = model.patch_embed
if hasattr(pe, 'img_size'):
img_size = pe.img_size
config.img_size = img_size[0] if isinstance(img_size, (tuple, list)) else img_size
if hasattr(pe, 'patch_size'):
patch_size = pe.patch_size
config.patch_size = patch_size[0] if isinstance(patch_size, (tuple, list)) else patch_size
# Extrair num_layers, embed_dim, num_heads dos blocks
if hasattr(model, 'blocks') and len(model.blocks) > 0:
config.num_layers = len(model.blocks)
block = model.blocks[0]
if hasattr(block, 'attn'):
attn = block.attn
if hasattr(attn, 'num_heads'):
config.num_heads = attn.num_heads
if hasattr(attn, 'qkv') and hasattr(attn.qkv, 'in_features'):
config.embed_dim = attn.qkv.in_features
# Extrair num_classes do head
if hasattr(model, 'head') and hasattr(model.head, 'out_features'):
config.num_classes = model.head.out_features
elif hasattr(model, 'head') and hasattr(model.head, 'weight'):
config.num_classes = model.head.weight.shape[0]
return config
def infer_config_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> ViTConfig:
"""Infere configuração ViT a partir de um state_dict."""
config = ViTConfig()
# Inferir num_layers contando blocks
layer_indices = set()
for key in state_dict.keys():
if key.startswith('blocks.') and '.attn.' in key:
# blocks.0.attn.qkv.weight -> extrair 0
idx = int(key.split('.')[1])
layer_indices.add(idx)
if layer_indices:
config.num_layers = max(layer_indices) + 1
# Inferir embed_dim do primeiro bloco
qkv_key = 'blocks.0.attn.qkv.weight'
if qkv_key in state_dict:
qkv_weight = state_dict[qkv_key]
# qkv.weight shape: [3*embed_dim, embed_dim]
config.embed_dim = qkv_weight.shape[1]
# Inferir num_heads diretamente: qkv tem shape [3*embed_dim, embed_dim]
# O output é 3*embed_dim = 3*num_heads*head_dim
# Podemos calcular num_heads = (qkv_out // 3) // head_dim
# Mas head_dim varia. Tentamos inferir de outra forma.
# Inferir num_heads: tentar múltiplos métodos
proj_key = 'blocks.0.attn.proj.weight'
if proj_key in state_dict and qkv_key in state_dict:
embed_dim = state_dict[proj_key].shape[0]
qkv_out = state_dict[qkv_key].shape[0] # 3*embed_dim
# Método 1: Se qkv_out == 3*embed_dim, tentar head_dim comum (64, 32, 96)
if qkv_out == 3 * embed_dim:
# Testar head_dims comuns em ordem de preferência
for head_dim in [64, 32, 96, 48, 128]:
if embed_dim % head_dim == 0:
config.num_heads = embed_dim // head_dim
break
else:
# Fallback: assumir que num_heads divide embed_dim uniformemente
# Tentar valores comuns de num_heads
for nh in [12, 16, 8, 6, 24, 4, 3]:
if embed_dim % nh == 0:
config.num_heads = nh
break
# Inferir qkv_bias
qkv_bias_key = 'blocks.0.attn.qkv.bias'
config.qkv_bias = qkv_bias_key in state_dict
# Inferir mlp_ratio do MLP
mlp_fc1_key = 'blocks.0.mlp.fc1.weight'
if mlp_fc1_key in state_dict and config.embed_dim > 0:
mlp_hidden = state_dict[mlp_fc1_key].shape[0]
config.mlp_ratio = mlp_hidden / config.embed_dim
# Inferir num_classes do head
head_key = 'head.weight'
if head_key in state_dict:
config.num_classes = state_dict[head_key].shape[0]
# Inferir patch_size e img_size do patch_embed
patch_proj_key = 'patch_embed.proj.weight'
if patch_proj_key in state_dict:
# shape: [embed_dim, 3, patch_size, patch_size]
patch_weight = state_dict[patch_proj_key]
config.patch_size = patch_weight.shape[2]
# Inferir img_size do pos_embed
pos_embed_key = 'pos_embed'
if pos_embed_key in state_dict:
# shape: [1, num_patches+1, embed_dim]
num_tokens = state_dict[pos_embed_key].shape[1]
num_patches = num_tokens - 1 # -1 para CLS token
grid_size = int(num_patches ** 0.5)
config.img_size = grid_size * config.patch_size
return config
def _hf_id2label_to_class_names(id2label: Any) -> Optional[Dict[int, str]]:
if not isinstance(id2label, dict):
return None
out: Dict[int, str] = {}
for k, v in id2label.items():
try:
out[int(k)] = str(v)
except Exception:
continue
return out or None
def _convert_hf_vit_to_timm_state_dict(hf_sd: Dict[str, torch.Tensor], num_layers: int) -> Dict[str, torch.Tensor]:
"""Converte state_dict de ViT (Hugging Face Transformers) para chaves do timm ViT.
Alvo: timm "vit_base_patch16_224".
"""
out: Dict[str, torch.Tensor] = {}
def get(key: str) -> torch.Tensor:
if key not in hf_sd:
raise KeyError(f"Missing key in HF state_dict: {key}")
return hf_sd[key]
# Embeddings
out["cls_token"] = get("vit.embeddings.cls_token")
out["pos_embed"] = get("vit.embeddings.position_embeddings")
out["patch_embed.proj.weight"] = get("vit.embeddings.patch_embeddings.projection.weight")
out["patch_embed.proj.bias"] = get("vit.embeddings.patch_embeddings.projection.bias")
# Encoder blocks
for i in range(num_layers):
prefix = f"vit.encoder.layer.{i}"
out[f"blocks.{i}.norm1.weight"] = get(f"{prefix}.layernorm_before.weight")
out[f"blocks.{i}.norm1.bias"] = get(f"{prefix}.layernorm_before.bias")
out[f"blocks.{i}.norm2.weight"] = get(f"{prefix}.layernorm_after.weight")
out[f"blocks.{i}.norm2.bias"] = get(f"{prefix}.layernorm_after.bias")
qw = get(f"{prefix}.attention.attention.query.weight")
kw = get(f"{prefix}.attention.attention.key.weight")
vw = get(f"{prefix}.attention.attention.value.weight")
qb = get(f"{prefix}.attention.attention.query.bias")
kb = get(f"{prefix}.attention.attention.key.bias")
vb = get(f"{prefix}.attention.attention.value.bias")
out[f"blocks.{i}.attn.qkv.weight"] = torch.cat([qw, kw, vw], dim=0)
out[f"blocks.{i}.attn.qkv.bias"] = torch.cat([qb, kb, vb], dim=0)
out[f"blocks.{i}.attn.proj.weight"] = get(f"{prefix}.attention.output.dense.weight")
out[f"blocks.{i}.attn.proj.bias"] = get(f"{prefix}.attention.output.dense.bias")
out[f"blocks.{i}.mlp.fc1.weight"] = get(f"{prefix}.intermediate.dense.weight")
out[f"blocks.{i}.mlp.fc1.bias"] = get(f"{prefix}.intermediate.dense.bias")
out[f"blocks.{i}.mlp.fc2.weight"] = get(f"{prefix}.output.dense.weight")
out[f"blocks.{i}.mlp.fc2.bias"] = get(f"{prefix}.output.dense.bias")
out["norm.weight"] = get("vit.layernorm.weight")
out["norm.bias"] = get("vit.layernorm.bias")
# Classifier
if "classifier.weight" in hf_sd and "classifier.bias" in hf_sd:
out["head.weight"] = get("classifier.weight")
out["head.bias"] = get("classifier.bias")
return out
def _convert_hf_timm_wrapper_to_timm_state_dict(hf_sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Converte state_dict de TimmWrapper (Transformers) para formato timm ViT.
Exemplo de origem: chaves com prefixo ``timm_model.``.
"""
out: Dict[str, torch.Tensor] = {}
for key, value in hf_sd.items():
if key.startswith("timm_model."):
out[key[len("timm_model."):]] = value
elif key.startswith("classifier."):
# Alguns wrappers usam head separado como classifier.
out[f"head.{key[len('classifier.'):]}"] = value
if not out:
raise ValueError("State_dict de TimmWrapper sem chaves reconhecidas (timm_model.* / classifier.*).")
return out
def load_vit_from_huggingface(model_id: str, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, Optional[Dict[int, str]], ViTConfig]:
"""Carrega ViT do Hugging Face Hub e retorna um modelo timm equivalente.
Returns:
(model, class_names, config)
"""
if AutoModelForImageClassification is None:
raise RuntimeError("transformers não está instalado; instale 'transformers' para carregar do Hugging Face.")
device = device or DEVICE_DEFAULT
hf_model = AutoModelForImageClassification.from_pretrained(model_id)
hf_model.eval()
cfg = getattr(hf_model, "config", None)
class_names = _hf_id2label_to_class_names(getattr(cfg, "id2label", None)) if cfg is not None else None
hf_sd = hf_model.state_dict()
if any(key.startswith("timm_model.") for key in hf_sd.keys()):
timm_sd = _convert_hf_timm_wrapper_to_timm_state_dict(hf_sd)
else:
num_layers = int(getattr(cfg, "num_hidden_layers", 12)) if cfg is not None else 12
timm_sd = _convert_hf_vit_to_timm_state_dict(hf_sd, num_layers=num_layers)
vit_config = infer_config_from_state_dict(timm_sd)
if cfg is not None and hasattr(cfg, "num_labels"):
try:
vit_config.num_classes = int(getattr(cfg, "num_labels"))
except Exception:
pass
print(f"[ViTViz] Carregando do HuggingFace: {vit_config.timm_model_name} "
f"(embed_dim={vit_config.embed_dim}, heads={vit_config.num_heads}, "
f"layers={vit_config.num_layers}, patch={vit_config.patch_size}, img={vit_config.img_size})")
timm_model = create_vit_from_config(vit_config, device=device)
timm_model.load_state_dict(timm_sd, strict=False)
timm_model.eval()
return timm_model, class_names, vit_config
class CustomUnpickler(pickle.Unpickler):
"""Unpickler que ignora classes customizadas ausentes criando dummies dinamicamente."""
def find_class(self, module, name):
try:
return super().find_class(module, name)
except Exception:
# Cria uma classe dummy com o mesmo nome para permitir o unpickle
return type(name, (), {})
def load_checkpoint(model_path: str, device: Optional[torch.device] = None) -> Any:
"""Carrega um checkpoint/modelo do caminho informado.
Suporta formatos:
- .pth / .pt: PyTorch checkpoint (torch.load)
- .safetensors: Formato moderno do HuggingFace (mais seguro e rápido)
Retorna o objeto carregado (modelo completo, state_dict ou dict de checkpoint).
"""
device = device or DEVICE_DEFAULT
# Detectar formato safetensors
if model_path.endswith('.safetensors'):
if load_safetensors is None:
raise ImportError(
"safetensors não está instalado. Instale com: pip install safetensors"
)
# safetensors sempre retorna um state_dict (não suporta modelo completo)
state_dict = load_safetensors(model_path, device=str(device))
return state_dict
# Formato PyTorch padrão (.pth, .pt, .ckpt, etc.)
try:
return torch.load(model_path, map_location=device, weights_only=False)
except (AttributeError, ModuleNotFoundError, RuntimeError):
# Fallback quando há classes ausentes ou conflitos de versão
with open(model_path, 'rb') as f:
return CustomUnpickler(f).load()
def infer_num_classes(state_dict: Dict[str, torch.Tensor]) -> int:
"""Infere o número de classes a partir do state_dict (camada de head).
Caso não encontre, retorna 1000 (padrão ImageNet).
"""
for key, tensor in state_dict.items():
if 'head' in key and 'weight' in key and hasattr(tensor, 'shape'):
return tensor.shape[0]
return 1000
def extract_class_names(checkpoint: Any) -> Optional[Dict[int, str]]:
"""Tenta extrair nomes de classes de um checkpoint (se presente)."""
if not isinstance(checkpoint, dict):
return None
possible_keys = [
'class_names', 'classes', 'class_to_idx', 'idx_to_class',
'label_names', 'labels', 'class_labels'
]
for key in possible_keys:
if key in checkpoint:
labels = checkpoint[key]
if isinstance(labels, list):
return {i: name for i, name in enumerate(labels)}
if isinstance(labels, dict):
# Se já for idx->nome
if all(isinstance(k, int) for k in labels.keys()):
return labels # type: ignore[return-value]
# Se for nome->idx
if all(isinstance(v, int) for v in labels.values()):
return {v: k for k, v in labels.items()}
return labels # type: ignore[return-value]
return None
def load_class_names_from_file(labels_file: Optional[str]) -> Optional[Dict[int, str]]:
"""Carrega nomes de classes de um arquivo .txt (um por linha) ou .json (lista ou dict)."""
if not labels_file:
return None
import json
try:
if labels_file.endswith('.json'):
with open(labels_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list):
return {i: name for i, name in enumerate(data)}
if isinstance(data, dict):
out: Dict[int, str] = {}
for k, v in data.items():
try:
out[int(k)] = v
except Exception:
# Ignora chaves não numéricas
pass
if out:
return out
# fallback se for nome->idx
if all(isinstance(v, int) for v in data.values()):
return {v: k for k, v in data.items()}
return None
else:
with open(labels_file, 'r', encoding='utf-8') as f:
lines = [line.strip() for line in f if line.strip()]
return {i: name for i, name in enumerate(lines)}
except Exception:
return None
def build_model_from_checkpoint(checkpoint: Any, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, ViTConfig]:
"""Constroi um modelo a partir de um checkpoint que pode ser um dict, state_dict ou o próprio modelo.
Suporta arquiteturas ViT arbitrárias, não limitadas aos nomes predefinidos do timm.
Returns:
(model, config) - modelo carregado e configuração inferida
"""
device = device or DEVICE_DEFAULT
config: Optional[ViTConfig] = None
# Detectar e logar se é checkpoint PyTorch Lightning
if isinstance(checkpoint, dict) and 'pytorch-lightning_version' in checkpoint:
print(f"[ViTViz] Detectado checkpoint PyTorch Lightning (v{checkpoint.get('pytorch-lightning_version', '?')})")
if isinstance(checkpoint, dict):
if 'model' in checkpoint:
# Modelo completo dentro do dict
model = checkpoint['model']
config = infer_config_from_model(model)
# Validar estrutura
is_valid, error_msg = validate_vit_structure(model)
if not is_valid:
raise ValueError(f"Modelo inválido: {error_msg}")
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
# Remover prefixos de frameworks (Lightning, DDP, etc.)
state_dict = _strip_state_dict_prefix(state_dict)
config = infer_config_from_state_dict(state_dict)
print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} "
f"(embed_dim={config.embed_dim}, heads={config.num_heads}, "
f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})")
# Criar modelo com arquitetura customizada
model = create_vit_from_config(config, device=device)
# strict=False para suportar variações como CLIP (norm_pre, etc.)
model.load_state_dict(state_dict, strict=False)
elif 'model_state_dict' in checkpoint:
# Novo formato com class_names embutidas
state_dict = checkpoint['model_state_dict']
# Remover prefixos de frameworks (Lightning, DDP, etc.)
state_dict = _strip_state_dict_prefix(state_dict)
config = infer_config_from_state_dict(state_dict)
print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} "
f"(embed_dim={config.embed_dim}, heads={config.num_heads}, "
f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})")
# Criar modelo com arquitetura customizada
model = create_vit_from_config(config, device=device)
# strict=False para suportar variações como CLIP (norm_pre, etc.)
model.load_state_dict(state_dict, strict=False)
else:
# assume dict é um state_dict puro
# Remover prefixos de frameworks (Lightning, DDP, etc.)
checkpoint = _strip_state_dict_prefix(checkpoint)
config = infer_config_from_state_dict(checkpoint)
print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} "
f"(embed_dim={config.embed_dim}, heads={config.num_heads}, "
f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})")
# Criar modelo com arquitetura customizada
model = create_vit_from_config(config, device=device)
# strict=False para suportar variações como CLIP (norm_pre, etc.)
model.load_state_dict(checkpoint, strict=False)
else:
# modelo completo salvo via torch.save(model, ...)
model = checkpoint
# Validar estrutura
is_valid, error_msg = validate_vit_structure(model)
if not is_valid:
raise ValueError(f"Modelo inválido: {error_msg}")
config = infer_config_from_model(model)
model = model.to(device)
model.eval()
# Garantir que config está preenchido
if config is None:
config = infer_config_from_model(model)
return model, config
def load_model_and_labels(
model_path: str,
labels_file: Optional[str] = None,
device: Optional[torch.device] = None,
) -> Tuple[torch.nn.Module, Optional[Dict[int, str]], Optional[str], ViTConfig]:
"""
** Função Principal **
Carrega modelo e, se disponível, nomes de classes.
Retorna: (model, class_names, origem_labels, config) onde origem_labels ∈ {"file", "checkpoint", "hf", None}
None se não houver nomes de classes disponíveis.
config contém a configuração da arquitetura ViT (embed_dim, num_heads, grid_size, etc.)
"""
device = device or DEVICE_DEFAULT
# Carregar diretamente do Hugging Face Hub (Transformers -> timm)
if isinstance(model_path, str) and model_path.startswith("hf-model://"):
model_id = model_path[len("hf-model://"):].strip("/")
model, class_names, config = load_vit_from_huggingface(model_id, device=device)
return model, class_names, 'hf', config
checkpoint = load_checkpoint(model_path, device=device)
class_names_ckpt = extract_class_names(checkpoint)
# class_names_file = load_class_names_from_file(labels_file)
# class_names = class_names_file or class_names_ckpt
# source: Optional[str] = None
# if class_names_file:
# source = 'file'
# elif class_names_ckpt:
# source = 'checkpoint'
class_names = class_names_ckpt
source = 'checkpoint' if class_names_ckpt else None
model, config = build_model_from_checkpoint(checkpoint, device=device)
return model, class_names, source, config
|