| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import gc |
| import math |
| import os |
| import time |
| import json |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from tqdm import tqdm |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| EXPERTS = [ |
| ("google-bert/bert-base-uncased", "bert", 512), |
| ("answerdotai/ModernBERT-base", "modern", 512), |
| ] |
|
|
| print("=" * 65) |
| print("RAPID PROTOTYPE: 2-Expert Consensus + Alignment Bank") |
| print("=" * 65) |
| print(f" Device: {DEVICE}") |
|
|
|
|
| |
| |
| |
|
|
| class MiniStudent(nn.Module): |
| def __init__(self, vocab_size=30522, max_len=512, d_model=256, |
| n_heads=4, n_layers=4, d_ff=1024, output_dim=768, |
| dropout=0.1, pad_token_id=0): |
| super().__init__() |
| self.pad_token_id = pad_token_id |
| self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id) |
| self.pos_emb = nn.Embedding(max_len, d_model) |
| self.emb_norm = nn.LayerNorm(d_model) |
| self.emb_drop = nn.Dropout(dropout) |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, |
| dropout=dropout, activation="gelu", batch_first=True, |
| norm_first=True) |
| self.encoder = nn.TransformerEncoder( |
| encoder_layer, num_layers=n_layers, enable_nested_tensor=False) |
| self.output_proj = nn.Sequential( |
| nn.Linear(d_model, d_model), nn.GELU(), |
| nn.LayerNorm(d_model), nn.Linear(d_model, output_dim)) |
|
|
| def forward(self, input_ids, attention_mask=None): |
| B, L = input_ids.shape |
| positions = torch.arange(L, device=input_ids.device).unsqueeze(0) |
| x = self.token_emb(input_ids) + self.pos_emb(positions) |
| x = self.emb_drop(self.emb_norm(x)) |
| kpm = ~attention_mask.bool() if attention_mask is not None else (input_ids == self.pad_token_id) |
| x = self.encoder(x, src_key_padding_mask=kpm) |
| mask = attention_mask.unsqueeze(-1).float() if attention_mask is not None else (~kpm).unsqueeze(-1).float() |
| pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1) |
| return F.normalize(self.output_proj(pooled), dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| class AlignmentBank(nn.Module): |
| """ |
| Geometric interface layer. Learns to annotate student embeddings |
| with per-expert alignment context and anchor distances. |
| |
| Trained on frozen student output. Provides geometric memory of |
| the expert consensus for downstream heads. |
| """ |
| def __init__(self, d_embed=768, n_experts=2, n_anchors=128, d_bank=64): |
| super().__init__() |
| self.d_embed = d_embed |
| self.n_experts = n_experts |
| self.n_anchors = n_anchors |
| self.d_bank = d_bank |
|
|
| |
| self.expert_rotations = nn.ParameterList([ |
| nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts) |
| ]) |
|
|
| |
| self.expert_means = nn.ParameterList([ |
| nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts) |
| ]) |
|
|
| |
| self.anchors = nn.Parameter( |
| F.normalize(torch.randn(n_anchors, d_embed), dim=-1)) |
|
|
| |
| |
| geo_dim = n_experts + n_anchors + n_experts |
| self.geo_proj = nn.Sequential( |
| nn.Linear(geo_dim, d_bank * 2), |
| nn.GELU(), |
| nn.LayerNorm(d_bank * 2), |
| nn.Linear(d_bank * 2, d_bank), |
| nn.LayerNorm(d_bank), |
| ) |
|
|
| def init_from_procrustes(self, procrustes_results, expert_names, |
| consensus_embeddings=None): |
| """Initialize from consensus training artifacts.""" |
| device = self.anchors.device |
| for i, name in enumerate(expert_names[:self.n_experts]): |
| info = procrustes_results[name] |
| self.expert_rotations[i].data = info["rotation"].float().to(device) |
| self.expert_means[i].data = info["source_mean"].float().to(device) |
| print(f" Expert {i} ({name}): rotation loaded, cos_after={info['cos_after']:.4f}") |
|
|
| if consensus_embeddings is not None: |
| n = min(self.n_anchors, consensus_embeddings.shape[0]) |
| indices = torch.linspace(0, consensus_embeddings.shape[0] - 1, n).long() |
| self.anchors.data[:n] = F.normalize( |
| consensus_embeddings[indices].float(), dim=-1).to(device) |
| print(f" Anchors: {n} initialized from consensus embeddings") |
|
|
| def forward(self, embedding): |
| """ |
| Annotate embedding with geometric context. |
| |
| Args: |
| embedding: (B, 768) L2-normalized |
| |
| Returns: |
| enriched: (B, 768 + d_bank) |
| aux: dict with geometric losses and diagnostics |
| """ |
| B = embedding.shape[0] |
| emb = embedding.float() |
|
|
| |
| expert_consistency = [] |
| expert_recon = [] |
| for i in range(self.n_experts): |
| R = self.expert_rotations[i] |
| |
| in_expert = emb @ R |
| |
| round_trip = in_expert @ R.T |
| |
| cos = F.cosine_similarity(emb, round_trip, dim=-1) |
| recon = (emb - round_trip).pow(2).mean(dim=-1) |
| expert_consistency.append(cos) |
| expert_recon.append(recon) |
|
|
| expert_cos = torch.stack(expert_consistency, dim=-1) |
| expert_mse = torch.stack(expert_recon, dim=-1) |
|
|
| |
| anchors_n = F.normalize(self.anchors, dim=-1) |
| anchor_cos = emb @ anchors_n.T |
|
|
| |
| geo_input = torch.cat([expert_cos, anchor_cos, expert_mse], dim=-1) |
| geo_context = self.geo_proj(geo_input) |
|
|
| |
| enriched = torch.cat([embedding, geo_context], dim=-1) |
|
|
| |
| aux = {} |
|
|
| |
| expert_mean = expert_cos.mean(dim=-1, keepdim=True) |
| aux["expert_agreement"] = (expert_cos - expert_mean).pow(2).mean() |
|
|
| |
| ortho_loss = 0.0 |
| for i in range(self.n_experts): |
| R = self.expert_rotations[i] |
| RRT = R @ R.T |
| ortho_loss += (RRT - torch.eye(self.d_embed, device=R.device)).pow(2).mean() |
| aux["rotation_ortho"] = ortho_loss / self.n_experts |
|
|
| |
| anchor_sim = anchors_n @ anchors_n.T |
| anchor_sim.fill_diagonal_(0) |
| aux["anchor_spread"] = anchor_sim.pow(2).mean() |
|
|
| |
| anchor_probs = F.softmax(anchor_cos * 10, dim=-1) |
| entropy = -(anchor_probs * (anchor_probs + 1e-12).log()).sum(-1).mean() |
| aux["anchor_entropy"] = entropy |
|
|
| |
| if B >= 10: |
| ctx_n = F.normalize(geo_context, dim=-1) |
| vols = [] |
| for _ in range(32): |
| idx = torch.randperm(B, device=embedding.device)[:5] |
| pts = ctx_n[idx].unsqueeze(0) |
| diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| d2 = (diff * diff).sum(-1) |
| Bv, V, _ = d2.shape |
| cm = torch.zeros(Bv, V+1, V+1, device=d2.device, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| s = (-1.0)**V; f = math.factorial(V-1) |
| v2 = s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) |
| vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) |
| stacked = torch.stack(vols) |
| bank_cv = stacked.std() / (stacked.mean() + 1e-8) |
| aux["bank_cv"] = bank_cv |
| else: |
| aux["bank_cv"] = torch.tensor(0.0, device=embedding.device) |
|
|
| |
| aux["expert_cos_mean"] = expert_cos.mean().item() |
| aux["expert_cos_std"] = expert_cos.std().item() |
| aux["anchor_max_cos"] = anchor_cos.max(dim=-1).values.mean().item() |
| aux["anchor_mean_cos"] = anchor_cos.mean().item() |
|
|
| return enriched, aux |
|
|
| def bank_loss(self, aux, cv_target=0.15): |
| """Combined bank training loss.""" |
| loss = (1.0 * aux["expert_agreement"] + |
| 1.0 * aux["rotation_ortho"] + |
| 0.5 * aux["anchor_spread"] + |
| 0.1 * aux["anchor_entropy"] + |
| 0.3 * (aux["bank_cv"] - cv_target).abs()) |
| return loss |
|
|
|
|
| |
| |
| |
|
|
| def infonce(a, b, temperature=0.07): |
| a = F.normalize(a, dim=-1) |
| b = F.normalize(b, dim=-1) |
| logits = (a @ b.T) / temperature |
| labels = torch.arange(logits.shape[0], device=logits.device) |
| loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 |
| with torch.no_grad(): |
| acc = (logits.argmax(-1) == labels).float().mean().item() |
| return loss, acc |
|
|
| def cayley_menger_vol2(pts): |
| pts = pts.float() |
| diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| d2 = (diff * diff).sum(-1) |
| B, V, _ = d2.shape |
| cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| s = (-1.0)**V; f = math.factorial(V-1) |
| return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) |
|
|
| def cv_loss(emb, target=0.12, n_samples=16): |
| B = emb.shape[0] |
| if B < 5: return torch.tensor(0.0, device=emb.device) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(B, device=emb.device)[:5] |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) |
| stacked = torch.stack(vols) |
| cv = stacked.std() / (stacked.mean() + 1e-8) |
| return (cv - target).abs() |
|
|
| def cv_metric(emb, n=200): |
| B = emb.shape[0] |
| if B < 5: return 0.0 |
| vols = [] |
| for _ in range(n): |
| idx = torch.randperm(B, device=emb.device)[:5] |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() |
| if v > 0: vols.append(v) |
| if len(vols) < 10: return 0.0 |
| a = np.array(vols) |
| return float(a.std() / (a.mean() + 1e-8)) |
|
|
|
|
| |
| |
| |
|
|
| def symmetric_inv_sqrt(cov, eps=1e-6): |
| evals, evecs = torch.linalg.eigh(cov) |
| evals = torch.clamp(evals, min=eps) |
| return evecs @ torch.diag(evals.rsqrt()) @ evecs.T |
|
|
| def procrustes_align(source, target, n_align=5000): |
| N = min(n_align, source.shape[0], target.shape[0]) |
| S = source[:N].float() |
| T = target[:N].float() |
| s_mean = S.mean(0, keepdim=True) |
| t_mean = T.mean(0, keepdim=True) |
| Sc = S - s_mean; Tc = T - t_mean |
| N_s = Sc.shape[0] |
| cos_before = F.cosine_similarity(Sc, Tc, dim=-1).mean().item() |
| s_cov = (Sc.T @ Sc) / max(N_s - 1, 1) |
| t_cov = (Tc.T @ Tc) / max(N_s - 1, 1) |
| s_whiten = symmetric_inv_sqrt(s_cov) |
| t_whiten = symmetric_inv_sqrt(t_cov) |
| Sc_w = F.normalize(Sc @ s_whiten, dim=-1) |
| Tc_w = F.normalize(Tc @ t_whiten, dim=-1) |
| U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False) |
| R = U @ Vt |
| cos_after = F.cosine_similarity(Sc_w @ R.T, Tc_w, dim=-1).mean().item() |
| return { |
| "rotation": R, "source_mean": s_mean.squeeze(0), |
| "source_whitener": s_whiten, |
| "target_unwhitener": torch.linalg.pinv(t_whiten), |
| "cos_before": cos_before, "cos_after": cos_after, |
| } |
|
|
| def apply_align(emb, a): |
| x = emb.float() - a["source_mean"] |
| x = x @ a["source_whitener"] |
| x = x @ a["rotation"].T |
| x = x @ a["target_unwhitener"] |
| return x |
|
|
|
|
| |
| |
| |
|
|
| def run(): |
| torch.manual_seed(42) |
| np.random.seed(42) |
| N_SAMPLES = 20000 |
| MAX_LEN = 128 |
| BATCH = 256 |
|
|
| |
| print(f"\n{'='*65}") |
| print("PHASE 0: EXTRACTION") |
| print(f"{'='*65}") |
|
|
| from datasets import load_dataset |
| from transformers import AutoModel, AutoTokenizer |
|
|
| ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext", |
| split="train", streaming=True) |
| captions = [] |
| for row in ds: |
| cap = row.get("caption_llava", "") |
| if isinstance(cap, str) and len(cap) > 50: |
| captions.append(cap) |
| if len(captions) >= N_SAMPLES: |
| break |
| print(f" Captions: {len(captions):,}") |
|
|
| embeds = {} |
| for model_name, short, max_len in EXPERTS: |
| print(f"\n Extracting: {short}...") |
| model = AutoModel.from_pretrained(model_name).to(DEVICE).eval() |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| all_emb = [] |
| with torch.no_grad(): |
| for i in tqdm(range(0, len(captions), 128), desc=f" {short}"): |
| batch = captions[i:i+128] |
| inputs = tokenizer(batch, max_length=max_len, padding=True, |
| truncation=True, return_tensors="pt").to(DEVICE) |
| out = model(**inputs) |
| m = inputs.attention_mask.unsqueeze(-1).float() |
| pooled = (out.last_hidden_state * m).sum(1) / m.sum(1).clamp(min=1) |
| all_emb.append(pooled.cpu()) |
| embeds[short] = torch.cat(all_emb) |
| print(f" Shape: {embeds[short].shape}") |
| del model; gc.collect(); torch.cuda.empty_cache() |
|
|
| |
| print(f"\n{'='*65}") |
| print("PHASE 0b: PROCRUSTES ALIGNMENT") |
| print(f"{'='*65}") |
|
|
| ref = "bert" |
| names = [s for _, s, _ in EXPERTS] |
| procrustes_results = {} |
| aligned = {} |
| for name in names: |
| info = procrustes_align(embeds[name], embeds[ref]) |
| procrustes_results[name] = info |
| aligned[name] = apply_align(embeds[name], info) |
| print(f" {name:10s}: cos {info['cos_before']:.4f} β {info['cos_after']:.4f}") |
|
|
| consensus = F.normalize(sum(aligned[n] for n in names) / len(names), dim=-1) |
| print(f" Consensus: {consensus.shape}") |
| for name in names: |
| cos = F.cosine_similarity(consensus[:2000], aligned[name][:2000], dim=-1).mean().item() |
| print(f" cos(consensus, {name}): {cos:.4f}") |
|
|
| consensus_cv = cv_metric(consensus[:2000].to(DEVICE)) |
| print(f" Consensus CV: {consensus_cv:.4f}") |
|
|
| del embeds, aligned |
| gc.collect(); torch.cuda.empty_cache() |
|
|
| |
| print(f"\n{'='*65}") |
| print("PHASE 1: TRAIN STUDENT (2 experts, 20K captions)") |
| print(f"{'='*65}") |
|
|
| tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") |
| tokens = tokenizer(captions, max_length=MAX_LEN, padding="max_length", |
| truncation=True, return_tensors="pt") |
| input_ids = tokens["input_ids"] |
| attention_mask = tokens["attention_mask"] |
|
|
| n_train = N_SAMPLES - 2000 |
| train_ids = input_ids[:n_train].to(DEVICE) |
| train_mask = attention_mask[:n_train].to(DEVICE) |
| train_targets = consensus[:n_train].to(DEVICE) |
| val_ids = input_ids[n_train:].to(DEVICE) |
| val_mask = attention_mask[n_train:].to(DEVICE) |
| val_targets = consensus[n_train:].to(DEVICE) |
|
|
| student = MiniStudent( |
| vocab_size=tokenizer.vocab_size, max_len=MAX_LEN, |
| d_model=256, n_heads=4, n_layers=4, d_ff=1024, |
| output_dim=768, dropout=0.1, pad_token_id=tokenizer.pad_token_id |
| ).to(DEVICE) |
| n_params = sum(p.numel() for p in student.parameters()) |
| print(f" Student: {n_params:,} params") |
|
|
| optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4, weight_decay=0.01) |
|
|
| for epoch in range(5): |
| student.train() |
| perm = torch.randperm(n_train, device=DEVICE) |
| t_loss, t_acc, t_cos, n = 0, 0, 0, 0 |
| t0 = time.time() |
|
|
| for i in range(0, n_train, BATCH): |
| idx = perm[i:i+BATCH] |
| if len(idx) < 8: continue |
| emb = student(train_ids[idx], train_mask[idx]) |
| tgt = train_targets[idx] |
| l_nce, acc = infonce(emb, tgt) |
| l_mse = F.mse_loss(emb, tgt) |
| l_cv = cv_loss(emb, target=consensus_cv) |
| loss = l_nce + l_mse + 0.1 * l_cv |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0) |
| optimizer.step(); optimizer.zero_grad(set_to_none=True) |
| with torch.no_grad(): |
| cos = F.cosine_similarity(emb, tgt, dim=-1).mean().item() |
| t_loss += loss.item(); t_acc += acc; t_cos += cos; n += 1 |
|
|
| elapsed = time.time() - t0 |
| d = max(n, 1) |
| student.eval() |
| with torch.no_grad(): |
| v_emb = student(val_ids, val_mask) |
| _, v_acc = infonce(v_emb[:1000], val_targets[:1000]) |
| v_cos = F.cosine_similarity(v_emb, val_targets, dim=-1).mean().item() |
| v_cv = cv_metric(v_emb[:1000]) |
|
|
| print(f" E{epoch+1}: {elapsed:.0f}s loss={t_loss/d:.4f} " |
| f"t_acc={t_acc/d:.3f} t_cos={t_cos/d:.3f} " |
| f"v_acc={v_acc:.3f} v_cos={v_cos:.3f} v_cv={v_cv:.3f}") |
|
|
| |
| torch.save(student.state_dict(), "mini_student.pt") |
| print(f"\n Student saved. v_cos={v_cos:.3f}, v_cv={v_cv:.3f}") |
|
|
| |
| print(f"\n{'='*65}") |
| print("PHASE 2: TRAIN ALIGNMENT BANK (student frozen)") |
| print(f"{'='*65}") |
|
|
| |
| student.eval() |
| for p in student.parameters(): |
| p.requires_grad = False |
|
|
| |
| print(" Pre-encoding through frozen student...") |
| with torch.no_grad(): |
| all_embs = [] |
| for i in range(0, n_train, 512): |
| j = min(i + 512, n_train) |
| emb = student(train_ids[i:j], train_mask[i:j]) |
| all_embs.append(emb) |
| student_embs = torch.cat(all_embs) |
| val_student_embs = student(val_ids, val_mask) |
|
|
| print(f" Student embeddings: {student_embs.shape}") |
|
|
| |
| bank = AlignmentBank( |
| d_embed=768, n_experts=len(EXPERTS), |
| n_anchors=128, d_bank=64 |
| ).to(DEVICE) |
|
|
| bank.init_from_procrustes(procrustes_results, names, consensus[:n_train]) |
| bank_params = sum(p.numel() for p in bank.parameters()) |
| print(f" Bank: {bank_params:,} params") |
|
|
| bank_opt = torch.optim.AdamW(bank.parameters(), lr=1e-3, weight_decay=0.01) |
| BANK_EPOCHS = 20 |
| BANK_BATCH = 256 |
|
|
| for epoch in range(BANK_EPOCHS): |
| bank.train() |
| perm = torch.randperm(n_train, device=DEVICE) |
| total_loss = 0 |
| stats = {"expert_agreement": 0, "rotation_ortho": 0, |
| "anchor_spread": 0, "bank_cv": 0} |
| n = 0 |
| t0 = time.time() |
|
|
| for i in range(0, n_train, BANK_BATCH): |
| idx = perm[i:i+BANK_BATCH] |
| if len(idx) < 16: continue |
|
|
| emb = student_embs[idx] |
| enriched, aux = bank(emb) |
| loss = bank.bank_loss(aux, cv_target=consensus_cv + 0.02) |
|
|
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(bank.parameters(), 1.0) |
| bank_opt.step(); bank_opt.zero_grad(set_to_none=True) |
|
|
| total_loss += loss.item() |
| for k in stats: |
| if k in aux: |
| v = aux[k] |
| stats[k] += v.item() if torch.is_tensor(v) else v |
| n += 1 |
|
|
| elapsed = time.time() - t0 |
| d = max(n, 1) |
|
|
| |
| bank.eval() |
| with torch.no_grad(): |
| v_enriched, v_aux = bank(val_student_embs) |
| v_loss = bank.bank_loss(v_aux, cv_target=consensus_cv + 0.02).item() |
|
|
| print(f" E{epoch+1:2d}: {elapsed:.0f}s loss={total_loss/d:.4f} " |
| f"v_loss={v_loss:.4f} " |
| f"expert_agr={stats['expert_agreement']/d:.5f} " |
| f"ortho={stats['rotation_ortho']/d:.5f} " |
| f"spread={stats['anchor_spread']/d:.5f} " |
| f"cv={stats['bank_cv']/d:.4f} " |
| f"anchor_max={v_aux['anchor_max_cos']:.3f} " |
| f"expert_cos={v_aux['expert_cos_mean']:.3f}Β±{v_aux['expert_cos_std']:.3f}") |
|
|
| torch.save(bank.state_dict(), "alignment_bank.pt") |
|
|
| |
| print(f"\n{'='*65}") |
| print("PHASE 3: GEOMETRIC VERIFICATION") |
| print(f"{'='*65}") |
|
|
| bank.eval() |
| with torch.no_grad(): |
| |
| enriched_val, _ = bank(val_student_embs) |
| original_768 = enriched_val[:, :768] |
| geo_context = enriched_val[:, 768:] |
|
|
| |
| passthrough_cos = F.cosine_similarity( |
| original_768[:100], val_student_embs[:100], dim=-1).mean().item() |
|
|
| |
| geo_cv = cv_metric(F.normalize(geo_context[:1000], dim=-1)) |
| geo_eff_dim = torch.linalg.svdvals( |
| geo_context[:1000].float() - geo_context[:1000].float().mean(0)).pow(2) |
| geo_eff_dim = (geo_eff_dim.sum() ** 2) / (geo_eff_dim.pow(2).sum() + 1e-12) |
|
|
| print(f" Passthrough integrity: {passthrough_cos:.6f} (should be ~1.000)") |
| print(f" Geo context CV: {geo_cv:.4f}") |
| print(f" Geo context eff_dim: {geo_eff_dim:.1f}") |
| print(f" Geo context shape: {geo_context.shape}") |
|
|
| |
| print(f"\n{'='*65}") |
| print("PHASE 4: CLASSIFIER STABILITY TEST") |
| print(f"{'='*65}") |
|
|
| |
| |
| |
| |
| with torch.no_grad(): |
| |
| embs = val_student_embs[:1000] |
| sim = embs @ embs.T |
| sim.fill_diagonal_(-1) |
|
|
| |
| n_pairs = 3000 |
| idx_a = torch.randint(0, 1000, (n_pairs,)) |
| idx_b = torch.randint(0, 1000, (n_pairs,)) |
| pair_cos = sim[idx_a, idx_b] |
|
|
| |
| sorted_cos, _ = pair_cos.sort() |
| t1 = sorted_cos[n_pairs // 3].item() |
| t2 = sorted_cos[2 * n_pairs // 3].item() |
| labels = torch.zeros(n_pairs, dtype=torch.long, device=DEVICE) |
| labels[pair_cos > t2] = 0 |
| labels[(pair_cos <= t2) & (pair_cos > t1)] = 1 |
| labels[pair_cos <= t1] = 2 |
|
|
| |
| enriched_a, _ = bank(embs[idx_a]) |
| enriched_b, _ = bank(embs[idx_b]) |
|
|
| |
| for mode in ["with_bank", "without_bank"]: |
| if mode == "with_bank": |
| feat_dim = (768 + 64) * 2 |
| features = torch.cat([enriched_a, enriched_b], dim=-1) |
| else: |
| feat_dim = 768 * 2 |
| features = torch.cat([embs[idx_a], embs[idx_b]], dim=-1) |
|
|
| clf = nn.Sequential( |
| nn.Linear(feat_dim, 128), nn.GELU(), |
| nn.Linear(128, 3) |
| ).to(DEVICE) |
|
|
| clf_opt = torch.optim.Adam(clf.parameters(), lr=1e-3) |
| n_clf_train = 2400 |
| train_f = features[:n_clf_train].detach() |
| train_l = labels[:n_clf_train] |
| val_f = features[n_clf_train:].detach() |
| val_l = labels[n_clf_train:] |
|
|
| for e in range(20): |
| clf.train() |
| logits = clf(train_f) |
| loss = F.cross_entropy(logits, train_l) |
| loss.backward() |
| clf_opt.step(); clf_opt.zero_grad() |
|
|
| clf.eval() |
| with torch.no_grad(): |
| val_logits = clf(val_f) |
| val_acc = (val_logits.argmax(-1) == val_l).float().mean().item() |
| train_logits = clf(train_f) |
| train_acc = (train_logits.argmax(-1) == train_l).float().mean().item() |
|
|
| print(f" {mode:15s}: train_acc={train_acc:.3f} val_acc={val_acc:.3f} " |
| f"gap={train_acc-val_acc:.3f}") |
|
|
| print(f"\n{'='*65}") |
| print("DONE") |
| print(f"{'='*65}") |
| print(f"\n Student: mini_student.pt") |
| print(f" Bank: alignment_bank.pt") |
| print(f" Consensus CV: {consensus_cv:.4f}") |
| print(f" Student v_cos: {v_cos:.3f}") |
|
|
|
|
| if __name__ == "__main__": |
| run() |