| |
| """ |
| MR-JEPA Phase 3.1 Training — Improved Generative Decoder |
| |
| Loads the Phase 3.0 checkpoint (with partially-trained gen_head) and applies |
| four targeted improvements to break through the 0% generative metrics: |
| |
| 1. gen_weight: 0.5 → 2.0 (4× stronger generative gradient signal) |
| 2. max_gen_len: 64 → 32 (shorter targets, less padding noise) |
| 3. Scheduled sampling (100% teacher forcing → 50% free-running, linear) |
| 4. Beam search evaluation (beam_width=5 instead of greedy argmax) |
| |
| Resumes from: checkpoints/hybrid_main_phase3_best.pt (gen_head pre-trained) |
| Training data: same as Phase 3.0 (ScienceQA MC + DocVQA/ChartQA/TextVQA open-ended) |
| |
| Usage: |
| python train_phase3_1.py |
| python train_phase3_1.py --gen_weight 2.0 --max_gen_len 32 --beam_width 5 |
| """ |
|
|
| import os |
| import sys |
| import json |
| import math |
| import copy |
| import random |
| import logging |
| import argparse |
| from collections import defaultdict |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.optim import AdamW |
| from torch.utils.data import Dataset, DataLoader |
| from PIL import Image |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s | %(levelname)s | %(message)s", |
| datefmt="%H:%M:%S", |
| ) |
| log = logging.getLogger("mrjepa-p3.1") |
|
|
|
|
| |
| |
| |
|
|
| class OpenEndedDataset(Dataset): |
| def __init__(self, benchmark, split, max_samples=0, transform=None, |
| tokenizer=None, max_len=192, max_gen_len=32): |
| from datasets import load_dataset |
| self.benchmark = benchmark |
| self.transform = transform |
| self.tokenizer = tokenizer |
| self.max_len = max_len |
| self.max_gen_len = max_gen_len |
| log.info(f"Loading {benchmark} {split}...") |
| if benchmark == "docvqa": |
| ds = load_dataset("lmms-lab/DocVQA", "DocVQA", split=split) |
| elif benchmark == "chartqa": |
| ds = load_dataset("lmms-lab/ChartQA", split=split) |
| elif benchmark == "textvqa": |
| ds = load_dataset("lmms-lab/textvqa", split=split) |
| else: |
| raise ValueError(f"Unknown benchmark: {benchmark}") |
| if max_samples > 0: |
| ds = ds.select(range(min(max_samples, len(ds)))) |
| self.data = ds |
| log.info(f"Loaded {len(ds)} samples from {benchmark} {split}") |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| row = self.data[idx] |
| img = row.get("image") |
| if img is None: |
| img = Image.new("RGB", (256, 256), "white") |
| else: |
| img = img.convert("RGB") |
| question = row["question"] |
| if self.benchmark == "docvqa": |
| answers = row.get("answers", [""]) |
| answer = answers[0] if answers else "" |
| all_answers = answers |
| elif self.benchmark == "chartqa": |
| answer = str(row.get("answer", "")) |
| all_answers = [answer] |
| elif self.benchmark == "textvqa": |
| answers = row.get("answers", [""]) |
| from collections import Counter |
| answer_counts = Counter(a.lower().strip() for a in answers) |
| answer = answer_counts.most_common(1)[0][0] if answer_counts else "" |
| all_answers = answers |
| else: |
| answer = "" |
| all_answers = [""] |
| ocr_tokens = row.get("ocr_tokens", []) |
| ocr_text = " ".join(ocr_tokens[:50]) if ocr_tokens else "" |
| text = question |
| if ocr_text: |
| text += f" [OCR: {ocr_text}]" |
| return { |
| "image": img, "text": text, "answer": answer, |
| "all_answers": all_answers, "benchmark": self.benchmark, |
| "ocr_text": ocr_text, |
| "question_type": row.get("type", row.get("question_types", [""])), |
| } |
|
|
|
|
| def collate_open_ended(batch, transform, tokenizer, max_len, max_gen_len): |
| images = [s["image"] for s in batch] |
| texts = [s["text"] for s in batch] |
| answers = [s["answer"] for s in batch] |
| if hasattr(transform, '__call__') and not hasattr(transform, 'feature_extractor'): |
| pixel_values = torch.stack([transform(img) for img in images]) |
| else: |
| pixel_values = transform(images=images, return_tensors="pt")["pixel_values"] |
| tok = tokenizer(texts, padding="max_length", truncation=True, |
| max_length=max_len, return_tensors="pt") |
| answer_texts = [a if a else " " for a in answers] |
| gen_tok = tokenizer(answer_texts, padding="max_length", truncation=True, |
| max_length=max_gen_len, return_tensors="pt") |
| return { |
| "pixel_values": pixel_values, |
| "input_ids": tok["input_ids"], |
| "attention_mask": tok["attention_mask"], |
| "gen_target_ids": gen_tok["input_ids"], |
| "gen_attention_mask": gen_tok["attention_mask"], |
| "batch_size": len(batch), |
| "benchmarks": [s["benchmark"] for s in batch], |
| "all_answers": [s["all_answers"] for s in batch], |
| "question_types": [s.get("question_type", "") for s in batch], |
| } |
|
|
|
|
| |
| |
| |
|
|
| class GenerativeDecoderLayer(nn.Module): |
| def __init__(self, hidden_dim, num_heads, dropout=0.1): |
| super().__init__() |
| self.self_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, |
| dropout=dropout, batch_first=True) |
| self.self_attn_norm = nn.LayerNorm(hidden_dim) |
| self.state_cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, |
| dropout=dropout, batch_first=True) |
| self.state_cross_norm = nn.LayerNorm(hidden_dim) |
| self.evidence_cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, |
| dropout=dropout, batch_first=True) |
| self.evidence_cross_norm = nn.LayerNorm(hidden_dim) |
| self.ffn = nn.Sequential(nn.Linear(hidden_dim, hidden_dim * 4), nn.GELU(), |
| nn.Dropout(dropout), nn.Linear(hidden_dim * 4, hidden_dim), |
| nn.Dropout(dropout)) |
| self.ffn_norm = nn.LayerNorm(hidden_dim) |
|
|
| def forward(self, x, z_final, evidence, causal_mask=None): |
| r = x; x2 = self.self_attn_norm(x); x2, _ = self.self_attn(x2, x2, x2, attn_mask=causal_mask); x = r + x2 |
| r = x; x2 = self.state_cross_norm(x); x2, _ = self.state_cross_attn(x2, z_final, z_final); x = r + x2 |
| r = x; x2 = self.evidence_cross_norm(x); x2, _ = self.evidence_cross_attn(x2, evidence, evidence); x = r + x2 |
| r = x; x = r + self.ffn(self.ffn_norm(x)) |
| return x |
|
|
|
|
| class GenerativeHead(nn.Module): |
| """ |
| Phase 3.1 generative decoder with: |
| - Scheduled sampling during training (teacher forcing warmup) |
| - Beam search during evaluation |
| """ |
| def __init__(self, hidden_dim, vocab_size, num_layers=4, num_heads=12, |
| max_gen_len=32, dropout=0.1): |
| super().__init__() |
| self.hidden_dim = hidden_dim |
| self.vocab_size = vocab_size |
| self.max_gen_len = max_gen_len |
| self.token_embedding = nn.Embedding(vocab_size, hidden_dim) |
| self.pos_embedding = nn.Embedding(max_gen_len, hidden_dim) |
| self.layers = nn.ModuleList([ |
| GenerativeDecoderLayer(hidden_dim, num_heads, dropout) for _ in range(num_layers) |
| ]) |
| self.output_norm = nn.LayerNorm(hidden_dim) |
| self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False) |
| self.lm_head.weight = self.token_embedding.weight |
|
|
| def _decode_step(self, token_ids, z_final, evidence): |
| """Run decoder on a token sequence, return logits for the last position.""" |
| seq_len = token_ids.size(1) |
| positions = torch.arange(seq_len, device=token_ids.device).unsqueeze(0) |
| x = self.token_embedding(token_ids) + self.pos_embedding(positions) |
| causal_mask = torch.triu( |
| torch.ones(seq_len, seq_len, device=token_ids.device, dtype=torch.bool), diagonal=1 |
| ) |
| for layer in self.layers: |
| x = layer(x, z_final, evidence, causal_mask) |
| logits = self.lm_head(self.output_norm(x)) |
| return logits |
|
|
| def forward(self, z_final, evidence, target_ids, pad_token_id=0, |
| teacher_forcing_ratio=1.0): |
| """ |
| Training forward with scheduled sampling. |
| |
| teacher_forcing_ratio=1.0 → pure teacher forcing (use ground truth at every step) |
| teacher_forcing_ratio=0.5 → 50% of tokens use model's own prediction |
| """ |
| B, seq_len = target_ids.shape |
| device = target_ids.device |
|
|
| if teacher_forcing_ratio >= 1.0: |
| |
| logits = self._decode_step(target_ids, z_final, evidence) |
| else: |
| |
| logits = torch.zeros(B, seq_len, self.vocab_size, device=device) |
| current_input = target_ids[:, :1] |
|
|
| for t in range(seq_len): |
| step_logits = self._decode_step(current_input, z_final, evidence) |
| logits[:, t] = step_logits[:, -1] |
|
|
| if t < seq_len - 1: |
| |
| use_teacher = random.random() < teacher_forcing_ratio |
| if use_teacher: |
| next_token = target_ids[:, t + 1:t + 2] |
| else: |
| next_token = step_logits[:, -1].argmax(dim=-1, keepdim=True) |
| current_input = torch.cat([current_input, next_token], dim=1) |
|
|
| |
| shift_logits = logits[:, :-1].contiguous() |
| shift_labels = target_ids[:, 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, self.vocab_size), |
| shift_labels.view(-1), |
| ignore_index=pad_token_id, |
| ) |
| return logits, loss |
|
|
| @torch.no_grad() |
| def generate_greedy(self, z_final, evidence, start_token_id, |
| max_length=32, eos_token_id=None): |
| """Greedy autoregressive generation (fallback).""" |
| B = z_final.size(0) |
| device = z_final.device |
| generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device) |
| for step in range(max_length - 1): |
| logits = self._decode_step(generated, z_final, evidence) |
| next_token = logits[:, -1].argmax(dim=-1, keepdim=True) |
| generated = torch.cat([generated, next_token], dim=1) |
| if eos_token_id is not None and (next_token == eos_token_id).all(): |
| break |
| return generated |
|
|
| @torch.no_grad() |
| def generate_beam(self, z_final, evidence, start_token_id, |
| max_length=32, eos_token_id=None, beam_width=5): |
| """ |
| Beam search generation. |
| |
| Processes each sample in the batch independently with beam search. |
| Returns the highest-scoring complete sequence per sample. |
| """ |
| B = z_final.size(0) |
| device = z_final.device |
| all_results = [] |
|
|
| for b in range(B): |
| z_b = z_final[b:b+1] |
| ev_b = evidence[b:b+1] |
|
|
| |
| beams = [(0.0, torch.tensor([[start_token_id]], dtype=torch.long, device=device))] |
| completed = [] |
|
|
| for step in range(max_length - 1): |
| candidates = [] |
| for score, seq in beams: |
| if eos_token_id is not None and seq[0, -1].item() == eos_token_id: |
| completed.append((score, seq)) |
| continue |
|
|
| logits = self._decode_step(seq, z_b, ev_b) |
| log_probs = F.log_softmax(logits[0, -1], dim=-1) |
|
|
| topk_lp, topk_ids = log_probs.topk(beam_width) |
| for k in range(beam_width): |
| new_score = score + topk_lp[k].item() |
| new_seq = torch.cat([seq, topk_ids[k:k+1].unsqueeze(0)], dim=1) |
| candidates.append((new_score, new_seq)) |
|
|
| if not candidates: |
| break |
|
|
| |
| candidates.sort(key=lambda x: x[0] / x[1].size(1), reverse=True) |
| beams = candidates[:beam_width] |
|
|
| |
| if all(eos_token_id is not None and seq[0, -1].item() == eos_token_id |
| for _, seq in beams): |
| completed.extend(beams) |
| break |
|
|
| |
| all_beams = completed + beams |
| if all_beams: |
| best = max(all_beams, key=lambda x: x[0] / max(x[1].size(1), 1)) |
| all_results.append(best[1]) |
| else: |
| all_results.append(torch.tensor([[start_token_id]], dtype=torch.long, device=device)) |
|
|
| |
| max_len = max(r.size(1) for r in all_results) |
| padded = torch.full((B, max_len), 0, dtype=torch.long, device=device) |
| for i, r in enumerate(all_results): |
| padded[i, :r.size(1)] = r[0] |
| return padded |
|
|
|
|
| |
| |
| |
|
|
| def normalized_levenshtein(s1, s2): |
| s1, s2 = s1.lower().strip(), s2.lower().strip() |
| if s1 == s2: return 0.0 |
| l1, l2 = len(s1), len(s2) |
| if l1 == 0 or l2 == 0: return 1.0 |
| m = [[0]*(l2+1) for _ in range(l1+1)] |
| for i in range(l1+1): m[i][0] = i |
| for j in range(l2+1): m[0][j] = j |
| for i in range(1,l1+1): |
| for j in range(1,l2+1): |
| c = 0 if s1[i-1]==s2[j-1] else 1 |
| m[i][j] = min(m[i-1][j]+1, m[i][j-1]+1, m[i-1][j-1]+c) |
| return m[l1][l2]/max(l1,l2) |
|
|
| def compute_anls(predictions, ground_truths, threshold=0.5): |
| scores = [] |
| for pred, gts in zip(predictions, ground_truths): |
| mx = max((1.0-normalized_levenshtein(str(pred),str(gt)) if normalized_levenshtein(str(pred),str(gt))<threshold else 0.0) for gt in gts) if gts else 0.0 |
| scores.append(mx) |
| return np.mean(scores)*100 if scores else 0.0 |
|
|
| def compute_vqa_accuracy(predictions, ground_truths): |
| scores = [] |
| for pred, gts in zip(predictions, ground_truths): |
| pn = str(pred).lower().strip() |
| scores.append(min(sum(1 for gt in gts if str(gt).lower().strip()==pn)/3.0, 1.0)) |
| return np.mean(scores)*100 if scores else 0.0 |
|
|
| def compute_relaxed_accuracy(predictions, ground_truths, tolerance=0.05): |
| correct = [] |
| for pred, gt in zip(predictions, ground_truths): |
| ps, gs = str(pred).strip().lower(), str(gt).strip().lower() |
| try: |
| gv = float(gs.replace(',','').replace('%','')) |
| pv = float(ps.replace(',','').replace('%','')) |
| correct.append(abs(pv-gv)/abs(gv)<=tolerance if gv!=0 else abs(pv)<=tolerance) |
| except (ValueError,ZeroDivisionError): |
| correct.append(ps==gs) |
| return np.mean(correct)*100 if correct else 0.0 |
|
|
|
|
| |
| |
| |
|
|
| def get_teacher_forcing_ratio(epoch, total_epochs, start_ratio=1.0, end_ratio=0.5): |
| """ |
| Linear decay from start_ratio to end_ratio over training. |
| Epoch 0: 100% teacher forcing (pure ground truth). |
| Final epoch: 50% teacher forcing (half free-running). |
| |
| This bridges the train/eval gap: during eval the model generates freely, |
| so training must gradually expose it to its own predictions. |
| """ |
| if total_epochs <= 1: |
| return start_ratio |
| progress = epoch / (total_epochs - 1) |
| return start_ratio - (start_ratio - end_ratio) * progress |
|
|
|
|
| |
| |
| |
|
|
| def download_checkpoint(hub_model_id, filename): |
| from huggingface_hub import hf_hub_download |
| path = hf_hub_download(repo_id=hub_model_id, filename=filename, repo_type="model") |
| log.info(f"Downloaded checkpoint: {path}") |
| return path |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="MR-JEPA Phase 3.1 Training") |
| parser.add_argument("--checkpoint", type=str, default=None, |
| help="Local path to checkpoint. Default: download Phase 3.0 from Hub.") |
| parser.add_argument("--hub_model_id", default="JorgeAV/MR-JEPA") |
| parser.add_argument("--run_name", default="hybrid_main_phase3_1") |
| parser.add_argument("--epochs", type=int, default=10) |
| parser.add_argument("--batch_size", type=int, default=8) |
| parser.add_argument("--grad_accum", type=int, default=16) |
| parser.add_argument("--core_lr", type=float, default=5e-5) |
| parser.add_argument("--backbone_lr", type=float, default=5e-6) |
| parser.add_argument("--text_lr", type=float, default=5e-6) |
| |
| parser.add_argument("--gen_weight", type=float, default=2.0, |
| help="Generative loss weight (was 0.5 in 3.0)") |
| parser.add_argument("--max_gen_len", type=int, default=32, |
| help="Max generation length (was 64 in 3.0)") |
| parser.add_argument("--beam_width", type=int, default=5, |
| help="Beam search width for evaluation (was greedy in 3.0)") |
| parser.add_argument("--tf_start", type=float, default=1.0, |
| help="Teacher forcing ratio at epoch 0") |
| parser.add_argument("--tf_end", type=float, default=0.5, |
| help="Teacher forcing ratio at final epoch") |
| |
| parser.add_argument("--max_eval_samples", type=int, default=200) |
| parser.add_argument("--max_train_samples", type=int, default=0) |
| parser.add_argument("--output_dir", default="./outputs/mrjepa_phase3_1") |
| parser.add_argument("--trackio_space", default="JorgeAV/MR-JEPA-Trackio") |
| args = parser.parse_args() |
|
|
| |
| log.info("Downloading Phase 1 training script for model definitions...") |
| from huggingface_hub import hf_hub_download |
| p1_script = hf_hub_download(repo_id=args.hub_model_id, filename="train_mrjepa.py", repo_type="model") |
| import importlib.util |
| spec = importlib.util.spec_from_file_location("train_mrjepa", p1_script) |
| p1 = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(p1) |
|
|
| |
| if args.checkpoint and os.path.exists(args.checkpoint): |
| ckpt_path = args.checkpoint |
| else: |
| ckpt_path = download_checkpoint(args.hub_model_id, |
| "checkpoints/hybrid_main_phase3_best.pt") |
|
|
| log.info(f"Loading Phase 3.0 checkpoint: {ckpt_path}") |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
|
|
| saved_cfg = ckpt["config"] |
| cfg = p1.Config() |
| for k, v in saved_cfg.items(): |
| if hasattr(cfg, k): |
| setattr(cfg, k, v) |
|
|
| cfg.phase = 3 |
| cfg.epochs = args.epochs |
| cfg.batch_size = args.batch_size |
| cfg.grad_accum = args.grad_accum |
| cfg.lr = args.core_lr |
| cfg.backbone_lr = args.backbone_lr |
| cfg.output_dir = args.output_dir |
| cfg.run_name = args.run_name |
| cfg.freeze_backbone = True |
| cfg.freeze_text = True |
| cfg.max_eval_samples = args.max_eval_samples |
| cfg.resolve() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| log.info(f"Device: {device}") |
| os.makedirs(cfg.output_dir, exist_ok=True) |
|
|
| |
| import trackio |
| trackio.init( |
| name=args.run_name, project="MR-JEPA", space_id=args.trackio_space, |
| config={ |
| "phase": "3.1", "epochs": args.epochs, |
| "core_lr": args.core_lr, "backbone_lr": args.backbone_lr, |
| "text_lr": args.text_lr, "gen_weight": args.gen_weight, |
| "max_gen_len": args.max_gen_len, "beam_width": args.beam_width, |
| "tf_start": args.tf_start, "tf_end": args.tf_end, |
| "batch_size": args.batch_size, "grad_accum": args.grad_accum, |
| "backbone": cfg.backbone, "K": cfg.K, |
| "improvements": "gen_weight_2.0, gen_len_32, scheduled_sampling, beam_search", |
| } |
| ) |
| log.info(f"Trackio → https://huggingface.co/spaces/{args.trackio_space}") |
|
|
| |
| log.info("Building model...") |
| model = p1.MRJEPAModel(cfg) |
| model.evidence.load_state_dict(ckpt["evidence"]) |
| model.rollout.load_state_dict(ckpt["rollout"]) |
| model.disc.load_state_dict(ckpt["disc"]) |
| model.target.t_ev.load_state_dict(ckpt["target_ev"]) |
| model.target.t_ro.load_state_dict(ckpt["target_ro"]) |
| log.info(f"Loaded core weights from Phase 3.0 (epoch={ckpt.get('epoch','?')}, " |
| f"composite={ckpt.get('composite_score','?')})") |
|
|
| |
| tokenizer = model.txt.tokenizer |
| actual_vocab_size = len(tokenizer) |
|
|
| gen_head = GenerativeHead( |
| hidden_dim=cfg.rollout_dim, |
| vocab_size=actual_vocab_size, |
| num_layers=4, |
| num_heads=cfg.predictor_heads, |
| max_gen_len=args.max_gen_len, |
| dropout=0.1, |
| ) |
|
|
| |
| if "gen_head" in ckpt: |
| p3_gen = ckpt["gen_head"] |
| new_sd = gen_head.state_dict() |
| loaded, skipped = 0, 0 |
| for k, v in p3_gen.items(): |
| if k in new_sd and new_sd[k].shape == v.shape: |
| new_sd[k] = v |
| loaded += 1 |
| elif k in new_sd: |
| skipped += 1 |
| log.info(f" Shape mismatch for {k}: ckpt {v.shape} vs new {new_sd[k].shape}") |
| else: |
| skipped += 1 |
| gen_head.load_state_dict(new_sd) |
| log.info(f"Loaded {loaded} gen_head params from Phase 3.0 ({skipped} skipped)") |
| else: |
| log.warning("No gen_head in checkpoint — starting from scratch") |
|
|
| model.gen_head = gen_head |
|
|
| |
| log.info("Unfreezing last 6 visual layers, last 4 text layers") |
| model.vis.unfreeze_last(6) |
| model.txt.unfreeze_last(4) |
|
|
| model = model.to(device) |
| total_p = sum(p.numel() for p in model.parameters()) |
| train_p = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| log.info(f"Total: {total_p:,} | Trainable: {train_p:,} ({100*train_p/total_p:.1f}%)") |
|
|
| |
| transform = model.vis.get_transform() |
| mc_max = args.max_train_samples if args.max_train_samples > 0 else 0 |
| train_mc_ds = p1.ScienceQADataset("train", max_samples=mc_max, transform=transform, |
| tokenizer=tokenizer, max_len=cfg.max_text_len, |
| max_opts=cfg.max_options) |
| eval_mc_ds = p1.ScienceQADataset("test", max_samples=cfg.max_eval_samples, |
| transform=transform, tokenizer=tokenizer, |
| max_len=cfg.max_text_len, max_opts=cfg.max_options) |
| mc_coll = lambda batch: p1.collate_fn(batch, transform, tokenizer, cfg.max_text_len, cfg.max_options) |
| train_mc_dl = DataLoader(train_mc_ds, batch_size=cfg.batch_size, shuffle=True, |
| num_workers=2, collate_fn=mc_coll, pin_memory=True, drop_last=True) |
| eval_mc_dl = DataLoader(eval_mc_ds, batch_size=cfg.batch_size, shuffle=False, |
| num_workers=2, collate_fn=mc_coll, pin_memory=True) |
|
|
| max_open = args.max_train_samples if args.max_train_samples > 0 else 5000 |
| open_coll = lambda batch: collate_open_ended(batch, transform, tokenizer, |
| cfg.max_text_len, args.max_gen_len) |
|
|
| train_open_dls = {} |
| eval_open_dls = {} |
| for bm, tr_split, ev_split in [("docvqa","validation","validation"), |
| ("chartqa","test","test"), |
| ("textvqa","train","validation")]: |
| train_open_dls[bm] = DataLoader( |
| OpenEndedDataset(bm, tr_split, max_samples=max_open, transform=transform, |
| tokenizer=tokenizer, max_len=cfg.max_text_len, |
| max_gen_len=args.max_gen_len), |
| batch_size=cfg.batch_size, shuffle=True, num_workers=2, |
| collate_fn=open_coll, pin_memory=True, drop_last=True) |
| eval_open_dls[bm] = DataLoader( |
| OpenEndedDataset(bm, ev_split, max_samples=args.max_eval_samples, |
| transform=transform, tokenizer=tokenizer, |
| max_len=cfg.max_text_len, max_gen_len=args.max_gen_len), |
| batch_size=cfg.batch_size, shuffle=False, num_workers=2, |
| collate_fn=open_coll, pin_memory=True) |
|
|
| |
| backbone_params = [p for p in model.vis.parameters() if p.requires_grad] |
| text_params = [p for p in model.txt.parameters() if p.requires_grad] |
| bb_txt_ids = {id(p) for p in backbone_params + text_params} |
| core_params = [p for p in model.parameters() if p.requires_grad and id(p) not in bb_txt_ids] |
| param_groups = [ |
| {"params": core_params, "lr": args.core_lr}, |
| {"params": backbone_params, "lr": args.backbone_lr}, |
| {"params": text_params, "lr": args.text_lr}, |
| ] |
| optimizer = AdamW(param_groups, weight_decay=cfg.weight_decay) |
|
|
| mc_steps = len(train_mc_dl) |
| open_steps = sum(len(dl) for dl in train_open_dls.values()) |
| total_steps = cfg.epochs * (mc_steps + open_steps) // cfg.grad_accum |
| warmup_steps = int(total_steps * 0.1) |
|
|
| def lr_lambda(step): |
| if step < warmup_steps: |
| return step / max(warmup_steps, 1) |
| progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1) |
| return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress)) |
|
|
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
| pad_token_id = tokenizer.pad_token_id |
| if pad_token_id is None: |
| pad_token_id = tokenizer.eos_token_id or 0 |
|
|
| log.info(f"Phase 3.1: {cfg.epochs} epochs | gen_weight={args.gen_weight} | " |
| f"max_gen_len={args.max_gen_len} | beam_width={args.beam_width}") |
| log.info(f" Teacher forcing: {args.tf_start:.0%} → {args.tf_end:.0%}") |
| log.info(f" MC batches/epoch: {mc_steps} | Open batches/epoch: {open_steps}") |
| log.info(f" Total opt steps: ~{total_steps} | Warmup: {warmup_steps}") |
|
|
| global_step = 0 |
| best_composite = 0.0 |
| amp_dtype = torch.bfloat16 if cfg.bf16 else torch.float32 |
| trainable = [p for p in model.parameters() if p.requires_grad] |
|
|
| try: |
| for epoch in range(cfg.epochs): |
| model.train() |
| epoch_losses = defaultdict(list) |
| epoch_mc_correct, epoch_mc_total = 0, 0 |
| optimizer.zero_grad() |
| batch_count = 0 |
|
|
| |
| tf_ratio = get_teacher_forcing_ratio(epoch, cfg.epochs, args.tf_start, args.tf_end) |
| log.info(f"Phase 3.1 Epoch {epoch}: teacher_forcing={tf_ratio:.2f}") |
|
|
| |
| log.info(f" MC training on ScienceQA...") |
| for bi, batch in enumerate(train_mc_dl): |
| batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} |
| with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type=="cuda"): |
| losses, preds = model(**batch) |
| loss = losses["total"] / cfg.grad_accum |
| loss.backward() |
| batch_count += 1 |
| if batch_count % cfg.grad_accum == 0: |
| nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm) |
| optimizer.step(); scheduler.step(); optimizer.zero_grad() |
| model.update_target(global_step, total_steps) |
| global_step += 1 |
| for k, v in losses.items(): |
| if isinstance(v, torch.Tensor): epoch_losses[f"mc_{k}"].append(v.item()) |
| epoch_mc_correct += (preds == batch["labels"]).sum().item() |
| epoch_mc_total += batch["batch_size"] |
| if bi % 100 == 0: |
| avg = {k: np.mean(v[-100:]) for k, v in epoch_losses.items() if k.startswith("mc_")} |
| acc = epoch_mc_correct / max(epoch_mc_total, 1) * 100 |
| log.info(f" E{epoch} MC B{bi}/{mc_steps} | loss={avg.get('mc_total',0):.4f} | acc={acc:.1f}%") |
| trackio.log({"train/mc_loss": avg.get("mc_total",0), "train/mc_accuracy": acc, |
| "train/lr": scheduler.get_last_lr()[0], "train/epoch": epoch, |
| "train/step": global_step, "train/tf_ratio": tf_ratio}) |
|
|
| |
| log.info(f" Open-ended training (tf_ratio={tf_ratio:.2f})...") |
| gen_losses = defaultdict(list) |
| open_iters = {n: iter(dl) for n, dl in train_open_dls.items()} |
| open_active = set(open_iters.keys()) |
| obi = 0 |
| while open_active: |
| for name in list(open_active): |
| try: |
| batch = next(open_iters[name]) |
| except StopIteration: |
| open_active.discard(name); continue |
| bt = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} |
| with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type=="cuda"): |
| vis_tok = model.vis(bt["pixel_values"]).float() |
| txt_tok = model.txt(bt["input_ids"], bt["attention_mask"]).float() |
| evidence, _, _ = model.evidence(vis_tok, txt_tok, bt["attention_mask"]) |
| if model._use_rollout: |
| traj, z_final, z_proj = model.rollout(evidence) |
| else: |
| B2 = bt["batch_size"] |
| z0 = model.rollout.init_tokens.expand(B2,-1,-1) + \ |
| model.rollout.z0_proj(F.adaptive_avg_pool1d( |
| evidence.permute(0,2,1), model.rollout.num_tokens).permute(0,2,1)) |
| z_final, z_proj = z0, model.rollout.out_proj(z0).unsqueeze(1) |
|
|
| jepa_loss_val = torch.tensor(0.0, device=device) |
| if model._use_jepa: |
| target_proj = model.target(vis_tok.detach(), txt_tok.detach(), bt["attention_mask"].detach()) |
| jl = model.jepa_loss(z_proj, target_proj, torch.tensor(0.0, device=device)) |
| jepa_loss_val = jl["jepa"] + jl["reg"] |
|
|
| |
| _, gen_loss = model.gen_head( |
| z_final, evidence, bt["gen_target_ids"], |
| pad_token_id=pad_token_id, |
| teacher_forcing_ratio=tf_ratio, |
| ) |
|
|
| total_loss = cfg.jepa_weight * jepa_loss_val + args.gen_weight * gen_loss |
| loss = total_loss / cfg.grad_accum |
|
|
| loss.backward() |
| batch_count += 1 |
| if batch_count % cfg.grad_accum == 0: |
| nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm) |
| optimizer.step(); scheduler.step(); optimizer.zero_grad() |
| model.update_target(global_step, total_steps); global_step += 1 |
|
|
| gen_losses[f"{name}_gen"].append(gen_loss.item()) |
| gen_losses[f"{name}_total"].append(total_loss.item()) |
| obi += 1 |
| if obi % 100 == 0: |
| avg = {k: np.mean(v[-100:]) for k, v in gen_losses.items()} |
| log.info(f" E{epoch} OPEN B{obi} | " + " | ".join(f"{k}={v:.4f}" for k,v in avg.items())) |
| trackio.log({f"train/{k}": v for k, v in avg.items()}) |
|
|
| |
| log.info(f" Evaluating (beam_width={args.beam_width})...") |
| mc_eval_acc = p1.evaluate(model, eval_mc_dl, device, cfg) |
| log.info(f" ScienceQA eval accuracy: {mc_eval_acc:.1f}%") |
|
|
| eval_results = evaluate_generative_beam( |
| model, eval_open_dls, device, cfg, tokenizer, |
| args.max_gen_len, amp_dtype, args.beam_width |
| ) |
| for bm, metrics in eval_results.items(): |
| for mk, mv in metrics.items(): |
| log.info(f" {bm} {mk}: {mv:.2f}") |
|
|
| all_scores = [mc_eval_acc] + [v for m in eval_results.values() for v in m.values()] |
| composite = np.mean(all_scores) |
| log.info(f"=== Phase 3.1 Epoch {epoch} | MC: {mc_eval_acc:.1f}% | " |
| f"Composite: {composite:.1f} | tf={tf_ratio:.2f} ===") |
|
|
| trackio.log({ |
| "eval/scienceqa_accuracy": mc_eval_acc, |
| "eval/composite_score": composite, |
| "eval/epoch": epoch, "eval/tf_ratio": tf_ratio, |
| **{f"eval/{bm}_{mk}": mv for bm, m in eval_results.items() for mk, mv in m.items()}, |
| }) |
|
|
| if composite > best_composite: |
| best_composite = composite |
| save_checkpoint(model, cfg, epoch, mc_eval_acc, eval_results, composite) |
| log.info(f" ★ New best composite: {best_composite:.1f}") |
|
|
| log.info(f"Phase 3.1 complete. Best composite: {best_composite:.1f}") |
|
|
| finally: |
| trackio.log({"final/best_composite": best_composite, "final/phase": "3.1", |
| "final/total_steps": global_step}) |
| trackio.finish() |
|
|
| if cfg.push_to_hub: |
| push_results(cfg, args, best_composite, eval_results) |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def evaluate_generative_beam(model, eval_dls, device, cfg, tokenizer, |
| max_gen_len, amp_dtype, beam_width): |
| """Evaluate open-ended benchmarks using beam search decoding.""" |
| model.eval() |
| results = {} |
| start_token_id = tokenizer.bos_token_id or tokenizer.cls_token_id or 1 |
| eos_token_id = tokenizer.eos_token_id |
|
|
| for benchmark, dl in eval_dls.items(): |
| predictions, ground_truths = [], [] |
| for batch in dl: |
| bt = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} |
| with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type=="cuda"): |
| vis_tok = model.vis(bt["pixel_values"]).float() |
| txt_tok = model.txt(bt["input_ids"], bt["attention_mask"]).float() |
| evidence, _, _ = model.evidence(vis_tok, txt_tok, bt["attention_mask"]) |
| if model._use_rollout: |
| _, z_final, _ = model.rollout(evidence) |
| else: |
| B2 = bt["batch_size"] |
| z_final = model.rollout.init_tokens.expand(B2,-1,-1) + model.rollout.z0_proj( |
| F.adaptive_avg_pool1d(evidence.permute(0,2,1), model.rollout.num_tokens).permute(0,2,1)) |
|
|
| gen_ids = model.gen_head.generate_beam( |
| z_final, evidence, start_token_id, |
| max_length=max_gen_len, eos_token_id=eos_token_id, |
| beam_width=beam_width, |
| ) |
| for i in range(gen_ids.size(0)): |
| predictions.append(tokenizer.decode(gen_ids[i], skip_special_tokens=True).strip()) |
| ground_truths.extend(batch["all_answers"]) |
|
|
| |
| for j in range(min(3, len(predictions))): |
| gt_sample = ground_truths[j] if j < len(ground_truths) else "?" |
| log.info(f" [{benchmark}] pred: '{predictions[j]}' | gt: '{gt_sample}'") |
|
|
| if benchmark == "docvqa": |
| results[benchmark] = {"anls": compute_anls(predictions, ground_truths)} |
| elif benchmark == "chartqa": |
| gt_flat = [g[0] if isinstance(g, list) else g for g in ground_truths] |
| results[benchmark] = {"relaxed_accuracy": compute_relaxed_accuracy(predictions, gt_flat)} |
| elif benchmark == "textvqa": |
| results[benchmark] = {"vqa_accuracy": compute_vqa_accuracy(predictions, ground_truths)} |
|
|
| model.train() |
| return results |
|
|
|
|
| |
| |
| |
|
|
| def save_checkpoint(model, cfg, epoch, mc_acc, open_results, composite): |
| path = os.path.join(cfg.output_dir, "checkpoint_best.pt") |
| torch.save({ |
| "evidence": model.evidence.state_dict(), |
| "rollout": model.rollout.state_dict(), |
| "disc": model.disc.state_dict(), |
| "gen_head": model.gen_head.state_dict(), |
| "target_ev": model.target.t_ev.state_dict(), |
| "target_ro": model.target.t_ro.state_dict(), |
| "config": cfg.__dict__, |
| "epoch": epoch, "mc_eval_acc": mc_acc, |
| "open_results": open_results, "composite_score": composite, |
| "phase": "3.1", |
| }, path) |
| log.info(f"Saved checkpoint: {path} (composite={composite:.1f})") |
|
|
|
|
| def push_results(cfg, args, best_composite, eval_results): |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| results = { |
| "run_name": cfg.run_name, "phase": "3.1", |
| "backbone": cfg.backbone, "K": cfg.K, |
| "best_composite_score": best_composite, |
| "gen_weight": args.gen_weight, "max_gen_len": args.max_gen_len, |
| "beam_width": args.beam_width, |
| "tf_start": args.tf_start, "tf_end": args.tf_end, |
| "epochs": cfg.epochs, "core_lr": args.core_lr, |
| "open_results": {k: v for k, v in (eval_results or {}).items()}, |
| "improvements": ["gen_weight_2.0", "gen_len_32", "scheduled_sampling", "beam_search"], |
| } |
| rp = os.path.join(cfg.output_dir, f"results_{cfg.run_name}.json") |
| with open(rp, "w") as f: |
| json.dump(results, f, indent=2) |
| api.upload_file(path_or_fileobj=rp, path_in_repo=f"results/{cfg.run_name}.json", |
| repo_id=cfg.hub_model_id, repo_type="model") |
| best_ckpt = os.path.join(cfg.output_dir, "checkpoint_best.pt") |
| if os.path.exists(best_ckpt): |
| api.upload_file(path_or_fileobj=best_ckpt, |
| path_in_repo=f"checkpoints/{cfg.run_name}_best.pt", |
| repo_id=cfg.hub_model_id, repo_type="model") |
| log.info(f"Pushed Phase 3.1 results to {cfg.hub_model_id}") |
| except Exception as e: |
| log.error(f"Push failed: {e}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|