#!/usr/bin/env python3 """ MR-JEPA Phase 2 Training — Perception Fine-tuning + SOTA Visual Diagnostics Loads the best Phase 1 checkpoint and unfreezes: - Last 6 DINOv3-L layers (LR: 1e-5) - Last 4 Qwen3-Embedding layers (LR: 1e-5) - Reasoning core continues at 1e-4 Visual diagnostics logged to Trackio (state-of-the-art for JEPA): 1. PCA Feature Maps (V-JEPA 2.1 style) — patch features → RGB via PCA 2. Multi-head Attention Heatmaps (DINO style) — CLS attention per head overlaid on image 3. RankMe Score (anti-collapse) — effective rank of embedding matrix 4. Per-dimension Variance (VICReg style) — collapse detection per dim 5. Latent Trajectory PCA — z₀→z₁→z₂→z₃ projected to 2D 6. Temporal Straightness (LeWM Eq. 9) — trajectory coherence metric 7. Evidence Gate Activation Heatmaps — gate values per rollout step 8. Token Norm Maps (DINOv2-Reg style) — artifact detection 9. Cross-Attention Weights in Perceiver — which evidence each query attends to 10. Eigenspectrum Plot — singular value distribution of latent space All images are persisted to HF Space JorgeAV/MR-JEPA-Trackio via space_id parameter. Usage: python train_phase2.py --checkpoint checkpoints/hybrid_main_best.pt python train_phase2.py --epochs 10 --backbone_lr 1e-5 Prerequisites: Phase 1 must be complete with a saved checkpoint at JorgeAV/MR-JEPA """ import os import sys import json import math import copy import logging import argparse from collections import defaultdict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import AdamW from torch.utils.data import DataLoader # ── Matplotlib non-interactive backend (MUST be before pyplot import) ── import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.colors as mcolors import seaborn as sns from sklearn.decomposition import PCA as SklearnPCA from PIL import Image as PILImage import io logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s", datefmt="%H:%M:%S") log = logging.getLogger("mrjepa-p2") def fig_to_pil(fig: plt.Figure) -> PILImage.Image: """Convert matplotlib figure to PIL Image for Trackio logging. Trackio accepts PIL.Image but NOT matplotlib.figure.Figure directly.""" buf = io.BytesIO() fig.savefig(buf, format='png', dpi=120, bbox_inches='tight', pad_inches=0.1) buf.seek(0) img = PILImage.open(buf).copy() buf.close() return img # ══════════════════════════════════════════════════════════════════════════ # SOTA VISUALIZATION FUNCTIONS (Papers: I-JEPA, V-JEPA 2.1, DINO, # LeWorldModel, RankMe, VICReg, DINOv2+Registers) # ══════════════════════════════════════════════════════════════════════════ def rankme(Z: torch.Tensor, epsilon: float = 1e-7) -> float: """ RankMe: effective rank of embedding matrix via Shannon entropy of singular values. Source: "RankMe: Assessing the Downstream Performance of Pretrained Self-Supervised Representations by Their Rank" (arxiv:2210.02885) Z: (N, D) — batch of embeddings Returns scalar effective rank (higher = less collapsed) """ if Z.dim() == 3: Z = Z.reshape(-1, Z.size(-1)) Z_centered = Z - Z.mean(0, keepdim=True) try: _, S, _ = torch.linalg.svd(Z_centered.float(), full_matrices=False) p = S / (S.sum() + epsilon) + epsilon return torch.exp(-torch.sum(p * torch.log(p))).item() except: return 0.0 def vicreg_collapse_stats(Z: torch.Tensor) -> dict: """ VICReg-style per-dimension variance monitoring. Source: "VICReg: Variance-Invariance-Covariance Regularization" (arxiv:2105.04906) Z: (N, D) — monitor each dimension's std Collapsed dims have std < 0.1 """ if Z.dim() == 3: Z = Z.reshape(-1, Z.size(-1)) std_per_dim = Z.float().std(0) # (D,) return { "min_std": std_per_dim.min().item(), "mean_std": std_per_dim.mean().item(), "max_std": std_per_dim.max().item(), "collapsed_dims": (std_per_dim < 0.1).sum().item(), "total_dims": std_per_dim.size(0), "std_values": std_per_dim.detach().cpu().numpy(), } def temporal_straightness(z_seq: torch.Tensor) -> float: """ Temporal straightness metric from LeWorldModel (Eq. 9, Appendix H, Fig. 17). Source: "Le World Model" (arxiv:2603.19312) Measures geometric coherence of latent trajectory. z_seq: (B, K+1, D) or (B, K+1, N, D) — latent trajectory Returns mean cosine similarity between consecutive velocity vectors. Higher = more coherent (straighter) trajectory. """ if z_seq.dim() == 4: z_seq = z_seq.mean(dim=2) # Pool tokens → (B, K+1, D) v = z_seq[:, 1:] - z_seq[:, :-1] # (B, K, D) if v.size(1) < 2: return 0.0 v_norm = F.normalize(v.float(), dim=-1) cos_sim = (v_norm[:, :-1] * v_norm[:, 1:]).sum(-1) # (B, K-1) return cos_sim.mean().item() def pca_feature_map(patch_features: torch.Tensor, grid_h: int, grid_w: int) -> np.ndarray: """ PCA RGB feature map visualization (V-JEPA 2.1 style). Source: "Revisiting Feature Prediction for Learning Visual Representations from Video" (arxiv:2603.14482), Section 2.2, Figs. 1, 3. Maps first 3 principal components of patch features → RGB channels. Semantic objects share PCA colors regardless of position. patch_features: (N_patches, D) Returns: (grid_h, grid_w, 3) uint8 array suitable for display """ feats = patch_features.float().detach().cpu().numpy() n_components = min(3, feats.shape[0], feats.shape[1]) if n_components < 3: return np.zeros((grid_h, grid_w, 3), dtype=np.uint8) pca = SklearnPCA(n_components=3) pca_feats = pca.fit_transform(feats) # (N, 3) # Normalize each component to [0, 1] for c in range(3): vmin, vmax = pca_feats[:, c].min(), pca_feats[:, c].max() if vmax - vmin > 1e-8: pca_feats[:, c] = (pca_feats[:, c] - vmin) / (vmax - vmin) else: pca_feats[:, c] = 0.5 n_patches = pca_feats.shape[0] if n_patches != grid_h * grid_w: actual_side = int(math.sqrt(n_patches)) grid_h = grid_w = actual_side pca_img = (pca_feats[:grid_h * grid_w].reshape(grid_h, grid_w, 3) * 255).astype(np.uint8) return pca_img def eigenspectrum_plot(Z: torch.Tensor, title: str = "Eigenspectrum") -> plt.Figure: """ Plot singular value spectrum of embedding matrix. Source: "RankMe" (arxiv:2210.02885), used in DINO, I-JEPA for collapse monitoring. Healthy representations show flat spectrum; collapsed ones show sharp decay. """ if Z.dim() == 3: Z = Z.reshape(-1, Z.size(-1)) Z_c = Z.float() - Z.float().mean(0, keepdim=True) try: _, S, _ = torch.linalg.svd(Z_c, full_matrices=False) s = S.detach().cpu().numpy() except: s = np.ones(10) fig, ax = plt.subplots(figsize=(6, 4)) ax.semilogy(s[:min(100, len(s))], 'b-', linewidth=1.5) ax.fill_between(range(min(100, len(s))), s[:min(100, len(s))], alpha=0.15, color='blue') ax.set_xlabel("Singular Value Index", fontsize=10) ax.set_ylabel("Singular Value (log)", fontsize=10) ax.set_title(title, fontsize=12, fontweight='bold') ax.grid(True, alpha=0.3) fig.tight_layout() return fig def per_dim_variance_plot(std_values: np.ndarray, title: str = "Per-Dimension Std") -> plt.Figure: """ VICReg-style per-dimension standard deviation bar plot. Source: "VICReg" (arxiv:2105.04906), Section 4. Each bar = std of one embedding dimension across the batch. Target: all dims near γ=1.0. Dims with std→0 are collapsed. """ fig, ax = plt.subplots(figsize=(8, 3)) n = min(len(std_values), 200) colors = ['red' if v < 0.1 else ('orange' if v < 0.3 else 'steelblue') for v in std_values[:n]] ax.bar(range(n), std_values[:n], color=colors, width=1.0, edgecolor='none') ax.axhline(y=1.0, color='green', linestyle='--', alpha=0.7, label='Target γ=1.0') ax.axhline(y=0.1, color='red', linestyle='--', alpha=0.5, label='Collapse threshold') ax.set_xlabel("Dimension Index", fontsize=9) ax.set_ylabel("Std", fontsize=9) ax.set_title(title, fontsize=11, fontweight='bold') ax.legend(fontsize=8) ax.set_xlim(-0.5, n - 0.5) fig.tight_layout() return fig def trajectory_pca_plot(trajectory: torch.Tensor, title: str = "Latent Trajectory") -> plt.Figure: """ Latent trajectory visualization via PCA projection. Source: I-JEPA + LeWorldModel trajectory analysis. trajectory: (K+1, N_tokens, D) — single sample trajectory Projects step centroids to 2D, draws arrows showing reasoning evolution. """ K_plus_1, N_s, D = trajectory.shape centroids = trajectory.mean(dim=1).float().detach().cpu().numpy() if K_plus_1 < 2: fig, ax = plt.subplots(figsize=(5, 5)) ax.set_title(title) return fig centered = centroids - centroids.mean(axis=0) try: pca = SklearnPCA(n_components=2) coords = pca.fit_transform(centered) except: coords = centered[:, :2] fig, ax = plt.subplots(figsize=(6, 6)) colors = plt.cm.viridis(np.linspace(0, 1, K_plus_1)) for k in range(K_plus_1 - 1): ax.annotate("", xy=coords[k+1], xytext=coords[k], arrowprops=dict(arrowstyle="->", color=colors[k], lw=2.5)) for k in range(K_plus_1): ax.scatter(coords[k, 0], coords[k, 1], c=[colors[k]], s=150, zorder=5, edgecolors='black', linewidth=1.5) label = 'z₀' if k == 0 else f'z_{k}' ax.annotate(label, (coords[k, 0], coords[k, 1]), textcoords="offset points", xytext=(10, 10), fontsize=12, fontweight='bold', color=colors[k]) ax.set_xlabel("PC1", fontsize=10) ax.set_ylabel("PC2", fontsize=10) ax.set_title(title, fontsize=12, fontweight='bold') ax.grid(True, alpha=0.3) fig.tight_layout() return fig def attention_heatmap_overlay( attn_weights: torch.Tensor, original_image: np.ndarray, grid_h: int = 16, grid_w: int = 16, title: str = "Attention Heatmap", num_heads_to_show: int = 4, ) -> plt.Figure: """ DINO-style multi-head self-attention heatmap overlay. Source: "Emerging Properties in Self-Supervised Vision Transformers" (arxiv:2104.14294), Section 4.2.2, Fig. 3. """ if attn_weights.dim() == 2: attn_weights = attn_weights.unsqueeze(0) n_heads = min(attn_weights.size(0), num_heads_to_show) fig, axes = plt.subplots(1, n_heads + 1, figsize=(4 * (n_heads + 1), 4)) if n_heads + 1 == 1: axes = [axes] axes[0].imshow(original_image) axes[0].set_title("Input Image", fontsize=10) axes[0].axis('off') head_colors = ['Reds', 'Blues', 'Greens', 'Purples', 'Oranges', 'YlOrRd'] for h in range(n_heads): attn = attn_weights[h].float().detach().cpu().numpy() if attn.ndim == 2: attn_map = attn[0] else: attn_map = attn n_tokens = attn_map.shape[0] side = int(math.sqrt(n_tokens)) if side * side != n_tokens: side = grid_h attn_2d = attn_map[:side*side].reshape(side, side) attn_2d = (attn_2d - attn_2d.min()) / (attn_2d.max() - attn_2d.min() + 1e-8) attn_resized = np.array(PILImage.fromarray((attn_2d * 255).astype(np.uint8)).resize( (original_image.shape[1], original_image.shape[0]), PILImage.BILINEAR)) / 255.0 axes[h + 1].imshow(original_image) axes[h + 1].imshow(attn_resized, cmap=head_colors[h % len(head_colors)], alpha=0.6) axes[h + 1].set_title(f"Head {h}", fontsize=10) axes[h + 1].axis('off') fig.suptitle(title, fontsize=12, fontweight='bold') fig.tight_layout() return fig def evidence_gate_heatmap(gate_values: list, title: str = "Evidence Gate Activations") -> plt.Figure: """Evidence gate activation visualization per rollout step.""" K = len(gate_values) if K == 0: fig, ax = plt.subplots() ax.text(0.5, 0.5, "No gates recorded", ha='center', va='center') return fig fig, axes = plt.subplots(1, K, figsize=(4 * K, 4)) if K == 1: axes = [axes] for k, gv in enumerate(gate_values): if isinstance(gv, torch.Tensor): gv = gv.float().detach().cpu().numpy() if gv.ndim == 3: gv = gv.mean(0) mean_gate = gv.mean(axis=-1) if gv.ndim == 2 else gv im = axes[k].imshow(mean_gate.reshape(1, -1) if mean_gate.ndim == 1 else mean_gate, cmap='YlOrRd', aspect='auto', vmin=0, vmax=1) axes[k].set_title(f'Step {k+1}: μ={mean_gate.mean():.3f}', fontsize=10) axes[k].set_xlabel("Token Index") fig.colorbar(im, ax=axes[k], fraction=0.046, pad=0.04) fig.suptitle(title, fontsize=12, fontweight='bold') fig.tight_layout() return fig def token_norm_map(tokens: torch.Tensor, grid_h: int = 16, grid_w: int = 16, title: str = "Token Norm Map") -> plt.Figure: """ Token norm map for artifact detection (DINOv2+Registers style). Source: "Vision Transformers Need Registers" (arxiv:2309.16588). """ if tokens.dim() == 3: tokens = tokens[0] norms = tokens.float().norm(dim=-1).detach().cpu().numpy() n_tokens = len(norms) side = int(math.sqrt(n_tokens)) if side * side != n_tokens: side = grid_h norm_map = norms[:side*side].reshape(side, side) fig, axes = plt.subplots(1, 2, figsize=(10, 4)) im = axes[0].imshow(norm_map, cmap='hot', aspect='auto') axes[0].set_title(f"{title}\nμ={norms.mean():.2f}, σ={norms.std():.2f}", fontsize=10) fig.colorbar(im, ax=axes[0]) axes[1].hist(norms, bins=50, color='steelblue', edgecolor='white', alpha=0.8) axes[1].axvline(norms.mean(), color='red', linestyle='--', label=f'Mean: {norms.mean():.2f}') axes[1].axvline(norms.mean() + 2*norms.std(), color='orange', linestyle='--', label=f'+2σ: {norms.mean() + 2*norms.std():.2f}') axes[1].set_title("Norm Distribution", fontsize=10) axes[1].set_xlabel("||token||₂") axes[1].legend(fontsize=8) fig.suptitle(title, fontsize=12, fontweight='bold') fig.tight_layout() return fig def cross_attention_weights_plot(attn_weights: torch.Tensor, title: str = "Perceiver Cross-Attention") -> plt.Figure: """Cross-attention weights in the Perceiver Resampler.""" if attn_weights.dim() == 3: attn = attn_weights.float().mean(0).detach().cpu().numpy() else: attn = attn_weights.float().detach().cpu().numpy() fig, ax = plt.subplots(figsize=(8, 6)) sns.heatmap(attn, cmap='viridis', ax=ax, xticklabels=False, yticklabels=False) ax.set_xlabel("Key Tokens (Evidence: Visual | Text)", fontsize=10) ax.set_ylabel("Query Tokens (Latent)", fontsize=10) ax.set_title(title, fontsize=12, fontweight='bold') fig.tight_layout() return fig def rollout_comparison_grid(images: list, pca_maps: list, titles: list = None, suptitle: str = "Latent Rollout PCA Maps") -> plt.Figure: """ Grid comparing original images with PCA feature maps at each rollout step. Source: LeWorldModel Fig. 7 + V-JEPA 2.1 Fig. 3. """ n_samples = min(len(images), 4) n_steps = len(pca_maps[0]) if pca_maps else 0 cols = 1 + n_steps fig, axes = plt.subplots(n_samples, cols, figsize=(3 * cols, 3 * n_samples)) if n_samples == 1: axes = axes.reshape(1, -1) for i in range(n_samples): axes[i, 0].imshow(images[i]) axes[i, 0].set_title("Input" if i == 0 else "", fontsize=9) axes[i, 0].axis('off') for k in range(n_steps): if k < len(pca_maps[i]): axes[i, k+1].imshow(pca_maps[i][k]) axes[i, k+1].set_title(f"z_{k}" if i == 0 else "", fontsize=9) axes[i, k+1].axis('off') fig.suptitle(suptitle, fontsize=13, fontweight='bold') fig.tight_layout() return fig # ══════════════════════════════════════════════════════════════════════════ # HOOK-BASED DIAGNOSTIC COLLECTOR # ══════════════════════════════════════════════════════════════════════════ class DiagnosticCollector: """Collects intermediate activations via hooks for visualization.""" def __init__(self): self.gate_values = [] self.cross_attn_weights = [] self.hooks = [] def attach(self, model): self.clear() if hasattr(model, 'rollout') and hasattr(model.rollout, 'predictor'): for i, block in enumerate(model.rollout.predictor): if hasattr(block, 'gate') and block.gate is not None: def make_gate_hook(layer_idx): def hook(module, input, output): self.gate_values.append(output.detach().cpu()) return hook h = block.gate.proj.register_forward_hook(make_gate_hook(i)) self.hooks.append(h) if hasattr(model, 'evidence') and hasattr(model.evidence, 'layers'): last_layer = model.evidence.layers[-1] if hasattr(last_layer, 'xa'): def xa_hook(module, input, output): if isinstance(output, tuple) and len(output) > 1: self.cross_attn_weights.append(output[1].detach().cpu()) h = last_layer.xa.register_forward_hook(xa_hook) self.hooks.append(h) def clear(self): self.gate_values = [] self.cross_attn_weights = [] def detach(self): for h in self.hooks: h.remove() self.hooks = [] # ══════════════════════════════════════════════════════════════════════════ # MAIN VISUALIZATION LOGGING FUNCTION # ══════════════════════════════════════════════════════════════════════════ def log_visual_diagnostics(model, batch, device, cfg, global_step, epoch, diagnostics_collector=None, vis_interval=100): """ Generate and log all visual diagnostics to Trackio. Implements SOTA visualizations from: - I-JEPA (attention maps), V-JEPA 2.1 (PCA feature maps) - LeWorldModel (trajectory, temporal straightness) - RankMe (effective rank), VICReg (per-dim variance) - DINOv2+Registers (token norm maps) """ import trackio model.eval() log_dict = {} try: with torch.no_grad(): vis_tok = model.vis(batch["pixel_values"].to(device)).float() txt_tok = model.txt(batch["input_ids"].to(device), batch["attention_mask"].to(device)).float() evidence, kv, ev_mask = model.evidence(vis_tok, txt_tok, batch["attention_mask"].to(device)) if model._use_rollout: traj, z_final, z_proj = model.rollout(evidence) else: traj = evidence.unsqueeze(1) z_final = evidence z_proj = evidence.unsqueeze(1) # ── RankMe Score (anti-collapse) ── evidence_rank = rankme(evidence) z_final_rank = rankme(z_final) vis_rank = rankme(vis_tok) log_dict["diagnostics/rankme_evidence"] = evidence_rank log_dict["diagnostics/rankme_z_final"] = z_final_rank log_dict["diagnostics/rankme_visual"] = vis_rank # ── VICReg Collapse Stats ── ev_stats = vicreg_collapse_stats(evidence) zf_stats = vicreg_collapse_stats(z_final) log_dict["diagnostics/evidence_min_std"] = ev_stats["min_std"] log_dict["diagnostics/evidence_mean_std"] = ev_stats["mean_std"] log_dict["diagnostics/evidence_collapsed_dims"] = ev_stats["collapsed_dims"] log_dict["diagnostics/z_final_min_std"] = zf_stats["min_std"] log_dict["diagnostics/z_final_mean_std"] = zf_stats["mean_std"] log_dict["diagnostics/z_final_collapsed_dims"] = zf_stats["collapsed_dims"] # ── Temporal Straightness (LeWM Eq. 9) ── if traj.dim() == 4: straightness = temporal_straightness(traj) log_dict["diagnostics/temporal_straightness"] = straightness centroids = traj.mean(dim=2) for k in range(centroids.size(1) - 1): dist = torch.norm(centroids[:, k+1] - centroids[:, k], dim=-1).mean().item() log_dict[f"diagnostics/step_distance_z{k}_to_z{k+1}"] = dist # ── PCA Feature Maps (V-JEPA 2.1) ── n_vis_patches = vis_tok.size(1) grid_side = int(math.sqrt(n_vis_patches)) pca_ctx = pca_feature_map(vis_tok[0], grid_side, grid_side) fig_pca_ctx = plt.figure(figsize=(4, 4)) plt.imshow(pca_ctx); plt.title("Context Encoder PCA", fontsize=11, fontweight='bold') plt.axis('off'); plt.tight_layout() log_dict["visuals/pca/context_encoder"] = trackio.Image(fig_to_pil(fig_pca_ctx), caption=f"step={global_step}") plt.close(fig_pca_ctx) pca_ev = pca_feature_map(evidence[0], 8, 8) fig_pca_ev = plt.figure(figsize=(4, 4)) plt.imshow(pca_ev); plt.title("Evidence Memory PCA", fontsize=11, fontweight='bold') plt.axis('off'); plt.tight_layout() log_dict["visuals/pca/evidence_memory"] = trackio.Image(fig_to_pil(fig_pca_ev), caption=f"step={global_step}") plt.close(fig_pca_ev) if traj.dim() == 4 and traj.size(1) > 1: rollout_pcas = [] for k in range(traj.size(1)): n_tok = traj.size(2) side_k = max(int(math.sqrt(n_tok)), 1) rollout_pcas.append(pca_feature_map(traj[0, k], side_k, side_k)) fig_rollout = plt.figure(figsize=(4 * len(rollout_pcas), 4)) for k, pca_k in enumerate(rollout_pcas): ax = fig_rollout.add_subplot(1, len(rollout_pcas), k + 1) ax.imshow(pca_k); ax.set_title(f"z_{k}", fontsize=10); ax.axis('off') fig_rollout.suptitle("Rollout PCA per Step", fontsize=12, fontweight='bold') fig_rollout.tight_layout() log_dict["visuals/pca/rollout_steps"] = trackio.Image(fig_to_pil(fig_rollout), caption=f"step={global_step}") plt.close(fig_rollout) # ── Eigenspectrum Plot ── fig_eigen_ev = eigenspectrum_plot(evidence, "Evidence Eigenspectrum") log_dict["visuals/eigenspectrum/evidence"] = trackio.Image(fig_to_pil(fig_eigen_ev), caption=f"RankMe={evidence_rank:.1f}, step={global_step}") plt.close(fig_eigen_ev) fig_eigen_zf = eigenspectrum_plot(z_final, "z_final Eigenspectrum") log_dict["visuals/eigenspectrum/z_final"] = trackio.Image(fig_to_pil(fig_eigen_zf), caption=f"RankMe={z_final_rank:.1f}, step={global_step}") plt.close(fig_eigen_zf) # ── Per-Dimension Variance Plot (VICReg) ── fig_vardim_ev = per_dim_variance_plot(ev_stats["std_values"], f"Evidence Per-Dim Std (collapsed={ev_stats['collapsed_dims']})") log_dict["visuals/collapse/evidence_std"] = trackio.Image(fig_to_pil(fig_vardim_ev), caption=f"step={global_step}") plt.close(fig_vardim_ev) fig_vardim_zf = per_dim_variance_plot(zf_stats["std_values"], f"z_final Per-Dim Std (collapsed={zf_stats['collapsed_dims']})") log_dict["visuals/collapse/z_final_std"] = trackio.Image(fig_to_pil(fig_vardim_zf), caption=f"step={global_step}") plt.close(fig_vardim_zf) # ── Latent Trajectory PCA ── if traj.dim() == 4 and traj.size(1) > 1: fig_traj = trajectory_pca_plot(traj[0], f"Latent Trajectory (straightness={straightness:.3f})") log_dict["visuals/trajectory/pca"] = trackio.Image(fig_to_pil(fig_traj), caption=f"K={traj.size(1)-1}, step={global_step}") plt.close(fig_traj) # ── Token Norm Maps (DINOv2+Registers) ── fig_norm_vis = token_norm_map(vis_tok[0], grid_side, grid_side, "Visual Token Norms") log_dict["visuals/norms/visual_tokens"] = trackio.Image(fig_to_pil(fig_norm_vis), caption=f"step={global_step}") plt.close(fig_norm_vis) fig_norm_ev = token_norm_map(evidence[0], 8, 8, "Evidence Token Norms") log_dict["visuals/norms/evidence_tokens"] = trackio.Image(fig_to_pil(fig_norm_ev), caption=f"step={global_step}") plt.close(fig_norm_ev) # ── Cross-Attention Weights (hooks) ── if diagnostics_collector and diagnostics_collector.cross_attn_weights: attn_w = diagnostics_collector.cross_attn_weights[-1] if attn_w is not None and attn_w.dim() >= 2: if attn_w.dim() == 4: attn_w = attn_w[0] elif attn_w.dim() == 3: attn_w = attn_w[0] fig_xattn = cross_attention_weights_plot(attn_w, "Perceiver Cross-Attention (Last Layer)") log_dict["visuals/attention/perceiver_xattn"] = trackio.Image(fig_to_pil(fig_xattn), caption=f"step={global_step}") plt.close(fig_xattn) # ── Evidence Gate Heatmaps ── if diagnostics_collector and diagnostics_collector.gate_values: gate_vals = [gv[0] if gv.dim() == 3 else gv for gv in diagnostics_collector.gate_values[-cfg.K:]] if gate_vals: fig_gates = evidence_gate_heatmap(gate_vals, "Evidence Gate Activations") log_dict["visuals/gates/activations"] = trackio.Image(fig_to_pil(fig_gates), caption=f"step={global_step}") plt.close(fig_gates) for k, gv in enumerate(gate_vals): if isinstance(gv, torch.Tensor): log_dict[f"diagnostics/gate_mean_step{k+1}"] = gv.float().mean().item() # ── Input Image with Feature Similarity Heatmap (DINO style) ── img_tensor = batch["pixel_values"][0] mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) img_denorm = (img_tensor.cpu() * std + mean).clamp(0, 1) img_np = (img_denorm.permute(1, 2, 0).numpy() * 255).astype(np.uint8) vis_feats = vis_tok[0] mean_feat = vis_feats.mean(0, keepdim=True) sim = F.cosine_similarity(vis_feats, mean_feat.expand_as(vis_feats), dim=-1).detach().cpu() n_patches = sim.size(0) side = int(math.sqrt(n_patches)) if side * side == n_patches: fig_attn_overlay, ax = plt.subplots(1, 2, figsize=(8, 4)) ax[0].imshow(img_np); ax[0].set_title("Input Image", fontsize=10); ax[0].axis('off') attn_resized = np.array(PILImage.fromarray( ((sim.numpy().reshape(side, side) - sim.min().item()) / (sim.max().item() - sim.min().item() + 1e-8) * 255).astype(np.uint8) ).resize((img_np.shape[1], img_np.shape[0]), PILImage.BILINEAR)) / 255.0 ax[1].imshow(img_np); ax[1].imshow(attn_resized, cmap='jet', alpha=0.5) ax[1].set_title("Feature Similarity Heatmap", fontsize=10); ax[1].axis('off') fig_attn_overlay.suptitle("Visual Feature Heatmap (DINO-style)", fontsize=12, fontweight='bold') fig_attn_overlay.tight_layout() log_dict["visuals/attention/feature_heatmap"] = trackio.Image(fig_to_pil(fig_attn_overlay), caption=f"step={global_step}") plt.close(fig_attn_overlay) # ── Rollout Comparison Grid ── if traj.dim() == 4 and traj.size(1) > 1: n_show = min(batch["pixel_values"].size(0), 3) images_for_grid = [] pcas_for_grid = [] for i in range(n_show): img_i = batch["pixel_values"][i].cpu() img_i_denorm = (img_i * std + mean).clamp(0, 1) images_for_grid.append((img_i_denorm.permute(1, 2, 0).numpy() * 255).astype(np.uint8)) step_pcas = [] for k in range(traj.size(1)): n_tok = traj.size(2) side_k = max(int(math.sqrt(n_tok)), 1) step_pcas.append(pca_feature_map(traj[i, k], side_k, side_k)) pcas_for_grid.append(step_pcas) fig_grid = rollout_comparison_grid(images_for_grid, pcas_for_grid, suptitle=f"Rollout PCA Grid (K={traj.size(1)-1})") log_dict["visuals/rollout/comparison_grid"] = trackio.Image(fig_to_pil(fig_grid), caption=f"step={global_step}") plt.close(fig_grid) trackio.log(log_dict) log.info(f"Logged {len(log_dict)} visual diagnostics at step {global_step}") except Exception as e: log.warning(f"Visual diagnostics failed: {e}") import traceback traceback.print_exc() finally: model.train() plt.close('all') if diagnostics_collector: diagnostics_collector.clear() # ══════════════════════════════════════════════════════════════════════════ # PHASE 2 TRAINING # ══════════════════════════════════════════════════════════════════════════ def download_phase1_checkpoint(hub_model_id: str, run_name: str = "hybrid_main"): from huggingface_hub import hf_hub_download path = hf_hub_download(repo_id=hub_model_id, filename=f"checkpoints/{run_name}_best.pt", repo_type="model") log.info(f"Downloaded checkpoint: {path}") return path def main(): parser = argparse.ArgumentParser(description="MR-JEPA Phase 2 Training") parser.add_argument("--checkpoint", type=str, default=None) parser.add_argument("--hub_model_id", default="JorgeAV/MR-JEPA") parser.add_argument("--run_name", default="hybrid_main_phase2") parser.add_argument("--phase1_run", default="hybrid_main") parser.add_argument("--epochs", type=int, default=10) parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--grad_accum", type=int, default=8) parser.add_argument("--core_lr", type=float, default=1e-4) parser.add_argument("--backbone_lr", type=float, default=1e-5) parser.add_argument("--text_lr", type=float, default=1e-5) parser.add_argument("--unfreeze_visual_layers", type=int, default=6) parser.add_argument("--unfreeze_text_layers", type=int, default=4) parser.add_argument("--max_eval_samples", type=int, default=500) parser.add_argument("--vis_interval", type=int, default=100) parser.add_argument("--output_dir", default="./outputs/mrjepa_phase2") parser.add_argument("--trackio_space", default="JorgeAV/MR-JEPA-Trackio", help="HF Space ID for persistent Trackio dashboard") args = parser.parse_args() log.info("Downloading Phase 1 training script...") from huggingface_hub import hf_hub_download p1_script = hf_hub_download(repo_id=args.hub_model_id, filename="train_mrjepa.py", repo_type="model") import importlib.util spec = importlib.util.spec_from_file_location("train_mrjepa", p1_script) p1 = importlib.util.module_from_spec(spec) spec.loader.exec_module(p1) if args.checkpoint and os.path.exists(args.checkpoint): ckpt_path = args.checkpoint else: ckpt_path = download_phase1_checkpoint(args.hub_model_id, args.phase1_run) log.info(f"Loading Phase 1 checkpoint: {ckpt_path}") ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) saved_cfg = ckpt["config"] cfg = p1.Config() for k, v in saved_cfg.items(): if hasattr(cfg, k): setattr(cfg, k, v) cfg.phase = 2 cfg.epochs = args.epochs cfg.batch_size = args.batch_size cfg.grad_accum = args.grad_accum cfg.lr = args.core_lr cfg.backbone_lr = args.backbone_lr cfg.output_dir = args.output_dir cfg.run_name = args.run_name cfg.freeze_backbone = True cfg.freeze_text = True cfg.max_eval_samples = args.max_eval_samples cfg.resolve() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") log.info(f"Device: {device}") os.makedirs(cfg.output_dir, exist_ok=True) # ── Initialize Trackio with persistent HF Space ── import trackio trackio.init( name=args.run_name, project="MR-JEPA", space_id=args.trackio_space, config={ "phase": 2, "epochs": args.epochs, "core_lr": args.core_lr, "backbone_lr": args.backbone_lr, "text_lr": args.text_lr, "batch_size": args.batch_size, "grad_accum": args.grad_accum, "unfreeze_visual_layers": args.unfreeze_visual_layers, "unfreeze_text_layers": args.unfreeze_text_layers, "phase1_best_acc": ckpt.get("eval_acc", "unknown"), "vis_interval": args.vis_interval, "backbone": cfg.backbone, "K": cfg.K, "use_jepa": cfg.use_jepa, "loss_fn": cfg.loss_fn, } ) log.info(f"Trackio initialized → Space: https://huggingface.co/spaces/{args.trackio_space}") log.info("Building model...") model = p1.MRJEPAModel(cfg) model.evidence.load_state_dict(ckpt["evidence"]) model.rollout.load_state_dict(ckpt["rollout"]) model.disc.load_state_dict(ckpt["disc"]) model.target.t_ev.load_state_dict(ckpt["target_ev"]) model.target.t_ro.load_state_dict(ckpt["target_ro"]) log.info(f"Loaded Phase 1 weights (epoch={ckpt.get('epoch','?')}, eval_acc={ckpt.get('eval_acc','?')}%)") log.info(f"Unfreezing last {args.unfreeze_visual_layers} visual layers, " f"last {args.unfreeze_text_layers} text layers") model.vis.unfreeze_last(args.unfreeze_visual_layers) model.txt.unfreeze_last(args.unfreeze_text_layers) model = model.to(device) total_p = sum(p.numel() for p in model.parameters()) train_p = sum(p.numel() for p in model.parameters() if p.requires_grad) log.info(f"Total: {total_p:,} | Trainable: {train_p:,} ({100*train_p/total_p:.1f}%)") trackio.log({"model/total_params": total_p, "model/trainable_params": train_p, "model/trainable_pct": 100 * train_p / total_p}) transform = model.vis.get_transform() tokenizer = model.txt.tokenizer train_ds = p1.ScienceQADataset("train", transform=transform, tokenizer=tokenizer, max_len=cfg.max_text_len, max_opts=cfg.max_options) eval_ds = p1.ScienceQADataset("test", max_samples=cfg.max_eval_samples, transform=transform, tokenizer=tokenizer, max_len=cfg.max_text_len, max_opts=cfg.max_options) coll = lambda batch: p1.collate_fn(batch, transform, tokenizer, cfg.max_text_len, cfg.max_options) train_dl = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=2, collate_fn=coll, pin_memory=True, drop_last=True) eval_dl = DataLoader(eval_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=2, collate_fn=coll, pin_memory=True) backbone_params = [p for p in model.vis.parameters() if p.requires_grad] text_params = [p for p in model.txt.parameters() if p.requires_grad] bb_txt_ids = {id(p) for p in backbone_params + text_params} core_params = [p for p in model.parameters() if p.requires_grad and id(p) not in bb_txt_ids] param_groups = [ {"params": core_params, "lr": args.core_lr}, {"params": backbone_params, "lr": args.backbone_lr}, {"params": text_params, "lr": args.text_lr}, ] log.info(f"Optimizer: core={len(core_params)} @ {args.core_lr}, " f"backbone={len(backbone_params)} @ {args.backbone_lr}, " f"text={len(text_params)} @ {args.text_lr}") optimizer = AdamW(param_groups, weight_decay=cfg.weight_decay) total_steps = cfg.epochs * len(train_dl) // cfg.grad_accum warmup_steps = int(total_steps * cfg.warmup_ratio) def lr_lambda(step): if step < warmup_steps: return step / max(warmup_steps, 1) progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1) return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) diag_collector = DiagnosticCollector() diag_collector.attach(model) log.info(f"Phase 2: {cfg.epochs} epochs, {len(train_dl)} batches/epoch, ga={cfg.grad_accum}") log.info(f"Visual diagnostics every {args.vis_interval} optimizer steps") global_step = 0 best_acc = ckpt.get("eval_acc", 0.0) amp_dtype = torch.bfloat16 if cfg.bf16 else torch.float32 trainable = [p for p in model.parameters() if p.requires_grad] try: for epoch in range(cfg.epochs): model.train() epoch_losses = defaultdict(list) epoch_correct = 0 epoch_total = 0 optimizer.zero_grad() for batch_idx, batch in enumerate(train_dl): batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} vis_tok = model.vis(batch["pixel_values"]).float() txt_tok = model.txt(batch["input_ids"], batch["attention_mask"]).float() with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type == "cuda"): evidence, _, ev_mask = model.evidence(vis_tok, txt_tok, batch["attention_mask"]) if model._use_rollout: traj, z_final, z_proj = model.rollout(evidence) else: B = batch["batch_size"] z0 = model.rollout.init_tokens.expand(B, -1, -1) + \ model.rollout.z0_proj(F.adaptive_avg_pool1d( evidence.permute(0,2,1), model.rollout.num_tokens).permute(0,2,1)) z_final = z0 z_proj = model.rollout.out_proj(z0).unsqueeze(1) if model._use_jepa: target_proj = model.target(vis_tok.detach(), txt_tok.detach(), batch["attention_mask"].detach()) else: target_proj = None opt_emb = model.encode_options(batch["opt_input_ids"], batch["opt_attention_mask"]) opt_emb = opt_emb.view(batch["batch_size"], cfg.max_options, -1) logits = model.disc(z_final, opt_emb, batch["opt_mask"]) task_loss = F.cross_entropy(logits, batch["labels"]) if model._use_jepa and target_proj is not None: losses = model.jepa_loss(z_proj, target_proj, task_loss) else: losses = {"total": task_loss, "jepa": torch.tensor(0.0), "task": task_loss, "reg": torch.tensor(0.0)} loss = losses["total"] / cfg.grad_accum loss.backward() if (batch_idx + 1) % cfg.grad_accum == 0: nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm) optimizer.step(); scheduler.step(); optimizer.zero_grad() model.update_target(global_step, total_steps) global_step += 1 if global_step % args.vis_interval == 0 and global_step > 0: log.info(f"Generating visual diagnostics at step {global_step}...") log_visual_diagnostics(model, batch, device, cfg, global_step, epoch, diagnostics_collector=diag_collector, vis_interval=args.vis_interval) preds = logits.argmax(dim=-1) for k, v in losses.items(): if isinstance(v, torch.Tensor): epoch_losses[k].append(v.item()) epoch_correct += (preds == batch["labels"]).sum().item() epoch_total += batch["batch_size"] if batch_idx % 50 == 0: avg = {k: np.mean(v[-50:]) for k, v in epoch_losses.items()} acc = epoch_correct / max(epoch_total, 1) * 100 lrs = scheduler.get_last_lr() log.info(f"P2 E{epoch} B{batch_idx}/{len(train_dl)} | " f"loss={avg.get('total',0):.4f} jepa={avg.get('jepa',0):.4f} " f"task={avg.get('task',0):.4f} | acc={acc:.1f}%") trackio.log({ "train/loss": avg.get("total", 0), "train/jepa_loss": avg.get("jepa", 0), "train/task_loss": avg.get("task", 0), "train/reg_loss": avg.get("reg", 0), "train/accuracy": acc, "train/lr": lrs[0] if lrs else 0, "train/backbone_lr": lrs[1] if len(lrs) > 1 else 0, "train/text_lr": lrs[2] if len(lrs) > 2 else 0, "train/ema_momentum": model.target.mom, "train/epoch": epoch, "train/step": global_step, }) eval_acc = p1.evaluate(model, eval_dl, device, cfg) train_acc = epoch_correct / max(epoch_total, 1) * 100 log.info(f"=== Phase 2 Epoch {epoch} | Train: {train_acc:.1f}% | Eval: {eval_acc:.1f}% ===") trackio.log({"eval/accuracy": eval_acc, "eval/epoch": epoch, "eval/train_accuracy": train_acc, "eval/best_accuracy": max(best_acc, eval_acc)}) log.info(f"Generating epoch-end visual diagnostics...") diag_batch = next(iter(eval_dl)) diag_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in diag_batch.items()} log_visual_diagnostics(model, diag_batch, device, cfg, global_step, epoch, diagnostics_collector=diag_collector, vis_interval=args.vis_interval) if eval_acc > best_acc: best_acc = eval_acc p1.save_checkpoint(model, cfg, epoch, eval_acc, is_best=True) log.info(f"New best accuracy: {best_acc:.1f}%") log.info(f"Phase 2 complete. Best eval accuracy: {best_acc:.1f}%") finally: # ── Ensure Trackio data is persisted even if training crashes ── diag_collector.detach() trackio.log({"final/best_accuracy": best_acc, "final/phase": 2, "final/total_steps": global_step}) log.info("Finishing Trackio and syncing to Space...") trackio.finish() # Belt-and-suspenders: explicit sync to ensure all images are uploaded try: trackio.sync(project="MR-JEPA", space_id=args.trackio_space) log.info(f"Trackio synced to https://huggingface.co/spaces/{args.trackio_space}") except Exception as e: log.warning(f"Trackio sync failed (data may still be available via finish): {e}") if cfg.push_to_hub: p1.push_results(cfg, best_acc) if __name__ == "__main__": main()