AbstractPhil's picture
Rename trainer.py to cell2_trainer.py
fd506cf verified
# ============================================================================
# TRAINER: MEMORY-EXTENDED CLIP-L TEXT ENCODER
#
# Cell 2: Training loop. Requires Cell 1 (memory_clip_l.py) loaded.
#
# Data: image-caption pairs with long captions.
# Uses COCO Captions for proof of concept (short captions, but validates
# the architecture works). For production: ShareGPT4V, LAION-COCO, etc.
#
# Training flow per batch:
# 1. ModernBERT (frozen): full caption β†’ teacher_cls (1024)
# 2. CLIP text + memory (trainable): segmented caption β†’ student_cls (768)
# 3. CLIP vision (frozen): image β†’ vision_cls (768) [alignment anchor]
# 4. Losses: InfoNCE(proj(student), teacher) + Procrustes + CV
# + CLIP_contrastive(student, vision) [preserve vision alignment]
# ============================================================================
import gc
import math
import os
import time
from dataclasses import dataclass, asdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from safetensors.torch import save_file as safetensors_save
# ══════════════════════════════════════════════════════════════════
# TRAIN CONFIG
# ══════════════════════════════════════════════════════════════════
@dataclass
class TrainConfig:
# Data
max_train_samples: int = 50000
max_val_samples: int = 1000
min_caption_length: int = 20
# Training
epochs: int = 10
batch_size: int = 64
lr_bank: float = 2e-3
lr_output: float = 5e-4
lr_proj: float = 1e-3
min_lr: float = 1e-6
weight_decay: float = 0.01
grad_clip: float = 1.0
warmup_steps: int = 200
# Loss weights
modern_weight: float = 1.0 # InfoNCE: student ↔ ModernBERT teacher
clip_vision_weight: float = 0.5 # Preserve CLIP vision alignment
procrustes_weight: float = 0.3 # Geometric regularizer
cv_weight: float = 0.05 # Pentachoron CV β†’ 0.20
temperature: float = 0.07
# Teacher
modern_max_len: int = 4096
procrustes_n_samples: int = 300
# Logging
checkpoint_dir: str = "/home/claude/memory_clip_checkpoints"
log_every: int = 20
eval_every: int = 200
TCFG = TrainConfig()
# ══════════════════════════════════════════════════════════════════
# GEOMETRIC UTILITIES (duplicated for cell independence)
# ══════════════════════════════════════════════════════════════════
def cayley_menger_vol2(pts):
with torch.amp.autocast("cuda", enabled=False):
pts = pts.float()
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
B, V, _ = d2.shape
cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
def pentachoron_cv(embeddings, n_samples=16):
B = embeddings.shape[0]
if B < 5:
return torch.tensor(0.0, device=embeddings.device)
vols = []
for _ in range(n_samples):
idx = torch.randperm(B, device=embeddings.device)[:5]
v2 = cayley_menger_vol2(embeddings[idx].unsqueeze(0))
vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
stacked = torch.stack(vols)
return stacked.std() / (stacked.mean() + 1e-8)
def procrustes_alignment_loss(emb_a, emb_b):
with torch.amp.autocast("cuda", enabled=False):
A = F.normalize(emb_a.float(), dim=-1)
B_e = F.normalize(emb_b.float(), dim=-1)
A = A - A.mean(0, keepdim=True)
B_e = B_e - B_e.mean(0, keepdim=True)
S = torch.linalg.svdvals(A.T @ B_e)
N, D = A.shape
return 1.0 - S.sum() / (math.sqrt(N) * D)
# ══════════════════════════════════════════════════════════════════
# LOSSES
# ══════════════════════════════════════════════════════════════════
def infonce_loss(emb_a, emb_b, temperature=0.07):
a = F.normalize(emb_a, dim=-1)
b = F.normalize(emb_b, dim=-1)
logits = (a @ b.T) / temperature
labels = torch.arange(logits.shape[0], device=logits.device)
loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
with torch.no_grad():
acc = (logits.argmax(-1) == labels).float().mean().item()
return loss, acc
def batch_cv_loss(all_anchors, n_reals, cv_target=0.20):
device = all_anchors.device
B = all_anchors.shape[0]
total_loss = torch.tensor(0.0, device=device)
total_cv = 0.0; n_valid = 0
for b in range(B):
n = n_reals[b].item() if isinstance(n_reals[b], torch.Tensor) else n_reals[b]
if n < 5:
continue
cv_val = pentachoron_cv(all_anchors[b, :n], n_samples=16)
total_loss = total_loss + (cv_val - cv_target).abs()
total_cv += cv_val.item()
n_valid += 1
return total_loss / max(n_valid, 1), {"cv_raw": total_cv / max(n_valid, 1)}
# ══════════════════════════════════════════════════════════════════
# TEACHER UTILITIES
# ══════════════════════════════════════════════════════════════════
@torch.no_grad()
def teacher_forward_modern(model, tokenizer, texts, device, max_len):
inputs = tokenizer(texts, max_length=max_len, padding=True,
truncation=True, return_tensors="pt").to(device)
out = model(**inputs)
mask = inputs.attention_mask.unsqueeze(-1).float()
return (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1)
@torch.no_grad()
def vision_forward(clip_model, processor, images, device):
"""Get CLIP vision embeddings for alignment preservation."""
inputs = processor(images=images, return_tensors="pt").to(device)
return clip_model.get_image_features(**inputs)
# ══════════════════════════════════════════════════════════════════
# DATASET
# ══════════════════════════════════════════════════════════════════
class CaptionDataset(Dataset):
"""
Wraps a list of caption strings.
For POC: COCO captions (short, ~10-20 tokens).
For production: ShareGPT4V or LAION-COCO (long, 100-500 tokens).
"""
def __init__(self, captions, images=None):
self.captions = captions
self.images = images # optional PIL images for vision alignment
def __len__(self):
return len(self.captions)
def __getitem__(self, idx):
item = {"caption": self.captions[idx]}
if self.images is not None:
item["image"] = self.images[idx]
return item
# ══════════════════════════════════════════════════════════════════
# PARAM GROUPS
# ══════════════════════════════════════════════════════════════════
def make_param_groups(model):
bank_names = {"bank.depth_compressor", "bank.temporal_proj",
"bank.cross_attn", "bank.cross_norms",
"bank.cross_ffns", "bank.ffn_norms",
"clip_cross_attn", "clip_cross_norms",
"clip_cross_ffns", "clip_cross_ffn_norms"}
proj_names = {"proj_modern"}
bank_p, proj_p, output_p = [], [], []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if any(name.startswith(p) for p in proj_names):
proj_p.append(param)
elif any(name.startswith(p) for p in bank_names):
bank_p.append(param)
else:
output_p.append(param)
groups = [
{"params": bank_p, "lr": TCFG.lr_bank, "name": "bank",
"weight_decay": TCFG.weight_decay},
{"params": proj_p, "lr": TCFG.lr_proj, "name": "proj",
"weight_decay": TCFG.weight_decay},
{"params": output_p, "lr": TCFG.lr_output, "name": "output",
"weight_decay": TCFG.weight_decay},
]
for g in groups:
n = sum(p.numel() for p in g["params"])
print(f" {g['name']:8s}: {n:>10,} params @ lr={g['lr']}")
return groups
# ══════════════════════════════════════════════════════════════════
# PROCRUSTES PRE-ALIGNMENT
# ══════════════════════════════════════════════════════════════════
@torch.no_grad()
def compute_and_init_procrustes(student_model, modern_model, modern_tok,
captions, device):
print(f"\n Computing static Procrustes on {len(captions)} captions...")
student_embs, modern_embs = [], []
for i in range(0, len(captions), 16):
batch = captions[i:i+16]
# Student: CLIP text encoder, single segment (no memory)
clip_tok = student_model.clip_tokenizer
tokens = clip_tok(batch, max_length=77, padding=True,
truncation=True, return_tensors="pt").to(device)
clip_out = student_model.clip_text(
input_ids=tokens.input_ids,
attention_mask=tokens.attention_mask,
output_hidden_states=False)
# Use pooler_output (EOS token embedding)
student_embs.append(clip_out.pooler_output.cpu())
# ModernBERT
modern_embs.append(
teacher_forward_modern(modern_model, modern_tok, batch,
device, TCFG.modern_max_len).cpu())
student_all = torch.cat(student_embs)
modern_all = torch.cat(modern_embs)
print(f" Student: {student_all.shape}, Teacher: {modern_all.shape}")
R, mu_s, mu_t = compute_static_procrustes(student_all, modern_all)
student_model.proj_modern.init_from_procrustes(R, mu_s, mu_t)
@torch.no_grad()
def compute_static_procrustes(student_embs, teacher_embs):
X = student_embs.float()
Y = teacher_embs.float()
mu_x, mu_y = X.mean(0), Y.mean(0)
Xc, Yc = X - mu_x, Y - mu_y
# Pad student (768) to teacher (1024) for SVD
if Xc.shape[1] < Yc.shape[1]:
pad = torch.zeros(Xc.shape[0], Yc.shape[1] - Xc.shape[1])
Xc = torch.cat([Xc, pad], dim=1)
mu_x = torch.cat([mu_x, torch.zeros(Yc.shape[1] - mu_x.shape[0])])
U, S, Vt = torch.linalg.svd(Xc.T @ Yc)
R = (U @ Vt).T
cos_before = F.cosine_similarity(Xc, Yc, dim=-1).mean()
cos_after = F.cosine_similarity((Xc @ R.T), Yc, dim=-1).mean()
print(f" Procrustes: cos {cos_before:.4f} β†’ {cos_after:.4f}")
return R, mu_x, mu_y
# ══════════════════════════════════════════════════════════════════
# TRAINING
# ══════════════════════════════════════════════════════════════════
def train(model, modern_model, modern_tok, train_captions, val_captions=None):
device = next(model.parameters()).device
os.makedirs(TCFG.checkpoint_dir, exist_ok=True)
param_groups = make_param_groups(model)
optimizer = torch.optim.AdamW(param_groups)
all_params = [p for g in param_groups for p in g["params"]]
n_batches_per_epoch = len(train_captions) // TCFG.batch_size
total_steps = n_batches_per_epoch * TCFG.epochs
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
[torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01,
total_iters=TCFG.warmup_steps),
torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=max(total_steps, 1), eta_min=TCFG.min_lr)],
milestones=[TCFG.warmup_steps])
scaler = torch.amp.GradScaler()
global_step = 0
clip_tokenizer = model.clip_tokenizer
print(f"\n Training: {model.num_trainable_params():,} params")
print(f" {n_batches_per_epoch} batches/epoch Γ— {TCFG.batch_size}")
print(f" Losses: modern({TCFG.modern_weight}) + "
f"procrustes({TCFG.procrustes_weight}) + cv({TCFG.cv_weight})")
for epoch in range(TCFG.epochs):
model.train()
perm = np.random.permutation(len(train_captions))
losses = {"total": 0, "modern": 0, "procrustes": 0, "cv": 0}
metrics = {"modern_acc": 0, "cv_raw": 0}
n = 0
t0 = time.time()
pbar = tqdm(range(0, len(train_captions), TCFG.batch_size),
desc=f"Epoch {epoch+1}/{TCFG.epochs}")
for batch_start in pbar:
idx = perm[batch_start:batch_start + TCFG.batch_size]
if len(idx) < 2:
continue
batch_captions = [train_captions[i] for i in idx]
B = len(batch_captions)
# ── Teacher: ModernBERT (frozen, full caption) ──
with torch.no_grad():
with torch.amp.autocast("cuda"):
modern_cls = teacher_forward_modern(
modern_model, modern_tok, batch_captions,
device, TCFG.modern_max_len)
# ── Student: CLIP + memory (segment by segment) ──
state = model.init_state(B, device)
# Segment all captions
all_segments = [segment_text(cap, clip_tokenizer,
model.config.max_content_tokens,
model.config.segment_overlap,
model.config.max_segments)
for cap in batch_captions]
# Find max segments in batch
max_segs = max(len(s) for s in all_segments)
n_segs = [len(s) for s in all_segments]
# Collect anchors
all_anchors = torch.zeros(
B, max_segs, model.config.anchor_dim, device=device)
for seg_k in range(max_segs):
# Build batch for this segment position
batch_ids = []
batch_masks = []
for b in range(B):
if seg_k < len(all_segments[b]):
batch_ids.append(all_segments[b][seg_k]["input_ids"])
batch_masks.append(all_segments[b][seg_k]["attention_mask"])
else:
# Pad with empty segment
batch_ids.append(torch.zeros(77, dtype=torch.long))
batch_masks.append(torch.zeros(77, dtype=torch.long))
ids = torch.stack(batch_ids).to(device)
masks = torch.stack(batch_masks).to(device)
with torch.amp.autocast("cuda"):
outputs, state = model(ids, masks, state)
all_anchors[:, seg_k] = outputs["live_anchor"]
# Student output from last segment
student_cls = outputs["memory_output"]
# ── Losses ──
with torch.amp.autocast("cuda"):
# InfoNCE: student β†’ teacher
proj_m = model.proj_modern(student_cls)
l_modern, acc_m = infonce_loss(proj_m, modern_cls, TCFG.temperature)
# Procrustes regularizer on student embeddings
l_procrustes = procrustes_alignment_loss(
student_cls, modern_cls[:, :model.config.clip_hidden])
# CV on live anchors
n_reals_t = torch.tensor(n_segs, device=device)
l_cv, cv_stats = batch_cv_loss(
all_anchors, n_reals_t, model.config.cv_target)
loss = (TCFG.modern_weight * l_modern +
TCFG.procrustes_weight * l_procrustes +
TCFG.cv_weight * l_cv)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(all_params, TCFG.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
scheduler.step()
global_step += 1
losses["total"] += loss.item()
losses["modern"] += l_modern.item()
losses["procrustes"] += l_procrustes.item()
losses["cv"] += l_cv.item()
metrics["modern_acc"] += acc_m
metrics["cv_raw"] += cv_stats.get("cv_raw", 0)
n += 1
d = max(n, 1)
pbar.set_postfix(
loss=f"{losses['total']/d:.3f}",
m_acc=f"{metrics['modern_acc']/d:.3f}",
cv=f"{metrics['cv_raw']/d:.3f}")
elapsed = time.time() - t0
d = max(n, 1)
print(f"\n Epoch {epoch+1}: {elapsed:.0f}s "
f"loss={losses['total']/d:.4f} "
f"m_acc={metrics['modern_acc']/d:.3f} "
f"cv={metrics['cv_raw']/d:.3f}")
# Save
save_checkpoint(model, optimizer, epoch + 1, global_step,
os.path.join(TCFG.checkpoint_dir, f"epoch_{epoch+1:03d}"))
save_checkpoint(model, optimizer, TCFG.epochs, global_step,
os.path.join(TCFG.checkpoint_dir, "final"))
def save_checkpoint(model, optimizer, epoch, global_step, path):
os.makedirs(path, exist_ok=True)
state = {}
for name, param in model.named_parameters():
if param.requires_grad:
state[name] = param.data.contiguous().cpu()
for name, buf in model.named_buffers():
state[f"buffer.{name}"] = buf.contiguous().cpu()
safetensors_save(state, os.path.join(path, "memory_system.safetensors"))
torch.save({"optimizer": optimizer.state_dict(), "epoch": epoch,
"global_step": global_step}, os.path.join(path, "training_state.pt"))
# ══════════════════════════════════════════════════════════════════
# MAIN
# ══════════════════════════════════════════════════════════════════
def main():
print("=" * 70)
print("TRAINING: MEMORY-EXTENDED CLIP-L TEXT ENCODER")
print("=" * 70)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Device: {device}")
if torch.cuda.is_available():
print(f" GPU: {torch.cuda.get_device_name()}")
print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# ── Student model ──
config = MemoryCLIPConfig()
model = MemoryExtendedCLIP(config)
model.setup(device)
model = model.to(device)
# ── Teacher: ModernBERT-large (frozen) ──
from transformers import AutoModel, AutoTokenizer
print(f"\n Loading ModernBERT-large...")
modern_model = AutoModel.from_pretrained(
config.teacher_model, torch_dtype=torch.float16).to(device)
modern_model.eval()
for p in modern_model.parameters():
p.requires_grad = False
modern_tok = AutoTokenizer.from_pretrained(config.teacher_model)
print(f" {sum(p.numel() for p in modern_model.parameters()):,} params (frozen)")
# ── Data: Long detailed captions that exceed CLIP's 77-token limit ──
print(f"\n Loading long captions...")
from datasets import load_dataset
train_captions = []
# Primary: CC12M with LLaVA-next detailed captions (~40-150 tokens each)
# 22M captions, parquet format, loads cleanly
try:
print(" Loading CaptionEmporium/conceptual-captions-cc12m-llavanext...")
ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext",
split="train", streaming=True)
for row in ds:
# Use the detailed LLaVA caption (not the shortened one)
cap = row.get("caption_llava", "")
if isinstance(cap, str) and len(cap) > 100: # want genuinely long
train_captions.append(cap)
if len(train_captions) >= TCFG.max_train_samples:
break
print(f" Got {len(train_captions)} long captions from CC12M-llavanext")
except Exception as e:
print(f" CC12M-llavanext failed: {e}")
# Fallback: ShareGPT4V (100K GPT4-Vision captions, very detailed)
if len(train_captions) < 1000:
try:
print(" Trying Lin-Chen/ShareGPT4V...")
ds = load_dataset("Lin-Chen/ShareGPT4V",
"ShareGPT4V", split="train", streaming=True)
for row in ds:
convs = row.get("conversations", [])
for c in convs:
val = c.get("value", "")
if isinstance(val, str) and len(val) > 100 and "<image>" not in val:
train_captions.append(val)
if len(train_captions) >= TCFG.max_train_samples:
break
print(f" Got {len(train_captions)} from ShareGPT4V")
except Exception as e:
print(f" ShareGPT4V failed: {e}")
# Last fallback: LLaVA-ReCap-CC3M
if len(train_captions) < 1000:
try:
print(" Trying lmms-lab/LLaVA-ReCap-CC3M...")
ds = load_dataset("lmms-lab/LLaVA-ReCap-CC3M",
split="train", streaming=True)
for row in ds:
cap = row.get("caption", row.get("text", ""))
if isinstance(cap, str) and len(cap) > 100:
train_captions.append(cap)
if len(train_captions) >= TCFG.max_train_samples:
break
print(f" Got {len(train_captions)} from LLaVA-ReCap-CC3M")
except Exception as e:
print(f" All fallbacks failed: {e}")
train_captions = train_captions[:TCFG.max_train_samples]
# Report caption length stats
if train_captions:
from transformers import CLIPTokenizer
tok_temp = CLIPTokenizer.from_pretrained(config.clip_model)
lengths = [len(tok_temp.encode(c)) for c in train_captions[:500]]
print(f" Caption token lengths (CLIP tokenizer, sample of {len(lengths)}):")
print(f" mean={np.mean(lengths):.0f} median={np.median(lengths):.0f} "
f"max={max(lengths)} >77: {sum(1 for l in lengths if l > 77)/len(lengths):.1%}")
del tok_temp
print(f" {len(train_captions)} captions loaded")
print(f" Example: {train_captions[0][:100]}...")
# ── Procrustes pre-alignment ──
align_captions = train_captions[:TCFG.procrustes_n_samples]
compute_and_init_procrustes(
model, modern_model, modern_tok, align_captions, device)
# ── Train ──
train(model, modern_model, modern_tok, train_captions)
print("\nDone.")
if __name__ == "__main__":
main()