| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import gc |
| import math |
| import os |
| import json |
| import time |
| from dataclasses import dataclass, asdict |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.tensorboard import SummaryWriter |
| from tqdm import tqdm |
| from safetensors.torch import save_file as safetensors_save, load_file |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class TrainSeqConfig: |
| |
| max_train_samples: int = 50000 |
| max_val_samples: int = 2000 |
| min_caption_length: int = 100 |
|
|
| |
| phase1_epochs: int = 5 |
| phase1_lr_seq: float = 2e-3 |
| phase1_lr_proj: float = 1e-3 |
|
|
| |
| phase2_epochs: int = 5 |
| phase2_lr_bank: float = 5e-4 |
| phase2_lr_output: float = 2e-4 |
| phase2_lr_proj: float = 5e-4 |
| phase2_lr_seq: float = 1e-3 |
|
|
| |
| batch_size: int = 64 |
| min_lr: float = 1e-6 |
| weight_decay: float = 0.01 |
| grad_clip: float = 1.0 |
| warmup_steps: int = 200 |
|
|
| |
| modern_weight: float = 1.0 |
| procrustes_weight: float = 0.3 |
| cv_weight: float = 0.05 |
| temperature: float = 0.07 |
|
|
| |
| sequence_weight: float = 1.0 |
| sequence_cosine_weight: float = 0.5 |
|
|
| |
| modern_max_len: int = 4096 |
| procrustes_n_samples: int = 300 |
|
|
| |
| v1_checkpoint: str = "" |
| v1_repo_id: str = "AbstractPhil/geolip-clip-vit-large-patch14-ctx576" |
| v1_filename: str = "model.safetensors" |
|
|
| |
| checkpoint_dir: str = "/home/claude/memory_clip_seq_checkpoints" |
| tensorboard_dir: str = "/home/claude/memory_clip_seq_tb" |
| metrics_file: str = "/home/claude/memory_clip_seq_checkpoints/metrics.json" |
| log_every: int = 20 |
| eval_every: int = 200 |
|
|
|
|
| TCFG = TrainSeqConfig() |
|
|
|
|
| |
| |
| |
|
|
| def cayley_menger_vol2(pts): |
| with torch.amp.autocast("cuda", enabled=False): |
| pts = pts.float() |
| diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| d2 = (diff * diff).sum(-1) |
| B, V, _ = d2.shape |
| cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| s = (-1.0)**V; f = math.factorial(V-1) |
| return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) |
|
|
| def pentachoron_cv(embeddings, n_samples=16): |
| B = embeddings.shape[0] |
| if B < 5: |
| return torch.tensor(0.0, device=embeddings.device) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(B, device=embeddings.device)[:5] |
| v2 = cayley_menger_vol2(embeddings[idx].unsqueeze(0)) |
| vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) |
| stacked = torch.stack(vols) |
| return stacked.std() / (stacked.mean() + 1e-8) |
|
|
| def procrustes_alignment_loss(emb_a, emb_b): |
| with torch.amp.autocast("cuda", enabled=False): |
| A = F.normalize(emb_a.float(), dim=-1) |
| B_e = F.normalize(emb_b.float(), dim=-1) |
| A = A - A.mean(0, keepdim=True) |
| B_e = B_e - B_e.mean(0, keepdim=True) |
| S = torch.linalg.svdvals(A.T @ B_e) |
| N, D = A.shape |
| return 1.0 - S.sum() / (math.sqrt(N) * D) |
|
|
|
|
| |
| |
| |
|
|
| def infonce_loss(emb_a, emb_b, temperature=0.07): |
| a = F.normalize(emb_a, dim=-1) |
| b = F.normalize(emb_b, dim=-1) |
| logits = (a @ b.T) / temperature |
| labels = torch.arange(logits.shape[0], device=logits.device) |
| loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 |
| with torch.no_grad(): |
| acc = (logits.argmax(-1) == labels).float().mean().item() |
| top5 = logits.topk(min(5, logits.shape[1]), dim=-1).indices |
| acc5 = (top5 == labels.unsqueeze(-1)).any(-1).float().mean().item() |
| return loss, acc, acc5 |
|
|
|
|
| def batch_cv_loss(all_anchors, n_reals, cv_target=0.20): |
| device = all_anchors.device |
| B = all_anchors.shape[0] |
| total_loss = torch.tensor(0.0, device=device) |
| total_cv = 0.0; n_valid = 0 |
| per_sample_cv = [] |
| for b in range(B): |
| n = n_reals[b].item() if isinstance(n_reals[b], torch.Tensor) else n_reals[b] |
| if n < 5: |
| continue |
| cv_val = pentachoron_cv(all_anchors[b, :n], n_samples=16) |
| total_loss = total_loss + (cv_val - cv_target).abs() |
| total_cv += cv_val.item() |
| per_sample_cv.append(cv_val.item()) |
| n_valid += 1 |
| stats = { |
| "cv_raw": total_cv / max(n_valid, 1), |
| "cv_std": float(np.std(per_sample_cv)) if per_sample_cv else 0.0, |
| "cv_n_valid": n_valid, |
| } |
| return total_loss / max(n_valid, 1), stats |
|
|
|
|
| def sequence_reconstruction_loss(pred_seq, target_seq): |
| """ |
| pred_seq: (B, 77, 768) β reconstructed sequence |
| target_seq: (B, 77, 768) β teacher projected sequence |
| |
| Returns: |
| mse_loss: mean squared error |
| cos_loss: 1 - mean per-position cosine similarity |
| mean_cos: scalar metric (not differentiable) |
| """ |
| mse = F.mse_loss(pred_seq, target_seq) |
|
|
| |
| pred_norm = F.normalize(pred_seq, dim=-1) |
| tgt_norm = F.normalize(target_seq, dim=-1) |
| cos_sim = (pred_norm * tgt_norm).sum(-1) |
| cos_loss = 1.0 - cos_sim.mean() |
|
|
| with torch.no_grad(): |
| mean_cos = cos_sim.mean().item() |
|
|
| return mse, cos_loss, mean_cos |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def teacher_forward(model, tokenizer, texts, device, max_len): |
| """Returns pooled (B, 1024) from ModernBERT. Sequence target comes from CLIP.""" |
| inputs = tokenizer(texts, max_length=max_len, padding=True, |
| truncation=True, return_tensors="pt").to(device) |
| out = model(**inputs) |
| mask = inputs.attention_mask.unsqueeze(-1).float() |
| pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1) |
| return pooled |
|
|
|
|
| |
| |
| |
|
|
| def make_param_groups_phase1(model): |
| """Phase 1: Only train sequence head + teacher seq projector.""" |
| seq_params = [] |
| for name, param in model.named_parameters(): |
| param.requires_grad = False |
| for name, param in model.named_parameters(): |
| if "sequence_reconstructor" in name: |
| param.requires_grad = True |
| seq_params.append(param) |
| |
| for name, param in model.named_parameters(): |
| if "proj_modern" in name: |
| param.requires_grad = True |
|
|
| proj_params = [p for n, p in model.named_parameters() |
| if "proj_modern" in n and p.requires_grad] |
|
|
| groups = [ |
| {"params": seq_params, "lr": TCFG.phase1_lr_seq, "name": "seq_head", |
| "weight_decay": TCFG.weight_decay}, |
| {"params": proj_params, "lr": TCFG.phase1_lr_proj, "name": "proj", |
| "weight_decay": TCFG.weight_decay}, |
| ] |
| for g in groups: |
| n = sum(p.numel() for p in g["params"]) |
| print(f" {g['name']:12s}: {n:>10,} params @ lr={g['lr']}") |
| return groups |
|
|
|
|
| def make_param_groups_phase2(model): |
| """Phase 2: Unfreeze everything, differential LRs.""" |
| |
| for name, param in model.named_parameters(): |
| if "clip_text" not in name and "_clip_text" not in name: |
| param.requires_grad = True |
|
|
| bank_names = {"bank.", "clip_cross_attn", "clip_cross_norms", |
| "clip_cross_ffns", "clip_cross_ffn_norms"} |
| seq_names = {"sequence_reconstructor"} |
| proj_names = {"proj_modern"} |
|
|
| bank_p, seq_p, proj_p, output_p = [], [], [], [] |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if any(name.startswith(s) or s in name for s in seq_names): |
| seq_p.append(param) |
| elif any(name.startswith(s) or s in name for s in proj_names): |
| proj_p.append(param) |
| elif any(name.startswith(s) or s in name for s in bank_names): |
| bank_p.append(param) |
| else: |
| output_p.append(param) |
|
|
| groups = [ |
| {"params": bank_p, "lr": TCFG.phase2_lr_bank, "name": "bank", |
| "weight_decay": TCFG.weight_decay}, |
| {"params": seq_p, "lr": TCFG.phase2_lr_seq, "name": "seq_head", |
| "weight_decay": TCFG.weight_decay}, |
| {"params": proj_p, "lr": TCFG.phase2_lr_proj, "name": "proj", |
| "weight_decay": TCFG.weight_decay}, |
| {"params": output_p, "lr": TCFG.phase2_lr_output, "name": "output", |
| "weight_decay": TCFG.weight_decay}, |
| ] |
| for g in groups: |
| n = sum(p.numel() for p in g["params"]) |
| print(f" {g['name']:12s}: {n:>10,} params @ lr={g['lr']}") |
| return groups |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def compute_and_init_procrustes(student_model, modern_model, modern_tok, |
| captions, device): |
| print(f"\n Computing static Procrustes on {len(captions)} captions...") |
| student_embs, modern_embs = [], [] |
| clip_tok = student_model.clip_tokenizer |
| for i in range(0, len(captions), 16): |
| batch = captions[i:i+16] |
| tokens = clip_tok(batch, max_length=77, padding=True, |
| truncation=True, return_tensors="pt").to(device) |
| clip_out = student_model.clip_text( |
| input_ids=tokens.input_ids, |
| attention_mask=tokens.attention_mask, |
| output_hidden_states=False) |
| student_embs.append(clip_out.pooler_output.cpu()) |
| pooled = teacher_forward(modern_model, modern_tok, batch, |
| device, TCFG.modern_max_len) |
| modern_embs.append(pooled.cpu()) |
| student_all = torch.cat(student_embs) |
| modern_all = torch.cat(modern_embs) |
| print(f" Student: {student_all.shape}, Teacher: {modern_all.shape}") |
| X = student_all.float(); Y = modern_all.float() |
| mu_x, mu_y = X.mean(0), Y.mean(0) |
| Xc, Yc = X - mu_x, Y - mu_y |
| if Xc.shape[1] < Yc.shape[1]: |
| pad = torch.zeros(Xc.shape[0], Yc.shape[1] - Xc.shape[1]) |
| Xc = torch.cat([Xc, pad], dim=1) |
| mu_x = torch.cat([mu_x, torch.zeros(Yc.shape[1] - mu_x.shape[0])]) |
| U, S, Vt = torch.linalg.svd(Xc.T @ Yc) |
| R = (U @ Vt).T |
| cos_before = F.cosine_similarity(Xc, Yc, dim=-1).mean() |
| cos_after = F.cosine_similarity((Xc @ R.T), Yc, dim=-1).mean() |
| print(f" Procrustes: cos {cos_before:.4f} β {cos_after:.4f}") |
| |
| if hasattr(student_model.proj_modern, 'init_from_procrustes'): |
| student_model.proj_modern.init_from_procrustes(R, mu_x, mu_y) |
| return {"cos_before": cos_before.item(), "cos_after": cos_after.item()} |
|
|
|
|
| |
| |
| |
|
|
| def load_v1_weights(model, device): |
| """Load v1 memory system weights into the expanded seq model. |
| Tries local path first, then downloads from HuggingFace.""" |
| checkpoint_path = TCFG.v1_checkpoint |
|
|
| |
| if checkpoint_path and os.path.exists(checkpoint_path): |
| print(f" Loading v1 weights (local): {checkpoint_path}") |
| else: |
| |
| from huggingface_hub import hf_hub_download |
| print(f" Downloading v1 weights from {TCFG.v1_repo_id}/{TCFG.v1_filename}...") |
| checkpoint_path = hf_hub_download( |
| repo_id=TCFG.v1_repo_id, |
| filename=TCFG.v1_filename) |
| print(f" Downloaded to: {checkpoint_path}") |
|
|
| state = load_file(checkpoint_path, device=str(device)) |
| missing, unexpected = model.load_state_dict(state, strict=False) |
|
|
| n_loaded = len(state) - len(unexpected) |
| print(f" Loaded: {n_loaded} tensors from v1") |
| print(f" Missing (new modules): {len(missing)}") |
| if missing: |
| new_module_keys = [k for k in missing if "sequence_reconstructor" in k |
| ] |
| other_missing = [k for k in missing if k not in new_module_keys] |
| print(f" Seq head (expected new): {len(new_module_keys)}") |
| if other_missing: |
| print(f" Other (check!): {other_missing[:5]}") |
| if unexpected: |
| print(f" Unexpected (v1 buffers, ignorable): {len(unexpected)}") |
| return True |
|
|
|
|
| |
| |
| |
|
|
| def train_phase(model, modern_model, modern_tok, train_captions, val_captions, |
| param_groups, n_epochs, phase_name, writer, all_metrics, |
| global_step=0): |
| """ |
| Single training phase. Used for both phase 1 and phase 2. |
| """ |
| device = next(model.parameters()).device |
| optimizer = torch.optim.AdamW(param_groups) |
| all_params = [p for g in param_groups for p in g["params"]] |
|
|
| n_batches = len(train_captions) // TCFG.batch_size |
| total_steps = n_batches * n_epochs |
| scheduler = torch.optim.lr_scheduler.SequentialLR( |
| optimizer, |
| [torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, |
| total_iters=TCFG.warmup_steps), |
| torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=max(total_steps, 1), eta_min=TCFG.min_lr)], |
| milestones=[TCFG.warmup_steps]) |
|
|
| scaler = torch.amp.GradScaler() |
| clip_tokenizer = model.clip_tokenizer |
| best_val_loss = float("inf") |
|
|
| print(f"\n {phase_name}: {sum(p.numel() for p in all_params):,} trainable params") |
| print(f" {n_batches} batches/epoch Γ {TCFG.batch_size}") |
|
|
| |
| _segment_text = segment_text |
|
|
| for epoch in range(n_epochs): |
| model.train() |
| perm = np.random.permutation(len(train_captions)) |
| losses = {"total": 0, "modern": 0, "procrustes": 0, "cv": 0, |
| "seq_mse": 0, "seq_cos": 0} |
| metrics = {"modern_acc": 0, "modern_acc5": 0, |
| "cv_raw": 0, "seq_cos_sim": 0, "n_segments_avg": 0} |
| n = 0 |
| t0 = time.time() |
|
|
| pbar = tqdm(range(0, len(train_captions), TCFG.batch_size), |
| desc=f"{phase_name} E{epoch+1}/{n_epochs}", unit="batch") |
|
|
| for batch_start in pbar: |
| idx = perm[batch_start:batch_start + TCFG.batch_size] |
| if len(idx) < 2: |
| continue |
| batch_captions = [train_captions[i] for i in idx] |
| B = len(batch_captions) |
|
|
| |
| with torch.no_grad(): |
| with torch.amp.autocast("cuda"): |
| modern_cls = teacher_forward( |
| modern_model, modern_tok, batch_captions, |
| device, TCFG.modern_max_len) |
|
|
| |
| state = model.init_state(B, device) |
| all_segments = [_segment_text(cap, clip_tokenizer, |
| model.config.max_content_tokens, |
| model.config.segment_overlap, |
| model.config.max_segments) |
| for cap in batch_captions] |
| max_segs = max(len(s) for s in all_segments) |
| n_segs = [len(s) for s in all_segments] |
|
|
| for seg_k in range(max_segs): |
| batch_ids, batch_masks = [], [] |
| for b in range(B): |
| if seg_k < len(all_segments[b]): |
| batch_ids.append(all_segments[b][seg_k]["input_ids"]) |
| batch_masks.append(all_segments[b][seg_k]["attention_mask"]) |
| else: |
| batch_ids.append(torch.zeros(77, dtype=torch.long)) |
| batch_masks.append(torch.zeros(77, dtype=torch.long)) |
| ids = torch.stack(batch_ids).to(device) |
| masks = torch.stack(batch_masks).to(device) |
| with torch.amp.autocast("cuda"): |
| fused_output, state = model.forward_segment(ids, masks, state) |
|
|
| student_cls = fused_output |
|
|
| |
| bank_anchors = state["bank"]["anchors"] |
| |
| all_anchors = torch.zeros(B, max_segs, model.config.anchor_dim, device=device) |
| n_written = min(bank_anchors.shape[1], max_segs) |
| all_anchors[:, :n_written] = bank_anchors[:, :n_written] |
|
|
| |
| with torch.amp.autocast("cuda"): |
| proj_m = model.proj_modern(student_cls) |
| l_modern, acc_m, acc5_m = infonce_loss( |
| proj_m, modern_cls, TCFG.temperature) |
| l_procrustes = procrustes_alignment_loss( |
| student_cls, modern_cls[:, :model.config.clip_hidden]) |
| n_reals_t = torch.tensor(n_segs, device=device) |
| l_cv, cv_stats = batch_cv_loss( |
| all_anchors, n_reals_t, model.config.cv_target) |
|
|
| |
| |
| |
| |
| with torch.no_grad(): |
| clip_inputs = clip_tokenizer( |
| batch_captions, max_length=77, padding="max_length", |
| truncation=True, return_tensors="pt").to(device) |
| with torch.amp.autocast("cuda"): |
| clip_target_out = model.clip_text( |
| input_ids=clip_inputs.input_ids, |
| attention_mask=clip_inputs.attention_mask, |
| output_hidden_states=False, return_dict=True) |
| clip_target_seq = clip_target_out.last_hidden_state |
|
|
| with torch.amp.autocast("cuda"): |
| |
| recon_seq = model.reconstruct_sequence(state) |
|
|
| l_seq_mse, l_seq_cos, seq_cos_metric = sequence_reconstruction_loss( |
| recon_seq, clip_target_seq.detach()) |
|
|
| |
| with torch.amp.autocast("cuda"): |
| loss = (TCFG.modern_weight * l_modern + |
| TCFG.procrustes_weight * l_procrustes + |
| TCFG.cv_weight * l_cv + |
| TCFG.sequence_weight * l_seq_mse + |
| TCFG.sequence_cosine_weight * l_seq_cos) |
|
|
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(all_params, TCFG.grad_clip) |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad(set_to_none=True) |
| scheduler.step() |
| global_step += 1 |
|
|
| |
| losses["total"] += loss.item() |
| losses["modern"] += l_modern.item() |
| losses["procrustes"] += l_procrustes.item() |
| losses["cv"] += l_cv.item() |
| losses["seq_mse"] += l_seq_mse.item() |
| losses["seq_cos"] += l_seq_cos.item() |
| metrics["modern_acc"] += acc_m |
| metrics["modern_acc5"] += acc5_m |
| metrics["cv_raw"] += cv_stats.get("cv_raw", 0) |
| metrics["seq_cos_sim"] += seq_cos_metric |
| metrics["n_segments_avg"] += np.mean(n_segs) |
| n += 1 |
|
|
| d = max(n, 1) |
| pbar.set_postfix( |
| loss=f"{losses['total']/d:.3f}", |
| m_acc=f"{metrics['modern_acc']/d:.3f}", |
| s_cos=f"{metrics['seq_cos_sim']/d:.3f}", |
| cv=f"{metrics['cv_raw']/d:.3f}") |
|
|
| |
| if global_step % TCFG.log_every == 0: |
| writer.add_scalar(f"{phase_name}/loss", losses["total"]/d, global_step) |
| writer.add_scalar(f"{phase_name}/modern_loss", losses["modern"]/d, global_step) |
| writer.add_scalar(f"{phase_name}/seq_mse", losses["seq_mse"]/d, global_step) |
| writer.add_scalar(f"{phase_name}/seq_cos_loss", losses["seq_cos"]/d, global_step) |
| writer.add_scalar(f"{phase_name}/seq_cos_sim", metrics["seq_cos_sim"]/d, global_step) |
| writer.add_scalar(f"{phase_name}/m_acc", metrics["modern_acc"]/d, global_step) |
| writer.add_scalar(f"{phase_name}/cv_raw", metrics["cv_raw"]/d, global_step) |
| all_metrics["steps"].append({ |
| "step": global_step, "phase": phase_name, |
| "epoch": epoch + 1, |
| "loss": losses["total"]/d, |
| "m_acc": metrics["modern_acc"]/d, |
| "seq_cos": metrics["seq_cos_sim"]/d, |
| "cv_raw": metrics["cv_raw"]/d, |
| }) |
|
|
| pbar.close() |
| elapsed = time.time() - t0 |
| d = max(n, 1) |
|
|
| epoch_summary = { |
| "phase": phase_name, "epoch": epoch + 1, |
| "elapsed_s": elapsed, |
| "loss": losses["total"]/d, |
| "modern_loss": losses["modern"]/d, |
| "seq_mse": losses["seq_mse"]/d, |
| "seq_cos_loss": losses["seq_cos"]/d, |
| "m_acc": metrics["modern_acc"]/d, |
| "m_acc5": metrics["modern_acc5"]/d, |
| "seq_cos_sim": metrics["seq_cos_sim"]/d, |
| "cv_raw": metrics["cv_raw"]/d, |
| "global_step": global_step, |
| } |
|
|
| all_metrics["epochs"].append(epoch_summary) |
|
|
| print(f"\n {phase_name} E{epoch+1}: {elapsed:.0f}s " |
| f"loss={epoch_summary['loss']:.4f} " |
| f"m_acc={epoch_summary['m_acc']:.3f} " |
| f"seq_cos={epoch_summary['seq_cos_sim']:.3f} " |
| f"cv={epoch_summary['cv_raw']:.3f}") |
|
|
| |
| save_checkpoint(model, optimizer, epoch + 1, global_step, phase_name, |
| os.path.join(TCFG.checkpoint_dir, f"{phase_name}_e{epoch+1:02d}")) |
|
|
| with open(TCFG.metrics_file, "w") as f: |
| json.dump(all_metrics, f, indent=2, default=str) |
|
|
| return global_step |
|
|
|
|
| def save_checkpoint(model, optimizer, epoch, global_step, phase, path): |
| os.makedirs(path, exist_ok=True) |
| state = {} |
| for name, param in model.named_parameters(): |
| if param.requires_grad: |
| state[name] = param.data.contiguous().cpu() |
| for name, buf in model.named_buffers(): |
| state[f"buffer.{name}"] = buf.contiguous().cpu() |
| safetensors_save(state, os.path.join(path, "memory_system.safetensors")) |
| torch.save({"optimizer": optimizer.state_dict() if optimizer else {}, |
| "epoch": epoch, |
| "global_step": global_step, "phase": phase}, |
| os.path.join(path, "training_state.pt")) |
| model_cfg = model.config.to_dict() if hasattr(model.config, 'to_dict') else {} |
| config_data = {"model": model_cfg, "training": asdict(TCFG)} |
| with open(os.path.join(path, "config.json"), "w") as f: |
| json.dump(config_data, f, indent=2, default=str) |
|
|
|
|
| |
| |
| |
|
|
| def load_long_captions(max_train, max_val, min_length=100): |
| from datasets import load_dataset |
| print(f" Loading CaptionEmporium/conceptual-captions-cc12m-llavanext...") |
| ds = load_dataset("CaptionEmporium/conceptual-captions-cc12m-llavanext", |
| split="train", streaming=True) |
| captions = [] |
| for row in ds: |
| cap = row.get("caption_llava", "") |
| if isinstance(cap, str) and len(cap) > min_length: |
| captions.append(cap) |
| if len(captions) >= max_train + max_val: |
| break |
| train_caps = captions[:max_train] |
| val_caps = captions[max_train:max_train + max_val] |
| print(f" Train: {len(train_caps)}, Val: {len(val_caps)}") |
| return train_caps, val_caps |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| print("=" * 70) |
| print("TRAINING: MEMORY-CLIP-SEQ (SEQUENCE RECONSTRUCTION)") |
| print("=" * 70) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f" Device: {device}") |
| if torch.cuda.is_available(): |
| print(f" GPU: {torch.cuda.get_device_name()}") |
| print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
|
|
| |
| config = MemoryCLIPSeqConfig() |
| model = MemoryCLIPSeqModel(config).to(device) |
|
|
| |
| _ = model.clip_text |
| _ = model.clip_tokenizer |
|
|
| |
| load_v1_weights(model, device) |
| print(" v1 memory system weights loaded") |
|
|
| |
| from transformers import AutoModel, AutoTokenizer |
| print(f"\n Loading ModernBERT-large...") |
| modern_model = AutoModel.from_pretrained( |
| config.teacher_model, torch_dtype=torch.float16).to(device) |
| modern_model.eval() |
| for p in modern_model.parameters(): |
| p.requires_grad = False |
| modern_tok = AutoTokenizer.from_pretrained(config.teacher_model) |
| print(f" {sum(p.numel() for p in modern_model.parameters()):,} params (frozen)") |
|
|
| |
| train_captions, val_captions = load_long_captions( |
| TCFG.max_train_samples, TCFG.max_val_samples, TCFG.min_caption_length) |
|
|
| from transformers import CLIPTokenizer |
| tok_temp = CLIPTokenizer.from_pretrained(config.clip_model) |
| lengths = [len(tok_temp.encode(c)) for c in train_captions[:500]] |
| print(f" Caption tokens (sample 500): mean={np.mean(lengths):.0f} " |
| f"median={np.median(lengths):.0f} max={max(lengths)} " |
| f">77: {sum(1 for l in lengths if l > 77)/len(lengths):.1%}") |
| del tok_temp |
|
|
| |
| compute_and_init_procrustes( |
| model, modern_model, modern_tok, |
| train_captions[:TCFG.procrustes_n_samples], device) |
|
|
| |
| os.makedirs(TCFG.checkpoint_dir, exist_ok=True) |
| os.makedirs(TCFG.tensorboard_dir, exist_ok=True) |
| writer = SummaryWriter(log_dir=TCFG.tensorboard_dir) |
| all_metrics = { |
| "config": {**{k: v for k, v in config.to_dict().items() |
| if not k.startswith("_")}, |
| **asdict(TCFG)}, |
| "epochs": [], "steps": [], |
| } |
|
|
| global_step = 0 |
|
|
| |
| |
| |
| print(f"\n{'='*70}") |
| print(f"PHASE 1: Sequence head training ({TCFG.phase1_epochs} epochs)") |
| print(f" v1 memory system: FROZEN") |
| print(f" Sequence reconstructor: TRAINING") |
| print(f"{'='*70}") |
|
|
| phase1_groups = make_param_groups_phase1(model) |
| global_step = train_phase( |
| model, modern_model, modern_tok, |
| train_captions, val_captions, |
| phase1_groups, TCFG.phase1_epochs, |
| "phase1", writer, all_metrics, global_step) |
|
|
| save_checkpoint(model, None, TCFG.phase1_epochs, global_step, "phase1", |
| os.path.join(TCFG.checkpoint_dir, "phase1_final")) |
|
|
| |
| |
| |
| print(f"\n{'='*70}") |
| print(f"PHASE 2: Joint fine-tune ({TCFG.phase2_epochs} epochs)") |
| print(f" All trainable modules: TRAINING") |
| print(f" v1 components: reduced LR") |
| print(f"{'='*70}") |
|
|
| phase2_groups = make_param_groups_phase2(model) |
| global_step = train_phase( |
| model, modern_model, modern_tok, |
| train_captions, val_captions, |
| phase2_groups, TCFG.phase2_epochs, |
| "phase2", writer, all_metrics, global_step) |
|
|
| |
| save_checkpoint(model, None, TCFG.phase1_epochs + TCFG.phase2_epochs, |
| global_step, "final", |
| os.path.join(TCFG.checkpoint_dir, "final")) |
|
|
| all_metrics["final"] = { |
| "total_steps": global_step, |
| "final_m_acc": all_metrics["epochs"][-1]["m_acc"], |
| "final_seq_cos": all_metrics["epochs"][-1]["seq_cos_sim"], |
| "final_cv": all_metrics["epochs"][-1]["cv_raw"], |
| } |
| with open(TCFG.metrics_file, "w") as f: |
| json.dump(all_metrics, f, indent=2, default=str) |
|
|
| writer.flush() |
| writer.close() |
|
|
| print(f"\n{'='*70}") |
| print(f"FINAL:") |
| final = all_metrics["epochs"][-1] |
| print(f" m_acc: {final['m_acc']:.4f}") |
| print(f" seq_cos: {final['seq_cos_sim']:.4f}") |
| print(f" CV: {final['cv_raw']:.4f}") |
| print(f"{'='*70}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |