| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import gc |
| import math |
| import os |
| import time |
| from dataclasses import dataclass, asdict |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from tqdm import tqdm |
| from safetensors.torch import save_file as safetensors_save |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class TrainConfig: |
| |
| max_train_samples: int = 50000 |
| max_val_samples: int = 1000 |
| min_caption_length: int = 20 |
|
|
| |
| epochs: int = 10 |
| batch_size: int = 64 |
| lr_bank: float = 2e-3 |
| lr_output: float = 5e-4 |
| lr_proj: float = 1e-3 |
| min_lr: float = 1e-6 |
| weight_decay: float = 0.01 |
| grad_clip: float = 1.0 |
| warmup_steps: int = 200 |
|
|
| |
| modern_weight: float = 1.0 |
| clip_vision_weight: float = 0.5 |
| procrustes_weight: float = 0.3 |
| cv_weight: float = 0.05 |
| temperature: float = 0.07 |
|
|
| |
| modern_max_len: int = 4096 |
| procrustes_n_samples: int = 300 |
|
|
| |
| checkpoint_dir: str = "/home/claude/memory_clip_checkpoints" |
| log_every: int = 20 |
| eval_every: int = 200 |
|
|
|
|
| TCFG = TrainConfig() |
|
|
|
|
| |
| |
| |
|
|
| def cayley_menger_vol2(pts): |
| with torch.amp.autocast("cuda", enabled=False): |
| pts = pts.float() |
| diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| d2 = (diff * diff).sum(-1) |
| B, V, _ = d2.shape |
| cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| s = (-1.0)**V; f = math.factorial(V-1) |
| return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) |
|
|
| def pentachoron_cv(embeddings, n_samples=16): |
| B = embeddings.shape[0] |
| if B < 5: |
| return torch.tensor(0.0, device=embeddings.device) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(B, device=embeddings.device)[:5] |
| v2 = cayley_menger_vol2(embeddings[idx].unsqueeze(0)) |
| vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) |
| stacked = torch.stack(vols) |
| return stacked.std() / (stacked.mean() + 1e-8) |
|
|
| def procrustes_alignment_loss(emb_a, emb_b): |
| with torch.amp.autocast("cuda", enabled=False): |
| A = F.normalize(emb_a.float(), dim=-1) |
| B_e = F.normalize(emb_b.float(), dim=-1) |
| A = A - A.mean(0, keepdim=True) |
| B_e = B_e - B_e.mean(0, keepdim=True) |
| S = torch.linalg.svdvals(A.T @ B_e) |
| N, D = A.shape |
| return 1.0 - S.sum() / (math.sqrt(N) * D) |
|
|
|
|
| |
| |
| |
|
|
| def infonce_loss(emb_a, emb_b, temperature=0.07): |
| a = F.normalize(emb_a, dim=-1) |
| b = F.normalize(emb_b, dim=-1) |
| logits = (a @ b.T) / temperature |
| labels = torch.arange(logits.shape[0], device=logits.device) |
| loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 |
| with torch.no_grad(): |
| acc = (logits.argmax(-1) == labels).float().mean().item() |
| return loss, acc |
|
|
|
|
| def batch_cv_loss(all_anchors, n_reals, cv_target=0.20): |
| device = all_anchors.device |
| B = all_anchors.shape[0] |
| total_loss = torch.tensor(0.0, device=device) |
| total_cv = 0.0; n_valid = 0 |
| for b in range(B): |
| n = n_reals[b].item() if isinstance(n_reals[b], torch.Tensor) else n_reals[b] |
| if n < 5: |
| continue |
| cv_val = pentachoron_cv(all_anchors[b, :n], n_samples=16) |
| total_loss = total_loss + (cv_val - cv_target).abs() |
| total_cv += cv_val.item() |
| n_valid += 1 |
| return total_loss / max(n_valid, 1), {"cv_raw": total_cv / max(n_valid, 1)} |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def teacher_forward_modern(model, tokenizer, texts, device, max_len): |
| inputs = tokenizer(texts, max_length=max_len, padding=True, |
| truncation=True, return_tensors="pt").to(device) |
| out = model(**inputs) |
| mask = inputs.attention_mask.unsqueeze(-1).float() |
| return (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1) |
|
|
|
|
| @torch.no_grad() |
| def vision_forward(clip_model, processor, images, device): |
| """Get CLIP vision embeddings for alignment preservation.""" |
| inputs = processor(images=images, return_tensors="pt").to(device) |
| return clip_model.get_image_features(**inputs) |
|
|
|
|
| |
| |
| |
|
|
| class CaptionDataset(Dataset): |
| """ |
| Wraps a list of caption strings. |
| For POC: COCO captions (short, ~10-20 tokens). |
| For production: ShareGPT4V or LAION-COCO (long, 100-500 tokens). |
| """ |
| def __init__(self, captions, images=None): |
| self.captions = captions |
| self.images = images |
|
|
| def __len__(self): |
| return len(self.captions) |
|
|
| def __getitem__(self, idx): |
| item = {"caption": self.captions[idx]} |
| if self.images is not None: |
| item["image"] = self.images[idx] |
| return item |
|
|
|
|
| |
| |
| |
|
|
| def make_param_groups(model): |
| bank_names = {"bank.depth_compressor", "bank.temporal_proj", |
| "bank.cross_attn", "bank.cross_norms", |
| "bank.cross_ffns", "bank.ffn_norms", |
| "clip_cross_attn", "clip_cross_norms", |
| "clip_cross_ffns", "clip_cross_ffn_norms"} |
| proj_names = {"proj_modern"} |
|
|
| bank_p, proj_p, output_p = [], [], [] |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if any(name.startswith(p) for p in proj_names): |
| proj_p.append(param) |
| elif any(name.startswith(p) for p in bank_names): |
| bank_p.append(param) |
| else: |
| output_p.append(param) |
|
|
| groups = [ |
| {"params": bank_p, "lr": TCFG.lr_bank, "name": "bank", |
| "weight_decay": TCFG.weight_decay}, |
| {"params": proj_p, "lr": TCFG.lr_proj, "name": "proj", |
| "weight_decay": TCFG.weight_decay}, |
| {"params": output_p, "lr": TCFG.lr_output, "name": "output", |
| "weight_decay": TCFG.weight_decay}, |
| ] |
| for g in groups: |
| n = sum(p.numel() for p in g["params"]) |
| print(f" {g['name']:8s}: {n:>10,} params @ lr={g['lr']}") |
| return groups |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def compute_and_init_procrustes(student_model, modern_model, modern_tok, |
| captions, device): |
| print(f"\n Computing static Procrustes on {len(captions)} captions...") |
| student_embs, modern_embs = [], [] |
|
|
| for i in range(0, len(captions), 16): |
| batch = captions[i:i+16] |
|
|
| |
| clip_tok = student_model.clip_tokenizer |
| tokens = clip_tok(batch, max_length=77, padding=True, |
| truncation=True, return_tensors="pt").to(device) |
| clip_out = student_model.clip_text( |
| input_ids=tokens.input_ids, |
| attention_mask=tokens.attention_mask, |
| output_hidden_states=False) |
| |
| student_embs.append(clip_out.pooler_output.cpu()) |
|
|
| |
| modern_embs.append( |
| teacher_forward_modern(modern_model, modern_tok, batch, |
| device, TCFG.modern_max_len).cpu()) |
|
|
| student_all = torch.cat(student_embs) |
| modern_all = torch.cat(modern_embs) |
|
|
| print(f" Student: {student_all.shape}, Teacher: {modern_all.shape}") |
| R, mu_s, mu_t = compute_static_procrustes(student_all, modern_all) |
| student_model.proj_modern.init_from_procrustes(R, mu_s, mu_t) |
|
|
|
|
| @torch.no_grad() |
| def compute_static_procrustes(student_embs, teacher_embs): |
| X = student_embs.float() |
| Y = teacher_embs.float() |
| mu_x, mu_y = X.mean(0), Y.mean(0) |
| Xc, Yc = X - mu_x, Y - mu_y |
|
|
| |
| if Xc.shape[1] < Yc.shape[1]: |
| pad = torch.zeros(Xc.shape[0], Yc.shape[1] - Xc.shape[1]) |
| Xc = torch.cat([Xc, pad], dim=1) |
| mu_x = torch.cat([mu_x, torch.zeros(Yc.shape[1] - mu_x.shape[0])]) |
|
|
| U, S, Vt = torch.linalg.svd(Xc.T @ Yc) |
| R = (U @ Vt).T |
| cos_before = F.cosine_similarity(Xc, Yc, dim=-1).mean() |
| cos_after = F.cosine_similarity((Xc @ R.T), Yc, dim=-1).mean() |
| print(f" Procrustes: cos {cos_before:.4f} β {cos_after:.4f}") |
| return R, mu_x, mu_y |
|
|
|
|
| |
| |
| |
|
|
| def train(model, modern_model, modern_tok, train_captions, val_captions=None): |
| device = next(model.parameters()).device |
| os.makedirs(TCFG.checkpoint_dir, exist_ok=True) |
|
|
| param_groups = make_param_groups(model) |
| optimizer = torch.optim.AdamW(param_groups) |
| all_params = [p for g in param_groups for p in g["params"]] |
|
|
| n_batches_per_epoch = len(train_captions) // TCFG.batch_size |
| total_steps = n_batches_per_epoch * TCFG.epochs |
| scheduler = torch.optim.lr_scheduler.SequentialLR( |
| optimizer, |
| [torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, |
| total_iters=TCFG.warmup_steps), |
| torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=max(total_steps, 1), eta_min=TCFG.min_lr)], |
| milestones=[TCFG.warmup_steps]) |
|
|
| scaler = torch.amp.GradScaler() |
| global_step = 0 |
| clip_tokenizer = model.clip_tokenizer |
|
|
| print(f"\n Training: {model.num_trainable_params():,} params") |
| print(f" {n_batches_per_epoch} batches/epoch Γ {TCFG.batch_size}") |
| print(f" Losses: modern({TCFG.modern_weight}) + " |
| f"procrustes({TCFG.procrustes_weight}) + cv({TCFG.cv_weight})") |
|
|
| for epoch in range(TCFG.epochs): |
| model.train() |
| perm = np.random.permutation(len(train_captions)) |
| losses = {"total": 0, "modern": 0, "procrustes": 0, "cv": 0} |
| metrics = {"modern_acc": 0, "cv_raw": 0} |
| n = 0 |
| t0 = time.time() |
|
|
| pbar = tqdm(range(0, len(train_captions), TCFG.batch_size), |
| desc=f"Epoch {epoch+1}/{TCFG.epochs}") |
|
|
| for batch_start in pbar: |
| idx = perm[batch_start:batch_start + TCFG.batch_size] |
| if len(idx) < 2: |
| continue |
| batch_captions = [train_captions[i] for i in idx] |
| B = len(batch_captions) |
|
|
| |
| with torch.no_grad(): |
| with torch.amp.autocast("cuda"): |
| modern_cls = teacher_forward_modern( |
| modern_model, modern_tok, batch_captions, |
| device, TCFG.modern_max_len) |
|
|
| |
| state = model.init_state(B, device) |
|
|
| |
| all_segments = [segment_text(cap, clip_tokenizer, |
| model.config.max_content_tokens, |
| model.config.segment_overlap, |
| model.config.max_segments) |
| for cap in batch_captions] |
|
|
| |
| max_segs = max(len(s) for s in all_segments) |
| n_segs = [len(s) for s in all_segments] |
|
|
| |
| all_anchors = torch.zeros( |
| B, max_segs, model.config.anchor_dim, device=device) |
|
|
| for seg_k in range(max_segs): |
| |
| batch_ids = [] |
| batch_masks = [] |
| for b in range(B): |
| if seg_k < len(all_segments[b]): |
| batch_ids.append(all_segments[b][seg_k]["input_ids"]) |
| batch_masks.append(all_segments[b][seg_k]["attention_mask"]) |
| else: |
| |
| batch_ids.append(torch.zeros(77, dtype=torch.long)) |
| batch_masks.append(torch.zeros(77, dtype=torch.long)) |
|
|
| ids = torch.stack(batch_ids).to(device) |
| masks = torch.stack(batch_masks).to(device) |
|
|
| with torch.amp.autocast("cuda"): |
| outputs, state = model(ids, masks, state) |
| all_anchors[:, seg_k] = outputs["live_anchor"] |
|
|
| |
| student_cls = outputs["memory_output"] |
|
|
| |
| with torch.amp.autocast("cuda"): |
| |
| proj_m = model.proj_modern(student_cls) |
| l_modern, acc_m = infonce_loss(proj_m, modern_cls, TCFG.temperature) |
|
|
| |
| l_procrustes = procrustes_alignment_loss( |
| student_cls, modern_cls[:, :model.config.clip_hidden]) |
|
|
| |
| n_reals_t = torch.tensor(n_segs, device=device) |
| l_cv, cv_stats = batch_cv_loss( |
| all_anchors, n_reals_t, model.config.cv_target) |
|
|
| loss = (TCFG.modern_weight * l_modern + |
| TCFG.procrustes_weight * l_procrustes + |
| TCFG.cv_weight * l_cv) |
|
|
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(all_params, TCFG.grad_clip) |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad(set_to_none=True) |
| scheduler.step() |
| global_step += 1 |
|
|
| losses["total"] += loss.item() |
| losses["modern"] += l_modern.item() |
| losses["procrustes"] += l_procrustes.item() |
| losses["cv"] += l_cv.item() |
| metrics["modern_acc"] += acc_m |
| metrics["cv_raw"] += cv_stats.get("cv_raw", 0) |
| n += 1 |
|
|
| d = max(n, 1) |
| pbar.set_postfix( |
| loss=f"{losses['total']/d:.3f}", |
| m_acc=f"{metrics['modern_acc']/d:.3f}", |
| cv=f"{metrics['cv_raw']/d:.3f}") |
|
|
| elapsed = time.time() - t0 |
| d = max(n, 1) |
| print(f"\n Epoch {epoch+1}: {elapsed:.0f}s " |
| f"loss={losses['total']/d:.4f} " |
| f"m_acc={metrics['modern_acc']/d:.3f} " |
| f"cv={metrics['cv_raw']/d:.3f}") |
|
|
| |
| save_checkpoint(model, optimizer, epoch + 1, global_step, |
| os.path.join(TCFG.checkpoint_dir, f"epoch_{epoch+1:03d}")) |
|
|
| save_checkpoint(model, optimizer, TCFG.epochs, global_step, |
| os.path.join(TCFG.checkpoint_dir, "final")) |
|
|
|
|
| def save_checkpoint(model, optimizer, epoch, global_step, path): |
| os.makedirs(path, exist_ok=True) |
| state = {} |
| for name, param in model.named_parameters(): |
| if param.requires_grad: |
| state[name] = param.data.contiguous().cpu() |
| for name, buf in model.named_buffers(): |
| state[f"buffer.{name}"] = buf.contiguous().cpu() |
| safetensors_save(state, os.path.join(path, "memory_system.safetensors")) |
| torch.save({"optimizer": optimizer.state_dict(), "epoch": epoch, |
| "global_step": global_step}, os.path.join(path, "training_state.pt")) |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| print("=" * 70) |
| print("TRAINING: MEMORY-EXTENDED CLIP-L TEXT ENCODER") |
| print("=" * 70) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f" Device: {device}") |
| if torch.cuda.is_available(): |
| print(f" GPU: {torch.cuda.get_device_name()}") |
| print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
|
|
| |
| config = MemoryCLIPConfig() |
| model = MemoryExtendedCLIP(config) |
| model.setup(device) |
| model = model.to(device) |
|
|
| |
| from transformers import AutoModel, AutoTokenizer |
| print(f"\n Loading ModernBERT-large...") |
| modern_model = AutoModel.from_pretrained( |
| config.teacher_model, torch_dtype=torch.float16).to(device) |
| modern_model.eval() |
| for p in modern_model.parameters(): |
| p.requires_grad = False |
| modern_tok = AutoTokenizer.from_pretrained(config.teacher_model) |
| print(f" {sum(p.numel() for p in modern_model.parameters()):,} params (frozen)") |
|
|
| |
| print(f"\n Loading long captions...") |
| from datasets import load_dataset |
|
|
| train_captions = [] |
|
|
| |
| |
| try: |
| print(" Loading CaptionEmporium/conceptual-captions-cc12m-llavanext...") |
| ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext", |
| split="train", streaming=True) |
| for row in ds: |
| |
| cap = row.get("caption_llava", "") |
| if isinstance(cap, str) and len(cap) > 100: |
| train_captions.append(cap) |
| if len(train_captions) >= TCFG.max_train_samples: |
| break |
| print(f" Got {len(train_captions)} long captions from CC12M-llavanext") |
| except Exception as e: |
| print(f" CC12M-llavanext failed: {e}") |
|
|
| |
| if len(train_captions) < 1000: |
| try: |
| print(" Trying Lin-Chen/ShareGPT4V...") |
| ds = load_dataset("Lin-Chen/ShareGPT4V", |
| "ShareGPT4V", split="train", streaming=True) |
| for row in ds: |
| convs = row.get("conversations", []) |
| for c in convs: |
| val = c.get("value", "") |
| if isinstance(val, str) and len(val) > 100 and "<image>" not in val: |
| train_captions.append(val) |
| if len(train_captions) >= TCFG.max_train_samples: |
| break |
| print(f" Got {len(train_captions)} from ShareGPT4V") |
| except Exception as e: |
| print(f" ShareGPT4V failed: {e}") |
|
|
| |
| if len(train_captions) < 1000: |
| try: |
| print(" Trying lmms-lab/LLaVA-ReCap-CC3M...") |
| ds = load_dataset("lmms-lab/LLaVA-ReCap-CC3M", |
| split="train", streaming=True) |
| for row in ds: |
| cap = row.get("caption", row.get("text", "")) |
| if isinstance(cap, str) and len(cap) > 100: |
| train_captions.append(cap) |
| if len(train_captions) >= TCFG.max_train_samples: |
| break |
| print(f" Got {len(train_captions)} from LLaVA-ReCap-CC3M") |
| except Exception as e: |
| print(f" All fallbacks failed: {e}") |
|
|
| train_captions = train_captions[:TCFG.max_train_samples] |
|
|
| |
| if train_captions: |
| from transformers import CLIPTokenizer |
| tok_temp = CLIPTokenizer.from_pretrained(config.clip_model) |
| lengths = [len(tok_temp.encode(c)) for c in train_captions[:500]] |
| print(f" Caption token lengths (CLIP tokenizer, sample of {len(lengths)}):") |
| print(f" mean={np.mean(lengths):.0f} median={np.median(lengths):.0f} " |
| f"max={max(lengths)} >77: {sum(1 for l in lengths if l > 77)/len(lengths):.1%}") |
| del tok_temp |
|
|
| print(f" {len(train_captions)} captions loaded") |
| print(f" Example: {train_captions[0][:100]}...") |
|
|
| |
| align_captions = train_captions[:TCFG.procrustes_n_samples] |
| compute_and_init_procrustes( |
| model, modern_model, modern_tok, align_captions, device) |
|
|
| |
| train(model, modern_model, modern_tok, train_captions) |
|
|
| print("\nDone.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |