ViTViz / utils /model_loader.py
lucasddmc's picture
feat: make possible to use different ViT formats and architectures
d012f9c
import pickle
import torch
import timm
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any
# Importar VisionTransformer diretamente para criar modelos com arquiteturas customizadas
try:
from timm.models.vision_transformer import VisionTransformer
except ImportError:
VisionTransformer = None
try:
from transformers import AutoModelForImageClassification
except Exception: # pragma: no cover
AutoModelForImageClassification = None
# Suporte a safetensors (formato moderno do HuggingFace)
try:
from safetensors.torch import load_file as load_safetensors
except ImportError:
load_safetensors = None
DEVICE_DEFAULT = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@dataclass
class ViTConfig:
"""Configuração de arquitetura ViT extraída dinamicamente do modelo."""
embed_dim: int = 768
num_heads: int = 12
num_layers: int = 12
patch_size: int = 16
img_size: int = 224
num_classes: int = 1000
mlp_ratio: float = 4.0
qkv_bias: bool = True
@property
def grid_size(self) -> int:
"""Tamanho do grid de patches (ex: 224/16 = 14)."""
return self.img_size // self.patch_size
@property
def num_patches(self) -> int:
"""Número total de patches (ex: 14*14 = 196)."""
return self.grid_size ** 2
@property
def timm_model_name(self) -> str:
"""Retorna o nome do modelo timm correspondente (para fins informativos)."""
# Mapeamento baseado em embed_dim e num_heads
size_map = {
(192, 3): 'tiny',
(384, 6): 'small',
(768, 12): 'base',
(1024, 16): 'large',
(1280, 16): 'huge',
}
size = size_map.get((self.embed_dim, self.num_heads), 'custom')
return f"vit_{size}_patch{self.patch_size}_{self.img_size}"
def create_vit_from_config(config: ViTConfig, device: Optional[torch.device] = None) -> torch.nn.Module:
"""Cria um modelo ViT diretamente a partir da configuração inferida.
Isso permite criar modelos com arquiteturas arbitrárias, não limitadas
aos nomes predefinidos do timm (vit_base_patch16_224, etc.).
"""
device = device or DEVICE_DEFAULT
if VisionTransformer is None:
raise RuntimeError("VisionTransformer não disponível. Verifique a instalação do timm.")
model = VisionTransformer(
img_size=config.img_size,
patch_size=config.patch_size,
in_chans=3,
num_classes=config.num_classes,
embed_dim=config.embed_dim,
depth=config.num_layers,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio,
qkv_bias=config.qkv_bias,
class_token=True,
global_pool='token',
)
return model.to(device)
def _strip_state_dict_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Remove prefixos comuns de frameworks (Lightning, DDP, etc.) das keys do state_dict.
Prefixos tratados:
- 'model.' (PyTorch Lightning)
- 'module.' (DataParallel/DistributedDataParallel)
- 'encoder.' (alguns frameworks de self-supervised learning)
- 'backbone.' (alguns frameworks de detecção)
Returns:
state_dict com keys sem prefixo
"""
prefixes = ['model.', 'module.', 'encoder.', 'backbone.']
# Verificar se alguma key tem prefixo
has_prefix = False
detected_prefix = None
for key in state_dict.keys():
for prefix in prefixes:
if key.startswith(prefix):
has_prefix = True
detected_prefix = prefix
break
if has_prefix:
break
if not has_prefix:
return state_dict
print(f"[ViTViz] Detectado prefixo '{detected_prefix}' nas keys do state_dict (Lightning/DDP). Removendo...")
new_sd: Dict[str, torch.Tensor] = {}
for key, value in state_dict.items():
new_key = key
for prefix in prefixes:
if key.startswith(prefix):
new_key = key[len(prefix):]
break
new_sd[new_key] = value
return new_sd
def validate_vit_structure(model: torch.nn.Module) -> Tuple[bool, str]:
"""Valida se o modelo tem a estrutura esperada de um ViT timm-compatível.
Returns:
(is_valid, error_message) - se inválido, error_message descreve o problema
"""
if not hasattr(model, 'blocks'):
return False, "Modelo não tem atributo 'blocks'. Não é um ViT compatível."
if len(model.blocks) == 0:
return False, "Modelo tem 'blocks' vazio."
block = model.blocks[0]
if not hasattr(block, 'attn'):
return False, "Bloco não tem atributo 'attn'. Estrutura incompatível."
attn = block.attn
if not hasattr(attn, 'qkv'):
return False, "Módulo de atenção não tem 'qkv'. Estrutura incompatível."
if not hasattr(attn, 'num_heads'):
return False, "Módulo de atenção não tem 'num_heads'. Estrutura incompatível."
return True, ""
def infer_config_from_model(model: torch.nn.Module) -> ViTConfig:
"""Infere configuração ViT a partir de um modelo timm carregado."""
config = ViTConfig()
# Extrair img_size e patch_size do patch_embed
if hasattr(model, 'patch_embed'):
pe = model.patch_embed
if hasattr(pe, 'img_size'):
img_size = pe.img_size
config.img_size = img_size[0] if isinstance(img_size, (tuple, list)) else img_size
if hasattr(pe, 'patch_size'):
patch_size = pe.patch_size
config.patch_size = patch_size[0] if isinstance(patch_size, (tuple, list)) else patch_size
# Extrair num_layers, embed_dim, num_heads dos blocks
if hasattr(model, 'blocks') and len(model.blocks) > 0:
config.num_layers = len(model.blocks)
block = model.blocks[0]
if hasattr(block, 'attn'):
attn = block.attn
if hasattr(attn, 'num_heads'):
config.num_heads = attn.num_heads
if hasattr(attn, 'qkv') and hasattr(attn.qkv, 'in_features'):
config.embed_dim = attn.qkv.in_features
# Extrair num_classes do head
if hasattr(model, 'head') and hasattr(model.head, 'out_features'):
config.num_classes = model.head.out_features
elif hasattr(model, 'head') and hasattr(model.head, 'weight'):
config.num_classes = model.head.weight.shape[0]
return config
def infer_config_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> ViTConfig:
"""Infere configuração ViT a partir de um state_dict."""
config = ViTConfig()
# Inferir num_layers contando blocks
layer_indices = set()
for key in state_dict.keys():
if key.startswith('blocks.') and '.attn.' in key:
# blocks.0.attn.qkv.weight -> extrair 0
idx = int(key.split('.')[1])
layer_indices.add(idx)
if layer_indices:
config.num_layers = max(layer_indices) + 1
# Inferir embed_dim do primeiro bloco
qkv_key = 'blocks.0.attn.qkv.weight'
if qkv_key in state_dict:
qkv_weight = state_dict[qkv_key]
# qkv.weight shape: [3*embed_dim, embed_dim]
config.embed_dim = qkv_weight.shape[1]
# Inferir num_heads diretamente: qkv tem shape [3*embed_dim, embed_dim]
# O output é 3*embed_dim = 3*num_heads*head_dim
# Podemos calcular num_heads = (qkv_out // 3) // head_dim
# Mas head_dim varia. Tentamos inferir de outra forma.
# Inferir num_heads: tentar múltiplos métodos
proj_key = 'blocks.0.attn.proj.weight'
if proj_key in state_dict and qkv_key in state_dict:
embed_dim = state_dict[proj_key].shape[0]
qkv_out = state_dict[qkv_key].shape[0] # 3*embed_dim
# Método 1: Se qkv_out == 3*embed_dim, tentar head_dim comum (64, 32, 96)
if qkv_out == 3 * embed_dim:
# Testar head_dims comuns em ordem de preferência
for head_dim in [64, 32, 96, 48, 128]:
if embed_dim % head_dim == 0:
config.num_heads = embed_dim // head_dim
break
else:
# Fallback: assumir que num_heads divide embed_dim uniformemente
# Tentar valores comuns de num_heads
for nh in [12, 16, 8, 6, 24, 4, 3]:
if embed_dim % nh == 0:
config.num_heads = nh
break
# Inferir qkv_bias
qkv_bias_key = 'blocks.0.attn.qkv.bias'
config.qkv_bias = qkv_bias_key in state_dict
# Inferir mlp_ratio do MLP
mlp_fc1_key = 'blocks.0.mlp.fc1.weight'
if mlp_fc1_key in state_dict and config.embed_dim > 0:
mlp_hidden = state_dict[mlp_fc1_key].shape[0]
config.mlp_ratio = mlp_hidden / config.embed_dim
# Inferir num_classes do head
head_key = 'head.weight'
if head_key in state_dict:
config.num_classes = state_dict[head_key].shape[0]
# Inferir patch_size e img_size do patch_embed
patch_proj_key = 'patch_embed.proj.weight'
if patch_proj_key in state_dict:
# shape: [embed_dim, 3, patch_size, patch_size]
patch_weight = state_dict[patch_proj_key]
config.patch_size = patch_weight.shape[2]
# Inferir img_size do pos_embed
pos_embed_key = 'pos_embed'
if pos_embed_key in state_dict:
# shape: [1, num_patches+1, embed_dim]
num_tokens = state_dict[pos_embed_key].shape[1]
num_patches = num_tokens - 1 # -1 para CLS token
grid_size = int(num_patches ** 0.5)
config.img_size = grid_size * config.patch_size
return config
def _hf_id2label_to_class_names(id2label: Any) -> Optional[Dict[int, str]]:
if not isinstance(id2label, dict):
return None
out: Dict[int, str] = {}
for k, v in id2label.items():
try:
out[int(k)] = str(v)
except Exception:
continue
return out or None
def _convert_hf_vit_to_timm_state_dict(hf_sd: Dict[str, torch.Tensor], num_layers: int) -> Dict[str, torch.Tensor]:
"""Converte state_dict de ViT (Hugging Face Transformers) para chaves do timm ViT.
Alvo: timm "vit_base_patch16_224".
"""
out: Dict[str, torch.Tensor] = {}
def get(key: str) -> torch.Tensor:
if key not in hf_sd:
raise KeyError(f"Missing key in HF state_dict: {key}")
return hf_sd[key]
# Embeddings
out["cls_token"] = get("vit.embeddings.cls_token")
out["pos_embed"] = get("vit.embeddings.position_embeddings")
out["patch_embed.proj.weight"] = get("vit.embeddings.patch_embeddings.projection.weight")
out["patch_embed.proj.bias"] = get("vit.embeddings.patch_embeddings.projection.bias")
# Encoder blocks
for i in range(num_layers):
prefix = f"vit.encoder.layer.{i}"
out[f"blocks.{i}.norm1.weight"] = get(f"{prefix}.layernorm_before.weight")
out[f"blocks.{i}.norm1.bias"] = get(f"{prefix}.layernorm_before.bias")
out[f"blocks.{i}.norm2.weight"] = get(f"{prefix}.layernorm_after.weight")
out[f"blocks.{i}.norm2.bias"] = get(f"{prefix}.layernorm_after.bias")
qw = get(f"{prefix}.attention.attention.query.weight")
kw = get(f"{prefix}.attention.attention.key.weight")
vw = get(f"{prefix}.attention.attention.value.weight")
qb = get(f"{prefix}.attention.attention.query.bias")
kb = get(f"{prefix}.attention.attention.key.bias")
vb = get(f"{prefix}.attention.attention.value.bias")
out[f"blocks.{i}.attn.qkv.weight"] = torch.cat([qw, kw, vw], dim=0)
out[f"blocks.{i}.attn.qkv.bias"] = torch.cat([qb, kb, vb], dim=0)
out[f"blocks.{i}.attn.proj.weight"] = get(f"{prefix}.attention.output.dense.weight")
out[f"blocks.{i}.attn.proj.bias"] = get(f"{prefix}.attention.output.dense.bias")
out[f"blocks.{i}.mlp.fc1.weight"] = get(f"{prefix}.intermediate.dense.weight")
out[f"blocks.{i}.mlp.fc1.bias"] = get(f"{prefix}.intermediate.dense.bias")
out[f"blocks.{i}.mlp.fc2.weight"] = get(f"{prefix}.output.dense.weight")
out[f"blocks.{i}.mlp.fc2.bias"] = get(f"{prefix}.output.dense.bias")
out["norm.weight"] = get("vit.layernorm.weight")
out["norm.bias"] = get("vit.layernorm.bias")
# Classifier
if "classifier.weight" in hf_sd and "classifier.bias" in hf_sd:
out["head.weight"] = get("classifier.weight")
out["head.bias"] = get("classifier.bias")
return out
def load_vit_from_huggingface(model_id: str, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, Optional[Dict[int, str]], ViTConfig]:
"""Carrega ViT do Hugging Face Hub e retorna um modelo timm equivalente.
Returns:
(model, class_names, config)
"""
if AutoModelForImageClassification is None:
raise RuntimeError("transformers não está instalado; instale 'transformers' para carregar do Hugging Face.")
device = device or DEVICE_DEFAULT
hf_model = AutoModelForImageClassification.from_pretrained(model_id)
hf_model.eval()
cfg = getattr(hf_model, "config", None)
num_labels = int(getattr(cfg, "num_labels", 1000)) if cfg is not None else 1000
num_layers = int(getattr(cfg, "num_hidden_layers", 12)) if cfg is not None else 12
hidden_size = int(getattr(cfg, "hidden_size", 768)) if cfg is not None else 768
num_heads = int(getattr(cfg, "num_attention_heads", 12)) if cfg is not None else 12
patch_size = int(getattr(cfg, "patch_size", 16)) if cfg is not None else 16
img_size = int(getattr(cfg, "image_size", 224)) if cfg is not None else 224
intermediate_size = int(getattr(cfg, "intermediate_size", hidden_size * 4)) if cfg is not None else hidden_size * 4
qkv_bias = bool(getattr(cfg, "qkv_bias", True)) if cfg is not None else True
class_names = _hf_id2label_to_class_names(getattr(cfg, "id2label", None)) if cfg is not None else None
# Criar config dinâmico
vit_config = ViTConfig(
embed_dim=hidden_size,
num_heads=num_heads,
num_layers=num_layers,
patch_size=patch_size,
img_size=img_size,
num_classes=num_labels,
mlp_ratio=intermediate_size / hidden_size,
qkv_bias=qkv_bias
)
print(f"[ViTViz] Carregando do HuggingFace: {vit_config.timm_model_name} "
f"(embed_dim={vit_config.embed_dim}, heads={vit_config.num_heads}, "
f"layers={vit_config.num_layers})")
# Criar modelo com arquitetura customizada diretamente
timm_model = create_vit_from_config(vit_config, device=device)
# Converter e carregar state_dict
timm_sd = _convert_hf_vit_to_timm_state_dict(hf_model.state_dict(), num_layers=num_layers)
timm_model.load_state_dict(timm_sd, strict=False)
timm_model.eval()
return timm_model, class_names, vit_config
class CustomUnpickler(pickle.Unpickler):
"""Unpickler que ignora classes customizadas ausentes criando dummies dinamicamente."""
def find_class(self, module, name):
try:
return super().find_class(module, name)
except Exception:
# Cria uma classe dummy com o mesmo nome para permitir o unpickle
return type(name, (), {})
def load_checkpoint(model_path: str, device: Optional[torch.device] = None) -> Any:
"""Carrega um checkpoint/modelo do caminho informado.
Suporta formatos:
- .pth / .pt: PyTorch checkpoint (torch.load)
- .safetensors: Formato moderno do HuggingFace (mais seguro e rápido)
Retorna o objeto carregado (modelo completo, state_dict ou dict de checkpoint).
"""
device = device or DEVICE_DEFAULT
# Detectar formato safetensors
if model_path.endswith('.safetensors'):
if load_safetensors is None:
raise ImportError(
"safetensors não está instalado. Instale com: pip install safetensors"
)
# safetensors sempre retorna um state_dict (não suporta modelo completo)
state_dict = load_safetensors(model_path, device=str(device))
return state_dict
# Formato PyTorch padrão (.pth, .pt, .ckpt, etc.)
try:
return torch.load(model_path, map_location=device, weights_only=False)
except (AttributeError, ModuleNotFoundError, RuntimeError):
# Fallback quando há classes ausentes ou conflitos de versão
with open(model_path, 'rb') as f:
return CustomUnpickler(f).load()
def infer_num_classes(state_dict: Dict[str, torch.Tensor]) -> int:
"""Infere o número de classes a partir do state_dict (camada de head).
Caso não encontre, retorna 1000 (padrão ImageNet).
"""
for key, tensor in state_dict.items():
if 'head' in key and 'weight' in key and hasattr(tensor, 'shape'):
return tensor.shape[0]
return 1000
def extract_class_names(checkpoint: Any) -> Optional[Dict[int, str]]:
"""Tenta extrair nomes de classes de um checkpoint (se presente)."""
if not isinstance(checkpoint, dict):
return None
possible_keys = [
'class_names', 'classes', 'class_to_idx', 'idx_to_class',
'label_names', 'labels', 'class_labels'
]
for key in possible_keys:
if key in checkpoint:
labels = checkpoint[key]
if isinstance(labels, list):
return {i: name for i, name in enumerate(labels)}
if isinstance(labels, dict):
# Se já for idx->nome
if all(isinstance(k, int) for k in labels.keys()):
return labels # type: ignore[return-value]
# Se for nome->idx
if all(isinstance(v, int) for v in labels.values()):
return {v: k for k, v in labels.items()}
return labels # type: ignore[return-value]
return None
def load_class_names_from_file(labels_file: Optional[str]) -> Optional[Dict[int, str]]:
"""Carrega nomes de classes de um arquivo .txt (um por linha) ou .json (lista ou dict)."""
if not labels_file:
return None
import json
try:
if labels_file.endswith('.json'):
with open(labels_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list):
return {i: name for i, name in enumerate(data)}
if isinstance(data, dict):
out: Dict[int, str] = {}
for k, v in data.items():
try:
out[int(k)] = v
except Exception:
# Ignora chaves não numéricas
pass
if out:
return out
# fallback se for nome->idx
if all(isinstance(v, int) for v in data.values()):
return {v: k for k, v in data.items()}
return None
else:
with open(labels_file, 'r', encoding='utf-8') as f:
lines = [line.strip() for line in f if line.strip()]
return {i: name for i, name in enumerate(lines)}
except Exception:
return None
def build_model_from_checkpoint(checkpoint: Any, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, ViTConfig]:
"""Constroi um modelo a partir de um checkpoint que pode ser um dict, state_dict ou o próprio modelo.
Suporta arquiteturas ViT arbitrárias, não limitadas aos nomes predefinidos do timm.
Returns:
(model, config) - modelo carregado e configuração inferida
"""
device = device or DEVICE_DEFAULT
config: Optional[ViTConfig] = None
# Detectar e logar se é checkpoint PyTorch Lightning
if isinstance(checkpoint, dict) and 'pytorch-lightning_version' in checkpoint:
print(f"[ViTViz] Detectado checkpoint PyTorch Lightning (v{checkpoint.get('pytorch-lightning_version', '?')})")
if isinstance(checkpoint, dict):
if 'model' in checkpoint:
# Modelo completo dentro do dict
model = checkpoint['model']
config = infer_config_from_model(model)
# Validar estrutura
is_valid, error_msg = validate_vit_structure(model)
if not is_valid:
raise ValueError(f"Modelo inválido: {error_msg}")
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
# Remover prefixos de frameworks (Lightning, DDP, etc.)
state_dict = _strip_state_dict_prefix(state_dict)
config = infer_config_from_state_dict(state_dict)
print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} "
f"(embed_dim={config.embed_dim}, heads={config.num_heads}, "
f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})")
# Criar modelo com arquitetura customizada
model = create_vit_from_config(config, device=device)
# strict=False para suportar variações como CLIP (norm_pre, etc.)
model.load_state_dict(state_dict, strict=False)
elif 'model_state_dict' in checkpoint:
# Novo formato com class_names embutidas
state_dict = checkpoint['model_state_dict']
# Remover prefixos de frameworks (Lightning, DDP, etc.)
state_dict = _strip_state_dict_prefix(state_dict)
config = infer_config_from_state_dict(state_dict)
print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} "
f"(embed_dim={config.embed_dim}, heads={config.num_heads}, "
f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})")
# Criar modelo com arquitetura customizada
model = create_vit_from_config(config, device=device)
# strict=False para suportar variações como CLIP (norm_pre, etc.)
model.load_state_dict(state_dict, strict=False)
else:
# assume dict é um state_dict puro
# Remover prefixos de frameworks (Lightning, DDP, etc.)
checkpoint = _strip_state_dict_prefix(checkpoint)
config = infer_config_from_state_dict(checkpoint)
print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} "
f"(embed_dim={config.embed_dim}, heads={config.num_heads}, "
f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})")
# Criar modelo com arquitetura customizada
model = create_vit_from_config(config, device=device)
# strict=False para suportar variações como CLIP (norm_pre, etc.)
model.load_state_dict(checkpoint, strict=False)
else:
# modelo completo salvo via torch.save(model, ...)
model = checkpoint
# Validar estrutura
is_valid, error_msg = validate_vit_structure(model)
if not is_valid:
raise ValueError(f"Modelo inválido: {error_msg}")
config = infer_config_from_model(model)
model = model.to(device)
model.eval()
# Garantir que config está preenchido
if config is None:
config = infer_config_from_model(model)
return model, config
def load_model_and_labels(
model_path: str,
labels_file: Optional[str] = None,
device: Optional[torch.device] = None,
) -> Tuple[torch.nn.Module, Optional[Dict[int, str]], Optional[str], ViTConfig]:
"""
** Função Principal **
Carrega modelo e, se disponível, nomes de classes.
Retorna: (model, class_names, origem_labels, config) onde origem_labels ∈ {"file", "checkpoint", "hf", None}
None se não houver nomes de classes disponíveis.
config contém a configuração da arquitetura ViT (embed_dim, num_heads, grid_size, etc.)
"""
device = device or DEVICE_DEFAULT
# Carregar diretamente do Hugging Face Hub (Transformers -> timm)
if isinstance(model_path, str) and model_path.startswith("hf-model://"):
model_id = model_path[len("hf-model://"):].strip("/")
model, class_names, config = load_vit_from_huggingface(model_id, device=device)
return model, class_names, 'hf', config
checkpoint = load_checkpoint(model_path, device=device)
class_names_ckpt = extract_class_names(checkpoint)
# class_names_file = load_class_names_from_file(labels_file)
# class_names = class_names_file or class_names_ckpt
# source: Optional[str] = None
# if class_names_file:
# source = 'file'
# elif class_names_ckpt:
# source = 'checkpoint'
class_names = class_names_ckpt
source = 'checkpoint' if class_names_ckpt else None
model, config = build_model_from_checkpoint(checkpoint, device=device)
return model, class_names, source, config