| | import pickle |
| | import torch |
| | import timm |
| | from dataclasses import dataclass |
| | from typing import Optional, Tuple, Dict, Any |
| |
|
| | |
| | try: |
| | from timm.models.vision_transformer import VisionTransformer |
| | except ImportError: |
| | VisionTransformer = None |
| |
|
| | try: |
| | from transformers import AutoModelForImageClassification |
| | except Exception: |
| | AutoModelForImageClassification = None |
| |
|
| | |
| | try: |
| | from safetensors.torch import load_file as load_safetensors |
| | except ImportError: |
| | load_safetensors = None |
| |
|
| | DEVICE_DEFAULT = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| |
|
| | @dataclass |
| | class ViTConfig: |
| | """Configuração de arquitetura ViT extraída dinamicamente do modelo.""" |
| | embed_dim: int = 768 |
| | num_heads: int = 12 |
| | num_layers: int = 12 |
| | patch_size: int = 16 |
| | img_size: int = 224 |
| | num_classes: int = 1000 |
| | mlp_ratio: float = 4.0 |
| | qkv_bias: bool = True |
| | |
| | @property |
| | def grid_size(self) -> int: |
| | """Tamanho do grid de patches (ex: 224/16 = 14).""" |
| | return self.img_size // self.patch_size |
| | |
| | @property |
| | def num_patches(self) -> int: |
| | """Número total de patches (ex: 14*14 = 196).""" |
| | return self.grid_size ** 2 |
| | |
| | @property |
| | def timm_model_name(self) -> str: |
| | """Retorna o nome do modelo timm correspondente (para fins informativos).""" |
| | |
| | size_map = { |
| | (192, 3): 'tiny', |
| | (384, 6): 'small', |
| | (768, 12): 'base', |
| | (1024, 16): 'large', |
| | (1280, 16): 'huge', |
| | } |
| | size = size_map.get((self.embed_dim, self.num_heads), 'custom') |
| | return f"vit_{size}_patch{self.patch_size}_{self.img_size}" |
| |
|
| |
|
| | def create_vit_from_config(config: ViTConfig, device: Optional[torch.device] = None) -> torch.nn.Module: |
| | """Cria um modelo ViT diretamente a partir da configuração inferida. |
| | |
| | Isso permite criar modelos com arquiteturas arbitrárias, não limitadas |
| | aos nomes predefinidos do timm (vit_base_patch16_224, etc.). |
| | """ |
| | device = device or DEVICE_DEFAULT |
| | |
| | if VisionTransformer is None: |
| | raise RuntimeError("VisionTransformer não disponível. Verifique a instalação do timm.") |
| | |
| | model = VisionTransformer( |
| | img_size=config.img_size, |
| | patch_size=config.patch_size, |
| | in_chans=3, |
| | num_classes=config.num_classes, |
| | embed_dim=config.embed_dim, |
| | depth=config.num_layers, |
| | num_heads=config.num_heads, |
| | mlp_ratio=config.mlp_ratio, |
| | qkv_bias=config.qkv_bias, |
| | class_token=True, |
| | global_pool='token', |
| | ) |
| | |
| | return model.to(device) |
| |
|
| |
|
| | def _strip_state_dict_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| | """Remove prefixos comuns de frameworks (Lightning, DDP, etc.) das keys do state_dict. |
| | |
| | Prefixos tratados: |
| | - 'model.' (PyTorch Lightning) |
| | - 'module.' (DataParallel/DistributedDataParallel) |
| | - 'encoder.' (alguns frameworks de self-supervised learning) |
| | - 'backbone.' (alguns frameworks de detecção) |
| | |
| | Returns: |
| | state_dict com keys sem prefixo |
| | """ |
| | prefixes = ['model.', 'module.', 'encoder.', 'backbone.'] |
| | |
| | |
| | has_prefix = False |
| | detected_prefix = None |
| | for key in state_dict.keys(): |
| | for prefix in prefixes: |
| | if key.startswith(prefix): |
| | has_prefix = True |
| | detected_prefix = prefix |
| | break |
| | if has_prefix: |
| | break |
| | |
| | if not has_prefix: |
| | return state_dict |
| | |
| | print(f"[ViTViz] Detectado prefixo '{detected_prefix}' nas keys do state_dict (Lightning/DDP). Removendo...") |
| | |
| | new_sd: Dict[str, torch.Tensor] = {} |
| | for key, value in state_dict.items(): |
| | new_key = key |
| | for prefix in prefixes: |
| | if key.startswith(prefix): |
| | new_key = key[len(prefix):] |
| | break |
| | new_sd[new_key] = value |
| | |
| | return new_sd |
| |
|
| |
|
| | def validate_vit_structure(model: torch.nn.Module) -> Tuple[bool, str]: |
| | """Valida se o modelo tem a estrutura esperada de um ViT timm-compatível. |
| | |
| | Returns: |
| | (is_valid, error_message) - se inválido, error_message descreve o problema |
| | """ |
| | if not hasattr(model, 'blocks'): |
| | return False, "Modelo não tem atributo 'blocks'. Não é um ViT compatível." |
| | |
| | if len(model.blocks) == 0: |
| | return False, "Modelo tem 'blocks' vazio." |
| | |
| | block = model.blocks[0] |
| | if not hasattr(block, 'attn'): |
| | return False, "Bloco não tem atributo 'attn'. Estrutura incompatível." |
| | |
| | attn = block.attn |
| | if not hasattr(attn, 'qkv'): |
| | return False, "Módulo de atenção não tem 'qkv'. Estrutura incompatível." |
| | |
| | if not hasattr(attn, 'num_heads'): |
| | return False, "Módulo de atenção não tem 'num_heads'. Estrutura incompatível." |
| | |
| | return True, "" |
| |
|
| |
|
| | def infer_config_from_model(model: torch.nn.Module) -> ViTConfig: |
| | """Infere configuração ViT a partir de um modelo timm carregado.""" |
| | config = ViTConfig() |
| | |
| | |
| | if hasattr(model, 'patch_embed'): |
| | pe = model.patch_embed |
| | if hasattr(pe, 'img_size'): |
| | img_size = pe.img_size |
| | config.img_size = img_size[0] if isinstance(img_size, (tuple, list)) else img_size |
| | if hasattr(pe, 'patch_size'): |
| | patch_size = pe.patch_size |
| | config.patch_size = patch_size[0] if isinstance(patch_size, (tuple, list)) else patch_size |
| | |
| | |
| | if hasattr(model, 'blocks') and len(model.blocks) > 0: |
| | config.num_layers = len(model.blocks) |
| | block = model.blocks[0] |
| | if hasattr(block, 'attn'): |
| | attn = block.attn |
| | if hasattr(attn, 'num_heads'): |
| | config.num_heads = attn.num_heads |
| | if hasattr(attn, 'qkv') and hasattr(attn.qkv, 'in_features'): |
| | config.embed_dim = attn.qkv.in_features |
| | |
| | |
| | if hasattr(model, 'head') and hasattr(model.head, 'out_features'): |
| | config.num_classes = model.head.out_features |
| | elif hasattr(model, 'head') and hasattr(model.head, 'weight'): |
| | config.num_classes = model.head.weight.shape[0] |
| | |
| | return config |
| |
|
| |
|
| | def infer_config_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> ViTConfig: |
| | """Infere configuração ViT a partir de um state_dict.""" |
| | config = ViTConfig() |
| | |
| | |
| | layer_indices = set() |
| | for key in state_dict.keys(): |
| | if key.startswith('blocks.') and '.attn.' in key: |
| | |
| | idx = int(key.split('.')[1]) |
| | layer_indices.add(idx) |
| | if layer_indices: |
| | config.num_layers = max(layer_indices) + 1 |
| | |
| | |
| | qkv_key = 'blocks.0.attn.qkv.weight' |
| | if qkv_key in state_dict: |
| | qkv_weight = state_dict[qkv_key] |
| | |
| | config.embed_dim = qkv_weight.shape[1] |
| | |
| | |
| | |
| | |
| | |
| | |
| | proj_key = 'blocks.0.attn.proj.weight' |
| | if proj_key in state_dict and qkv_key in state_dict: |
| | embed_dim = state_dict[proj_key].shape[0] |
| | qkv_out = state_dict[qkv_key].shape[0] |
| | |
| | |
| | if qkv_out == 3 * embed_dim: |
| | |
| | for head_dim in [64, 32, 96, 48, 128]: |
| | if embed_dim % head_dim == 0: |
| | config.num_heads = embed_dim // head_dim |
| | break |
| | else: |
| | |
| | |
| | for nh in [12, 16, 8, 6, 24, 4, 3]: |
| | if embed_dim % nh == 0: |
| | config.num_heads = nh |
| | break |
| | |
| | |
| | qkv_bias_key = 'blocks.0.attn.qkv.bias' |
| | config.qkv_bias = qkv_bias_key in state_dict |
| | |
| | |
| | mlp_fc1_key = 'blocks.0.mlp.fc1.weight' |
| | if mlp_fc1_key in state_dict and config.embed_dim > 0: |
| | mlp_hidden = state_dict[mlp_fc1_key].shape[0] |
| | config.mlp_ratio = mlp_hidden / config.embed_dim |
| | |
| | |
| | head_key = 'head.weight' |
| | if head_key in state_dict: |
| | config.num_classes = state_dict[head_key].shape[0] |
| | |
| | |
| | patch_proj_key = 'patch_embed.proj.weight' |
| | if patch_proj_key in state_dict: |
| | |
| | patch_weight = state_dict[patch_proj_key] |
| | config.patch_size = patch_weight.shape[2] |
| | |
| | |
| | pos_embed_key = 'pos_embed' |
| | if pos_embed_key in state_dict: |
| | |
| | num_tokens = state_dict[pos_embed_key].shape[1] |
| | num_patches = num_tokens - 1 |
| | grid_size = int(num_patches ** 0.5) |
| | config.img_size = grid_size * config.patch_size |
| | |
| | return config |
| |
|
| |
|
| | def _hf_id2label_to_class_names(id2label: Any) -> Optional[Dict[int, str]]: |
| | if not isinstance(id2label, dict): |
| | return None |
| | out: Dict[int, str] = {} |
| | for k, v in id2label.items(): |
| | try: |
| | out[int(k)] = str(v) |
| | except Exception: |
| | continue |
| | return out or None |
| |
|
| |
|
| | def _convert_hf_vit_to_timm_state_dict(hf_sd: Dict[str, torch.Tensor], num_layers: int) -> Dict[str, torch.Tensor]: |
| | """Converte state_dict de ViT (Hugging Face Transformers) para chaves do timm ViT. |
| | |
| | Alvo: timm "vit_base_patch16_224". |
| | """ |
| | out: Dict[str, torch.Tensor] = {} |
| |
|
| | def get(key: str) -> torch.Tensor: |
| | if key not in hf_sd: |
| | raise KeyError(f"Missing key in HF state_dict: {key}") |
| | return hf_sd[key] |
| |
|
| | |
| | out["cls_token"] = get("vit.embeddings.cls_token") |
| | out["pos_embed"] = get("vit.embeddings.position_embeddings") |
| | out["patch_embed.proj.weight"] = get("vit.embeddings.patch_embeddings.projection.weight") |
| | out["patch_embed.proj.bias"] = get("vit.embeddings.patch_embeddings.projection.bias") |
| |
|
| | |
| | for i in range(num_layers): |
| | prefix = f"vit.encoder.layer.{i}" |
| | out[f"blocks.{i}.norm1.weight"] = get(f"{prefix}.layernorm_before.weight") |
| | out[f"blocks.{i}.norm1.bias"] = get(f"{prefix}.layernorm_before.bias") |
| | out[f"blocks.{i}.norm2.weight"] = get(f"{prefix}.layernorm_after.weight") |
| | out[f"blocks.{i}.norm2.bias"] = get(f"{prefix}.layernorm_after.bias") |
| |
|
| | qw = get(f"{prefix}.attention.attention.query.weight") |
| | kw = get(f"{prefix}.attention.attention.key.weight") |
| | vw = get(f"{prefix}.attention.attention.value.weight") |
| | qb = get(f"{prefix}.attention.attention.query.bias") |
| | kb = get(f"{prefix}.attention.attention.key.bias") |
| | vb = get(f"{prefix}.attention.attention.value.bias") |
| | out[f"blocks.{i}.attn.qkv.weight"] = torch.cat([qw, kw, vw], dim=0) |
| | out[f"blocks.{i}.attn.qkv.bias"] = torch.cat([qb, kb, vb], dim=0) |
| |
|
| | out[f"blocks.{i}.attn.proj.weight"] = get(f"{prefix}.attention.output.dense.weight") |
| | out[f"blocks.{i}.attn.proj.bias"] = get(f"{prefix}.attention.output.dense.bias") |
| |
|
| | out[f"blocks.{i}.mlp.fc1.weight"] = get(f"{prefix}.intermediate.dense.weight") |
| | out[f"blocks.{i}.mlp.fc1.bias"] = get(f"{prefix}.intermediate.dense.bias") |
| | out[f"blocks.{i}.mlp.fc2.weight"] = get(f"{prefix}.output.dense.weight") |
| | out[f"blocks.{i}.mlp.fc2.bias"] = get(f"{prefix}.output.dense.bias") |
| |
|
| | out["norm.weight"] = get("vit.layernorm.weight") |
| | out["norm.bias"] = get("vit.layernorm.bias") |
| |
|
| | |
| | if "classifier.weight" in hf_sd and "classifier.bias" in hf_sd: |
| | out["head.weight"] = get("classifier.weight") |
| | out["head.bias"] = get("classifier.bias") |
| |
|
| | return out |
| |
|
| |
|
| | def load_vit_from_huggingface(model_id: str, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, Optional[Dict[int, str]], ViTConfig]: |
| | """Carrega ViT do Hugging Face Hub e retorna um modelo timm equivalente. |
| | |
| | Returns: |
| | (model, class_names, config) |
| | """ |
| | if AutoModelForImageClassification is None: |
| | raise RuntimeError("transformers não está instalado; instale 'transformers' para carregar do Hugging Face.") |
| |
|
| | device = device or DEVICE_DEFAULT |
| | hf_model = AutoModelForImageClassification.from_pretrained(model_id) |
| | hf_model.eval() |
| | cfg = getattr(hf_model, "config", None) |
| | num_labels = int(getattr(cfg, "num_labels", 1000)) if cfg is not None else 1000 |
| | num_layers = int(getattr(cfg, "num_hidden_layers", 12)) if cfg is not None else 12 |
| | hidden_size = int(getattr(cfg, "hidden_size", 768)) if cfg is not None else 768 |
| | num_heads = int(getattr(cfg, "num_attention_heads", 12)) if cfg is not None else 12 |
| | patch_size = int(getattr(cfg, "patch_size", 16)) if cfg is not None else 16 |
| | img_size = int(getattr(cfg, "image_size", 224)) if cfg is not None else 224 |
| | intermediate_size = int(getattr(cfg, "intermediate_size", hidden_size * 4)) if cfg is not None else hidden_size * 4 |
| | qkv_bias = bool(getattr(cfg, "qkv_bias", True)) if cfg is not None else True |
| | class_names = _hf_id2label_to_class_names(getattr(cfg, "id2label", None)) if cfg is not None else None |
| |
|
| | |
| | vit_config = ViTConfig( |
| | embed_dim=hidden_size, |
| | num_heads=num_heads, |
| | num_layers=num_layers, |
| | patch_size=patch_size, |
| | img_size=img_size, |
| | num_classes=num_labels, |
| | mlp_ratio=intermediate_size / hidden_size, |
| | qkv_bias=qkv_bias |
| | ) |
| | |
| | print(f"[ViTViz] Carregando do HuggingFace: {vit_config.timm_model_name} " |
| | f"(embed_dim={vit_config.embed_dim}, heads={vit_config.num_heads}, " |
| | f"layers={vit_config.num_layers})") |
| | |
| | |
| | timm_model = create_vit_from_config(vit_config, device=device) |
| | |
| | |
| | timm_sd = _convert_hf_vit_to_timm_state_dict(hf_model.state_dict(), num_layers=num_layers) |
| | timm_model.load_state_dict(timm_sd, strict=False) |
| | timm_model.eval() |
| | |
| | return timm_model, class_names, vit_config |
| |
|
| |
|
| | class CustomUnpickler(pickle.Unpickler): |
| | """Unpickler que ignora classes customizadas ausentes criando dummies dinamicamente.""" |
| |
|
| | def find_class(self, module, name): |
| | try: |
| | return super().find_class(module, name) |
| | except Exception: |
| | |
| | return type(name, (), {}) |
| |
|
| |
|
| | def load_checkpoint(model_path: str, device: Optional[torch.device] = None) -> Any: |
| | """Carrega um checkpoint/modelo do caminho informado. |
| | |
| | Suporta formatos: |
| | - .pth / .pt: PyTorch checkpoint (torch.load) |
| | - .safetensors: Formato moderno do HuggingFace (mais seguro e rápido) |
| | |
| | Retorna o objeto carregado (modelo completo, state_dict ou dict de checkpoint). |
| | """ |
| | device = device or DEVICE_DEFAULT |
| | |
| | |
| | if model_path.endswith('.safetensors'): |
| | if load_safetensors is None: |
| | raise ImportError( |
| | "safetensors não está instalado. Instale com: pip install safetensors" |
| | ) |
| | |
| | state_dict = load_safetensors(model_path, device=str(device)) |
| | return state_dict |
| | |
| | |
| | try: |
| | return torch.load(model_path, map_location=device, weights_only=False) |
| | except (AttributeError, ModuleNotFoundError, RuntimeError): |
| | |
| | with open(model_path, 'rb') as f: |
| | return CustomUnpickler(f).load() |
| |
|
| |
|
| | def infer_num_classes(state_dict: Dict[str, torch.Tensor]) -> int: |
| | """Infere o número de classes a partir do state_dict (camada de head). |
| | |
| | Caso não encontre, retorna 1000 (padrão ImageNet). |
| | """ |
| | for key, tensor in state_dict.items(): |
| | if 'head' in key and 'weight' in key and hasattr(tensor, 'shape'): |
| | return tensor.shape[0] |
| | return 1000 |
| |
|
| |
|
| | def extract_class_names(checkpoint: Any) -> Optional[Dict[int, str]]: |
| | """Tenta extrair nomes de classes de um checkpoint (se presente).""" |
| | if not isinstance(checkpoint, dict): |
| | return None |
| |
|
| | possible_keys = [ |
| | 'class_names', 'classes', 'class_to_idx', 'idx_to_class', |
| | 'label_names', 'labels', 'class_labels' |
| | ] |
| |
|
| | for key in possible_keys: |
| | if key in checkpoint: |
| | labels = checkpoint[key] |
| | if isinstance(labels, list): |
| | return {i: name for i, name in enumerate(labels)} |
| | if isinstance(labels, dict): |
| | |
| | if all(isinstance(k, int) for k in labels.keys()): |
| | return labels |
| | |
| | if all(isinstance(v, int) for v in labels.values()): |
| | return {v: k for k, v in labels.items()} |
| | return labels |
| | return None |
| |
|
| |
|
| | def load_class_names_from_file(labels_file: Optional[str]) -> Optional[Dict[int, str]]: |
| | """Carrega nomes de classes de um arquivo .txt (um por linha) ou .json (lista ou dict).""" |
| | if not labels_file: |
| | return None |
| | import json |
| | try: |
| | if labels_file.endswith('.json'): |
| | with open(labels_file, 'r', encoding='utf-8') as f: |
| | data = json.load(f) |
| | if isinstance(data, list): |
| | return {i: name for i, name in enumerate(data)} |
| | if isinstance(data, dict): |
| | out: Dict[int, str] = {} |
| | for k, v in data.items(): |
| | try: |
| | out[int(k)] = v |
| | except Exception: |
| | |
| | pass |
| | if out: |
| | return out |
| | |
| | if all(isinstance(v, int) for v in data.values()): |
| | return {v: k for k, v in data.items()} |
| | return None |
| | else: |
| | with open(labels_file, 'r', encoding='utf-8') as f: |
| | lines = [line.strip() for line in f if line.strip()] |
| | return {i: name for i, name in enumerate(lines)} |
| | except Exception: |
| | return None |
| |
|
| |
|
| | def build_model_from_checkpoint(checkpoint: Any, device: Optional[torch.device] = None) -> Tuple[torch.nn.Module, ViTConfig]: |
| | """Constroi um modelo a partir de um checkpoint que pode ser um dict, state_dict ou o próprio modelo. |
| | |
| | Suporta arquiteturas ViT arbitrárias, não limitadas aos nomes predefinidos do timm. |
| | |
| | Returns: |
| | (model, config) - modelo carregado e configuração inferida |
| | """ |
| | device = device or DEVICE_DEFAULT |
| | config: Optional[ViTConfig] = None |
| | |
| | |
| | if isinstance(checkpoint, dict) and 'pytorch-lightning_version' in checkpoint: |
| | print(f"[ViTViz] Detectado checkpoint PyTorch Lightning (v{checkpoint.get('pytorch-lightning_version', '?')})") |
| | |
| | if isinstance(checkpoint, dict): |
| | if 'model' in checkpoint: |
| | |
| | model = checkpoint['model'] |
| | config = infer_config_from_model(model) |
| | |
| | is_valid, error_msg = validate_vit_structure(model) |
| | if not is_valid: |
| | raise ValueError(f"Modelo inválido: {error_msg}") |
| | elif 'state_dict' in checkpoint: |
| | state_dict = checkpoint['state_dict'] |
| | |
| | state_dict = _strip_state_dict_prefix(state_dict) |
| | config = infer_config_from_state_dict(state_dict) |
| | print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} " |
| | f"(embed_dim={config.embed_dim}, heads={config.num_heads}, " |
| | f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})") |
| | |
| | model = create_vit_from_config(config, device=device) |
| | |
| | model.load_state_dict(state_dict, strict=False) |
| | elif 'model_state_dict' in checkpoint: |
| | |
| | state_dict = checkpoint['model_state_dict'] |
| | |
| | state_dict = _strip_state_dict_prefix(state_dict) |
| | config = infer_config_from_state_dict(state_dict) |
| | print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} " |
| | f"(embed_dim={config.embed_dim}, heads={config.num_heads}, " |
| | f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})") |
| | |
| | model = create_vit_from_config(config, device=device) |
| | |
| | model.load_state_dict(state_dict, strict=False) |
| | else: |
| | |
| | |
| | checkpoint = _strip_state_dict_prefix(checkpoint) |
| | config = infer_config_from_state_dict(checkpoint) |
| | print(f"[ViTViz] Arquitetura inferida: {config.timm_model_name} " |
| | f"(embed_dim={config.embed_dim}, heads={config.num_heads}, " |
| | f"layers={config.num_layers}, patch={config.patch_size}, img={config.img_size})") |
| | |
| | model = create_vit_from_config(config, device=device) |
| | |
| | model.load_state_dict(checkpoint, strict=False) |
| | else: |
| | |
| | model = checkpoint |
| | |
| | is_valid, error_msg = validate_vit_structure(model) |
| | if not is_valid: |
| | raise ValueError(f"Modelo inválido: {error_msg}") |
| | config = infer_config_from_model(model) |
| |
|
| | model = model.to(device) |
| | model.eval() |
| | |
| | |
| | if config is None: |
| | config = infer_config_from_model(model) |
| | |
| | return model, config |
| |
|
| |
|
| | def load_model_and_labels( |
| | model_path: str, |
| | labels_file: Optional[str] = None, |
| | device: Optional[torch.device] = None, |
| | ) -> Tuple[torch.nn.Module, Optional[Dict[int, str]], Optional[str], ViTConfig]: |
| | """ |
| | ** Função Principal ** |
| | Carrega modelo e, se disponível, nomes de classes. |
| | |
| | Retorna: (model, class_names, origem_labels, config) onde origem_labels ∈ {"file", "checkpoint", "hf", None} |
| | None se não houver nomes de classes disponíveis. |
| | config contém a configuração da arquitetura ViT (embed_dim, num_heads, grid_size, etc.) |
| | """ |
| | device = device or DEVICE_DEFAULT |
| |
|
| | |
| | if isinstance(model_path, str) and model_path.startswith("hf-model://"): |
| | model_id = model_path[len("hf-model://"):].strip("/") |
| | model, class_names, config = load_vit_from_huggingface(model_id, device=device) |
| | return model, class_names, 'hf', config |
| |
|
| | checkpoint = load_checkpoint(model_path, device=device) |
| | class_names_ckpt = extract_class_names(checkpoint) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | class_names = class_names_ckpt |
| | source = 'checkpoint' if class_names_ckpt else None |
| |
|
| | model, config = build_model_from_checkpoint(checkpoint, device=device) |
| | return model, class_names, source, config |
| |
|