| |
| """ |
| MR-JEPA Phase 4 — SmolLM2-135M Generative Decoder |
| |
| Replaces the random-init 4-layer transformer decoder (which produced 0% generative |
| metrics after 10+ epochs) with SmolLM2-135M-Instruct as a pre-trained LM decoder. |
| |
| Architecture (BLIP-2 / LLaVA-1.5 pattern): |
| z_K (768d) ──→ Bridge MLP (768→576→576) ──→ visual soft prompt tokens |
| evidence (N×768d) ──→ same Bridge MLP ──→ evidence soft prompt tokens |
| [vis_tokens, ev_tokens, text_tokens] ──→ SmolLM2-135M ──→ next-token prediction |
| |
| Training recipe (2-stage, following LLaVA/BLIP-2): |
| Stage 1: Freeze SmolLM2, train only bridge MLP. LR=1e-3. |
| Stage 2: Unfreeze SmolLM2, joint fine-tuning. LR=2e-5, cosine decay. |
| |
| Key improvements over Phase 3.x: |
| 1. Pre-trained 30-layer LM decoder (135M params) vs random-init 4-layer (7M params) |
| 2. LLaVA-1.5 two-layer MLP bridge (nonlinear alignment) vs none |
| 3. Label smoothing (ε=0.1) to combat repetition collapse |
| 4. Repetition penalty + nucleus sampling in evaluation |
| 5. SmolLM2 tokenizer (49K vocab, ChatML) vs Qwen3 tokenizer (152K vocab) |
| 6. Proper label masking: -100 for visual prefix, pad tokens |
| |
| Resumes JEPA/Evidence/Rollout/Disc from Phase 3.1 checkpoint. |
| SmolLM2-135M loaded fresh from HuggingFace Hub. |
| |
| Usage: |
| python train_phase4.py |
| python train_phase4.py --stage 1 --epochs 5 --bridge_lr 1e-3 |
| python train_phase4.py --stage 2 --epochs 10 --lm_lr 2e-5 |
| """ |
|
|
| import os |
| import sys |
| import json |
| import math |
| import copy |
| import random |
| import logging |
| import argparse |
| from collections import defaultdict, Counter |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.optim import AdamW |
| from torch.utils.data import Dataset, DataLoader |
| from PIL import Image |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s | %(levelname)s | %(message)s", |
| datefmt="%H:%M:%S", |
| ) |
| log = logging.getLogger("mrjepa-p4") |
|
|
|
|
| |
| |
| |
|
|
| class VisionLanguageBridge(nn.Module): |
| """ |
| LLaVA-1.5 style 2-layer MLP connector. |
| Projects JEPA representations (768d) into SmolLM2 space (576d). |
| |
| Applied to both z_K (global JEPA latent) and evidence tokens. |
| The nonlinear projection is critical — BLIP-2 showed linear works, |
| LLaVA-1.5 showed MLP is significantly better for VQA. |
| """ |
| def __init__(self, jepa_dim=768, lm_dim=576): |
| super().__init__() |
| self.proj = nn.Sequential( |
| nn.Linear(jepa_dim, lm_dim), |
| nn.GELU(), |
| nn.Linear(lm_dim, lm_dim), |
| ) |
| |
| nn.init.xavier_uniform_(self.proj[0].weight, gain=0.5) |
| nn.init.zeros_(self.proj[0].bias) |
| nn.init.xavier_uniform_(self.proj[2].weight, gain=0.1) |
| nn.init.zeros_(self.proj[2].bias) |
| |
| def forward(self, features): |
| """ |
| Args: |
| features: [B, N, 768] — either z_K or evidence tokens |
| Returns: |
| projected: [B, N, 576] — in SmolLM2 embedding space |
| """ |
| return self.proj(features) |
|
|
|
|
| |
| |
| |
|
|
| class SmolLMDecoder(nn.Module): |
| """ |
| Wraps SmolLM2-135M-Instruct as the generative decoder. |
| |
| Architecture: |
| 1. Bridge MLP projects z_K + evidence from JEPA space (768d) to LM space (576d) |
| 2. Projected tokens are prepended as "soft visual prompts" before text tokens |
| 3. SmolLM2 processes [vis_prefix | text_tokens] with causal attention |
| 4. Loss computed only on answer tokens (visual prefix masked with -100) |
| |
| This follows the BLIP-2 / LLaVA pattern exactly: |
| "projected query embeddings are prepended to the input text embeddings. |
| They function as soft visual prompts that condition the LLM on visual |
| representation." — Li et al., BLIP-2 §3.3 |
| """ |
| def __init__(self, jepa_dim=768, freeze_lm=True, label_smoothing=0.1, |
| num_evidence_tokens=8): |
| super().__init__() |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
| log.info("Loading SmolLM2-135M-Instruct...") |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| "HuggingFaceTB/SmolLM2-135M-Instruct" |
| ) |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| self.lm = AutoModelForCausalLM.from_pretrained( |
| "HuggingFaceTB/SmolLM2-135M-Instruct", |
| torch_dtype=torch.bfloat16, |
| ) |
| |
| self.lm_dim = self.lm.config.hidden_size |
| self.vocab_size = self.lm.config.vocab_size |
| |
| log.info(f"SmolLM2: hidden={self.lm_dim}, vocab={self.vocab_size}, " |
| f"layers={self.lm.config.num_hidden_layers}") |
| |
| if freeze_lm: |
| for p in self.lm.parameters(): |
| p.requires_grad = False |
| log.info("SmolLM2 weights frozen (Stage 1: train bridge only)") |
| else: |
| log.info("SmolLM2 weights trainable (Stage 2: full fine-tuning)") |
| |
| |
| self.bridge = VisionLanguageBridge(jepa_dim, self.lm_dim) |
| |
| |
| |
| self.num_evidence_tokens = num_evidence_tokens |
| if num_evidence_tokens < 64: |
| self.ev_pool = nn.Linear(jepa_dim, jepa_dim) |
| else: |
| self.ev_pool = None |
| |
| self.label_smoothing = label_smoothing |
| self.freeze_lm = freeze_lm |
| |
| def unfreeze_lm(self): |
| """Unfreeze SmolLM2 for Stage 2 fine-tuning.""" |
| for p in self.lm.parameters(): |
| p.requires_grad = True |
| self.freeze_lm = False |
| log.info("SmolLM2 unfrozen for Stage 2") |
| |
| def _subsample_evidence(self, evidence): |
| """Subsample evidence tokens from 64 → num_evidence_tokens.""" |
| B, N, D = evidence.shape |
| if N <= self.num_evidence_tokens: |
| return evidence |
| |
| if self.ev_pool is not None: |
| |
| stride = N // self.num_evidence_tokens |
| indices = torch.arange(0, N, stride, device=evidence.device)[:self.num_evidence_tokens] |
| return evidence[:, indices] |
| return evidence[:, :self.num_evidence_tokens] |
| |
| def prepare_inputs(self, z_final, evidence, questions, answers=None, |
| max_answer_len=32): |
| """ |
| Prepare inputs for SmolLM2 forward pass. |
| |
| Args: |
| z_final: [B, N_state, 768] — JEPA latent states |
| evidence: [B, N_ev, 768] — evidence memory tokens |
| questions: list[str] — question texts |
| answers: list[str] or None — answer texts (None for generation) |
| max_answer_len: int — max tokens for answer |
| |
| Returns: |
| inputs_embeds: [B, N_vis + N_text, 576] |
| attention_mask: [B, N_vis + N_text] |
| labels: [B, N_vis + N_text] or None |
| n_vis_tokens: int — number of visual prefix tokens |
| """ |
| device = z_final.device |
| B = z_final.size(0) |
| |
| |
| vis_embeds = self.bridge(z_final) |
| |
| ev_sub = self._subsample_evidence(evidence) |
| ev_embeds = self.bridge(ev_sub) |
| |
| |
| vis_prefix = torch.cat([vis_embeds, ev_embeds], dim=1) |
| n_vis = vis_prefix.size(1) |
| |
| |
| if answers is not None: |
| |
| texts = [] |
| for q, a in zip(questions, answers): |
| texts.append(f"Question: {q}\nAnswer: {a}") |
| tok = self.tokenizer( |
| texts, padding="max_length", truncation=True, |
| max_length=192 + max_answer_len, |
| return_tensors="pt", |
| ).to(device) |
| else: |
| |
| texts = [f"Question: {q}\nAnswer:" for q in questions] |
| tok = self.tokenizer( |
| texts, padding="max_length", truncation=True, |
| max_length=192, |
| return_tensors="pt", |
| ).to(device) |
| |
| |
| text_embeds = self.lm.model.embed_tokens(tok["input_ids"]) |
| |
| |
| lm_dtype = text_embeds.dtype |
| vis_prefix = vis_prefix.to(lm_dtype) |
| inputs_embeds = torch.cat([vis_prefix, text_embeds], dim=1) |
| |
| |
| vis_mask = torch.ones(B, n_vis, device=device, dtype=tok["attention_mask"].dtype) |
| attention_mask = torch.cat([vis_mask, tok["attention_mask"]], dim=1) |
| |
| |
| labels = None |
| if answers is not None: |
| |
| vis_labels = torch.full((B, n_vis), -100, device=device, dtype=torch.long) |
| |
| |
| text_labels = tok["input_ids"].clone() |
| |
| text_labels[text_labels == self.tokenizer.pad_token_id] = -100 |
| |
| |
| |
| for i, (q, a) in enumerate(zip(questions, answers)): |
| q_text = f"Question: {q}\nAnswer:" |
| q_tok = self.tokenizer(q_text, add_special_tokens=False) |
| q_len = len(q_tok["input_ids"]) |
| |
| text_labels[i, :min(q_len, text_labels.size(1))] = -100 |
| |
| labels = torch.cat([vis_labels, text_labels], dim=1) |
| |
| return inputs_embeds, attention_mask, labels, n_vis |
| |
| def forward(self, z_final, evidence, questions, answers, |
| max_answer_len=32): |
| """ |
| Training forward pass. |
| |
| Returns: |
| loss: scalar tensor — CE loss with label smoothing |
| logits: [B, L, V] — for debugging |
| """ |
| inputs_embeds, attention_mask, labels, n_vis = self.prepare_inputs( |
| z_final, evidence, questions, answers, max_answer_len |
| ) |
| |
| outputs = self.lm( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| labels=labels, |
| ) |
| |
| |
| if self.label_smoothing > 0 and labels is not None: |
| |
| logits = outputs.logits.float() |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss = F.cross_entropy( |
| shift_logits.view(-1, self.vocab_size), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| label_smoothing=self.label_smoothing, |
| ) |
| else: |
| loss = outputs.loss |
| |
| return loss, outputs.logits |
| |
| @torch.no_grad() |
| def generate(self, z_final, evidence, questions, |
| max_new_tokens=32, temperature=0.7, top_p=0.9, |
| repetition_penalty=1.3, no_repeat_ngram_size=3): |
| """ |
| Generate answers with nucleus sampling + repetition penalty. |
| |
| Returns: |
| predictions: list[str] — decoded answer strings |
| """ |
| device = z_final.device |
| B = z_final.size(0) |
| |
| |
| inputs_embeds, attention_mask, _, n_vis = self.prepare_inputs( |
| z_final, evidence, questions, answers=None |
| ) |
| |
| |
| generated_ids = [] |
| past_key_values = None |
| cur_embeds = inputs_embeds |
| cur_mask = attention_mask |
| |
| |
| all_generated = [[] for _ in range(B)] |
| |
| for step in range(max_new_tokens): |
| outputs = self.lm( |
| inputs_embeds=cur_embeds, |
| attention_mask=cur_mask, |
| past_key_values=past_key_values, |
| use_cache=True, |
| ) |
| |
| next_logits = outputs.logits[:, -1, :] |
| past_key_values = outputs.past_key_values |
| |
| |
| if repetition_penalty != 1.0: |
| for b in range(B): |
| for token_id in set(all_generated[b]): |
| if next_logits[b, token_id] > 0: |
| next_logits[b, token_id] /= repetition_penalty |
| else: |
| next_logits[b, token_id] *= repetition_penalty |
| |
| |
| if no_repeat_ngram_size > 0 and len(all_generated[0]) >= no_repeat_ngram_size - 1: |
| for b in range(B): |
| gen = all_generated[b] |
| if len(gen) >= no_repeat_ngram_size - 1: |
| ngram_prefix = tuple(gen[-(no_repeat_ngram_size - 1):]) |
| |
| for i in range(len(gen) - no_repeat_ngram_size + 1): |
| if tuple(gen[i:i + no_repeat_ngram_size - 1]) == ngram_prefix: |
| blocked = gen[i + no_repeat_ngram_size - 1] |
| next_logits[b, blocked] = float('-inf') |
| |
| |
| if temperature > 0: |
| next_logits = next_logits / temperature |
| |
| |
| sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| |
| |
| sorted_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p |
| sorted_logits[sorted_mask] = float('-inf') |
| |
| |
| probs = F.softmax(sorted_logits, dim=-1) |
| sampled_idx = torch.multinomial(probs, 1) |
| next_tokens = sorted_indices.gather(1, sampled_idx) |
| else: |
| next_tokens = next_logits.argmax(dim=-1, keepdim=True) |
| |
| generated_ids.append(next_tokens) |
| |
| |
| for b in range(B): |
| all_generated[b].append(next_tokens[b, 0].item()) |
| |
| |
| if (next_tokens == self.tokenizer.eos_token_id).all(): |
| break |
| |
| |
| cur_embeds = self.lm.model.embed_tokens(next_tokens) |
| cur_mask = torch.cat([ |
| cur_mask, |
| torch.ones(B, 1, device=device, dtype=cur_mask.dtype) |
| ], dim=1) |
| |
| |
| if generated_ids: |
| gen_tensor = torch.cat(generated_ids, dim=1) |
| predictions = [] |
| for i in range(B): |
| text = self.tokenizer.decode(gen_tensor[i], skip_special_tokens=True) |
| |
| text = text.strip() |
| predictions.append(text) |
| else: |
| predictions = [""] * B |
| |
| return predictions |
|
|
|
|
| |
| |
| |
|
|
| class OpenEndedDataset(Dataset): |
| def __init__(self, benchmark, split, max_samples=0, transform=None, |
| tokenizer=None, max_len=192): |
| from datasets import load_dataset |
| self.benchmark = benchmark |
| self.transform = transform |
| self.tokenizer = tokenizer |
| self.max_len = max_len |
| log.info(f"Loading {benchmark} {split}...") |
| if benchmark == "docvqa": |
| ds = load_dataset("lmms-lab/DocVQA", "DocVQA", split=split) |
| elif benchmark == "chartqa": |
| ds = load_dataset("lmms-lab/ChartQA", split=split) |
| elif benchmark == "textvqa": |
| ds = load_dataset("lmms-lab/textvqa", split=split) |
| else: |
| raise ValueError(f"Unknown benchmark: {benchmark}") |
| if max_samples > 0: |
| ds = ds.select(range(min(max_samples, len(ds)))) |
| self.data = ds |
| log.info(f"Loaded {len(ds)} samples from {benchmark} {split}") |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| row = self.data[idx] |
| img = row.get("image") |
| if img is None: |
| img = Image.new("RGB", (256, 256), "white") |
| else: |
| img = img.convert("RGB") |
| question = row["question"] |
| if self.benchmark == "docvqa": |
| answers = row.get("answers", [""]) |
| answer = answers[0] if answers else "" |
| all_answers = answers |
| elif self.benchmark == "chartqa": |
| answer = str(row.get("answer", "")) |
| all_answers = [answer] |
| elif self.benchmark == "textvqa": |
| answers = row.get("answers", [""]) |
| answer_counts = Counter(a.lower().strip() for a in answers) |
| answer = answer_counts.most_common(1)[0][0] if answer_counts else "" |
| all_answers = answers |
| else: |
| answer = "" |
| all_answers = [""] |
| ocr_tokens = row.get("ocr_tokens", []) |
| ocr_text = " ".join(ocr_tokens[:50]) if ocr_tokens else "" |
| text = question |
| if ocr_text: |
| text += f" [OCR: {ocr_text}]" |
| return { |
| "image": img, "text": text, "answer": answer, |
| "all_answers": all_answers, "benchmark": self.benchmark, |
| } |
|
|
|
|
| def collate_open_ended_p4(batch, transform, qwen_tokenizer, max_len): |
| """Collate for Phase 4 — we only need image, question text, and answer string.""" |
| images = [s["image"] for s in batch] |
| texts = [s["text"] for s in batch] |
| answers = [s["answer"] for s in batch] |
| |
| if hasattr(transform, '__call__') and not hasattr(transform, 'feature_extractor'): |
| pixel_values = torch.stack([transform(img) for img in images]) |
| else: |
| pixel_values = transform(images=images, return_tensors="pt")["pixel_values"] |
| |
| |
| tok = qwen_tokenizer(texts, padding="max_length", truncation=True, |
| max_length=max_len, return_tensors="pt") |
| |
| return { |
| "pixel_values": pixel_values, |
| "input_ids": tok["input_ids"], |
| "attention_mask": tok["attention_mask"], |
| "questions": texts, |
| "answers": answers, |
| "batch_size": len(batch), |
| "benchmarks": [s["benchmark"] for s in batch], |
| "all_answers": [s["all_answers"] for s in batch], |
| } |
|
|
|
|
| |
| |
| |
|
|
| def normalized_levenshtein(s1, s2): |
| s1, s2 = s1.lower().strip(), s2.lower().strip() |
| if s1 == s2: return 0.0 |
| l1, l2 = len(s1), len(s2) |
| if l1 == 0 or l2 == 0: return 1.0 |
| m = [[0]*(l2+1) for _ in range(l1+1)] |
| for i in range(l1+1): m[i][0] = i |
| for j in range(l2+1): m[0][j] = j |
| for i in range(1,l1+1): |
| for j in range(1,l2+1): |
| c = 0 if s1[i-1]==s2[j-1] else 1 |
| m[i][j] = min(m[i-1][j]+1, m[i][j-1]+1, m[i-1][j-1]+c) |
| return m[l1][l2]/max(l1,l2) |
|
|
| def compute_anls(predictions, ground_truths, threshold=0.5): |
| scores = [] |
| for pred, gts in zip(predictions, ground_truths): |
| mx = max((1.0-normalized_levenshtein(str(pred),str(gt)) |
| if normalized_levenshtein(str(pred),str(gt))<threshold else 0.0) |
| for gt in gts) if gts else 0.0 |
| scores.append(mx) |
| return np.mean(scores)*100 if scores else 0.0 |
|
|
| def compute_vqa_accuracy(predictions, ground_truths): |
| scores = [] |
| for pred, gts in zip(predictions, ground_truths): |
| pn = str(pred).lower().strip() |
| scores.append(min(sum(1 for gt in gts if str(gt).lower().strip()==pn)/3.0, 1.0)) |
| return np.mean(scores)*100 if scores else 0.0 |
|
|
| def compute_relaxed_accuracy(predictions, ground_truths, tolerance=0.05): |
| correct = [] |
| for pred, gt in zip(predictions, ground_truths): |
| ps, gs = str(pred).strip().lower(), str(gt).strip().lower() |
| try: |
| gv = float(gs.replace(',','').replace('%','')) |
| pv = float(ps.replace(',','').replace('%','')) |
| correct.append(abs(pv-gv)/abs(gv)<=tolerance if gv!=0 else abs(pv)<=tolerance) |
| except (ValueError,ZeroDivisionError): |
| correct.append(ps==gs) |
| return np.mean(correct)*100 if correct else 0.0 |
|
|
|
|
| |
| |
| |
|
|
| def download_checkpoint(hub_model_id, filename): |
| from huggingface_hub import hf_hub_download |
| path = hf_hub_download(repo_id=hub_model_id, filename=filename, repo_type="model") |
| log.info(f"Downloaded checkpoint: {path}") |
| return path |
|
|
|
|
| def load_jepa_model(hub_model_id, ckpt_filename, device): |
| """Load Phase 3.1 JEPA model (everything except gen_head).""" |
| |
| from huggingface_hub import hf_hub_download |
| p1_script = hf_hub_download(repo_id=hub_model_id, filename="train_mrjepa.py", repo_type="model") |
| import importlib.util |
| spec = importlib.util.spec_from_file_location("train_mrjepa", p1_script) |
| p1 = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(p1) |
| |
| |
| ckpt_path = download_checkpoint(hub_model_id, ckpt_filename) |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| |
| |
| saved_cfg = ckpt["config"] |
| cfg = p1.Config() |
| for k, v in saved_cfg.items(): |
| if hasattr(cfg, k): |
| setattr(cfg, k, v) |
| cfg.resolve() |
| |
| |
| model = p1.MRJEPAModel(cfg) |
| model.evidence.load_state_dict(ckpt["evidence"]) |
| model.rollout.load_state_dict(ckpt["rollout"]) |
| model.disc.load_state_dict(ckpt["disc"]) |
| model.target.t_ev.load_state_dict(ckpt["target_ev"]) |
| model.target.t_ro.load_state_dict(ckpt["target_ro"]) |
| |
| log.info(f"Loaded JEPA weights from {ckpt_filename} " |
| f"(epoch={ckpt.get('epoch','?')}, score={ckpt.get('composite_score','?')})") |
| |
| return model, cfg, p1 |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def evaluate_generative(jepa_model, decoder, eval_dls, device, cfg, |
| amp_dtype, max_new_tokens=32): |
| """Evaluate open-ended benchmarks using SmolLM2 generation.""" |
| jepa_model.eval() |
| decoder.eval() |
| results = {} |
| |
| for benchmark, dl in eval_dls.items(): |
| predictions, ground_truths = [], [] |
| |
| for batch in dl: |
| bt = {k: v.to(device) if isinstance(v, torch.Tensor) else v |
| for k, v in batch.items()} |
| |
| with torch.autocast(device_type="cuda", dtype=amp_dtype, |
| enabled=cfg.bf16 and device.type == "cuda"): |
| vis_tok = jepa_model.vis(bt["pixel_values"]).float() |
| txt_tok = jepa_model.txt(bt["input_ids"], bt["attention_mask"]).float() |
| evidence, _, _ = jepa_model.evidence(vis_tok, txt_tok, bt["attention_mask"]) |
| |
| if jepa_model._use_rollout: |
| _, z_final, _ = jepa_model.rollout(evidence) |
| else: |
| B2 = bt["batch_size"] |
| z_final = jepa_model.rollout.init_tokens.expand(B2,-1,-1) + \ |
| jepa_model.rollout.z0_proj( |
| F.adaptive_avg_pool1d(evidence.permute(0,2,1), |
| jepa_model.rollout.num_tokens).permute(0,2,1)) |
| |
| preds = decoder.generate( |
| z_final.float(), evidence.float(), bt["questions"], |
| max_new_tokens=max_new_tokens, |
| temperature=0.7, top_p=0.9, |
| repetition_penalty=1.3, no_repeat_ngram_size=3, |
| ) |
| |
| predictions.extend(preds) |
| ground_truths.extend(batch["all_answers"]) |
| |
| |
| for j in range(min(5, len(predictions))): |
| gt_sample = ground_truths[j] if j < len(ground_truths) else "?" |
| log.info(f" [{benchmark}] pred: '{predictions[j][:80]}' | gt: '{gt_sample}'") |
| |
| if benchmark == "docvqa": |
| results[benchmark] = {"anls": compute_anls(predictions, ground_truths)} |
| elif benchmark == "chartqa": |
| gt_flat = [g[0] if isinstance(g, list) else g for g in ground_truths] |
| results[benchmark] = {"relaxed_accuracy": compute_relaxed_accuracy(predictions, gt_flat)} |
| elif benchmark == "textvqa": |
| results[benchmark] = {"vqa_accuracy": compute_vqa_accuracy(predictions, ground_truths)} |
| |
| jepa_model.train() |
| decoder.train() |
| return results |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="MR-JEPA Phase 4: SmolLM2 Decoder") |
| parser.add_argument("--hub_model_id", default="JorgeAV/MR-JEPA") |
| parser.add_argument("--ckpt", default="checkpoints/hybrid_main_phase3_1_best.pt", |
| help="JEPA checkpoint to load") |
| parser.add_argument("--run_name", default="phase4_smollm2") |
| parser.add_argument("--stage", type=int, default=1, choices=[1, 2], |
| help="1=freeze LM train bridge, 2=unfreeze all") |
| parser.add_argument("--epochs", type=int, default=5) |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--grad_accum", type=int, default=32) |
| parser.add_argument("--bridge_lr", type=float, default=1e-3, |
| help="Bridge MLP learning rate (Stage 1)") |
| parser.add_argument("--lm_lr", type=float, default=2e-5, |
| help="SmolLM2 learning rate (Stage 2)") |
| parser.add_argument("--core_lr", type=float, default=5e-5, |
| help="JEPA core module learning rate") |
| parser.add_argument("--backbone_lr", type=float, default=5e-6) |
| parser.add_argument("--text_lr", type=float, default=5e-6) |
| parser.add_argument("--label_smoothing", type=float, default=0.1) |
| parser.add_argument("--num_evidence_tokens", type=int, default=8, |
| help="Evidence tokens as soft prompts (subsample from 64)") |
| parser.add_argument("--max_answer_len", type=int, default=32) |
| parser.add_argument("--max_eval_samples", type=int, default=200) |
| parser.add_argument("--max_train_samples", type=int, default=5000) |
| parser.add_argument("--gen_weight", type=float, default=2.0) |
| parser.add_argument("--output_dir", default="./outputs/mrjepa_phase4") |
| parser.add_argument("--trackio_space", default="JorgeAV/MR-JEPA-Trackio") |
| |
| parser.add_argument("--stage1_epochs", type=int, default=3, |
| help="Auto-transition: Stage 1 epochs (0=skip)") |
| parser.add_argument("--stage2_epochs", type=int, default=7, |
| help="Auto-transition: Stage 2 epochs (0=skip)") |
| parser.add_argument("--auto_transition", action="store_true", default=True, |
| help="Auto-transition from Stage 1 → Stage 2") |
| args = parser.parse_args() |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| log.info(f"Device: {device}") |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| jepa_model, cfg, p1 = load_jepa_model(args.hub_model_id, args.ckpt, device) |
| |
| |
| jepa_model.vis.unfreeze_last(6) |
| jepa_model.txt.unfreeze_last(4) |
| jepa_model = jepa_model.to(device) |
| |
| |
| freeze_lm = (args.stage == 1) if not args.auto_transition else True |
| decoder = SmolLMDecoder( |
| jepa_dim=cfg.rollout_dim, |
| freeze_lm=freeze_lm, |
| label_smoothing=args.label_smoothing, |
| num_evidence_tokens=args.num_evidence_tokens, |
| ).to(device) |
| |
| |
| import trackio |
| trackio.init( |
| name=args.run_name, project="MR-JEPA", space_id=args.trackio_space, |
| config={ |
| "phase": "4", "stage": args.stage, |
| "auto_transition": args.auto_transition, |
| "stage1_epochs": args.stage1_epochs, |
| "stage2_epochs": args.stage2_epochs, |
| "bridge_lr": args.bridge_lr, "lm_lr": args.lm_lr, |
| "core_lr": args.core_lr, "backbone_lr": args.backbone_lr, |
| "label_smoothing": args.label_smoothing, |
| "num_evidence_tokens": args.num_evidence_tokens, |
| "gen_weight": args.gen_weight, |
| "decoder": "SmolLM2-135M-Instruct", |
| "decoder_params": "135M", "bridge": "LLaVA-1.5 MLP", |
| } |
| ) |
| log.info(f"Trackio → https://huggingface.co/spaces/{args.trackio_space}") |
| |
| |
| jepa_p = sum(p.numel() for p in jepa_model.parameters()) |
| jepa_tp = sum(p.numel() for p in jepa_model.parameters() if p.requires_grad) |
| dec_p = sum(p.numel() for p in decoder.parameters()) |
| dec_tp = sum(p.numel() for p in decoder.parameters() if p.requires_grad) |
| log.info(f"JEPA: {jepa_p:,} total, {jepa_tp:,} trainable") |
| log.info(f"Decoder: {dec_p:,} total, {dec_tp:,} trainable") |
| log.info(f"Combined: {jepa_p + dec_p:,} total, {jepa_tp + dec_tp:,} trainable") |
| |
| |
| qwen_tokenizer = jepa_model.txt.tokenizer |
| transform = jepa_model.vis.get_transform() |
| |
| |
| mc_max = 0 |
| train_mc_ds = p1.ScienceQADataset("train", max_samples=mc_max, transform=transform, |
| tokenizer=qwen_tokenizer, max_len=cfg.max_text_len, |
| max_opts=cfg.max_options) |
| eval_mc_ds = p1.ScienceQADataset("test", max_samples=args.max_eval_samples, |
| transform=transform, tokenizer=qwen_tokenizer, |
| max_len=cfg.max_text_len, max_opts=cfg.max_options) |
| mc_coll = lambda batch: p1.collate_fn(batch, transform, qwen_tokenizer, |
| cfg.max_text_len, cfg.max_options) |
| train_mc_dl = DataLoader(train_mc_ds, batch_size=args.batch_size, shuffle=True, |
| num_workers=2, collate_fn=mc_coll, pin_memory=True, drop_last=True) |
| eval_mc_dl = DataLoader(eval_mc_ds, batch_size=args.batch_size, shuffle=False, |
| num_workers=2, collate_fn=mc_coll, pin_memory=True) |
| |
| |
| open_coll = lambda batch: collate_open_ended_p4(batch, transform, qwen_tokenizer, |
| cfg.max_text_len) |
| train_open_dls = {} |
| eval_open_dls = {} |
| for bm, tr_split, ev_split in [("docvqa", "validation", "validation"), |
| ("chartqa", "test", "test"), |
| ("textvqa", "train", "validation")]: |
| train_open_dls[bm] = DataLoader( |
| OpenEndedDataset(bm, tr_split, max_samples=args.max_train_samples, |
| transform=transform, tokenizer=qwen_tokenizer, |
| max_len=cfg.max_text_len), |
| batch_size=args.batch_size, shuffle=True, num_workers=2, |
| collate_fn=open_coll, pin_memory=True, drop_last=True) |
| eval_open_dls[bm] = DataLoader( |
| OpenEndedDataset(bm, ev_split, max_samples=args.max_eval_samples, |
| transform=transform, tokenizer=qwen_tokenizer, |
| max_len=cfg.max_text_len), |
| batch_size=args.batch_size, shuffle=False, num_workers=2, |
| collate_fn=open_coll, pin_memory=True) |
| |
| |
| pad_token_id = qwen_tokenizer.pad_token_id or 0 |
| amp_dtype = torch.bfloat16 if cfg.bf16 else torch.float32 |
| |
| total_epochs = args.stage1_epochs + args.stage2_epochs if args.auto_transition else args.epochs |
| |
| def run_training_stage(stage, num_epochs, start_epoch=0): |
| """Run one training stage.""" |
| log.info(f"\n{'='*60}") |
| log.info(f"STAGE {stage}: {'Freeze LM, train bridge' if stage==1 else 'Unfreeze all, joint fine-tuning'}") |
| log.info(f"{'='*60}") |
| |
| if stage == 2: |
| decoder.unfreeze_lm() |
| |
| |
| bridge_params = list(decoder.bridge.parameters()) |
| if decoder.ev_pool is not None: |
| bridge_params += list(decoder.ev_pool.parameters()) |
| |
| param_groups = [] |
| |
| |
| param_groups.append({ |
| "params": bridge_params, |
| "lr": args.bridge_lr if stage == 1 else args.bridge_lr * 0.1, |
| "name": "bridge", |
| }) |
| |
| |
| jepa_core_params = [p for n, p in jepa_model.named_parameters() |
| if p.requires_grad and 'vis.' not in n and 'txt.' not in n] |
| if jepa_core_params: |
| param_groups.append({ |
| "params": jepa_core_params, |
| "lr": args.core_lr if stage == 2 else args.core_lr * 0.1, |
| "name": "jepa_core", |
| }) |
| |
| |
| bb_params = [p for p in jepa_model.vis.parameters() if p.requires_grad] |
| if bb_params: |
| param_groups.append({ |
| "params": bb_params, |
| "lr": args.backbone_lr, |
| "name": "backbone", |
| }) |
| |
| |
| txt_params = [p for p in jepa_model.txt.parameters() if p.requires_grad] |
| if txt_params: |
| param_groups.append({ |
| "params": txt_params, |
| "lr": args.text_lr, |
| "name": "text_encoder", |
| }) |
| |
| |
| if stage == 2: |
| lm_params = [p for p in decoder.lm.parameters() if p.requires_grad] |
| if lm_params: |
| param_groups.append({ |
| "params": lm_params, |
| "lr": args.lm_lr, |
| "name": "smollm2", |
| }) |
| |
| |
| for pg in param_groups: |
| n_params = sum(p.numel() for p in pg["params"]) |
| log.info(f" {pg['name']}: {n_params:,} params, lr={pg['lr']:.2e}") |
| |
| optimizer = AdamW(param_groups, weight_decay=0.05) |
| |
| mc_steps = len(train_mc_dl) |
| open_steps = sum(len(dl) for dl in train_open_dls.values()) |
| total_steps = num_epochs * (mc_steps + open_steps) // args.grad_accum |
| warmup_steps = int(total_steps * 0.1) |
| |
| def lr_lambda(step): |
| if step < warmup_steps: |
| return step / max(warmup_steps, 1) |
| progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1) |
| return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress)) |
| |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
| |
| global_step = 0 |
| best_composite = 0.0 |
| all_trainable = ([p for p in jepa_model.parameters() if p.requires_grad] + |
| [p for p in decoder.parameters() if p.requires_grad]) |
| |
| for epoch in range(num_epochs): |
| abs_epoch = start_epoch + epoch |
| jepa_model.train() |
| decoder.train() |
| epoch_losses = defaultdict(list) |
| epoch_mc_correct, epoch_mc_total = 0, 0 |
| optimizer.zero_grad() |
| batch_count = 0 |
| |
| |
| log.info(f" Stage {stage} Epoch {epoch}/{num_epochs}: MC training...") |
| for bi, batch in enumerate(train_mc_dl): |
| batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v |
| for k, v in batch.items()} |
| with torch.autocast(device_type="cuda", dtype=amp_dtype, |
| enabled=cfg.bf16 and device.type == "cuda"): |
| losses, preds = jepa_model(**batch) |
| loss = losses["total"] / args.grad_accum |
| loss.backward() |
| batch_count += 1 |
| if batch_count % args.grad_accum == 0: |
| nn.utils.clip_grad_norm_(all_trainable, cfg.max_grad_norm) |
| optimizer.step(); scheduler.step(); optimizer.zero_grad() |
| jepa_model.update_target(global_step, total_steps) |
| global_step += 1 |
| for k, v in losses.items(): |
| if isinstance(v, torch.Tensor): |
| epoch_losses[f"mc_{k}"].append(v.item()) |
| epoch_mc_correct += (preds == batch["labels"]).sum().item() |
| epoch_mc_total += batch["batch_size"] |
| if bi % 100 == 0: |
| avg = {k: np.mean(v[-100:]) for k, v in epoch_losses.items() if k.startswith("mc_")} |
| acc = epoch_mc_correct / max(epoch_mc_total, 1) * 100 |
| log.info(f" S{stage} E{epoch} MC B{bi}/{mc_steps} | " |
| f"loss={avg.get('mc_total',0):.4f} | acc={acc:.1f}%") |
| trackio.log({"train/mc_loss": avg.get("mc_total", 0), |
| "train/mc_accuracy": acc, |
| "train/lr": scheduler.get_last_lr()[0], |
| "train/epoch": abs_epoch, "train/stage": stage, |
| "train/step": global_step}) |
| |
| |
| log.info(f" Stage {stage} Epoch {epoch}: Open-ended training...") |
| gen_losses = defaultdict(list) |
| open_iters = {n: iter(dl) for n, dl in train_open_dls.items()} |
| open_active = set(open_iters.keys()) |
| obi = 0 |
| |
| while open_active: |
| for name in list(open_active): |
| try: |
| batch = next(open_iters[name]) |
| except StopIteration: |
| open_active.discard(name) |
| continue |
| |
| bt = {k: v.to(device) if isinstance(v, torch.Tensor) else v |
| for k, v in batch.items()} |
| |
| with torch.autocast(device_type="cuda", dtype=amp_dtype, |
| enabled=cfg.bf16 and device.type == "cuda"): |
| |
| vis_tok = jepa_model.vis(bt["pixel_values"]).float() |
| txt_tok = jepa_model.txt(bt["input_ids"], bt["attention_mask"]).float() |
| evidence, _, _ = jepa_model.evidence(vis_tok, txt_tok, bt["attention_mask"]) |
| |
| if jepa_model._use_rollout: |
| traj, z_final, z_proj = jepa_model.rollout(evidence) |
| else: |
| B2 = bt["batch_size"] |
| z0 = jepa_model.rollout.init_tokens.expand(B2,-1,-1) + \ |
| jepa_model.rollout.z0_proj(F.adaptive_avg_pool1d( |
| evidence.permute(0,2,1), jepa_model.rollout.num_tokens).permute(0,2,1)) |
| z_final = z0 |
| z_proj = jepa_model.rollout.out_proj(z0).unsqueeze(1) |
| |
| |
| jepa_loss_val = torch.tensor(0.0, device=device) |
| if jepa_model._use_jepa: |
| target_proj = jepa_model.target( |
| vis_tok.detach(), txt_tok.detach(), bt["attention_mask"].detach()) |
| jl = jepa_model.jepa_loss(z_proj, target_proj, torch.tensor(0.0, device=device)) |
| jepa_loss_val = jl["jepa"] + jl["reg"] |
| |
| |
| gen_loss, gen_logits = decoder( |
| z_final.float(), evidence.float(), |
| bt["questions"], bt["answers"], |
| max_answer_len=args.max_answer_len, |
| ) |
| |
| total_loss = cfg.jepa_weight * jepa_loss_val + args.gen_weight * gen_loss |
| loss = total_loss / args.grad_accum |
| |
| loss.backward() |
| batch_count += 1 |
| if batch_count % args.grad_accum == 0: |
| nn.utils.clip_grad_norm_(all_trainable, cfg.max_grad_norm) |
| optimizer.step(); scheduler.step(); optimizer.zero_grad() |
| jepa_model.update_target(global_step, total_steps) |
| global_step += 1 |
| |
| gen_losses[f"{name}_gen"].append(gen_loss.item()) |
| gen_losses[f"{name}_total"].append(total_loss.item()) |
| obi += 1 |
| if obi % 50 == 0: |
| avg = {k: np.mean(v[-50:]) for k, v in gen_losses.items()} |
| log.info(f" S{stage} E{epoch} OPEN B{obi} | " + |
| " | ".join(f"{k}={v:.4f}" for k, v in avg.items())) |
| trackio.log({f"train/{k}": v for k, v in avg.items()}) |
| |
| |
| log.info(f" Stage {stage} Epoch {epoch}: Evaluating...") |
| mc_eval_acc = p1.evaluate(jepa_model, eval_mc_dl, device, cfg) |
| log.info(f" ScienceQA: {mc_eval_acc:.1f}%") |
| |
| gen_results = evaluate_generative( |
| jepa_model, decoder, eval_open_dls, device, cfg, amp_dtype, |
| max_new_tokens=args.max_answer_len, |
| ) |
| for bm, metrics in gen_results.items(): |
| for mk, mv in metrics.items(): |
| log.info(f" {bm} {mk}: {mv:.2f}%") |
| |
| all_scores = [mc_eval_acc] + [v for m in gen_results.values() for v in m.values()] |
| composite = np.mean(all_scores) |
| log.info(f"{'='*40}") |
| log.info(f"Stage {stage} Epoch {epoch} | MC: {mc_eval_acc:.1f}% | Composite: {composite:.1f}") |
| log.info(f"{'='*40}") |
| |
| trackio.log({ |
| "eval/scienceqa_accuracy": mc_eval_acc, |
| "eval/composite_score": composite, |
| "eval/epoch": abs_epoch, "eval/stage": stage, |
| **{f"eval/{bm}_{mk}": mv for bm, m in gen_results.items() for mk, mv in m.items()}, |
| }) |
| |
| if composite > best_composite: |
| best_composite = composite |
| save_phase4_checkpoint( |
| jepa_model, decoder, cfg, args, abs_epoch, |
| mc_eval_acc, gen_results, composite, stage, |
| ) |
| log.info(f" ★ New best composite: {best_composite:.1f}") |
| |
| return best_composite |
| |
| |
| best_overall = 0.0 |
| |
| try: |
| if args.auto_transition: |
| |
| if args.stage1_epochs > 0: |
| s1_best = run_training_stage(stage=1, num_epochs=args.stage1_epochs, start_epoch=0) |
| best_overall = max(best_overall, s1_best) |
| |
| |
| if args.stage2_epochs > 0: |
| s2_best = run_training_stage(stage=2, num_epochs=args.stage2_epochs, |
| start_epoch=args.stage1_epochs) |
| best_overall = max(best_overall, s2_best) |
| else: |
| best_overall = run_training_stage(stage=args.stage, num_epochs=args.epochs) |
| |
| log.info(f"\nPhase 4 complete. Best composite: {best_overall:.1f}") |
| |
| finally: |
| trackio.log({"final/best_composite": best_overall, "final/phase": "4"}) |
| trackio.finish() |
| |
| |
| push_phase4_results(cfg, args, best_overall) |
|
|
|
|
| def save_phase4_checkpoint(jepa_model, decoder, cfg, args, epoch, |
| mc_acc, gen_results, composite, stage): |
| """Save combined checkpoint.""" |
| path = os.path.join(args.output_dir, "checkpoint_best.pt") |
| torch.save({ |
| "evidence": jepa_model.evidence.state_dict(), |
| "rollout": jepa_model.rollout.state_dict(), |
| "disc": jepa_model.disc.state_dict(), |
| "target_ev": jepa_model.target.t_ev.state_dict(), |
| "target_ro": jepa_model.target.t_ro.state_dict(), |
| "bridge": decoder.bridge.state_dict(), |
| "ev_pool": decoder.ev_pool.state_dict() if decoder.ev_pool is not None else None, |
| "smollm2": decoder.lm.state_dict(), |
| "config": cfg.__dict__, |
| "phase4_args": vars(args), |
| "epoch": epoch, "stage": stage, |
| "mc_eval_acc": mc_acc, |
| "gen_results": gen_results, |
| "composite_score": composite, |
| "phase": "4", |
| }, path) |
| log.info(f"Saved Phase 4 checkpoint: {path} (composite={composite:.1f})") |
|
|
|
|
| def push_phase4_results(cfg, args, best_composite): |
| """Push results and checkpoint to Hub.""" |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| |
| results = { |
| "run_name": args.run_name, "phase": "4", |
| "decoder": "SmolLM2-135M-Instruct", |
| "bridge": "LLaVA-1.5 MLP (768→576→576)", |
| "backbone": cfg.backbone, "K": cfg.K, |
| "best_composite_score": best_composite, |
| "stage1_epochs": args.stage1_epochs, |
| "stage2_epochs": args.stage2_epochs, |
| "bridge_lr": args.bridge_lr, "lm_lr": args.lm_lr, |
| "core_lr": args.core_lr, "label_smoothing": args.label_smoothing, |
| "num_evidence_tokens": args.num_evidence_tokens, |
| "gen_weight": args.gen_weight, |
| } |
| rp = os.path.join(args.output_dir, f"results_{args.run_name}.json") |
| with open(rp, "w") as f: |
| json.dump(results, f, indent=2) |
| |
| api.upload_file(path_or_fileobj=rp, |
| path_in_repo=f"results/{args.run_name}.json", |
| repo_id=args.hub_model_id, repo_type="model") |
| |
| best_ckpt = os.path.join(args.output_dir, "checkpoint_best.pt") |
| if os.path.exists(best_ckpt): |
| api.upload_file(path_or_fileobj=best_ckpt, |
| path_in_repo=f"checkpoints/{args.run_name}_best.pt", |
| repo_id=args.hub_model_id, repo_type="model") |
| |
| |
| script_path = os.path.abspath(__file__) |
| api.upload_file(path_or_fileobj=script_path, |
| path_in_repo="train_phase4.py", |
| repo_id=args.hub_model_id, repo_type="model") |
| |
| log.info(f"Pushed Phase 4 results to {args.hub_model_id}") |
| except Exception as e: |
| log.error(f"Push failed: {e}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|