geolip-captionbert-8192 / trainers /trainer_alignment_base.py
AbstractPhil's picture
Update trainers/trainer_alignment_base.py
f8f95e1 verified
# ============================================================================
# DISTILLED CONSENSUS BERT β€” 200K Scale
#
# Self-contained pipeline:
# 1. Extract 5 BERT-family embeddings on 200K CC12M captions
# 2. Whitened Procrustes alignment
# 3. Generate consensus targets (centroid of aligned embeddings)
# 4. Train small standalone transformer from scratch
# 5. No expert models needed at inference
# ============================================================================
import math
import os
import time
import json
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODELS = [
("google-bert/bert-base-uncased", "bert", 512),
("answerdotai/ModernBERT-base", "modern", 8192),
("FacebookAI/roberta-base", "roberta", 512),
("albert/albert-base-v2", "albert", 512),
("distilbert/distilbert-base-uncased", "distil", 512),
]
@dataclass
class Config:
# Data
n_samples: int = 500000
n_val: int = 5000
min_caption_len: int = 50
extract_batch: int = 1024
cache_dir: str = "/home/claude/consensus_500k"
# Student architecture
d_model: int = 384
n_heads: int = 6
n_layers: int = 6
d_ff: int = 1536
max_len: int = 8192 # position embedding capacity
tokenize_len: int = 512 # actual padding length (captions avg ~100 tokens)
output_dim: int = 768
dropout: float = 0.1
# Training
epochs: int = 30
batch_size: int = 128 # sequences are tokenize_len=512, not max_len=8192
lr: float = 3e-4
weight_decay: float = 0.01
warmup_steps: int = 1000
grad_clip: float = 1.0
seed: int = 42
# Loss
nce_weight: float = 1.0
mse_weight: float = 1.0
cv_weight: float = 0.1
cv_target: float = 0.084
CFG = Config()
print("=" * 65)
print("DISTILLED CONSENSUS BERT β€” 200K Scale")
print("=" * 65)
print(f" Device: {DEVICE}")
print(f" Samples: {CFG.n_samples:,}")
# ══════════════════════════════════════════════════════════════════
# EXTRACTION
# ══════════════════════════════════════════════════════════════════
def load_captions(n, min_len=50):
from datasets import load_dataset
print(f"\n Loading captions (n={n:,})...")
ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext",
split="train", streaming=True)
captions = []
for row in ds:
cap = row.get("caption_llava", "")
if isinstance(cap, str) and len(cap) > min_len:
captions.append(cap)
if len(captions) >= n:
break
print(f" Got {len(captions):,} captions")
return captions
@torch.no_grad()
def extract_one(model_name, short_name, captions, max_len, batch_size):
from transformers import AutoModel, AutoTokenizer
print(f"\n Extracting: {short_name} ({model_name})...")
model = AutoModel.from_pretrained(model_name).to(DEVICE).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
dim = model.config.hidden_size
n_params = sum(p.numel() for p in model.parameters())
print(f" dim={dim}, {n_params:,} params")
all_emb = []
for i in tqdm(range(0, len(captions), batch_size), desc=f" {short_name}"):
batch = captions[i:i+batch_size]
inputs = tokenizer(batch, max_length=max_len, padding=True,
truncation=True, return_tensors="pt").to(DEVICE)
out = model(**inputs)
mask = inputs.attention_mask.unsqueeze(-1).float()
pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1)
all_emb.append(pooled.cpu())
emb = torch.cat(all_emb)
print(f" Shape: {emb.shape}")
del model
torch.cuda.empty_cache()
return emb
def extract_all():
os.makedirs(CFG.cache_dir, exist_ok=True)
caps_path = os.path.join(CFG.cache_dir, "captions.json")
all_cached = all(
os.path.exists(os.path.join(CFG.cache_dir, f"{s}.pt"))
for _, s, _ in MODELS)
if all_cached and os.path.exists(caps_path):
print("\n Loading cached embeddings...")
embeds = {}
for _, short, _ in MODELS:
embeds[short] = torch.load(
os.path.join(CFG.cache_dir, f"{short}.pt"), weights_only=True)
print(f" {short}: {embeds[short].shape}")
with open(caps_path) as f:
captions = json.load(f)
return embeds, captions
captions = load_captions(CFG.n_samples, CFG.min_caption_len)
embeds = {}
for model_name, short, model_max_len in MODELS:
emb = extract_one(model_name, short, captions,
model_max_len, CFG.extract_batch)
if emb.shape[1] != 768:
if emb.shape[1] < 768:
emb = F.pad(emb, (0, 768 - emb.shape[1]))
else:
emb = emb[:, :768]
embeds[short] = emb
torch.save(emb, os.path.join(CFG.cache_dir, f"{short}.pt"))
with open(caps_path, "w") as f:
json.dump(captions, f)
return embeds, captions
# ══════════════════════════════════════════════════════════════════
# WHITENED PROCRUSTES + CONSENSUS
# ══════════════════════════════════════════════════════════════════
def symmetric_inv_sqrt(cov, eps=1e-6):
evals, evecs = torch.linalg.eigh(cov)
evals = torch.clamp(evals, min=eps)
return evecs @ torch.diag(evals.rsqrt()) @ evecs.T
def procrustes_align(source, target, n_align=10000):
N = min(n_align, source.shape[0], target.shape[0])
S = source[:N].float()
T = target[:N].float()
s_mean = S.mean(0, keepdim=True)
t_mean = T.mean(0, keepdim=True)
Sc = S - s_mean
Tc = T - t_mean
N_s = Sc.shape[0]
s_cov = (Sc.T @ Sc) / max(N_s - 1, 1)
t_cov = (Tc.T @ Tc) / max(N_s - 1, 1)
s_whiten = symmetric_inv_sqrt(s_cov)
t_whiten = symmetric_inv_sqrt(t_cov)
Sc_w = F.normalize(Sc @ s_whiten, dim=-1)
Tc_w = F.normalize(Tc @ t_whiten, dim=-1)
cos_before = F.cosine_similarity(Sc, Tc, dim=-1).mean().item()
U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False)
R = U @ Vt
cos_after = F.cosine_similarity(Sc_w @ R.T, Tc_w, dim=-1).mean().item()
return {
"rotation": R, "source_mean": s_mean.squeeze(0),
"source_whitener": s_whiten,
"target_unwhitener": torch.linalg.pinv(t_whiten),
"cos_before": cos_before, "cos_after": cos_after,
}
def apply_align(emb, a):
x = emb.float() - a["source_mean"]
x = x @ a["source_whitener"]
x = x @ a["rotation"].T
x = x @ a["target_unwhitener"]
return x
def generate_consensus(embeds):
"""Align all to bert space, take normalized centroid as target."""
print(f"\n{'='*65}")
print("WHITENED PROCRUSTES ALIGNMENT + CONSENSUS")
print(f"{'='*65}")
ref_name = "bert"
names = [s for _, s, _ in MODELS]
aligned = {}
for name in names:
info = procrustes_align(embeds[name], embeds[ref_name])
aligned[name] = apply_align(embeds[name], info)
label = " (ref)" if name == ref_name else ""
print(f" {name:10s}: cos {info['cos_before']:.4f} β†’ {info['cos_after']:.4f}{label}")
# Consensus = normalized centroid of all 5 aligned embeddings
# This is what the five-BERT experiment proved: the centroid IS the consensus
# to three decimal places regardless of seed. No learned model needed.
centroid = sum(aligned[n] for n in names) / len(names)
consensus = F.normalize(centroid, dim=-1)
# Verify geometry
N_check = min(5000, consensus.shape[0])
for name in names:
cos = F.cosine_similarity(
consensus[:N_check], aligned[name][:N_check], dim=-1).mean().item()
print(f" cos(consensus, {name:10s}): {cos:.4f}")
return consensus
# ══════════════════════════════════════════════════════════════════
# STUDENT MODEL
# ══════════════════════════════════════════════════════════════════
class CaptionEncoder(nn.Module):
def __init__(self, vocab_size=30522, max_len=128, d_model=384,
n_heads=6, n_layers=6, d_ff=1536, output_dim=768,
dropout=0.1, pad_token_id=0):
super().__init__()
self.pad_token_id = pad_token_id
self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
self.pos_emb = nn.Embedding(max_len, d_model)
self.emb_norm = nn.LayerNorm(d_model)
self.emb_drop = nn.Dropout(dropout)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=d_ff,
dropout=dropout, activation="gelu", batch_first=True,
norm_first=True)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
self.output_proj = nn.Sequential(
nn.Linear(d_model, d_model),
nn.GELU(),
nn.LayerNorm(d_model),
nn.Linear(d_model, output_dim),
)
def forward(self, input_ids, attention_mask=None):
B, L = input_ids.shape
positions = torch.arange(L, device=input_ids.device).unsqueeze(0)
x = self.token_emb(input_ids) + self.pos_emb(positions)
x = self.emb_drop(self.emb_norm(x))
if attention_mask is not None:
kpm = ~attention_mask.bool()
else:
kpm = (input_ids == self.pad_token_id)
x = self.encoder(x, src_key_padding_mask=kpm)
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).float()
else:
mask = (~kpm).unsqueeze(-1).float()
pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
return F.normalize(self.output_proj(pooled), dim=-1)
# ══════════════════════════════════════════════════════════════════
# GEOMETRY
# ══════════════════════════════════════════════════════════════════
def cayley_menger_vol2(pts):
pts = pts.float()
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
B, V, _ = d2.shape
cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
def cv_loss(emb, target=0.084, n_samples=16):
B = emb.shape[0]
if B < 5: return torch.tensor(0.0, device=emb.device)
vols = []
for _ in range(n_samples):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
stacked = torch.stack(vols)
cv = stacked.std() / (stacked.mean() + 1e-8)
return (cv - target).abs()
def cv_metric(emb, n=200):
B = emb.shape[0]
if B < 5: return 0.0
vols = []
for _ in range(n):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
if v > 0: vols.append(v)
if len(vols) < 10: return 0.0
a = np.array(vols)
return float(a.std() / (a.mean() + 1e-8))
def infonce(a, b, temperature=0.07):
a = F.normalize(a, dim=-1)
b = F.normalize(b, dim=-1)
logits = (a @ b.T) / temperature
labels = torch.arange(logits.shape[0], device=logits.device)
loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
with torch.no_grad():
acc = (logits.argmax(-1) == labels).float().mean().item()
return loss, acc
# ══════════════════════════════════════════════════════════════════
# TRAINING
# ══════════════════════════════════════════════════════════════════
def train():
torch.manual_seed(CFG.seed)
torch.cuda.manual_seed_all(CFG.seed)
np.random.seed(CFG.seed)
# ── Extract + Align + Consensus ──
embeds, captions = extract_all()
consensus = generate_consensus(embeds)
# Free the raw embeddings
del embeds
torch.cuda.empty_cache()
import gc; gc.collect()
# ── Tokenize ──
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
print(f"\n Tokenizer: bert-base-uncased (vocab={tokenizer.vocab_size})")
print(" Pre-tokenizing...")
# Tokenize in chunks to avoid memory issues
all_ids, all_masks = [], []
chunk = 50000
for i in tqdm(range(0, len(captions), chunk), desc=" Tokenizing"):
j = min(i + chunk, len(captions))
tokens = tokenizer(captions[i:j], max_length=CFG.tokenize_len,
padding="max_length", truncation=True,
return_tensors="pt")
all_ids.append(tokens["input_ids"])
all_masks.append(tokens["attention_mask"])
input_ids = torch.cat(all_ids)
attention_mask = torch.cat(all_masks)
real_lens = attention_mask.sum(1).float()
print(f" Token lengths: mean={real_lens.mean():.0f} "
f"median={real_lens.median():.0f} "
f">{CFG.tokenize_len}: {(real_lens >= CFG.tokenize_len).float().mean():.1%}")
print(f" Padded to: {CFG.tokenize_len} (model supports up to {CFG.max_len})")
# Split
n_train = len(captions) - CFG.n_val
print(f" Train: {n_train:,}, Val: {CFG.n_val:,}")
# Move to GPU
train_ids = input_ids[:n_train].to(DEVICE)
train_mask = attention_mask[:n_train].to(DEVICE)
train_targets = consensus[:n_train].to(DEVICE)
val_ids = input_ids[n_train:].to(DEVICE)
val_mask = attention_mask[n_train:].to(DEVICE)
val_targets = consensus[n_train:].to(DEVICE)
# ── Student ──
print(f"\n{'='*65}")
print("STUDENT MODEL")
print(f"{'='*65}")
student = CaptionEncoder(
vocab_size=tokenizer.vocab_size,
max_len=CFG.max_len,
d_model=CFG.d_model,
n_heads=CFG.n_heads,
n_layers=CFG.n_layers,
d_ff=CFG.d_ff,
output_dim=CFG.output_dim,
dropout=CFG.dropout,
pad_token_id=tokenizer.pad_token_id,
).to(DEVICE)
n_params = sum(p.numel() for p in student.parameters())
print(f" Architecture: {CFG.n_layers}L, {CFG.d_model}d, {CFG.n_heads}h, {CFG.d_ff} FFN")
print(f" Output: {CFG.output_dim}-dim (consensus space)")
print(f" Parameters: {n_params:,}")
size_mb = sum(p.numel() * p.element_size() for p in student.parameters()) / 1e6
print(f" Size: {size_mb:.1f} MB")
# ── Warm-start from previous checkpoint if available ──
for prev_dir in ["/home/claude/consensus_200k/student",
"/home/claude/distilled_consensus"]:
prev_ckpt = os.path.join(prev_dir, "best_model.pt")
if os.path.exists(prev_ckpt):
print(f"\n Warm-starting from: {prev_ckpt}")
prev_state = torch.load(prev_ckpt, weights_only=True, map_location=DEVICE)
current_state = student.state_dict()
loaded, extended, skipped = 0, 0, 0
for name, param in prev_state.items():
if name not in current_state:
skipped += 1
continue
if param.shape == current_state[name].shape:
current_state[name] = param
loaded += 1
elif "pos_emb" in name and param.shape[0] < current_state[name].shape[0]:
# Extend position embeddings: copy old positions, init new ones
old_len = param.shape[0]
current_state[name][:old_len] = param
nn.init.normal_(current_state[name][old_len:], std=0.02)
extended += 1
print(f" Extended {name}: {param.shape[0]}β†’{current_state[name].shape[0]}")
else:
skipped += 1
student.load_state_dict(current_state)
print(f" Loaded: {loaded}, Extended: {extended}, Skipped: {skipped}")
break
else:
print("\n Training from scratch (no previous checkpoint found)")
# ── Optimizer ──
optimizer = torch.optim.AdamW(student.parameters(), lr=CFG.lr,
weight_decay=CFG.weight_decay)
n_batches = n_train // CFG.batch_size
total_steps = n_batches * CFG.epochs
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
[torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01,
total_iters=CFG.warmup_steps),
torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=max(total_steps - CFG.warmup_steps, 1),
eta_min=1e-6)],
milestones=[CFG.warmup_steps])
os.makedirs(CFG.cache_dir, exist_ok=True)
save_dir = os.path.join(CFG.cache_dir, "student")
os.makedirs(save_dir, exist_ok=True)
# ── Train ──
print(f"\n{'='*65}")
print(f"TRAINING ({CFG.epochs} epochs, {n_batches} batches/epoch)")
print(f"{'='*65}")
all_metrics = {"config": {k: str(v) for k, v in vars(CFG).items()}, "epochs": []}
best_val_cos = 0.0
for epoch in range(CFG.epochs):
student.train()
perm = torch.randperm(n_train, device=DEVICE)
losses = {"total": 0, "nce": 0, "mse": 0}
metrics = {"acc": 0, "cos": 0}
n = 0
t0 = time.time()
for i in range(0, n_train, CFG.batch_size):
idx = perm[i:i+CFG.batch_size]
if len(idx) < 8: continue
emb = student(train_ids[idx], train_mask[idx])
tgt = train_targets[idx]
l_nce, acc = infonce(emb, tgt)
l_mse = F.mse_loss(emb, tgt)
l_cv = cv_loss(emb, target=CFG.cv_target)
loss = CFG.nce_weight * l_nce + CFG.mse_weight * l_mse + CFG.cv_weight * l_cv
loss.backward()
torch.nn.utils.clip_grad_norm_(student.parameters(), CFG.grad_clip)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
scheduler.step()
with torch.no_grad():
cos = F.cosine_similarity(emb, tgt, dim=-1).mean().item()
losses["total"] += loss.item()
losses["nce"] += l_nce.item()
losses["mse"] += l_mse.item()
metrics["acc"] += acc
metrics["cos"] += cos
n += 1
elapsed = time.time() - t0
d = max(n, 1)
# Val
student.eval()
with torch.no_grad():
val_embs = []
for vi in range(0, CFG.n_val, 512):
vj = min(vi + 512, CFG.n_val)
ve = student(val_ids[vi:vj], val_mask[vi:vj])
val_embs.append(ve)
val_emb = torch.cat(val_embs)
_, val_acc = infonce(val_emb[:2000], val_targets[:2000])
val_cos = F.cosine_similarity(val_emb, val_targets, dim=-1).mean().item()
val_cv = cv_metric(val_emb[:2000])
summary = {
"epoch": epoch + 1, "elapsed": elapsed,
"loss": losses["total"] / d,
"train_acc": metrics["acc"] / d,
"train_cos": metrics["cos"] / d,
"val_acc": val_acc, "val_cos": val_cos, "val_cv": val_cv,
}
all_metrics["epochs"].append(summary)
print(f" E{epoch+1:2d}: {elapsed:.0f}s "
f"loss={summary['loss']:.4f} "
f"t_acc={summary['train_acc']:.3f} t_cos={summary['train_cos']:.3f} "
f"v_acc={summary['val_acc']:.3f} v_cos={summary['val_cos']:.3f} "
f"v_cv={summary['val_cv']:.3f}")
if val_cos > best_val_cos:
best_val_cos = val_cos
torch.save(student.state_dict(), os.path.join(save_dir, "best_model.pt"))
if (epoch + 1) % 10 == 0:
torch.save(student.state_dict(),
os.path.join(save_dir, f"model_e{epoch+1:02d}.pt"))
# Final save
torch.save(student.state_dict(), os.path.join(save_dir, "final_model.pt"))
tokenizer.save_pretrained(os.path.join(save_dir, "tokenizer"))
with open(os.path.join(save_dir, "metrics.json"), "w") as f:
json.dump(all_metrics, f, indent=2, default=str)
# ══════════════════════════════════════════════════════════════
# FINAL EVAL
# ══════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("FINAL EVALUATION")
print(f"{'='*65}")
student.load_state_dict(
torch.load(os.path.join(save_dir, "best_model.pt"),
weights_only=True, map_location=DEVICE))
student.eval()
with torch.no_grad():
val_embs = []
for vi in range(0, CFG.n_val, 512):
vj = min(vi + 512, CFG.n_val)
ve = student(val_ids[vi:vj], val_mask[vi:vj])
val_embs.append(ve)
val_emb = torch.cat(val_embs)
# Retrieval (on 2K subset for memory)
sub = min(2000, CFG.n_val)
sim = val_emb[:sub] @ val_targets[:sub].T
labels = torch.arange(sub, device=DEVICE)
r1 = (sim.argmax(1) == labels).float().mean().item()
r5 = (sim.topk(5, dim=1).indices == labels.unsqueeze(1)).any(1).float().mean().item()
r10 = (sim.topk(10, dim=1).indices == labels.unsqueeze(1)).any(1).float().mean().item()
cos_match = F.cosine_similarity(val_emb, val_targets, dim=-1).mean().item()
final_cv = cv_metric(val_emb[:2000])
print(f" Retrieval (student β†’ consensus):")
print(f" R@1: {r1:.4f}")
print(f" R@5: {r5:.4f}")
print(f" R@10: {r10:.4f}")
print(f" Cosine: {cos_match:.4f}")
print(f" CV: {final_cv:.4f} (target: {CFG.cv_target})")
print(f" Model: {n_params:,} params, {size_mb:.1f} MB")
# Standalone test
print(f"\n Standalone similarity test:")
test = [
"A cat sitting on a windowsill watching birds",
"A golden retriever playing fetch on the beach",
"A still life painting with flowers and fruit",
"An aerial photograph of a city skyline at night",
"A child riding a bicycle through autumn leaves",
]
with torch.no_grad():
tok = tokenizer(test, max_length=CFG.tokenize_len, padding="max_length",
truncation=True, return_tensors="pt").to(DEVICE)
embs = student(tok["input_ids"], tok["attention_mask"])
sim = embs @ embs.T
for i in range(len(test)):
for j in range(i+1, len(test)):
print(f" [{i}]↔[{j}]: {sim[i,j]:.3f} "
f"({test[i][:35]}↔{test[j][:35]})")
print(f"\n Saved to: {save_dir}/")
print(f"\n{'='*65}")
print("DONE")
print(f"{'='*65}")
if __name__ == "__main__":
train()