| |
| """ |
| MR-JEPA Phase 3 Training — Enriched Evidence + Generative Decoder |
| |
| Loads the best Phase 2 checkpoint and: |
| 1. Enables OCR token injection (from TextVQA ocr_tokens or simple extraction) |
| 2. Trains the generative head on open-ended benchmarks (DocVQA, ChartQA, TextVQA) |
| 3. Continues JEPA + discriminative training on ScienceQA |
| 4. Full end-to-end fine-tuning of all components |
| |
| Training data: |
| - ScienceQA train (MC, JEPA + task loss) |
| - DocVQA validation (open-ended, generative loss) |
| - ChartQA test (open-ended, generative loss) |
| - TextVQA train (open-ended, generative loss, OCR tokens available) |
| |
| Eval: |
| - ScienceQA test (accuracy) |
| - DocVQA validation (ANLS) |
| - ChartQA test (relaxed accuracy) |
| - TextVQA validation (VQA accuracy) |
| |
| Phase 3 hyperparameters (from ARCHITECTURE.md): |
| LR: 5e-5 (core), 5e-6 (backbone) |
| Batch: 16, grad_accum: 8 |
| Epochs: 10 |
| Cosine schedule + warmup (10%) |
| |
| Usage: |
| python train_phase3.py |
| python train_phase3.py --epochs 10 --core_lr 5e-5 |
| """ |
|
|
| import os |
| import sys |
| import json |
| import math |
| import copy |
| import logging |
| import argparse |
| from collections import defaultdict |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.optim import AdamW |
| from torch.utils.data import Dataset, DataLoader |
| from PIL import Image |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s | %(levelname)s | %(message)s", |
| datefmt="%H:%M:%S", |
| ) |
| log = logging.getLogger("mrjepa-p3") |
|
|
|
|
| |
| |
| |
|
|
| class OpenEndedDataset(Dataset): |
| """Dataset for open-ended VQA benchmarks (Phase 3 generative training).""" |
| |
| def __init__(self, benchmark, split, max_samples=0, transform=None, |
| tokenizer=None, max_len=192, max_gen_len=64): |
| from datasets import load_dataset |
| |
| self.benchmark = benchmark |
| self.transform = transform |
| self.tokenizer = tokenizer |
| self.max_len = max_len |
| self.max_gen_len = max_gen_len |
| |
| log.info(f"Loading {benchmark} {split}...") |
| |
| if benchmark == "docvqa": |
| ds = load_dataset("lmms-lab/DocVQA", "DocVQA", split=split) |
| elif benchmark == "chartqa": |
| ds = load_dataset("lmms-lab/ChartQA", split=split) |
| elif benchmark == "textvqa": |
| ds = load_dataset("lmms-lab/textvqa", split=split) |
| else: |
| raise ValueError(f"Unknown benchmark: {benchmark}") |
| |
| if max_samples > 0: |
| ds = ds.select(range(min(max_samples, len(ds)))) |
| |
| self.data = ds |
| log.info(f"Loaded {len(ds)} samples from {benchmark} {split}") |
| |
| def __len__(self): |
| return len(self.data) |
| |
| def __getitem__(self, idx): |
| row = self.data[idx] |
| |
| |
| img = row.get("image") |
| if img is None: |
| img = Image.new("RGB", (256, 256), "white") |
| else: |
| img = img.convert("RGB") |
| |
| |
| question = row["question"] |
| |
| |
| if self.benchmark == "docvqa": |
| answers = row.get("answers", [""]) |
| answer = answers[0] if answers else "" |
| all_answers = answers |
| elif self.benchmark == "chartqa": |
| answer = str(row.get("answer", "")) |
| all_answers = [answer] |
| elif self.benchmark == "textvqa": |
| answers = row.get("answers", [""]) |
| |
| from collections import Counter |
| answer_counts = Counter(a.lower().strip() for a in answers) |
| answer = answer_counts.most_common(1)[0][0] if answer_counts else "" |
| all_answers = answers |
| else: |
| answer = "" |
| all_answers = [""] |
| |
| |
| ocr_tokens = row.get("ocr_tokens", []) |
| ocr_text = " ".join(ocr_tokens[:50]) if ocr_tokens else "" |
| |
| |
| text = question |
| if ocr_text: |
| text += f" [OCR: {ocr_text}]" |
| |
| return { |
| "image": img, |
| "text": text, |
| "answer": answer, |
| "all_answers": all_answers, |
| "benchmark": self.benchmark, |
| "ocr_text": ocr_text, |
| "question_type": row.get("type", row.get("question_types", [""])), |
| } |
|
|
|
|
| def collate_open_ended(batch, transform, tokenizer, max_len, max_gen_len): |
| """Collate function for open-ended VQA batches.""" |
| images = [s["image"] for s in batch] |
| texts = [s["text"] for s in batch] |
| answers = [s["answer"] for s in batch] |
| |
| |
| if hasattr(transform, '__call__') and not hasattr(transform, 'feature_extractor'): |
| pixel_values = torch.stack([transform(img) for img in images]) |
| else: |
| pixel_values = transform(images=images, return_tensors="pt")["pixel_values"] |
| |
| |
| tok = tokenizer( |
| texts, padding="max_length", truncation=True, |
| max_length=max_len, return_tensors="pt" |
| ) |
| |
| |
| |
| answer_texts = [a if a else " " for a in answers] |
| gen_tok = tokenizer( |
| answer_texts, padding="max_length", truncation=True, |
| max_length=max_gen_len, return_tensors="pt" |
| ) |
| |
| return { |
| "pixel_values": pixel_values, |
| "input_ids": tok["input_ids"], |
| "attention_mask": tok["attention_mask"], |
| "gen_target_ids": gen_tok["input_ids"], |
| "gen_attention_mask": gen_tok["attention_mask"], |
| "batch_size": len(batch), |
| "benchmarks": [s["benchmark"] for s in batch], |
| "all_answers": [s["all_answers"] for s in batch], |
| "question_types": [s.get("question_type", "") for s in batch], |
| } |
|
|
|
|
| |
| |
| |
|
|
| class GenerativeDecoderLayer(nn.Module): |
| """Transformer decoder layer with cross-attention to latent state and evidence.""" |
| |
| def __init__(self, hidden_dim, num_heads, dropout=0.1): |
| super().__init__() |
| |
| self.self_attn = nn.MultiheadAttention( |
| embed_dim=hidden_dim, num_heads=num_heads, |
| dropout=dropout, batch_first=True, |
| ) |
| self.self_attn_norm = nn.LayerNorm(hidden_dim) |
| |
| |
| self.state_cross_attn = nn.MultiheadAttention( |
| embed_dim=hidden_dim, num_heads=num_heads, |
| dropout=dropout, batch_first=True, |
| ) |
| self.state_cross_norm = nn.LayerNorm(hidden_dim) |
| |
| |
| self.evidence_cross_attn = nn.MultiheadAttention( |
| embed_dim=hidden_dim, num_heads=num_heads, |
| dropout=dropout, batch_first=True, |
| ) |
| self.evidence_cross_norm = nn.LayerNorm(hidden_dim) |
| |
| |
| self.ffn = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim * 4), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim * 4, hidden_dim), |
| nn.Dropout(dropout), |
| ) |
| self.ffn_norm = nn.LayerNorm(hidden_dim) |
| |
| def forward(self, x, z_final, evidence, causal_mask=None): |
| |
| r = x |
| x2 = self.self_attn_norm(x) |
| x2, _ = self.self_attn(x2, x2, x2, attn_mask=causal_mask) |
| x = r + x2 |
| |
| |
| r = x |
| x2 = self.state_cross_norm(x) |
| x2, _ = self.state_cross_attn(x2, z_final, z_final) |
| x = r + x2 |
| |
| |
| r = x |
| x2 = self.evidence_cross_norm(x) |
| x2, _ = self.evidence_cross_attn(x2, evidence, evidence) |
| x = r + x2 |
| |
| |
| r = x |
| x = r + self.ffn(self.ffn_norm(x)) |
| return x |
|
|
|
|
| class GenerativeHead(nn.Module): |
| """ |
| Lightweight generative decoder for Phase 3. |
| |
| Cross-attends to z_K and evidence memory to generate short answers. |
| Uses the text encoder's tokenizer vocabulary. |
| """ |
| |
| def __init__(self, hidden_dim, vocab_size, num_layers=4, num_heads=12, |
| max_gen_len=64, dropout=0.1): |
| super().__init__() |
| self.hidden_dim = hidden_dim |
| self.vocab_size = vocab_size |
| self.max_gen_len = max_gen_len |
| |
| |
| self.token_embedding = nn.Embedding(vocab_size, hidden_dim) |
| self.pos_embedding = nn.Embedding(max_gen_len, hidden_dim) |
| |
| |
| self.layers = nn.ModuleList([ |
| GenerativeDecoderLayer(hidden_dim, num_heads, dropout) |
| for _ in range(num_layers) |
| ]) |
| |
| |
| self.output_norm = nn.LayerNorm(hidden_dim) |
| self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False) |
| |
| |
| self.lm_head.weight = self.token_embedding.weight |
| |
| def forward(self, z_final, evidence, target_ids, pad_token_id=0): |
| """Teacher-forced forward pass.""" |
| B, seq_len = target_ids.shape |
| device = target_ids.device |
| |
| positions = torch.arange(seq_len, device=device).unsqueeze(0) |
| x = self.token_embedding(target_ids) + self.pos_embedding(positions) |
| |
| |
| causal_mask = torch.triu( |
| torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), |
| diagonal=1 |
| ) |
| |
| for layer in self.layers: |
| x = layer(x, z_final, evidence, causal_mask) |
| |
| logits = self.lm_head(self.output_norm(x)) |
| |
| |
| shift_logits = logits[:, :-1].contiguous() |
| shift_labels = target_ids[:, 1:].contiguous() |
| |
| |
| loss = F.cross_entropy( |
| shift_logits.view(-1, self.vocab_size), |
| shift_labels.view(-1), |
| ignore_index=pad_token_id, |
| ) |
| |
| return logits, loss |
| |
| @torch.no_grad() |
| def generate(self, z_final, evidence, start_token_id, max_length=64, eos_token_id=None): |
| """Autoregressive generation.""" |
| B = z_final.size(0) |
| device = z_final.device |
| |
| generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device) |
| |
| for step in range(max_length - 1): |
| seq_len = generated.size(1) |
| positions = torch.arange(seq_len, device=device).unsqueeze(0) |
| x = self.token_embedding(generated) + self.pos_embedding(positions) |
| |
| causal_mask = torch.triu( |
| torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), |
| diagonal=1 |
| ) |
| |
| for layer in self.layers: |
| x = layer(x, z_final, evidence, causal_mask) |
| |
| logits = self.lm_head(self.output_norm(x[:, -1:])) |
| next_token = logits.argmax(dim=-1) |
| generated = torch.cat([generated, next_token], dim=1) |
| |
| if eos_token_id is not None and (next_token == eos_token_id).all(): |
| break |
| |
| return generated |
|
|
|
|
| |
| |
| |
|
|
| def normalized_levenshtein(s1, s2): |
| """Normalized Levenshtein distance.""" |
| s1 = s1.lower().strip() |
| s2 = s2.lower().strip() |
| if s1 == s2: |
| return 0.0 |
| len1, len2 = len(s1), len(s2) |
| if len1 == 0 or len2 == 0: |
| return 1.0 |
| matrix = [[0] * (len2 + 1) for _ in range(len1 + 1)] |
| for i in range(len1 + 1): |
| matrix[i][0] = i |
| for j in range(len2 + 1): |
| matrix[0][j] = j |
| for i in range(1, len1 + 1): |
| for j in range(1, len2 + 1): |
| cost = 0 if s1[i-1] == s2[j-1] else 1 |
| matrix[i][j] = min(matrix[i-1][j]+1, matrix[i][j-1]+1, matrix[i-1][j-1]+cost) |
| return matrix[len1][len2] / max(len1, len2) |
|
|
|
|
| def compute_anls(predictions, ground_truths, threshold=0.5): |
| """ANLS metric for DocVQA.""" |
| scores = [] |
| for pred, gts in zip(predictions, ground_truths): |
| max_score = 0.0 |
| for gt in gts: |
| nl_dist = normalized_levenshtein(str(pred), str(gt)) |
| score = 1.0 - nl_dist if nl_dist < threshold else 0.0 |
| max_score = max(max_score, score) |
| scores.append(max_score) |
| return np.mean(scores) * 100 if scores else 0.0 |
|
|
|
|
| def compute_vqa_accuracy(predictions, ground_truths): |
| """VQA accuracy for TextVQA.""" |
| scores = [] |
| for pred, gts in zip(predictions, ground_truths): |
| pred_norm = str(pred).lower().strip() |
| matching = sum(1 for gt in gts if str(gt).lower().strip() == pred_norm) |
| scores.append(min(matching / 3.0, 1.0)) |
| return np.mean(scores) * 100 if scores else 0.0 |
|
|
|
|
| def compute_relaxed_accuracy(predictions, ground_truths, tolerance=0.05): |
| """Relaxed accuracy for ChartQA.""" |
| correct = [] |
| for pred, gt in zip(predictions, ground_truths): |
| pred_str = str(pred).strip().lower() |
| gt_str = str(gt).strip().lower() |
| try: |
| gt_val = float(gt_str.replace(',', '').replace('%', '')) |
| pred_val = float(pred_str.replace(',', '').replace('%', '')) |
| if gt_val == 0: |
| is_correct = abs(pred_val) <= tolerance |
| else: |
| is_correct = abs(pred_val - gt_val) / abs(gt_val) <= tolerance |
| except (ValueError, ZeroDivisionError): |
| is_correct = pred_str == gt_str |
| correct.append(is_correct) |
| return np.mean(correct) * 100 if correct else 0.0 |
|
|
|
|
| |
| |
| |
|
|
| def download_phase2_checkpoint(hub_model_id, run_name="hybrid_main"): |
| from huggingface_hub import hf_hub_download |
| path = hf_hub_download( |
| repo_id=hub_model_id, |
| filename=f"checkpoints/{run_name}_best.pt", |
| repo_type="model" |
| ) |
| log.info(f"Downloaded Phase 2 checkpoint: {path}") |
| return path |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="MR-JEPA Phase 3 Training") |
| parser.add_argument("--checkpoint", type=str, default=None) |
| parser.add_argument("--hub_model_id", default="JorgeAV/MR-JEPA") |
| parser.add_argument("--run_name", default="hybrid_main_phase3") |
| parser.add_argument("--phase2_run", default="hybrid_main") |
| parser.add_argument("--epochs", type=int, default=10) |
| parser.add_argument("--batch_size", type=int, default=16) |
| parser.add_argument("--grad_accum", type=int, default=8) |
| parser.add_argument("--core_lr", type=float, default=5e-5) |
| parser.add_argument("--backbone_lr", type=float, default=5e-6) |
| parser.add_argument("--text_lr", type=float, default=5e-6) |
| parser.add_argument("--gen_weight", type=float, default=0.5, |
| help="Weight for generative loss relative to task loss") |
| parser.add_argument("--max_eval_samples", type=int, default=500) |
| parser.add_argument("--max_gen_len", type=int, default=64) |
| parser.add_argument("--max_train_samples", type=int, default=0, |
| help="0 = all samples") |
| parser.add_argument("--output_dir", default="./outputs/mrjepa_phase3") |
| parser.add_argument("--trackio_space", default="JorgeAV/MR-JEPA-Trackio") |
| args = parser.parse_args() |
|
|
| |
| log.info("Downloading Phase 1 training script for model definitions...") |
| from huggingface_hub import hf_hub_download |
| p1_script = hf_hub_download( |
| repo_id=args.hub_model_id, filename="train_mrjepa.py", repo_type="model" |
| ) |
| import importlib.util |
| spec = importlib.util.spec_from_file_location("train_mrjepa", p1_script) |
| p1 = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(p1) |
|
|
| |
| if args.checkpoint and os.path.exists(args.checkpoint): |
| ckpt_path = args.checkpoint |
| else: |
| ckpt_path = download_phase2_checkpoint(args.hub_model_id, args.phase2_run) |
| |
| log.info(f"Loading Phase 2 checkpoint: {ckpt_path}") |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| |
| saved_cfg = ckpt["config"] |
| cfg = p1.Config() |
| for k, v in saved_cfg.items(): |
| if hasattr(cfg, k): |
| setattr(cfg, k, v) |
| |
| cfg.phase = 3 |
| cfg.epochs = args.epochs |
| cfg.batch_size = args.batch_size |
| cfg.grad_accum = args.grad_accum |
| cfg.lr = args.core_lr |
| cfg.backbone_lr = args.backbone_lr |
| cfg.output_dir = args.output_dir |
| cfg.run_name = args.run_name |
| cfg.freeze_backbone = True |
| cfg.freeze_text = True |
| cfg.max_eval_samples = args.max_eval_samples |
| cfg.resolve() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| log.info(f"Device: {device}") |
| os.makedirs(cfg.output_dir, exist_ok=True) |
|
|
| |
| import trackio |
| trackio.init( |
| name=args.run_name, |
| project="MR-JEPA", |
| space_id=args.trackio_space, |
| config={ |
| "phase": 3, "epochs": args.epochs, |
| "core_lr": args.core_lr, "backbone_lr": args.backbone_lr, |
| "text_lr": args.text_lr, "gen_weight": args.gen_weight, |
| "batch_size": args.batch_size, "grad_accum": args.grad_accum, |
| "backbone": cfg.backbone, "K": cfg.K, |
| "use_jepa": cfg.use_jepa, "loss_fn": cfg.loss_fn, |
| "max_gen_len": args.max_gen_len, |
| "phase2_best_acc": ckpt.get("eval_acc", "unknown"), |
| } |
| ) |
| log.info(f"Trackio initialized → https://huggingface.co/spaces/{args.trackio_space}") |
|
|
| |
| log.info("Building model...") |
| model = p1.MRJEPAModel(cfg) |
| model.evidence.load_state_dict(ckpt["evidence"]) |
| model.rollout.load_state_dict(ckpt["rollout"]) |
| model.disc.load_state_dict(ckpt["disc"]) |
| model.target.t_ev.load_state_dict(ckpt["target_ev"]) |
| model.target.t_ro.load_state_dict(ckpt["target_ro"]) |
| log.info(f"Loaded Phase 2 weights (epoch={ckpt.get('epoch','?')}, " |
| f"eval_acc={ckpt.get('eval_acc','?')}%)") |
|
|
| |
| tokenizer = model.txt.tokenizer |
| |
| |
| actual_vocab_size = len(tokenizer) |
| log.info(f"Adding generative head: actual_vocab_size={actual_vocab_size}, " |
| f"hidden_dim={cfg.rollout_dim}, layers=4") |
| |
| gen_head = GenerativeHead( |
| hidden_dim=cfg.rollout_dim, |
| vocab_size=actual_vocab_size, |
| num_layers=4, |
| num_heads=cfg.predictor_heads, |
| max_gen_len=args.max_gen_len, |
| dropout=0.1, |
| ) |
| model.gen_head = gen_head |
|
|
| |
| log.info("Unfreezing last 6 visual layers, last 4 text layers") |
| model.vis.unfreeze_last(6) |
| model.txt.unfreeze_last(4) |
|
|
| model = model.to(device) |
| total_p = sum(p.numel() for p in model.parameters()) |
| train_p = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| log.info(f"Total: {total_p:,} | Trainable: {train_p:,} ({100*train_p/total_p:.1f}%)") |
| trackio.log({ |
| "model/total_params": total_p, |
| "model/trainable_params": train_p, |
| "model/trainable_pct": 100 * train_p / total_p |
| }) |
|
|
| |
| transform = model.vis.get_transform() |
| |
| |
| mc_max = args.max_train_samples if args.max_train_samples > 0 else 0 |
| train_mc_ds = p1.ScienceQADataset( |
| "train", max_samples=mc_max, transform=transform, tokenizer=tokenizer, |
| max_len=cfg.max_text_len, max_opts=cfg.max_options |
| ) |
| eval_mc_ds = p1.ScienceQADataset( |
| "test", max_samples=cfg.max_eval_samples, transform=transform, |
| tokenizer=tokenizer, max_len=cfg.max_text_len, max_opts=cfg.max_options |
| ) |
| |
| mc_coll = lambda batch: p1.collate_fn( |
| batch, transform, tokenizer, cfg.max_text_len, cfg.max_options |
| ) |
| train_mc_dl = DataLoader( |
| train_mc_ds, batch_size=cfg.batch_size, shuffle=True, |
| num_workers=2, collate_fn=mc_coll, pin_memory=True, drop_last=True |
| ) |
| eval_mc_dl = DataLoader( |
| eval_mc_ds, batch_size=cfg.batch_size, shuffle=False, |
| num_workers=2, collate_fn=mc_coll, pin_memory=True |
| ) |
| |
| |
| max_open_train = args.max_train_samples if args.max_train_samples > 0 else 5000 |
| |
| |
| train_docvqa_ds = OpenEndedDataset( |
| "docvqa", "validation", max_samples=max_open_train, |
| transform=transform, tokenizer=tokenizer, |
| max_len=cfg.max_text_len, max_gen_len=args.max_gen_len |
| ) |
| |
| train_chartqa_ds = OpenEndedDataset( |
| "chartqa", "test", max_samples=max_open_train, |
| transform=transform, tokenizer=tokenizer, |
| max_len=cfg.max_text_len, max_gen_len=args.max_gen_len |
| ) |
| |
| train_textvqa_ds = OpenEndedDataset( |
| "textvqa", "train", max_samples=max_open_train, |
| transform=transform, tokenizer=tokenizer, |
| max_len=cfg.max_text_len, max_gen_len=args.max_gen_len |
| ) |
| |
| |
| eval_docvqa_ds = OpenEndedDataset( |
| "docvqa", "validation", max_samples=args.max_eval_samples, |
| transform=transform, tokenizer=tokenizer, |
| max_len=cfg.max_text_len, max_gen_len=args.max_gen_len |
| ) |
| eval_chartqa_ds = OpenEndedDataset( |
| "chartqa", "test", max_samples=args.max_eval_samples, |
| transform=transform, tokenizer=tokenizer, |
| max_len=cfg.max_text_len, max_gen_len=args.max_gen_len |
| ) |
| eval_textvqa_ds = OpenEndedDataset( |
| "textvqa", "validation", max_samples=args.max_eval_samples, |
| transform=transform, tokenizer=tokenizer, |
| max_len=cfg.max_text_len, max_gen_len=args.max_gen_len |
| ) |
| |
| open_coll = lambda batch: collate_open_ended( |
| batch, transform, tokenizer, cfg.max_text_len, args.max_gen_len |
| ) |
| |
| train_open_dls = { |
| "docvqa": DataLoader( |
| train_docvqa_ds, batch_size=cfg.batch_size, shuffle=True, |
| num_workers=2, collate_fn=open_coll, pin_memory=True, drop_last=True |
| ), |
| "chartqa": DataLoader( |
| train_chartqa_ds, batch_size=cfg.batch_size, shuffle=True, |
| num_workers=2, collate_fn=open_coll, pin_memory=True, drop_last=True |
| ), |
| "textvqa": DataLoader( |
| train_textvqa_ds, batch_size=cfg.batch_size, shuffle=True, |
| num_workers=2, collate_fn=open_coll, pin_memory=True, drop_last=True |
| ), |
| } |
| |
| eval_open_dls = { |
| "docvqa": DataLoader( |
| eval_docvqa_ds, batch_size=cfg.batch_size, shuffle=False, |
| num_workers=2, collate_fn=open_coll, pin_memory=True |
| ), |
| "chartqa": DataLoader( |
| eval_chartqa_ds, batch_size=cfg.batch_size, shuffle=False, |
| num_workers=2, collate_fn=open_coll, pin_memory=True |
| ), |
| "textvqa": DataLoader( |
| eval_textvqa_ds, batch_size=cfg.batch_size, shuffle=False, |
| num_workers=2, collate_fn=open_coll, pin_memory=True |
| ), |
| } |
|
|
| |
| backbone_params = [p for p in model.vis.parameters() if p.requires_grad] |
| text_params = [p for p in model.txt.parameters() if p.requires_grad] |
| bb_txt_ids = {id(p) for p in backbone_params + text_params} |
| core_params = [p for p in model.parameters() if p.requires_grad and id(p) not in bb_txt_ids] |
| |
| param_groups = [ |
| {"params": core_params, "lr": args.core_lr}, |
| {"params": backbone_params, "lr": args.backbone_lr}, |
| {"params": text_params, "lr": args.text_lr}, |
| ] |
| log.info(f"Optimizer: core={len(core_params)} @ {args.core_lr}, " |
| f"backbone={len(backbone_params)} @ {args.backbone_lr}, " |
| f"text={len(text_params)} @ {args.text_lr}") |
| |
| optimizer = AdamW(param_groups, weight_decay=cfg.weight_decay) |
| |
| |
| mc_steps_per_epoch = len(train_mc_dl) |
| open_steps_per_epoch = sum(len(dl) for dl in train_open_dls.values()) |
| total_batches_per_epoch = mc_steps_per_epoch + open_steps_per_epoch |
| total_steps = cfg.epochs * total_batches_per_epoch // cfg.grad_accum |
| warmup_steps = int(total_steps * 0.1) |
| |
| def lr_lambda(step): |
| if step < warmup_steps: |
| return step / max(warmup_steps, 1) |
| progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1) |
| return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress)) |
| |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
| |
| |
| pad_token_id = tokenizer.pad_token_id |
| if pad_token_id is None: |
| pad_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0 |
| log.info(f"Pad token ID for gen loss: {pad_token_id}") |
| |
| log.info(f"Phase 3: {cfg.epochs} epochs") |
| log.info(f" MC batches/epoch: {mc_steps_per_epoch}") |
| log.info(f" Open batches/epoch: {open_steps_per_epoch}") |
| log.info(f" Total opt steps: ~{total_steps}, warmup: {warmup_steps}") |
| |
| global_step = 0 |
| best_composite = 0.0 |
| amp_dtype = torch.bfloat16 if cfg.bf16 else torch.float32 |
| trainable = [p for p in model.parameters() if p.requires_grad] |
| |
| try: |
| for epoch in range(cfg.epochs): |
| model.train() |
| epoch_losses = defaultdict(list) |
| epoch_mc_correct = 0 |
| epoch_mc_total = 0 |
| optimizer.zero_grad() |
| batch_count = 0 |
| |
| |
| log.info(f"Phase 3 Epoch {epoch}: MC training on ScienceQA...") |
| for batch_idx, batch in enumerate(train_mc_dl): |
| batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v |
| for k, v in batch.items()} |
| |
| with torch.autocast(device_type="cuda", dtype=amp_dtype, |
| enabled=cfg.bf16 and device.type == "cuda"): |
| losses, preds = model(**batch) |
| loss = losses["total"] / cfg.grad_accum |
| |
| loss.backward() |
| batch_count += 1 |
| |
| if batch_count % cfg.grad_accum == 0: |
| nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm) |
| optimizer.step(); scheduler.step(); optimizer.zero_grad() |
| model.update_target(global_step, total_steps) |
| global_step += 1 |
| |
| for k, v in losses.items(): |
| if isinstance(v, torch.Tensor): |
| epoch_losses[f"mc_{k}"].append(v.item()) |
| epoch_mc_correct += (preds == batch["labels"]).sum().item() |
| epoch_mc_total += batch["batch_size"] |
| |
| if batch_idx % 100 == 0: |
| avg = {k: np.mean(v[-100:]) for k, v in epoch_losses.items() if k.startswith("mc_")} |
| mc_acc = epoch_mc_correct / max(epoch_mc_total, 1) * 100 |
| log.info(f"P3 E{epoch} MC B{batch_idx}/{mc_steps_per_epoch} | " |
| f"loss={avg.get('mc_total',0):.4f} | acc={mc_acc:.1f}%") |
| trackio.log({ |
| "train/mc_loss": avg.get("mc_total", 0), |
| "train/mc_jepa": avg.get("mc_jepa", 0), |
| "train/mc_task": avg.get("mc_task", 0), |
| "train/mc_accuracy": mc_acc, |
| "train/lr": scheduler.get_last_lr()[0], |
| "train/epoch": epoch, "train/step": global_step, |
| }) |
| |
| |
| log.info(f"Phase 3 Epoch {epoch}: Open-ended training...") |
| epoch_gen_losses = defaultdict(list) |
| |
| |
| open_iters = {name: iter(dl) for name, dl in train_open_dls.items()} |
| open_active = set(open_iters.keys()) |
| open_batch_idx = 0 |
| |
| while open_active: |
| for name in list(open_active): |
| try: |
| batch = next(open_iters[name]) |
| except StopIteration: |
| open_active.discard(name) |
| continue |
| |
| batch_t = {k: v.to(device) if isinstance(v, torch.Tensor) else v |
| for k, v in batch.items()} |
| |
| with torch.autocast(device_type="cuda", dtype=amp_dtype, |
| enabled=cfg.bf16 and device.type == "cuda"): |
| |
| vis_tok = model.vis(batch_t["pixel_values"]).float() |
| txt_tok = model.txt(batch_t["input_ids"], |
| batch_t["attention_mask"]).float() |
| evidence, _, ev_mask = model.evidence(vis_tok, txt_tok, |
| batch_t["attention_mask"]) |
| |
| if model._use_rollout: |
| traj, z_final, z_proj = model.rollout(evidence) |
| else: |
| B = batch_t["batch_size"] |
| z0 = model.rollout.init_tokens.expand(B, -1, -1) + \ |
| model.rollout.z0_proj(F.adaptive_avg_pool1d( |
| evidence.permute(0,2,1), |
| model.rollout.num_tokens |
| ).permute(0,2,1)) |
| z_final = z0 |
| z_proj = model.rollout.out_proj(z0).unsqueeze(1) |
| |
| |
| jepa_loss_val = torch.tensor(0.0, device=device) |
| if model._use_jepa: |
| target_proj = model.target( |
| vis_tok.detach(), txt_tok.detach(), |
| batch_t["attention_mask"].detach() |
| ) |
| jepa_losses = model.jepa_loss( |
| z_proj, target_proj, |
| torch.tensor(0.0, device=device) |
| ) |
| jepa_loss_val = jepa_losses["jepa"] + jepa_losses["reg"] |
| |
| |
| gen_logits, gen_loss = model.gen_head( |
| z_final, evidence, batch_t["gen_target_ids"], |
| pad_token_id=pad_token_id |
| ) |
| |
| |
| total_loss = (cfg.jepa_weight * jepa_loss_val + |
| args.gen_weight * gen_loss) |
| loss = total_loss / cfg.grad_accum |
| |
| loss.backward() |
| batch_count += 1 |
| |
| if batch_count % cfg.grad_accum == 0: |
| nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm) |
| optimizer.step(); scheduler.step(); optimizer.zero_grad() |
| model.update_target(global_step, total_steps) |
| global_step += 1 |
| |
| epoch_gen_losses[f"{name}_gen"].append(gen_loss.item()) |
| epoch_gen_losses[f"{name}_total"].append(total_loss.item()) |
| epoch_losses["gen_total"].append(total_loss.item()) |
| |
| open_batch_idx += 1 |
| if open_batch_idx % 100 == 0: |
| avg_gen = {k: np.mean(v[-100:]) for k, v in epoch_gen_losses.items()} |
| log.info(f"P3 E{epoch} OPEN B{open_batch_idx} | " + |
| " | ".join(f"{k}={v:.4f}" for k, v in avg_gen.items())) |
| trackio.log({ |
| f"train/{k}": v for k, v in avg_gen.items() |
| }) |
| |
| |
| log.info(f"Phase 3 Epoch {epoch}: Evaluating...") |
| |
| |
| mc_eval_acc = p1.evaluate(model, eval_mc_dl, device, cfg) |
| log.info(f" ScienceQA eval accuracy: {mc_eval_acc:.1f}%") |
| |
| |
| eval_results = evaluate_generative( |
| model, eval_open_dls, device, cfg, tokenizer, |
| pad_token_id, args.max_gen_len, amp_dtype |
| ) |
| |
| for bm, metrics in eval_results.items(): |
| for mk, mv in metrics.items(): |
| log.info(f" {bm} {mk}: {mv:.2f}") |
| |
| |
| all_scores = [mc_eval_acc] |
| for bm, metrics in eval_results.items(): |
| all_scores.extend(metrics.values()) |
| composite = np.mean(all_scores) |
| |
| log.info(f"=== Phase 3 Epoch {epoch} | MC: {mc_eval_acc:.1f}% | " |
| f"Composite: {composite:.1f} ===") |
| |
| trackio.log({ |
| "eval/scienceqa_accuracy": mc_eval_acc, |
| "eval/composite_score": composite, |
| "eval/epoch": epoch, |
| **{f"eval/{bm}_{mk}": mv |
| for bm, metrics in eval_results.items() |
| for mk, mv in metrics.items()}, |
| }) |
| |
| |
| if composite > best_composite: |
| best_composite = composite |
| save_phase3_checkpoint( |
| model, cfg, epoch, mc_eval_acc, eval_results, |
| composite, is_best=True |
| ) |
| log.info(f"New best composite: {best_composite:.1f}") |
| |
| log.info(f"Phase 3 complete. Best composite score: {best_composite:.1f}") |
| |
| finally: |
| trackio.log({ |
| "final/best_composite": best_composite, |
| "final/phase": 3, |
| "final/total_steps": global_step |
| }) |
| log.info("Finishing Trackio...") |
| trackio.finish() |
| |
| |
| if cfg.push_to_hub: |
| push_phase3_results(cfg, args, best_composite, eval_results) |
|
|
|
|
| @torch.no_grad() |
| def evaluate_generative(model, eval_dls, device, cfg, tokenizer, |
| pad_token_id, max_gen_len, amp_dtype): |
| """Evaluate on open-ended benchmarks via generation.""" |
| model.eval() |
| results = {} |
| |
| |
| start_token_id = tokenizer.bos_token_id |
| if start_token_id is None: |
| start_token_id = tokenizer.cls_token_id or 1 |
| eos_token_id = tokenizer.eos_token_id |
| |
| for benchmark, dl in eval_dls.items(): |
| predictions = [] |
| ground_truths = [] |
| |
| for batch in dl: |
| batch_t = {k: v.to(device) if isinstance(v, torch.Tensor) else v |
| for k, v in batch.items()} |
| |
| with torch.autocast(device_type="cuda", dtype=amp_dtype, |
| enabled=cfg.bf16 and device.type == "cuda"): |
| vis_tok = model.vis(batch_t["pixel_values"]).float() |
| txt_tok = model.txt(batch_t["input_ids"], |
| batch_t["attention_mask"]).float() |
| evidence, _, _ = model.evidence(vis_tok, txt_tok, |
| batch_t["attention_mask"]) |
| |
| if model._use_rollout: |
| _, z_final, _ = model.rollout(evidence) |
| else: |
| B = batch_t["batch_size"] |
| z_final = model.rollout.init_tokens.expand(B, -1, -1) + \ |
| model.rollout.z0_proj(F.adaptive_avg_pool1d( |
| evidence.permute(0,2,1), |
| model.rollout.num_tokens |
| ).permute(0,2,1)) |
| |
| |
| gen_ids = model.gen_head.generate( |
| z_final, evidence, start_token_id, |
| max_length=max_gen_len, eos_token_id=eos_token_id |
| ) |
| |
| |
| for i in range(gen_ids.size(0)): |
| pred_text = tokenizer.decode( |
| gen_ids[i], skip_special_tokens=True |
| ).strip() |
| predictions.append(pred_text) |
| |
| ground_truths.extend(batch["all_answers"]) |
| |
| |
| if benchmark == "docvqa": |
| score = compute_anls(predictions, ground_truths) |
| results[benchmark] = {"anls": score} |
| elif benchmark == "chartqa": |
| |
| gt_flat = [gt[0] if isinstance(gt, list) else gt for gt in ground_truths] |
| score = compute_relaxed_accuracy(predictions, gt_flat) |
| results[benchmark] = {"relaxed_accuracy": score} |
| elif benchmark == "textvqa": |
| score = compute_vqa_accuracy(predictions, ground_truths) |
| results[benchmark] = {"vqa_accuracy": score} |
| |
| log.info(f" {benchmark}: {results[benchmark]}") |
| |
| model.train() |
| return results |
|
|
|
|
| def save_phase3_checkpoint(model, cfg, epoch, mc_acc, open_results, composite, is_best=False): |
| """Save Phase 3 checkpoint.""" |
| tag = "best" if is_best else f"epoch{epoch}" |
| path = os.path.join(cfg.output_dir, f"checkpoint_{tag}.pt") |
| |
| state = { |
| "evidence": model.evidence.state_dict(), |
| "rollout": model.rollout.state_dict(), |
| "disc": model.disc.state_dict(), |
| "gen_head": model.gen_head.state_dict(), |
| "target_ev": model.target.t_ev.state_dict(), |
| "target_ro": model.target.t_ro.state_dict(), |
| "config": cfg.__dict__, |
| "epoch": epoch, |
| "mc_eval_acc": mc_acc, |
| "open_results": open_results, |
| "composite_score": composite, |
| "phase": 3, |
| } |
| torch.save(state, path) |
| log.info(f"Saved Phase 3 checkpoint: {path} (composite={composite:.1f})") |
|
|
|
|
| def push_phase3_results(cfg, args, best_composite, eval_results): |
| """Push Phase 3 results and checkpoint to Hub.""" |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| |
| results = { |
| "run_name": cfg.run_name, |
| "phase": 3, |
| "backbone": cfg.backbone, |
| "K": cfg.K, |
| "use_jepa": cfg.use_jepa, |
| "loss_fn": cfg.loss_fn, |
| "best_composite_score": best_composite, |
| "epochs": cfg.epochs, |
| "core_lr": args.core_lr, |
| "backbone_lr": args.backbone_lr, |
| "text_lr": args.text_lr, |
| "gen_weight": args.gen_weight, |
| "batch_size": cfg.batch_size, |
| "grad_accum": cfg.grad_accum, |
| "open_results": {k: v for k, v in (eval_results or {}).items()}, |
| } |
| |
| result_path = os.path.join(cfg.output_dir, f"results_{cfg.run_name}.json") |
| with open(result_path, "w") as f: |
| json.dump(results, f, indent=2) |
| |
| api.upload_file( |
| path_or_fileobj=result_path, |
| path_in_repo=f"results/{cfg.run_name}.json", |
| repo_id=cfg.hub_model_id, |
| repo_type="model", |
| ) |
| |
| best_ckpt = os.path.join(cfg.output_dir, "checkpoint_best.pt") |
| if os.path.exists(best_ckpt): |
| api.upload_file( |
| path_or_fileobj=best_ckpt, |
| path_in_repo=f"checkpoints/{cfg.run_name}_best.pt", |
| repo_id=cfg.hub_model_id, |
| repo_type="model", |
| ) |
| |
| log.info(f"Pushed Phase 3 results to {cfg.hub_model_id}") |
| except Exception as e: |
| log.error(f"Push failed: {e}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|