| | """ |
| | Pentachora batch generation and model creation. |
| | Assumes vocab is already loaded as 'vocab'. |
| | Assumes PentachoronStabilizer is already loaded. |
| | Assumes BaselineViT is already loaded. |
| | """ |
| |
|
| | import torch |
| | import numpy as np |
| |
|
| | |
| | CIFAR100_CLASSES = [ |
| | 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', |
| | 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', |
| | 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', |
| | 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', |
| | 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', |
| | 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', |
| | 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', |
| | 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', |
| | 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', |
| | 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', |
| | 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', |
| | 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', |
| | 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', |
| | 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' |
| | ] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | MODEL_CONFIGS = { |
| | |
| | |
| | 'vit_beatrix_shaper': { |
| | 'embed_dim': 256, |
| | 'vocab_dim': 256, |
| | 'depth': 16, |
| | 'num_heads': 8, |
| | 'mlp_ratio': 1.0, |
| | |
| | 'margin_type': 'cosface', |
| | 'margin_m': 0.30, |
| | 'scale_s': 30.0, |
| | }, |
| | 'vit_beatrix_arc_shaper': { |
| | 'embed_dim': 256, |
| | 'vocab_dim': 256, |
| | 'depth': 16, |
| | 'num_heads': 8, |
| | 'mlp_ratio': 2.0, |
| | |
| | 'margin_type': 'arcface', |
| | 'margin_m': 0.2914, |
| | 'scale_s': 30.0, |
| | }, |
| | 'vit_beatrix_nano_arc': { |
| | 'embed_dim': 64, |
| | 'vocab_dim': 64, |
| | 'depth': 25, |
| | 'num_heads': 8, |
| | 'mlp_ratio': 8.0, |
| | |
| | 'margin_type': 'arcface', |
| | 'margin_m': 0.2914, |
| | 'scale_s': 30.0, |
| | }, |
| | 'vit_beatrix_nano_cos': { |
| | 'embed_dim': 64, |
| | 'vocab_dim': 64, |
| | 'depth': 25, |
| | 'num_heads': 8, |
| | 'mlp_ratio': 8.0, |
| | |
| | 'margin_type': 'cosface', |
| | 'margin_m': 0.2914, |
| | 'scale_s': 30.0, |
| | }, |
| | 'vit_beatrix_nano_128_cos': { |
| | 'embed_dim': 128, |
| | 'vocab_dim': 128, |
| | 'depth': 25, |
| | 'num_heads': 8, |
| | 'mlp_ratio': 8.0, |
| | |
| | 'margin_type': 'cosface', |
| | 'margin_m': 0.2914, |
| | 'scale_s': 30.0, |
| | }, |
| | 'vit_beatrix_mini_cos': { |
| | 'embed_dim': 256, |
| | 'vocab_dim': 256, |
| | 'depth': 25, |
| | 'num_heads': 8, |
| | 'mlp_ratio': 8.0, |
| | |
| | 'margin_type': 'cosface', |
| | 'margin_m': 0.2914, |
| | 'scale_s': 30.0, |
| | }, |
| | 'vit_beatrix_mini_cos_large_margin': { |
| | 'embed_dim': 256, |
| | 'vocab_dim': 256, |
| | 'depth': 25, |
| | 'num_heads': 8, |
| | 'mlp_ratio': 8.0, |
| | |
| | 'margin_type': 'cosface', |
| | 'margin_m': 0.7086, |
| | 'scale_s': 30.0, |
| | }, |
| | 'vit_zana_nano': { |
| | 'embed_dim': 128, |
| | 'vocab_dim': 128, |
| | 'depth': 4, |
| | 'num_heads': 2, |
| | 'mlp_ratio': 2.0 |
| | }, |
| | 'vit_beatrix_base_cos': { |
| | 'embed_dim': 512, |
| | 'vocab_dim': 512, |
| | 'depth': 25, |
| | 'num_heads': 16, |
| | 'mlp_ratio': 8.0, |
| | |
| | 'margin_type': 'cosface', |
| | 'margin_m': 0.2914, |
| | 'scale_s': 30.0, |
| | }, |
| | 'vit_zana_nano_deep': { |
| | 'embed_dim': 128, |
| | 'vocab_dim': 128, |
| | 'depth': 8, |
| | 'num_heads': 4, |
| | 'mlp_ratio': 2.0 |
| | }, |
| | 'vit_zana_shaper': { |
| | 'embed_dim': 256, |
| | 'vocab_dim': 256, |
| | 'depth': 32, |
| | 'num_heads': 8, |
| | 'mlp_ratio': 4.0 |
| | }, |
| | 'vit_zana_nano_thicc': { |
| | 'embed_dim': 128, |
| | 'vocab_dim': 128, |
| | 'depth': 4, |
| | 'num_heads': 8, |
| | 'mlp_ratio': 4.0 |
| | }, |
| | 'vit_zana_micro': { |
| | 'embed_dim': 500, |
| | 'vocab_dim': 25, |
| | 'depth': 6, |
| | 'num_heads': 2, |
| | 'mlp_ratio': 2.0 |
| | }, |
| | 'vit_zana_micro_500': { |
| | 'embed_dim': 500, |
| | 'vocab_dim': 25, |
| | 'depth': 6, |
| | 'num_heads': 5, |
| | 'mlp_ratio': 2.0 |
| | }, |
| |
|
| | 'vit_zana_base': { |
| | 'embed_dim': 512, |
| | 'vocab_dim': 512, |
| | 'depth': 16, |
| | 'num_heads': 4, |
| | 'mlp_ratio': 4.0 |
| | }, |
| | 'vit_ursula_nano_1000': { |
| | 'embed_dim': 1000, |
| | 'vocab_dim': 500, |
| | 'depth': 4, |
| | 'num_heads': 50, |
| | 'mlp_ratio': 4.0 |
| | }, |
| | 'vit_ursula_nano': { |
| | 'embed_dim': 1000, |
| | 'vocab_dim': 25, |
| | 'depth': 4, |
| | 'num_heads': 10, |
| | 'mlp_ratio': 4.0 |
| | }, |
| | |
| | |
| | 'tiny': { |
| | 'embed_dim': 192, |
| | 'vocab_dim': 192, |
| | 'depth': 12, |
| | 'num_heads': 3, |
| | 'mlp_ratio': 4.0 |
| | }, |
| | |
| | 'vit_ursula_mini': { |
| | 'embed_dim': 256, |
| | 'vocab_dim': 256, |
| | 'depth': 12, |
| | 'num_heads': 4, |
| | 'mlp_ratio': 4.0 |
| | }, |
| | |
| | |
| | 'small': { |
| | 'embed_dim': 384, |
| | 'vocab_dim': 384, |
| | 'depth': 12, |
| | 'num_heads': 6, |
| | 'mlp_ratio': 4.0 |
| | }, |
| | |
| | 'base': { |
| | 'embed_dim': 768, |
| | 'vocab_dim': 768, |
| | 'depth': 12, |
| | 'num_heads': 12, |
| | 'mlp_ratio': 4.0 |
| | }, |
| | |
| | |
| | 'wide_shallow': { |
| | 'embed_dim': 1024, |
| | 'vocab_dim': 1024, |
| | 'depth': 4, |
| | 'num_heads': 16, |
| | 'mlp_ratio': 2.0 |
| | }, |
| | |
| | 'narrow_deep': { |
| | 'embed_dim': 192, |
| | 'vocab_dim': 192, |
| | 'depth': 24, |
| | 'num_heads': 3, |
| | 'mlp_ratio': 4.0 |
| | }, |
| | } |
| |
|
| |
|
| | """ |
| | Updated pentachora batch generation and model creation for L1 norm. |
| | Add this modification to your existing build_model function. |
| | """ |
| |
|
| | def build_model(variant='small', **override_params): |
| | """ |
| | Build model with explicit parameter handling - no hidden kwargs. |
| | |
| | Args: |
| | variant: Model variant name from MODEL_CONFIGS |
| | **override_params: Individual parameter overrides |
| | |
| | Returns: |
| | model: BaselineViT model with frozen pentachora |
| | """ |
| | assert variant in MODEL_CONFIGS, f"Unknown variant: {variant}. Choose from {list(MODEL_CONFIGS.keys())}" |
| | base_config = MODEL_CONFIGS[variant].copy() |
| | |
| | |
| | |
| | embed_dim = override_params.get('embed_dim', base_config.get('embed_dim', 512)) |
| | vocab_dim = override_params.get('vocab_dim', base_config.get('vocab_dim', 512)) |
| | depth = override_params.get('depth', base_config.get('depth', 12)) |
| | num_heads = override_params.get('num_heads', base_config.get('num_heads', 8)) |
| | mlp_ratio = override_params.get('mlp_ratio', base_config.get('mlp_ratio', 4.0)) |
| | |
| | |
| | img_size = override_params.get('img_size', base_config.get('img_size', 32)) |
| | patch_size = override_params.get('patch_size', base_config.get('patch_size', 4)) |
| | |
| | |
| | dropout = override_params.get('dropout', base_config.get('dropout', 0.0)) |
| | attn_dropout = override_params.get('attn_dropout', base_config.get('attn_dropout', 0.0)) |
| | |
| | |
| | similarity_mode = override_params.get('similarity_mode', base_config.get('similarity_mode', 'rose')) |
| | norm_type = override_params.get('norm_type', base_config.get('norm_type', 'l1')) |
| | |
| | |
| | head_type = override_params.get('head_type', base_config.get('head_type', 'roseface')) |
| | prototype_mode = override_params.get('prototype_mode', base_config.get('prototype_mode', 'centroid')) |
| | margin_type = override_params.get('margin_type', base_config.get('margin_type', 'cosface')) |
| | margin_m = float(override_params.get('margin_m', base_config.get('margin_m', 0.30))) |
| | scale_s = float(override_params.get('scale_s', base_config.get('scale_s', 30.0))) |
| | apply_margin_train_only = override_params.get('apply_margin_train_only', |
| | base_config.get('apply_margin_train_only', False)) |
| | |
| | |
| | num_classes = len(CIFAR100_CLASSES) |
| | |
| | |
| | print(f"Building {variant}:") |
| | print(f" Architecture: embed={embed_dim}, vocab={vocab_dim}, depth={depth}, heads={num_heads}") |
| | print(f" Image: {img_size}x{img_size}, patch={patch_size}x{patch_size}") |
| | print(f" RoseFace: {margin_type}, m={margin_m:.4f}, s={scale_s:.1f}") |
| | print(f" Norm: {norm_type}, Similarity: {similarity_mode}") |
| | |
| | |
| | print(f"Generating {num_classes} pentachora from vocabulary...") |
| | class_names = CIFAR100_CLASSES[:num_classes] |
| | |
| | |
| | pentachora_np_list = vocab.encode_batch(class_names, generate=True) |
| | |
| | |
| | raw_penta_list = [torch.tensor(penta, dtype=torch.float32) for penta in pentachora_np_list] |
| | |
| | |
| | pentachora_list = [] |
| | for i, penta in enumerate(raw_penta_list): |
| | if penta.shape[-1] != vocab_dim: |
| | current_dim = penta.shape[-1] |
| | |
| | if current_dim > vocab_dim: |
| | |
| | resized_vertices = [] |
| | for v in range(penta.shape[0]): |
| | indices = torch.linspace(0, current_dim - 1, vocab_dim) |
| | vertex = penta[v] |
| | left_idx = indices.floor().long() |
| | right_idx = (left_idx + 1).clamp(max=current_dim - 1) |
| | alpha = indices - left_idx.float() |
| | interpolated = vertex[left_idx] * (1 - alpha) + vertex[right_idx] * alpha |
| | resized_vertices.append(interpolated) |
| | penta_resized = torch.stack(resized_vertices) |
| | if i == 0: |
| | print(f" Downsampling pentachora from {current_dim} to {vocab_dim}") |
| | else: |
| | |
| | resized_vertices = [] |
| | for v in range(penta.shape[0]): |
| | vertex = penta[v] |
| | x = torch.linspace(0, current_dim - 1, vocab_dim) |
| | interpolated = torch.zeros(vocab_dim, dtype=vertex.dtype, device=vertex.device) |
| | for j in range(vocab_dim): |
| | if x[j] <= 0: |
| | interpolated[j] = vertex[0] |
| | elif x[j] >= current_dim - 1: |
| | interpolated[j] = vertex[-1] |
| | else: |
| | left = int(x[j]) |
| | alpha = x[j] - left |
| | interpolated[j] = vertex[left] * (1 - alpha) + vertex[left + 1] * alpha |
| | resized_vertices.append(interpolated) |
| | penta_resized = torch.stack(resized_vertices) |
| | if i == 0: |
| | print(f" Upsampling pentachora from {current_dim} to {vocab_dim}") |
| | |
| | pentachora_list.append(penta_resized) |
| | else: |
| | pentachora_list.append(penta.detach().clone().to(get_default_device())) |
| | |
| | print(f"Using {num_classes} L1-normalized pentachora") |
| | |
| | |
| | model = BaselineViT( |
| | pentachora_list=pentachora_list, |
| | vocab_dim=vocab_dim, |
| | img_size=img_size, |
| | patch_size=patch_size, |
| | embed_dim=embed_dim, |
| | depth=depth, |
| | num_heads=num_heads, |
| | mlp_ratio=mlp_ratio, |
| | dropout=dropout, |
| | attn_dropout=attn_dropout, |
| | similarity_mode=similarity_mode, |
| | norm_type=norm_type, |
| | head_type=head_type, |
| | prototype_mode=prototype_mode, |
| | margin_type=margin_type, |
| | margin_m=margin_m, |
| | scale_s=scale_s, |
| | apply_margin_train_only=apply_margin_train_only |
| | ) |
| | |
| | |
| | model.config = { |
| | 'variant': variant, |
| | 'vocab_dim': vocab_dim, |
| | 'embed_dim': embed_dim, |
| | 'depth': depth, |
| | 'num_heads': num_heads, |
| | 'mlp_ratio': mlp_ratio, |
| | 'img_size': img_size, |
| | 'patch_size': patch_size, |
| | 'dropout': dropout, |
| | 'attn_dropout': attn_dropout, |
| | 'similarity_mode': similarity_mode, |
| | 'norm_type': norm_type, |
| | 'head_type': head_type, |
| | 'prototype_mode': prototype_mode, |
| | 'margin_type': margin_type, |
| | 'margin_m': margin_m, |
| | 'scale_s': scale_s, |
| | 'apply_margin_train_only': apply_margin_train_only, |
| | 'num_classes': num_classes, |
| | } |
| | |
| | |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | frozen_params = total_params - trainable_params |
| |
|
| | |
| | print("\nDiagnostic: Checking pentachora status...") |
| | for i, penta in enumerate(model.class_pentachora[:3]): |
| | print(f"Pentachora {i}:") |
| | print(f" vertices requires_grad: {penta.vertices.requires_grad}") |
| | print(f" vertices mean: {penta.vertices.mean().item():.6f}") |
| | print(f" vertices std: {penta.vertices.std().item():.6f}") |
| |
|
| | |
| | print("\nMain model parameters:") |
| | if hasattr(model, 'patch_embed'): |
| | print(f" patch_embed.weight mean: {model.patch_embed.weight.mean().item():.6f}") |
| | print(f" patch_embed.weight std: {model.patch_embed.weight.std().item():.6f}") |
| | |
| | print(f"\nModel: {variant}") |
| | print(f" Classes: {num_classes}") |
| | print(f" Normalization: {norm_type.upper()}") |
| | print(f" Total params: {total_params:,}") |
| | print(f" Trainable params: {trainable_params:,}") |
| | print(f" Frozen pentachora params: {frozen_params:,}") |
| | |
| | return model |
| |
|
| | |
| | |
| | |
| | import os, json, math |
| | from pathlib import Path |
| | import torch |
| | import numpy as np |
| |
|
| | try: |
| | from safetensors.torch import save_file, load_file |
| | except Exception as e: |
| | raise RuntimeError("safetensors is required: pip install safetensors") from e |
| |
|
| | def _get_device(): |
| | return torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | def _jsonify_obj(obj) -> dict: |
| | """Turn a config object or dict into a JSON-safe dict.""" |
| | if obj is None: |
| | return {} |
| | if isinstance(obj, dict): |
| | return obj |
| | out = {} |
| | for k in dir(obj): |
| | if k.startswith('_'): |
| | continue |
| | v = getattr(obj, k) |
| | if callable(v): |
| | continue |
| | if isinstance(v, torch.Tensor): |
| | v = v.tolist() |
| | elif isinstance(v, np.ndarray): |
| | v = v.tolist() |
| | out[k] = v |
| | return out |
| |
|
| | def _ensure_model_config_dict(model): |
| | """Guarantee model.config is a dict describing the head + geometry relevant fields.""" |
| | if hasattr(model, "config") and isinstance(model.config, dict): |
| | return model.config |
| | cfg = { |
| | "arch": type(model).__name__, |
| | "num_classes": getattr(model, "num_classes", None), |
| | "embed_dim": getattr(model, "embed_dim", None), |
| | "pentachora_dim": getattr(model, "pentachora_dim", None), |
| | "img_size": getattr(model, "img_size", 32), |
| | "patch_size": getattr(model, "patch_size", 4), |
| | "norm_type": getattr(model, "norm_type", None), |
| | "similarity_mode": getattr(model, "similarity_mode", None), |
| | "head_type": getattr(model, "head_type", None), |
| | "prototype_mode": getattr(model, "prototype_mode", None), |
| | "margin_type": getattr(model, "margin_type", None), |
| | "margin_m": float(getattr(model, "margin_m", 0.0)) if hasattr(model, "margin_m") else None, |
| | "scale_s": float(getattr(model, "scale_s", 1.0)) if hasattr(model, "scale_s") else None, |
| | } |
| | model.config = cfg |
| | return cfg |
| |
|
| | def _collect_state_tensors(state_dict): |
| | return {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} |
| |
|
| | def _session_dir(paths: dict) -> Path: |
| | root = Path(paths["save_dir"]) |
| | return root / f"{paths['model_variant']}_{paths['session_timestamp']}" |
| |
|
| | def _find_local_checkpoint(paths: dict) -> tuple[Path, Path | None, Path | None]: |
| | """ |
| | Return (weights_path, model_config_path, vocab_path) from the session dir. |
| | Prefer 'best_*.safetensors'; fall back to most recent '*.safetensors'. |
| | """ |
| | sdir = _session_dir(paths) |
| | if not sdir.exists(): |
| | return None, None, None |
| | safes = sorted(sdir.glob("*.safetensors"), key=lambda p: p.stat().st_mtime) |
| | if not safes: |
| | return None, None, None |
| | |
| | bests = [p for p in safes if p.name.startswith("best_")] |
| | w = bests[-1] if bests else safes[-1] |
| | model_cfg = sdir / w.name.replace(".safetensors", "_model_config.json") |
| | vocab = sdir / w.name.replace(".safetensors", "_vocabulary.json") |
| | return w, (model_cfg if model_cfg.exists() else None), (vocab if vocab.exists() else None) |
| |
|
| | def _load_saved_vocabulary(vocab_json_path: Path) -> list[torch.Tensor]: |
| | """Return list of [5,D] tensors from saved crystal JSON.""" |
| | with open(vocab_json_path, "r") as f: |
| | data = json.load(f) |
| | crystals = data.get("crystal_to_token", []) |
| | |
| | penta_list = [] |
| | for item in crystals: |
| | arr = torch.tensor(item["crystal"], dtype=torch.float32) |
| | penta_list.append(arr) |
| | return penta_list |
| |
|
| | |
| | |
| | |
| | def save_existing_model( |
| | model, |
| | paths: dict, |
| | model_config=None, |
| | training_config=None, |
| | *, |
| | filename_base: str | None = None, |
| | save_vocabulary: bool = True, |
| | push_to_hub: bool | None = None |
| | ): |
| | """ |
| | Save the model to disk, and optionally upload to the HF Hub. |
| | |
| | Args: |
| | model: BaselineViT instance |
| | paths: { |
| | 'save_dir': str, |
| | 'model_variant': str, |
| | 'session_timestamp': str, |
| | # (optional for naming) |
| | 'epoch': int, |
| | 'val_acc': float, |
| | 'is_best': bool, |
| | # hub |
| | 'hub_repo': str, |
| | 'hub_token': str|None, |
| | } |
| | model_config: dict or object (optional; if None, built from model) |
| | training_config: TrainingConfig or dict (optional; saved to JSON) |
| | filename_base: override the base filename; if None, derived from epoch/acc/best |
| | save_vocabulary: write *_vocabulary.json from model.class_pentachora |
| | push_to_hub: override paths.get('push_to_hub') |
| | """ |
| | device = _get_device() |
| | sess_dir = _session_dir(paths) |
| | sess_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | if filename_base is None: |
| | ep = paths.get("epoch") |
| | acc = paths.get("val_acc") |
| | is_best = bool(paths.get("is_best", False)) |
| | tag = f"epoch{int(ep):03d}_acc{float(acc):.2f}" if (ep is not None and acc is not None) else "snapshot" |
| | filename_base = f"{'best_' if is_best else 'checkpoint_'}{tag}" |
| |
|
| | |
| | weights_path = sess_dir / f"{filename_base}.safetensors" |
| | state = _collect_state_tensors(model.state_dict()) |
| | save_file(state, str(weights_path)) |
| |
|
| | |
| | cfg_dict = _jsonify_obj(model_config) or _ensure_model_config_dict(model) |
| | model_cfg_path = sess_dir / f"{filename_base}_model_config.json" |
| | with open(model_cfg_path, "w") as f: |
| | json.dump(cfg_dict, f, indent=2, default=str) |
| |
|
| | |
| | if training_config is not None: |
| | train_cfg_dict = _jsonify_obj(training_config) |
| | train_cfg_path = sess_dir / f"{filename_base}_training_config.json" |
| | with open(train_cfg_path, "w") as f: |
| | json.dump(train_cfg_dict, f, indent=2, default=str) |
| | else: |
| | train_cfg_path = None |
| |
|
| | |
| | vocab_path = None |
| | if save_vocabulary and hasattr(model, "class_pentachora") and model.class_pentachora is not None: |
| | crystals = torch.stack([p.vertices for p in model.class_pentachora], dim=0).detach().cpu().numpy().tolist() |
| | vocab_data = { |
| | "vocab_dim": getattr(model, "pentachora_dim", None), |
| | "num_classes": len(model.class_pentachora), |
| | "num_vertices": 5, |
| | "tokens": CIFAR100_CLASSES[: len(crystals)], |
| | "crystal_to_token": [ |
| | {"index": i, "token": CIFAR100_CLASSES[i], "crystal": crystals[i]} |
| | for i in range(len(crystals)) |
| | ], |
| | } |
| | vocab_path = sess_dir / f"{filename_base}_vocabulary.json" |
| | with open(vocab_path, "w") as f: |
| | json.dump(vocab_data, f, indent=2) |
| |
|
| | print(f"✓ Saved weights: {weights_path.name}") |
| | print(f"✓ Saved model config: {model_cfg_path.name}") |
| | if train_cfg_path: |
| | print(f"✓ Saved training config: {train_cfg_path.name}") |
| | if vocab_path: |
| | print(f"✓ Saved vocabulary: {vocab_path.name}") |
| |
|
| | |
| | do_push = push_to_hub if push_to_hub is not None else paths.get("push_to_hub", False) |
| | if do_push: |
| | try: |
| | from huggingface_hub import HfApi, create_repo |
| | hub_repo = paths["hub_repo"] |
| | hub_token = paths.get("hub_token") |
| | subfolder = f"models/{paths['model_variant']}/{paths['session_timestamp']}" |
| |
|
| | api = HfApi(token=hub_token) |
| | try: |
| | create_repo(hub_repo, token=hub_token, private=True, exist_ok=True) |
| | except Exception: |
| | pass |
| |
|
| | def _up(p: Path): |
| | api.upload_file( |
| | path_or_fileobj=str(p), |
| | path_in_repo=f"{subfolder}/{p.name}", |
| | repo_id=hub_repo, |
| | repo_type="model" |
| | ) |
| |
|
| | _up(weights_path); _up(model_cfg_path) |
| | if train_cfg_path: _up(train_cfg_path) |
| | if vocab_path: _up(vocab_path) |
| | print(f"✓ Pushed to hub: {hub_repo}/{subfolder}") |
| | except Exception as e: |
| | print(f"⚠ Hub upload failed: {e}") |
| |
|
| | return { |
| | "weights": weights_path, |
| | "model_config": model_cfg_path, |
| | "training_config": train_cfg_path, |
| | "vocabulary": vocab_path, |
| | "session_dir": sess_dir |
| | } |
| |
|
| | |
| | |
| | |
| | def load_existing_model( |
| | model_path: str | Path | None, |
| | paths: dict | None, |
| | model_config=None, |
| | training_config=None, |
| | *, |
| | from_hub: bool = False, |
| | prefer_best: bool = True, |
| | map_location: str | torch.device | None = None |
| | ): |
| | """ |
| | Load a saved model (weights + config), reconstruct the architecture via build_model, |
| | and return a ready-to-use model. If a saved vocabulary is present, reuse it. |
| | |
| | Args: |
| | model_path: explicit path to a .safetensors file; if None, resolve from `paths` |
| | paths: { |
| | 'save_dir': str, 'model_variant': str, 'session_timestamp': str, |
| | # (for hub) |
| | 'hub_repo': str, 'hub_token': str|None |
| | } |
| | from_hub: if True, pull from HF Hub subfolder models/{variant}/{session}/ |
| | prefer_best: when scanning a folder, pick 'best_*.safetensors' if available |
| | map_location: optional torch map_location |
| | |
| | Returns: |
| | model (on default device), resolved_paths dict |
| | """ |
| | device = _get_device() if map_location is None else map_location |
| |
|
| | |
| | if model_path is not None: |
| | weights_path = Path(model_path) |
| | base = weights_path.name.replace(".safetensors", "") |
| | session_dir = weights_path.parent |
| | model_cfg_path = session_dir / f"{base}_model_config.json" |
| | vocab_path = session_dir / f"{base}_vocabulary.json" |
| | elif from_hub: |
| | try: |
| | from huggingface_hub import hf_hub_download |
| | except Exception as e: |
| | raise RuntimeError("huggingface_hub is required for from_hub=True") from e |
| | hub_repo = paths["hub_repo"] |
| | subfolder = f"models/{paths['model_variant']}/{paths['session_timestamp']}" |
| | |
| | |
| | |
| | candidates = ["best", "checkpoint"] |
| | weights_path = None |
| | for pref in candidates: |
| | try: |
| | fname = None |
| | |
| | |
| | |
| | |
| | |
| | pass |
| | except Exception: |
| | pass |
| | raise RuntimeError( |
| | "When loading from Hub, please supply the explicit .safetensors filename in model_path " |
| | "(e.g., '.../best_epoch010_acc30.30.safetensors') or download locally first." |
| | ) |
| | else: |
| | |
| | weights_path, model_cfg_path, vocab_path = _find_local_checkpoint(paths) |
| | if weights_path is None: |
| | raise FileNotFoundError("No checkpoint found in session folder") |
| |
|
| | |
| | |
| | if model_cfg_path and model_cfg_path.exists(): |
| | with open(model_cfg_path, "r") as f: |
| | cfg = json.load(f) |
| | else: |
| | cfg = _jsonify_obj(model_config) |
| |
|
| | |
| | variant = cfg.get("variant", paths.get("model_variant") if paths else None) |
| | if variant is None: |
| | raise ValueError("Model variant not found in config; pass paths['model_variant'] or include 'variant'.") |
| |
|
| | overrides = {} |
| | |
| | for k in ("embed_dim","vocab_dim","depth","num_heads","mlp_ratio", |
| | "img_size","patch_size","dropout","attn_dropout", |
| | "norm_type","similarity_mode", |
| | "head_type","prototype_mode","margin_type","margin_m","scale_s", |
| | "apply_margin_train_only"): |
| | if k in cfg and cfg[k] is not None: |
| | overrides[k] = cfg[k] |
| |
|
| | |
| | |
| | if 'vocabulary' in overrides: |
| | overrides.pop('vocabulary') |
| | if 'num_classes' in cfg: |
| | overrides['num_classes'] = cfg['num_classes'] |
| |
|
| | if 'vocab' in globals() and (not ('pentachora_list' in overrides)): |
| | |
| | model = build_model(variant=variant, **overrides).to(device) |
| | if 'get_default_device' in globals(): |
| | model = model.to(get_default_device()) |
| | else: |
| | model = build_model(variant=variant, **overrides).to(device) |
| |
|
| | |
| | if 'vocab' in globals() and vocab_path and vocab_path.exists(): |
| | saved_penta = _load_saved_vocabulary(vocab_path) |
| | if hasattr(model, "class_pentachora") and len(saved_penta) == len(model.class_pentachora): |
| | |
| | new_list = [] |
| | for p in saved_penta: |
| | new_list.append(type(model.class_pentachora[0])(p, norm_type=getattr(model, "norm_type", "l1"))) |
| | |
| | import torch.nn as nn |
| | model.class_pentachora = nn.ModuleList(new_list) |
| | |
| |
|
| | |
| | sd = load_file(str(weights_path), device='cpu') |
| | print(f"\nCheckpoint contains {len(sd)} keys") |
| | print(f"First 5 keys: {list(sd.keys())[:5]}") |
| |
|
| | |
| | has_orig_mod = any(k.startswith("_orig_mod.") for k in sd.keys()) |
| | if has_orig_mod: |
| | print("Detected compiled model checkpoint (_orig_mod. prefix)") |
| |
|
| | |
| | fixed = {} |
| | for k, v in sd.items(): |
| | new_key = k[10:] if k.startswith("_orig_mod.") else k |
| | fixed[new_key] = v |
| |
|
| | |
| | model_state = model.state_dict() |
| | print(f"\nModel expects {len(model_state)} keys") |
| | print(f"First 5 expected: {list(model_state.keys())[:5]}") |
| |
|
| | |
| | checkpoint_keys = set(fixed.keys()) |
| | model_keys = set(model_state.keys()) |
| |
|
| | missing_in_checkpoint = model_keys - checkpoint_keys |
| | unexpected_in_checkpoint = checkpoint_keys - model_keys |
| |
|
| | print(f"\nKeys in model but not in checkpoint: {len(missing_in_checkpoint)}") |
| | if missing_in_checkpoint and len(missing_in_checkpoint) < 10: |
| | print(f" Missing: {list(missing_in_checkpoint)[:10]}") |
| |
|
| | print(f"Keys in checkpoint but not in model: {len(unexpected_in_checkpoint)}") |
| | if unexpected_in_checkpoint and len(unexpected_in_checkpoint) < 10: |
| | print(f" Unexpected: {list(unexpected_in_checkpoint)[:10]}") |
| |
|
| | |
| | try: |
| | model.load_state_dict(fixed, strict=True) |
| | print("✓ Strict load successful - all weights loaded") |
| | except RuntimeError as e: |
| | print(f"⚠ Strict load failed: {e}") |
| | |
| | incompatible = model.load_state_dict(fixed, strict=False) |
| | print(f"Loaded with strict=False") |
| | print(f" Missing keys: {len(incompatible.missing_keys)}") |
| | print(f" Unexpected keys: {len(incompatible.unexpected_keys)}") |
| | |
| | |
| | critical_missing = [k for k in incompatible.missing_keys if 'weight' in k or 'bias' in k] |
| | if critical_missing: |
| | print(f" ⚠ Critical missing weights: {critical_missing[:5]}") |
| |
|
| | |
| | sample_weight = next(iter(model.parameters())) |
| | print(f"\nFirst parameter stats:") |
| | print(f" Shape: {sample_weight.shape}") |
| | print(f" Mean: {sample_weight.mean().item():.6f}") |
| | print(f" Std: {sample_weight.std().item():.6f}") |
| | print(f" Min: {sample_weight.min().item():.6f}") |
| | print(f" Max: {sample_weight.max().item():.6f}") |
| |
|
| | model.eval() |
| | return model, { |
| | "weights": weights_path, |
| | "model_config": model_cfg_path, |
| | "vocabulary": vocab_path, |
| | "session_dir": weights_path.parent |
| | } |
| |
|
| | def get_parameter_groups(model, weight_decay): |
| | """Create parameter groups with weight decay handling""" |
| | no_decay = ['bias', 'LayerNorm.weight', 'norm'] |
| | params_decay = [] |
| | params_no_decay = [] |
| | |
| | for name, param in model.named_parameters(): |
| | if param.requires_grad: |
| | if any(nd in name for nd in no_decay): |
| | params_no_decay.append(param) |
| | else: |
| | params_decay.append(param) |
| | |
| | return [ |
| | {'params': params_decay, 'weight_decay': weight_decay}, |
| | {'params': params_no_decay, 'weight_decay': 0.0} |
| | ] |
| |
|
| | def create_scheduler(optimizer, config, start_epoch=0): |
| | """Create cosine scheduler with warmup""" |
| | def lr_lambda(epoch): |
| | if epoch < config.warmup_epochs: |
| | return epoch / config.warmup_epochs |
| | if config.epochs <= config.warmup_epochs: |
| | return 1.0 |
| | return 0.5 * (1 + np.cos(np.pi * (epoch - config.warmup_epochs) / |
| | (config.epochs - config.warmup_epochs))) |
| | |
| | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
| | |
| | |
| | for _ in range(start_epoch): |
| | scheduler.step() |
| | |
| | return scheduler |
| |
|
| | def count_parameters(model): |
| | """Count model parameters""" |
| | total = sum(p.numel() for p in model.parameters()) |
| | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | return {'total': total, 'trainable': trainable} |
| |
|
| | |
| | if __name__ == "__main__": |
| | print("Testing model loader...") |
| | print("=" * 50) |
| | |
| | |
| | model = build_model('vit_beatrix_shaper').to(get_default_device()) |
| | |
| | |
| | |
| | x = torch.randn(4, 3, 32, 32).to(get_default_device()) |
| | output = model(x) |
| | |
| | print(f"\nForward pass successful!") |
| | print(f" Input shape: {x.shape}") |
| | print(f" Logits shape: {output['logits'].shape}") |
| | print(f" Similarities shape: {output['similarities'].shape}") |
| | |
| | print("\n✓ Model loader working correctly!") |