AbstractPhil's picture
Create trainer.py
38cce53 verified
# ============================================================================
# TRAINER: MEMORY-CLIP-SEQ β€” Sequence Reconstruction
#
# Extends the v2 trainer with:
# - Teacher full sequence capture (ModernBERT last_hidden_state)
# - Sequence reconstruction loss (reconstructed 77 vs teacher projected 77)
# - Two-phase training:
# Phase 1: freeze v1 memory weights, train only seq head
# Phase 2: unfreeze all, joint fine-tune
# - v1 checkpoint loading at startup
#
# Core training loop (InfoNCE + Procrustes + CV) UNCHANGED from v2.
# Sequence loss is ADDED alongside existing losses.
# ============================================================================
import gc
import math
import os
import json
import time
from dataclasses import dataclass, asdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from safetensors.torch import save_file as safetensors_save, load_file
# ══════════════════════════════════════════════════════════════════
# CONFIG
# ══════════════════════════════════════════════════════════════════
@dataclass
class TrainSeqConfig:
# Data
max_train_samples: int = 50000
max_val_samples: int = 2000
min_caption_length: int = 100
# Training β€” phase 1 (seq head only)
phase1_epochs: int = 5
phase1_lr_seq: float = 2e-3
phase1_lr_proj: float = 1e-3
# Training β€” phase 2 (joint fine-tune)
phase2_epochs: int = 5
phase2_lr_bank: float = 5e-4 # reduced from v2's 2e-3
phase2_lr_output: float = 2e-4 # reduced
phase2_lr_proj: float = 5e-4
phase2_lr_seq: float = 1e-3
# Shared
batch_size: int = 64
min_lr: float = 1e-6
weight_decay: float = 0.01
grad_clip: float = 1.0
warmup_steps: int = 200
# Loss weights β€” existing (unchanged from v2)
modern_weight: float = 1.0
procrustes_weight: float = 0.3
cv_weight: float = 0.05
temperature: float = 0.07
# Loss weights β€” sequence (NEW)
sequence_weight: float = 1.0 # MSE between reconstructed and teacher seq
sequence_cosine_weight: float = 0.5 # per-position cosine similarity
# Teacher
modern_max_len: int = 4096
procrustes_n_samples: int = 300
# v1 checkpoint β€” local path or HuggingFace URL
v1_checkpoint: str = ""
v1_repo_id: str = "AbstractPhil/geolip-clip-vit-large-patch14-ctx576"
v1_filename: str = "model.safetensors"
# Logging
checkpoint_dir: str = "/home/claude/memory_clip_seq_checkpoints"
tensorboard_dir: str = "/home/claude/memory_clip_seq_tb"
metrics_file: str = "/home/claude/memory_clip_seq_checkpoints/metrics.json"
log_every: int = 20
eval_every: int = 200
TCFG = TrainSeqConfig()
# ══════════════════════════════════════════════════════════════════
# GEOMETRIC UTILITIES β€” IDENTICAL to v2
# ══════════════════════════════════════════════════════════════════
def cayley_menger_vol2(pts):
with torch.amp.autocast("cuda", enabled=False):
pts = pts.float()
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
B, V, _ = d2.shape
cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
def pentachoron_cv(embeddings, n_samples=16):
B = embeddings.shape[0]
if B < 5:
return torch.tensor(0.0, device=embeddings.device)
vols = []
for _ in range(n_samples):
idx = torch.randperm(B, device=embeddings.device)[:5]
v2 = cayley_menger_vol2(embeddings[idx].unsqueeze(0))
vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
stacked = torch.stack(vols)
return stacked.std() / (stacked.mean() + 1e-8)
def procrustes_alignment_loss(emb_a, emb_b):
with torch.amp.autocast("cuda", enabled=False):
A = F.normalize(emb_a.float(), dim=-1)
B_e = F.normalize(emb_b.float(), dim=-1)
A = A - A.mean(0, keepdim=True)
B_e = B_e - B_e.mean(0, keepdim=True)
S = torch.linalg.svdvals(A.T @ B_e)
N, D = A.shape
return 1.0 - S.sum() / (math.sqrt(N) * D)
# ══════════════════════════════════════════════════════════════════
# LOSSES β€” v2 existing + NEW sequence loss
# ══════════════════════════════════════════════════════════════════
def infonce_loss(emb_a, emb_b, temperature=0.07):
a = F.normalize(emb_a, dim=-1)
b = F.normalize(emb_b, dim=-1)
logits = (a @ b.T) / temperature
labels = torch.arange(logits.shape[0], device=logits.device)
loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
with torch.no_grad():
acc = (logits.argmax(-1) == labels).float().mean().item()
top5 = logits.topk(min(5, logits.shape[1]), dim=-1).indices
acc5 = (top5 == labels.unsqueeze(-1)).any(-1).float().mean().item()
return loss, acc, acc5
def batch_cv_loss(all_anchors, n_reals, cv_target=0.20):
device = all_anchors.device
B = all_anchors.shape[0]
total_loss = torch.tensor(0.0, device=device)
total_cv = 0.0; n_valid = 0
per_sample_cv = []
for b in range(B):
n = n_reals[b].item() if isinstance(n_reals[b], torch.Tensor) else n_reals[b]
if n < 5:
continue
cv_val = pentachoron_cv(all_anchors[b, :n], n_samples=16)
total_loss = total_loss + (cv_val - cv_target).abs()
total_cv += cv_val.item()
per_sample_cv.append(cv_val.item())
n_valid += 1
stats = {
"cv_raw": total_cv / max(n_valid, 1),
"cv_std": float(np.std(per_sample_cv)) if per_sample_cv else 0.0,
"cv_n_valid": n_valid,
}
return total_loss / max(n_valid, 1), stats
def sequence_reconstruction_loss(pred_seq, target_seq):
"""
pred_seq: (B, 77, 768) β€” reconstructed sequence
target_seq: (B, 77, 768) β€” teacher projected sequence
Returns:
mse_loss: mean squared error
cos_loss: 1 - mean per-position cosine similarity
mean_cos: scalar metric (not differentiable)
"""
mse = F.mse_loss(pred_seq, target_seq)
# Per-position cosine similarity
pred_norm = F.normalize(pred_seq, dim=-1)
tgt_norm = F.normalize(target_seq, dim=-1)
cos_sim = (pred_norm * tgt_norm).sum(-1) # (B, 77)
cos_loss = 1.0 - cos_sim.mean()
with torch.no_grad():
mean_cos = cos_sim.mean().item()
return mse, cos_loss, mean_cos
# ══════════════════════════════════════════════════════════════════
# TEACHER β€” returns BOTH pooled AND full sequence
# ══════════════════════════════════════════════════════════════════
@torch.no_grad()
def teacher_forward(model, tokenizer, texts, device, max_len):
"""Returns pooled (B, 1024) from ModernBERT. Sequence target comes from CLIP."""
inputs = tokenizer(texts, max_length=max_len, padding=True,
truncation=True, return_tensors="pt").to(device)
out = model(**inputs)
mask = inputs.attention_mask.unsqueeze(-1).float()
pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1)
return pooled
# ══════════════════════════════════════════════════════════════════
# PARAM GROUPS
# ══════════════════════════════════════════════════════════════════
def make_param_groups_phase1(model):
"""Phase 1: Only train sequence head + teacher seq projector."""
seq_params = []
for name, param in model.named_parameters():
param.requires_grad = False # freeze everything first
for name, param in model.named_parameters():
if "sequence_reconstructor" in name:
param.requires_grad = True
seq_params.append(param)
# Also keep proj_modern trainable (it's the pooled projector)
for name, param in model.named_parameters():
if "proj_modern" in name:
param.requires_grad = True
proj_params = [p for n, p in model.named_parameters()
if "proj_modern" in n and p.requires_grad]
groups = [
{"params": seq_params, "lr": TCFG.phase1_lr_seq, "name": "seq_head",
"weight_decay": TCFG.weight_decay},
{"params": proj_params, "lr": TCFG.phase1_lr_proj, "name": "proj",
"weight_decay": TCFG.weight_decay},
]
for g in groups:
n = sum(p.numel() for p in g["params"])
print(f" {g['name']:12s}: {n:>10,} params @ lr={g['lr']}")
return groups
def make_param_groups_phase2(model):
"""Phase 2: Unfreeze everything, differential LRs."""
# Unfreeze all trainable (non-CLIP) params
for name, param in model.named_parameters():
if "clip_text" not in name and "_clip_text" not in name:
param.requires_grad = True
bank_names = {"bank.", "clip_cross_attn", "clip_cross_norms",
"clip_cross_ffns", "clip_cross_ffn_norms"}
seq_names = {"sequence_reconstructor"}
proj_names = {"proj_modern"}
bank_p, seq_p, proj_p, output_p = [], [], [], []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if any(name.startswith(s) or s in name for s in seq_names):
seq_p.append(param)
elif any(name.startswith(s) or s in name for s in proj_names):
proj_p.append(param)
elif any(name.startswith(s) or s in name for s in bank_names):
bank_p.append(param)
else:
output_p.append(param)
groups = [
{"params": bank_p, "lr": TCFG.phase2_lr_bank, "name": "bank",
"weight_decay": TCFG.weight_decay},
{"params": seq_p, "lr": TCFG.phase2_lr_seq, "name": "seq_head",
"weight_decay": TCFG.weight_decay},
{"params": proj_p, "lr": TCFG.phase2_lr_proj, "name": "proj",
"weight_decay": TCFG.weight_decay},
{"params": output_p, "lr": TCFG.phase2_lr_output, "name": "output",
"weight_decay": TCFG.weight_decay},
]
for g in groups:
n = sum(p.numel() for p in g["params"])
print(f" {g['name']:12s}: {n:>10,} params @ lr={g['lr']}")
return groups
# ══════════════════════════════════════════════════════════════════
# PROCRUSTES INIT β€” IDENTICAL to v2
# ══════════════════════════════════════════════════════════════════
@torch.no_grad()
def compute_and_init_procrustes(student_model, modern_model, modern_tok,
captions, device):
print(f"\n Computing static Procrustes on {len(captions)} captions...")
student_embs, modern_embs = [], []
clip_tok = student_model.clip_tokenizer
for i in range(0, len(captions), 16):
batch = captions[i:i+16]
tokens = clip_tok(batch, max_length=77, padding=True,
truncation=True, return_tensors="pt").to(device)
clip_out = student_model.clip_text(
input_ids=tokens.input_ids,
attention_mask=tokens.attention_mask,
output_hidden_states=False)
student_embs.append(clip_out.pooler_output.cpu())
pooled = teacher_forward(modern_model, modern_tok, batch,
device, TCFG.modern_max_len)
modern_embs.append(pooled.cpu())
student_all = torch.cat(student_embs)
modern_all = torch.cat(modern_embs)
print(f" Student: {student_all.shape}, Teacher: {modern_all.shape}")
X = student_all.float(); Y = modern_all.float()
mu_x, mu_y = X.mean(0), Y.mean(0)
Xc, Yc = X - mu_x, Y - mu_y
if Xc.shape[1] < Yc.shape[1]:
pad = torch.zeros(Xc.shape[0], Yc.shape[1] - Xc.shape[1])
Xc = torch.cat([Xc, pad], dim=1)
mu_x = torch.cat([mu_x, torch.zeros(Yc.shape[1] - mu_x.shape[0])])
U, S, Vt = torch.linalg.svd(Xc.T @ Yc)
R = (U @ Vt).T
cos_before = F.cosine_similarity(Xc, Yc, dim=-1).mean()
cos_after = F.cosine_similarity((Xc @ R.T), Yc, dim=-1).mean()
print(f" Procrustes: cos {cos_before:.4f} β†’ {cos_after:.4f}")
# Init the pooled projector if it has init_from_procrustes
if hasattr(student_model.proj_modern, 'init_from_procrustes'):
student_model.proj_modern.init_from_procrustes(R, mu_x, mu_y)
return {"cos_before": cos_before.item(), "cos_after": cos_after.item()}
# ══════════════════════════════════════════════════════════════════
# V1 WEIGHT LOADING
# ══════════════════════════════════════════════════════════════════
def load_v1_weights(model, device):
"""Load v1 memory system weights into the expanded seq model.
Tries local path first, then downloads from HuggingFace."""
checkpoint_path = TCFG.v1_checkpoint
# Try local path first
if checkpoint_path and os.path.exists(checkpoint_path):
print(f" Loading v1 weights (local): {checkpoint_path}")
else:
# Download from HuggingFace
from huggingface_hub import hf_hub_download
print(f" Downloading v1 weights from {TCFG.v1_repo_id}/{TCFG.v1_filename}...")
checkpoint_path = hf_hub_download(
repo_id=TCFG.v1_repo_id,
filename=TCFG.v1_filename)
print(f" Downloaded to: {checkpoint_path}")
state = load_file(checkpoint_path, device=str(device))
missing, unexpected = model.load_state_dict(state, strict=False)
n_loaded = len(state) - len(unexpected)
print(f" Loaded: {n_loaded} tensors from v1")
print(f" Missing (new modules): {len(missing)}")
if missing:
new_module_keys = [k for k in missing if "sequence_reconstructor" in k
]
other_missing = [k for k in missing if k not in new_module_keys]
print(f" Seq head (expected new): {len(new_module_keys)}")
if other_missing:
print(f" Other (check!): {other_missing[:5]}")
if unexpected:
print(f" Unexpected (v1 buffers, ignorable): {len(unexpected)}")
return True
# ══════════════════════════════════════════════════════════════════
# TRAINING
# ══════════════════════════════════════════════════════════════════
def train_phase(model, modern_model, modern_tok, train_captions, val_captions,
param_groups, n_epochs, phase_name, writer, all_metrics,
global_step=0):
"""
Single training phase. Used for both phase 1 and phase 2.
"""
device = next(model.parameters()).device
optimizer = torch.optim.AdamW(param_groups)
all_params = [p for g in param_groups for p in g["params"]]
n_batches = len(train_captions) // TCFG.batch_size
total_steps = n_batches * n_epochs
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
[torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01,
total_iters=TCFG.warmup_steps),
torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=max(total_steps, 1), eta_min=TCFG.min_lr)],
milestones=[TCFG.warmup_steps])
scaler = torch.amp.GradScaler()
clip_tokenizer = model.clip_tokenizer
best_val_loss = float("inf")
print(f"\n {phase_name}: {sum(p.numel() for p in all_params):,} trainable params")
print(f" {n_batches} batches/epoch Γ— {TCFG.batch_size}")
# segment_text is in notebook namespace from architecture cell
_segment_text = segment_text
for epoch in range(n_epochs):
model.train()
perm = np.random.permutation(len(train_captions))
losses = {"total": 0, "modern": 0, "procrustes": 0, "cv": 0,
"seq_mse": 0, "seq_cos": 0}
metrics = {"modern_acc": 0, "modern_acc5": 0,
"cv_raw": 0, "seq_cos_sim": 0, "n_segments_avg": 0}
n = 0
t0 = time.time()
pbar = tqdm(range(0, len(train_captions), TCFG.batch_size),
desc=f"{phase_name} E{epoch+1}/{n_epochs}", unit="batch")
for batch_start in pbar:
idx = perm[batch_start:batch_start + TCFG.batch_size]
if len(idx) < 2:
continue
batch_captions = [train_captions[i] for i in idx]
B = len(batch_captions)
# ── Teacher: ModernBERT pooled (sequence target comes from CLIP) ──
with torch.no_grad():
with torch.amp.autocast("cuda"):
modern_cls = teacher_forward(
modern_model, modern_tok, batch_captions,
device, TCFG.modern_max_len)
# ── Student: segment-by-segment processing ──
state = model.init_state(B, device)
all_segments = [_segment_text(cap, clip_tokenizer,
model.config.max_content_tokens,
model.config.segment_overlap,
model.config.max_segments)
for cap in batch_captions]
max_segs = max(len(s) for s in all_segments)
n_segs = [len(s) for s in all_segments]
for seg_k in range(max_segs):
batch_ids, batch_masks = [], []
for b in range(B):
if seg_k < len(all_segments[b]):
batch_ids.append(all_segments[b][seg_k]["input_ids"])
batch_masks.append(all_segments[b][seg_k]["attention_mask"])
else:
batch_ids.append(torch.zeros(77, dtype=torch.long))
batch_masks.append(torch.zeros(77, dtype=torch.long))
ids = torch.stack(batch_ids).to(device)
masks = torch.stack(batch_masks).to(device)
with torch.amp.autocast("cuda"):
fused_output, state = model.forward_segment(ids, masks, state)
student_cls = fused_output # pooled output from last segment
# Bank anchors for CV loss β€” accumulated in state during segment processing
bank_anchors = state["bank"]["anchors"] # (B, N_written, 768)
# Pad to max_segs for batch CV computation
all_anchors = torch.zeros(B, max_segs, model.config.anchor_dim, device=device)
n_written = min(bank_anchors.shape[1], max_segs)
all_anchors[:, :n_written] = bank_anchors[:, :n_written]
# ── Existing losses (UNCHANGED from v2) ──
with torch.amp.autocast("cuda"):
proj_m = model.proj_modern(student_cls)
l_modern, acc_m, acc5_m = infonce_loss(
proj_m, modern_cls, TCFG.temperature)
l_procrustes = procrustes_alignment_loss(
student_cls, modern_cls[:, :model.config.clip_hidden])
n_reals_t = torch.tensor(n_segs, device=device)
l_cv, cv_stats = batch_cv_loss(
all_anchors, n_reals_t, model.config.cv_target)
# ── NEW: Sequence reconstruction loss ──
# Target: CLIP's own last_hidden_state on the truncated caption.
# This is what the UNet was trained on β€” the reconstructor must
# produce sequences in CLIP's distribution.
with torch.no_grad():
clip_inputs = clip_tokenizer(
batch_captions, max_length=77, padding="max_length",
truncation=True, return_tensors="pt").to(device)
with torch.amp.autocast("cuda"):
clip_target_out = model.clip_text(
input_ids=clip_inputs.input_ids,
attention_mask=clip_inputs.attention_mask,
output_hidden_states=False, return_dict=True)
clip_target_seq = clip_target_out.last_hidden_state # (B, 77, 768)
with torch.amp.autocast("cuda"):
# Reconstruct sequence from memory state
recon_seq = model.reconstruct_sequence(state) # (B, 77, 768)
l_seq_mse, l_seq_cos, seq_cos_metric = sequence_reconstruction_loss(
recon_seq, clip_target_seq.detach())
# ── Combined loss ──
with torch.amp.autocast("cuda"):
loss = (TCFG.modern_weight * l_modern +
TCFG.procrustes_weight * l_procrustes +
TCFG.cv_weight * l_cv +
TCFG.sequence_weight * l_seq_mse +
TCFG.sequence_cosine_weight * l_seq_cos)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(all_params, TCFG.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
scheduler.step()
global_step += 1
# ── Metrics ──
losses["total"] += loss.item()
losses["modern"] += l_modern.item()
losses["procrustes"] += l_procrustes.item()
losses["cv"] += l_cv.item()
losses["seq_mse"] += l_seq_mse.item()
losses["seq_cos"] += l_seq_cos.item()
metrics["modern_acc"] += acc_m
metrics["modern_acc5"] += acc5_m
metrics["cv_raw"] += cv_stats.get("cv_raw", 0)
metrics["seq_cos_sim"] += seq_cos_metric
metrics["n_segments_avg"] += np.mean(n_segs)
n += 1
d = max(n, 1)
pbar.set_postfix(
loss=f"{losses['total']/d:.3f}",
m_acc=f"{metrics['modern_acc']/d:.3f}",
s_cos=f"{metrics['seq_cos_sim']/d:.3f}",
cv=f"{metrics['cv_raw']/d:.3f}")
# Tensorboard
if global_step % TCFG.log_every == 0:
writer.add_scalar(f"{phase_name}/loss", losses["total"]/d, global_step)
writer.add_scalar(f"{phase_name}/modern_loss", losses["modern"]/d, global_step)
writer.add_scalar(f"{phase_name}/seq_mse", losses["seq_mse"]/d, global_step)
writer.add_scalar(f"{phase_name}/seq_cos_loss", losses["seq_cos"]/d, global_step)
writer.add_scalar(f"{phase_name}/seq_cos_sim", metrics["seq_cos_sim"]/d, global_step)
writer.add_scalar(f"{phase_name}/m_acc", metrics["modern_acc"]/d, global_step)
writer.add_scalar(f"{phase_name}/cv_raw", metrics["cv_raw"]/d, global_step)
all_metrics["steps"].append({
"step": global_step, "phase": phase_name,
"epoch": epoch + 1,
"loss": losses["total"]/d,
"m_acc": metrics["modern_acc"]/d,
"seq_cos": metrics["seq_cos_sim"]/d,
"cv_raw": metrics["cv_raw"]/d,
})
pbar.close()
elapsed = time.time() - t0
d = max(n, 1)
epoch_summary = {
"phase": phase_name, "epoch": epoch + 1,
"elapsed_s": elapsed,
"loss": losses["total"]/d,
"modern_loss": losses["modern"]/d,
"seq_mse": losses["seq_mse"]/d,
"seq_cos_loss": losses["seq_cos"]/d,
"m_acc": metrics["modern_acc"]/d,
"m_acc5": metrics["modern_acc5"]/d,
"seq_cos_sim": metrics["seq_cos_sim"]/d,
"cv_raw": metrics["cv_raw"]/d,
"global_step": global_step,
}
all_metrics["epochs"].append(epoch_summary)
print(f"\n {phase_name} E{epoch+1}: {elapsed:.0f}s "
f"loss={epoch_summary['loss']:.4f} "
f"m_acc={epoch_summary['m_acc']:.3f} "
f"seq_cos={epoch_summary['seq_cos_sim']:.3f} "
f"cv={epoch_summary['cv_raw']:.3f}")
# Save
save_checkpoint(model, optimizer, epoch + 1, global_step, phase_name,
os.path.join(TCFG.checkpoint_dir, f"{phase_name}_e{epoch+1:02d}"))
with open(TCFG.metrics_file, "w") as f:
json.dump(all_metrics, f, indent=2, default=str)
return global_step
def save_checkpoint(model, optimizer, epoch, global_step, phase, path):
os.makedirs(path, exist_ok=True)
state = {}
for name, param in model.named_parameters():
if param.requires_grad:
state[name] = param.data.contiguous().cpu()
for name, buf in model.named_buffers():
state[f"buffer.{name}"] = buf.contiguous().cpu()
safetensors_save(state, os.path.join(path, "memory_system.safetensors"))
torch.save({"optimizer": optimizer.state_dict() if optimizer else {},
"epoch": epoch,
"global_step": global_step, "phase": phase},
os.path.join(path, "training_state.pt"))
model_cfg = model.config.to_dict() if hasattr(model.config, 'to_dict') else {}
config_data = {"model": model_cfg, "training": asdict(TCFG)}
with open(os.path.join(path, "config.json"), "w") as f:
json.dump(config_data, f, indent=2, default=str)
# ══════════════════════════════════════════════════════════════════
# DATA
# ══════════════════════════════════════════════════════════════════
def load_long_captions(max_train, max_val, min_length=100):
from datasets import load_dataset
print(f" Loading CaptionEmporium/conceptual-captions-cc12m-llavanext...")
ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext",
split="train", streaming=True)
captions = []
for row in ds:
cap = row.get("caption_llava", "")
if isinstance(cap, str) and len(cap) > min_length:
captions.append(cap)
if len(captions) >= max_train + max_val:
break
train_caps = captions[:max_train]
val_caps = captions[max_train:max_train + max_val]
print(f" Train: {len(train_caps)}, Val: {len(val_caps)}")
return train_caps, val_caps
# ══════════════════════════════════════════════════════════════════
# MAIN
# ══════════════════════════════════════════════════════════════════
def main():
print("=" * 70)
print("TRAINING: MEMORY-CLIP-SEQ (SEQUENCE RECONSTRUCTION)")
print("=" * 70)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Device: {device}")
if torch.cuda.is_available():
print(f" GPU: {torch.cuda.get_device_name()}")
print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# ── Model (classes loaded from architecture cell) ──
config = MemoryCLIPSeqConfig()
model = MemoryCLIPSeqModel(config).to(device)
# Trigger CLIP lazy load
_ = model.clip_text
_ = model.clip_tokenizer
# ── Load v1 weights ──
load_v1_weights(model, device)
print(" v1 memory system weights loaded")
# ── Teacher ──
from transformers import AutoModel, AutoTokenizer
print(f"\n Loading ModernBERT-large...")
modern_model = AutoModel.from_pretrained(
config.teacher_model, torch_dtype=torch.float16).to(device)
modern_model.eval()
for p in modern_model.parameters():
p.requires_grad = False
modern_tok = AutoTokenizer.from_pretrained(config.teacher_model)
print(f" {sum(p.numel() for p in modern_model.parameters()):,} params (frozen)")
# ── Data ──
train_captions, val_captions = load_long_captions(
TCFG.max_train_samples, TCFG.max_val_samples, TCFG.min_caption_length)
from transformers import CLIPTokenizer
tok_temp = CLIPTokenizer.from_pretrained(config.clip_model)
lengths = [len(tok_temp.encode(c)) for c in train_captions[:500]]
print(f" Caption tokens (sample 500): mean={np.mean(lengths):.0f} "
f"median={np.median(lengths):.0f} max={max(lengths)} "
f">77: {sum(1 for l in lengths if l > 77)/len(lengths):.1%}")
del tok_temp
# ── Procrustes init ──
compute_and_init_procrustes(
model, modern_model, modern_tok,
train_captions[:TCFG.procrustes_n_samples], device)
# ── Setup logging ──
os.makedirs(TCFG.checkpoint_dir, exist_ok=True)
os.makedirs(TCFG.tensorboard_dir, exist_ok=True)
writer = SummaryWriter(log_dir=TCFG.tensorboard_dir)
all_metrics = {
"config": {**{k: v for k, v in config.to_dict().items()
if not k.startswith("_")},
**asdict(TCFG)},
"epochs": [], "steps": [],
}
global_step = 0
# ═══════════════════════════════════════════════════
# PHASE 1: Train sequence head only (v1 weights frozen)
# ═══════════════════════════════════════════════════
print(f"\n{'='*70}")
print(f"PHASE 1: Sequence head training ({TCFG.phase1_epochs} epochs)")
print(f" v1 memory system: FROZEN")
print(f" Sequence reconstructor: TRAINING")
print(f"{'='*70}")
phase1_groups = make_param_groups_phase1(model)
global_step = train_phase(
model, modern_model, modern_tok,
train_captions, val_captions,
phase1_groups, TCFG.phase1_epochs,
"phase1", writer, all_metrics, global_step)
save_checkpoint(model, None, TCFG.phase1_epochs, global_step, "phase1",
os.path.join(TCFG.checkpoint_dir, "phase1_final"))
# ═══════════════════════════════════════════════════
# PHASE 2: Joint fine-tune (everything unfrozen)
# ═══════════════════════════════════════════════════
print(f"\n{'='*70}")
print(f"PHASE 2: Joint fine-tune ({TCFG.phase2_epochs} epochs)")
print(f" All trainable modules: TRAINING")
print(f" v1 components: reduced LR")
print(f"{'='*70}")
phase2_groups = make_param_groups_phase2(model)
global_step = train_phase(
model, modern_model, modern_tok,
train_captions, val_captions,
phase2_groups, TCFG.phase2_epochs,
"phase2", writer, all_metrics, global_step)
# ── Final save ──
save_checkpoint(model, None, TCFG.phase1_epochs + TCFG.phase2_epochs,
global_step, "final",
os.path.join(TCFG.checkpoint_dir, "final"))
all_metrics["final"] = {
"total_steps": global_step,
"final_m_acc": all_metrics["epochs"][-1]["m_acc"],
"final_seq_cos": all_metrics["epochs"][-1]["seq_cos_sim"],
"final_cv": all_metrics["epochs"][-1]["cv_raw"],
}
with open(TCFG.metrics_file, "w") as f:
json.dump(all_metrics, f, indent=2, default=str)
writer.flush()
writer.close()
print(f"\n{'='*70}")
print(f"FINAL:")
final = all_metrics["epochs"][-1]
print(f" m_acc: {final['m_acc']:.4f}")
print(f" seq_cos: {final['seq_cos_sim']:.4f}")
print(f" CV: {final['cv_raw']:.4f}")
print(f"{'='*70}")
if __name__ == "__main__":
main()