""" Pentachora batch generation and model creation. Assumes vocab is already loaded as 'vocab'. Assumes PentachoronStabilizer is already loaded. Assumes BaselineViT is already loaded. """ import torch import numpy as np # CIFAR-100 class names CIFAR100_CLASSES = [ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' ] #config = { # 'head_type': 'roseface', # 'roseface' | 'legacy' # 'prototype_mode': 'centroid', # 'centroid' | 'rose5' | 'max_vertex' # 'margin_type': 'cosface', # 'arcface' | 'cosface' | 'sphereface' # 'margin_m': 0.30, # 'scale_s': 30.0, # 'apply_margin_train_only': False, # 'norm_type': 'l1', # 'l1' | 'l2' normalization # 'similarity_mode': 'rose', # legacy #} # Model variant configurations MODEL_CONFIGS = { # Ultra-light 'vit_beatrix_shaper': { 'embed_dim': 256, 'vocab_dim': 256, 'depth': 16, 'num_heads': 8, 'mlp_ratio': 1.0, #'norm_type': 'l1', 'margin_type': 'cosface', 'margin_m': 0.30, 'scale_s': 30.0, }, 'vit_beatrix_arc_shaper': { 'embed_dim': 256, 'vocab_dim': 256, 'depth': 16, 'num_heads': 8, 'mlp_ratio': 2.0, #'norm_type': 'l1', 'margin_type': 'arcface', 'margin_m': 0.2914, 'scale_s': 30.0, }, 'vit_beatrix_nano_arc': { 'embed_dim': 64, 'vocab_dim': 64, 'depth': 25, 'num_heads': 8, 'mlp_ratio': 8.0, #'norm_type': 'l1', 'margin_type': 'arcface', 'margin_m': 0.2914, 'scale_s': 30.0, }, 'vit_beatrix_nano_cos': { 'embed_dim': 64, 'vocab_dim': 64, 'depth': 25, 'num_heads': 8, 'mlp_ratio': 8.0, #'norm_type': 'l1', 'margin_type': 'cosface', 'margin_m': 0.2914, 'scale_s': 30.0, }, 'vit_beatrix_nano_128_cos': { 'embed_dim': 128, 'vocab_dim': 128, 'depth': 25, 'num_heads': 8, 'mlp_ratio': 8.0, #'norm_type': 'l1', 'margin_type': 'cosface', 'margin_m': 0.2914, 'scale_s': 30.0, }, 'vit_beatrix_mini_cos': { 'embed_dim': 256, 'vocab_dim': 256, 'depth': 25, 'num_heads': 8, 'mlp_ratio': 8.0, #'norm_type': 'l1', 'margin_type': 'cosface', 'margin_m': 0.2914, 'scale_s': 30.0, }, 'vit_beatrix_mini_cos_large_margin': { 'embed_dim': 256, 'vocab_dim': 256, 'depth': 25, 'num_heads': 8, 'mlp_ratio': 8.0, #'norm_type': 'l1', 'margin_type': 'cosface', 'margin_m': 0.7086, 'scale_s': 30.0, }, 'vit_zana_nano': { 'embed_dim': 128, 'vocab_dim': 128, 'depth': 4, 'num_heads': 2, 'mlp_ratio': 2.0 }, 'vit_beatrix_base_cos': { 'embed_dim': 512, 'vocab_dim': 512, 'depth': 25, 'num_heads': 16, 'mlp_ratio': 8.0, #'norm_type': 'l1', 'margin_type': 'cosface', 'margin_m': 0.2914, 'scale_s': 30.0, }, 'vit_zana_nano_deep': { 'embed_dim': 128, 'vocab_dim': 128, 'depth': 8, 'num_heads': 4, 'mlp_ratio': 2.0 }, 'vit_zana_shaper': { 'embed_dim': 256, 'vocab_dim': 256, 'depth': 32, 'num_heads': 8, 'mlp_ratio': 4.0 }, 'vit_zana_nano_thicc': { 'embed_dim': 128, 'vocab_dim': 128, 'depth': 4, 'num_heads': 8, 'mlp_ratio': 4.0 }, 'vit_zana_micro': { 'embed_dim': 500, 'vocab_dim': 25, 'depth': 6, 'num_heads': 2, 'mlp_ratio': 2.0 }, 'vit_zana_micro_500': { 'embed_dim': 500, 'vocab_dim': 25, 'depth': 6, 'num_heads': 5, 'mlp_ratio': 2.0 }, 'vit_zana_base': { 'embed_dim': 512, 'vocab_dim': 512, 'depth': 16, 'num_heads': 4, 'mlp_ratio': 4.0 }, 'vit_ursula_nano_1000': { 'embed_dim': 1000, 'vocab_dim': 500, 'depth': 4, 'num_heads': 50, 'mlp_ratio': 4.0 }, 'vit_ursula_nano': { 'embed_dim': 1000, 'vocab_dim': 25, 'depth': 4, 'num_heads': 10, 'mlp_ratio': 4.0 }, # Lightweight 'tiny': { 'embed_dim': 192, 'vocab_dim': 192, 'depth': 12, 'num_heads': 3, 'mlp_ratio': 4.0 }, 'vit_ursula_mini': { 'embed_dim': 256, 'vocab_dim': 256, 'depth': 12, 'num_heads': 4, 'mlp_ratio': 4.0 }, # Standard 'small': { 'embed_dim': 384, 'vocab_dim': 384, 'depth': 12, 'num_heads': 6, 'mlp_ratio': 4.0 }, 'base': { 'embed_dim': 768, 'vocab_dim': 768, 'depth': 12, 'num_heads': 12, 'mlp_ratio': 4.0 }, # Experimental 'wide_shallow': { 'embed_dim': 1024, 'vocab_dim': 1024, 'depth': 4, 'num_heads': 16, 'mlp_ratio': 2.0 }, 'narrow_deep': { 'embed_dim': 192, 'vocab_dim': 192, 'depth': 24, 'num_heads': 3, 'mlp_ratio': 4.0 }, } """ Updated pentachora batch generation and model creation for L1 norm. Add this modification to your existing build_model function. """ def build_model(variant='small', **override_params): """ Build model with explicit parameter handling - no hidden kwargs. Args: variant: Model variant name from MODEL_CONFIGS **override_params: Individual parameter overrides Returns: model: BaselineViT model with frozen pentachora """ assert variant in MODEL_CONFIGS, f"Unknown variant: {variant}. Choose from {list(MODEL_CONFIGS.keys())}" base_config = MODEL_CONFIGS[variant].copy() # EXPLICIT parameter extraction with defaults # Core architecture parameters embed_dim = override_params.get('embed_dim', base_config.get('embed_dim', 512)) vocab_dim = override_params.get('vocab_dim', base_config.get('vocab_dim', 512)) depth = override_params.get('depth', base_config.get('depth', 12)) num_heads = override_params.get('num_heads', base_config.get('num_heads', 8)) mlp_ratio = override_params.get('mlp_ratio', base_config.get('mlp_ratio', 4.0)) # Image and patch parameters img_size = override_params.get('img_size', base_config.get('img_size', 32)) patch_size = override_params.get('patch_size', base_config.get('patch_size', 4)) # Regularization parameters dropout = override_params.get('dropout', base_config.get('dropout', 0.0)) attn_dropout = override_params.get('attn_dropout', base_config.get('attn_dropout', 0.0)) # Pentachora geometry parameters similarity_mode = override_params.get('similarity_mode', base_config.get('similarity_mode', 'rose')) norm_type = override_params.get('norm_type', base_config.get('norm_type', 'l1')) # RoseFace head parameters head_type = override_params.get('head_type', base_config.get('head_type', 'roseface')) prototype_mode = override_params.get('prototype_mode', base_config.get('prototype_mode', 'centroid')) margin_type = override_params.get('margin_type', base_config.get('margin_type', 'cosface')) margin_m = float(override_params.get('margin_m', base_config.get('margin_m', 0.30))) scale_s = float(override_params.get('scale_s', base_config.get('scale_s', 30.0))) apply_margin_train_only = override_params.get('apply_margin_train_only', base_config.get('apply_margin_train_only', False)) # Dataset configuration num_classes = len(CIFAR100_CLASSES) # Print what we're building print(f"Building {variant}:") print(f" Architecture: embed={embed_dim}, vocab={vocab_dim}, depth={depth}, heads={num_heads}") print(f" Image: {img_size}x{img_size}, patch={patch_size}x{patch_size}") print(f" RoseFace: {margin_type}, m={margin_m:.4f}, s={scale_s:.1f}") print(f" Norm: {norm_type}, Similarity: {similarity_mode}") # Generate pentachora from vocab print(f"Generating {num_classes} pentachora from vocabulary...") class_names = CIFAR100_CLASSES[:num_classes] # vocab.encode_batch returns List[np.ndarray] where each is (5, vocab_dim) pentachora_np_list = vocab.encode_batch(class_names, generate=True) # Convert to torch tensors raw_penta_list = [torch.tensor(penta, dtype=torch.float32) for penta in pentachora_np_list] # Handle dimension mismatch if needed pentachora_list = [] for i, penta in enumerate(raw_penta_list): if penta.shape[-1] != vocab_dim: current_dim = penta.shape[-1] if current_dim > vocab_dim: # Downsample via linear interpolation resized_vertices = [] for v in range(penta.shape[0]): indices = torch.linspace(0, current_dim - 1, vocab_dim) vertex = penta[v] left_idx = indices.floor().long() right_idx = (left_idx + 1).clamp(max=current_dim - 1) alpha = indices - left_idx.float() interpolated = vertex[left_idx] * (1 - alpha) + vertex[right_idx] * alpha resized_vertices.append(interpolated) penta_resized = torch.stack(resized_vertices) if i == 0: # Only print once print(f" Downsampling pentachora from {current_dim} to {vocab_dim}") else: # Upsample via linear interpolation resized_vertices = [] for v in range(penta.shape[0]): vertex = penta[v] x = torch.linspace(0, current_dim - 1, vocab_dim) interpolated = torch.zeros(vocab_dim, dtype=vertex.dtype, device=vertex.device) for j in range(vocab_dim): if x[j] <= 0: interpolated[j] = vertex[0] elif x[j] >= current_dim - 1: interpolated[j] = vertex[-1] else: left = int(x[j]) alpha = x[j] - left interpolated[j] = vertex[left] * (1 - alpha) + vertex[left + 1] * alpha resized_vertices.append(interpolated) penta_resized = torch.stack(resized_vertices) if i == 0: # Only print once print(f" Upsampling pentachora from {current_dim} to {vocab_dim}") pentachora_list.append(penta_resized) else: pentachora_list.append(penta.detach().clone().to(get_default_device())) print(f"Using {num_classes} L1-normalized pentachora") # Create model with EXPLICIT parameters - no **kwargs model = BaselineViT( pentachora_list=pentachora_list, vocab_dim=vocab_dim, img_size=img_size, patch_size=patch_size, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, dropout=dropout, attn_dropout=attn_dropout, similarity_mode=similarity_mode, norm_type=norm_type, head_type=head_type, prototype_mode=prototype_mode, margin_type=margin_type, margin_m=margin_m, scale_s=scale_s, apply_margin_train_only=apply_margin_train_only ) # Store complete config for checkpoint saving model.config = { 'variant': variant, 'vocab_dim': vocab_dim, 'embed_dim': embed_dim, 'depth': depth, 'num_heads': num_heads, 'mlp_ratio': mlp_ratio, 'img_size': img_size, 'patch_size': patch_size, 'dropout': dropout, 'attn_dropout': attn_dropout, 'similarity_mode': similarity_mode, 'norm_type': norm_type, 'head_type': head_type, 'prototype_mode': prototype_mode, 'margin_type': margin_type, 'margin_m': margin_m, 'scale_s': scale_s, 'apply_margin_train_only': apply_margin_train_only, 'num_classes': num_classes, } # Print model statistics total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) frozen_params = total_params - trainable_params # After creating model, before returning print("\nDiagnostic: Checking pentachora status...") for i, penta in enumerate(model.class_pentachora[:3]): # Check first 3 print(f"Pentachora {i}:") print(f" vertices requires_grad: {penta.vertices.requires_grad}") print(f" vertices mean: {penta.vertices.mean().item():.6f}") print(f" vertices std: {penta.vertices.std().item():.6f}") # Check a main model parameter print("\nMain model parameters:") if hasattr(model, 'patch_embed'): print(f" patch_embed.weight mean: {model.patch_embed.weight.mean().item():.6f}") print(f" patch_embed.weight std: {model.patch_embed.weight.std().item():.6f}") print(f"\nModel: {variant}") print(f" Classes: {num_classes}") print(f" Normalization: {norm_type.upper()}") print(f" Total params: {total_params:,}") print(f" Trainable params: {trainable_params:,}") print(f" Frozen pentachora params: {frozen_params:,}") return model # ========================= # Minimal load/save helpers # ========================= import os, json, math from pathlib import Path import torch import numpy as np try: from safetensors.torch import save_file, load_file except Exception as e: raise RuntimeError("safetensors is required: pip install safetensors") from e def _get_device(): return torch.device('cuda' if torch.cuda.is_available() else 'cpu') def _jsonify_obj(obj) -> dict: """Turn a config object or dict into a JSON-safe dict.""" if obj is None: return {} if isinstance(obj, dict): return obj out = {} for k in dir(obj): if k.startswith('_'): continue v = getattr(obj, k) if callable(v): continue if isinstance(v, torch.Tensor): v = v.tolist() elif isinstance(v, np.ndarray): v = v.tolist() out[k] = v return out def _ensure_model_config_dict(model): """Guarantee model.config is a dict describing the head + geometry relevant fields.""" if hasattr(model, "config") and isinstance(model.config, dict): return model.config cfg = { "arch": type(model).__name__, "num_classes": getattr(model, "num_classes", None), "embed_dim": getattr(model, "embed_dim", None), "pentachora_dim": getattr(model, "pentachora_dim", None), "img_size": getattr(model, "img_size", 32), "patch_size": getattr(model, "patch_size", 4), "norm_type": getattr(model, "norm_type", None), "similarity_mode": getattr(model, "similarity_mode", None), "head_type": getattr(model, "head_type", None), "prototype_mode": getattr(model, "prototype_mode", None), "margin_type": getattr(model, "margin_type", None), "margin_m": float(getattr(model, "margin_m", 0.0)) if hasattr(model, "margin_m") else None, "scale_s": float(getattr(model, "scale_s", 1.0)) if hasattr(model, "scale_s") else None, } model.config = cfg return cfg def _collect_state_tensors(state_dict): return {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} def _session_dir(paths: dict) -> Path: root = Path(paths["save_dir"]) return root / f"{paths['model_variant']}_{paths['session_timestamp']}" def _find_local_checkpoint(paths: dict) -> tuple[Path, Path | None, Path | None]: """ Return (weights_path, model_config_path, vocab_path) from the session dir. Prefer 'best_*.safetensors'; fall back to most recent '*.safetensors'. """ sdir = _session_dir(paths) if not sdir.exists(): return None, None, None safes = sorted(sdir.glob("*.safetensors"), key=lambda p: p.stat().st_mtime) if not safes: return None, None, None # prefer 'best_' if present bests = [p for p in safes if p.name.startswith("best_")] w = bests[-1] if bests else safes[-1] model_cfg = sdir / w.name.replace(".safetensors", "_model_config.json") vocab = sdir / w.name.replace(".safetensors", "_vocabulary.json") return w, (model_cfg if model_cfg.exists() else None), (vocab if vocab.exists() else None) def _load_saved_vocabulary(vocab_json_path: Path) -> list[torch.Tensor]: """Return list of [5,D] tensors from saved crystal JSON.""" with open(vocab_json_path, "r") as f: data = json.load(f) crystals = data.get("crystal_to_token", []) # crystals[i]['crystal'] is [5,D] list penta_list = [] for item in crystals: arr = torch.tensor(item["crystal"], dtype=torch.float32) penta_list.append(arr) return penta_list # ========================================= # SAVE: weights + model/training/vocabulary # ========================================= def save_existing_model( model, paths: dict, model_config=None, training_config=None, *, filename_base: str | None = None, save_vocabulary: bool = True, push_to_hub: bool | None = None ): """ Save the model to disk, and optionally upload to the HF Hub. Args: model: BaselineViT instance paths: { 'save_dir': str, 'model_variant': str, 'session_timestamp': str, # (optional for naming) 'epoch': int, 'val_acc': float, 'is_best': bool, # hub 'hub_repo': str, 'hub_token': str|None, } model_config: dict or object (optional; if None, built from model) training_config: TrainingConfig or dict (optional; saved to JSON) filename_base: override the base filename; if None, derived from epoch/acc/best save_vocabulary: write *_vocabulary.json from model.class_pentachora push_to_hub: override paths.get('push_to_hub') """ device = _get_device() sess_dir = _session_dir(paths) sess_dir.mkdir(parents=True, exist_ok=True) # ---- filename base if filename_base is None: ep = paths.get("epoch") acc = paths.get("val_acc") is_best = bool(paths.get("is_best", False)) tag = f"epoch{int(ep):03d}_acc{float(acc):.2f}" if (ep is not None and acc is not None) else "snapshot" filename_base = f"{'best_' if is_best else 'checkpoint_'}{tag}" # ---- weights weights_path = sess_dir / f"{filename_base}.safetensors" state = _collect_state_tensors(model.state_dict()) save_file(state, str(weights_path)) # ---- model config cfg_dict = _jsonify_obj(model_config) or _ensure_model_config_dict(model) model_cfg_path = sess_dir / f"{filename_base}_model_config.json" with open(model_cfg_path, "w") as f: json.dump(cfg_dict, f, indent=2, default=str) # ---- training config (metadata) if training_config is not None: train_cfg_dict = _jsonify_obj(training_config) train_cfg_path = sess_dir / f"{filename_base}_training_config.json" with open(train_cfg_path, "w") as f: json.dump(train_cfg_dict, f, indent=2, default=str) else: train_cfg_path = None # ---- vocabulary vocab_path = None if save_vocabulary and hasattr(model, "class_pentachora") and model.class_pentachora is not None: crystals = torch.stack([p.vertices for p in model.class_pentachora], dim=0).detach().cpu().numpy().tolist() vocab_data = { "vocab_dim": getattr(model, "pentachora_dim", None), "num_classes": len(model.class_pentachora), "num_vertices": 5, "tokens": CIFAR100_CLASSES[: len(crystals)], "crystal_to_token": [ {"index": i, "token": CIFAR100_CLASSES[i], "crystal": crystals[i]} for i in range(len(crystals)) ], } vocab_path = sess_dir / f"{filename_base}_vocabulary.json" with open(vocab_path, "w") as f: json.dump(vocab_data, f, indent=2) print(f"✓ Saved weights: {weights_path.name}") print(f"✓ Saved model config: {model_cfg_path.name}") if train_cfg_path: print(f"✓ Saved training config: {train_cfg_path.name}") if vocab_path: print(f"✓ Saved vocabulary: {vocab_path.name}") # ---- optional hub upload do_push = push_to_hub if push_to_hub is not None else paths.get("push_to_hub", False) if do_push: try: from huggingface_hub import HfApi, create_repo hub_repo = paths["hub_repo"] hub_token = paths.get("hub_token") subfolder = f"models/{paths['model_variant']}/{paths['session_timestamp']}" api = HfApi(token=hub_token) try: create_repo(hub_repo, token=hub_token, private=True, exist_ok=True) except Exception: pass def _up(p: Path): api.upload_file( path_or_fileobj=str(p), path_in_repo=f"{subfolder}/{p.name}", repo_id=hub_repo, repo_type="model" ) _up(weights_path); _up(model_cfg_path) if train_cfg_path: _up(train_cfg_path) if vocab_path: _up(vocab_path) print(f"✓ Pushed to hub: {hub_repo}/{subfolder}") except Exception as e: print(f"⚠ Hub upload failed: {e}") return { "weights": weights_path, "model_config": model_cfg_path, "training_config": train_cfg_path, "vocabulary": vocab_path, "session_dir": sess_dir } # ========================================= # LOAD: from disk or hub subfolder # ========================================= def load_existing_model( model_path: str | Path | None, paths: dict | None, model_config=None, training_config=None, *, from_hub: bool = False, prefer_best: bool = True, map_location: str | torch.device | None = None ): """ Load a saved model (weights + config), reconstruct the architecture via build_model, and return a ready-to-use model. If a saved vocabulary is present, reuse it. Args: model_path: explicit path to a .safetensors file; if None, resolve from `paths` paths: { 'save_dir': str, 'model_variant': str, 'session_timestamp': str, # (for hub) 'hub_repo': str, 'hub_token': str|None } from_hub: if True, pull from HF Hub subfolder models/{variant}/{session}/ prefer_best: when scanning a folder, pick 'best_*.safetensors' if available map_location: optional torch map_location Returns: model (on default device), resolved_paths dict """ device = _get_device() if map_location is None else map_location # ---------- resolve source files ---------- if model_path is not None: weights_path = Path(model_path) base = weights_path.name.replace(".safetensors", "") session_dir = weights_path.parent model_cfg_path = session_dir / f"{base}_model_config.json" vocab_path = session_dir / f"{base}_vocabulary.json" elif from_hub: try: from huggingface_hub import hf_hub_download except Exception as e: raise RuntimeError("huggingface_hub is required for from_hub=True") from e hub_repo = paths["hub_repo"] subfolder = f"models/{paths['model_variant']}/{paths['session_timestamp']}" # Download index (weights); prefer 'best_' by asking caller to pass the exact name or we try both # We will download repo file list is not available here; caller should pass model_path if you want a specific file. # Fallback: try canonical 'best_' name; else 'checkpoint_'. candidates = ["best", "checkpoint"] weights_path = None for pref in candidates: try: fname = None # look for any .safetensors in subfolder; require caller to provide exact file if multiple # Here we try a common name; if it fails, raise with guidance # (You can extend to list_repo_files if needed.) # Attempt pattern-less download will fail; so require explicit file or local resolution. # Safer approach: user supplies explicit model_path for hub. pass except Exception: pass raise RuntimeError( "When loading from Hub, please supply the explicit .safetensors filename in model_path " "(e.g., '.../best_epoch010_acc30.30.safetensors') or download locally first." ) else: # resolve from local session dir weights_path, model_cfg_path, vocab_path = _find_local_checkpoint(paths) if weights_path is None: raise FileNotFoundError("No checkpoint found in session folder") # ---------- read model config ---------- # prefer on-disk config; else use provided model_config; else minimal override dict if model_cfg_path and model_cfg_path.exists(): with open(model_cfg_path, "r") as f: cfg = json.load(f) else: cfg = _jsonify_obj(model_config) # variant + overrides to rebuild the model variant = cfg.get("variant", paths.get("model_variant") if paths else None) if variant is None: raise ValueError("Model variant not found in config; pass paths['model_variant'] or include 'variant'.") overrides = {} # allow restoring head settings if present for k in ("embed_dim","vocab_dim","depth","num_heads","mlp_ratio", "img_size","patch_size","dropout","attn_dropout", "norm_type","similarity_mode", "head_type","prototype_mode","margin_type","margin_m","scale_s", "apply_margin_train_only"): if k in cfg and cfg[k] is not None: overrides[k] = cfg[k] # ---------- rebuild model via your factory ---------- # IMPORTANT: if a saved vocabulary exists, load it to reproduce exact pentachora if 'vocabulary' in overrides: # just in case overrides.pop('vocabulary') if 'num_classes' in cfg: overrides['num_classes'] = cfg['num_classes'] # not used directly by build_model but okay to keep if 'vocab' in globals() and (not ('pentachora_list' in overrides)): # build_model will use vocab.encode_batch; if we have a saved vocab JSON, override afterwards model = build_model(variant=variant, **overrides).to(device) if 'get_default_device' in globals(): model = model.to(get_default_device()) else: model = build_model(variant=variant, **overrides).to(device) # if a vocabulary JSON exists, replace model.class_pentachora with saved crystals if 'vocab' in globals() and vocab_path and vocab_path.exists(): saved_penta = _load_saved_vocabulary(vocab_path) # list of [5,D] if hasattr(model, "class_pentachora") and len(saved_penta) == len(model.class_pentachora): # swap in the exact saved pentachora new_list = [] for p in saved_penta: new_list.append(type(model.class_pentachora[0])(p, norm_type=getattr(model, "norm_type", "l1"))) # rebuild ModuleList import torch.nn as nn model.class_pentachora = nn.ModuleList(new_list) # update normalized buffers inside PentachoraEmbedding if needed (constructor already handles it) # ---------- load weights ---------- sd = load_file(str(weights_path), device='cpu') print(f"\nCheckpoint contains {len(sd)} keys") print(f"First 5 keys: {list(sd.keys())[:5]}") # Check for compiled model prefix has_orig_mod = any(k.startswith("_orig_mod.") for k in sd.keys()) if has_orig_mod: print("Detected compiled model checkpoint (_orig_mod. prefix)") # Strip _orig_mod. if present fixed = {} for k, v in sd.items(): new_key = k[10:] if k.startswith("_orig_mod.") else k fixed[new_key] = v # Get model state dict for comparison model_state = model.state_dict() print(f"\nModel expects {len(model_state)} keys") print(f"First 5 expected: {list(model_state.keys())[:5]}") # Find mismatches checkpoint_keys = set(fixed.keys()) model_keys = set(model_state.keys()) missing_in_checkpoint = model_keys - checkpoint_keys unexpected_in_checkpoint = checkpoint_keys - model_keys print(f"\nKeys in model but not in checkpoint: {len(missing_in_checkpoint)}") if missing_in_checkpoint and len(missing_in_checkpoint) < 10: print(f" Missing: {list(missing_in_checkpoint)[:10]}") print(f"Keys in checkpoint but not in model: {len(unexpected_in_checkpoint)}") if unexpected_in_checkpoint and len(unexpected_in_checkpoint) < 10: print(f" Unexpected: {list(unexpected_in_checkpoint)[:10]}") # Load with strict=True to see the actual error try: model.load_state_dict(fixed, strict=True) print("✓ Strict load successful - all weights loaded") except RuntimeError as e: print(f"⚠ Strict load failed: {e}") # Fall back to non-strict incompatible = model.load_state_dict(fixed, strict=False) print(f"Loaded with strict=False") print(f" Missing keys: {len(incompatible.missing_keys)}") print(f" Unexpected keys: {len(incompatible.unexpected_keys)}") # Check if critical weights are missing critical_missing = [k for k in incompatible.missing_keys if 'weight' in k or 'bias' in k] if critical_missing: print(f" ⚠ Critical missing weights: {critical_missing[:5]}") # Verify weights aren't zero sample_weight = next(iter(model.parameters())) print(f"\nFirst parameter stats:") print(f" Shape: {sample_weight.shape}") print(f" Mean: {sample_weight.mean().item():.6f}") print(f" Std: {sample_weight.std().item():.6f}") print(f" Min: {sample_weight.min().item():.6f}") print(f" Max: {sample_weight.max().item():.6f}") model.eval() return model, { "weights": weights_path, "model_config": model_cfg_path, "vocabulary": vocab_path, "session_dir": weights_path.parent } def get_parameter_groups(model, weight_decay): """Create parameter groups with weight decay handling""" no_decay = ['bias', 'LayerNorm.weight', 'norm'] params_decay = [] params_no_decay = [] for name, param in model.named_parameters(): if param.requires_grad: if any(nd in name for nd in no_decay): params_no_decay.append(param) else: params_decay.append(param) return [ {'params': params_decay, 'weight_decay': weight_decay}, {'params': params_no_decay, 'weight_decay': 0.0} ] def create_scheduler(optimizer, config, start_epoch=0): """Create cosine scheduler with warmup""" def lr_lambda(epoch): if epoch < config.warmup_epochs: return epoch / config.warmup_epochs if config.epochs <= config.warmup_epochs: return 1.0 return 0.5 * (1 + np.cos(np.pi * (epoch - config.warmup_epochs) / (config.epochs - config.warmup_epochs))) scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # Fast-forward to correct epoch if resuming for _ in range(start_epoch): scheduler.step() return scheduler def count_parameters(model): """Count model parameters""" total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) return {'total': total, 'trainable': trainable} # Test loading if __name__ == "__main__": print("Testing model loader...") print("=" * 50) # Test load a small model model = build_model('vit_beatrix_shaper').to(get_default_device()) #model = load_exisiting_model( # Test forward pass x = torch.randn(4, 3, 32, 32).to(get_default_device()) output = model(x) print(f"\nForward pass successful!") print(f" Input shape: {x.shape}") print(f" Logits shape: {output['logits'].shape}") print(f" Similarities shape: {output['similarities'].shape}") print("\n✓ Model loader working correctly!")