MR-JEPA / train_phase3_1.py
JorgeAV's picture
Add Phase 3.1 training: gen_weight 2.0, gen_len 32, scheduled sampling, beam search
206e1ad verified
#!/usr/bin/env python3
"""
MR-JEPA Phase 3.1 Training — Improved Generative Decoder
Loads the Phase 3.0 checkpoint (with partially-trained gen_head) and applies
four targeted improvements to break through the 0% generative metrics:
1. gen_weight: 0.5 → 2.0 (4× stronger generative gradient signal)
2. max_gen_len: 64 → 32 (shorter targets, less padding noise)
3. Scheduled sampling (100% teacher forcing → 50% free-running, linear)
4. Beam search evaluation (beam_width=5 instead of greedy argmax)
Resumes from: checkpoints/hybrid_main_phase3_best.pt (gen_head pre-trained)
Training data: same as Phase 3.0 (ScienceQA MC + DocVQA/ChartQA/TextVQA open-ended)
Usage:
python train_phase3_1.py
python train_phase3_1.py --gen_weight 2.0 --max_gen_len 32 --beam_width 5
"""
import os
import sys
import json
import math
import copy
import random
import logging
import argparse
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from PIL import Image
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s",
datefmt="%H:%M:%S",
)
log = logging.getLogger("mrjepa-p3.1")
# ══════════════════════════════════════════════════════════════════════════
# OPEN-ENDED DATASET (same as Phase 3.0)
# ══════════════════════════════════════════════════════════════════════════
class OpenEndedDataset(Dataset):
def __init__(self, benchmark, split, max_samples=0, transform=None,
tokenizer=None, max_len=192, max_gen_len=32):
from datasets import load_dataset
self.benchmark = benchmark
self.transform = transform
self.tokenizer = tokenizer
self.max_len = max_len
self.max_gen_len = max_gen_len
log.info(f"Loading {benchmark} {split}...")
if benchmark == "docvqa":
ds = load_dataset("lmms-lab/DocVQA", "DocVQA", split=split)
elif benchmark == "chartqa":
ds = load_dataset("lmms-lab/ChartQA", split=split)
elif benchmark == "textvqa":
ds = load_dataset("lmms-lab/textvqa", split=split)
else:
raise ValueError(f"Unknown benchmark: {benchmark}")
if max_samples > 0:
ds = ds.select(range(min(max_samples, len(ds))))
self.data = ds
log.info(f"Loaded {len(ds)} samples from {benchmark} {split}")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
row = self.data[idx]
img = row.get("image")
if img is None:
img = Image.new("RGB", (256, 256), "white")
else:
img = img.convert("RGB")
question = row["question"]
if self.benchmark == "docvqa":
answers = row.get("answers", [""])
answer = answers[0] if answers else ""
all_answers = answers
elif self.benchmark == "chartqa":
answer = str(row.get("answer", ""))
all_answers = [answer]
elif self.benchmark == "textvqa":
answers = row.get("answers", [""])
from collections import Counter
answer_counts = Counter(a.lower().strip() for a in answers)
answer = answer_counts.most_common(1)[0][0] if answer_counts else ""
all_answers = answers
else:
answer = ""
all_answers = [""]
ocr_tokens = row.get("ocr_tokens", [])
ocr_text = " ".join(ocr_tokens[:50]) if ocr_tokens else ""
text = question
if ocr_text:
text += f" [OCR: {ocr_text}]"
return {
"image": img, "text": text, "answer": answer,
"all_answers": all_answers, "benchmark": self.benchmark,
"ocr_text": ocr_text,
"question_type": row.get("type", row.get("question_types", [""])),
}
def collate_open_ended(batch, transform, tokenizer, max_len, max_gen_len):
images = [s["image"] for s in batch]
texts = [s["text"] for s in batch]
answers = [s["answer"] for s in batch]
if hasattr(transform, '__call__') and not hasattr(transform, 'feature_extractor'):
pixel_values = torch.stack([transform(img) for img in images])
else:
pixel_values = transform(images=images, return_tensors="pt")["pixel_values"]
tok = tokenizer(texts, padding="max_length", truncation=True,
max_length=max_len, return_tensors="pt")
answer_texts = [a if a else " " for a in answers]
gen_tok = tokenizer(answer_texts, padding="max_length", truncation=True,
max_length=max_gen_len, return_tensors="pt")
return {
"pixel_values": pixel_values,
"input_ids": tok["input_ids"],
"attention_mask": tok["attention_mask"],
"gen_target_ids": gen_tok["input_ids"],
"gen_attention_mask": gen_tok["attention_mask"],
"batch_size": len(batch),
"benchmarks": [s["benchmark"] for s in batch],
"all_answers": [s["all_answers"] for s in batch],
"question_types": [s.get("question_type", "") for s in batch],
}
# ══════════════════════════════════════════════════════════════════════════
# GENERATIVE HEAD with SCHEDULED SAMPLING + BEAM SEARCH
# ══════════════════════════════════════════════════════════════════════════
class GenerativeDecoderLayer(nn.Module):
def __init__(self, hidden_dim, num_heads, dropout=0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads,
dropout=dropout, batch_first=True)
self.self_attn_norm = nn.LayerNorm(hidden_dim)
self.state_cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads,
dropout=dropout, batch_first=True)
self.state_cross_norm = nn.LayerNorm(hidden_dim)
self.evidence_cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads,
dropout=dropout, batch_first=True)
self.evidence_cross_norm = nn.LayerNorm(hidden_dim)
self.ffn = nn.Sequential(nn.Linear(hidden_dim, hidden_dim * 4), nn.GELU(),
nn.Dropout(dropout), nn.Linear(hidden_dim * 4, hidden_dim),
nn.Dropout(dropout))
self.ffn_norm = nn.LayerNorm(hidden_dim)
def forward(self, x, z_final, evidence, causal_mask=None):
r = x; x2 = self.self_attn_norm(x); x2, _ = self.self_attn(x2, x2, x2, attn_mask=causal_mask); x = r + x2
r = x; x2 = self.state_cross_norm(x); x2, _ = self.state_cross_attn(x2, z_final, z_final); x = r + x2
r = x; x2 = self.evidence_cross_norm(x); x2, _ = self.evidence_cross_attn(x2, evidence, evidence); x = r + x2
r = x; x = r + self.ffn(self.ffn_norm(x))
return x
class GenerativeHead(nn.Module):
"""
Phase 3.1 generative decoder with:
- Scheduled sampling during training (teacher forcing warmup)
- Beam search during evaluation
"""
def __init__(self, hidden_dim, vocab_size, num_layers=4, num_heads=12,
max_gen_len=32, dropout=0.1):
super().__init__()
self.hidden_dim = hidden_dim
self.vocab_size = vocab_size
self.max_gen_len = max_gen_len
self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
self.pos_embedding = nn.Embedding(max_gen_len, hidden_dim)
self.layers = nn.ModuleList([
GenerativeDecoderLayer(hidden_dim, num_heads, dropout) for _ in range(num_layers)
])
self.output_norm = nn.LayerNorm(hidden_dim)
self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
self.lm_head.weight = self.token_embedding.weight
def _decode_step(self, token_ids, z_final, evidence):
"""Run decoder on a token sequence, return logits for the last position."""
seq_len = token_ids.size(1)
positions = torch.arange(seq_len, device=token_ids.device).unsqueeze(0)
x = self.token_embedding(token_ids) + self.pos_embedding(positions)
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=token_ids.device, dtype=torch.bool), diagonal=1
)
for layer in self.layers:
x = layer(x, z_final, evidence, causal_mask)
logits = self.lm_head(self.output_norm(x))
return logits
def forward(self, z_final, evidence, target_ids, pad_token_id=0,
teacher_forcing_ratio=1.0):
"""
Training forward with scheduled sampling.
teacher_forcing_ratio=1.0 → pure teacher forcing (use ground truth at every step)
teacher_forcing_ratio=0.5 → 50% of tokens use model's own prediction
"""
B, seq_len = target_ids.shape
device = target_ids.device
if teacher_forcing_ratio >= 1.0:
# ── Pure teacher forcing (fast, batched) ──
logits = self._decode_step(target_ids, z_final, evidence)
else:
# ── Scheduled sampling: mix teacher forcing with free-running ──
logits = torch.zeros(B, seq_len, self.vocab_size, device=device)
current_input = target_ids[:, :1] # start with first token
for t in range(seq_len):
step_logits = self._decode_step(current_input, z_final, evidence)
logits[:, t] = step_logits[:, -1] # logits at last position
if t < seq_len - 1:
# Decide: teacher forcing or free-running for next input
use_teacher = random.random() < teacher_forcing_ratio
if use_teacher:
next_token = target_ids[:, t + 1:t + 2]
else:
next_token = step_logits[:, -1].argmax(dim=-1, keepdim=True)
current_input = torch.cat([current_input, next_token], dim=1)
# Loss: next-token prediction
shift_logits = logits[:, :-1].contiguous()
shift_labels = target_ids[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, self.vocab_size),
shift_labels.view(-1),
ignore_index=pad_token_id,
)
return logits, loss
@torch.no_grad()
def generate_greedy(self, z_final, evidence, start_token_id,
max_length=32, eos_token_id=None):
"""Greedy autoregressive generation (fallback)."""
B = z_final.size(0)
device = z_final.device
generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
for step in range(max_length - 1):
logits = self._decode_step(generated, z_final, evidence)
next_token = logits[:, -1].argmax(dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return generated
@torch.no_grad()
def generate_beam(self, z_final, evidence, start_token_id,
max_length=32, eos_token_id=None, beam_width=5):
"""
Beam search generation.
Processes each sample in the batch independently with beam search.
Returns the highest-scoring complete sequence per sample.
"""
B = z_final.size(0)
device = z_final.device
all_results = []
for b in range(B):
z_b = z_final[b:b+1] # (1, N_s, D)
ev_b = evidence[b:b+1] # (1, N_e, D)
# Each beam: (log_prob, token_ids_tensor)
beams = [(0.0, torch.tensor([[start_token_id]], dtype=torch.long, device=device))]
completed = []
for step in range(max_length - 1):
candidates = []
for score, seq in beams:
if eos_token_id is not None and seq[0, -1].item() == eos_token_id:
completed.append((score, seq))
continue
logits = self._decode_step(seq, z_b, ev_b) # (1, T, V)
log_probs = F.log_softmax(logits[0, -1], dim=-1) # (V,)
topk_lp, topk_ids = log_probs.topk(beam_width)
for k in range(beam_width):
new_score = score + topk_lp[k].item()
new_seq = torch.cat([seq, topk_ids[k:k+1].unsqueeze(0)], dim=1)
candidates.append((new_score, new_seq))
if not candidates:
break
# Length-normalize scores and keep top beams
candidates.sort(key=lambda x: x[0] / x[1].size(1), reverse=True)
beams = candidates[:beam_width]
# Early stop if all beams ended
if all(eos_token_id is not None and seq[0, -1].item() == eos_token_id
for _, seq in beams):
completed.extend(beams)
break
# Merge completed and remaining, pick best
all_beams = completed + beams
if all_beams:
best = max(all_beams, key=lambda x: x[0] / max(x[1].size(1), 1))
all_results.append(best[1])
else:
all_results.append(torch.tensor([[start_token_id]], dtype=torch.long, device=device))
# Pad to same length
max_len = max(r.size(1) for r in all_results)
padded = torch.full((B, max_len), 0, dtype=torch.long, device=device)
for i, r in enumerate(all_results):
padded[i, :r.size(1)] = r[0]
return padded
# ══════════════════════════════════════════════════════════════════════════
# EVALUATION METRICS (same as Phase 3.0)
# ══════════════════════════════════════════════════════════════════════════
def normalized_levenshtein(s1, s2):
s1, s2 = s1.lower().strip(), s2.lower().strip()
if s1 == s2: return 0.0
l1, l2 = len(s1), len(s2)
if l1 == 0 or l2 == 0: return 1.0
m = [[0]*(l2+1) for _ in range(l1+1)]
for i in range(l1+1): m[i][0] = i
for j in range(l2+1): m[0][j] = j
for i in range(1,l1+1):
for j in range(1,l2+1):
c = 0 if s1[i-1]==s2[j-1] else 1
m[i][j] = min(m[i-1][j]+1, m[i][j-1]+1, m[i-1][j-1]+c)
return m[l1][l2]/max(l1,l2)
def compute_anls(predictions, ground_truths, threshold=0.5):
scores = []
for pred, gts in zip(predictions, ground_truths):
mx = max((1.0-normalized_levenshtein(str(pred),str(gt)) if normalized_levenshtein(str(pred),str(gt))<threshold else 0.0) for gt in gts) if gts else 0.0
scores.append(mx)
return np.mean(scores)*100 if scores else 0.0
def compute_vqa_accuracy(predictions, ground_truths):
scores = []
for pred, gts in zip(predictions, ground_truths):
pn = str(pred).lower().strip()
scores.append(min(sum(1 for gt in gts if str(gt).lower().strip()==pn)/3.0, 1.0))
return np.mean(scores)*100 if scores else 0.0
def compute_relaxed_accuracy(predictions, ground_truths, tolerance=0.05):
correct = []
for pred, gt in zip(predictions, ground_truths):
ps, gs = str(pred).strip().lower(), str(gt).strip().lower()
try:
gv = float(gs.replace(',','').replace('%',''))
pv = float(ps.replace(',','').replace('%',''))
correct.append(abs(pv-gv)/abs(gv)<=tolerance if gv!=0 else abs(pv)<=tolerance)
except (ValueError,ZeroDivisionError):
correct.append(ps==gs)
return np.mean(correct)*100 if correct else 0.0
# ══════════════════════════════════════════════════════════════════════════
# SCHEDULED SAMPLING SCHEDULE
# ══════════════════════════════════════════════════════════════════════════
def get_teacher_forcing_ratio(epoch, total_epochs, start_ratio=1.0, end_ratio=0.5):
"""
Linear decay from start_ratio to end_ratio over training.
Epoch 0: 100% teacher forcing (pure ground truth).
Final epoch: 50% teacher forcing (half free-running).
This bridges the train/eval gap: during eval the model generates freely,
so training must gradually expose it to its own predictions.
"""
if total_epochs <= 1:
return start_ratio
progress = epoch / (total_epochs - 1)
return start_ratio - (start_ratio - end_ratio) * progress
# ══════════════════════════════════════════════════════════════════════════
# MAIN
# ══════════════════════════════════════════════════════════════════════════
def download_checkpoint(hub_model_id, filename):
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id=hub_model_id, filename=filename, repo_type="model")
log.info(f"Downloaded checkpoint: {path}")
return path
def main():
parser = argparse.ArgumentParser(description="MR-JEPA Phase 3.1 Training")
parser.add_argument("--checkpoint", type=str, default=None,
help="Local path to checkpoint. Default: download Phase 3.0 from Hub.")
parser.add_argument("--hub_model_id", default="JorgeAV/MR-JEPA")
parser.add_argument("--run_name", default="hybrid_main_phase3_1")
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--grad_accum", type=int, default=16)
parser.add_argument("--core_lr", type=float, default=5e-5)
parser.add_argument("--backbone_lr", type=float, default=5e-6)
parser.add_argument("--text_lr", type=float, default=5e-6)
# ── Phase 3.1 improvements ──
parser.add_argument("--gen_weight", type=float, default=2.0,
help="Generative loss weight (was 0.5 in 3.0)")
parser.add_argument("--max_gen_len", type=int, default=32,
help="Max generation length (was 64 in 3.0)")
parser.add_argument("--beam_width", type=int, default=5,
help="Beam search width for evaluation (was greedy in 3.0)")
parser.add_argument("--tf_start", type=float, default=1.0,
help="Teacher forcing ratio at epoch 0")
parser.add_argument("--tf_end", type=float, default=0.5,
help="Teacher forcing ratio at final epoch")
# ──────────────────────────────
parser.add_argument("--max_eval_samples", type=int, default=200)
parser.add_argument("--max_train_samples", type=int, default=0)
parser.add_argument("--output_dir", default="./outputs/mrjepa_phase3_1")
parser.add_argument("--trackio_space", default="JorgeAV/MR-JEPA-Trackio")
args = parser.parse_args()
# ── Import Phase 1 model definitions ──
log.info("Downloading Phase 1 training script for model definitions...")
from huggingface_hub import hf_hub_download
p1_script = hf_hub_download(repo_id=args.hub_model_id, filename="train_mrjepa.py", repo_type="model")
import importlib.util
spec = importlib.util.spec_from_file_location("train_mrjepa", p1_script)
p1 = importlib.util.module_from_spec(spec)
spec.loader.exec_module(p1)
# ── Load Phase 3.0 checkpoint (includes gen_head weights) ──
if args.checkpoint and os.path.exists(args.checkpoint):
ckpt_path = args.checkpoint
else:
ckpt_path = download_checkpoint(args.hub_model_id,
"checkpoints/hybrid_main_phase3_best.pt")
log.info(f"Loading Phase 3.0 checkpoint: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
saved_cfg = ckpt["config"]
cfg = p1.Config()
for k, v in saved_cfg.items():
if hasattr(cfg, k):
setattr(cfg, k, v)
cfg.phase = 3
cfg.epochs = args.epochs
cfg.batch_size = args.batch_size
cfg.grad_accum = args.grad_accum
cfg.lr = args.core_lr
cfg.backbone_lr = args.backbone_lr
cfg.output_dir = args.output_dir
cfg.run_name = args.run_name
cfg.freeze_backbone = True
cfg.freeze_text = True
cfg.max_eval_samples = args.max_eval_samples
cfg.resolve()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info(f"Device: {device}")
os.makedirs(cfg.output_dir, exist_ok=True)
# ── Trackio ──
import trackio
trackio.init(
name=args.run_name, project="MR-JEPA", space_id=args.trackio_space,
config={
"phase": "3.1", "epochs": args.epochs,
"core_lr": args.core_lr, "backbone_lr": args.backbone_lr,
"text_lr": args.text_lr, "gen_weight": args.gen_weight,
"max_gen_len": args.max_gen_len, "beam_width": args.beam_width,
"tf_start": args.tf_start, "tf_end": args.tf_end,
"batch_size": args.batch_size, "grad_accum": args.grad_accum,
"backbone": cfg.backbone, "K": cfg.K,
"improvements": "gen_weight_2.0, gen_len_32, scheduled_sampling, beam_search",
}
)
log.info(f"Trackio → https://huggingface.co/spaces/{args.trackio_space}")
# ── Build model ──
log.info("Building model...")
model = p1.MRJEPAModel(cfg)
model.evidence.load_state_dict(ckpt["evidence"])
model.rollout.load_state_dict(ckpt["rollout"])
model.disc.load_state_dict(ckpt["disc"])
model.target.t_ev.load_state_dict(ckpt["target_ev"])
model.target.t_ro.load_state_dict(ckpt["target_ro"])
log.info(f"Loaded core weights from Phase 3.0 (epoch={ckpt.get('epoch','?')}, "
f"composite={ckpt.get('composite_score','?')})")
# ── Generative head: new architecture with max_gen_len=32 ──
tokenizer = model.txt.tokenizer
actual_vocab_size = len(tokenizer)
gen_head = GenerativeHead(
hidden_dim=cfg.rollout_dim,
vocab_size=actual_vocab_size,
num_layers=4,
num_heads=cfg.predictor_heads,
max_gen_len=args.max_gen_len,
dropout=0.1,
)
# Load Phase 3.0 gen_head weights where shapes match
if "gen_head" in ckpt:
p3_gen = ckpt["gen_head"]
new_sd = gen_head.state_dict()
loaded, skipped = 0, 0
for k, v in p3_gen.items():
if k in new_sd and new_sd[k].shape == v.shape:
new_sd[k] = v
loaded += 1
elif k in new_sd:
skipped += 1
log.info(f" Shape mismatch for {k}: ckpt {v.shape} vs new {new_sd[k].shape}")
else:
skipped += 1
gen_head.load_state_dict(new_sd)
log.info(f"Loaded {loaded} gen_head params from Phase 3.0 ({skipped} skipped)")
else:
log.warning("No gen_head in checkpoint — starting from scratch")
model.gen_head = gen_head
# ── Unfreeze backbone layers ──
log.info("Unfreezing last 6 visual layers, last 4 text layers")
model.vis.unfreeze_last(6)
model.txt.unfreeze_last(4)
model = model.to(device)
total_p = sum(p.numel() for p in model.parameters())
train_p = sum(p.numel() for p in model.parameters() if p.requires_grad)
log.info(f"Total: {total_p:,} | Trainable: {train_p:,} ({100*train_p/total_p:.1f}%)")
# ── Datasets ──
transform = model.vis.get_transform()
mc_max = args.max_train_samples if args.max_train_samples > 0 else 0
train_mc_ds = p1.ScienceQADataset("train", max_samples=mc_max, transform=transform,
tokenizer=tokenizer, max_len=cfg.max_text_len,
max_opts=cfg.max_options)
eval_mc_ds = p1.ScienceQADataset("test", max_samples=cfg.max_eval_samples,
transform=transform, tokenizer=tokenizer,
max_len=cfg.max_text_len, max_opts=cfg.max_options)
mc_coll = lambda batch: p1.collate_fn(batch, transform, tokenizer, cfg.max_text_len, cfg.max_options)
train_mc_dl = DataLoader(train_mc_ds, batch_size=cfg.batch_size, shuffle=True,
num_workers=2, collate_fn=mc_coll, pin_memory=True, drop_last=True)
eval_mc_dl = DataLoader(eval_mc_ds, batch_size=cfg.batch_size, shuffle=False,
num_workers=2, collate_fn=mc_coll, pin_memory=True)
max_open = args.max_train_samples if args.max_train_samples > 0 else 5000
open_coll = lambda batch: collate_open_ended(batch, transform, tokenizer,
cfg.max_text_len, args.max_gen_len)
train_open_dls = {}
eval_open_dls = {}
for bm, tr_split, ev_split in [("docvqa","validation","validation"),
("chartqa","test","test"),
("textvqa","train","validation")]:
train_open_dls[bm] = DataLoader(
OpenEndedDataset(bm, tr_split, max_samples=max_open, transform=transform,
tokenizer=tokenizer, max_len=cfg.max_text_len,
max_gen_len=args.max_gen_len),
batch_size=cfg.batch_size, shuffle=True, num_workers=2,
collate_fn=open_coll, pin_memory=True, drop_last=True)
eval_open_dls[bm] = DataLoader(
OpenEndedDataset(bm, ev_split, max_samples=args.max_eval_samples,
transform=transform, tokenizer=tokenizer,
max_len=cfg.max_text_len, max_gen_len=args.max_gen_len),
batch_size=cfg.batch_size, shuffle=False, num_workers=2,
collate_fn=open_coll, pin_memory=True)
# ── Optimizer ──
backbone_params = [p for p in model.vis.parameters() if p.requires_grad]
text_params = [p for p in model.txt.parameters() if p.requires_grad]
bb_txt_ids = {id(p) for p in backbone_params + text_params}
core_params = [p for p in model.parameters() if p.requires_grad and id(p) not in bb_txt_ids]
param_groups = [
{"params": core_params, "lr": args.core_lr},
{"params": backbone_params, "lr": args.backbone_lr},
{"params": text_params, "lr": args.text_lr},
]
optimizer = AdamW(param_groups, weight_decay=cfg.weight_decay)
mc_steps = len(train_mc_dl)
open_steps = sum(len(dl) for dl in train_open_dls.values())
total_steps = cfg.epochs * (mc_steps + open_steps) // cfg.grad_accum
warmup_steps = int(total_steps * 0.1)
def lr_lambda(step):
if step < warmup_steps:
return step / max(warmup_steps, 1)
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
pad_token_id = tokenizer.pad_token_id
if pad_token_id is None:
pad_token_id = tokenizer.eos_token_id or 0
log.info(f"Phase 3.1: {cfg.epochs} epochs | gen_weight={args.gen_weight} | "
f"max_gen_len={args.max_gen_len} | beam_width={args.beam_width}")
log.info(f" Teacher forcing: {args.tf_start:.0%}{args.tf_end:.0%}")
log.info(f" MC batches/epoch: {mc_steps} | Open batches/epoch: {open_steps}")
log.info(f" Total opt steps: ~{total_steps} | Warmup: {warmup_steps}")
global_step = 0
best_composite = 0.0
amp_dtype = torch.bfloat16 if cfg.bf16 else torch.float32
trainable = [p for p in model.parameters() if p.requires_grad]
try:
for epoch in range(cfg.epochs):
model.train()
epoch_losses = defaultdict(list)
epoch_mc_correct, epoch_mc_total = 0, 0
optimizer.zero_grad()
batch_count = 0
# ── Scheduled sampling ratio for this epoch ──
tf_ratio = get_teacher_forcing_ratio(epoch, cfg.epochs, args.tf_start, args.tf_end)
log.info(f"Phase 3.1 Epoch {epoch}: teacher_forcing={tf_ratio:.2f}")
# ── MC training ──
log.info(f" MC training on ScienceQA...")
for bi, batch in enumerate(train_mc_dl):
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type=="cuda"):
losses, preds = model(**batch)
loss = losses["total"] / cfg.grad_accum
loss.backward()
batch_count += 1
if batch_count % cfg.grad_accum == 0:
nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm)
optimizer.step(); scheduler.step(); optimizer.zero_grad()
model.update_target(global_step, total_steps)
global_step += 1
for k, v in losses.items():
if isinstance(v, torch.Tensor): epoch_losses[f"mc_{k}"].append(v.item())
epoch_mc_correct += (preds == batch["labels"]).sum().item()
epoch_mc_total += batch["batch_size"]
if bi % 100 == 0:
avg = {k: np.mean(v[-100:]) for k, v in epoch_losses.items() if k.startswith("mc_")}
acc = epoch_mc_correct / max(epoch_mc_total, 1) * 100
log.info(f" E{epoch} MC B{bi}/{mc_steps} | loss={avg.get('mc_total',0):.4f} | acc={acc:.1f}%")
trackio.log({"train/mc_loss": avg.get("mc_total",0), "train/mc_accuracy": acc,
"train/lr": scheduler.get_last_lr()[0], "train/epoch": epoch,
"train/step": global_step, "train/tf_ratio": tf_ratio})
# ── Open-ended training (with scheduled sampling) ──
log.info(f" Open-ended training (tf_ratio={tf_ratio:.2f})...")
gen_losses = defaultdict(list)
open_iters = {n: iter(dl) for n, dl in train_open_dls.items()}
open_active = set(open_iters.keys())
obi = 0
while open_active:
for name in list(open_active):
try:
batch = next(open_iters[name])
except StopIteration:
open_active.discard(name); continue
bt = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type=="cuda"):
vis_tok = model.vis(bt["pixel_values"]).float()
txt_tok = model.txt(bt["input_ids"], bt["attention_mask"]).float()
evidence, _, _ = model.evidence(vis_tok, txt_tok, bt["attention_mask"])
if model._use_rollout:
traj, z_final, z_proj = model.rollout(evidence)
else:
B2 = bt["batch_size"]
z0 = model.rollout.init_tokens.expand(B2,-1,-1) + \
model.rollout.z0_proj(F.adaptive_avg_pool1d(
evidence.permute(0,2,1), model.rollout.num_tokens).permute(0,2,1))
z_final, z_proj = z0, model.rollout.out_proj(z0).unsqueeze(1)
jepa_loss_val = torch.tensor(0.0, device=device)
if model._use_jepa:
target_proj = model.target(vis_tok.detach(), txt_tok.detach(), bt["attention_mask"].detach())
jl = model.jepa_loss(z_proj, target_proj, torch.tensor(0.0, device=device))
jepa_loss_val = jl["jepa"] + jl["reg"]
# ── Generative loss with scheduled sampling ──
_, gen_loss = model.gen_head(
z_final, evidence, bt["gen_target_ids"],
pad_token_id=pad_token_id,
teacher_forcing_ratio=tf_ratio,
)
total_loss = cfg.jepa_weight * jepa_loss_val + args.gen_weight * gen_loss
loss = total_loss / cfg.grad_accum
loss.backward()
batch_count += 1
if batch_count % cfg.grad_accum == 0:
nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm)
optimizer.step(); scheduler.step(); optimizer.zero_grad()
model.update_target(global_step, total_steps); global_step += 1
gen_losses[f"{name}_gen"].append(gen_loss.item())
gen_losses[f"{name}_total"].append(total_loss.item())
obi += 1
if obi % 100 == 0:
avg = {k: np.mean(v[-100:]) for k, v in gen_losses.items()}
log.info(f" E{epoch} OPEN B{obi} | " + " | ".join(f"{k}={v:.4f}" for k,v in avg.items()))
trackio.log({f"train/{k}": v for k, v in avg.items()})
# ── Evaluation (with beam search) ──
log.info(f" Evaluating (beam_width={args.beam_width})...")
mc_eval_acc = p1.evaluate(model, eval_mc_dl, device, cfg)
log.info(f" ScienceQA eval accuracy: {mc_eval_acc:.1f}%")
eval_results = evaluate_generative_beam(
model, eval_open_dls, device, cfg, tokenizer,
args.max_gen_len, amp_dtype, args.beam_width
)
for bm, metrics in eval_results.items():
for mk, mv in metrics.items():
log.info(f" {bm} {mk}: {mv:.2f}")
all_scores = [mc_eval_acc] + [v for m in eval_results.values() for v in m.values()]
composite = np.mean(all_scores)
log.info(f"=== Phase 3.1 Epoch {epoch} | MC: {mc_eval_acc:.1f}% | "
f"Composite: {composite:.1f} | tf={tf_ratio:.2f} ===")
trackio.log({
"eval/scienceqa_accuracy": mc_eval_acc,
"eval/composite_score": composite,
"eval/epoch": epoch, "eval/tf_ratio": tf_ratio,
**{f"eval/{bm}_{mk}": mv for bm, m in eval_results.items() for mk, mv in m.items()},
})
if composite > best_composite:
best_composite = composite
save_checkpoint(model, cfg, epoch, mc_eval_acc, eval_results, composite)
log.info(f" ★ New best composite: {best_composite:.1f}")
log.info(f"Phase 3.1 complete. Best composite: {best_composite:.1f}")
finally:
trackio.log({"final/best_composite": best_composite, "final/phase": "3.1",
"final/total_steps": global_step})
trackio.finish()
if cfg.push_to_hub:
push_results(cfg, args, best_composite, eval_results)
# ══════════════════════════════════════════════════════════════════════════
# BEAM SEARCH EVALUATION
# ══════════════════════════════════════════════════════════════════════════
@torch.no_grad()
def evaluate_generative_beam(model, eval_dls, device, cfg, tokenizer,
max_gen_len, amp_dtype, beam_width):
"""Evaluate open-ended benchmarks using beam search decoding."""
model.eval()
results = {}
start_token_id = tokenizer.bos_token_id or tokenizer.cls_token_id or 1
eos_token_id = tokenizer.eos_token_id
for benchmark, dl in eval_dls.items():
predictions, ground_truths = [], []
for batch in dl:
bt = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type=="cuda"):
vis_tok = model.vis(bt["pixel_values"]).float()
txt_tok = model.txt(bt["input_ids"], bt["attention_mask"]).float()
evidence, _, _ = model.evidence(vis_tok, txt_tok, bt["attention_mask"])
if model._use_rollout:
_, z_final, _ = model.rollout(evidence)
else:
B2 = bt["batch_size"]
z_final = model.rollout.init_tokens.expand(B2,-1,-1) + model.rollout.z0_proj(
F.adaptive_avg_pool1d(evidence.permute(0,2,1), model.rollout.num_tokens).permute(0,2,1))
gen_ids = model.gen_head.generate_beam(
z_final, evidence, start_token_id,
max_length=max_gen_len, eos_token_id=eos_token_id,
beam_width=beam_width,
)
for i in range(gen_ids.size(0)):
predictions.append(tokenizer.decode(gen_ids[i], skip_special_tokens=True).strip())
ground_truths.extend(batch["all_answers"])
# Log a few sample predictions for debugging
for j in range(min(3, len(predictions))):
gt_sample = ground_truths[j] if j < len(ground_truths) else "?"
log.info(f" [{benchmark}] pred: '{predictions[j]}' | gt: '{gt_sample}'")
if benchmark == "docvqa":
results[benchmark] = {"anls": compute_anls(predictions, ground_truths)}
elif benchmark == "chartqa":
gt_flat = [g[0] if isinstance(g, list) else g for g in ground_truths]
results[benchmark] = {"relaxed_accuracy": compute_relaxed_accuracy(predictions, gt_flat)}
elif benchmark == "textvqa":
results[benchmark] = {"vqa_accuracy": compute_vqa_accuracy(predictions, ground_truths)}
model.train()
return results
# ══════════════════════════════════════════════════════════════════════════
# CHECKPOINT & HUB
# ══════════════════════════════════════════════════════════════════════════
def save_checkpoint(model, cfg, epoch, mc_acc, open_results, composite):
path = os.path.join(cfg.output_dir, "checkpoint_best.pt")
torch.save({
"evidence": model.evidence.state_dict(),
"rollout": model.rollout.state_dict(),
"disc": model.disc.state_dict(),
"gen_head": model.gen_head.state_dict(),
"target_ev": model.target.t_ev.state_dict(),
"target_ro": model.target.t_ro.state_dict(),
"config": cfg.__dict__,
"epoch": epoch, "mc_eval_acc": mc_acc,
"open_results": open_results, "composite_score": composite,
"phase": "3.1",
}, path)
log.info(f"Saved checkpoint: {path} (composite={composite:.1f})")
def push_results(cfg, args, best_composite, eval_results):
try:
from huggingface_hub import HfApi
api = HfApi()
results = {
"run_name": cfg.run_name, "phase": "3.1",
"backbone": cfg.backbone, "K": cfg.K,
"best_composite_score": best_composite,
"gen_weight": args.gen_weight, "max_gen_len": args.max_gen_len,
"beam_width": args.beam_width,
"tf_start": args.tf_start, "tf_end": args.tf_end,
"epochs": cfg.epochs, "core_lr": args.core_lr,
"open_results": {k: v for k, v in (eval_results or {}).items()},
"improvements": ["gen_weight_2.0", "gen_len_32", "scheduled_sampling", "beam_search"],
}
rp = os.path.join(cfg.output_dir, f"results_{cfg.run_name}.json")
with open(rp, "w") as f:
json.dump(results, f, indent=2)
api.upload_file(path_or_fileobj=rp, path_in_repo=f"results/{cfg.run_name}.json",
repo_id=cfg.hub_model_id, repo_type="model")
best_ckpt = os.path.join(cfg.output_dir, "checkpoint_best.pt")
if os.path.exists(best_ckpt):
api.upload_file(path_or_fileobj=best_ckpt,
path_in_repo=f"checkpoints/{cfg.run_name}_best.pt",
repo_id=cfg.hub_model_id, repo_type="model")
log.info(f"Pushed Phase 3.1 results to {cfg.hub_model_id}")
except Exception as e:
log.error(f"Push failed: {e}")
if __name__ == "__main__":
main()