import pickle import torch import timm from dataclasses import dataclass from typing import Optional, Tuple, Dict, Any # Importar VisionTransformer diretamente para criar modelos com arquiteturas customizadas try: from timm.models.vision_transformer import VisionTransformer except ImportError: VisionTransformer = None try: from transformers import AutoModelForImageClassification except Exception: # pragma: no cover AutoModelForImageClassification = None # Suporte a safetensors (formato moderno do HuggingFace) try: from safetensors.torch import load_file as load_safetensors except ImportError: load_safetensors = None DEVICE_DEFAULT = torch.device("cuda" if torch.cuda.is_available() else "cpu") @dataclass class ViTConfig: """Configuração de arquitetura ViT extraída dinamicamente do modelo.""" embed_dim: int = 768 num_heads: int = 12 num_layers: int = 12 patch_size: int = 16 img_size: int = 224 num_classes: int = 1000 mlp_ratio: float = 4.0 qkv_bias: bool = True @property def grid_size(self) -> int: """Tamanho do grid de patches (ex: 224/16 = 14).""" return self.img_size // self.patch_size @property def num_patches(self) -> int: """Número total de patches (ex: 14*14 = 196).""" return self.grid_size ** 2 @property def timm_model_name(self) -> str: """Retorna o nome do modelo timm correspondente (para fins informativos).""" # Mapeamento baseado em embed_dim e num_heads size_map = { (192, 3): 'tiny', (384, 6): 'small', (768, 12): 'base', (1024, 16): 'large', (1280, 16): 'huge', } size = size_map.get((self.embed_dim, self.num_heads), 'custom') return f"vit_{size}_patch{self.patch_size}_{self.img_size}" def create_vit_from_config(config: ViTConfig, device: Optional[torch.device] = None) -> torch.nn.Module: """Cria um modelo ViT diretamente a partir da configuração inferida. Isso permite criar modelos com arquiteturas arbitrárias, não limitadas aos nomes predefinidos do timm (vit_base_patch16_224, etc.). """ device = device or DEVICE_DEFAULT if VisionTransformer is None: raise RuntimeError("VisionTransformer não disponível. Verifique a instalação do timm.") model = VisionTransformer( img_size=config.img_size, patch_size=config.patch_size, in_chans=3, num_classes=config.num_classes, embed_dim=config.embed_dim, depth=config.num_layers, num_heads=config.num_heads, mlp_ratio=config.mlp_ratio, qkv_bias=config.qkv_bias, class_token=True, global_pool='token', ) return model.to(device) def _strip_state_dict_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Remove prefixos comuns de frameworks (Lightning, DDP, etc.) das keys do state_dict. Prefixos tratados: - 'model.' (PyTorch Lightning) - 'module.' (DataParallel/DistributedDataParallel) - 'encoder.' (alguns frameworks de self-supervised learning) - 'backbone.' (alguns frameworks de detecção) Returns: state_dict com keys sem prefixo """ prefixes = ['model.', 'module.', 'encoder.', 'backbone.'] # Verificar se alguma key tem prefixo has_prefix = False detected_prefix = None for key in state_dict.keys(): for prefix in prefixes: if key.startswith(prefix): has_prefix = True detected_prefix = prefix break if has_prefix: break if not has_prefix: return state_dict print(f"[ViTViz] Detectado prefixo '{detected_prefix}' nas keys do state_dict (Lightning/DDP). Removendo...") new_sd: Dict[str, torch.Tensor] = {} for key, value in state_dict.items(): new_key = key for prefix in prefixes: if key.startswith(prefix): new_key = key[len(prefix):] break new_sd[new_key] = value return new_sd def validate_vit_structure(model: torch.nn.Module) -> Tuple[bool, str]: """Valida se o modelo tem a estrutura esperada de um ViT timm-compatível. Returns: (is_valid, error_message) - se inválido, error_message descreve o problema """ if not hasattr(model, 'blocks'): return False, "Modelo não tem atributo 'blocks'. Não é um ViT compatível." if len(model.blocks) == 0: return False, "Modelo tem 'blocks' vazio." block = model.blocks[0] if not hasattr(block, 'attn'): return False, "Bloco não tem atributo 'attn'. Estrutura incompatível." attn = block.attn if not hasattr(attn, 'qkv'): return False, "Módulo de atenção não tem 'qkv'. Estrutura incompatível." if not hasattr(attn, 'num_heads'): return False, "Módulo de atenção não tem 'num_heads'. Estrutura incompatível." return True, "" def infer_config_from_model(model: torch.nn.Module) -> ViTConfig: """Infere configuração ViT a partir de um modelo timm carregado.""" config = ViTConfig() # Extrair img_size e patch_size do patch_embed if hasattr(model, 'patch_embed'): pe = model.patch_embed if hasattr(pe, 'img_size'): img_size = pe.img_size config.img_size = img_size[0] if isinstance(img_size, (tuple, list)) else img_size if hasattr(pe, 'patch_size'): patch_size = pe.patch_size config.patch_size = patch_size[0] if isinstance(patch_size, (tuple, list)) else patch_size # Extrair num_layers, embed_dim, num_heads dos blocks if hasattr(model, 'blocks') and len(model.blocks) > 0: config.num_layers = len(model.blocks) block = model.blocks[0] if hasattr(block, 'attn'): attn = block.attn if hasattr(attn, 'num_heads'): config.num_heads = attn.num_heads if hasattr(attn, 'qkv') and hasattr(attn.qkv, 'in_features'): config.embed_dim = attn.qkv.in_features # Extrair num_classes do head if hasattr(model, 'head') and hasattr(model.head, 'out_features'): config.num_classes = model.head.out_features elif hasattr(model, 'head') and hasattr(model.head, 'weight'): config.num_classes = model.head.weight.shape[0] return config def infer_config_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> ViTConfig: """Infere configuração ViT a partir de um state_dict.""" config = ViTConfig() # Inferir num_layers contando blocks layer_indices = set() for key in state_dict.keys(): if key.startswith('blocks.') and '.attn.' in key: # blocks.0.attn.qkv.weight -> extrair 0 idx = int(key.split('.')[1]) layer_indices.add(idx) if layer_indices: config.num_layers = max(layer_indices) + 1 # Inferir embed_dim do primeiro bloco qkv_key = 'blocks.0.attn.qkv.weight' if qkv_key in state_dict: qkv_weight = state_dict[qkv_key] # qkv.weight shape: [3*embed_dim, embed_dim] config.embed_dim = qkv_weight.shape[1] # Inferir num_heads diretamente: qkv tem shape [3*embed_dim, embed_dim] # O output é 3*embed_dim = 3*num_heads*head_dim # Podemos calcular num_heads = (qkv_out // 3) // head_dim # Mas head_dim varia. Tentamos inferir de outra forma. # Inferir num_heads: tentar múltiplos métodos proj_key = 'blocks.0.attn.proj.weight' if proj_key in state_dict and qkv_key in state_dict: embed_dim = state_dict[proj_key].shape[0] qkv_out = state_dict[qkv_key].shape[0] # 3*embed_dim # Método 1: Se qkv_out == 3*embed_dim, tentar head_dim comum (64, 32, 96) if qkv_out == 3 * embed_dim: # Testar head_dims comuns em ordem de preferência for head_dim in [64, 32, 96, 48, 128]: if embed_dim % head_dim == 0: config.num_heads = embed_dim // head_dim break else: # Fallback: assumir que num_heads divide embed_dim uniformemente # Tentar valores comuns de num_heads for nh in [12, 16, 8, 6, 24, 4, 3]: if embed_dim % nh == 0: config.num_heads = nh break # Inferir qkv_bias qkv_bias_key = 'blocks.0.attn.qkv.bias' config.qkv_bias = qkv_bias_key in state_dict # Inferir mlp_ratio do MLP mlp_fc1_key = 'blocks.0.mlp.fc1.weight' if mlp_fc1_key in state_dict and config.embed_dim > 0: mlp_hidden = state_dict[mlp_fc1_key].shape[0] config.mlp_ratio = mlp_hidden / config.embed_dim # Inferir num_classes do head head_key = 'head.weight' if head_key in state_dict: config.num_classes = state_dict[head_key].shape[0] # Inferir patch_size e img_size do patch_embed patch_proj_key = 'patch_embed.proj.weight' if patch_proj_key in state_dict: # shape: [embed_dim, 3, patch_size, patch_size] patch_weight = state_dict[patch_proj_key] config.patch_size = patch_weight.shape[2] # Inferir img_size do pos_embed pos_embed_key = 'pos_embed' if pos_embed_key in state_dict: # shape: [1, num_patches+1, embed_dim] num_tokens = state_dict[pos_embed_key].shape[1] num_patches = num_tokens - 1 # -1 para CLS token grid_size = int(num_patches ** 0.5) config.img_size = grid_size * config.patch_size return config def _hf_id2label_to_class_names(id2label: Any) -> Optional[Dict[int, str]]: if not isinstance(id2label, dict): return None out: Dict[int, str] = {} for k, v in id2label.items(): try: out[int(k)] = str(v) except Exception: continue return out or None def _convert_hf_vit_to_timm_state_dict(hf_sd: Dict[str, torch.Tensor], num_layers: int) -> Dict[str, torch.Tensor]: """Converte state_dict de ViT (Hugging Face Transformers) para chaves do timm ViT. Alvo: timm "vit_base_patch16_224". """ out: Dict[str, torch.Tensor] = {} def get(key: str) -> torch.Tensor: if key not in hf_sd: raise KeyError(f"Missing key in HF state_dict: {key}") return hf_sd[key] # Embeddings out["cls_token"] = get("vit.embeddings.cls_token") out["pos_embed"] = get("vit.embeddings.position_embeddings") out["patch_embed.proj.weight"] = get("vit.embeddings.patch_embeddings.projection.weight") out["patch_embed.proj.bias"] = get("vit.embeddings.patch_embeddings.projection.bias") # Encoder blocks for i in range(num_layers): prefix = f"vit.encoder.layer.{i}" out[f"blocks.{i}.norm1.weight"] = get(f"{prefix}.layernorm_before.weight") out[f"blocks.{i}.norm1.bias"] = get(f"{prefix}.layernorm_before.bias") out[f"blocks.{i}.norm2.weight"] = get(f"{prefix}.layernorm_after.weight") out[f"blocks.{i}.norm2.bias"] = get(f"{prefix}.layernorm_after.bias") qw = get(f"{prefix}.attention.attention.query.weight") kw = get(f"{prefix}.attention.attention.key.weight") vw = get(f"{prefix}.attention.attention.value.weight") qb = get(f"{prefix}.attention.attention.query.bias") kb = get(f"{prefix}.attention.attention.key.bias") vb = get(f"{prefix}.attention.attention.value.bias") out[f"blocks.{i}.attn.qkv.weight"] = torch.cat([qw, kw, vw], dim=0) out[f"blocks.{i}.attn.qkv.bias"] = torch.cat([qb, kb, vb], dim=0) out[f"blocks.{i}.attn.proj.weight"] = get(f"{prefix}.attention.output.dense.weight") out[f"blocks.{i}.attn.proj.bias"] = get(f"{prefix}.attention.output.dense.bias") out[f"blocks.{i}.mlp.fc1.weight"] = get(f"{prefix}.intermediate.dense.weight") out[f"blocks.{i}.mlp.fc1.bias"] = get(f"{prefix}.intermediate.dense.bias") out[f"blocks.{i}.mlp.fc2.weight"] = get(f"{prefix}.output.dense.weight") out[f"blocks.{i}.mlp.fc2.bias"] = get(f"{prefix}.output.dense.bias") out["norm.weight"] = get("vit.layernorm.weight") out["norm.bias"] = get("vit.layernorm.bias") # Classifier if "classifier.weight" in hf_sd and "classifier.bias" in hf_sd: out["head.weight"] = get("classifier.weight") out["head.bias"] = get("classifier.bias") return out def load_vit_from_huggingface(model_id: str, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, Optional[Dict[int, str]], ViTConfig]: """Carrega ViT do Hugging Face Hub e retorna um modelo timm equivalente. Returns: (model, class_names, config) """ if AutoModelForImageClassification is None: raise RuntimeError("transformers não está instalado; instale 'transformers' para carregar do Hugging Face.") device = device or DEVICE_DEFAULT hf_model = AutoModelForImageClassification.from_pretrained(model_id) hf_model.eval() cfg = getattr(hf_model, "config", None) num_labels = int(getattr(cfg, "num_labels", 1000)) if cfg is not None else 1000 num_layers = int(getattr(cfg, "num_hidden_layers", 12)) if cfg is not None else 12 hidden_size = int(getattr(cfg, "hidden_size", 768)) if cfg is not None else 768 num_heads = int(getattr(cfg, "num_attention_heads", 12)) if cfg is not None else 12 patch_size = int(getattr(cfg, "patch_size", 16)) if cfg is not None else 16 img_size = int(getattr(cfg, "image_size", 224)) if cfg is not None else 224 intermediate_size = int(getattr(cfg, "intermediate_size", hidden_size * 4)) if cfg is not None else hidden_size * 4 qkv_bias = bool(getattr(cfg, "qkv_bias", True)) if cfg is not None else True class_names = _hf_id2label_to_class_names(getattr(cfg, "id2label", None)) if cfg is not None else None # Criar config dinâmico vit_config = ViTConfig( embed_dim=hidden_size, num_heads=num_heads, num_layers=num_layers, patch_size=patch_size, img_size=img_size, num_classes=num_labels, mlp_ratio=intermediate_size / hidden_size, qkv_bias=qkv_bias ) print(f"[ViTViz] Carregando do HuggingFace: {vit_config.timm_model_name} " f"(embed_dim={vit_config.embed_dim}, heads={vit_config.num_heads}, " f"layers={vit_config.num_layers})") # Criar modelo com arquitetura customizada diretamente timm_model = create_vit_from_config(vit_config, device=device) # Converter e carregar state_dict timm_sd = _convert_hf_vit_to_timm_state_dict(hf_model.state_dict(), num_layers=num_layers) timm_model.load_state_dict(timm_sd, strict=False) timm_model.eval() return timm_model, class_names, vit_config class CustomUnpickler(pickle.Unpickler): """Unpickler que ignora classes customizadas ausentes criando dummies dinamicamente.""" def find_class(self, module, name): try: return super().find_class(module, name) except Exception: # Cria uma classe dummy com o mesmo nome para permitir o unpickle return type(name, (), {}) def load_checkpoint(model_path: str, device: Optional[torch.device] = None) -> Any: """Carrega um checkpoint/modelo do caminho informado. Suporta formatos: - .pth / .pt: PyTorch checkpoint (torch.load) - .safetensors: Formato moderno do HuggingFace (mais seguro e rápido) Retorna o objeto carregado (modelo completo, state_dict ou dict de checkpoint). """ device = device or DEVICE_DEFAULT # Detectar formato safetensors if model_path.endswith('.safetensors'): if load_safetensors is None: raise ImportError( "safetensors não está instalado. Instale com: pip install safetensors" ) # safetensors sempre retorna um state_dict (não suporta modelo completo) state_dict = load_safetensors(model_path, device=str(device)) return state_dict # Formato PyTorch padrão (.pth, .pt, .ckpt, etc.) try: return torch.load(model_path, map_location=device, weights_only=False) except (AttributeError, ModuleNotFoundError, RuntimeError): # Fallback quando há classes ausentes ou conflitos de versão with open(model_path, 'rb') as f: return CustomUnpickler(f).load() def infer_num_classes(state_dict: Dict[str, torch.Tensor]) -> int: """Infere o número de classes a partir do state_dict (camada de head). Caso não encontre, retorna 1000 (padrão ImageNet). """ for key, tensor in state_dict.items(): if 'head' in key and 'weight' in key and hasattr(tensor, 'shape'): return tensor.shape[0] return 1000 def extract_class_names(checkpoint: Any) -> Optional[Dict[int, str]]: """Tenta extrair nomes de classes de um checkpoint (se presente).""" if not isinstance(checkpoint, dict): return None possible_keys = [ 'class_names', 'classes', 'class_to_idx', 'idx_to_class', 'label_names', 'labels', 'class_labels' ] for key in possible_keys: if key in checkpoint: labels = checkpoint[key] if isinstance(labels, list): return {i: name for i, name in enumerate(labels)} if isinstance(labels, dict): # Se já for idx->nome if all(isinstance(k, int) for k in labels.keys()): return labels # type: ignore[return-value] # Se for nome->idx if all(isinstance(v, int) for v in labels.values()): return {v: k for k, v in labels.items()} return labels # type: ignore[return-value] return None def load_class_names_from_file(labels_file: Optional[str]) -> Optional[Dict[int, str]]: """Carrega nomes de classes de um arquivo .txt (um por linha) ou .json (lista ou dict).""" if not labels_file: return None import json try: if labels_file.endswith('.json'): with open(labels_file, 'r', encoding='utf-8') as f: data = json.load(f) if isinstance(data, list): return {i: name for i, name in enumerate(data)} if isinstance(data, dict): out: Dict[int, str] = {} for k, v in data.items(): try: out[int(k)] = v except Exception: # Ignora chaves não numéricas pass if out: return out # fallback se for nome->idx if all(isinstance(v, int) for v in data.values()): return {v: k for k, v in data.items()} return None else: with open(labels_file, 'r', encoding='utf-8') as f: lines = [line.strip() for line in f if line.strip()] return {i: name for i, name in enumerate(lines)} except Exception: return None def build_model_from_checkpoint(checkpoint: Any, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, ViTConfig]: """Constroi um modelo a partir de um checkpoint que pode ser um dict, state_dict ou o próprio modelo. Suporta arquiteturas ViT arbitrárias, não limitadas aos nomes predefinidos do timm. Returns: (model, config) - modelo carregado e configuração inferida """ device = device or DEVICE_DEFAULT config: Optional[ViTConfig] = None # Detectar e logar se é checkpoint PyTorch Lightning if isinstance(checkpoint, dict) and 'pytorch-lightning_version' in checkpoint: print(f"[ViTViz] Detectado checkpoint PyTorch Lightning (v{checkpoint.get('pytorch-lightning_version', '?')})") if isinstance(checkpoint, dict): if 'model' in checkpoint: # Modelo completo dentro do dict model = checkpoint['model'] config = infer_config_from_model(model) # Validar estrutura is_valid, error_msg = validate_vit_structure(model) if not is_valid: raise ValueError(f"Modelo inválido: {error_msg}") elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] # Remover prefixos de frameworks (Lightning, DDP, etc.) state_dict = _strip_state_dict_prefix(state_dict) config = infer_config_from_state_dict(state_dict) print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} " f"(embed_dim={config.embed_dim}, heads={config.num_heads}, " f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})") # Criar modelo com arquitetura customizada model = create_vit_from_config(config, device=device) # strict=False para suportar variações como CLIP (norm_pre, etc.) model.load_state_dict(state_dict, strict=False) elif 'model_state_dict' in checkpoint: # Novo formato com class_names embutidas state_dict = checkpoint['model_state_dict'] # Remover prefixos de frameworks (Lightning, DDP, etc.) state_dict = _strip_state_dict_prefix(state_dict) config = infer_config_from_state_dict(state_dict) print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} " f"(embed_dim={config.embed_dim}, heads={config.num_heads}, " f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})") # Criar modelo com arquitetura customizada model = create_vit_from_config(config, device=device) # strict=False para suportar variações como CLIP (norm_pre, etc.) model.load_state_dict(state_dict, strict=False) else: # assume dict é um state_dict puro # Remover prefixos de frameworks (Lightning, DDP, etc.) checkpoint = _strip_state_dict_prefix(checkpoint) config = infer_config_from_state_dict(checkpoint) print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} " f"(embed_dim={config.embed_dim}, heads={config.num_heads}, " f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})") # Criar modelo com arquitetura customizada model = create_vit_from_config(config, device=device) # strict=False para suportar variações como CLIP (norm_pre, etc.) model.load_state_dict(checkpoint, strict=False) else: # modelo completo salvo via torch.save(model, ...) model = checkpoint # Validar estrutura is_valid, error_msg = validate_vit_structure(model) if not is_valid: raise ValueError(f"Modelo inválido: {error_msg}") config = infer_config_from_model(model) model = model.to(device) model.eval() # Garantir que config está preenchido if config is None: config = infer_config_from_model(model) return model, config def load_model_and_labels( model_path: str, labels_file: Optional[str] = None, device: Optional[torch.device] = None, ) -> Tuple[torch.nn.Module, Optional[Dict[int, str]], Optional[str], ViTConfig]: """ ** Função Principal ** Carrega modelo e, se disponível, nomes de classes. Retorna: (model, class_names, origem_labels, config) onde origem_labels ∈ {"file", "checkpoint", "hf", None} None se não houver nomes de classes disponíveis. config contém a configuração da arquitetura ViT (embed_dim, num_heads, grid_size, etc.) """ device = device or DEVICE_DEFAULT # Carregar diretamente do Hugging Face Hub (Transformers -> timm) if isinstance(model_path, str) and model_path.startswith("hf-model://"): model_id = model_path[len("hf-model://"):].strip("/") model, class_names, config = load_vit_from_huggingface(model_id, device=device) return model, class_names, 'hf', config checkpoint = load_checkpoint(model_path, device=device) class_names_ckpt = extract_class_names(checkpoint) # class_names_file = load_class_names_from_file(labels_file) # class_names = class_names_file or class_names_ckpt # source: Optional[str] = None # if class_names_file: # source = 'file' # elif class_names_ckpt: # source = 'checkpoint' class_names = class_names_ckpt source = 'checkpoint' if class_names_ckpt else None model, config = build_model_from_checkpoint(checkpoint, device=device) return model, class_names, source, config