| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import math |
| | import os |
| | import time |
| | import json |
| | from dataclasses import dataclass |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from tqdm import tqdm |
| |
|
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | MODELS = [ |
| | ("google-bert/bert-base-uncased", "bert", 512), |
| | ("answerdotai/ModernBERT-base", "modern", 8192), |
| | ("FacebookAI/roberta-base", "roberta", 512), |
| | ("albert/albert-base-v2", "albert", 512), |
| | ("distilbert/distilbert-base-uncased", "distil", 512), |
| | ] |
| |
|
| | @dataclass |
| | class Config: |
| | |
| | n_samples: int = 500000 |
| | n_val: int = 5000 |
| | min_caption_len: int = 50 |
| | extract_batch: int = 1024 |
| | cache_dir: str = "/home/claude/consensus_500k" |
| |
|
| | |
| | d_model: int = 384 |
| | n_heads: int = 6 |
| | n_layers: int = 6 |
| | d_ff: int = 1536 |
| | max_len: int = 8192 |
| | tokenize_len: int = 512 |
| | output_dim: int = 768 |
| | dropout: float = 0.1 |
| |
|
| | |
| | epochs: int = 30 |
| | batch_size: int = 128 |
| | lr: float = 3e-4 |
| | weight_decay: float = 0.01 |
| | warmup_steps: int = 1000 |
| | grad_clip: float = 1.0 |
| | seed: int = 42 |
| |
|
| | |
| | nce_weight: float = 1.0 |
| | mse_weight: float = 1.0 |
| | cv_weight: float = 0.1 |
| | cv_target: float = 0.084 |
| |
|
| | CFG = Config() |
| |
|
| | print("=" * 65) |
| | print("DISTILLED CONSENSUS BERT β 200K Scale") |
| | print("=" * 65) |
| | print(f" Device: {DEVICE}") |
| | print(f" Samples: {CFG.n_samples:,}") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_captions(n, min_len=50): |
| | from datasets import load_dataset |
| | print(f"\n Loading captions (n={n:,})...") |
| | ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext", |
| | split="train", streaming=True) |
| | captions = [] |
| | for row in ds: |
| | cap = row.get("caption_llava", "") |
| | if isinstance(cap, str) and len(cap) > min_len: |
| | captions.append(cap) |
| | if len(captions) >= n: |
| | break |
| | print(f" Got {len(captions):,} captions") |
| | return captions |
| |
|
| |
|
| | @torch.no_grad() |
| | def extract_one(model_name, short_name, captions, max_len, batch_size): |
| | from transformers import AutoModel, AutoTokenizer |
| | print(f"\n Extracting: {short_name} ({model_name})...") |
| | model = AutoModel.from_pretrained(model_name).to(DEVICE).eval() |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | dim = model.config.hidden_size |
| | n_params = sum(p.numel() for p in model.parameters()) |
| | print(f" dim={dim}, {n_params:,} params") |
| |
|
| | all_emb = [] |
| | for i in tqdm(range(0, len(captions), batch_size), desc=f" {short_name}"): |
| | batch = captions[i:i+batch_size] |
| | inputs = tokenizer(batch, max_length=max_len, padding=True, |
| | truncation=True, return_tensors="pt").to(DEVICE) |
| | out = model(**inputs) |
| | mask = inputs.attention_mask.unsqueeze(-1).float() |
| | pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1) |
| | all_emb.append(pooled.cpu()) |
| |
|
| | emb = torch.cat(all_emb) |
| | print(f" Shape: {emb.shape}") |
| | del model |
| | torch.cuda.empty_cache() |
| | return emb |
| |
|
| |
|
| | def extract_all(): |
| | os.makedirs(CFG.cache_dir, exist_ok=True) |
| | caps_path = os.path.join(CFG.cache_dir, "captions.json") |
| |
|
| | all_cached = all( |
| | os.path.exists(os.path.join(CFG.cache_dir, f"{s}.pt")) |
| | for _, s, _ in MODELS) |
| |
|
| | if all_cached and os.path.exists(caps_path): |
| | print("\n Loading cached embeddings...") |
| | embeds = {} |
| | for _, short, _ in MODELS: |
| | embeds[short] = torch.load( |
| | os.path.join(CFG.cache_dir, f"{short}.pt"), weights_only=True) |
| | print(f" {short}: {embeds[short].shape}") |
| | with open(caps_path) as f: |
| | captions = json.load(f) |
| | return embeds, captions |
| |
|
| | captions = load_captions(CFG.n_samples, CFG.min_caption_len) |
| |
|
| | embeds = {} |
| | for model_name, short, model_max_len in MODELS: |
| | emb = extract_one(model_name, short, captions, |
| | model_max_len, CFG.extract_batch) |
| | if emb.shape[1] != 768: |
| | if emb.shape[1] < 768: |
| | emb = F.pad(emb, (0, 768 - emb.shape[1])) |
| | else: |
| | emb = emb[:, :768] |
| | embeds[short] = emb |
| | torch.save(emb, os.path.join(CFG.cache_dir, f"{short}.pt")) |
| |
|
| | with open(caps_path, "w") as f: |
| | json.dump(captions, f) |
| |
|
| | return embeds, captions |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def symmetric_inv_sqrt(cov, eps=1e-6): |
| | evals, evecs = torch.linalg.eigh(cov) |
| | evals = torch.clamp(evals, min=eps) |
| | return evecs @ torch.diag(evals.rsqrt()) @ evecs.T |
| |
|
| |
|
| | def procrustes_align(source, target, n_align=10000): |
| | N = min(n_align, source.shape[0], target.shape[0]) |
| | S = source[:N].float() |
| | T = target[:N].float() |
| | s_mean = S.mean(0, keepdim=True) |
| | t_mean = T.mean(0, keepdim=True) |
| | Sc = S - s_mean |
| | Tc = T - t_mean |
| | N_s = Sc.shape[0] |
| |
|
| | s_cov = (Sc.T @ Sc) / max(N_s - 1, 1) |
| | t_cov = (Tc.T @ Tc) / max(N_s - 1, 1) |
| | s_whiten = symmetric_inv_sqrt(s_cov) |
| | t_whiten = symmetric_inv_sqrt(t_cov) |
| |
|
| | Sc_w = F.normalize(Sc @ s_whiten, dim=-1) |
| | Tc_w = F.normalize(Tc @ t_whiten, dim=-1) |
| |
|
| | cos_before = F.cosine_similarity(Sc, Tc, dim=-1).mean().item() |
| |
|
| | U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False) |
| | R = U @ Vt |
| |
|
| | cos_after = F.cosine_similarity(Sc_w @ R.T, Tc_w, dim=-1).mean().item() |
| |
|
| | return { |
| | "rotation": R, "source_mean": s_mean.squeeze(0), |
| | "source_whitener": s_whiten, |
| | "target_unwhitener": torch.linalg.pinv(t_whiten), |
| | "cos_before": cos_before, "cos_after": cos_after, |
| | } |
| |
|
| |
|
| | def apply_align(emb, a): |
| | x = emb.float() - a["source_mean"] |
| | x = x @ a["source_whitener"] |
| | x = x @ a["rotation"].T |
| | x = x @ a["target_unwhitener"] |
| | return x |
| |
|
| |
|
| | def generate_consensus(embeds): |
| | """Align all to bert space, take normalized centroid as target.""" |
| | print(f"\n{'='*65}") |
| | print("WHITENED PROCRUSTES ALIGNMENT + CONSENSUS") |
| | print(f"{'='*65}") |
| |
|
| | ref_name = "bert" |
| | names = [s for _, s, _ in MODELS] |
| | aligned = {} |
| |
|
| | for name in names: |
| | info = procrustes_align(embeds[name], embeds[ref_name]) |
| | aligned[name] = apply_align(embeds[name], info) |
| | label = " (ref)" if name == ref_name else "" |
| | print(f" {name:10s}: cos {info['cos_before']:.4f} β {info['cos_after']:.4f}{label}") |
| |
|
| | |
| | |
| | |
| | centroid = sum(aligned[n] for n in names) / len(names) |
| | consensus = F.normalize(centroid, dim=-1) |
| |
|
| | |
| | N_check = min(5000, consensus.shape[0]) |
| | for name in names: |
| | cos = F.cosine_similarity( |
| | consensus[:N_check], aligned[name][:N_check], dim=-1).mean().item() |
| | print(f" cos(consensus, {name:10s}): {cos:.4f}") |
| |
|
| | return consensus |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class CaptionEncoder(nn.Module): |
| | def __init__(self, vocab_size=30522, max_len=128, d_model=384, |
| | n_heads=6, n_layers=6, d_ff=1536, output_dim=768, |
| | dropout=0.1, pad_token_id=0): |
| | super().__init__() |
| | self.pad_token_id = pad_token_id |
| | self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id) |
| | self.pos_emb = nn.Embedding(max_len, d_model) |
| | self.emb_norm = nn.LayerNorm(d_model) |
| | self.emb_drop = nn.Dropout(dropout) |
| |
|
| | encoder_layer = nn.TransformerEncoderLayer( |
| | d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, |
| | dropout=dropout, activation="gelu", batch_first=True, |
| | norm_first=True) |
| | self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) |
| |
|
| | self.output_proj = nn.Sequential( |
| | nn.Linear(d_model, d_model), |
| | nn.GELU(), |
| | nn.LayerNorm(d_model), |
| | nn.Linear(d_model, output_dim), |
| | ) |
| |
|
| | def forward(self, input_ids, attention_mask=None): |
| | B, L = input_ids.shape |
| | positions = torch.arange(L, device=input_ids.device).unsqueeze(0) |
| | x = self.token_emb(input_ids) + self.pos_emb(positions) |
| | x = self.emb_drop(self.emb_norm(x)) |
| |
|
| | if attention_mask is not None: |
| | kpm = ~attention_mask.bool() |
| | else: |
| | kpm = (input_ids == self.pad_token_id) |
| |
|
| | x = self.encoder(x, src_key_padding_mask=kpm) |
| |
|
| | if attention_mask is not None: |
| | mask = attention_mask.unsqueeze(-1).float() |
| | else: |
| | mask = (~kpm).unsqueeze(-1).float() |
| | pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1) |
| |
|
| | return F.normalize(self.output_proj(pooled), dim=-1) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def cayley_menger_vol2(pts): |
| | pts = pts.float() |
| | diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| | d2 = (diff * diff).sum(-1) |
| | B, V, _ = d2.shape |
| | cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) |
| | cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| | s = (-1.0)**V; f = math.factorial(V-1) |
| | return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) |
| |
|
| | def cv_loss(emb, target=0.084, n_samples=16): |
| | B = emb.shape[0] |
| | if B < 5: return torch.tensor(0.0, device=emb.device) |
| | vols = [] |
| | for _ in range(n_samples): |
| | idx = torch.randperm(B, device=emb.device)[:5] |
| | v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| | vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) |
| | stacked = torch.stack(vols) |
| | cv = stacked.std() / (stacked.mean() + 1e-8) |
| | return (cv - target).abs() |
| |
|
| | def cv_metric(emb, n=200): |
| | B = emb.shape[0] |
| | if B < 5: return 0.0 |
| | vols = [] |
| | for _ in range(n): |
| | idx = torch.randperm(B, device=emb.device)[:5] |
| | v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| | v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() |
| | if v > 0: vols.append(v) |
| | if len(vols) < 10: return 0.0 |
| | a = np.array(vols) |
| | return float(a.std() / (a.mean() + 1e-8)) |
| |
|
| | def infonce(a, b, temperature=0.07): |
| | a = F.normalize(a, dim=-1) |
| | b = F.normalize(b, dim=-1) |
| | logits = (a @ b.T) / temperature |
| | labels = torch.arange(logits.shape[0], device=logits.device) |
| | loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 |
| | with torch.no_grad(): |
| | acc = (logits.argmax(-1) == labels).float().mean().item() |
| | return loss, acc |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def train(): |
| | torch.manual_seed(CFG.seed) |
| | torch.cuda.manual_seed_all(CFG.seed) |
| | np.random.seed(CFG.seed) |
| |
|
| | |
| | embeds, captions = extract_all() |
| | consensus = generate_consensus(embeds) |
| |
|
| | |
| | del embeds |
| | torch.cuda.empty_cache() |
| | import gc; gc.collect() |
| |
|
| | |
| | from transformers import AutoTokenizer |
| | tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") |
| | print(f"\n Tokenizer: bert-base-uncased (vocab={tokenizer.vocab_size})") |
| |
|
| | print(" Pre-tokenizing...") |
| | |
| | all_ids, all_masks = [], [] |
| | chunk = 50000 |
| | for i in tqdm(range(0, len(captions), chunk), desc=" Tokenizing"): |
| | j = min(i + chunk, len(captions)) |
| | tokens = tokenizer(captions[i:j], max_length=CFG.tokenize_len, |
| | padding="max_length", truncation=True, |
| | return_tensors="pt") |
| | all_ids.append(tokens["input_ids"]) |
| | all_masks.append(tokens["attention_mask"]) |
| |
|
| | input_ids = torch.cat(all_ids) |
| | attention_mask = torch.cat(all_masks) |
| |
|
| | real_lens = attention_mask.sum(1).float() |
| | print(f" Token lengths: mean={real_lens.mean():.0f} " |
| | f"median={real_lens.median():.0f} " |
| | f">{CFG.tokenize_len}: {(real_lens >= CFG.tokenize_len).float().mean():.1%}") |
| | print(f" Padded to: {CFG.tokenize_len} (model supports up to {CFG.max_len})") |
| |
|
| | |
| | n_train = len(captions) - CFG.n_val |
| | print(f" Train: {n_train:,}, Val: {CFG.n_val:,}") |
| |
|
| | |
| | train_ids = input_ids[:n_train].to(DEVICE) |
| | train_mask = attention_mask[:n_train].to(DEVICE) |
| | train_targets = consensus[:n_train].to(DEVICE) |
| | val_ids = input_ids[n_train:].to(DEVICE) |
| | val_mask = attention_mask[n_train:].to(DEVICE) |
| | val_targets = consensus[n_train:].to(DEVICE) |
| |
|
| | |
| | print(f"\n{'='*65}") |
| | print("STUDENT MODEL") |
| | print(f"{'='*65}") |
| |
|
| | student = CaptionEncoder( |
| | vocab_size=tokenizer.vocab_size, |
| | max_len=CFG.max_len, |
| | d_model=CFG.d_model, |
| | n_heads=CFG.n_heads, |
| | n_layers=CFG.n_layers, |
| | d_ff=CFG.d_ff, |
| | output_dim=CFG.output_dim, |
| | dropout=CFG.dropout, |
| | pad_token_id=tokenizer.pad_token_id, |
| | ).to(DEVICE) |
| |
|
| | n_params = sum(p.numel() for p in student.parameters()) |
| | print(f" Architecture: {CFG.n_layers}L, {CFG.d_model}d, {CFG.n_heads}h, {CFG.d_ff} FFN") |
| | print(f" Output: {CFG.output_dim}-dim (consensus space)") |
| | print(f" Parameters: {n_params:,}") |
| | size_mb = sum(p.numel() * p.element_size() for p in student.parameters()) / 1e6 |
| | print(f" Size: {size_mb:.1f} MB") |
| |
|
| | |
| | for prev_dir in ["/home/claude/consensus_200k/student", |
| | "/home/claude/distilled_consensus"]: |
| | prev_ckpt = os.path.join(prev_dir, "best_model.pt") |
| | if os.path.exists(prev_ckpt): |
| | print(f"\n Warm-starting from: {prev_ckpt}") |
| | prev_state = torch.load(prev_ckpt, weights_only=True, map_location=DEVICE) |
| | current_state = student.state_dict() |
| |
|
| | loaded, extended, skipped = 0, 0, 0 |
| | for name, param in prev_state.items(): |
| | if name not in current_state: |
| | skipped += 1 |
| | continue |
| | if param.shape == current_state[name].shape: |
| | current_state[name] = param |
| | loaded += 1 |
| | elif "pos_emb" in name and param.shape[0] < current_state[name].shape[0]: |
| | |
| | old_len = param.shape[0] |
| | current_state[name][:old_len] = param |
| | nn.init.normal_(current_state[name][old_len:], std=0.02) |
| | extended += 1 |
| | print(f" Extended {name}: {param.shape[0]}β{current_state[name].shape[0]}") |
| | else: |
| | skipped += 1 |
| |
|
| | student.load_state_dict(current_state) |
| | print(f" Loaded: {loaded}, Extended: {extended}, Skipped: {skipped}") |
| | break |
| | else: |
| | print("\n Training from scratch (no previous checkpoint found)") |
| |
|
| | |
| | optimizer = torch.optim.AdamW(student.parameters(), lr=CFG.lr, |
| | weight_decay=CFG.weight_decay) |
| | n_batches = n_train // CFG.batch_size |
| | total_steps = n_batches * CFG.epochs |
| | scheduler = torch.optim.lr_scheduler.SequentialLR( |
| | optimizer, |
| | [torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, |
| | total_iters=CFG.warmup_steps), |
| | torch.optim.lr_scheduler.CosineAnnealingLR( |
| | optimizer, T_max=max(total_steps - CFG.warmup_steps, 1), |
| | eta_min=1e-6)], |
| | milestones=[CFG.warmup_steps]) |
| |
|
| | os.makedirs(CFG.cache_dir, exist_ok=True) |
| | save_dir = os.path.join(CFG.cache_dir, "student") |
| | os.makedirs(save_dir, exist_ok=True) |
| |
|
| | |
| | print(f"\n{'='*65}") |
| | print(f"TRAINING ({CFG.epochs} epochs, {n_batches} batches/epoch)") |
| | print(f"{'='*65}") |
| |
|
| | all_metrics = {"config": {k: str(v) for k, v in vars(CFG).items()}, "epochs": []} |
| | best_val_cos = 0.0 |
| |
|
| | for epoch in range(CFG.epochs): |
| | student.train() |
| | perm = torch.randperm(n_train, device=DEVICE) |
| | losses = {"total": 0, "nce": 0, "mse": 0} |
| | metrics = {"acc": 0, "cos": 0} |
| | n = 0 |
| | t0 = time.time() |
| |
|
| | for i in range(0, n_train, CFG.batch_size): |
| | idx = perm[i:i+CFG.batch_size] |
| | if len(idx) < 8: continue |
| |
|
| | emb = student(train_ids[idx], train_mask[idx]) |
| | tgt = train_targets[idx] |
| |
|
| | l_nce, acc = infonce(emb, tgt) |
| | l_mse = F.mse_loss(emb, tgt) |
| | l_cv = cv_loss(emb, target=CFG.cv_target) |
| |
|
| | loss = CFG.nce_weight * l_nce + CFG.mse_weight * l_mse + CFG.cv_weight * l_cv |
| |
|
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(student.parameters(), CFG.grad_clip) |
| | optimizer.step() |
| | optimizer.zero_grad(set_to_none=True) |
| | scheduler.step() |
| |
|
| | with torch.no_grad(): |
| | cos = F.cosine_similarity(emb, tgt, dim=-1).mean().item() |
| |
|
| | losses["total"] += loss.item() |
| | losses["nce"] += l_nce.item() |
| | losses["mse"] += l_mse.item() |
| | metrics["acc"] += acc |
| | metrics["cos"] += cos |
| | n += 1 |
| |
|
| | elapsed = time.time() - t0 |
| | d = max(n, 1) |
| |
|
| | |
| | student.eval() |
| | with torch.no_grad(): |
| | val_embs = [] |
| | for vi in range(0, CFG.n_val, 512): |
| | vj = min(vi + 512, CFG.n_val) |
| | ve = student(val_ids[vi:vj], val_mask[vi:vj]) |
| | val_embs.append(ve) |
| | val_emb = torch.cat(val_embs) |
| | _, val_acc = infonce(val_emb[:2000], val_targets[:2000]) |
| | val_cos = F.cosine_similarity(val_emb, val_targets, dim=-1).mean().item() |
| | val_cv = cv_metric(val_emb[:2000]) |
| |
|
| | summary = { |
| | "epoch": epoch + 1, "elapsed": elapsed, |
| | "loss": losses["total"] / d, |
| | "train_acc": metrics["acc"] / d, |
| | "train_cos": metrics["cos"] / d, |
| | "val_acc": val_acc, "val_cos": val_cos, "val_cv": val_cv, |
| | } |
| | all_metrics["epochs"].append(summary) |
| |
|
| | print(f" E{epoch+1:2d}: {elapsed:.0f}s " |
| | f"loss={summary['loss']:.4f} " |
| | f"t_acc={summary['train_acc']:.3f} t_cos={summary['train_cos']:.3f} " |
| | f"v_acc={summary['val_acc']:.3f} v_cos={summary['val_cos']:.3f} " |
| | f"v_cv={summary['val_cv']:.3f}") |
| |
|
| | if val_cos > best_val_cos: |
| | best_val_cos = val_cos |
| | torch.save(student.state_dict(), os.path.join(save_dir, "best_model.pt")) |
| |
|
| | if (epoch + 1) % 10 == 0: |
| | torch.save(student.state_dict(), |
| | os.path.join(save_dir, f"model_e{epoch+1:02d}.pt")) |
| |
|
| | |
| | torch.save(student.state_dict(), os.path.join(save_dir, "final_model.pt")) |
| | tokenizer.save_pretrained(os.path.join(save_dir, "tokenizer")) |
| | with open(os.path.join(save_dir, "metrics.json"), "w") as f: |
| | json.dump(all_metrics, f, indent=2, default=str) |
| |
|
| | |
| | |
| | |
| |
|
| | print(f"\n{'='*65}") |
| | print("FINAL EVALUATION") |
| | print(f"{'='*65}") |
| |
|
| | student.load_state_dict( |
| | torch.load(os.path.join(save_dir, "best_model.pt"), |
| | weights_only=True, map_location=DEVICE)) |
| | student.eval() |
| |
|
| | with torch.no_grad(): |
| | val_embs = [] |
| | for vi in range(0, CFG.n_val, 512): |
| | vj = min(vi + 512, CFG.n_val) |
| | ve = student(val_ids[vi:vj], val_mask[vi:vj]) |
| | val_embs.append(ve) |
| | val_emb = torch.cat(val_embs) |
| |
|
| | |
| | sub = min(2000, CFG.n_val) |
| | sim = val_emb[:sub] @ val_targets[:sub].T |
| | labels = torch.arange(sub, device=DEVICE) |
| | r1 = (sim.argmax(1) == labels).float().mean().item() |
| | r5 = (sim.topk(5, dim=1).indices == labels.unsqueeze(1)).any(1).float().mean().item() |
| | r10 = (sim.topk(10, dim=1).indices == labels.unsqueeze(1)).any(1).float().mean().item() |
| |
|
| | cos_match = F.cosine_similarity(val_emb, val_targets, dim=-1).mean().item() |
| | final_cv = cv_metric(val_emb[:2000]) |
| |
|
| | print(f" Retrieval (student β consensus):") |
| | print(f" R@1: {r1:.4f}") |
| | print(f" R@5: {r5:.4f}") |
| | print(f" R@10: {r10:.4f}") |
| | print(f" Cosine: {cos_match:.4f}") |
| | print(f" CV: {final_cv:.4f} (target: {CFG.cv_target})") |
| | print(f" Model: {n_params:,} params, {size_mb:.1f} MB") |
| |
|
| | |
| | print(f"\n Standalone similarity test:") |
| | test = [ |
| | "A cat sitting on a windowsill watching birds", |
| | "A golden retriever playing fetch on the beach", |
| | "A still life painting with flowers and fruit", |
| | "An aerial photograph of a city skyline at night", |
| | "A child riding a bicycle through autumn leaves", |
| | ] |
| | with torch.no_grad(): |
| | tok = tokenizer(test, max_length=CFG.tokenize_len, padding="max_length", |
| | truncation=True, return_tensors="pt").to(DEVICE) |
| | embs = student(tok["input_ids"], tok["attention_mask"]) |
| | sim = embs @ embs.T |
| | for i in range(len(test)): |
| | for j in range(i+1, len(test)): |
| | print(f" [{i}]β[{j}]: {sim[i,j]:.3f} " |
| | f"({test[i][:35]}β{test[j][:35]})") |
| |
|
| | print(f"\n Saved to: {save_dir}/") |
| | print(f"\n{'='*65}") |
| | print("DONE") |
| | print(f"{'='*65}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | train() |