|
|
""" |
|
|
Pentachora batch generation and model creation. |
|
|
Assumes vocab is already loaded as 'vocab'. |
|
|
Assumes PentachoronStabilizer is already loaded. |
|
|
Assumes BaselineViT is already loaded. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
CIFAR100_CLASSES = [ |
|
|
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', |
|
|
'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', |
|
|
'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', |
|
|
'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', |
|
|
'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', |
|
|
'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', |
|
|
'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', |
|
|
'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', |
|
|
'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', |
|
|
'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', |
|
|
'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', |
|
|
'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', |
|
|
'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', |
|
|
'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_CONFIGS = { |
|
|
|
|
|
|
|
|
'vit_beatrix_shaper': { |
|
|
'embed_dim': 256, |
|
|
'vocab_dim': 256, |
|
|
'depth': 16, |
|
|
'num_heads': 8, |
|
|
'mlp_ratio': 1.0, |
|
|
|
|
|
'margin_type': 'cosface', |
|
|
'margin_m': 0.30, |
|
|
'scale_s': 30.0, |
|
|
}, |
|
|
'vit_beatrix_arc_shaper': { |
|
|
'embed_dim': 256, |
|
|
'vocab_dim': 256, |
|
|
'depth': 16, |
|
|
'num_heads': 8, |
|
|
'mlp_ratio': 2.0, |
|
|
|
|
|
'margin_type': 'arcface', |
|
|
'margin_m': 0.2914, |
|
|
'scale_s': 30.0, |
|
|
}, |
|
|
'vit_beatrix_nano_arc': { |
|
|
'embed_dim': 64, |
|
|
'vocab_dim': 64, |
|
|
'depth': 25, |
|
|
'num_heads': 8, |
|
|
'mlp_ratio': 8.0, |
|
|
|
|
|
'margin_type': 'arcface', |
|
|
'margin_m': 0.2914, |
|
|
'scale_s': 30.0, |
|
|
}, |
|
|
'vit_beatrix_nano_cos': { |
|
|
'embed_dim': 64, |
|
|
'vocab_dim': 64, |
|
|
'depth': 25, |
|
|
'num_heads': 8, |
|
|
'mlp_ratio': 8.0, |
|
|
|
|
|
'margin_type': 'cosface', |
|
|
'margin_m': 0.2914, |
|
|
'scale_s': 30.0, |
|
|
}, |
|
|
'vit_beatrix_nano_128_cos': { |
|
|
'embed_dim': 128, |
|
|
'vocab_dim': 128, |
|
|
'depth': 25, |
|
|
'num_heads': 8, |
|
|
'mlp_ratio': 8.0, |
|
|
|
|
|
'margin_type': 'cosface', |
|
|
'margin_m': 0.2914, |
|
|
'scale_s': 30.0, |
|
|
}, |
|
|
'vit_beatrix_mini_cos': { |
|
|
'embed_dim': 256, |
|
|
'vocab_dim': 256, |
|
|
'depth': 25, |
|
|
'num_heads': 8, |
|
|
'mlp_ratio': 8.0, |
|
|
|
|
|
'margin_type': 'cosface', |
|
|
'margin_m': 0.2914, |
|
|
'scale_s': 30.0, |
|
|
}, |
|
|
'vit_beatrix_mini_cos_large_margin': { |
|
|
'embed_dim': 256, |
|
|
'vocab_dim': 256, |
|
|
'depth': 25, |
|
|
'num_heads': 8, |
|
|
'mlp_ratio': 8.0, |
|
|
|
|
|
'margin_type': 'cosface', |
|
|
'margin_m': 0.7086, |
|
|
'scale_s': 30.0, |
|
|
}, |
|
|
'vit_zana_nano': { |
|
|
'embed_dim': 128, |
|
|
'vocab_dim': 128, |
|
|
'depth': 4, |
|
|
'num_heads': 2, |
|
|
'mlp_ratio': 2.0 |
|
|
}, |
|
|
'vit_beatrix_base_cos': { |
|
|
'embed_dim': 512, |
|
|
'vocab_dim': 512, |
|
|
'depth': 25, |
|
|
'num_heads': 16, |
|
|
'mlp_ratio': 8.0, |
|
|
|
|
|
'margin_type': 'cosface', |
|
|
'margin_m': 0.2914, |
|
|
'scale_s': 30.0, |
|
|
}, |
|
|
'vit_zana_nano_deep': { |
|
|
'embed_dim': 128, |
|
|
'vocab_dim': 128, |
|
|
'depth': 8, |
|
|
'num_heads': 4, |
|
|
'mlp_ratio': 2.0 |
|
|
}, |
|
|
'vit_zana_shaper': { |
|
|
'embed_dim': 256, |
|
|
'vocab_dim': 256, |
|
|
'depth': 32, |
|
|
'num_heads': 8, |
|
|
'mlp_ratio': 4.0 |
|
|
}, |
|
|
'vit_zana_nano_thicc': { |
|
|
'embed_dim': 128, |
|
|
'vocab_dim': 128, |
|
|
'depth': 4, |
|
|
'num_heads': 8, |
|
|
'mlp_ratio': 4.0 |
|
|
}, |
|
|
'vit_zana_micro': { |
|
|
'embed_dim': 500, |
|
|
'vocab_dim': 25, |
|
|
'depth': 6, |
|
|
'num_heads': 2, |
|
|
'mlp_ratio': 2.0 |
|
|
}, |
|
|
'vit_zana_micro_500': { |
|
|
'embed_dim': 500, |
|
|
'vocab_dim': 25, |
|
|
'depth': 6, |
|
|
'num_heads': 5, |
|
|
'mlp_ratio': 2.0 |
|
|
}, |
|
|
|
|
|
'vit_zana_base': { |
|
|
'embed_dim': 512, |
|
|
'vocab_dim': 512, |
|
|
'depth': 16, |
|
|
'num_heads': 4, |
|
|
'mlp_ratio': 4.0 |
|
|
}, |
|
|
'vit_ursula_nano_1000': { |
|
|
'embed_dim': 1000, |
|
|
'vocab_dim': 500, |
|
|
'depth': 4, |
|
|
'num_heads': 50, |
|
|
'mlp_ratio': 4.0 |
|
|
}, |
|
|
'vit_ursula_nano': { |
|
|
'embed_dim': 1000, |
|
|
'vocab_dim': 25, |
|
|
'depth': 4, |
|
|
'num_heads': 10, |
|
|
'mlp_ratio': 4.0 |
|
|
}, |
|
|
|
|
|
|
|
|
'tiny': { |
|
|
'embed_dim': 192, |
|
|
'vocab_dim': 192, |
|
|
'depth': 12, |
|
|
'num_heads': 3, |
|
|
'mlp_ratio': 4.0 |
|
|
}, |
|
|
|
|
|
'vit_ursula_mini': { |
|
|
'embed_dim': 256, |
|
|
'vocab_dim': 256, |
|
|
'depth': 12, |
|
|
'num_heads': 4, |
|
|
'mlp_ratio': 4.0 |
|
|
}, |
|
|
|
|
|
|
|
|
'small': { |
|
|
'embed_dim': 384, |
|
|
'vocab_dim': 384, |
|
|
'depth': 12, |
|
|
'num_heads': 6, |
|
|
'mlp_ratio': 4.0 |
|
|
}, |
|
|
|
|
|
'base': { |
|
|
'embed_dim': 768, |
|
|
'vocab_dim': 768, |
|
|
'depth': 12, |
|
|
'num_heads': 12, |
|
|
'mlp_ratio': 4.0 |
|
|
}, |
|
|
|
|
|
|
|
|
'wide_shallow': { |
|
|
'embed_dim': 1024, |
|
|
'vocab_dim': 1024, |
|
|
'depth': 4, |
|
|
'num_heads': 16, |
|
|
'mlp_ratio': 2.0 |
|
|
}, |
|
|
|
|
|
'narrow_deep': { |
|
|
'embed_dim': 192, |
|
|
'vocab_dim': 192, |
|
|
'depth': 24, |
|
|
'num_heads': 3, |
|
|
'mlp_ratio': 4.0 |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
""" |
|
|
Updated pentachora batch generation and model creation for L1 norm. |
|
|
Add this modification to your existing build_model function. |
|
|
""" |
|
|
|
|
|
def build_model(variant='small', **override_params): |
|
|
""" |
|
|
Build model with explicit parameter handling - no hidden kwargs. |
|
|
|
|
|
Args: |
|
|
variant: Model variant name from MODEL_CONFIGS |
|
|
**override_params: Individual parameter overrides |
|
|
|
|
|
Returns: |
|
|
model: BaselineViT model with frozen pentachora |
|
|
""" |
|
|
assert variant in MODEL_CONFIGS, f"Unknown variant: {variant}. Choose from {list(MODEL_CONFIGS.keys())}" |
|
|
base_config = MODEL_CONFIGS[variant].copy() |
|
|
|
|
|
|
|
|
|
|
|
embed_dim = override_params.get('embed_dim', base_config.get('embed_dim', 512)) |
|
|
vocab_dim = override_params.get('vocab_dim', base_config.get('vocab_dim', 512)) |
|
|
depth = override_params.get('depth', base_config.get('depth', 12)) |
|
|
num_heads = override_params.get('num_heads', base_config.get('num_heads', 8)) |
|
|
mlp_ratio = override_params.get('mlp_ratio', base_config.get('mlp_ratio', 4.0)) |
|
|
|
|
|
|
|
|
img_size = override_params.get('img_size', base_config.get('img_size', 32)) |
|
|
patch_size = override_params.get('patch_size', base_config.get('patch_size', 4)) |
|
|
|
|
|
|
|
|
dropout = override_params.get('dropout', base_config.get('dropout', 0.0)) |
|
|
attn_dropout = override_params.get('attn_dropout', base_config.get('attn_dropout', 0.0)) |
|
|
|
|
|
|
|
|
similarity_mode = override_params.get('similarity_mode', base_config.get('similarity_mode', 'rose')) |
|
|
norm_type = override_params.get('norm_type', base_config.get('norm_type', 'l1')) |
|
|
|
|
|
|
|
|
head_type = override_params.get('head_type', base_config.get('head_type', 'roseface')) |
|
|
prototype_mode = override_params.get('prototype_mode', base_config.get('prototype_mode', 'centroid')) |
|
|
margin_type = override_params.get('margin_type', base_config.get('margin_type', 'cosface')) |
|
|
margin_m = float(override_params.get('margin_m', base_config.get('margin_m', 0.30))) |
|
|
scale_s = float(override_params.get('scale_s', base_config.get('scale_s', 30.0))) |
|
|
apply_margin_train_only = override_params.get('apply_margin_train_only', |
|
|
base_config.get('apply_margin_train_only', False)) |
|
|
|
|
|
|
|
|
num_classes = len(CIFAR100_CLASSES) |
|
|
|
|
|
|
|
|
print(f"Building {variant}:") |
|
|
print(f" Architecture: embed={embed_dim}, vocab={vocab_dim}, depth={depth}, heads={num_heads}") |
|
|
print(f" Image: {img_size}x{img_size}, patch={patch_size}x{patch_size}") |
|
|
print(f" RoseFace: {margin_type}, m={margin_m:.4f}, s={scale_s:.1f}") |
|
|
print(f" Norm: {norm_type}, Similarity: {similarity_mode}") |
|
|
|
|
|
|
|
|
print(f"Generating {num_classes} pentachora from vocabulary...") |
|
|
class_names = CIFAR100_CLASSES[:num_classes] |
|
|
|
|
|
|
|
|
pentachora_np_list = vocab.encode_batch(class_names, generate=True) |
|
|
|
|
|
|
|
|
raw_penta_list = [torch.tensor(penta, dtype=torch.float32) for penta in pentachora_np_list] |
|
|
|
|
|
|
|
|
pentachora_list = [] |
|
|
for i, penta in enumerate(raw_penta_list): |
|
|
if penta.shape[-1] != vocab_dim: |
|
|
current_dim = penta.shape[-1] |
|
|
|
|
|
if current_dim > vocab_dim: |
|
|
|
|
|
resized_vertices = [] |
|
|
for v in range(penta.shape[0]): |
|
|
indices = torch.linspace(0, current_dim - 1, vocab_dim) |
|
|
vertex = penta[v] |
|
|
left_idx = indices.floor().long() |
|
|
right_idx = (left_idx + 1).clamp(max=current_dim - 1) |
|
|
alpha = indices - left_idx.float() |
|
|
interpolated = vertex[left_idx] * (1 - alpha) + vertex[right_idx] * alpha |
|
|
resized_vertices.append(interpolated) |
|
|
penta_resized = torch.stack(resized_vertices) |
|
|
if i == 0: |
|
|
print(f" Downsampling pentachora from {current_dim} to {vocab_dim}") |
|
|
else: |
|
|
|
|
|
resized_vertices = [] |
|
|
for v in range(penta.shape[0]): |
|
|
vertex = penta[v] |
|
|
x = torch.linspace(0, current_dim - 1, vocab_dim) |
|
|
interpolated = torch.zeros(vocab_dim, dtype=vertex.dtype, device=vertex.device) |
|
|
for j in range(vocab_dim): |
|
|
if x[j] <= 0: |
|
|
interpolated[j] = vertex[0] |
|
|
elif x[j] >= current_dim - 1: |
|
|
interpolated[j] = vertex[-1] |
|
|
else: |
|
|
left = int(x[j]) |
|
|
alpha = x[j] - left |
|
|
interpolated[j] = vertex[left] * (1 - alpha) + vertex[left + 1] * alpha |
|
|
resized_vertices.append(interpolated) |
|
|
penta_resized = torch.stack(resized_vertices) |
|
|
if i == 0: |
|
|
print(f" Upsampling pentachora from {current_dim} to {vocab_dim}") |
|
|
|
|
|
pentachora_list.append(penta_resized) |
|
|
else: |
|
|
pentachora_list.append(penta.detach().clone().to(get_default_device())) |
|
|
|
|
|
print(f"Using {num_classes} L1-normalized pentachora") |
|
|
|
|
|
|
|
|
model = BaselineViT( |
|
|
pentachora_list=pentachora_list, |
|
|
vocab_dim=vocab_dim, |
|
|
img_size=img_size, |
|
|
patch_size=patch_size, |
|
|
embed_dim=embed_dim, |
|
|
depth=depth, |
|
|
num_heads=num_heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
dropout=dropout, |
|
|
attn_dropout=attn_dropout, |
|
|
similarity_mode=similarity_mode, |
|
|
norm_type=norm_type, |
|
|
head_type=head_type, |
|
|
prototype_mode=prototype_mode, |
|
|
margin_type=margin_type, |
|
|
margin_m=margin_m, |
|
|
scale_s=scale_s, |
|
|
apply_margin_train_only=apply_margin_train_only |
|
|
) |
|
|
|
|
|
|
|
|
model.config = { |
|
|
'variant': variant, |
|
|
'vocab_dim': vocab_dim, |
|
|
'embed_dim': embed_dim, |
|
|
'depth': depth, |
|
|
'num_heads': num_heads, |
|
|
'mlp_ratio': mlp_ratio, |
|
|
'img_size': img_size, |
|
|
'patch_size': patch_size, |
|
|
'dropout': dropout, |
|
|
'attn_dropout': attn_dropout, |
|
|
'similarity_mode': similarity_mode, |
|
|
'norm_type': norm_type, |
|
|
'head_type': head_type, |
|
|
'prototype_mode': prototype_mode, |
|
|
'margin_type': margin_type, |
|
|
'margin_m': margin_m, |
|
|
'scale_s': scale_s, |
|
|
'apply_margin_train_only': apply_margin_train_only, |
|
|
'num_classes': num_classes, |
|
|
} |
|
|
|
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
frozen_params = total_params - trainable_params |
|
|
|
|
|
|
|
|
print("\nDiagnostic: Checking pentachora status...") |
|
|
for i, penta in enumerate(model.class_pentachora[:3]): |
|
|
print(f"Pentachora {i}:") |
|
|
print(f" vertices requires_grad: {penta.vertices.requires_grad}") |
|
|
print(f" vertices mean: {penta.vertices.mean().item():.6f}") |
|
|
print(f" vertices std: {penta.vertices.std().item():.6f}") |
|
|
|
|
|
|
|
|
print("\nMain model parameters:") |
|
|
if hasattr(model, 'patch_embed'): |
|
|
print(f" patch_embed.weight mean: {model.patch_embed.weight.mean().item():.6f}") |
|
|
print(f" patch_embed.weight std: {model.patch_embed.weight.std().item():.6f}") |
|
|
|
|
|
print(f"\nModel: {variant}") |
|
|
print(f" Classes: {num_classes}") |
|
|
print(f" Normalization: {norm_type.upper()}") |
|
|
print(f" Total params: {total_params:,}") |
|
|
print(f" Trainable params: {trainable_params:,}") |
|
|
print(f" Frozen pentachora params: {frozen_params:,}") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os, json, math |
|
|
from pathlib import Path |
|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
try: |
|
|
from safetensors.torch import save_file, load_file |
|
|
except Exception as e: |
|
|
raise RuntimeError("safetensors is required: pip install safetensors") from e |
|
|
|
|
|
def _get_device(): |
|
|
return torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
def _jsonify_obj(obj) -> dict: |
|
|
"""Turn a config object or dict into a JSON-safe dict.""" |
|
|
if obj is None: |
|
|
return {} |
|
|
if isinstance(obj, dict): |
|
|
return obj |
|
|
out = {} |
|
|
for k in dir(obj): |
|
|
if k.startswith('_'): |
|
|
continue |
|
|
v = getattr(obj, k) |
|
|
if callable(v): |
|
|
continue |
|
|
if isinstance(v, torch.Tensor): |
|
|
v = v.tolist() |
|
|
elif isinstance(v, np.ndarray): |
|
|
v = v.tolist() |
|
|
out[k] = v |
|
|
return out |
|
|
|
|
|
def _ensure_model_config_dict(model): |
|
|
"""Guarantee model.config is a dict describing the head + geometry relevant fields.""" |
|
|
if hasattr(model, "config") and isinstance(model.config, dict): |
|
|
return model.config |
|
|
cfg = { |
|
|
"arch": type(model).__name__, |
|
|
"num_classes": getattr(model, "num_classes", None), |
|
|
"embed_dim": getattr(model, "embed_dim", None), |
|
|
"pentachora_dim": getattr(model, "pentachora_dim", None), |
|
|
"img_size": getattr(model, "img_size", 32), |
|
|
"patch_size": getattr(model, "patch_size", 4), |
|
|
"norm_type": getattr(model, "norm_type", None), |
|
|
"similarity_mode": getattr(model, "similarity_mode", None), |
|
|
"head_type": getattr(model, "head_type", None), |
|
|
"prototype_mode": getattr(model, "prototype_mode", None), |
|
|
"margin_type": getattr(model, "margin_type", None), |
|
|
"margin_m": float(getattr(model, "margin_m", 0.0)) if hasattr(model, "margin_m") else None, |
|
|
"scale_s": float(getattr(model, "scale_s", 1.0)) if hasattr(model, "scale_s") else None, |
|
|
} |
|
|
model.config = cfg |
|
|
return cfg |
|
|
|
|
|
def _collect_state_tensors(state_dict): |
|
|
return {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} |
|
|
|
|
|
def _session_dir(paths: dict) -> Path: |
|
|
root = Path(paths["save_dir"]) |
|
|
return root / f"{paths['model_variant']}_{paths['session_timestamp']}" |
|
|
|
|
|
def _find_local_checkpoint(paths: dict) -> tuple[Path, Path | None, Path | None]: |
|
|
""" |
|
|
Return (weights_path, model_config_path, vocab_path) from the session dir. |
|
|
Prefer 'best_*.safetensors'; fall back to most recent '*.safetensors'. |
|
|
""" |
|
|
sdir = _session_dir(paths) |
|
|
if not sdir.exists(): |
|
|
return None, None, None |
|
|
safes = sorted(sdir.glob("*.safetensors"), key=lambda p: p.stat().st_mtime) |
|
|
if not safes: |
|
|
return None, None, None |
|
|
|
|
|
bests = [p for p in safes if p.name.startswith("best_")] |
|
|
w = bests[-1] if bests else safes[-1] |
|
|
model_cfg = sdir / w.name.replace(".safetensors", "_model_config.json") |
|
|
vocab = sdir / w.name.replace(".safetensors", "_vocabulary.json") |
|
|
return w, (model_cfg if model_cfg.exists() else None), (vocab if vocab.exists() else None) |
|
|
|
|
|
def _load_saved_vocabulary(vocab_json_path: Path) -> list[torch.Tensor]: |
|
|
"""Return list of [5,D] tensors from saved crystal JSON.""" |
|
|
with open(vocab_json_path, "r") as f: |
|
|
data = json.load(f) |
|
|
crystals = data.get("crystal_to_token", []) |
|
|
|
|
|
penta_list = [] |
|
|
for item in crystals: |
|
|
arr = torch.tensor(item["crystal"], dtype=torch.float32) |
|
|
penta_list.append(arr) |
|
|
return penta_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_existing_model( |
|
|
model, |
|
|
paths: dict, |
|
|
model_config=None, |
|
|
training_config=None, |
|
|
*, |
|
|
filename_base: str | None = None, |
|
|
save_vocabulary: bool = True, |
|
|
push_to_hub: bool | None = None |
|
|
): |
|
|
""" |
|
|
Save the model to disk, and optionally upload to the HF Hub. |
|
|
|
|
|
Args: |
|
|
model: BaselineViT instance |
|
|
paths: { |
|
|
'save_dir': str, |
|
|
'model_variant': str, |
|
|
'session_timestamp': str, |
|
|
# (optional for naming) |
|
|
'epoch': int, |
|
|
'val_acc': float, |
|
|
'is_best': bool, |
|
|
# hub |
|
|
'hub_repo': str, |
|
|
'hub_token': str|None, |
|
|
} |
|
|
model_config: dict or object (optional; if None, built from model) |
|
|
training_config: TrainingConfig or dict (optional; saved to JSON) |
|
|
filename_base: override the base filename; if None, derived from epoch/acc/best |
|
|
save_vocabulary: write *_vocabulary.json from model.class_pentachora |
|
|
push_to_hub: override paths.get('push_to_hub') |
|
|
""" |
|
|
device = _get_device() |
|
|
sess_dir = _session_dir(paths) |
|
|
sess_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
if filename_base is None: |
|
|
ep = paths.get("epoch") |
|
|
acc = paths.get("val_acc") |
|
|
is_best = bool(paths.get("is_best", False)) |
|
|
tag = f"epoch{int(ep):03d}_acc{float(acc):.2f}" if (ep is not None and acc is not None) else "snapshot" |
|
|
filename_base = f"{'best_' if is_best else 'checkpoint_'}{tag}" |
|
|
|
|
|
|
|
|
weights_path = sess_dir / f"{filename_base}.safetensors" |
|
|
state = _collect_state_tensors(model.state_dict()) |
|
|
save_file(state, str(weights_path)) |
|
|
|
|
|
|
|
|
cfg_dict = _jsonify_obj(model_config) or _ensure_model_config_dict(model) |
|
|
model_cfg_path = sess_dir / f"{filename_base}_model_config.json" |
|
|
with open(model_cfg_path, "w") as f: |
|
|
json.dump(cfg_dict, f, indent=2, default=str) |
|
|
|
|
|
|
|
|
if training_config is not None: |
|
|
train_cfg_dict = _jsonify_obj(training_config) |
|
|
train_cfg_path = sess_dir / f"{filename_base}_training_config.json" |
|
|
with open(train_cfg_path, "w") as f: |
|
|
json.dump(train_cfg_dict, f, indent=2, default=str) |
|
|
else: |
|
|
train_cfg_path = None |
|
|
|
|
|
|
|
|
vocab_path = None |
|
|
if save_vocabulary and hasattr(model, "class_pentachora") and model.class_pentachora is not None: |
|
|
crystals = torch.stack([p.vertices for p in model.class_pentachora], dim=0).detach().cpu().numpy().tolist() |
|
|
vocab_data = { |
|
|
"vocab_dim": getattr(model, "pentachora_dim", None), |
|
|
"num_classes": len(model.class_pentachora), |
|
|
"num_vertices": 5, |
|
|
"tokens": CIFAR100_CLASSES[: len(crystals)], |
|
|
"crystal_to_token": [ |
|
|
{"index": i, "token": CIFAR100_CLASSES[i], "crystal": crystals[i]} |
|
|
for i in range(len(crystals)) |
|
|
], |
|
|
} |
|
|
vocab_path = sess_dir / f"{filename_base}_vocabulary.json" |
|
|
with open(vocab_path, "w") as f: |
|
|
json.dump(vocab_data, f, indent=2) |
|
|
|
|
|
print(f"✓ Saved weights: {weights_path.name}") |
|
|
print(f"✓ Saved model config: {model_cfg_path.name}") |
|
|
if train_cfg_path: |
|
|
print(f"✓ Saved training config: {train_cfg_path.name}") |
|
|
if vocab_path: |
|
|
print(f"✓ Saved vocabulary: {vocab_path.name}") |
|
|
|
|
|
|
|
|
do_push = push_to_hub if push_to_hub is not None else paths.get("push_to_hub", False) |
|
|
if do_push: |
|
|
try: |
|
|
from huggingface_hub import HfApi, create_repo |
|
|
hub_repo = paths["hub_repo"] |
|
|
hub_token = paths.get("hub_token") |
|
|
subfolder = f"models/{paths['model_variant']}/{paths['session_timestamp']}" |
|
|
|
|
|
api = HfApi(token=hub_token) |
|
|
try: |
|
|
create_repo(hub_repo, token=hub_token, private=True, exist_ok=True) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
def _up(p: Path): |
|
|
api.upload_file( |
|
|
path_or_fileobj=str(p), |
|
|
path_in_repo=f"{subfolder}/{p.name}", |
|
|
repo_id=hub_repo, |
|
|
repo_type="model" |
|
|
) |
|
|
|
|
|
_up(weights_path); _up(model_cfg_path) |
|
|
if train_cfg_path: _up(train_cfg_path) |
|
|
if vocab_path: _up(vocab_path) |
|
|
print(f"✓ Pushed to hub: {hub_repo}/{subfolder}") |
|
|
except Exception as e: |
|
|
print(f"⚠ Hub upload failed: {e}") |
|
|
|
|
|
return { |
|
|
"weights": weights_path, |
|
|
"model_config": model_cfg_path, |
|
|
"training_config": train_cfg_path, |
|
|
"vocabulary": vocab_path, |
|
|
"session_dir": sess_dir |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_existing_model( |
|
|
model_path: str | Path | None, |
|
|
paths: dict | None, |
|
|
model_config=None, |
|
|
training_config=None, |
|
|
*, |
|
|
from_hub: bool = False, |
|
|
prefer_best: bool = True, |
|
|
map_location: str | torch.device | None = None |
|
|
): |
|
|
""" |
|
|
Load a saved model (weights + config), reconstruct the architecture via build_model, |
|
|
and return a ready-to-use model. If a saved vocabulary is present, reuse it. |
|
|
|
|
|
Args: |
|
|
model_path: explicit path to a .safetensors file; if None, resolve from `paths` |
|
|
paths: { |
|
|
'save_dir': str, 'model_variant': str, 'session_timestamp': str, |
|
|
# (for hub) |
|
|
'hub_repo': str, 'hub_token': str|None |
|
|
} |
|
|
from_hub: if True, pull from HF Hub subfolder models/{variant}/{session}/ |
|
|
prefer_best: when scanning a folder, pick 'best_*.safetensors' if available |
|
|
map_location: optional torch map_location |
|
|
|
|
|
Returns: |
|
|
model (on default device), resolved_paths dict |
|
|
""" |
|
|
device = _get_device() if map_location is None else map_location |
|
|
|
|
|
|
|
|
if model_path is not None: |
|
|
weights_path = Path(model_path) |
|
|
base = weights_path.name.replace(".safetensors", "") |
|
|
session_dir = weights_path.parent |
|
|
model_cfg_path = session_dir / f"{base}_model_config.json" |
|
|
vocab_path = session_dir / f"{base}_vocabulary.json" |
|
|
elif from_hub: |
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
except Exception as e: |
|
|
raise RuntimeError("huggingface_hub is required for from_hub=True") from e |
|
|
hub_repo = paths["hub_repo"] |
|
|
subfolder = f"models/{paths['model_variant']}/{paths['session_timestamp']}" |
|
|
|
|
|
|
|
|
|
|
|
candidates = ["best", "checkpoint"] |
|
|
weights_path = None |
|
|
for pref in candidates: |
|
|
try: |
|
|
fname = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
except Exception: |
|
|
pass |
|
|
raise RuntimeError( |
|
|
"When loading from Hub, please supply the explicit .safetensors filename in model_path " |
|
|
"(e.g., '.../best_epoch010_acc30.30.safetensors') or download locally first." |
|
|
) |
|
|
else: |
|
|
|
|
|
weights_path, model_cfg_path, vocab_path = _find_local_checkpoint(paths) |
|
|
if weights_path is None: |
|
|
raise FileNotFoundError("No checkpoint found in session folder") |
|
|
|
|
|
|
|
|
|
|
|
if model_cfg_path and model_cfg_path.exists(): |
|
|
with open(model_cfg_path, "r") as f: |
|
|
cfg = json.load(f) |
|
|
else: |
|
|
cfg = _jsonify_obj(model_config) |
|
|
|
|
|
|
|
|
variant = cfg.get("variant", paths.get("model_variant") if paths else None) |
|
|
if variant is None: |
|
|
raise ValueError("Model variant not found in config; pass paths['model_variant'] or include 'variant'.") |
|
|
|
|
|
overrides = {} |
|
|
|
|
|
for k in ("embed_dim","vocab_dim","depth","num_heads","mlp_ratio", |
|
|
"img_size","patch_size","dropout","attn_dropout", |
|
|
"norm_type","similarity_mode", |
|
|
"head_type","prototype_mode","margin_type","margin_m","scale_s", |
|
|
"apply_margin_train_only"): |
|
|
if k in cfg and cfg[k] is not None: |
|
|
overrides[k] = cfg[k] |
|
|
|
|
|
|
|
|
|
|
|
if 'vocabulary' in overrides: |
|
|
overrides.pop('vocabulary') |
|
|
if 'num_classes' in cfg: |
|
|
overrides['num_classes'] = cfg['num_classes'] |
|
|
|
|
|
if 'vocab' in globals() and (not ('pentachora_list' in overrides)): |
|
|
|
|
|
model = build_model(variant=variant, **overrides).to(device) |
|
|
if 'get_default_device' in globals(): |
|
|
model = model.to(get_default_device()) |
|
|
else: |
|
|
model = build_model(variant=variant, **overrides).to(device) |
|
|
|
|
|
|
|
|
if 'vocab' in globals() and vocab_path and vocab_path.exists(): |
|
|
saved_penta = _load_saved_vocabulary(vocab_path) |
|
|
if hasattr(model, "class_pentachora") and len(saved_penta) == len(model.class_pentachora): |
|
|
|
|
|
new_list = [] |
|
|
for p in saved_penta: |
|
|
new_list.append(type(model.class_pentachora[0])(p, norm_type=getattr(model, "norm_type", "l1"))) |
|
|
|
|
|
import torch.nn as nn |
|
|
model.class_pentachora = nn.ModuleList(new_list) |
|
|
|
|
|
|
|
|
|
|
|
sd = load_file(str(weights_path), device='cpu') |
|
|
print(f"\nCheckpoint contains {len(sd)} keys") |
|
|
print(f"First 5 keys: {list(sd.keys())[:5]}") |
|
|
|
|
|
|
|
|
has_orig_mod = any(k.startswith("_orig_mod.") for k in sd.keys()) |
|
|
if has_orig_mod: |
|
|
print("Detected compiled model checkpoint (_orig_mod. prefix)") |
|
|
|
|
|
|
|
|
fixed = {} |
|
|
for k, v in sd.items(): |
|
|
new_key = k[10:] if k.startswith("_orig_mod.") else k |
|
|
fixed[new_key] = v |
|
|
|
|
|
|
|
|
model_state = model.state_dict() |
|
|
print(f"\nModel expects {len(model_state)} keys") |
|
|
print(f"First 5 expected: {list(model_state.keys())[:5]}") |
|
|
|
|
|
|
|
|
checkpoint_keys = set(fixed.keys()) |
|
|
model_keys = set(model_state.keys()) |
|
|
|
|
|
missing_in_checkpoint = model_keys - checkpoint_keys |
|
|
unexpected_in_checkpoint = checkpoint_keys - model_keys |
|
|
|
|
|
print(f"\nKeys in model but not in checkpoint: {len(missing_in_checkpoint)}") |
|
|
if missing_in_checkpoint and len(missing_in_checkpoint) < 10: |
|
|
print(f" Missing: {list(missing_in_checkpoint)[:10]}") |
|
|
|
|
|
print(f"Keys in checkpoint but not in model: {len(unexpected_in_checkpoint)}") |
|
|
if unexpected_in_checkpoint and len(unexpected_in_checkpoint) < 10: |
|
|
print(f" Unexpected: {list(unexpected_in_checkpoint)[:10]}") |
|
|
|
|
|
|
|
|
try: |
|
|
model.load_state_dict(fixed, strict=True) |
|
|
print("✓ Strict load successful - all weights loaded") |
|
|
except RuntimeError as e: |
|
|
print(f"⚠ Strict load failed: {e}") |
|
|
|
|
|
incompatible = model.load_state_dict(fixed, strict=False) |
|
|
print(f"Loaded with strict=False") |
|
|
print(f" Missing keys: {len(incompatible.missing_keys)}") |
|
|
print(f" Unexpected keys: {len(incompatible.unexpected_keys)}") |
|
|
|
|
|
|
|
|
critical_missing = [k for k in incompatible.missing_keys if 'weight' in k or 'bias' in k] |
|
|
if critical_missing: |
|
|
print(f" ⚠ Critical missing weights: {critical_missing[:5]}") |
|
|
|
|
|
|
|
|
sample_weight = next(iter(model.parameters())) |
|
|
print(f"\nFirst parameter stats:") |
|
|
print(f" Shape: {sample_weight.shape}") |
|
|
print(f" Mean: {sample_weight.mean().item():.6f}") |
|
|
print(f" Std: {sample_weight.std().item():.6f}") |
|
|
print(f" Min: {sample_weight.min().item():.6f}") |
|
|
print(f" Max: {sample_weight.max().item():.6f}") |
|
|
|
|
|
model.eval() |
|
|
return model, { |
|
|
"weights": weights_path, |
|
|
"model_config": model_cfg_path, |
|
|
"vocabulary": vocab_path, |
|
|
"session_dir": weights_path.parent |
|
|
} |
|
|
|
|
|
def get_parameter_groups(model, weight_decay): |
|
|
"""Create parameter groups with weight decay handling""" |
|
|
no_decay = ['bias', 'LayerNorm.weight', 'norm'] |
|
|
params_decay = [] |
|
|
params_no_decay = [] |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if param.requires_grad: |
|
|
if any(nd in name for nd in no_decay): |
|
|
params_no_decay.append(param) |
|
|
else: |
|
|
params_decay.append(param) |
|
|
|
|
|
return [ |
|
|
{'params': params_decay, 'weight_decay': weight_decay}, |
|
|
{'params': params_no_decay, 'weight_decay': 0.0} |
|
|
] |
|
|
|
|
|
def create_scheduler(optimizer, config, start_epoch=0): |
|
|
"""Create cosine scheduler with warmup""" |
|
|
def lr_lambda(epoch): |
|
|
if epoch < config.warmup_epochs: |
|
|
return epoch / config.warmup_epochs |
|
|
if config.epochs <= config.warmup_epochs: |
|
|
return 1.0 |
|
|
return 0.5 * (1 + np.cos(np.pi * (epoch - config.warmup_epochs) / |
|
|
(config.epochs - config.warmup_epochs))) |
|
|
|
|
|
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
|
|
|
|
for _ in range(start_epoch): |
|
|
scheduler.step() |
|
|
|
|
|
return scheduler |
|
|
|
|
|
def count_parameters(model): |
|
|
"""Count model parameters""" |
|
|
total = sum(p.numel() for p in model.parameters()) |
|
|
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
return {'total': total, 'trainable': trainable} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Testing model loader...") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
model = build_model('vit_beatrix_shaper').to(get_default_device()) |
|
|
|
|
|
|
|
|
|
|
|
x = torch.randn(4, 3, 32, 32).to(get_default_device()) |
|
|
output = model(x) |
|
|
|
|
|
print(f"\nForward pass successful!") |
|
|
print(f" Input shape: {x.shape}") |
|
|
print(f" Logits shape: {output['logits'].shape}") |
|
|
print(f" Similarities shape: {output['similarities'].shape}") |
|
|
|
|
|
print("\n✓ Model loader working correctly!") |