Add Phase 3.1 training: gen_weight 2.0, gen_len 32, scheduled sampling, beam search
Browse files- train_phase3_1.py +864 -0
train_phase3_1.py
ADDED
|
@@ -0,0 +1,864 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
MR-JEPA Phase 3.1 Training — Improved Generative Decoder
|
| 4 |
+
|
| 5 |
+
Loads the Phase 3.0 checkpoint (with partially-trained gen_head) and applies
|
| 6 |
+
four targeted improvements to break through the 0% generative metrics:
|
| 7 |
+
|
| 8 |
+
1. gen_weight: 0.5 → 2.0 (4× stronger generative gradient signal)
|
| 9 |
+
2. max_gen_len: 64 → 32 (shorter targets, less padding noise)
|
| 10 |
+
3. Scheduled sampling (100% teacher forcing → 50% free-running, linear)
|
| 11 |
+
4. Beam search evaluation (beam_width=5 instead of greedy argmax)
|
| 12 |
+
|
| 13 |
+
Resumes from: checkpoints/hybrid_main_phase3_best.pt (gen_head pre-trained)
|
| 14 |
+
Training data: same as Phase 3.0 (ScienceQA MC + DocVQA/ChartQA/TextVQA open-ended)
|
| 15 |
+
|
| 16 |
+
Usage:
|
| 17 |
+
python train_phase3_1.py
|
| 18 |
+
python train_phase3_1.py --gen_weight 2.0 --max_gen_len 32 --beam_width 5
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
import json
|
| 24 |
+
import math
|
| 25 |
+
import copy
|
| 26 |
+
import random
|
| 27 |
+
import logging
|
| 28 |
+
import argparse
|
| 29 |
+
from collections import defaultdict
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
import torch.nn.functional as F
|
| 35 |
+
from torch.optim import AdamW
|
| 36 |
+
from torch.utils.data import Dataset, DataLoader
|
| 37 |
+
from PIL import Image
|
| 38 |
+
|
| 39 |
+
logging.basicConfig(
|
| 40 |
+
level=logging.INFO,
|
| 41 |
+
format="%(asctime)s | %(levelname)s | %(message)s",
|
| 42 |
+
datefmt="%H:%M:%S",
|
| 43 |
+
)
|
| 44 |
+
log = logging.getLogger("mrjepa-p3.1")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 48 |
+
# OPEN-ENDED DATASET (same as Phase 3.0)
|
| 49 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 50 |
+
|
| 51 |
+
class OpenEndedDataset(Dataset):
|
| 52 |
+
def __init__(self, benchmark, split, max_samples=0, transform=None,
|
| 53 |
+
tokenizer=None, max_len=192, max_gen_len=32):
|
| 54 |
+
from datasets import load_dataset
|
| 55 |
+
self.benchmark = benchmark
|
| 56 |
+
self.transform = transform
|
| 57 |
+
self.tokenizer = tokenizer
|
| 58 |
+
self.max_len = max_len
|
| 59 |
+
self.max_gen_len = max_gen_len
|
| 60 |
+
log.info(f"Loading {benchmark} {split}...")
|
| 61 |
+
if benchmark == "docvqa":
|
| 62 |
+
ds = load_dataset("lmms-lab/DocVQA", "DocVQA", split=split)
|
| 63 |
+
elif benchmark == "chartqa":
|
| 64 |
+
ds = load_dataset("lmms-lab/ChartQA", split=split)
|
| 65 |
+
elif benchmark == "textvqa":
|
| 66 |
+
ds = load_dataset("lmms-lab/textvqa", split=split)
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f"Unknown benchmark: {benchmark}")
|
| 69 |
+
if max_samples > 0:
|
| 70 |
+
ds = ds.select(range(min(max_samples, len(ds))))
|
| 71 |
+
self.data = ds
|
| 72 |
+
log.info(f"Loaded {len(ds)} samples from {benchmark} {split}")
|
| 73 |
+
|
| 74 |
+
def __len__(self):
|
| 75 |
+
return len(self.data)
|
| 76 |
+
|
| 77 |
+
def __getitem__(self, idx):
|
| 78 |
+
row = self.data[idx]
|
| 79 |
+
img = row.get("image")
|
| 80 |
+
if img is None:
|
| 81 |
+
img = Image.new("RGB", (256, 256), "white")
|
| 82 |
+
else:
|
| 83 |
+
img = img.convert("RGB")
|
| 84 |
+
question = row["question"]
|
| 85 |
+
if self.benchmark == "docvqa":
|
| 86 |
+
answers = row.get("answers", [""])
|
| 87 |
+
answer = answers[0] if answers else ""
|
| 88 |
+
all_answers = answers
|
| 89 |
+
elif self.benchmark == "chartqa":
|
| 90 |
+
answer = str(row.get("answer", ""))
|
| 91 |
+
all_answers = [answer]
|
| 92 |
+
elif self.benchmark == "textvqa":
|
| 93 |
+
answers = row.get("answers", [""])
|
| 94 |
+
from collections import Counter
|
| 95 |
+
answer_counts = Counter(a.lower().strip() for a in answers)
|
| 96 |
+
answer = answer_counts.most_common(1)[0][0] if answer_counts else ""
|
| 97 |
+
all_answers = answers
|
| 98 |
+
else:
|
| 99 |
+
answer = ""
|
| 100 |
+
all_answers = [""]
|
| 101 |
+
ocr_tokens = row.get("ocr_tokens", [])
|
| 102 |
+
ocr_text = " ".join(ocr_tokens[:50]) if ocr_tokens else ""
|
| 103 |
+
text = question
|
| 104 |
+
if ocr_text:
|
| 105 |
+
text += f" [OCR: {ocr_text}]"
|
| 106 |
+
return {
|
| 107 |
+
"image": img, "text": text, "answer": answer,
|
| 108 |
+
"all_answers": all_answers, "benchmark": self.benchmark,
|
| 109 |
+
"ocr_text": ocr_text,
|
| 110 |
+
"question_type": row.get("type", row.get("question_types", [""])),
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def collate_open_ended(batch, transform, tokenizer, max_len, max_gen_len):
|
| 115 |
+
images = [s["image"] for s in batch]
|
| 116 |
+
texts = [s["text"] for s in batch]
|
| 117 |
+
answers = [s["answer"] for s in batch]
|
| 118 |
+
if hasattr(transform, '__call__') and not hasattr(transform, 'feature_extractor'):
|
| 119 |
+
pixel_values = torch.stack([transform(img) for img in images])
|
| 120 |
+
else:
|
| 121 |
+
pixel_values = transform(images=images, return_tensors="pt")["pixel_values"]
|
| 122 |
+
tok = tokenizer(texts, padding="max_length", truncation=True,
|
| 123 |
+
max_length=max_len, return_tensors="pt")
|
| 124 |
+
answer_texts = [a if a else " " for a in answers]
|
| 125 |
+
gen_tok = tokenizer(answer_texts, padding="max_length", truncation=True,
|
| 126 |
+
max_length=max_gen_len, return_tensors="pt")
|
| 127 |
+
return {
|
| 128 |
+
"pixel_values": pixel_values,
|
| 129 |
+
"input_ids": tok["input_ids"],
|
| 130 |
+
"attention_mask": tok["attention_mask"],
|
| 131 |
+
"gen_target_ids": gen_tok["input_ids"],
|
| 132 |
+
"gen_attention_mask": gen_tok["attention_mask"],
|
| 133 |
+
"batch_size": len(batch),
|
| 134 |
+
"benchmarks": [s["benchmark"] for s in batch],
|
| 135 |
+
"all_answers": [s["all_answers"] for s in batch],
|
| 136 |
+
"question_types": [s.get("question_type", "") for s in batch],
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 141 |
+
# GENERATIVE HEAD with SCHEDULED SAMPLING + BEAM SEARCH
|
| 142 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 143 |
+
|
| 144 |
+
class GenerativeDecoderLayer(nn.Module):
|
| 145 |
+
def __init__(self, hidden_dim, num_heads, dropout=0.1):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.self_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads,
|
| 148 |
+
dropout=dropout, batch_first=True)
|
| 149 |
+
self.self_attn_norm = nn.LayerNorm(hidden_dim)
|
| 150 |
+
self.state_cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads,
|
| 151 |
+
dropout=dropout, batch_first=True)
|
| 152 |
+
self.state_cross_norm = nn.LayerNorm(hidden_dim)
|
| 153 |
+
self.evidence_cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads,
|
| 154 |
+
dropout=dropout, batch_first=True)
|
| 155 |
+
self.evidence_cross_norm = nn.LayerNorm(hidden_dim)
|
| 156 |
+
self.ffn = nn.Sequential(nn.Linear(hidden_dim, hidden_dim * 4), nn.GELU(),
|
| 157 |
+
nn.Dropout(dropout), nn.Linear(hidden_dim * 4, hidden_dim),
|
| 158 |
+
nn.Dropout(dropout))
|
| 159 |
+
self.ffn_norm = nn.LayerNorm(hidden_dim)
|
| 160 |
+
|
| 161 |
+
def forward(self, x, z_final, evidence, causal_mask=None):
|
| 162 |
+
r = x; x2 = self.self_attn_norm(x); x2, _ = self.self_attn(x2, x2, x2, attn_mask=causal_mask); x = r + x2
|
| 163 |
+
r = x; x2 = self.state_cross_norm(x); x2, _ = self.state_cross_attn(x2, z_final, z_final); x = r + x2
|
| 164 |
+
r = x; x2 = self.evidence_cross_norm(x); x2, _ = self.evidence_cross_attn(x2, evidence, evidence); x = r + x2
|
| 165 |
+
r = x; x = r + self.ffn(self.ffn_norm(x))
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class GenerativeHead(nn.Module):
|
| 170 |
+
"""
|
| 171 |
+
Phase 3.1 generative decoder with:
|
| 172 |
+
- Scheduled sampling during training (teacher forcing warmup)
|
| 173 |
+
- Beam search during evaluation
|
| 174 |
+
"""
|
| 175 |
+
def __init__(self, hidden_dim, vocab_size, num_layers=4, num_heads=12,
|
| 176 |
+
max_gen_len=32, dropout=0.1):
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.hidden_dim = hidden_dim
|
| 179 |
+
self.vocab_size = vocab_size
|
| 180 |
+
self.max_gen_len = max_gen_len
|
| 181 |
+
self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
|
| 182 |
+
self.pos_embedding = nn.Embedding(max_gen_len, hidden_dim)
|
| 183 |
+
self.layers = nn.ModuleList([
|
| 184 |
+
GenerativeDecoderLayer(hidden_dim, num_heads, dropout) for _ in range(num_layers)
|
| 185 |
+
])
|
| 186 |
+
self.output_norm = nn.LayerNorm(hidden_dim)
|
| 187 |
+
self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
|
| 188 |
+
self.lm_head.weight = self.token_embedding.weight
|
| 189 |
+
|
| 190 |
+
def _decode_step(self, token_ids, z_final, evidence):
|
| 191 |
+
"""Run decoder on a token sequence, return logits for the last position."""
|
| 192 |
+
seq_len = token_ids.size(1)
|
| 193 |
+
positions = torch.arange(seq_len, device=token_ids.device).unsqueeze(0)
|
| 194 |
+
x = self.token_embedding(token_ids) + self.pos_embedding(positions)
|
| 195 |
+
causal_mask = torch.triu(
|
| 196 |
+
torch.ones(seq_len, seq_len, device=token_ids.device, dtype=torch.bool), diagonal=1
|
| 197 |
+
)
|
| 198 |
+
for layer in self.layers:
|
| 199 |
+
x = layer(x, z_final, evidence, causal_mask)
|
| 200 |
+
logits = self.lm_head(self.output_norm(x))
|
| 201 |
+
return logits
|
| 202 |
+
|
| 203 |
+
def forward(self, z_final, evidence, target_ids, pad_token_id=0,
|
| 204 |
+
teacher_forcing_ratio=1.0):
|
| 205 |
+
"""
|
| 206 |
+
Training forward with scheduled sampling.
|
| 207 |
+
|
| 208 |
+
teacher_forcing_ratio=1.0 → pure teacher forcing (use ground truth at every step)
|
| 209 |
+
teacher_forcing_ratio=0.5 → 50% of tokens use model's own prediction
|
| 210 |
+
"""
|
| 211 |
+
B, seq_len = target_ids.shape
|
| 212 |
+
device = target_ids.device
|
| 213 |
+
|
| 214 |
+
if teacher_forcing_ratio >= 1.0:
|
| 215 |
+
# ── Pure teacher forcing (fast, batched) ──
|
| 216 |
+
logits = self._decode_step(target_ids, z_final, evidence)
|
| 217 |
+
else:
|
| 218 |
+
# ── Scheduled sampling: mix teacher forcing with free-running ──
|
| 219 |
+
logits = torch.zeros(B, seq_len, self.vocab_size, device=device)
|
| 220 |
+
current_input = target_ids[:, :1] # start with first token
|
| 221 |
+
|
| 222 |
+
for t in range(seq_len):
|
| 223 |
+
step_logits = self._decode_step(current_input, z_final, evidence)
|
| 224 |
+
logits[:, t] = step_logits[:, -1] # logits at last position
|
| 225 |
+
|
| 226 |
+
if t < seq_len - 1:
|
| 227 |
+
# Decide: teacher forcing or free-running for next input
|
| 228 |
+
use_teacher = random.random() < teacher_forcing_ratio
|
| 229 |
+
if use_teacher:
|
| 230 |
+
next_token = target_ids[:, t + 1:t + 2]
|
| 231 |
+
else:
|
| 232 |
+
next_token = step_logits[:, -1].argmax(dim=-1, keepdim=True)
|
| 233 |
+
current_input = torch.cat([current_input, next_token], dim=1)
|
| 234 |
+
|
| 235 |
+
# Loss: next-token prediction
|
| 236 |
+
shift_logits = logits[:, :-1].contiguous()
|
| 237 |
+
shift_labels = target_ids[:, 1:].contiguous()
|
| 238 |
+
loss = F.cross_entropy(
|
| 239 |
+
shift_logits.view(-1, self.vocab_size),
|
| 240 |
+
shift_labels.view(-1),
|
| 241 |
+
ignore_index=pad_token_id,
|
| 242 |
+
)
|
| 243 |
+
return logits, loss
|
| 244 |
+
|
| 245 |
+
@torch.no_grad()
|
| 246 |
+
def generate_greedy(self, z_final, evidence, start_token_id,
|
| 247 |
+
max_length=32, eos_token_id=None):
|
| 248 |
+
"""Greedy autoregressive generation (fallback)."""
|
| 249 |
+
B = z_final.size(0)
|
| 250 |
+
device = z_final.device
|
| 251 |
+
generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
|
| 252 |
+
for step in range(max_length - 1):
|
| 253 |
+
logits = self._decode_step(generated, z_final, evidence)
|
| 254 |
+
next_token = logits[:, -1].argmax(dim=-1, keepdim=True)
|
| 255 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 256 |
+
if eos_token_id is not None and (next_token == eos_token_id).all():
|
| 257 |
+
break
|
| 258 |
+
return generated
|
| 259 |
+
|
| 260 |
+
@torch.no_grad()
|
| 261 |
+
def generate_beam(self, z_final, evidence, start_token_id,
|
| 262 |
+
max_length=32, eos_token_id=None, beam_width=5):
|
| 263 |
+
"""
|
| 264 |
+
Beam search generation.
|
| 265 |
+
|
| 266 |
+
Processes each sample in the batch independently with beam search.
|
| 267 |
+
Returns the highest-scoring complete sequence per sample.
|
| 268 |
+
"""
|
| 269 |
+
B = z_final.size(0)
|
| 270 |
+
device = z_final.device
|
| 271 |
+
all_results = []
|
| 272 |
+
|
| 273 |
+
for b in range(B):
|
| 274 |
+
z_b = z_final[b:b+1] # (1, N_s, D)
|
| 275 |
+
ev_b = evidence[b:b+1] # (1, N_e, D)
|
| 276 |
+
|
| 277 |
+
# Each beam: (log_prob, token_ids_tensor)
|
| 278 |
+
beams = [(0.0, torch.tensor([[start_token_id]], dtype=torch.long, device=device))]
|
| 279 |
+
completed = []
|
| 280 |
+
|
| 281 |
+
for step in range(max_length - 1):
|
| 282 |
+
candidates = []
|
| 283 |
+
for score, seq in beams:
|
| 284 |
+
if eos_token_id is not None and seq[0, -1].item() == eos_token_id:
|
| 285 |
+
completed.append((score, seq))
|
| 286 |
+
continue
|
| 287 |
+
|
| 288 |
+
logits = self._decode_step(seq, z_b, ev_b) # (1, T, V)
|
| 289 |
+
log_probs = F.log_softmax(logits[0, -1], dim=-1) # (V,)
|
| 290 |
+
|
| 291 |
+
topk_lp, topk_ids = log_probs.topk(beam_width)
|
| 292 |
+
for k in range(beam_width):
|
| 293 |
+
new_score = score + topk_lp[k].item()
|
| 294 |
+
new_seq = torch.cat([seq, topk_ids[k:k+1].unsqueeze(0)], dim=1)
|
| 295 |
+
candidates.append((new_score, new_seq))
|
| 296 |
+
|
| 297 |
+
if not candidates:
|
| 298 |
+
break
|
| 299 |
+
|
| 300 |
+
# Length-normalize scores and keep top beams
|
| 301 |
+
candidates.sort(key=lambda x: x[0] / x[1].size(1), reverse=True)
|
| 302 |
+
beams = candidates[:beam_width]
|
| 303 |
+
|
| 304 |
+
# Early stop if all beams ended
|
| 305 |
+
if all(eos_token_id is not None and seq[0, -1].item() == eos_token_id
|
| 306 |
+
for _, seq in beams):
|
| 307 |
+
completed.extend(beams)
|
| 308 |
+
break
|
| 309 |
+
|
| 310 |
+
# Merge completed and remaining, pick best
|
| 311 |
+
all_beams = completed + beams
|
| 312 |
+
if all_beams:
|
| 313 |
+
best = max(all_beams, key=lambda x: x[0] / max(x[1].size(1), 1))
|
| 314 |
+
all_results.append(best[1])
|
| 315 |
+
else:
|
| 316 |
+
all_results.append(torch.tensor([[start_token_id]], dtype=torch.long, device=device))
|
| 317 |
+
|
| 318 |
+
# Pad to same length
|
| 319 |
+
max_len = max(r.size(1) for r in all_results)
|
| 320 |
+
padded = torch.full((B, max_len), 0, dtype=torch.long, device=device)
|
| 321 |
+
for i, r in enumerate(all_results):
|
| 322 |
+
padded[i, :r.size(1)] = r[0]
|
| 323 |
+
return padded
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 327 |
+
# EVALUATION METRICS (same as Phase 3.0)
|
| 328 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 329 |
+
|
| 330 |
+
def normalized_levenshtein(s1, s2):
|
| 331 |
+
s1, s2 = s1.lower().strip(), s2.lower().strip()
|
| 332 |
+
if s1 == s2: return 0.0
|
| 333 |
+
l1, l2 = len(s1), len(s2)
|
| 334 |
+
if l1 == 0 or l2 == 0: return 1.0
|
| 335 |
+
m = [[0]*(l2+1) for _ in range(l1+1)]
|
| 336 |
+
for i in range(l1+1): m[i][0] = i
|
| 337 |
+
for j in range(l2+1): m[0][j] = j
|
| 338 |
+
for i in range(1,l1+1):
|
| 339 |
+
for j in range(1,l2+1):
|
| 340 |
+
c = 0 if s1[i-1]==s2[j-1] else 1
|
| 341 |
+
m[i][j] = min(m[i-1][j]+1, m[i][j-1]+1, m[i-1][j-1]+c)
|
| 342 |
+
return m[l1][l2]/max(l1,l2)
|
| 343 |
+
|
| 344 |
+
def compute_anls(predictions, ground_truths, threshold=0.5):
|
| 345 |
+
scores = []
|
| 346 |
+
for pred, gts in zip(predictions, ground_truths):
|
| 347 |
+
mx = max((1.0-normalized_levenshtein(str(pred),str(gt)) if normalized_levenshtein(str(pred),str(gt))<threshold else 0.0) for gt in gts) if gts else 0.0
|
| 348 |
+
scores.append(mx)
|
| 349 |
+
return np.mean(scores)*100 if scores else 0.0
|
| 350 |
+
|
| 351 |
+
def compute_vqa_accuracy(predictions, ground_truths):
|
| 352 |
+
scores = []
|
| 353 |
+
for pred, gts in zip(predictions, ground_truths):
|
| 354 |
+
pn = str(pred).lower().strip()
|
| 355 |
+
scores.append(min(sum(1 for gt in gts if str(gt).lower().strip()==pn)/3.0, 1.0))
|
| 356 |
+
return np.mean(scores)*100 if scores else 0.0
|
| 357 |
+
|
| 358 |
+
def compute_relaxed_accuracy(predictions, ground_truths, tolerance=0.05):
|
| 359 |
+
correct = []
|
| 360 |
+
for pred, gt in zip(predictions, ground_truths):
|
| 361 |
+
ps, gs = str(pred).strip().lower(), str(gt).strip().lower()
|
| 362 |
+
try:
|
| 363 |
+
gv = float(gs.replace(',','').replace('%',''))
|
| 364 |
+
pv = float(ps.replace(',','').replace('%',''))
|
| 365 |
+
correct.append(abs(pv-gv)/abs(gv)<=tolerance if gv!=0 else abs(pv)<=tolerance)
|
| 366 |
+
except (ValueError,ZeroDivisionError):
|
| 367 |
+
correct.append(ps==gs)
|
| 368 |
+
return np.mean(correct)*100 if correct else 0.0
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 372 |
+
# SCHEDULED SAMPLING SCHEDULE
|
| 373 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 374 |
+
|
| 375 |
+
def get_teacher_forcing_ratio(epoch, total_epochs, start_ratio=1.0, end_ratio=0.5):
|
| 376 |
+
"""
|
| 377 |
+
Linear decay from start_ratio to end_ratio over training.
|
| 378 |
+
Epoch 0: 100% teacher forcing (pure ground truth).
|
| 379 |
+
Final epoch: 50% teacher forcing (half free-running).
|
| 380 |
+
|
| 381 |
+
This bridges the train/eval gap: during eval the model generates freely,
|
| 382 |
+
so training must gradually expose it to its own predictions.
|
| 383 |
+
"""
|
| 384 |
+
if total_epochs <= 1:
|
| 385 |
+
return start_ratio
|
| 386 |
+
progress = epoch / (total_epochs - 1)
|
| 387 |
+
return start_ratio - (start_ratio - end_ratio) * progress
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 391 |
+
# MAIN
|
| 392 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 393 |
+
|
| 394 |
+
def download_checkpoint(hub_model_id, filename):
|
| 395 |
+
from huggingface_hub import hf_hub_download
|
| 396 |
+
path = hf_hub_download(repo_id=hub_model_id, filename=filename, repo_type="model")
|
| 397 |
+
log.info(f"Downloaded checkpoint: {path}")
|
| 398 |
+
return path
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def main():
|
| 402 |
+
parser = argparse.ArgumentParser(description="MR-JEPA Phase 3.1 Training")
|
| 403 |
+
parser.add_argument("--checkpoint", type=str, default=None,
|
| 404 |
+
help="Local path to checkpoint. Default: download Phase 3.0 from Hub.")
|
| 405 |
+
parser.add_argument("--hub_model_id", default="JorgeAV/MR-JEPA")
|
| 406 |
+
parser.add_argument("--run_name", default="hybrid_main_phase3_1")
|
| 407 |
+
parser.add_argument("--epochs", type=int, default=10)
|
| 408 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 409 |
+
parser.add_argument("--grad_accum", type=int, default=16)
|
| 410 |
+
parser.add_argument("--core_lr", type=float, default=5e-5)
|
| 411 |
+
parser.add_argument("--backbone_lr", type=float, default=5e-6)
|
| 412 |
+
parser.add_argument("--text_lr", type=float, default=5e-6)
|
| 413 |
+
# ── Phase 3.1 improvements ──
|
| 414 |
+
parser.add_argument("--gen_weight", type=float, default=2.0,
|
| 415 |
+
help="Generative loss weight (was 0.5 in 3.0)")
|
| 416 |
+
parser.add_argument("--max_gen_len", type=int, default=32,
|
| 417 |
+
help="Max generation length (was 64 in 3.0)")
|
| 418 |
+
parser.add_argument("--beam_width", type=int, default=5,
|
| 419 |
+
help="Beam search width for evaluation (was greedy in 3.0)")
|
| 420 |
+
parser.add_argument("--tf_start", type=float, default=1.0,
|
| 421 |
+
help="Teacher forcing ratio at epoch 0")
|
| 422 |
+
parser.add_argument("--tf_end", type=float, default=0.5,
|
| 423 |
+
help="Teacher forcing ratio at final epoch")
|
| 424 |
+
# ─────────────────────────��────
|
| 425 |
+
parser.add_argument("--max_eval_samples", type=int, default=200)
|
| 426 |
+
parser.add_argument("--max_train_samples", type=int, default=0)
|
| 427 |
+
parser.add_argument("--output_dir", default="./outputs/mrjepa_phase3_1")
|
| 428 |
+
parser.add_argument("--trackio_space", default="JorgeAV/MR-JEPA-Trackio")
|
| 429 |
+
args = parser.parse_args()
|
| 430 |
+
|
| 431 |
+
# ── Import Phase 1 model definitions ──
|
| 432 |
+
log.info("Downloading Phase 1 training script for model definitions...")
|
| 433 |
+
from huggingface_hub import hf_hub_download
|
| 434 |
+
p1_script = hf_hub_download(repo_id=args.hub_model_id, filename="train_mrjepa.py", repo_type="model")
|
| 435 |
+
import importlib.util
|
| 436 |
+
spec = importlib.util.spec_from_file_location("train_mrjepa", p1_script)
|
| 437 |
+
p1 = importlib.util.module_from_spec(spec)
|
| 438 |
+
spec.loader.exec_module(p1)
|
| 439 |
+
|
| 440 |
+
# ── Load Phase 3.0 checkpoint (includes gen_head weights) ──
|
| 441 |
+
if args.checkpoint and os.path.exists(args.checkpoint):
|
| 442 |
+
ckpt_path = args.checkpoint
|
| 443 |
+
else:
|
| 444 |
+
ckpt_path = download_checkpoint(args.hub_model_id,
|
| 445 |
+
"checkpoints/hybrid_main_phase3_best.pt")
|
| 446 |
+
|
| 447 |
+
log.info(f"Loading Phase 3.0 checkpoint: {ckpt_path}")
|
| 448 |
+
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 449 |
+
|
| 450 |
+
saved_cfg = ckpt["config"]
|
| 451 |
+
cfg = p1.Config()
|
| 452 |
+
for k, v in saved_cfg.items():
|
| 453 |
+
if hasattr(cfg, k):
|
| 454 |
+
setattr(cfg, k, v)
|
| 455 |
+
|
| 456 |
+
cfg.phase = 3
|
| 457 |
+
cfg.epochs = args.epochs
|
| 458 |
+
cfg.batch_size = args.batch_size
|
| 459 |
+
cfg.grad_accum = args.grad_accum
|
| 460 |
+
cfg.lr = args.core_lr
|
| 461 |
+
cfg.backbone_lr = args.backbone_lr
|
| 462 |
+
cfg.output_dir = args.output_dir
|
| 463 |
+
cfg.run_name = args.run_name
|
| 464 |
+
cfg.freeze_backbone = True
|
| 465 |
+
cfg.freeze_text = True
|
| 466 |
+
cfg.max_eval_samples = args.max_eval_samples
|
| 467 |
+
cfg.resolve()
|
| 468 |
+
|
| 469 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 470 |
+
log.info(f"Device: {device}")
|
| 471 |
+
os.makedirs(cfg.output_dir, exist_ok=True)
|
| 472 |
+
|
| 473 |
+
# ── Trackio ──
|
| 474 |
+
import trackio
|
| 475 |
+
trackio.init(
|
| 476 |
+
name=args.run_name, project="MR-JEPA", space_id=args.trackio_space,
|
| 477 |
+
config={
|
| 478 |
+
"phase": "3.1", "epochs": args.epochs,
|
| 479 |
+
"core_lr": args.core_lr, "backbone_lr": args.backbone_lr,
|
| 480 |
+
"text_lr": args.text_lr, "gen_weight": args.gen_weight,
|
| 481 |
+
"max_gen_len": args.max_gen_len, "beam_width": args.beam_width,
|
| 482 |
+
"tf_start": args.tf_start, "tf_end": args.tf_end,
|
| 483 |
+
"batch_size": args.batch_size, "grad_accum": args.grad_accum,
|
| 484 |
+
"backbone": cfg.backbone, "K": cfg.K,
|
| 485 |
+
"improvements": "gen_weight_2.0, gen_len_32, scheduled_sampling, beam_search",
|
| 486 |
+
}
|
| 487 |
+
)
|
| 488 |
+
log.info(f"Trackio → https://huggingface.co/spaces/{args.trackio_space}")
|
| 489 |
+
|
| 490 |
+
# ── Build model ──
|
| 491 |
+
log.info("Building model...")
|
| 492 |
+
model = p1.MRJEPAModel(cfg)
|
| 493 |
+
model.evidence.load_state_dict(ckpt["evidence"])
|
| 494 |
+
model.rollout.load_state_dict(ckpt["rollout"])
|
| 495 |
+
model.disc.load_state_dict(ckpt["disc"])
|
| 496 |
+
model.target.t_ev.load_state_dict(ckpt["target_ev"])
|
| 497 |
+
model.target.t_ro.load_state_dict(ckpt["target_ro"])
|
| 498 |
+
log.info(f"Loaded core weights from Phase 3.0 (epoch={ckpt.get('epoch','?')}, "
|
| 499 |
+
f"composite={ckpt.get('composite_score','?')})")
|
| 500 |
+
|
| 501 |
+
# ── Generative head: new architecture with max_gen_len=32 ──
|
| 502 |
+
tokenizer = model.txt.tokenizer
|
| 503 |
+
actual_vocab_size = len(tokenizer)
|
| 504 |
+
|
| 505 |
+
gen_head = GenerativeHead(
|
| 506 |
+
hidden_dim=cfg.rollout_dim,
|
| 507 |
+
vocab_size=actual_vocab_size,
|
| 508 |
+
num_layers=4,
|
| 509 |
+
num_heads=cfg.predictor_heads,
|
| 510 |
+
max_gen_len=args.max_gen_len,
|
| 511 |
+
dropout=0.1,
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# Load Phase 3.0 gen_head weights where shapes match
|
| 515 |
+
if "gen_head" in ckpt:
|
| 516 |
+
p3_gen = ckpt["gen_head"]
|
| 517 |
+
new_sd = gen_head.state_dict()
|
| 518 |
+
loaded, skipped = 0, 0
|
| 519 |
+
for k, v in p3_gen.items():
|
| 520 |
+
if k in new_sd and new_sd[k].shape == v.shape:
|
| 521 |
+
new_sd[k] = v
|
| 522 |
+
loaded += 1
|
| 523 |
+
elif k in new_sd:
|
| 524 |
+
skipped += 1
|
| 525 |
+
log.info(f" Shape mismatch for {k}: ckpt {v.shape} vs new {new_sd[k].shape}")
|
| 526 |
+
else:
|
| 527 |
+
skipped += 1
|
| 528 |
+
gen_head.load_state_dict(new_sd)
|
| 529 |
+
log.info(f"Loaded {loaded} gen_head params from Phase 3.0 ({skipped} skipped)")
|
| 530 |
+
else:
|
| 531 |
+
log.warning("No gen_head in checkpoint — starting from scratch")
|
| 532 |
+
|
| 533 |
+
model.gen_head = gen_head
|
| 534 |
+
|
| 535 |
+
# ── Unfreeze backbone layers ──
|
| 536 |
+
log.info("Unfreezing last 6 visual layers, last 4 text layers")
|
| 537 |
+
model.vis.unfreeze_last(6)
|
| 538 |
+
model.txt.unfreeze_last(4)
|
| 539 |
+
|
| 540 |
+
model = model.to(device)
|
| 541 |
+
total_p = sum(p.numel() for p in model.parameters())
|
| 542 |
+
train_p = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 543 |
+
log.info(f"Total: {total_p:,} | Trainable: {train_p:,} ({100*train_p/total_p:.1f}%)")
|
| 544 |
+
|
| 545 |
+
# ── Datasets ──
|
| 546 |
+
transform = model.vis.get_transform()
|
| 547 |
+
mc_max = args.max_train_samples if args.max_train_samples > 0 else 0
|
| 548 |
+
train_mc_ds = p1.ScienceQADataset("train", max_samples=mc_max, transform=transform,
|
| 549 |
+
tokenizer=tokenizer, max_len=cfg.max_text_len,
|
| 550 |
+
max_opts=cfg.max_options)
|
| 551 |
+
eval_mc_ds = p1.ScienceQADataset("test", max_samples=cfg.max_eval_samples,
|
| 552 |
+
transform=transform, tokenizer=tokenizer,
|
| 553 |
+
max_len=cfg.max_text_len, max_opts=cfg.max_options)
|
| 554 |
+
mc_coll = lambda batch: p1.collate_fn(batch, transform, tokenizer, cfg.max_text_len, cfg.max_options)
|
| 555 |
+
train_mc_dl = DataLoader(train_mc_ds, batch_size=cfg.batch_size, shuffle=True,
|
| 556 |
+
num_workers=2, collate_fn=mc_coll, pin_memory=True, drop_last=True)
|
| 557 |
+
eval_mc_dl = DataLoader(eval_mc_ds, batch_size=cfg.batch_size, shuffle=False,
|
| 558 |
+
num_workers=2, collate_fn=mc_coll, pin_memory=True)
|
| 559 |
+
|
| 560 |
+
max_open = args.max_train_samples if args.max_train_samples > 0 else 5000
|
| 561 |
+
open_coll = lambda batch: collate_open_ended(batch, transform, tokenizer,
|
| 562 |
+
cfg.max_text_len, args.max_gen_len)
|
| 563 |
+
|
| 564 |
+
train_open_dls = {}
|
| 565 |
+
eval_open_dls = {}
|
| 566 |
+
for bm, tr_split, ev_split in [("docvqa","validation","validation"),
|
| 567 |
+
("chartqa","test","test"),
|
| 568 |
+
("textvqa","train","validation")]:
|
| 569 |
+
train_open_dls[bm] = DataLoader(
|
| 570 |
+
OpenEndedDataset(bm, tr_split, max_samples=max_open, transform=transform,
|
| 571 |
+
tokenizer=tokenizer, max_len=cfg.max_text_len,
|
| 572 |
+
max_gen_len=args.max_gen_len),
|
| 573 |
+
batch_size=cfg.batch_size, shuffle=True, num_workers=2,
|
| 574 |
+
collate_fn=open_coll, pin_memory=True, drop_last=True)
|
| 575 |
+
eval_open_dls[bm] = DataLoader(
|
| 576 |
+
OpenEndedDataset(bm, ev_split, max_samples=args.max_eval_samples,
|
| 577 |
+
transform=transform, tokenizer=tokenizer,
|
| 578 |
+
max_len=cfg.max_text_len, max_gen_len=args.max_gen_len),
|
| 579 |
+
batch_size=cfg.batch_size, shuffle=False, num_workers=2,
|
| 580 |
+
collate_fn=open_coll, pin_memory=True)
|
| 581 |
+
|
| 582 |
+
# ── Optimizer ──
|
| 583 |
+
backbone_params = [p for p in model.vis.parameters() if p.requires_grad]
|
| 584 |
+
text_params = [p for p in model.txt.parameters() if p.requires_grad]
|
| 585 |
+
bb_txt_ids = {id(p) for p in backbone_params + text_params}
|
| 586 |
+
core_params = [p for p in model.parameters() if p.requires_grad and id(p) not in bb_txt_ids]
|
| 587 |
+
param_groups = [
|
| 588 |
+
{"params": core_params, "lr": args.core_lr},
|
| 589 |
+
{"params": backbone_params, "lr": args.backbone_lr},
|
| 590 |
+
{"params": text_params, "lr": args.text_lr},
|
| 591 |
+
]
|
| 592 |
+
optimizer = AdamW(param_groups, weight_decay=cfg.weight_decay)
|
| 593 |
+
|
| 594 |
+
mc_steps = len(train_mc_dl)
|
| 595 |
+
open_steps = sum(len(dl) for dl in train_open_dls.values())
|
| 596 |
+
total_steps = cfg.epochs * (mc_steps + open_steps) // cfg.grad_accum
|
| 597 |
+
warmup_steps = int(total_steps * 0.1)
|
| 598 |
+
|
| 599 |
+
def lr_lambda(step):
|
| 600 |
+
if step < warmup_steps:
|
| 601 |
+
return step / max(warmup_steps, 1)
|
| 602 |
+
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
|
| 603 |
+
return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress))
|
| 604 |
+
|
| 605 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 606 |
+
|
| 607 |
+
pad_token_id = tokenizer.pad_token_id
|
| 608 |
+
if pad_token_id is None:
|
| 609 |
+
pad_token_id = tokenizer.eos_token_id or 0
|
| 610 |
+
|
| 611 |
+
log.info(f"Phase 3.1: {cfg.epochs} epochs | gen_weight={args.gen_weight} | "
|
| 612 |
+
f"max_gen_len={args.max_gen_len} | beam_width={args.beam_width}")
|
| 613 |
+
log.info(f" Teacher forcing: {args.tf_start:.0%} → {args.tf_end:.0%}")
|
| 614 |
+
log.info(f" MC batches/epoch: {mc_steps} | Open batches/epoch: {open_steps}")
|
| 615 |
+
log.info(f" Total opt steps: ~{total_steps} | Warmup: {warmup_steps}")
|
| 616 |
+
|
| 617 |
+
global_step = 0
|
| 618 |
+
best_composite = 0.0
|
| 619 |
+
amp_dtype = torch.bfloat16 if cfg.bf16 else torch.float32
|
| 620 |
+
trainable = [p for p in model.parameters() if p.requires_grad]
|
| 621 |
+
|
| 622 |
+
try:
|
| 623 |
+
for epoch in range(cfg.epochs):
|
| 624 |
+
model.train()
|
| 625 |
+
epoch_losses = defaultdict(list)
|
| 626 |
+
epoch_mc_correct, epoch_mc_total = 0, 0
|
| 627 |
+
optimizer.zero_grad()
|
| 628 |
+
batch_count = 0
|
| 629 |
+
|
| 630 |
+
# ── Scheduled sampling ratio for this epoch ──
|
| 631 |
+
tf_ratio = get_teacher_forcing_ratio(epoch, cfg.epochs, args.tf_start, args.tf_end)
|
| 632 |
+
log.info(f"Phase 3.1 Epoch {epoch}: teacher_forcing={tf_ratio:.2f}")
|
| 633 |
+
|
| 634 |
+
# ── MC training ──
|
| 635 |
+
log.info(f" MC training on ScienceQA...")
|
| 636 |
+
for bi, batch in enumerate(train_mc_dl):
|
| 637 |
+
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
|
| 638 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type=="cuda"):
|
| 639 |
+
losses, preds = model(**batch)
|
| 640 |
+
loss = losses["total"] / cfg.grad_accum
|
| 641 |
+
loss.backward()
|
| 642 |
+
batch_count += 1
|
| 643 |
+
if batch_count % cfg.grad_accum == 0:
|
| 644 |
+
nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm)
|
| 645 |
+
optimizer.step(); scheduler.step(); optimizer.zero_grad()
|
| 646 |
+
model.update_target(global_step, total_steps)
|
| 647 |
+
global_step += 1
|
| 648 |
+
for k, v in losses.items():
|
| 649 |
+
if isinstance(v, torch.Tensor): epoch_losses[f"mc_{k}"].append(v.item())
|
| 650 |
+
epoch_mc_correct += (preds == batch["labels"]).sum().item()
|
| 651 |
+
epoch_mc_total += batch["batch_size"]
|
| 652 |
+
if bi % 100 == 0:
|
| 653 |
+
avg = {k: np.mean(v[-100:]) for k, v in epoch_losses.items() if k.startswith("mc_")}
|
| 654 |
+
acc = epoch_mc_correct / max(epoch_mc_total, 1) * 100
|
| 655 |
+
log.info(f" E{epoch} MC B{bi}/{mc_steps} | loss={avg.get('mc_total',0):.4f} | acc={acc:.1f}%")
|
| 656 |
+
trackio.log({"train/mc_loss": avg.get("mc_total",0), "train/mc_accuracy": acc,
|
| 657 |
+
"train/lr": scheduler.get_last_lr()[0], "train/epoch": epoch,
|
| 658 |
+
"train/step": global_step, "train/tf_ratio": tf_ratio})
|
| 659 |
+
|
| 660 |
+
# ── Open-ended training (with scheduled sampling) ──
|
| 661 |
+
log.info(f" Open-ended training (tf_ratio={tf_ratio:.2f})...")
|
| 662 |
+
gen_losses = defaultdict(list)
|
| 663 |
+
open_iters = {n: iter(dl) for n, dl in train_open_dls.items()}
|
| 664 |
+
open_active = set(open_iters.keys())
|
| 665 |
+
obi = 0
|
| 666 |
+
while open_active:
|
| 667 |
+
for name in list(open_active):
|
| 668 |
+
try:
|
| 669 |
+
batch = next(open_iters[name])
|
| 670 |
+
except StopIteration:
|
| 671 |
+
open_active.discard(name); continue
|
| 672 |
+
bt = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
|
| 673 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type=="cuda"):
|
| 674 |
+
vis_tok = model.vis(bt["pixel_values"]).float()
|
| 675 |
+
txt_tok = model.txt(bt["input_ids"], bt["attention_mask"]).float()
|
| 676 |
+
evidence, _, _ = model.evidence(vis_tok, txt_tok, bt["attention_mask"])
|
| 677 |
+
if model._use_rollout:
|
| 678 |
+
traj, z_final, z_proj = model.rollout(evidence)
|
| 679 |
+
else:
|
| 680 |
+
B2 = bt["batch_size"]
|
| 681 |
+
z0 = model.rollout.init_tokens.expand(B2,-1,-1) + \
|
| 682 |
+
model.rollout.z0_proj(F.adaptive_avg_pool1d(
|
| 683 |
+
evidence.permute(0,2,1), model.rollout.num_tokens).permute(0,2,1))
|
| 684 |
+
z_final, z_proj = z0, model.rollout.out_proj(z0).unsqueeze(1)
|
| 685 |
+
|
| 686 |
+
jepa_loss_val = torch.tensor(0.0, device=device)
|
| 687 |
+
if model._use_jepa:
|
| 688 |
+
target_proj = model.target(vis_tok.detach(), txt_tok.detach(), bt["attention_mask"].detach())
|
| 689 |
+
jl = model.jepa_loss(z_proj, target_proj, torch.tensor(0.0, device=device))
|
| 690 |
+
jepa_loss_val = jl["jepa"] + jl["reg"]
|
| 691 |
+
|
| 692 |
+
# ── Generative loss with scheduled sampling ──
|
| 693 |
+
_, gen_loss = model.gen_head(
|
| 694 |
+
z_final, evidence, bt["gen_target_ids"],
|
| 695 |
+
pad_token_id=pad_token_id,
|
| 696 |
+
teacher_forcing_ratio=tf_ratio,
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
total_loss = cfg.jepa_weight * jepa_loss_val + args.gen_weight * gen_loss
|
| 700 |
+
loss = total_loss / cfg.grad_accum
|
| 701 |
+
|
| 702 |
+
loss.backward()
|
| 703 |
+
batch_count += 1
|
| 704 |
+
if batch_count % cfg.grad_accum == 0:
|
| 705 |
+
nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm)
|
| 706 |
+
optimizer.step(); scheduler.step(); optimizer.zero_grad()
|
| 707 |
+
model.update_target(global_step, total_steps); global_step += 1
|
| 708 |
+
|
| 709 |
+
gen_losses[f"{name}_gen"].append(gen_loss.item())
|
| 710 |
+
gen_losses[f"{name}_total"].append(total_loss.item())
|
| 711 |
+
obi += 1
|
| 712 |
+
if obi % 100 == 0:
|
| 713 |
+
avg = {k: np.mean(v[-100:]) for k, v in gen_losses.items()}
|
| 714 |
+
log.info(f" E{epoch} OPEN B{obi} | " + " | ".join(f"{k}={v:.4f}" for k,v in avg.items()))
|
| 715 |
+
trackio.log({f"train/{k}": v for k, v in avg.items()})
|
| 716 |
+
|
| 717 |
+
# ── Evaluation (with beam search) ──
|
| 718 |
+
log.info(f" Evaluating (beam_width={args.beam_width})...")
|
| 719 |
+
mc_eval_acc = p1.evaluate(model, eval_mc_dl, device, cfg)
|
| 720 |
+
log.info(f" ScienceQA eval accuracy: {mc_eval_acc:.1f}%")
|
| 721 |
+
|
| 722 |
+
eval_results = evaluate_generative_beam(
|
| 723 |
+
model, eval_open_dls, device, cfg, tokenizer,
|
| 724 |
+
args.max_gen_len, amp_dtype, args.beam_width
|
| 725 |
+
)
|
| 726 |
+
for bm, metrics in eval_results.items():
|
| 727 |
+
for mk, mv in metrics.items():
|
| 728 |
+
log.info(f" {bm} {mk}: {mv:.2f}")
|
| 729 |
+
|
| 730 |
+
all_scores = [mc_eval_acc] + [v for m in eval_results.values() for v in m.values()]
|
| 731 |
+
composite = np.mean(all_scores)
|
| 732 |
+
log.info(f"=== Phase 3.1 Epoch {epoch} | MC: {mc_eval_acc:.1f}% | "
|
| 733 |
+
f"Composite: {composite:.1f} | tf={tf_ratio:.2f} ===")
|
| 734 |
+
|
| 735 |
+
trackio.log({
|
| 736 |
+
"eval/scienceqa_accuracy": mc_eval_acc,
|
| 737 |
+
"eval/composite_score": composite,
|
| 738 |
+
"eval/epoch": epoch, "eval/tf_ratio": tf_ratio,
|
| 739 |
+
**{f"eval/{bm}_{mk}": mv for bm, m in eval_results.items() for mk, mv in m.items()},
|
| 740 |
+
})
|
| 741 |
+
|
| 742 |
+
if composite > best_composite:
|
| 743 |
+
best_composite = composite
|
| 744 |
+
save_checkpoint(model, cfg, epoch, mc_eval_acc, eval_results, composite)
|
| 745 |
+
log.info(f" ★ New best composite: {best_composite:.1f}")
|
| 746 |
+
|
| 747 |
+
log.info(f"Phase 3.1 complete. Best composite: {best_composite:.1f}")
|
| 748 |
+
|
| 749 |
+
finally:
|
| 750 |
+
trackio.log({"final/best_composite": best_composite, "final/phase": "3.1",
|
| 751 |
+
"final/total_steps": global_step})
|
| 752 |
+
trackio.finish()
|
| 753 |
+
|
| 754 |
+
if cfg.push_to_hub:
|
| 755 |
+
push_results(cfg, args, best_composite, eval_results)
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 759 |
+
# BEAM SEARCH EVALUATION
|
| 760 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 761 |
+
|
| 762 |
+
@torch.no_grad()
|
| 763 |
+
def evaluate_generative_beam(model, eval_dls, device, cfg, tokenizer,
|
| 764 |
+
max_gen_len, amp_dtype, beam_width):
|
| 765 |
+
"""Evaluate open-ended benchmarks using beam search decoding."""
|
| 766 |
+
model.eval()
|
| 767 |
+
results = {}
|
| 768 |
+
start_token_id = tokenizer.bos_token_id or tokenizer.cls_token_id or 1
|
| 769 |
+
eos_token_id = tokenizer.eos_token_id
|
| 770 |
+
|
| 771 |
+
for benchmark, dl in eval_dls.items():
|
| 772 |
+
predictions, ground_truths = [], []
|
| 773 |
+
for batch in dl:
|
| 774 |
+
bt = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
|
| 775 |
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type=="cuda"):
|
| 776 |
+
vis_tok = model.vis(bt["pixel_values"]).float()
|
| 777 |
+
txt_tok = model.txt(bt["input_ids"], bt["attention_mask"]).float()
|
| 778 |
+
evidence, _, _ = model.evidence(vis_tok, txt_tok, bt["attention_mask"])
|
| 779 |
+
if model._use_rollout:
|
| 780 |
+
_, z_final, _ = model.rollout(evidence)
|
| 781 |
+
else:
|
| 782 |
+
B2 = bt["batch_size"]
|
| 783 |
+
z_final = model.rollout.init_tokens.expand(B2,-1,-1) + model.rollout.z0_proj(
|
| 784 |
+
F.adaptive_avg_pool1d(evidence.permute(0,2,1), model.rollout.num_tokens).permute(0,2,1))
|
| 785 |
+
|
| 786 |
+
gen_ids = model.gen_head.generate_beam(
|
| 787 |
+
z_final, evidence, start_token_id,
|
| 788 |
+
max_length=max_gen_len, eos_token_id=eos_token_id,
|
| 789 |
+
beam_width=beam_width,
|
| 790 |
+
)
|
| 791 |
+
for i in range(gen_ids.size(0)):
|
| 792 |
+
predictions.append(tokenizer.decode(gen_ids[i], skip_special_tokens=True).strip())
|
| 793 |
+
ground_truths.extend(batch["all_answers"])
|
| 794 |
+
|
| 795 |
+
# Log a few sample predictions for debugging
|
| 796 |
+
for j in range(min(3, len(predictions))):
|
| 797 |
+
gt_sample = ground_truths[j] if j < len(ground_truths) else "?"
|
| 798 |
+
log.info(f" [{benchmark}] pred: '{predictions[j]}' | gt: '{gt_sample}'")
|
| 799 |
+
|
| 800 |
+
if benchmark == "docvqa":
|
| 801 |
+
results[benchmark] = {"anls": compute_anls(predictions, ground_truths)}
|
| 802 |
+
elif benchmark == "chartqa":
|
| 803 |
+
gt_flat = [g[0] if isinstance(g, list) else g for g in ground_truths]
|
| 804 |
+
results[benchmark] = {"relaxed_accuracy": compute_relaxed_accuracy(predictions, gt_flat)}
|
| 805 |
+
elif benchmark == "textvqa":
|
| 806 |
+
results[benchmark] = {"vqa_accuracy": compute_vqa_accuracy(predictions, ground_truths)}
|
| 807 |
+
|
| 808 |
+
model.train()
|
| 809 |
+
return results
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 813 |
+
# CHECKPOINT & HUB
|
| 814 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 815 |
+
|
| 816 |
+
def save_checkpoint(model, cfg, epoch, mc_acc, open_results, composite):
|
| 817 |
+
path = os.path.join(cfg.output_dir, "checkpoint_best.pt")
|
| 818 |
+
torch.save({
|
| 819 |
+
"evidence": model.evidence.state_dict(),
|
| 820 |
+
"rollout": model.rollout.state_dict(),
|
| 821 |
+
"disc": model.disc.state_dict(),
|
| 822 |
+
"gen_head": model.gen_head.state_dict(),
|
| 823 |
+
"target_ev": model.target.t_ev.state_dict(),
|
| 824 |
+
"target_ro": model.target.t_ro.state_dict(),
|
| 825 |
+
"config": cfg.__dict__,
|
| 826 |
+
"epoch": epoch, "mc_eval_acc": mc_acc,
|
| 827 |
+
"open_results": open_results, "composite_score": composite,
|
| 828 |
+
"phase": "3.1",
|
| 829 |
+
}, path)
|
| 830 |
+
log.info(f"Saved checkpoint: {path} (composite={composite:.1f})")
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
def push_results(cfg, args, best_composite, eval_results):
|
| 834 |
+
try:
|
| 835 |
+
from huggingface_hub import HfApi
|
| 836 |
+
api = HfApi()
|
| 837 |
+
results = {
|
| 838 |
+
"run_name": cfg.run_name, "phase": "3.1",
|
| 839 |
+
"backbone": cfg.backbone, "K": cfg.K,
|
| 840 |
+
"best_composite_score": best_composite,
|
| 841 |
+
"gen_weight": args.gen_weight, "max_gen_len": args.max_gen_len,
|
| 842 |
+
"beam_width": args.beam_width,
|
| 843 |
+
"tf_start": args.tf_start, "tf_end": args.tf_end,
|
| 844 |
+
"epochs": cfg.epochs, "core_lr": args.core_lr,
|
| 845 |
+
"open_results": {k: v for k, v in (eval_results or {}).items()},
|
| 846 |
+
"improvements": ["gen_weight_2.0", "gen_len_32", "scheduled_sampling", "beam_search"],
|
| 847 |
+
}
|
| 848 |
+
rp = os.path.join(cfg.output_dir, f"results_{cfg.run_name}.json")
|
| 849 |
+
with open(rp, "w") as f:
|
| 850 |
+
json.dump(results, f, indent=2)
|
| 851 |
+
api.upload_file(path_or_fileobj=rp, path_in_repo=f"results/{cfg.run_name}.json",
|
| 852 |
+
repo_id=cfg.hub_model_id, repo_type="model")
|
| 853 |
+
best_ckpt = os.path.join(cfg.output_dir, "checkpoint_best.pt")
|
| 854 |
+
if os.path.exists(best_ckpt):
|
| 855 |
+
api.upload_file(path_or_fileobj=best_ckpt,
|
| 856 |
+
path_in_repo=f"checkpoints/{cfg.run_name}_best.pt",
|
| 857 |
+
repo_id=cfg.hub_model_id, repo_type="model")
|
| 858 |
+
log.info(f"Pushed Phase 3.1 results to {cfg.hub_model_id}")
|
| 859 |
+
except Exception as e:
|
| 860 |
+
log.error(f"Push failed: {e}")
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
if __name__ == "__main__":
|
| 864 |
+
main()
|