#!/usr/bin/env python3 """ MR-JEPA Phase 4 — SmolLM2-135M Generative Decoder Replaces the random-init 4-layer transformer decoder (which produced 0% generative metrics after 10+ epochs) with SmolLM2-135M-Instruct as a pre-trained LM decoder. Architecture (BLIP-2 / LLaVA-1.5 pattern): z_K (768d) ──→ Bridge MLP (768→576→576) ──→ visual soft prompt tokens evidence (N×768d) ──→ same Bridge MLP ──→ evidence soft prompt tokens [vis_tokens, ev_tokens, text_tokens] ──→ SmolLM2-135M ──→ next-token prediction Training recipe (2-stage, following LLaVA/BLIP-2): Stage 1: Freeze SmolLM2, train only bridge MLP. LR=1e-3. Stage 2: Unfreeze SmolLM2, joint fine-tuning. LR=2e-5, cosine decay. Key improvements over Phase 3.x: 1. Pre-trained 30-layer LM decoder (135M params) vs random-init 4-layer (7M params) 2. LLaVA-1.5 two-layer MLP bridge (nonlinear alignment) vs none 3. Label smoothing (ε=0.1) to combat repetition collapse 4. Repetition penalty + nucleus sampling in evaluation 5. SmolLM2 tokenizer (49K vocab, ChatML) vs Qwen3 tokenizer (152K vocab) 6. Proper label masking: -100 for visual prefix, pad tokens Resumes JEPA/Evidence/Rollout/Disc from Phase 3.1 checkpoint. SmolLM2-135M loaded fresh from HuggingFace Hub. Usage: python train_phase4.py python train_phase4.py --stage 1 --epochs 5 --bridge_lr 1e-3 python train_phase4.py --stage 2 --epochs 10 --lm_lr 2e-5 """ import os import sys import json import math import copy import random import logging import argparse from collections import defaultdict, Counter import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import AdamW from torch.utils.data import Dataset, DataLoader from PIL import Image logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s", datefmt="%H:%M:%S", ) log = logging.getLogger("mrjepa-p4") # ══════════════════════════════════════════════════════════════════════════ # BRIDGE MODULE: JEPA latent space → SmolLM2 embedding space # ══════════════════════════════════════════════════════════════════════════ class VisionLanguageBridge(nn.Module): """ LLaVA-1.5 style 2-layer MLP connector. Projects JEPA representations (768d) into SmolLM2 space (576d). Applied to both z_K (global JEPA latent) and evidence tokens. The nonlinear projection is critical — BLIP-2 showed linear works, LLaVA-1.5 showed MLP is significantly better for VQA. """ def __init__(self, jepa_dim=768, lm_dim=576): super().__init__() self.proj = nn.Sequential( nn.Linear(jepa_dim, lm_dim), nn.GELU(), nn.Linear(lm_dim, lm_dim), ) # Initialize close to identity-like mapping nn.init.xavier_uniform_(self.proj[0].weight, gain=0.5) nn.init.zeros_(self.proj[0].bias) nn.init.xavier_uniform_(self.proj[2].weight, gain=0.1) nn.init.zeros_(self.proj[2].bias) def forward(self, features): """ Args: features: [B, N, 768] — either z_K or evidence tokens Returns: projected: [B, N, 576] — in SmolLM2 embedding space """ return self.proj(features) # ══════════════════════════════════════════════════════════════════════════ # SmolLM2 GENERATIVE DECODER # ══════════════════════════════════════════════════════════════════════════ class SmolLMDecoder(nn.Module): """ Wraps SmolLM2-135M-Instruct as the generative decoder. Architecture: 1. Bridge MLP projects z_K + evidence from JEPA space (768d) to LM space (576d) 2. Projected tokens are prepended as "soft visual prompts" before text tokens 3. SmolLM2 processes [vis_prefix | text_tokens] with causal attention 4. Loss computed only on answer tokens (visual prefix masked with -100) This follows the BLIP-2 / LLaVA pattern exactly: "projected query embeddings are prepended to the input text embeddings. They function as soft visual prompts that condition the LLM on visual representation." — Li et al., BLIP-2 §3.3 """ def __init__(self, jepa_dim=768, freeze_lm=True, label_smoothing=0.1, num_evidence_tokens=8): super().__init__() from transformers import AutoModelForCausalLM, AutoTokenizer log.info("Loading SmolLM2-135M-Instruct...") self.tokenizer = AutoTokenizer.from_pretrained( "HuggingFaceTB/SmolLM2-135M-Instruct" ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.lm = AutoModelForCausalLM.from_pretrained( "HuggingFaceTB/SmolLM2-135M-Instruct", torch_dtype=torch.bfloat16, ) self.lm_dim = self.lm.config.hidden_size # 576 self.vocab_size = self.lm.config.vocab_size # 49152 log.info(f"SmolLM2: hidden={self.lm_dim}, vocab={self.vocab_size}, " f"layers={self.lm.config.num_hidden_layers}") if freeze_lm: for p in self.lm.parameters(): p.requires_grad = False log.info("SmolLM2 weights frozen (Stage 1: train bridge only)") else: log.info("SmolLM2 weights trainable (Stage 2: full fine-tuning)") # Bridge MLP: JEPA space → SmolLM2 space self.bridge = VisionLanguageBridge(jepa_dim, self.lm_dim) # How many evidence tokens to use as soft prompts # (subsample from 64 to avoid very long prefix) self.num_evidence_tokens = num_evidence_tokens if num_evidence_tokens < 64: self.ev_pool = nn.Linear(jepa_dim, jepa_dim) # learned pooling else: self.ev_pool = None self.label_smoothing = label_smoothing self.freeze_lm = freeze_lm def unfreeze_lm(self): """Unfreeze SmolLM2 for Stage 2 fine-tuning.""" for p in self.lm.parameters(): p.requires_grad = True self.freeze_lm = False log.info("SmolLM2 unfrozen for Stage 2") def _subsample_evidence(self, evidence): """Subsample evidence tokens from 64 → num_evidence_tokens.""" B, N, D = evidence.shape if N <= self.num_evidence_tokens: return evidence # Learned attention pooling if self.ev_pool is not None: # Use strided selection + learned projection stride = N // self.num_evidence_tokens indices = torch.arange(0, N, stride, device=evidence.device)[:self.num_evidence_tokens] return evidence[:, indices] return evidence[:, :self.num_evidence_tokens] def prepare_inputs(self, z_final, evidence, questions, answers=None, max_answer_len=32): """ Prepare inputs for SmolLM2 forward pass. Args: z_final: [B, N_state, 768] — JEPA latent states evidence: [B, N_ev, 768] — evidence memory tokens questions: list[str] — question texts answers: list[str] or None — answer texts (None for generation) max_answer_len: int — max tokens for answer Returns: inputs_embeds: [B, N_vis + N_text, 576] attention_mask: [B, N_vis + N_text] labels: [B, N_vis + N_text] or None n_vis_tokens: int — number of visual prefix tokens """ device = z_final.device B = z_final.size(0) # 1. Project JEPA features to LM space vis_embeds = self.bridge(z_final) # [B, N_state, 576] ev_sub = self._subsample_evidence(evidence) # [B, N_ev_sub, 768] ev_embeds = self.bridge(ev_sub) # [B, N_ev_sub, 576] # Concatenate visual prefix: [z_K tokens | evidence tokens] vis_prefix = torch.cat([vis_embeds, ev_embeds], dim=1) # [B, N_vis, 576] n_vis = vis_prefix.size(1) # 2. Tokenize text if answers is not None: # Training: "Question: {q}\nAnswer: {a}<|im_end|>" texts = [] for q, a in zip(questions, answers): texts.append(f"Question: {q}\nAnswer: {a}") tok = self.tokenizer( texts, padding="max_length", truncation=True, max_length=192 + max_answer_len, return_tensors="pt", ).to(device) else: # Generation: "Question: {q}\nAnswer:" texts = [f"Question: {q}\nAnswer:" for q in questions] tok = self.tokenizer( texts, padding="max_length", truncation=True, max_length=192, return_tensors="pt", ).to(device) # 3. Get text token embeddings (bypass embedding table) text_embeds = self.lm.model.embed_tokens(tok["input_ids"]) # [B, L, 576] # 4. Prepend visual soft prompts — cast to LM dtype (bfloat16) lm_dtype = text_embeds.dtype vis_prefix = vis_prefix.to(lm_dtype) inputs_embeds = torch.cat([vis_prefix, text_embeds], dim=1) # [B, N_vis+L, 576] # 5. Extend attention mask vis_mask = torch.ones(B, n_vis, device=device, dtype=tok["attention_mask"].dtype) attention_mask = torch.cat([vis_mask, tok["attention_mask"]], dim=1) # 6. Build labels (if training) labels = None if answers is not None: # Visual prefix → -100 (ignore) vis_labels = torch.full((B, n_vis), -100, device=device, dtype=torch.long) # Text labels: shift by 1 for next-token prediction text_labels = tok["input_ids"].clone() # Mask padding tokens text_labels[text_labels == self.tokenizer.pad_token_id] = -100 # Find where the answer starts to only compute loss on answer tokens # We mask the question part too — only train on answer generation for i, (q, a) in enumerate(zip(questions, answers)): q_text = f"Question: {q}\nAnswer:" q_tok = self.tokenizer(q_text, add_special_tokens=False) q_len = len(q_tok["input_ids"]) # Mask question prefix in labels text_labels[i, :min(q_len, text_labels.size(1))] = -100 labels = torch.cat([vis_labels, text_labels], dim=1) return inputs_embeds, attention_mask, labels, n_vis def forward(self, z_final, evidence, questions, answers, max_answer_len=32): """ Training forward pass. Returns: loss: scalar tensor — CE loss with label smoothing logits: [B, L, V] — for debugging """ inputs_embeds, attention_mask, labels, n_vis = self.prepare_inputs( z_final, evidence, questions, answers, max_answer_len ) outputs = self.lm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, ) # Apply label smoothing manually if needed if self.label_smoothing > 0 and labels is not None: # Recompute loss with label smoothing — use float32 for stability logits = outputs.logits.float() shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, self.vocab_size), shift_labels.view(-1), ignore_index=-100, label_smoothing=self.label_smoothing, ) else: loss = outputs.loss return loss, outputs.logits @torch.no_grad() def generate(self, z_final, evidence, questions, max_new_tokens=32, temperature=0.7, top_p=0.9, repetition_penalty=1.3, no_repeat_ngram_size=3): """ Generate answers with nucleus sampling + repetition penalty. Returns: predictions: list[str] — decoded answer strings """ device = z_final.device B = z_final.size(0) # Prepare inputs (no answers → generation mode) inputs_embeds, attention_mask, _, n_vis = self.prepare_inputs( z_final, evidence, questions, answers=None ) # Generate token by token with sampling generated_ids = [] past_key_values = None cur_embeds = inputs_embeds cur_mask = attention_mask # Track generated tokens for repetition penalty all_generated = [[] for _ in range(B)] for step in range(max_new_tokens): outputs = self.lm( inputs_embeds=cur_embeds, attention_mask=cur_mask, past_key_values=past_key_values, use_cache=True, ) next_logits = outputs.logits[:, -1, :] # [B, V] past_key_values = outputs.past_key_values # Apply repetition penalty if repetition_penalty != 1.0: for b in range(B): for token_id in set(all_generated[b]): if next_logits[b, token_id] > 0: next_logits[b, token_id] /= repetition_penalty else: next_logits[b, token_id] *= repetition_penalty # Apply no-repeat n-gram blocking if no_repeat_ngram_size > 0 and len(all_generated[0]) >= no_repeat_ngram_size - 1: for b in range(B): gen = all_generated[b] if len(gen) >= no_repeat_ngram_size - 1: ngram_prefix = tuple(gen[-(no_repeat_ngram_size - 1):]) # Find all n-grams in history and block their continuations for i in range(len(gen) - no_repeat_ngram_size + 1): if tuple(gen[i:i + no_repeat_ngram_size - 1]) == ngram_prefix: blocked = gen[i + no_repeat_ngram_size - 1] next_logits[b, blocked] = float('-inf') # Temperature scaling + nucleus sampling if temperature > 0: next_logits = next_logits / temperature # Top-p (nucleus) sampling sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative prob > top_p sorted_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p sorted_logits[sorted_mask] = float('-inf') # Sample probs = F.softmax(sorted_logits, dim=-1) sampled_idx = torch.multinomial(probs, 1) # [B, 1] next_tokens = sorted_indices.gather(1, sampled_idx) # [B, 1] else: next_tokens = next_logits.argmax(dim=-1, keepdim=True) # [B, 1] generated_ids.append(next_tokens) # Update tracking for b in range(B): all_generated[b].append(next_tokens[b, 0].item()) # Check for EOS if (next_tokens == self.tokenizer.eos_token_id).all(): break # Prepare next step input (only the new token embedding) cur_embeds = self.lm.model.embed_tokens(next_tokens) cur_mask = torch.cat([ cur_mask, torch.ones(B, 1, device=device, dtype=cur_mask.dtype) ], dim=1) # Decode if generated_ids: gen_tensor = torch.cat(generated_ids, dim=1) # [B, T] predictions = [] for i in range(B): text = self.tokenizer.decode(gen_tensor[i], skip_special_tokens=True) # Clean up: take only up to first newline or period for short answers text = text.strip() predictions.append(text) else: predictions = [""] * B return predictions # ══════════════════════════════════════════════════════════════════════════ # OPEN-ENDED DATASET (reused from Phase 3.x) # ══════════════════════════════════════════════════════════════════════════ class OpenEndedDataset(Dataset): def __init__(self, benchmark, split, max_samples=0, transform=None, tokenizer=None, max_len=192): from datasets import load_dataset self.benchmark = benchmark self.transform = transform self.tokenizer = tokenizer self.max_len = max_len log.info(f"Loading {benchmark} {split}...") if benchmark == "docvqa": ds = load_dataset("lmms-lab/DocVQA", "DocVQA", split=split) elif benchmark == "chartqa": ds = load_dataset("lmms-lab/ChartQA", split=split) elif benchmark == "textvqa": ds = load_dataset("lmms-lab/textvqa", split=split) else: raise ValueError(f"Unknown benchmark: {benchmark}") if max_samples > 0: ds = ds.select(range(min(max_samples, len(ds)))) self.data = ds log.info(f"Loaded {len(ds)} samples from {benchmark} {split}") def __len__(self): return len(self.data) def __getitem__(self, idx): row = self.data[idx] img = row.get("image") if img is None: img = Image.new("RGB", (256, 256), "white") else: img = img.convert("RGB") question = row["question"] if self.benchmark == "docvqa": answers = row.get("answers", [""]) answer = answers[0] if answers else "" all_answers = answers elif self.benchmark == "chartqa": answer = str(row.get("answer", "")) all_answers = [answer] elif self.benchmark == "textvqa": answers = row.get("answers", [""]) answer_counts = Counter(a.lower().strip() for a in answers) answer = answer_counts.most_common(1)[0][0] if answer_counts else "" all_answers = answers else: answer = "" all_answers = [""] ocr_tokens = row.get("ocr_tokens", []) ocr_text = " ".join(ocr_tokens[:50]) if ocr_tokens else "" text = question if ocr_text: text += f" [OCR: {ocr_text}]" return { "image": img, "text": text, "answer": answer, "all_answers": all_answers, "benchmark": self.benchmark, } def collate_open_ended_p4(batch, transform, qwen_tokenizer, max_len): """Collate for Phase 4 — we only need image, question text, and answer string.""" images = [s["image"] for s in batch] texts = [s["text"] for s in batch] answers = [s["answer"] for s in batch] if hasattr(transform, '__call__') and not hasattr(transform, 'feature_extractor'): pixel_values = torch.stack([transform(img) for img in images]) else: pixel_values = transform(images=images, return_tensors="pt")["pixel_values"] # Tokenize with Qwen tokenizer (for the JEPA text encoder) tok = qwen_tokenizer(texts, padding="max_length", truncation=True, max_length=max_len, return_tensors="pt") return { "pixel_values": pixel_values, "input_ids": tok["input_ids"], "attention_mask": tok["attention_mask"], "questions": texts, "answers": answers, "batch_size": len(batch), "benchmarks": [s["benchmark"] for s in batch], "all_answers": [s["all_answers"] for s in batch], } # ══════════════════════════════════════════════════════════════════════════ # EVALUATION METRICS (same as Phase 3.x) # ══════════════════════════════════════════════════════════════════════════ def normalized_levenshtein(s1, s2): s1, s2 = s1.lower().strip(), s2.lower().strip() if s1 == s2: return 0.0 l1, l2 = len(s1), len(s2) if l1 == 0 or l2 == 0: return 1.0 m = [[0]*(l2+1) for _ in range(l1+1)] for i in range(l1+1): m[i][0] = i for j in range(l2+1): m[0][j] = j for i in range(1,l1+1): for j in range(1,l2+1): c = 0 if s1[i-1]==s2[j-1] else 1 m[i][j] = min(m[i-1][j]+1, m[i][j-1]+1, m[i-1][j-1]+c) return m[l1][l2]/max(l1,l2) def compute_anls(predictions, ground_truths, threshold=0.5): scores = [] for pred, gts in zip(predictions, ground_truths): mx = max((1.0-normalized_levenshtein(str(pred),str(gt)) if normalized_levenshtein(str(pred),str(gt)) best_composite: best_composite = composite save_phase4_checkpoint( jepa_model, decoder, cfg, args, abs_epoch, mc_eval_acc, gen_results, composite, stage, ) log.info(f" ★ New best composite: {best_composite:.1f}") return best_composite # ── Execute training ── best_overall = 0.0 try: if args.auto_transition: # Stage 1: Freeze LM, train bridge if args.stage1_epochs > 0: s1_best = run_training_stage(stage=1, num_epochs=args.stage1_epochs, start_epoch=0) best_overall = max(best_overall, s1_best) # Stage 2: Unfreeze all if args.stage2_epochs > 0: s2_best = run_training_stage(stage=2, num_epochs=args.stage2_epochs, start_epoch=args.stage1_epochs) best_overall = max(best_overall, s2_best) else: best_overall = run_training_stage(stage=args.stage, num_epochs=args.epochs) log.info(f"\nPhase 4 complete. Best composite: {best_overall:.1f}") finally: trackio.log({"final/best_composite": best_overall, "final/phase": "4"}) trackio.finish() # Push final results push_phase4_results(cfg, args, best_overall) def save_phase4_checkpoint(jepa_model, decoder, cfg, args, epoch, mc_acc, gen_results, composite, stage): """Save combined checkpoint.""" path = os.path.join(args.output_dir, "checkpoint_best.pt") torch.save({ "evidence": jepa_model.evidence.state_dict(), "rollout": jepa_model.rollout.state_dict(), "disc": jepa_model.disc.state_dict(), "target_ev": jepa_model.target.t_ev.state_dict(), "target_ro": jepa_model.target.t_ro.state_dict(), "bridge": decoder.bridge.state_dict(), "ev_pool": decoder.ev_pool.state_dict() if decoder.ev_pool is not None else None, "smollm2": decoder.lm.state_dict(), "config": cfg.__dict__, "phase4_args": vars(args), "epoch": epoch, "stage": stage, "mc_eval_acc": mc_acc, "gen_results": gen_results, "composite_score": composite, "phase": "4", }, path) log.info(f"Saved Phase 4 checkpoint: {path} (composite={composite:.1f})") def push_phase4_results(cfg, args, best_composite): """Push results and checkpoint to Hub.""" try: from huggingface_hub import HfApi api = HfApi() results = { "run_name": args.run_name, "phase": "4", "decoder": "SmolLM2-135M-Instruct", "bridge": "LLaVA-1.5 MLP (768→576→576)", "backbone": cfg.backbone, "K": cfg.K, "best_composite_score": best_composite, "stage1_epochs": args.stage1_epochs, "stage2_epochs": args.stage2_epochs, "bridge_lr": args.bridge_lr, "lm_lr": args.lm_lr, "core_lr": args.core_lr, "label_smoothing": args.label_smoothing, "num_evidence_tokens": args.num_evidence_tokens, "gen_weight": args.gen_weight, } rp = os.path.join(args.output_dir, f"results_{args.run_name}.json") with open(rp, "w") as f: json.dump(results, f, indent=2) api.upload_file(path_or_fileobj=rp, path_in_repo=f"results/{args.run_name}.json", repo_id=args.hub_model_id, repo_type="model") best_ckpt = os.path.join(args.output_dir, "checkpoint_best.pt") if os.path.exists(best_ckpt): api.upload_file(path_or_fileobj=best_ckpt, path_in_repo=f"checkpoints/{args.run_name}_best.pt", repo_id=args.hub_model_id, repo_type="model") # Also upload the training script script_path = os.path.abspath(__file__) api.upload_file(path_or_fileobj=script_path, path_in_repo="train_phase4.py", repo_id=args.hub_model_id, repo_type="model") log.info(f"Pushed Phase 4 results to {args.hub_model_id}") except Exception as e: log.error(f"Push failed: {e}") if __name__ == "__main__": main()