# ============================================================================ # DISTILLED CONSENSUS BERT — 200K Scale # # Self-contained pipeline: # 1. Extract 5 BERT-family embeddings on 200K CC12M captions # 2. Whitened Procrustes alignment # 3. Generate consensus targets (centroid of aligned embeddings) # 4. Train small standalone transformer from scratch # 5. No expert models needed at inference # ============================================================================ import math import os import time import json from dataclasses import dataclass import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODELS = [ ("google-bert/bert-base-uncased", "bert", 512), ("answerdotai/ModernBERT-base", "modern", 8192), ("FacebookAI/roberta-base", "roberta", 512), ("albert/albert-base-v2", "albert", 512), ("distilbert/distilbert-base-uncased", "distil", 512), ] @dataclass class Config: # Data n_samples: int = 500000 n_val: int = 5000 min_caption_len: int = 50 extract_batch: int = 1024 cache_dir: str = "/home/claude/consensus_500k" # Student architecture d_model: int = 384 n_heads: int = 6 n_layers: int = 6 d_ff: int = 1536 max_len: int = 8192 # position embedding capacity tokenize_len: int = 512 # actual padding length (captions avg ~100 tokens) output_dim: int = 768 dropout: float = 0.1 # Training epochs: int = 30 batch_size: int = 128 # sequences are tokenize_len=512, not max_len=8192 lr: float = 3e-4 weight_decay: float = 0.01 warmup_steps: int = 1000 grad_clip: float = 1.0 seed: int = 42 # Loss nce_weight: float = 1.0 mse_weight: float = 1.0 cv_weight: float = 0.1 cv_target: float = 0.084 CFG = Config() print("=" * 65) print("DISTILLED CONSENSUS BERT — 200K Scale") print("=" * 65) print(f" Device: {DEVICE}") print(f" Samples: {CFG.n_samples:,}") # ══════════════════════════════════════════════════════════════════ # EXTRACTION # ══════════════════════════════════════════════════════════════════ def load_captions(n, min_len=50): from datasets import load_dataset print(f"\n Loading captions (n={n:,})...") ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext", split="train", streaming=True) captions = [] for row in ds: cap = row.get("caption_llava", "") if isinstance(cap, str) and len(cap) > min_len: captions.append(cap) if len(captions) >= n: break print(f" Got {len(captions):,} captions") return captions @torch.no_grad() def extract_one(model_name, short_name, captions, max_len, batch_size): from transformers import AutoModel, AutoTokenizer print(f"\n Extracting: {short_name} ({model_name})...") model = AutoModel.from_pretrained(model_name).to(DEVICE).eval() tokenizer = AutoTokenizer.from_pretrained(model_name) dim = model.config.hidden_size n_params = sum(p.numel() for p in model.parameters()) print(f" dim={dim}, {n_params:,} params") all_emb = [] for i in tqdm(range(0, len(captions), batch_size), desc=f" {short_name}"): batch = captions[i:i+batch_size] inputs = tokenizer(batch, max_length=max_len, padding=True, truncation=True, return_tensors="pt").to(DEVICE) out = model(**inputs) mask = inputs.attention_mask.unsqueeze(-1).float() pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1) all_emb.append(pooled.cpu()) emb = torch.cat(all_emb) print(f" Shape: {emb.shape}") del model torch.cuda.empty_cache() return emb def extract_all(): os.makedirs(CFG.cache_dir, exist_ok=True) caps_path = os.path.join(CFG.cache_dir, "captions.json") all_cached = all( os.path.exists(os.path.join(CFG.cache_dir, f"{s}.pt")) for _, s, _ in MODELS) if all_cached and os.path.exists(caps_path): print("\n Loading cached embeddings...") embeds = {} for _, short, _ in MODELS: embeds[short] = torch.load( os.path.join(CFG.cache_dir, f"{short}.pt"), weights_only=True) print(f" {short}: {embeds[short].shape}") with open(caps_path) as f: captions = json.load(f) return embeds, captions captions = load_captions(CFG.n_samples, CFG.min_caption_len) embeds = {} for model_name, short, model_max_len in MODELS: emb = extract_one(model_name, short, captions, model_max_len, CFG.extract_batch) if emb.shape[1] != 768: if emb.shape[1] < 768: emb = F.pad(emb, (0, 768 - emb.shape[1])) else: emb = emb[:, :768] embeds[short] = emb torch.save(emb, os.path.join(CFG.cache_dir, f"{short}.pt")) with open(caps_path, "w") as f: json.dump(captions, f) return embeds, captions # ══════════════════════════════════════════════════════════════════ # WHITENED PROCRUSTES + CONSENSUS # ══════════════════════════════════════════════════════════════════ def symmetric_inv_sqrt(cov, eps=1e-6): evals, evecs = torch.linalg.eigh(cov) evals = torch.clamp(evals, min=eps) return evecs @ torch.diag(evals.rsqrt()) @ evecs.T def procrustes_align(source, target, n_align=10000): N = min(n_align, source.shape[0], target.shape[0]) S = source[:N].float() T = target[:N].float() s_mean = S.mean(0, keepdim=True) t_mean = T.mean(0, keepdim=True) Sc = S - s_mean Tc = T - t_mean N_s = Sc.shape[0] s_cov = (Sc.T @ Sc) / max(N_s - 1, 1) t_cov = (Tc.T @ Tc) / max(N_s - 1, 1) s_whiten = symmetric_inv_sqrt(s_cov) t_whiten = symmetric_inv_sqrt(t_cov) Sc_w = F.normalize(Sc @ s_whiten, dim=-1) Tc_w = F.normalize(Tc @ t_whiten, dim=-1) cos_before = F.cosine_similarity(Sc, Tc, dim=-1).mean().item() U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False) R = U @ Vt cos_after = F.cosine_similarity(Sc_w @ R.T, Tc_w, dim=-1).mean().item() return { "rotation": R, "source_mean": s_mean.squeeze(0), "source_whitener": s_whiten, "target_unwhitener": torch.linalg.pinv(t_whiten), "cos_before": cos_before, "cos_after": cos_after, } def apply_align(emb, a): x = emb.float() - a["source_mean"] x = x @ a["source_whitener"] x = x @ a["rotation"].T x = x @ a["target_unwhitener"] return x def generate_consensus(embeds): """Align all to bert space, take normalized centroid as target.""" print(f"\n{'='*65}") print("WHITENED PROCRUSTES ALIGNMENT + CONSENSUS") print(f"{'='*65}") ref_name = "bert" names = [s for _, s, _ in MODELS] aligned = {} for name in names: info = procrustes_align(embeds[name], embeds[ref_name]) aligned[name] = apply_align(embeds[name], info) label = " (ref)" if name == ref_name else "" print(f" {name:10s}: cos {info['cos_before']:.4f} → {info['cos_after']:.4f}{label}") # Consensus = normalized centroid of all 5 aligned embeddings # This is what the five-BERT experiment proved: the centroid IS the consensus # to three decimal places regardless of seed. No learned model needed. centroid = sum(aligned[n] for n in names) / len(names) consensus = F.normalize(centroid, dim=-1) # Verify geometry N_check = min(5000, consensus.shape[0]) for name in names: cos = F.cosine_similarity( consensus[:N_check], aligned[name][:N_check], dim=-1).mean().item() print(f" cos(consensus, {name:10s}): {cos:.4f}") return consensus # ══════════════════════════════════════════════════════════════════ # STUDENT MODEL # ══════════════════════════════════════════════════════════════════ class CaptionEncoder(nn.Module): def __init__(self, vocab_size=30522, max_len=128, d_model=384, n_heads=6, n_layers=6, d_ff=1536, output_dim=768, dropout=0.1, pad_token_id=0): super().__init__() self.pad_token_id = pad_token_id self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id) self.pos_emb = nn.Embedding(max_len, d_model) self.emb_norm = nn.LayerNorm(d_model) self.emb_drop = nn.Dropout(dropout) encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, dropout=dropout, activation="gelu", batch_first=True, norm_first=True) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) self.output_proj = nn.Sequential( nn.Linear(d_model, d_model), nn.GELU(), nn.LayerNorm(d_model), nn.Linear(d_model, output_dim), ) def forward(self, input_ids, attention_mask=None): B, L = input_ids.shape positions = torch.arange(L, device=input_ids.device).unsqueeze(0) x = self.token_emb(input_ids) + self.pos_emb(positions) x = self.emb_drop(self.emb_norm(x)) if attention_mask is not None: kpm = ~attention_mask.bool() else: kpm = (input_ids == self.pad_token_id) x = self.encoder(x, src_key_padding_mask=kpm) if attention_mask is not None: mask = attention_mask.unsqueeze(-1).float() else: mask = (~kpm).unsqueeze(-1).float() pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1) return F.normalize(self.output_proj(pooled), dim=-1) # ══════════════════════════════════════════════════════════════════ # GEOMETRY # ══════════════════════════════════════════════════════════════════ def cayley_menger_vol2(pts): pts = pts.float() diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) d2 = (diff * diff).sum(-1) B, V, _ = d2.shape cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 s = (-1.0)**V; f = math.factorial(V-1) return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) def cv_loss(emb, target=0.084, n_samples=16): B = emb.shape[0] if B < 5: return torch.tensor(0.0, device=emb.device) vols = [] for _ in range(n_samples): idx = torch.randperm(B, device=emb.device)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) stacked = torch.stack(vols) cv = stacked.std() / (stacked.mean() + 1e-8) return (cv - target).abs() def cv_metric(emb, n=200): B = emb.shape[0] if B < 5: return 0.0 vols = [] for _ in range(n): idx = torch.randperm(B, device=emb.device)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() if v > 0: vols.append(v) if len(vols) < 10: return 0.0 a = np.array(vols) return float(a.std() / (a.mean() + 1e-8)) def infonce(a, b, temperature=0.07): a = F.normalize(a, dim=-1) b = F.normalize(b, dim=-1) logits = (a @ b.T) / temperature labels = torch.arange(logits.shape[0], device=logits.device) loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 with torch.no_grad(): acc = (logits.argmax(-1) == labels).float().mean().item() return loss, acc # ══════════════════════════════════════════════════════════════════ # TRAINING # ══════════════════════════════════════════════════════════════════ def train(): torch.manual_seed(CFG.seed) torch.cuda.manual_seed_all(CFG.seed) np.random.seed(CFG.seed) # ── Extract + Align + Consensus ── embeds, captions = extract_all() consensus = generate_consensus(embeds) # Free the raw embeddings del embeds torch.cuda.empty_cache() import gc; gc.collect() # ── Tokenize ── from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") print(f"\n Tokenizer: bert-base-uncased (vocab={tokenizer.vocab_size})") print(" Pre-tokenizing...") # Tokenize in chunks to avoid memory issues all_ids, all_masks = [], [] chunk = 50000 for i in tqdm(range(0, len(captions), chunk), desc=" Tokenizing"): j = min(i + chunk, len(captions)) tokens = tokenizer(captions[i:j], max_length=CFG.tokenize_len, padding="max_length", truncation=True, return_tensors="pt") all_ids.append(tokens["input_ids"]) all_masks.append(tokens["attention_mask"]) input_ids = torch.cat(all_ids) attention_mask = torch.cat(all_masks) real_lens = attention_mask.sum(1).float() print(f" Token lengths: mean={real_lens.mean():.0f} " f"median={real_lens.median():.0f} " f">{CFG.tokenize_len}: {(real_lens >= CFG.tokenize_len).float().mean():.1%}") print(f" Padded to: {CFG.tokenize_len} (model supports up to {CFG.max_len})") # Split n_train = len(captions) - CFG.n_val print(f" Train: {n_train:,}, Val: {CFG.n_val:,}") # Move to GPU train_ids = input_ids[:n_train].to(DEVICE) train_mask = attention_mask[:n_train].to(DEVICE) train_targets = consensus[:n_train].to(DEVICE) val_ids = input_ids[n_train:].to(DEVICE) val_mask = attention_mask[n_train:].to(DEVICE) val_targets = consensus[n_train:].to(DEVICE) # ── Student ── print(f"\n{'='*65}") print("STUDENT MODEL") print(f"{'='*65}") student = CaptionEncoder( vocab_size=tokenizer.vocab_size, max_len=CFG.max_len, d_model=CFG.d_model, n_heads=CFG.n_heads, n_layers=CFG.n_layers, d_ff=CFG.d_ff, output_dim=CFG.output_dim, dropout=CFG.dropout, pad_token_id=tokenizer.pad_token_id, ).to(DEVICE) n_params = sum(p.numel() for p in student.parameters()) print(f" Architecture: {CFG.n_layers}L, {CFG.d_model}d, {CFG.n_heads}h, {CFG.d_ff} FFN") print(f" Output: {CFG.output_dim}-dim (consensus space)") print(f" Parameters: {n_params:,}") size_mb = sum(p.numel() * p.element_size() for p in student.parameters()) / 1e6 print(f" Size: {size_mb:.1f} MB") # ── Warm-start from previous checkpoint if available ── for prev_dir in ["/home/claude/consensus_200k/student", "/home/claude/distilled_consensus"]: prev_ckpt = os.path.join(prev_dir, "best_model.pt") if os.path.exists(prev_ckpt): print(f"\n Warm-starting from: {prev_ckpt}") prev_state = torch.load(prev_ckpt, weights_only=True, map_location=DEVICE) current_state = student.state_dict() loaded, extended, skipped = 0, 0, 0 for name, param in prev_state.items(): if name not in current_state: skipped += 1 continue if param.shape == current_state[name].shape: current_state[name] = param loaded += 1 elif "pos_emb" in name and param.shape[0] < current_state[name].shape[0]: # Extend position embeddings: copy old positions, init new ones old_len = param.shape[0] current_state[name][:old_len] = param nn.init.normal_(current_state[name][old_len:], std=0.02) extended += 1 print(f" Extended {name}: {param.shape[0]}→{current_state[name].shape[0]}") else: skipped += 1 student.load_state_dict(current_state) print(f" Loaded: {loaded}, Extended: {extended}, Skipped: {skipped}") break else: print("\n Training from scratch (no previous checkpoint found)") # ── Optimizer ── optimizer = torch.optim.AdamW(student.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay) n_batches = n_train // CFG.batch_size total_steps = n_batches * CFG.epochs scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, [torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=CFG.warmup_steps), torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=max(total_steps - CFG.warmup_steps, 1), eta_min=1e-6)], milestones=[CFG.warmup_steps]) os.makedirs(CFG.cache_dir, exist_ok=True) save_dir = os.path.join(CFG.cache_dir, "student") os.makedirs(save_dir, exist_ok=True) # ── Train ── print(f"\n{'='*65}") print(f"TRAINING ({CFG.epochs} epochs, {n_batches} batches/epoch)") print(f"{'='*65}") all_metrics = {"config": {k: str(v) for k, v in vars(CFG).items()}, "epochs": []} best_val_cos = 0.0 for epoch in range(CFG.epochs): student.train() perm = torch.randperm(n_train, device=DEVICE) losses = {"total": 0, "nce": 0, "mse": 0} metrics = {"acc": 0, "cos": 0} n = 0 t0 = time.time() for i in range(0, n_train, CFG.batch_size): idx = perm[i:i+CFG.batch_size] if len(idx) < 8: continue emb = student(train_ids[idx], train_mask[idx]) tgt = train_targets[idx] l_nce, acc = infonce(emb, tgt) l_mse = F.mse_loss(emb, tgt) l_cv = cv_loss(emb, target=CFG.cv_target) loss = CFG.nce_weight * l_nce + CFG.mse_weight * l_mse + CFG.cv_weight * l_cv loss.backward() torch.nn.utils.clip_grad_norm_(student.parameters(), CFG.grad_clip) optimizer.step() optimizer.zero_grad(set_to_none=True) scheduler.step() with torch.no_grad(): cos = F.cosine_similarity(emb, tgt, dim=-1).mean().item() losses["total"] += loss.item() losses["nce"] += l_nce.item() losses["mse"] += l_mse.item() metrics["acc"] += acc metrics["cos"] += cos n += 1 elapsed = time.time() - t0 d = max(n, 1) # Val student.eval() with torch.no_grad(): val_embs = [] for vi in range(0, CFG.n_val, 512): vj = min(vi + 512, CFG.n_val) ve = student(val_ids[vi:vj], val_mask[vi:vj]) val_embs.append(ve) val_emb = torch.cat(val_embs) _, val_acc = infonce(val_emb[:2000], val_targets[:2000]) val_cos = F.cosine_similarity(val_emb, val_targets, dim=-1).mean().item() val_cv = cv_metric(val_emb[:2000]) summary = { "epoch": epoch + 1, "elapsed": elapsed, "loss": losses["total"] / d, "train_acc": metrics["acc"] / d, "train_cos": metrics["cos"] / d, "val_acc": val_acc, "val_cos": val_cos, "val_cv": val_cv, } all_metrics["epochs"].append(summary) print(f" E{epoch+1:2d}: {elapsed:.0f}s " f"loss={summary['loss']:.4f} " f"t_acc={summary['train_acc']:.3f} t_cos={summary['train_cos']:.3f} " f"v_acc={summary['val_acc']:.3f} v_cos={summary['val_cos']:.3f} " f"v_cv={summary['val_cv']:.3f}") if val_cos > best_val_cos: best_val_cos = val_cos torch.save(student.state_dict(), os.path.join(save_dir, "best_model.pt")) if (epoch + 1) % 10 == 0: torch.save(student.state_dict(), os.path.join(save_dir, f"model_e{epoch+1:02d}.pt")) # Final save torch.save(student.state_dict(), os.path.join(save_dir, "final_model.pt")) tokenizer.save_pretrained(os.path.join(save_dir, "tokenizer")) with open(os.path.join(save_dir, "metrics.json"), "w") as f: json.dump(all_metrics, f, indent=2, default=str) # ══════════════════════════════════════════════════════════════ # FINAL EVAL # ══════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("FINAL EVALUATION") print(f"{'='*65}") student.load_state_dict( torch.load(os.path.join(save_dir, "best_model.pt"), weights_only=True, map_location=DEVICE)) student.eval() with torch.no_grad(): val_embs = [] for vi in range(0, CFG.n_val, 512): vj = min(vi + 512, CFG.n_val) ve = student(val_ids[vi:vj], val_mask[vi:vj]) val_embs.append(ve) val_emb = torch.cat(val_embs) # Retrieval (on 2K subset for memory) sub = min(2000, CFG.n_val) sim = val_emb[:sub] @ val_targets[:sub].T labels = torch.arange(sub, device=DEVICE) r1 = (sim.argmax(1) == labels).float().mean().item() r5 = (sim.topk(5, dim=1).indices == labels.unsqueeze(1)).any(1).float().mean().item() r10 = (sim.topk(10, dim=1).indices == labels.unsqueeze(1)).any(1).float().mean().item() cos_match = F.cosine_similarity(val_emb, val_targets, dim=-1).mean().item() final_cv = cv_metric(val_emb[:2000]) print(f" Retrieval (student → consensus):") print(f" R@1: {r1:.4f}") print(f" R@5: {r5:.4f}") print(f" R@10: {r10:.4f}") print(f" Cosine: {cos_match:.4f}") print(f" CV: {final_cv:.4f} (target: {CFG.cv_target})") print(f" Model: {n_params:,} params, {size_mb:.1f} MB") # Standalone test print(f"\n Standalone similarity test:") test = [ "A cat sitting on a windowsill watching birds", "A golden retriever playing fetch on the beach", "A still life painting with flowers and fruit", "An aerial photograph of a city skyline at night", "A child riding a bicycle through autumn leaves", ] with torch.no_grad(): tok = tokenizer(test, max_length=CFG.tokenize_len, padding="max_length", truncation=True, return_tensors="pt").to(DEVICE) embs = student(tok["input_ids"], tok["attention_mask"]) sim = embs @ embs.T for i in range(len(test)): for j in range(i+1, len(test)): print(f" [{i}]↔[{j}]: {sim[i,j]:.3f} " f"({test[i][:35]}↔{test[j][:35]})") print(f"\n Saved to: {save_dir}/") print(f"\n{'='*65}") print("DONE") print(f"{'='*65}") if __name__ == "__main__": train()