Rename trainer_alignment_base_8192_upgrade.py to trainers/trainer_alignment_base_8192_upgrade.py
e44e6ea verified | # ============================================================================ | |
| # DISTILLED CONSENSUS BERT β 500k scale, expanded for 8192 | |
| # | |
| # Self-contained pipeline: | |
| # 1. Extract 5 BERT-family embeddings on 500k CC12M captions | |
| # 2. Whitened Procrustes alignment | |
| # 3. Generate consensus targets (centroid of aligned embeddings) | |
| # 4. Train small standalone transformer from scratch | |
| # 5. No expert models needed at inference | |
| # ============================================================================ | |
| import math | |
| import os | |
| import time | |
| import json | |
| from dataclasses import dataclass | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODELS = [ | |
| ("google-bert/bert-base-uncased", "bert", 512), | |
| ("answerdotai/ModernBERT-base", "modern", 8192), | |
| ("FacebookAI/roberta-base", "roberta", 512), | |
| ("albert/albert-base-v2", "albert", 512), | |
| ("distilbert/distilbert-base-uncased", "distil", 512), | |
| ] | |
| class Config: | |
| # Data | |
| n_samples: int = 500000 | |
| n_val: int = 5000 | |
| min_caption_len: int = 50 | |
| extract_batch: int = 1024 | |
| cache_dir: str = "/home/claude/consensus_500k" | |
| # Student architecture | |
| d_model: int = 384 | |
| n_heads: int = 6 | |
| n_layers: int = 6 | |
| d_ff: int = 1536 | |
| max_len: int = 8192 # position embedding capacity | |
| tokenize_len: int = 512 # actual padding length (captions avg ~100 tokens) | |
| output_dim: int = 768 | |
| dropout: float = 0.1 | |
| # Training | |
| epochs: int = 30 | |
| batch_size: int = 128 # sequences are tokenize_len=512, not max_len=8192 | |
| lr: float = 3e-4 | |
| weight_decay: float = 0.01 | |
| warmup_steps: int = 1000 | |
| grad_clip: float = 1.0 | |
| seed: int = 42 | |
| # Loss | |
| nce_weight: float = 1.0 | |
| mse_weight: float = 1.0 | |
| cv_weight: float = 0.1 | |
| cv_target: float = 0.084 | |
| CFG = Config() | |
| print("=" * 65) | |
| print("DISTILLED CONSENSUS BERT β 500k Scale") | |
| print("=" * 65) | |
| print(f" Device: {DEVICE}") | |
| print(f" Samples: {CFG.n_samples:,}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # EXTRACTION | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_captions(n, min_len=50): | |
| from datasets import load_dataset | |
| print(f"\n Loading captions (n={n:,})...") | |
| ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext", | |
| split="train", streaming=True) | |
| captions = [] | |
| for row in ds: | |
| cap = row.get("caption_llava", "") | |
| if isinstance(cap, str) and len(cap) > min_len: | |
| captions.append(cap) | |
| if len(captions) >= n: | |
| break | |
| print(f" Got {len(captions):,} captions") | |
| return captions | |
| def extract_one(model_name, short_name, captions, max_len, batch_size): | |
| from transformers import AutoModel, AutoTokenizer | |
| print(f"\n Extracting: {short_name} ({model_name})...") | |
| model = AutoModel.from_pretrained(model_name).to(DEVICE).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| dim = model.config.hidden_size | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| print(f" dim={dim}, {n_params:,} params") | |
| all_emb = [] | |
| for i in tqdm(range(0, len(captions), batch_size), desc=f" {short_name}"): | |
| batch = captions[i:i+batch_size] | |
| inputs = tokenizer(batch, max_length=max_len, padding=True, | |
| truncation=True, return_tensors="pt").to(DEVICE) | |
| out = model(**inputs) | |
| mask = inputs.attention_mask.unsqueeze(-1).float() | |
| pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1) | |
| all_emb.append(pooled.cpu()) | |
| emb = torch.cat(all_emb) | |
| print(f" Shape: {emb.shape}") | |
| del model | |
| torch.cuda.empty_cache() | |
| return emb | |
| def extract_all(): | |
| os.makedirs(CFG.cache_dir, exist_ok=True) | |
| caps_path = os.path.join(CFG.cache_dir, "captions.json") | |
| all_cached = all( | |
| os.path.exists(os.path.join(CFG.cache_dir, f"{s}.pt")) | |
| for _, s, _ in MODELS) | |
| if all_cached and os.path.exists(caps_path): | |
| print("\n Loading cached embeddings...") | |
| embeds = {} | |
| for _, short, _ in MODELS: | |
| embeds[short] = torch.load( | |
| os.path.join(CFG.cache_dir, f"{short}.pt"), weights_only=True) | |
| print(f" {short}: {embeds[short].shape}") | |
| with open(caps_path) as f: | |
| captions = json.load(f) | |
| return embeds, captions | |
| captions = load_captions(CFG.n_samples, CFG.min_caption_len) | |
| embeds = {} | |
| for model_name, short, model_max_len in MODELS: | |
| emb = extract_one(model_name, short, captions, | |
| model_max_len, CFG.extract_batch) | |
| if emb.shape[1] != 768: | |
| if emb.shape[1] < 768: | |
| emb = F.pad(emb, (0, 768 - emb.shape[1])) | |
| else: | |
| emb = emb[:, :768] | |
| embeds[short] = emb | |
| torch.save(emb, os.path.join(CFG.cache_dir, f"{short}.pt")) | |
| with open(caps_path, "w") as f: | |
| json.dump(captions, f) | |
| return embeds, captions | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # WHITENED PROCRUSTES + CONSENSUS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def symmetric_inv_sqrt(cov, eps=1e-6): | |
| evals, evecs = torch.linalg.eigh(cov) | |
| evals = torch.clamp(evals, min=eps) | |
| return evecs @ torch.diag(evals.rsqrt()) @ evecs.T | |
| def procrustes_align(source, target, n_align=10000): | |
| N = min(n_align, source.shape[0], target.shape[0]) | |
| S = source[:N].float() | |
| T = target[:N].float() | |
| s_mean = S.mean(0, keepdim=True) | |
| t_mean = T.mean(0, keepdim=True) | |
| Sc = S - s_mean | |
| Tc = T - t_mean | |
| N_s = Sc.shape[0] | |
| s_cov = (Sc.T @ Sc) / max(N_s - 1, 1) | |
| t_cov = (Tc.T @ Tc) / max(N_s - 1, 1) | |
| s_whiten = symmetric_inv_sqrt(s_cov) | |
| t_whiten = symmetric_inv_sqrt(t_cov) | |
| Sc_w = F.normalize(Sc @ s_whiten, dim=-1) | |
| Tc_w = F.normalize(Tc @ t_whiten, dim=-1) | |
| cos_before = F.cosine_similarity(Sc, Tc, dim=-1).mean().item() | |
| U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False) | |
| R = U @ Vt | |
| cos_after = F.cosine_similarity(Sc_w @ R.T, Tc_w, dim=-1).mean().item() | |
| return { | |
| "rotation": R, "source_mean": s_mean.squeeze(0), | |
| "source_whitener": s_whiten, | |
| "target_unwhitener": torch.linalg.pinv(t_whiten), | |
| "cos_before": cos_before, "cos_after": cos_after, | |
| } | |
| def apply_align(emb, a): | |
| x = emb.float() - a["source_mean"] | |
| x = x @ a["source_whitener"] | |
| x = x @ a["rotation"].T | |
| x = x @ a["target_unwhitener"] | |
| return x | |
| def generate_consensus(embeds): | |
| """Align all to bert space, take normalized centroid as target.""" | |
| print(f"\n{'='*65}") | |
| print("WHITENED PROCRUSTES ALIGNMENT + CONSENSUS") | |
| print(f"{'='*65}") | |
| ref_name = "bert" | |
| names = [s for _, s, _ in MODELS] | |
| aligned = {} | |
| for name in names: | |
| info = procrustes_align(embeds[name], embeds[ref_name]) | |
| aligned[name] = apply_align(embeds[name], info) | |
| label = " (ref)" if name == ref_name else "" | |
| print(f" {name:10s}: cos {info['cos_before']:.4f} β {info['cos_after']:.4f}{label}") | |
| # Consensus = normalized centroid of all 5 aligned embeddings | |
| # This is what the five-BERT experiment proved: the centroid IS the consensus | |
| # to three decimal places regardless of seed. No learned model needed. | |
| centroid = sum(aligned[n] for n in names) / len(names) | |
| consensus = F.normalize(centroid, dim=-1) | |
| # Verify geometry | |
| N_check = min(5000, consensus.shape[0]) | |
| for name in names: | |
| cos = F.cosine_similarity( | |
| consensus[:N_check], aligned[name][:N_check], dim=-1).mean().item() | |
| print(f" cos(consensus, {name:10s}): {cos:.4f}") | |
| return consensus | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # STUDENT MODEL | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class CaptionEncoder(nn.Module): | |
| def __init__(self, vocab_size=30522, max_len=128, d_model=384, | |
| n_heads=6, n_layers=6, d_ff=1536, output_dim=768, | |
| dropout=0.1, pad_token_id=0): | |
| super().__init__() | |
| self.pad_token_id = pad_token_id | |
| self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id) | |
| self.pos_emb = nn.Embedding(max_len, d_model) | |
| self.emb_norm = nn.LayerNorm(d_model) | |
| self.emb_drop = nn.Dropout(dropout) | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, | |
| dropout=dropout, activation="gelu", batch_first=True, | |
| norm_first=True) | |
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) | |
| self.output_proj = nn.Sequential( | |
| nn.Linear(d_model, d_model), | |
| nn.GELU(), | |
| nn.LayerNorm(d_model), | |
| nn.Linear(d_model, output_dim), | |
| ) | |
| def forward(self, input_ids, attention_mask=None): | |
| B, L = input_ids.shape | |
| positions = torch.arange(L, device=input_ids.device).unsqueeze(0) | |
| x = self.token_emb(input_ids) + self.pos_emb(positions) | |
| x = self.emb_drop(self.emb_norm(x)) | |
| if attention_mask is not None: | |
| kpm = ~attention_mask.bool() | |
| else: | |
| kpm = (input_ids == self.pad_token_id) | |
| x = self.encoder(x, src_key_padding_mask=kpm) | |
| if attention_mask is not None: | |
| mask = attention_mask.unsqueeze(-1).float() | |
| else: | |
| mask = (~kpm).unsqueeze(-1).float() | |
| pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1) | |
| return F.normalize(self.output_proj(pooled), dim=-1) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GEOMETRY | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def cayley_menger_vol2(pts): | |
| pts = pts.float() | |
| diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) | |
| d2 = (diff * diff).sum(-1) | |
| B, V, _ = d2.shape | |
| cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) | |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 | |
| s = (-1.0)**V; f = math.factorial(V-1) | |
| return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) | |
| def cv_loss(emb, target=0.084, n_samples=16): | |
| B = emb.shape[0] | |
| if B < 5: return torch.tensor(0.0, device=emb.device) | |
| vols = [] | |
| for _ in range(n_samples): | |
| idx = torch.randperm(B, device=emb.device)[:5] | |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) | |
| vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) | |
| stacked = torch.stack(vols) | |
| cv = stacked.std() / (stacked.mean() + 1e-8) | |
| return (cv - target).abs() | |
| def cv_metric(emb, n=200): | |
| B = emb.shape[0] | |
| if B < 5: return 0.0 | |
| vols = [] | |
| for _ in range(n): | |
| idx = torch.randperm(B, device=emb.device)[:5] | |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) | |
| v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() | |
| if v > 0: vols.append(v) | |
| if len(vols) < 10: return 0.0 | |
| a = np.array(vols) | |
| return float(a.std() / (a.mean() + 1e-8)) | |
| def infonce(a, b, temperature=0.07): | |
| a = F.normalize(a, dim=-1) | |
| b = F.normalize(b, dim=-1) | |
| logits = (a @ b.T) / temperature | |
| labels = torch.arange(logits.shape[0], device=logits.device) | |
| loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 | |
| with torch.no_grad(): | |
| acc = (logits.argmax(-1) == labels).float().mean().item() | |
| return loss, acc | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TRAINING | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train(): | |
| torch.manual_seed(CFG.seed) | |
| torch.cuda.manual_seed_all(CFG.seed) | |
| np.random.seed(CFG.seed) | |
| # ββ Extract + Align + Consensus ββ | |
| embeds, captions = extract_all() | |
| consensus = generate_consensus(embeds) | |
| # Free the raw embeddings | |
| del embeds | |
| torch.cuda.empty_cache() | |
| import gc; gc.collect() | |
| # ββ Tokenize ββ | |
| from transformers import AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") | |
| print(f"\n Tokenizer: bert-base-uncased (vocab={tokenizer.vocab_size})") | |
| print(" Pre-tokenizing...") | |
| # Tokenize in chunks to avoid memory issues | |
| all_ids, all_masks = [], [] | |
| chunk = 50000 | |
| for i in tqdm(range(0, len(captions), chunk), desc=" Tokenizing"): | |
| j = min(i + chunk, len(captions)) | |
| tokens = tokenizer(captions[i:j], max_length=CFG.tokenize_len, | |
| padding="max_length", truncation=True, | |
| return_tensors="pt") | |
| all_ids.append(tokens["input_ids"]) | |
| all_masks.append(tokens["attention_mask"]) | |
| input_ids = torch.cat(all_ids) | |
| attention_mask = torch.cat(all_masks) | |
| real_lens = attention_mask.sum(1).float() | |
| print(f" Token lengths: mean={real_lens.mean():.0f} " | |
| f"median={real_lens.median():.0f} " | |
| f">{CFG.tokenize_len}: {(real_lens >= CFG.tokenize_len).float().mean():.1%}") | |
| print(f" Padded to: {CFG.tokenize_len} (model supports up to {CFG.max_len})") | |
| # Split | |
| n_train = len(captions) - CFG.n_val | |
| print(f" Train: {n_train:,}, Val: {CFG.n_val:,}") | |
| # Move to GPU | |
| train_ids = input_ids[:n_train].to(DEVICE) | |
| train_mask = attention_mask[:n_train].to(DEVICE) | |
| train_targets = consensus[:n_train].to(DEVICE) | |
| val_ids = input_ids[n_train:].to(DEVICE) | |
| val_mask = attention_mask[n_train:].to(DEVICE) | |
| val_targets = consensus[n_train:].to(DEVICE) | |
| # ββ Student ββ | |
| print(f"\n{'='*65}") | |
| print("STUDENT MODEL") | |
| print(f"{'='*65}") | |
| student = CaptionEncoder( | |
| vocab_size=tokenizer.vocab_size, | |
| max_len=CFG.max_len, | |
| d_model=CFG.d_model, | |
| n_heads=CFG.n_heads, | |
| n_layers=CFG.n_layers, | |
| d_ff=CFG.d_ff, | |
| output_dim=CFG.output_dim, | |
| dropout=CFG.dropout, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ).to(DEVICE) | |
| n_params = sum(p.numel() for p in student.parameters()) | |
| print(f" Architecture: {CFG.n_layers}L, {CFG.d_model}d, {CFG.n_heads}h, {CFG.d_ff} FFN") | |
| print(f" Output: {CFG.output_dim}-dim (consensus space)") | |
| print(f" Parameters: {n_params:,}") | |
| size_mb = sum(p.numel() * p.element_size() for p in student.parameters()) / 1e6 | |
| print(f" Size: {size_mb:.1f} MB") | |
| # ββ Warm-start from previous checkpoint if available ββ # import from the 200k student directory | |
| for prev_dir in ["/home/claude/consensus_200k/student", | |
| "/home/claude/distilled_consensus"]: | |
| prev_ckpt = os.path.join(prev_dir, "best_model.pt") | |
| if os.path.exists(prev_ckpt): | |
| print(f"\n Warm-starting from: {prev_ckpt}") | |
| prev_state = torch.load(prev_ckpt, weights_only=True, map_location=DEVICE) | |
| current_state = student.state_dict() | |
| loaded, extended, skipped = 0, 0, 0 | |
| for name, param in prev_state.items(): | |
| if name not in current_state: | |
| skipped += 1 | |
| continue | |
| if param.shape == current_state[name].shape: | |
| current_state[name] = param | |
| loaded += 1 | |
| elif "pos_emb" in name and param.shape[0] < current_state[name].shape[0]: | |
| # Extend position embeddings: copy old positions, init new ones | |
| old_len = param.shape[0] | |
| current_state[name][:old_len] = param | |
| nn.init.normal_(current_state[name][old_len:], std=0.02) | |
| extended += 1 | |
| print(f" Extended {name}: {param.shape[0]}β{current_state[name].shape[0]}") | |
| else: | |
| skipped += 1 | |
| student.load_state_dict(current_state) | |
| print(f" Loaded: {loaded}, Extended: {extended}, Skipped: {skipped}") | |
| break | |
| else: | |
| print("\n Training from scratch (no previous checkpoint found)") | |
| # ββ Optimizer ββ | |
| optimizer = torch.optim.AdamW(student.parameters(), lr=CFG.lr, | |
| weight_decay=CFG.weight_decay) | |
| n_batches = n_train // CFG.batch_size | |
| total_steps = n_batches * CFG.epochs | |
| scheduler = torch.optim.lr_scheduler.SequentialLR( | |
| optimizer, | |
| [torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, | |
| total_iters=CFG.warmup_steps), | |
| torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=max(total_steps - CFG.warmup_steps, 1), | |
| eta_min=1e-6)], | |
| milestones=[CFG.warmup_steps]) | |
| os.makedirs(CFG.cache_dir, exist_ok=True) | |
| save_dir = os.path.join(CFG.cache_dir, "student") | |
| os.makedirs(save_dir, exist_ok=True) | |
| # ββ Train ββ | |
| print(f"\n{'='*65}") | |
| print(f"TRAINING ({CFG.epochs} epochs, {n_batches} batches/epoch)") | |
| print(f"{'='*65}") | |
| all_metrics = {"config": {k: str(v) for k, v in vars(CFG).items()}, "epochs": []} | |
| best_val_cos = 0.0 | |
| for epoch in range(CFG.epochs): | |
| student.train() | |
| perm = torch.randperm(n_train, device=DEVICE) | |
| losses = {"total": 0, "nce": 0, "mse": 0} | |
| metrics = {"acc": 0, "cos": 0} | |
| n = 0 | |
| t0 = time.time() | |
| for i in range(0, n_train, CFG.batch_size): | |
| idx = perm[i:i+CFG.batch_size] | |
| if len(idx) < 8: continue | |
| emb = student(train_ids[idx], train_mask[idx]) | |
| tgt = train_targets[idx] | |
| l_nce, acc = infonce(emb, tgt) | |
| l_mse = F.mse_loss(emb, tgt) | |
| l_cv = cv_loss(emb, target=CFG.cv_target) | |
| loss = CFG.nce_weight * l_nce + CFG.mse_weight * l_mse + CFG.cv_weight * l_cv | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(student.parameters(), CFG.grad_clip) | |
| optimizer.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| scheduler.step() | |
| with torch.no_grad(): | |
| cos = F.cosine_similarity(emb, tgt, dim=-1).mean().item() | |
| losses["total"] += loss.item() | |
| losses["nce"] += l_nce.item() | |
| losses["mse"] += l_mse.item() | |
| metrics["acc"] += acc | |
| metrics["cos"] += cos | |
| n += 1 | |
| elapsed = time.time() - t0 | |
| d = max(n, 1) | |
| # Val | |
| student.eval() | |
| with torch.no_grad(): | |
| val_embs = [] | |
| for vi in range(0, CFG.n_val, 512): | |
| vj = min(vi + 512, CFG.n_val) | |
| ve = student(val_ids[vi:vj], val_mask[vi:vj]) | |
| val_embs.append(ve) | |
| val_emb = torch.cat(val_embs) | |
| _, val_acc = infonce(val_emb[:2000], val_targets[:2000]) | |
| val_cos = F.cosine_similarity(val_emb, val_targets, dim=-1).mean().item() | |
| val_cv = cv_metric(val_emb[:2000]) | |
| summary = { | |
| "epoch": epoch + 1, "elapsed": elapsed, | |
| "loss": losses["total"] / d, | |
| "train_acc": metrics["acc"] / d, | |
| "train_cos": metrics["cos"] / d, | |
| "val_acc": val_acc, "val_cos": val_cos, "val_cv": val_cv, | |
| } | |
| all_metrics["epochs"].append(summary) | |
| print(f" E{epoch+1:2d}: {elapsed:.0f}s " | |
| f"loss={summary['loss']:.4f} " | |
| f"t_acc={summary['train_acc']:.3f} t_cos={summary['train_cos']:.3f} " | |
| f"v_acc={summary['val_acc']:.3f} v_cos={summary['val_cos']:.3f} " | |
| f"v_cv={summary['val_cv']:.3f}") | |
| if val_cos > best_val_cos: | |
| best_val_cos = val_cos | |
| torch.save(student.state_dict(), os.path.join(save_dir, "best_model.pt")) | |
| if (epoch + 1) % 10 == 0: | |
| torch.save(student.state_dict(), | |
| os.path.join(save_dir, f"model_e{epoch+1:02d}.pt")) | |
| # Final save | |
| torch.save(student.state_dict(), os.path.join(save_dir, "final_model.pt")) | |
| tokenizer.save_pretrained(os.path.join(save_dir, "tokenizer")) | |
| with open(os.path.join(save_dir, "metrics.json"), "w") as f: | |
| json.dump(all_metrics, f, indent=2, default=str) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FINAL EVAL | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"\n{'='*65}") | |
| print("FINAL EVALUATION") | |
| print(f"{'='*65}") | |
| student.load_state_dict( | |
| torch.load(os.path.join(save_dir, "best_model.pt"), | |
| weights_only=True, map_location=DEVICE)) | |
| student.eval() | |
| with torch.no_grad(): | |
| val_embs = [] | |
| for vi in range(0, CFG.n_val, 512): | |
| vj = min(vi + 512, CFG.n_val) | |
| ve = student(val_ids[vi:vj], val_mask[vi:vj]) | |
| val_embs.append(ve) | |
| val_emb = torch.cat(val_embs) | |
| # Retrieval (on 2K subset for memory) | |
| sub = min(2000, CFG.n_val) | |
| sim = val_emb[:sub] @ val_targets[:sub].T | |
| labels = torch.arange(sub, device=DEVICE) | |
| r1 = (sim.argmax(1) == labels).float().mean().item() | |
| r5 = (sim.topk(5, dim=1).indices == labels.unsqueeze(1)).any(1).float().mean().item() | |
| r10 = (sim.topk(10, dim=1).indices == labels.unsqueeze(1)).any(1).float().mean().item() | |
| cos_match = F.cosine_similarity(val_emb, val_targets, dim=-1).mean().item() | |
| final_cv = cv_metric(val_emb[:2000]) | |
| print(f" Retrieval (student β consensus):") | |
| print(f" R@1: {r1:.4f}") | |
| print(f" R@5: {r5:.4f}") | |
| print(f" R@10: {r10:.4f}") | |
| print(f" Cosine: {cos_match:.4f}") | |
| print(f" CV: {final_cv:.4f} (target: {CFG.cv_target})") | |
| print(f" Model: {n_params:,} params, {size_mb:.1f} MB") | |
| # Standalone test | |
| print(f"\n Standalone similarity test:") | |
| test = [ | |
| "A cat sitting on a windowsill watching birds", | |
| "A golden retriever playing fetch on the beach", | |
| "A still life painting with flowers and fruit", | |
| "An aerial photograph of a city skyline at night", | |
| "A child riding a bicycle through autumn leaves", | |
| ] | |
| with torch.no_grad(): | |
| tok = tokenizer(test, max_length=CFG.tokenize_len, padding="max_length", | |
| truncation=True, return_tensors="pt").to(DEVICE) | |
| embs = student(tok["input_ids"], tok["attention_mask"]) | |
| sim = embs @ embs.T | |
| for i in range(len(test)): | |
| for j in range(i+1, len(test)): | |
| print(f" [{i}]β[{j}]: {sim[i,j]:.3f} " | |
| f"({test[i][:35]}β{test[j][:35]})") | |
| print(f"\n Saved to: {save_dir}/") | |
| print(f"\n{'='*65}") | |
| print("DONE") | |
| print(f"{'='*65}") | |
| if __name__ == "__main__": | |
| train() |