Rename trainers/hli_head_trainer_alignbanked_conv5d_76pct.py to trainers/nli_head_trainer_alignbanked_conv5d_76pct.py
802f97a verified | # ============================================================================ | |
| # NLI HEAD: Compositional Convolution (conv5d) | |
| # | |
| # Decomposes the premise-hypothesis relationship into 5 feature components, | |
| # processes through all 2^4 = 16 integer partition paths. | |
| # | |
| # Components: | |
| # 1. raw_a (768) β premise consensus embedding | |
| # 2. raw_b (768) β hypothesis consensus embedding | |
| # 3. |a-b| (768) β element-wise difference | |
| # 4. a*b (768) β element-wise interaction | |
| # 5. bank_diff (128) β bank geometric context difference | |
| # | |
| # Each path is a different factorization of the 5-way relationship. | |
| # Path (5,): process all 5 holistically β global view | |
| # Path (1,1,1,1,1): process each alone then combine β independent view | |
| # Path (2,3): first two then last three β partial groupings | |
| # ... 16 total paths, learned weights per path | |
| # ============================================================================ | |
| import gc | |
| import math | |
| import os | |
| import time | |
| from itertools import product as iter_product | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| REPO_ID = "AbstractPhil/geolip-captionbert-8192" | |
| print("=" * 65) | |
| print("NLI HEAD: Compositional Convolution (conv5d)") | |
| print("=" * 65) | |
| print(f" Device: {DEVICE}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # INTEGER COMPOSITIONS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def integer_compositions(n): | |
| """All ordered compositions of integer n. |compositions| = 2^(n-1).""" | |
| if n == 0: | |
| yield () | |
| return | |
| if n == 1: | |
| yield (1,) | |
| return | |
| for i in range(1, n + 1): | |
| for rest in integer_compositions(n - i): | |
| yield (i,) + rest | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # COMPOSITIONAL CONV NLI HEAD | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class CompConvNLI(nn.Module): | |
| """ | |
| Compositional convolution NLI head. | |
| Takes 5 feature components of a premise-hypothesis pair, | |
| projects each to shared dimension, processes through all | |
| 16 integer partition paths, weighted combination β classification. | |
| """ | |
| def __init__(self, d_raw=768, d_bank=128, d_path=256, | |
| n_components=5, n_classes=3, dropout=0.1): | |
| super().__init__() | |
| self.d_path = d_path | |
| self.n_components = n_components | |
| # Project each component to shared path dimension | |
| self.proj_raw_a = nn.Linear(d_raw, d_path) | |
| self.proj_raw_b = nn.Linear(d_raw, d_path) | |
| self.proj_diff = nn.Linear(d_raw, d_path) | |
| self.proj_prod = nn.Linear(d_raw, d_path) | |
| self.proj_bank = nn.Linear(d_bank, d_path) | |
| self.proj_norm = nn.LayerNorm(d_path) | |
| # Enumerate all compositions | |
| self.compositions = list(integer_compositions(n_components)) | |
| n_paths = len(self.compositions) | |
| print(f" Compositions of {n_components}: {n_paths} paths") | |
| for c in self.compositions: | |
| print(f" {c}") | |
| # Per-group fusion: for each possible group size, a fusion layer | |
| # Group size k means k projected features are concatenated and fused | |
| max_group = n_components | |
| self.group_fusions = nn.ModuleDict() | |
| for k in range(1, max_group + 1): | |
| self.group_fusions[str(k)] = nn.Sequential( | |
| nn.Linear(k * d_path, d_path), | |
| nn.GELU(), | |
| nn.LayerNorm(d_path), | |
| ) | |
| # Learned path weights | |
| self.path_weights = nn.Parameter(torch.ones(n_paths) / n_paths) | |
| # Geometric classifier: 3 prototypes on the hypersphere | |
| # No MLP. No dense layers. Just distance to class centers. | |
| self.class_prototypes = nn.Parameter( | |
| F.normalize(torch.randn(n_classes, d_path), dim=-1)) | |
| self.temperature = nn.Parameter(torch.tensor(10.0)) | |
| def forward(self, enriched_a, enriched_b): | |
| """ | |
| Args: | |
| enriched_a: (B, 896) premise enriched embedding | |
| enriched_b: (B, 896) hypothesis enriched embedding | |
| """ | |
| # Split enriched into raw (768) and bank context (128) | |
| raw_a = enriched_a[:, :768] | |
| bank_a = enriched_a[:, 768:] | |
| raw_b = enriched_b[:, :768] | |
| bank_b = enriched_b[:, 768:] | |
| # 5 components β ORDER MATTERS for compositions | |
| # Geometry first: sets the frame of reference | |
| # Structure second: how the pair relates | |
| # Content last: interpreted within the geometric frame | |
| c1 = self.proj_norm(self.proj_bank(bank_a - bank_b)) # geometric frame | |
| c2 = self.proj_norm(self.proj_diff(torch.abs(raw_a - raw_b))) # structural difference | |
| c3 = self.proj_norm(self.proj_prod(raw_a * raw_b)) # structural interaction | |
| c4 = self.proj_norm(self.proj_raw_a(raw_a)) # premise content | |
| c5 = self.proj_norm(self.proj_raw_b(raw_b)) # hypothesis content | |
| # Now (1,4) = "geo frame alone, then all semantic content" | |
| # (2,3) = "geo+diff frame, then interaction+content" | |
| # (3,2) = "geo+diff+interaction, then premise+hypothesis" | |
| components = [c1, c2, c3, c4, c5] # each (B, d_path) | |
| # Process each path | |
| path_outputs = [] | |
| for composition in self.compositions: | |
| # Walk through the composition, fusing groups | |
| idx = 0 | |
| group_outputs = [] | |
| for group_size in composition: | |
| group = components[idx:idx + group_size] | |
| fused = torch.cat(group, dim=-1) # (B, group_size * d_path) | |
| fused = self.group_fusions[str(group_size)](fused) # (B, d_path) | |
| group_outputs.append(fused) | |
| idx += group_size | |
| # Accumulate group outputs: sequential mean | |
| path_out = torch.stack(group_outputs).mean(dim=0) # (B, d_path) | |
| path_outputs.append(path_out) | |
| # Weighted combination of all paths | |
| weights = F.softmax(self.path_weights, dim=0) # (n_paths,) | |
| stacked = torch.stack(path_outputs, dim=0) # (n_paths, B, d_path) | |
| combined = (weights.unsqueeze(-1).unsqueeze(-1) * stacked).sum(0) # (B, d_path) | |
| # Geometric classification: cosine to learned prototypes | |
| combined_n = F.normalize(combined, dim=-1) # (B, d_path) | |
| protos_n = F.normalize(self.class_prototypes, dim=-1) # (n_classes, d_path) | |
| logits = combined_n @ protos_n.T * self.temperature.abs() # (B, n_classes) | |
| # Diagnostics | |
| with torch.no_grad(): | |
| w = weights.cpu().tolist() | |
| top_paths = sorted(enumerate(w), key=lambda x: -x[1])[:5] | |
| proto_sim = protos_n @ protos_n.T | |
| proto_spread = proto_sim[~torch.eye(3, dtype=bool, device=proto_sim.device)].mean().item() | |
| return logits, { | |
| "path_weights": w, | |
| "top_paths": [(self.compositions[i], w) for i, w in top_paths], | |
| "weight_spread": max(w) - min(w), | |
| "proto_spread": proto_spread, | |
| "temperature": self.temperature.abs().item(), | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GEOMETRY | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def cayley_menger_vol2(pts): | |
| pts = pts.float() | |
| diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) | |
| d2 = (diff * diff).sum(-1) | |
| B, V, _ = d2.shape | |
| cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) | |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 | |
| s = (-1.0)**V; f = math.factorial(V-1) | |
| return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) | |
| def cv_metric(emb, n=200): | |
| B = emb.shape[0] | |
| if B < 5: return 0.0 | |
| vols = [] | |
| for _ in range(n): | |
| idx = torch.randperm(B, device=emb.device)[:5] | |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) | |
| v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() | |
| if v > 0: vols.append(v) | |
| if len(vols) < 10: return 0.0 | |
| a = np.array(vols) | |
| return float(a.std() / (a.mean() + 1e-8)) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run(): | |
| torch.manual_seed(42) | |
| np.random.seed(42) | |
| from transformers import AutoModel, AutoTokenizer | |
| from datasets import load_dataset | |
| # ββ Load model ββ | |
| print(f"\n{'='*65}") | |
| print("LOADING MODEL") | |
| print(f"{'='*65}") | |
| model = AutoModel.from_pretrained(REPO_ID, trust_remote_code=True).to(DEVICE).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(REPO_ID, trust_remote_code=True) | |
| for p in model.parameters(): | |
| p.requires_grad = False | |
| has_bank = model.bank is not None | |
| print(f" Model: {sum(p.numel() for p in model.parameters()):,} params (frozen)") | |
| print(f" Bank: {'present' if has_bank else 'absent'}") | |
| # ββ Load SNLI ββ | |
| print(f"\n{'='*65}") | |
| print("LOADING SNLI") | |
| print(f"{'='*65}") | |
| ds = load_dataset("stanfordnlp/snli") | |
| train_ds = ds["train"].filter(lambda x: x["label"] >= 0) | |
| val_ds = ds["validation"].filter(lambda x: x["label"] >= 0) | |
| print(f" Train: {len(train_ds):,} Val: {len(val_ds):,}") | |
| MAX_TRAIN = 549000 # full SNLI | |
| MAX_VAL = 9800 | |
| # ββ Pre-encode ββ | |
| print(f"\n{'='*65}") | |
| print("PRE-ENCODING") | |
| print(f"{'='*65}") | |
| def encode_pairs(dataset, max_n, batch_size=1024): | |
| dataset = dataset.select(range(min(max_n, len(dataset)))) | |
| all_p_enr, all_h_enr = [], [] | |
| all_labels = [] | |
| for i in tqdm(range(0, len(dataset), batch_size), desc=" Encoding"): | |
| j = min(i + batch_size, len(dataset)) | |
| batch = dataset[i:j] | |
| p_in = tokenizer(batch["premise"], max_length=128, | |
| padding="max_length", truncation=True, | |
| return_tensors="pt").to(DEVICE) | |
| p_out = model(**p_in) | |
| h_in = tokenizer(batch["hypothesis"], max_length=128, | |
| padding="max_length", truncation=True, | |
| return_tensors="pt").to(DEVICE) | |
| h_out = model(**h_in) | |
| p_feat = p_out.enriched if p_out.enriched is not None else p_out.last_hidden_state | |
| h_feat = h_out.enriched if h_out.enriched is not None else h_out.last_hidden_state | |
| all_p_enr.append(p_feat.cpu()) | |
| all_h_enr.append(h_feat.cpu()) | |
| all_labels.append(torch.tensor(batch["label"])) | |
| return { | |
| "p": torch.cat(all_p_enr), | |
| "h": torch.cat(all_h_enr), | |
| "labels": torch.cat(all_labels), | |
| } | |
| train_data = encode_pairs(train_ds, MAX_TRAIN) | |
| val_data = encode_pairs(val_ds, MAX_VAL) | |
| d_enriched = train_data["p"].shape[1] | |
| d_raw = 768 | |
| d_bank = d_enriched - d_raw | |
| print(f" Enriched: {d_enriched} (raw={d_raw} + bank={d_bank})") | |
| print(f" Train: {train_data['labels'].shape[0]:,} Val: {val_data['labels'].shape[0]:,}") | |
| for label, name in [(0, "entailment"), (1, "neutral"), (2, "contradiction")]: | |
| n_l = (train_data["labels"] == label).sum().item() | |
| print(f" {name}: {n_l:,} ({n_l/len(train_data['labels']):.1%})") | |
| del model; gc.collect(); torch.cuda.empty_cache() | |
| # Move to GPU | |
| train_p = train_data["p"].to(DEVICE) | |
| train_h = train_data["h"].to(DEVICE) | |
| train_labels = train_data["labels"].to(DEVICE) | |
| val_p = val_data["p"].to(DEVICE) | |
| val_h = val_data["h"].to(DEVICE) | |
| val_labels = val_data["labels"].to(DEVICE) | |
| # ββ Build CompConv NLI ββ | |
| print(f"\n{'='*65}") | |
| print("COMPOSITIONAL CONV NLI HEAD") | |
| print(f"{'='*65}") | |
| nli = CompConvNLI( | |
| d_raw=d_raw, d_bank=max(d_bank, 1), | |
| d_path=256, n_components=5, n_classes=3, dropout=0.1 | |
| ).to(DEVICE) | |
| n_head_params = sum(p.numel() for p in nli.parameters()) | |
| print(f" Head params: {n_head_params:,}") | |
| # ββ Training ββ | |
| EPOCHS = 20 | |
| BATCH = 128 | |
| LR = 1e-4 | |
| n_train = train_labels.shape[0] | |
| n_batches = n_train // BATCH | |
| optimizer = torch.optim.AdamW(nli.parameters(), lr=LR, weight_decay=0.01) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=n_batches * EPOCHS, eta_min=1e-6) | |
| print(f"\n{'='*65}") | |
| print(f"TRAINING ({EPOCHS} epochs, {n_batches} batches/epoch)") | |
| print(f"{'='*65}") | |
| best_val_acc = 0.0 | |
| for epoch in range(EPOCHS): | |
| nli.train() | |
| perm = torch.randperm(n_train, device=DEVICE) | |
| total_loss, total_correct, n = 0, 0, 0 | |
| t0 = time.time() | |
| for i in range(0, n_train, BATCH): | |
| idx = perm[i:i+BATCH] | |
| if len(idx) < 8: continue | |
| logits, _ = nli(train_p[idx], train_h[idx]) | |
| labels = train_labels[idx] | |
| loss = F.cross_entropy(logits, labels) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(nli.parameters(), 1.0) | |
| optimizer.step(); optimizer.zero_grad(set_to_none=True) | |
| scheduler.step() | |
| total_correct += (logits.argmax(-1) == labels).sum().item() | |
| total_loss += loss.item() | |
| n += len(idx) | |
| elapsed = time.time() - t0 | |
| train_acc = total_correct / max(n, 1) | |
| train_loss = total_loss / max(n // BATCH, 1) | |
| # Validation | |
| nli.eval() | |
| with torch.no_grad(): | |
| val_n = val_labels.shape[0] | |
| val_correct = 0 | |
| val_loss_sum = 0 | |
| all_preds, all_labs = [], [] | |
| path_info = None | |
| for i in range(0, val_n, 512): | |
| j = min(i + 512, val_n) | |
| logits, info = nli(val_p[i:j], val_h[i:j]) | |
| labs = val_labels[i:j] | |
| val_correct += (logits.argmax(-1) == labs).sum().item() | |
| val_loss_sum += F.cross_entropy(logits, labs, reduction="sum").item() | |
| all_preds.append(logits.argmax(-1).cpu()) | |
| all_labs.append(labs.cpu()) | |
| if path_info is None: | |
| path_info = info | |
| val_acc = val_correct / val_n | |
| val_loss = val_loss_sum / val_n | |
| preds = torch.cat(all_preds) | |
| labs_all = torch.cat(all_labs) | |
| acc_ent = (preds[labs_all == 0] == 0).float().mean().item() if (labs_all == 0).sum() > 0 else 0 | |
| acc_neu = (preds[labs_all == 1] == 1).float().mean().item() if (labs_all == 1).sum() > 0 else 0 | |
| acc_con = (preds[labs_all == 2] == 2).float().mean().item() if (labs_all == 2).sum() > 0 else 0 | |
| print(f"\n E{epoch+1:2d}: {elapsed:.0f}s") | |
| print(f" Task: loss={train_loss:.4f} t_acc={train_acc:.4f} v_acc={val_acc:.4f} v_loss={val_loss:.4f}") | |
| print(f" Per-class: ent={acc_ent:.3f} neu={acc_neu:.3f} con={acc_con:.3f}") | |
| if path_info: | |
| top3 = path_info["top_paths"][:3] | |
| path_str = " ".join(f"{comp}={w:.3f}" for comp, w in top3) | |
| print(f" Paths: {path_str} spread={path_info['weight_spread']:.4f}") | |
| print(f" Protos: sim={path_info.get('proto_spread', 0):.4f} " | |
| f"temp={path_info.get('temperature', 0):.2f}") | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| torch.save(nli.state_dict(), "nli_conv5d_best.pt") | |
| print(f" β New best: {val_acc:.4f}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PATH ANALYSIS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"\n{'='*65}") | |
| print("PATH WEIGHT ANALYSIS") | |
| print(f"{'='*65}") | |
| nli.load_state_dict(torch.load("nli_conv5d_best.pt", weights_only=True)) | |
| nli.eval() | |
| weights = F.softmax(nli.path_weights, dim=0).cpu().tolist() | |
| ranked = sorted(zip(nli.compositions, weights), key=lambda x: -x[1]) | |
| print(f"\n {'Path':<25} {'Weight':>8} {'Type':<15}") | |
| print(f" {'-'*50}") | |
| for comp, w in ranked: | |
| if len(comp) == 1: | |
| ptype = "holistic" | |
| elif all(c == 1 for c in comp): | |
| ptype = "independent" | |
| elif comp[0] >= 3: | |
| ptype = "geo-first" | |
| elif comp[0] == 1 and sum(comp[1:]) == 4: | |
| ptype = "geoβrest" | |
| elif comp[0] == 2: | |
| ptype = "geo+structβ..." | |
| else: | |
| ptype = "mixed" | |
| bar = "β" * int(w * 100) | |
| print(f" {str(comp):<25} {w:>8.4f} {ptype:<15} {bar}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # COMPOSITIONAL ORDER TEST | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"\n{'='*65}") | |
| print("COMPOSITIONAL ORDER TEST") | |
| print(f"{'='*65}") | |
| model = AutoModel.from_pretrained(REPO_ID, trust_remote_code=True).to(DEVICE).eval() | |
| label_names = ["entailment", "neutral", "contradiction"] | |
| test_pairs = [ | |
| ("a potato on top of a table", "a table on top of a potato"), | |
| ("a potato on top of a table", "there is a potato"), | |
| ("a cat is sitting on a mat", "a mat is sitting on a cat"), | |
| ("a dog chased the cat", "the cat chased the dog"), | |
| ("a woman is holding a baby", "a baby is holding a woman"), | |
| ("the boy kicked the ball", "the ball kicked the boy"), | |
| ("a man is riding a horse", "a horse is riding a man"), | |
| ("a girl is painting a picture", "a girl is creating art"), | |
| ("two dogs are playing in a park", "animals are outdoors"), | |
| ("a person is swimming in the ocean", "nobody is in the water"), | |
| ] | |
| with torch.no_grad(): | |
| for premise, hypothesis in test_pairs: | |
| p_in = tokenizer([premise], max_length=128, padding="max_length", | |
| truncation=True, return_tensors="pt").to(DEVICE) | |
| h_in = tokenizer([hypothesis], max_length=128, padding="max_length", | |
| truncation=True, return_tensors="pt").to(DEVICE) | |
| p_out = model(**p_in) | |
| h_out = model(**h_in) | |
| p_feat = p_out.enriched if p_out.enriched is not None else p_out.last_hidden_state | |
| h_feat = h_out.enriched if h_out.enriched is not None else h_out.last_hidden_state | |
| logits, _ = nli(p_feat, h_feat) | |
| probs = F.softmax(logits, dim=-1)[0] | |
| pred = label_names[probs.argmax()] | |
| cos = F.cosine_similarity( | |
| p_out.last_hidden_state, h_out.last_hidden_state).item() | |
| print(f"\n P: {premise}") | |
| print(f" H: {hypothesis}") | |
| print(f" Pooled cos: {cos:.3f}") | |
| print(f" NLI: {pred} [E={probs[0]:.3f} N={probs[1]:.3f} C={probs[2]:.3f}]") | |
| print(f"\n{'='*65}") | |
| print("SUMMARY") | |
| print(f"{'='*65}") | |
| print(f" Best val accuracy: {best_val_acc:.4f}") | |
| print(f" Head params: {n_head_params:,}") | |
| print(f" Paths: {len(nli.compositions)}") | |
| print(f" Components: {nli.n_components} β d_path={nli.d_path}") | |
| print(f" Bank present: {has_bank}") | |
| print(f" Saved: nli_conv5d_best.pt") | |
| print(f"\n{'='*65}") | |
| print("DONE") | |
| print(f"{'='*65}") | |
| if __name__ == "__main__": | |
| run() |