geolip-captionbert-8192 / trainers /nli_head_trainer_alignbanked_conv5d_76pct.py
AbstractPhil's picture
Rename trainers/hli_head_trainer_alignbanked_conv5d_76pct.py to trainers/nli_head_trainer_alignbanked_conv5d_76pct.py
802f97a verified
# ============================================================================
# NLI HEAD: Compositional Convolution (conv5d)
#
# Decomposes the premise-hypothesis relationship into 5 feature components,
# processes through all 2^4 = 16 integer partition paths.
#
# Components:
# 1. raw_a (768) β€” premise consensus embedding
# 2. raw_b (768) β€” hypothesis consensus embedding
# 3. |a-b| (768) β€” element-wise difference
# 4. a*b (768) β€” element-wise interaction
# 5. bank_diff (128) β€” bank geometric context difference
#
# Each path is a different factorization of the 5-way relationship.
# Path (5,): process all 5 holistically β€” global view
# Path (1,1,1,1,1): process each alone then combine β€” independent view
# Path (2,3): first two then last three β€” partial groupings
# ... 16 total paths, learned weights per path
# ============================================================================
import gc
import math
import os
import time
from itertools import product as iter_product
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
REPO_ID = "AbstractPhil/geolip-captionbert-8192"
print("=" * 65)
print("NLI HEAD: Compositional Convolution (conv5d)")
print("=" * 65)
print(f" Device: {DEVICE}")
# ══════════════════════════════════════════════════════════════════
# INTEGER COMPOSITIONS
# ══════════════════════════════════════════════════════════════════
def integer_compositions(n):
"""All ordered compositions of integer n. |compositions| = 2^(n-1)."""
if n == 0:
yield ()
return
if n == 1:
yield (1,)
return
for i in range(1, n + 1):
for rest in integer_compositions(n - i):
yield (i,) + rest
# ══════════════════════════════════════════════════════════════════
# COMPOSITIONAL CONV NLI HEAD
# ══════════════════════════════════════════════════════════════════
class CompConvNLI(nn.Module):
"""
Compositional convolution NLI head.
Takes 5 feature components of a premise-hypothesis pair,
projects each to shared dimension, processes through all
16 integer partition paths, weighted combination β†’ classification.
"""
def __init__(self, d_raw=768, d_bank=128, d_path=256,
n_components=5, n_classes=3, dropout=0.1):
super().__init__()
self.d_path = d_path
self.n_components = n_components
# Project each component to shared path dimension
self.proj_raw_a = nn.Linear(d_raw, d_path)
self.proj_raw_b = nn.Linear(d_raw, d_path)
self.proj_diff = nn.Linear(d_raw, d_path)
self.proj_prod = nn.Linear(d_raw, d_path)
self.proj_bank = nn.Linear(d_bank, d_path)
self.proj_norm = nn.LayerNorm(d_path)
# Enumerate all compositions
self.compositions = list(integer_compositions(n_components))
n_paths = len(self.compositions)
print(f" Compositions of {n_components}: {n_paths} paths")
for c in self.compositions:
print(f" {c}")
# Per-group fusion: for each possible group size, a fusion layer
# Group size k means k projected features are concatenated and fused
max_group = n_components
self.group_fusions = nn.ModuleDict()
for k in range(1, max_group + 1):
self.group_fusions[str(k)] = nn.Sequential(
nn.Linear(k * d_path, d_path),
nn.GELU(),
nn.LayerNorm(d_path),
)
# Learned path weights
self.path_weights = nn.Parameter(torch.ones(n_paths) / n_paths)
# Geometric classifier: 3 prototypes on the hypersphere
# No MLP. No dense layers. Just distance to class centers.
self.class_prototypes = nn.Parameter(
F.normalize(torch.randn(n_classes, d_path), dim=-1))
self.temperature = nn.Parameter(torch.tensor(10.0))
def forward(self, enriched_a, enriched_b):
"""
Args:
enriched_a: (B, 896) premise enriched embedding
enriched_b: (B, 896) hypothesis enriched embedding
"""
# Split enriched into raw (768) and bank context (128)
raw_a = enriched_a[:, :768]
bank_a = enriched_a[:, 768:]
raw_b = enriched_b[:, :768]
bank_b = enriched_b[:, 768:]
# 5 components β€” ORDER MATTERS for compositions
# Geometry first: sets the frame of reference
# Structure second: how the pair relates
# Content last: interpreted within the geometric frame
c1 = self.proj_norm(self.proj_bank(bank_a - bank_b)) # geometric frame
c2 = self.proj_norm(self.proj_diff(torch.abs(raw_a - raw_b))) # structural difference
c3 = self.proj_norm(self.proj_prod(raw_a * raw_b)) # structural interaction
c4 = self.proj_norm(self.proj_raw_a(raw_a)) # premise content
c5 = self.proj_norm(self.proj_raw_b(raw_b)) # hypothesis content
# Now (1,4) = "geo frame alone, then all semantic content"
# (2,3) = "geo+diff frame, then interaction+content"
# (3,2) = "geo+diff+interaction, then premise+hypothesis"
components = [c1, c2, c3, c4, c5] # each (B, d_path)
# Process each path
path_outputs = []
for composition in self.compositions:
# Walk through the composition, fusing groups
idx = 0
group_outputs = []
for group_size in composition:
group = components[idx:idx + group_size]
fused = torch.cat(group, dim=-1) # (B, group_size * d_path)
fused = self.group_fusions[str(group_size)](fused) # (B, d_path)
group_outputs.append(fused)
idx += group_size
# Accumulate group outputs: sequential mean
path_out = torch.stack(group_outputs).mean(dim=0) # (B, d_path)
path_outputs.append(path_out)
# Weighted combination of all paths
weights = F.softmax(self.path_weights, dim=0) # (n_paths,)
stacked = torch.stack(path_outputs, dim=0) # (n_paths, B, d_path)
combined = (weights.unsqueeze(-1).unsqueeze(-1) * stacked).sum(0) # (B, d_path)
# Geometric classification: cosine to learned prototypes
combined_n = F.normalize(combined, dim=-1) # (B, d_path)
protos_n = F.normalize(self.class_prototypes, dim=-1) # (n_classes, d_path)
logits = combined_n @ protos_n.T * self.temperature.abs() # (B, n_classes)
# Diagnostics
with torch.no_grad():
w = weights.cpu().tolist()
top_paths = sorted(enumerate(w), key=lambda x: -x[1])[:5]
proto_sim = protos_n @ protos_n.T
proto_spread = proto_sim[~torch.eye(3, dtype=bool, device=proto_sim.device)].mean().item()
return logits, {
"path_weights": w,
"top_paths": [(self.compositions[i], w) for i, w in top_paths],
"weight_spread": max(w) - min(w),
"proto_spread": proto_spread,
"temperature": self.temperature.abs().item(),
}
# ══════════════════════════════════════════════════════════════════
# GEOMETRY
# ══════════════════════════════════════════════════════════════════
def cayley_menger_vol2(pts):
pts = pts.float()
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
B, V, _ = d2.shape
cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
def cv_metric(emb, n=200):
B = emb.shape[0]
if B < 5: return 0.0
vols = []
for _ in range(n):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
if v > 0: vols.append(v)
if len(vols) < 10: return 0.0
a = np.array(vols)
return float(a.std() / (a.mean() + 1e-8))
# ══════════════════════════════════════════════════════════════════
# MAIN
# ══════════════════════════════════════════════════════════════════
def run():
torch.manual_seed(42)
np.random.seed(42)
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
# ── Load model ──
print(f"\n{'='*65}")
print("LOADING MODEL")
print(f"{'='*65}")
model = AutoModel.from_pretrained(REPO_ID, trust_remote_code=True).to(DEVICE).eval()
tokenizer = AutoTokenizer.from_pretrained(REPO_ID, trust_remote_code=True)
for p in model.parameters():
p.requires_grad = False
has_bank = model.bank is not None
print(f" Model: {sum(p.numel() for p in model.parameters()):,} params (frozen)")
print(f" Bank: {'present' if has_bank else 'absent'}")
# ── Load SNLI ──
print(f"\n{'='*65}")
print("LOADING SNLI")
print(f"{'='*65}")
ds = load_dataset("stanfordnlp/snli")
train_ds = ds["train"].filter(lambda x: x["label"] >= 0)
val_ds = ds["validation"].filter(lambda x: x["label"] >= 0)
print(f" Train: {len(train_ds):,} Val: {len(val_ds):,}")
MAX_TRAIN = 549000 # full SNLI
MAX_VAL = 9800
# ── Pre-encode ──
print(f"\n{'='*65}")
print("PRE-ENCODING")
print(f"{'='*65}")
@torch.no_grad()
def encode_pairs(dataset, max_n, batch_size=1024):
dataset = dataset.select(range(min(max_n, len(dataset))))
all_p_enr, all_h_enr = [], []
all_labels = []
for i in tqdm(range(0, len(dataset), batch_size), desc=" Encoding"):
j = min(i + batch_size, len(dataset))
batch = dataset[i:j]
p_in = tokenizer(batch["premise"], max_length=128,
padding="max_length", truncation=True,
return_tensors="pt").to(DEVICE)
p_out = model(**p_in)
h_in = tokenizer(batch["hypothesis"], max_length=128,
padding="max_length", truncation=True,
return_tensors="pt").to(DEVICE)
h_out = model(**h_in)
p_feat = p_out.enriched if p_out.enriched is not None else p_out.last_hidden_state
h_feat = h_out.enriched if h_out.enriched is not None else h_out.last_hidden_state
all_p_enr.append(p_feat.cpu())
all_h_enr.append(h_feat.cpu())
all_labels.append(torch.tensor(batch["label"]))
return {
"p": torch.cat(all_p_enr),
"h": torch.cat(all_h_enr),
"labels": torch.cat(all_labels),
}
train_data = encode_pairs(train_ds, MAX_TRAIN)
val_data = encode_pairs(val_ds, MAX_VAL)
d_enriched = train_data["p"].shape[1]
d_raw = 768
d_bank = d_enriched - d_raw
print(f" Enriched: {d_enriched} (raw={d_raw} + bank={d_bank})")
print(f" Train: {train_data['labels'].shape[0]:,} Val: {val_data['labels'].shape[0]:,}")
for label, name in [(0, "entailment"), (1, "neutral"), (2, "contradiction")]:
n_l = (train_data["labels"] == label).sum().item()
print(f" {name}: {n_l:,} ({n_l/len(train_data['labels']):.1%})")
del model; gc.collect(); torch.cuda.empty_cache()
# Move to GPU
train_p = train_data["p"].to(DEVICE)
train_h = train_data["h"].to(DEVICE)
train_labels = train_data["labels"].to(DEVICE)
val_p = val_data["p"].to(DEVICE)
val_h = val_data["h"].to(DEVICE)
val_labels = val_data["labels"].to(DEVICE)
# ── Build CompConv NLI ──
print(f"\n{'='*65}")
print("COMPOSITIONAL CONV NLI HEAD")
print(f"{'='*65}")
nli = CompConvNLI(
d_raw=d_raw, d_bank=max(d_bank, 1),
d_path=256, n_components=5, n_classes=3, dropout=0.1
).to(DEVICE)
n_head_params = sum(p.numel() for p in nli.parameters())
print(f" Head params: {n_head_params:,}")
# ── Training ──
EPOCHS = 20
BATCH = 128
LR = 1e-4
n_train = train_labels.shape[0]
n_batches = n_train // BATCH
optimizer = torch.optim.AdamW(nli.parameters(), lr=LR, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=n_batches * EPOCHS, eta_min=1e-6)
print(f"\n{'='*65}")
print(f"TRAINING ({EPOCHS} epochs, {n_batches} batches/epoch)")
print(f"{'='*65}")
best_val_acc = 0.0
for epoch in range(EPOCHS):
nli.train()
perm = torch.randperm(n_train, device=DEVICE)
total_loss, total_correct, n = 0, 0, 0
t0 = time.time()
for i in range(0, n_train, BATCH):
idx = perm[i:i+BATCH]
if len(idx) < 8: continue
logits, _ = nli(train_p[idx], train_h[idx])
labels = train_labels[idx]
loss = F.cross_entropy(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(nli.parameters(), 1.0)
optimizer.step(); optimizer.zero_grad(set_to_none=True)
scheduler.step()
total_correct += (logits.argmax(-1) == labels).sum().item()
total_loss += loss.item()
n += len(idx)
elapsed = time.time() - t0
train_acc = total_correct / max(n, 1)
train_loss = total_loss / max(n // BATCH, 1)
# Validation
nli.eval()
with torch.no_grad():
val_n = val_labels.shape[0]
val_correct = 0
val_loss_sum = 0
all_preds, all_labs = [], []
path_info = None
for i in range(0, val_n, 512):
j = min(i + 512, val_n)
logits, info = nli(val_p[i:j], val_h[i:j])
labs = val_labels[i:j]
val_correct += (logits.argmax(-1) == labs).sum().item()
val_loss_sum += F.cross_entropy(logits, labs, reduction="sum").item()
all_preds.append(logits.argmax(-1).cpu())
all_labs.append(labs.cpu())
if path_info is None:
path_info = info
val_acc = val_correct / val_n
val_loss = val_loss_sum / val_n
preds = torch.cat(all_preds)
labs_all = torch.cat(all_labs)
acc_ent = (preds[labs_all == 0] == 0).float().mean().item() if (labs_all == 0).sum() > 0 else 0
acc_neu = (preds[labs_all == 1] == 1).float().mean().item() if (labs_all == 1).sum() > 0 else 0
acc_con = (preds[labs_all == 2] == 2).float().mean().item() if (labs_all == 2).sum() > 0 else 0
print(f"\n E{epoch+1:2d}: {elapsed:.0f}s")
print(f" Task: loss={train_loss:.4f} t_acc={train_acc:.4f} v_acc={val_acc:.4f} v_loss={val_loss:.4f}")
print(f" Per-class: ent={acc_ent:.3f} neu={acc_neu:.3f} con={acc_con:.3f}")
if path_info:
top3 = path_info["top_paths"][:3]
path_str = " ".join(f"{comp}={w:.3f}" for comp, w in top3)
print(f" Paths: {path_str} spread={path_info['weight_spread']:.4f}")
print(f" Protos: sim={path_info.get('proto_spread', 0):.4f} "
f"temp={path_info.get('temperature', 0):.2f}")
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(nli.state_dict(), "nli_conv5d_best.pt")
print(f" β˜… New best: {val_acc:.4f}")
# ══════════════════════════════════════════════════════════════
# PATH ANALYSIS
# ══════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("PATH WEIGHT ANALYSIS")
print(f"{'='*65}")
nli.load_state_dict(torch.load("nli_conv5d_best.pt", weights_only=True))
nli.eval()
weights = F.softmax(nli.path_weights, dim=0).cpu().tolist()
ranked = sorted(zip(nli.compositions, weights), key=lambda x: -x[1])
print(f"\n {'Path':<25} {'Weight':>8} {'Type':<15}")
print(f" {'-'*50}")
for comp, w in ranked:
if len(comp) == 1:
ptype = "holistic"
elif all(c == 1 for c in comp):
ptype = "independent"
elif comp[0] >= 3:
ptype = "geo-first"
elif comp[0] == 1 and sum(comp[1:]) == 4:
ptype = "geo→rest"
elif comp[0] == 2:
ptype = "geo+struct→..."
else:
ptype = "mixed"
bar = "β–ˆ" * int(w * 100)
print(f" {str(comp):<25} {w:>8.4f} {ptype:<15} {bar}")
# ══════════════════════════════════════════════════════════════
# COMPOSITIONAL ORDER TEST
# ══════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("COMPOSITIONAL ORDER TEST")
print(f"{'='*65}")
model = AutoModel.from_pretrained(REPO_ID, trust_remote_code=True).to(DEVICE).eval()
label_names = ["entailment", "neutral", "contradiction"]
test_pairs = [
("a potato on top of a table", "a table on top of a potato"),
("a potato on top of a table", "there is a potato"),
("a cat is sitting on a mat", "a mat is sitting on a cat"),
("a dog chased the cat", "the cat chased the dog"),
("a woman is holding a baby", "a baby is holding a woman"),
("the boy kicked the ball", "the ball kicked the boy"),
("a man is riding a horse", "a horse is riding a man"),
("a girl is painting a picture", "a girl is creating art"),
("two dogs are playing in a park", "animals are outdoors"),
("a person is swimming in the ocean", "nobody is in the water"),
]
with torch.no_grad():
for premise, hypothesis in test_pairs:
p_in = tokenizer([premise], max_length=128, padding="max_length",
truncation=True, return_tensors="pt").to(DEVICE)
h_in = tokenizer([hypothesis], max_length=128, padding="max_length",
truncation=True, return_tensors="pt").to(DEVICE)
p_out = model(**p_in)
h_out = model(**h_in)
p_feat = p_out.enriched if p_out.enriched is not None else p_out.last_hidden_state
h_feat = h_out.enriched if h_out.enriched is not None else h_out.last_hidden_state
logits, _ = nli(p_feat, h_feat)
probs = F.softmax(logits, dim=-1)[0]
pred = label_names[probs.argmax()]
cos = F.cosine_similarity(
p_out.last_hidden_state, h_out.last_hidden_state).item()
print(f"\n P: {premise}")
print(f" H: {hypothesis}")
print(f" Pooled cos: {cos:.3f}")
print(f" NLI: {pred} [E={probs[0]:.3f} N={probs[1]:.3f} C={probs[2]:.3f}]")
print(f"\n{'='*65}")
print("SUMMARY")
print(f"{'='*65}")
print(f" Best val accuracy: {best_val_acc:.4f}")
print(f" Head params: {n_head_params:,}")
print(f" Paths: {len(nli.compositions)}")
print(f" Components: {nli.n_components} β†’ d_path={nli.d_path}")
print(f" Bank present: {has_bank}")
print(f" Saved: nli_conv5d_best.pt")
print(f"\n{'='*65}")
print("DONE")
print(f"{'='*65}")
if __name__ == "__main__":
run()