#!/usr/bin/env python3 """ MR-JEPA Phase 3.1 Training — Improved Generative Decoder Loads the Phase 3.0 checkpoint (with partially-trained gen_head) and applies four targeted improvements to break through the 0% generative metrics: 1. gen_weight: 0.5 → 2.0 (4× stronger generative gradient signal) 2. max_gen_len: 64 → 32 (shorter targets, less padding noise) 3. Scheduled sampling (100% teacher forcing → 50% free-running, linear) 4. Beam search evaluation (beam_width=5 instead of greedy argmax) Resumes from: checkpoints/hybrid_main_phase3_best.pt (gen_head pre-trained) Training data: same as Phase 3.0 (ScienceQA MC + DocVQA/ChartQA/TextVQA open-ended) Usage: python train_phase3_1.py python train_phase3_1.py --gen_weight 2.0 --max_gen_len 32 --beam_width 5 """ import os import sys import json import math import copy import random import logging import argparse from collections import defaultdict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import AdamW from torch.utils.data import Dataset, DataLoader from PIL import Image logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s", datefmt="%H:%M:%S", ) log = logging.getLogger("mrjepa-p3.1") # ══════════════════════════════════════════════════════════════════════════ # OPEN-ENDED DATASET (same as Phase 3.0) # ══════════════════════════════════════════════════════════════════════════ class OpenEndedDataset(Dataset): def __init__(self, benchmark, split, max_samples=0, transform=None, tokenizer=None, max_len=192, max_gen_len=32): from datasets import load_dataset self.benchmark = benchmark self.transform = transform self.tokenizer = tokenizer self.max_len = max_len self.max_gen_len = max_gen_len log.info(f"Loading {benchmark} {split}...") if benchmark == "docvqa": ds = load_dataset("lmms-lab/DocVQA", "DocVQA", split=split) elif benchmark == "chartqa": ds = load_dataset("lmms-lab/ChartQA", split=split) elif benchmark == "textvqa": ds = load_dataset("lmms-lab/textvqa", split=split) else: raise ValueError(f"Unknown benchmark: {benchmark}") if max_samples > 0: ds = ds.select(range(min(max_samples, len(ds)))) self.data = ds log.info(f"Loaded {len(ds)} samples from {benchmark} {split}") def __len__(self): return len(self.data) def __getitem__(self, idx): row = self.data[idx] img = row.get("image") if img is None: img = Image.new("RGB", (256, 256), "white") else: img = img.convert("RGB") question = row["question"] if self.benchmark == "docvqa": answers = row.get("answers", [""]) answer = answers[0] if answers else "" all_answers = answers elif self.benchmark == "chartqa": answer = str(row.get("answer", "")) all_answers = [answer] elif self.benchmark == "textvqa": answers = row.get("answers", [""]) from collections import Counter answer_counts = Counter(a.lower().strip() for a in answers) answer = answer_counts.most_common(1)[0][0] if answer_counts else "" all_answers = answers else: answer = "" all_answers = [""] ocr_tokens = row.get("ocr_tokens", []) ocr_text = " ".join(ocr_tokens[:50]) if ocr_tokens else "" text = question if ocr_text: text += f" [OCR: {ocr_text}]" return { "image": img, "text": text, "answer": answer, "all_answers": all_answers, "benchmark": self.benchmark, "ocr_text": ocr_text, "question_type": row.get("type", row.get("question_types", [""])), } def collate_open_ended(batch, transform, tokenizer, max_len, max_gen_len): images = [s["image"] for s in batch] texts = [s["text"] for s in batch] answers = [s["answer"] for s in batch] if hasattr(transform, '__call__') and not hasattr(transform, 'feature_extractor'): pixel_values = torch.stack([transform(img) for img in images]) else: pixel_values = transform(images=images, return_tensors="pt")["pixel_values"] tok = tokenizer(texts, padding="max_length", truncation=True, max_length=max_len, return_tensors="pt") answer_texts = [a if a else " " for a in answers] gen_tok = tokenizer(answer_texts, padding="max_length", truncation=True, max_length=max_gen_len, return_tensors="pt") return { "pixel_values": pixel_values, "input_ids": tok["input_ids"], "attention_mask": tok["attention_mask"], "gen_target_ids": gen_tok["input_ids"], "gen_attention_mask": gen_tok["attention_mask"], "batch_size": len(batch), "benchmarks": [s["benchmark"] for s in batch], "all_answers": [s["all_answers"] for s in batch], "question_types": [s.get("question_type", "") for s in batch], } # ══════════════════════════════════════════════════════════════════════════ # GENERATIVE HEAD with SCHEDULED SAMPLING + BEAM SEARCH # ══════════════════════════════════════════════════════════════════════════ class GenerativeDecoderLayer(nn.Module): def __init__(self, hidden_dim, num_heads, dropout=0.1): super().__init__() self.self_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout, batch_first=True) self.self_attn_norm = nn.LayerNorm(hidden_dim) self.state_cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout, batch_first=True) self.state_cross_norm = nn.LayerNorm(hidden_dim) self.evidence_cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout, batch_first=True) self.evidence_cross_norm = nn.LayerNorm(hidden_dim) self.ffn = nn.Sequential(nn.Linear(hidden_dim, hidden_dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim * 4, hidden_dim), nn.Dropout(dropout)) self.ffn_norm = nn.LayerNorm(hidden_dim) def forward(self, x, z_final, evidence, causal_mask=None): r = x; x2 = self.self_attn_norm(x); x2, _ = self.self_attn(x2, x2, x2, attn_mask=causal_mask); x = r + x2 r = x; x2 = self.state_cross_norm(x); x2, _ = self.state_cross_attn(x2, z_final, z_final); x = r + x2 r = x; x2 = self.evidence_cross_norm(x); x2, _ = self.evidence_cross_attn(x2, evidence, evidence); x = r + x2 r = x; x = r + self.ffn(self.ffn_norm(x)) return x class GenerativeHead(nn.Module): """ Phase 3.1 generative decoder with: - Scheduled sampling during training (teacher forcing warmup) - Beam search during evaluation """ def __init__(self, hidden_dim, vocab_size, num_layers=4, num_heads=12, max_gen_len=32, dropout=0.1): super().__init__() self.hidden_dim = hidden_dim self.vocab_size = vocab_size self.max_gen_len = max_gen_len self.token_embedding = nn.Embedding(vocab_size, hidden_dim) self.pos_embedding = nn.Embedding(max_gen_len, hidden_dim) self.layers = nn.ModuleList([ GenerativeDecoderLayer(hidden_dim, num_heads, dropout) for _ in range(num_layers) ]) self.output_norm = nn.LayerNorm(hidden_dim) self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False) self.lm_head.weight = self.token_embedding.weight def _decode_step(self, token_ids, z_final, evidence): """Run decoder on a token sequence, return logits for the last position.""" seq_len = token_ids.size(1) positions = torch.arange(seq_len, device=token_ids.device).unsqueeze(0) x = self.token_embedding(token_ids) + self.pos_embedding(positions) causal_mask = torch.triu( torch.ones(seq_len, seq_len, device=token_ids.device, dtype=torch.bool), diagonal=1 ) for layer in self.layers: x = layer(x, z_final, evidence, causal_mask) logits = self.lm_head(self.output_norm(x)) return logits def forward(self, z_final, evidence, target_ids, pad_token_id=0, teacher_forcing_ratio=1.0): """ Training forward with scheduled sampling. teacher_forcing_ratio=1.0 → pure teacher forcing (use ground truth at every step) teacher_forcing_ratio=0.5 → 50% of tokens use model's own prediction """ B, seq_len = target_ids.shape device = target_ids.device if teacher_forcing_ratio >= 1.0: # ── Pure teacher forcing (fast, batched) ── logits = self._decode_step(target_ids, z_final, evidence) else: # ── Scheduled sampling: mix teacher forcing with free-running ── logits = torch.zeros(B, seq_len, self.vocab_size, device=device) current_input = target_ids[:, :1] # start with first token for t in range(seq_len): step_logits = self._decode_step(current_input, z_final, evidence) logits[:, t] = step_logits[:, -1] # logits at last position if t < seq_len - 1: # Decide: teacher forcing or free-running for next input use_teacher = random.random() < teacher_forcing_ratio if use_teacher: next_token = target_ids[:, t + 1:t + 2] else: next_token = step_logits[:, -1].argmax(dim=-1, keepdim=True) current_input = torch.cat([current_input, next_token], dim=1) # Loss: next-token prediction shift_logits = logits[:, :-1].contiguous() shift_labels = target_ids[:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, self.vocab_size), shift_labels.view(-1), ignore_index=pad_token_id, ) return logits, loss @torch.no_grad() def generate_greedy(self, z_final, evidence, start_token_id, max_length=32, eos_token_id=None): """Greedy autoregressive generation (fallback).""" B = z_final.size(0) device = z_final.device generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device) for step in range(max_length - 1): logits = self._decode_step(generated, z_final, evidence) next_token = logits[:, -1].argmax(dim=-1, keepdim=True) generated = torch.cat([generated, next_token], dim=1) if eos_token_id is not None and (next_token == eos_token_id).all(): break return generated @torch.no_grad() def generate_beam(self, z_final, evidence, start_token_id, max_length=32, eos_token_id=None, beam_width=5): """ Beam search generation. Processes each sample in the batch independently with beam search. Returns the highest-scoring complete sequence per sample. """ B = z_final.size(0) device = z_final.device all_results = [] for b in range(B): z_b = z_final[b:b+1] # (1, N_s, D) ev_b = evidence[b:b+1] # (1, N_e, D) # Each beam: (log_prob, token_ids_tensor) beams = [(0.0, torch.tensor([[start_token_id]], dtype=torch.long, device=device))] completed = [] for step in range(max_length - 1): candidates = [] for score, seq in beams: if eos_token_id is not None and seq[0, -1].item() == eos_token_id: completed.append((score, seq)) continue logits = self._decode_step(seq, z_b, ev_b) # (1, T, V) log_probs = F.log_softmax(logits[0, -1], dim=-1) # (V,) topk_lp, topk_ids = log_probs.topk(beam_width) for k in range(beam_width): new_score = score + topk_lp[k].item() new_seq = torch.cat([seq, topk_ids[k:k+1].unsqueeze(0)], dim=1) candidates.append((new_score, new_seq)) if not candidates: break # Length-normalize scores and keep top beams candidates.sort(key=lambda x: x[0] / x[1].size(1), reverse=True) beams = candidates[:beam_width] # Early stop if all beams ended if all(eos_token_id is not None and seq[0, -1].item() == eos_token_id for _, seq in beams): completed.extend(beams) break # Merge completed and remaining, pick best all_beams = completed + beams if all_beams: best = max(all_beams, key=lambda x: x[0] / max(x[1].size(1), 1)) all_results.append(best[1]) else: all_results.append(torch.tensor([[start_token_id]], dtype=torch.long, device=device)) # Pad to same length max_len = max(r.size(1) for r in all_results) padded = torch.full((B, max_len), 0, dtype=torch.long, device=device) for i, r in enumerate(all_results): padded[i, :r.size(1)] = r[0] return padded # ══════════════════════════════════════════════════════════════════════════ # EVALUATION METRICS (same as Phase 3.0) # ══════════════════════════════════════════════════════════════════════════ def normalized_levenshtein(s1, s2): s1, s2 = s1.lower().strip(), s2.lower().strip() if s1 == s2: return 0.0 l1, l2 = len(s1), len(s2) if l1 == 0 or l2 == 0: return 1.0 m = [[0]*(l2+1) for _ in range(l1+1)] for i in range(l1+1): m[i][0] = i for j in range(l2+1): m[0][j] = j for i in range(1,l1+1): for j in range(1,l2+1): c = 0 if s1[i-1]==s2[j-1] else 1 m[i][j] = min(m[i-1][j]+1, m[i][j-1]+1, m[i-1][j-1]+c) return m[l1][l2]/max(l1,l2) def compute_anls(predictions, ground_truths, threshold=0.5): scores = [] for pred, gts in zip(predictions, ground_truths): mx = max((1.0-normalized_levenshtein(str(pred),str(gt)) if normalized_levenshtein(str(pred),str(gt)) 0 else 0 train_mc_ds = p1.ScienceQADataset("train", max_samples=mc_max, transform=transform, tokenizer=tokenizer, max_len=cfg.max_text_len, max_opts=cfg.max_options) eval_mc_ds = p1.ScienceQADataset("test", max_samples=cfg.max_eval_samples, transform=transform, tokenizer=tokenizer, max_len=cfg.max_text_len, max_opts=cfg.max_options) mc_coll = lambda batch: p1.collate_fn(batch, transform, tokenizer, cfg.max_text_len, cfg.max_options) train_mc_dl = DataLoader(train_mc_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=2, collate_fn=mc_coll, pin_memory=True, drop_last=True) eval_mc_dl = DataLoader(eval_mc_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=2, collate_fn=mc_coll, pin_memory=True) max_open = args.max_train_samples if args.max_train_samples > 0 else 5000 open_coll = lambda batch: collate_open_ended(batch, transform, tokenizer, cfg.max_text_len, args.max_gen_len) train_open_dls = {} eval_open_dls = {} for bm, tr_split, ev_split in [("docvqa","validation","validation"), ("chartqa","test","test"), ("textvqa","train","validation")]: train_open_dls[bm] = DataLoader( OpenEndedDataset(bm, tr_split, max_samples=max_open, transform=transform, tokenizer=tokenizer, max_len=cfg.max_text_len, max_gen_len=args.max_gen_len), batch_size=cfg.batch_size, shuffle=True, num_workers=2, collate_fn=open_coll, pin_memory=True, drop_last=True) eval_open_dls[bm] = DataLoader( OpenEndedDataset(bm, ev_split, max_samples=args.max_eval_samples, transform=transform, tokenizer=tokenizer, max_len=cfg.max_text_len, max_gen_len=args.max_gen_len), batch_size=cfg.batch_size, shuffle=False, num_workers=2, collate_fn=open_coll, pin_memory=True) # ── Optimizer ── backbone_params = [p for p in model.vis.parameters() if p.requires_grad] text_params = [p for p in model.txt.parameters() if p.requires_grad] bb_txt_ids = {id(p) for p in backbone_params + text_params} core_params = [p for p in model.parameters() if p.requires_grad and id(p) not in bb_txt_ids] param_groups = [ {"params": core_params, "lr": args.core_lr}, {"params": backbone_params, "lr": args.backbone_lr}, {"params": text_params, "lr": args.text_lr}, ] optimizer = AdamW(param_groups, weight_decay=cfg.weight_decay) mc_steps = len(train_mc_dl) open_steps = sum(len(dl) for dl in train_open_dls.values()) total_steps = cfg.epochs * (mc_steps + open_steps) // cfg.grad_accum warmup_steps = int(total_steps * 0.1) def lr_lambda(step): if step < warmup_steps: return step / max(warmup_steps, 1) progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1) return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) pad_token_id = tokenizer.pad_token_id if pad_token_id is None: pad_token_id = tokenizer.eos_token_id or 0 log.info(f"Phase 3.1: {cfg.epochs} epochs | gen_weight={args.gen_weight} | " f"max_gen_len={args.max_gen_len} | beam_width={args.beam_width}") log.info(f" Teacher forcing: {args.tf_start:.0%} → {args.tf_end:.0%}") log.info(f" MC batches/epoch: {mc_steps} | Open batches/epoch: {open_steps}") log.info(f" Total opt steps: ~{total_steps} | Warmup: {warmup_steps}") global_step = 0 best_composite = 0.0 amp_dtype = torch.bfloat16 if cfg.bf16 else torch.float32 trainable = [p for p in model.parameters() if p.requires_grad] try: for epoch in range(cfg.epochs): model.train() epoch_losses = defaultdict(list) epoch_mc_correct, epoch_mc_total = 0, 0 optimizer.zero_grad() batch_count = 0 # ── Scheduled sampling ratio for this epoch ── tf_ratio = get_teacher_forcing_ratio(epoch, cfg.epochs, args.tf_start, args.tf_end) log.info(f"Phase 3.1 Epoch {epoch}: teacher_forcing={tf_ratio:.2f}") # ── MC training ── log.info(f" MC training on ScienceQA...") for bi, batch in enumerate(train_mc_dl): batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type=="cuda"): losses, preds = model(**batch) loss = losses["total"] / cfg.grad_accum loss.backward() batch_count += 1 if batch_count % cfg.grad_accum == 0: nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm) optimizer.step(); scheduler.step(); optimizer.zero_grad() model.update_target(global_step, total_steps) global_step += 1 for k, v in losses.items(): if isinstance(v, torch.Tensor): epoch_losses[f"mc_{k}"].append(v.item()) epoch_mc_correct += (preds == batch["labels"]).sum().item() epoch_mc_total += batch["batch_size"] if bi % 100 == 0: avg = {k: np.mean(v[-100:]) for k, v in epoch_losses.items() if k.startswith("mc_")} acc = epoch_mc_correct / max(epoch_mc_total, 1) * 100 log.info(f" E{epoch} MC B{bi}/{mc_steps} | loss={avg.get('mc_total',0):.4f} | acc={acc:.1f}%") trackio.log({"train/mc_loss": avg.get("mc_total",0), "train/mc_accuracy": acc, "train/lr": scheduler.get_last_lr()[0], "train/epoch": epoch, "train/step": global_step, "train/tf_ratio": tf_ratio}) # ── Open-ended training (with scheduled sampling) ── log.info(f" Open-ended training (tf_ratio={tf_ratio:.2f})...") gen_losses = defaultdict(list) open_iters = {n: iter(dl) for n, dl in train_open_dls.items()} open_active = set(open_iters.keys()) obi = 0 while open_active: for name in list(open_active): try: batch = next(open_iters[name]) except StopIteration: open_active.discard(name); continue bt = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type=="cuda"): vis_tok = model.vis(bt["pixel_values"]).float() txt_tok = model.txt(bt["input_ids"], bt["attention_mask"]).float() evidence, _, _ = model.evidence(vis_tok, txt_tok, bt["attention_mask"]) if model._use_rollout: traj, z_final, z_proj = model.rollout(evidence) else: B2 = bt["batch_size"] z0 = model.rollout.init_tokens.expand(B2,-1,-1) + \ model.rollout.z0_proj(F.adaptive_avg_pool1d( evidence.permute(0,2,1), model.rollout.num_tokens).permute(0,2,1)) z_final, z_proj = z0, model.rollout.out_proj(z0).unsqueeze(1) jepa_loss_val = torch.tensor(0.0, device=device) if model._use_jepa: target_proj = model.target(vis_tok.detach(), txt_tok.detach(), bt["attention_mask"].detach()) jl = model.jepa_loss(z_proj, target_proj, torch.tensor(0.0, device=device)) jepa_loss_val = jl["jepa"] + jl["reg"] # ── Generative loss with scheduled sampling ── _, gen_loss = model.gen_head( z_final, evidence, bt["gen_target_ids"], pad_token_id=pad_token_id, teacher_forcing_ratio=tf_ratio, ) total_loss = cfg.jepa_weight * jepa_loss_val + args.gen_weight * gen_loss loss = total_loss / cfg.grad_accum loss.backward() batch_count += 1 if batch_count % cfg.grad_accum == 0: nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm) optimizer.step(); scheduler.step(); optimizer.zero_grad() model.update_target(global_step, total_steps); global_step += 1 gen_losses[f"{name}_gen"].append(gen_loss.item()) gen_losses[f"{name}_total"].append(total_loss.item()) obi += 1 if obi % 100 == 0: avg = {k: np.mean(v[-100:]) for k, v in gen_losses.items()} log.info(f" E{epoch} OPEN B{obi} | " + " | ".join(f"{k}={v:.4f}" for k,v in avg.items())) trackio.log({f"train/{k}": v for k, v in avg.items()}) # ── Evaluation (with beam search) ── log.info(f" Evaluating (beam_width={args.beam_width})...") mc_eval_acc = p1.evaluate(model, eval_mc_dl, device, cfg) log.info(f" ScienceQA eval accuracy: {mc_eval_acc:.1f}%") eval_results = evaluate_generative_beam( model, eval_open_dls, device, cfg, tokenizer, args.max_gen_len, amp_dtype, args.beam_width ) for bm, metrics in eval_results.items(): for mk, mv in metrics.items(): log.info(f" {bm} {mk}: {mv:.2f}") all_scores = [mc_eval_acc] + [v for m in eval_results.values() for v in m.values()] composite = np.mean(all_scores) log.info(f"=== Phase 3.1 Epoch {epoch} | MC: {mc_eval_acc:.1f}% | " f"Composite: {composite:.1f} | tf={tf_ratio:.2f} ===") trackio.log({ "eval/scienceqa_accuracy": mc_eval_acc, "eval/composite_score": composite, "eval/epoch": epoch, "eval/tf_ratio": tf_ratio, **{f"eval/{bm}_{mk}": mv for bm, m in eval_results.items() for mk, mv in m.items()}, }) if composite > best_composite: best_composite = composite save_checkpoint(model, cfg, epoch, mc_eval_acc, eval_results, composite) log.info(f" ★ New best composite: {best_composite:.1f}") log.info(f"Phase 3.1 complete. Best composite: {best_composite:.1f}") finally: trackio.log({"final/best_composite": best_composite, "final/phase": "3.1", "final/total_steps": global_step}) trackio.finish() if cfg.push_to_hub: push_results(cfg, args, best_composite, eval_results) # ══════════════════════════════════════════════════════════════════════════ # BEAM SEARCH EVALUATION # ══════════════════════════════════════════════════════════════════════════ @torch.no_grad() def evaluate_generative_beam(model, eval_dls, device, cfg, tokenizer, max_gen_len, amp_dtype, beam_width): """Evaluate open-ended benchmarks using beam search decoding.""" model.eval() results = {} start_token_id = tokenizer.bos_token_id or tokenizer.cls_token_id or 1 eos_token_id = tokenizer.eos_token_id for benchmark, dl in eval_dls.items(): predictions, ground_truths = [], [] for batch in dl: bt = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type=="cuda"): vis_tok = model.vis(bt["pixel_values"]).float() txt_tok = model.txt(bt["input_ids"], bt["attention_mask"]).float() evidence, _, _ = model.evidence(vis_tok, txt_tok, bt["attention_mask"]) if model._use_rollout: _, z_final, _ = model.rollout(evidence) else: B2 = bt["batch_size"] z_final = model.rollout.init_tokens.expand(B2,-1,-1) + model.rollout.z0_proj( F.adaptive_avg_pool1d(evidence.permute(0,2,1), model.rollout.num_tokens).permute(0,2,1)) gen_ids = model.gen_head.generate_beam( z_final, evidence, start_token_id, max_length=max_gen_len, eos_token_id=eos_token_id, beam_width=beam_width, ) for i in range(gen_ids.size(0)): predictions.append(tokenizer.decode(gen_ids[i], skip_special_tokens=True).strip()) ground_truths.extend(batch["all_answers"]) # Log a few sample predictions for debugging for j in range(min(3, len(predictions))): gt_sample = ground_truths[j] if j < len(ground_truths) else "?" log.info(f" [{benchmark}] pred: '{predictions[j]}' | gt: '{gt_sample}'") if benchmark == "docvqa": results[benchmark] = {"anls": compute_anls(predictions, ground_truths)} elif benchmark == "chartqa": gt_flat = [g[0] if isinstance(g, list) else g for g in ground_truths] results[benchmark] = {"relaxed_accuracy": compute_relaxed_accuracy(predictions, gt_flat)} elif benchmark == "textvqa": results[benchmark] = {"vqa_accuracy": compute_vqa_accuracy(predictions, ground_truths)} model.train() return results # ══════════════════════════════════════════════════════════════════════════ # CHECKPOINT & HUB # ══════════════════════════════════════════════════════════════════════════ def save_checkpoint(model, cfg, epoch, mc_acc, open_results, composite): path = os.path.join(cfg.output_dir, "checkpoint_best.pt") torch.save({ "evidence": model.evidence.state_dict(), "rollout": model.rollout.state_dict(), "disc": model.disc.state_dict(), "gen_head": model.gen_head.state_dict(), "target_ev": model.target.t_ev.state_dict(), "target_ro": model.target.t_ro.state_dict(), "config": cfg.__dict__, "epoch": epoch, "mc_eval_acc": mc_acc, "open_results": open_results, "composite_score": composite, "phase": "3.1", }, path) log.info(f"Saved checkpoint: {path} (composite={composite:.1f})") def push_results(cfg, args, best_composite, eval_results): try: from huggingface_hub import HfApi api = HfApi() results = { "run_name": cfg.run_name, "phase": "3.1", "backbone": cfg.backbone, "K": cfg.K, "best_composite_score": best_composite, "gen_weight": args.gen_weight, "max_gen_len": args.max_gen_len, "beam_width": args.beam_width, "tf_start": args.tf_start, "tf_end": args.tf_end, "epochs": cfg.epochs, "core_lr": args.core_lr, "open_results": {k: v for k, v in (eval_results or {}).items()}, "improvements": ["gen_weight_2.0", "gen_len_32", "scheduled_sampling", "beam_search"], } rp = os.path.join(cfg.output_dir, f"results_{cfg.run_name}.json") with open(rp, "w") as f: json.dump(results, f, indent=2) api.upload_file(path_or_fileobj=rp, path_in_repo=f"results/{cfg.run_name}.json", repo_id=cfg.hub_model_id, repo_type="model") best_ckpt = os.path.join(cfg.output_dir, "checkpoint_best.pt") if os.path.exists(best_ckpt): api.upload_file(path_or_fileobj=best_ckpt, path_in_repo=f"checkpoints/{cfg.run_name}_best.pt", repo_id=cfg.hub_model_id, repo_type="model") log.info(f"Pushed Phase 3.1 results to {cfg.hub_model_id}") except Exception as e: log.error(f"Push failed: {e}") if __name__ == "__main__": main()