| |
| """ |
| MR-JEPA Phase 2 Training — Perception Fine-tuning + SOTA Visual Diagnostics |
| |
| Loads the best Phase 1 checkpoint and unfreezes: |
| - Last 6 DINOv3-L layers (LR: 1e-5) |
| - Last 4 Qwen3-Embedding layers (LR: 1e-5) |
| - Reasoning core continues at 1e-4 |
| |
| Visual diagnostics logged to Trackio (state-of-the-art for JEPA): |
| 1. PCA Feature Maps (V-JEPA 2.1 style) — patch features → RGB via PCA |
| 2. Multi-head Attention Heatmaps (DINO style) — CLS attention per head overlaid on image |
| 3. RankMe Score (anti-collapse) — effective rank of embedding matrix |
| 4. Per-dimension Variance (VICReg style) — collapse detection per dim |
| 5. Latent Trajectory PCA — z₀→z₁→z₂→z₃ projected to 2D |
| 6. Temporal Straightness (LeWM Eq. 9) — trajectory coherence metric |
| 7. Evidence Gate Activation Heatmaps — gate values per rollout step |
| 8. Token Norm Maps (DINOv2-Reg style) — artifact detection |
| 9. Cross-Attention Weights in Perceiver — which evidence each query attends to |
| 10. Eigenspectrum Plot — singular value distribution of latent space |
| |
| All images are persisted to HF Space JorgeAV/MR-JEPA-Trackio via space_id parameter. |
| |
| Usage: |
| python train_phase2.py --checkpoint checkpoints/hybrid_main_best.pt |
| python train_phase2.py --epochs 10 --backbone_lr 1e-5 |
| |
| Prerequisites: |
| Phase 1 must be complete with a saved checkpoint at JorgeAV/MR-JEPA |
| """ |
|
|
| import os |
| import sys |
| import json |
| import math |
| import copy |
| import logging |
| import argparse |
| from collections import defaultdict |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.optim import AdamW |
| from torch.utils.data import DataLoader |
|
|
| |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import matplotlib.colors as mcolors |
| import seaborn as sns |
| from sklearn.decomposition import PCA as SklearnPCA |
|
|
| from PIL import Image as PILImage |
| import io |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s", datefmt="%H:%M:%S") |
| log = logging.getLogger("mrjepa-p2") |
|
|
|
|
| def fig_to_pil(fig: plt.Figure) -> PILImage.Image: |
| """Convert matplotlib figure to PIL Image for Trackio logging. |
| Trackio accepts PIL.Image but NOT matplotlib.figure.Figure directly.""" |
| buf = io.BytesIO() |
| fig.savefig(buf, format='png', dpi=120, bbox_inches='tight', pad_inches=0.1) |
| buf.seek(0) |
| img = PILImage.open(buf).copy() |
| buf.close() |
| return img |
|
|
|
|
| |
| |
| |
| |
|
|
| def rankme(Z: torch.Tensor, epsilon: float = 1e-7) -> float: |
| """ |
| RankMe: effective rank of embedding matrix via Shannon entropy of singular values. |
| Source: "RankMe: Assessing the Downstream Performance of Pretrained Self-Supervised |
| Representations by Their Rank" (arxiv:2210.02885) |
| |
| Z: (N, D) — batch of embeddings |
| Returns scalar effective rank (higher = less collapsed) |
| """ |
| if Z.dim() == 3: |
| Z = Z.reshape(-1, Z.size(-1)) |
| Z_centered = Z - Z.mean(0, keepdim=True) |
| try: |
| _, S, _ = torch.linalg.svd(Z_centered.float(), full_matrices=False) |
| p = S / (S.sum() + epsilon) + epsilon |
| return torch.exp(-torch.sum(p * torch.log(p))).item() |
| except: |
| return 0.0 |
|
|
|
|
| def vicreg_collapse_stats(Z: torch.Tensor) -> dict: |
| """ |
| VICReg-style per-dimension variance monitoring. |
| Source: "VICReg: Variance-Invariance-Covariance Regularization" (arxiv:2105.04906) |
| |
| Z: (N, D) — monitor each dimension's std |
| Collapsed dims have std < 0.1 |
| """ |
| if Z.dim() == 3: |
| Z = Z.reshape(-1, Z.size(-1)) |
| std_per_dim = Z.float().std(0) |
| return { |
| "min_std": std_per_dim.min().item(), |
| "mean_std": std_per_dim.mean().item(), |
| "max_std": std_per_dim.max().item(), |
| "collapsed_dims": (std_per_dim < 0.1).sum().item(), |
| "total_dims": std_per_dim.size(0), |
| "std_values": std_per_dim.detach().cpu().numpy(), |
| } |
|
|
|
|
| def temporal_straightness(z_seq: torch.Tensor) -> float: |
| """ |
| Temporal straightness metric from LeWorldModel (Eq. 9, Appendix H, Fig. 17). |
| Source: "Le World Model" (arxiv:2603.19312) |
| |
| Measures geometric coherence of latent trajectory. |
| z_seq: (B, K+1, D) or (B, K+1, N, D) — latent trajectory |
| Returns mean cosine similarity between consecutive velocity vectors. |
| Higher = more coherent (straighter) trajectory. |
| """ |
| if z_seq.dim() == 4: |
| z_seq = z_seq.mean(dim=2) |
| v = z_seq[:, 1:] - z_seq[:, :-1] |
| if v.size(1) < 2: |
| return 0.0 |
| v_norm = F.normalize(v.float(), dim=-1) |
| cos_sim = (v_norm[:, :-1] * v_norm[:, 1:]).sum(-1) |
| return cos_sim.mean().item() |
|
|
|
|
| def pca_feature_map(patch_features: torch.Tensor, grid_h: int, grid_w: int) -> np.ndarray: |
| """ |
| PCA RGB feature map visualization (V-JEPA 2.1 style). |
| Source: "Revisiting Feature Prediction for Learning Visual Representations |
| from Video" (arxiv:2603.14482), Section 2.2, Figs. 1, 3. |
| |
| Maps first 3 principal components of patch features → RGB channels. |
| Semantic objects share PCA colors regardless of position. |
| |
| patch_features: (N_patches, D) |
| Returns: (grid_h, grid_w, 3) uint8 array suitable for display |
| """ |
| feats = patch_features.float().detach().cpu().numpy() |
| n_components = min(3, feats.shape[0], feats.shape[1]) |
| if n_components < 3: |
| return np.zeros((grid_h, grid_w, 3), dtype=np.uint8) |
| pca = SklearnPCA(n_components=3) |
| pca_feats = pca.fit_transform(feats) |
| |
| for c in range(3): |
| vmin, vmax = pca_feats[:, c].min(), pca_feats[:, c].max() |
| if vmax - vmin > 1e-8: |
| pca_feats[:, c] = (pca_feats[:, c] - vmin) / (vmax - vmin) |
| else: |
| pca_feats[:, c] = 0.5 |
| n_patches = pca_feats.shape[0] |
| if n_patches != grid_h * grid_w: |
| actual_side = int(math.sqrt(n_patches)) |
| grid_h = grid_w = actual_side |
| pca_img = (pca_feats[:grid_h * grid_w].reshape(grid_h, grid_w, 3) * 255).astype(np.uint8) |
| return pca_img |
|
|
|
|
| def eigenspectrum_plot(Z: torch.Tensor, title: str = "Eigenspectrum") -> plt.Figure: |
| """ |
| Plot singular value spectrum of embedding matrix. |
| Source: "RankMe" (arxiv:2210.02885), used in DINO, I-JEPA for collapse monitoring. |
| |
| Healthy representations show flat spectrum; collapsed ones show sharp decay. |
| """ |
| if Z.dim() == 3: |
| Z = Z.reshape(-1, Z.size(-1)) |
| Z_c = Z.float() - Z.float().mean(0, keepdim=True) |
| try: |
| _, S, _ = torch.linalg.svd(Z_c, full_matrices=False) |
| s = S.detach().cpu().numpy() |
| except: |
| s = np.ones(10) |
| fig, ax = plt.subplots(figsize=(6, 4)) |
| ax.semilogy(s[:min(100, len(s))], 'b-', linewidth=1.5) |
| ax.fill_between(range(min(100, len(s))), s[:min(100, len(s))], alpha=0.15, color='blue') |
| ax.set_xlabel("Singular Value Index", fontsize=10) |
| ax.set_ylabel("Singular Value (log)", fontsize=10) |
| ax.set_title(title, fontsize=12, fontweight='bold') |
| ax.grid(True, alpha=0.3) |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def per_dim_variance_plot(std_values: np.ndarray, title: str = "Per-Dimension Std") -> plt.Figure: |
| """ |
| VICReg-style per-dimension standard deviation bar plot. |
| Source: "VICReg" (arxiv:2105.04906), Section 4. |
| |
| Each bar = std of one embedding dimension across the batch. |
| Target: all dims near γ=1.0. Dims with std→0 are collapsed. |
| """ |
| fig, ax = plt.subplots(figsize=(8, 3)) |
| n = min(len(std_values), 200) |
| colors = ['red' if v < 0.1 else ('orange' if v < 0.3 else 'steelblue') for v in std_values[:n]] |
| ax.bar(range(n), std_values[:n], color=colors, width=1.0, edgecolor='none') |
| ax.axhline(y=1.0, color='green', linestyle='--', alpha=0.7, label='Target γ=1.0') |
| ax.axhline(y=0.1, color='red', linestyle='--', alpha=0.5, label='Collapse threshold') |
| ax.set_xlabel("Dimension Index", fontsize=9) |
| ax.set_ylabel("Std", fontsize=9) |
| ax.set_title(title, fontsize=11, fontweight='bold') |
| ax.legend(fontsize=8) |
| ax.set_xlim(-0.5, n - 0.5) |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def trajectory_pca_plot(trajectory: torch.Tensor, title: str = "Latent Trajectory") -> plt.Figure: |
| """ |
| Latent trajectory visualization via PCA projection. |
| Source: I-JEPA + LeWorldModel trajectory analysis. |
| |
| trajectory: (K+1, N_tokens, D) — single sample trajectory |
| Projects step centroids to 2D, draws arrows showing reasoning evolution. |
| """ |
| K_plus_1, N_s, D = trajectory.shape |
| centroids = trajectory.mean(dim=1).float().detach().cpu().numpy() |
| |
| if K_plus_1 < 2: |
| fig, ax = plt.subplots(figsize=(5, 5)) |
| ax.set_title(title) |
| return fig |
| |
| centered = centroids - centroids.mean(axis=0) |
| try: |
| pca = SklearnPCA(n_components=2) |
| coords = pca.fit_transform(centered) |
| except: |
| coords = centered[:, :2] |
| |
| fig, ax = plt.subplots(figsize=(6, 6)) |
| colors = plt.cm.viridis(np.linspace(0, 1, K_plus_1)) |
| |
| for k in range(K_plus_1 - 1): |
| ax.annotate("", xy=coords[k+1], xytext=coords[k], |
| arrowprops=dict(arrowstyle="->", color=colors[k], lw=2.5)) |
| |
| for k in range(K_plus_1): |
| ax.scatter(coords[k, 0], coords[k, 1], c=[colors[k]], s=150, |
| zorder=5, edgecolors='black', linewidth=1.5) |
| label = 'z₀' if k == 0 else f'z_{k}' |
| ax.annotate(label, (coords[k, 0], coords[k, 1]), |
| textcoords="offset points", xytext=(10, 10), |
| fontsize=12, fontweight='bold', color=colors[k]) |
| |
| ax.set_xlabel("PC1", fontsize=10) |
| ax.set_ylabel("PC2", fontsize=10) |
| ax.set_title(title, fontsize=12, fontweight='bold') |
| ax.grid(True, alpha=0.3) |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def attention_heatmap_overlay( |
| attn_weights: torch.Tensor, |
| original_image: np.ndarray, |
| grid_h: int = 16, grid_w: int = 16, |
| title: str = "Attention Heatmap", |
| num_heads_to_show: int = 4, |
| ) -> plt.Figure: |
| """ |
| DINO-style multi-head self-attention heatmap overlay. |
| Source: "Emerging Properties in Self-Supervised Vision Transformers" |
| (arxiv:2104.14294), Section 4.2.2, Fig. 3. |
| """ |
| if attn_weights.dim() == 2: |
| attn_weights = attn_weights.unsqueeze(0) |
| |
| n_heads = min(attn_weights.size(0), num_heads_to_show) |
| fig, axes = plt.subplots(1, n_heads + 1, figsize=(4 * (n_heads + 1), 4)) |
| if n_heads + 1 == 1: |
| axes = [axes] |
| |
| axes[0].imshow(original_image) |
| axes[0].set_title("Input Image", fontsize=10) |
| axes[0].axis('off') |
| |
| head_colors = ['Reds', 'Blues', 'Greens', 'Purples', 'Oranges', 'YlOrRd'] |
| |
| for h in range(n_heads): |
| attn = attn_weights[h].float().detach().cpu().numpy() |
| if attn.ndim == 2: |
| attn_map = attn[0] |
| else: |
| attn_map = attn |
| n_tokens = attn_map.shape[0] |
| side = int(math.sqrt(n_tokens)) |
| if side * side != n_tokens: |
| side = grid_h |
| attn_2d = attn_map[:side*side].reshape(side, side) |
| attn_2d = (attn_2d - attn_2d.min()) / (attn_2d.max() - attn_2d.min() + 1e-8) |
| attn_resized = np.array(PILImage.fromarray((attn_2d * 255).astype(np.uint8)).resize( |
| (original_image.shape[1], original_image.shape[0]), PILImage.BILINEAR)) / 255.0 |
| axes[h + 1].imshow(original_image) |
| axes[h + 1].imshow(attn_resized, cmap=head_colors[h % len(head_colors)], alpha=0.6) |
| axes[h + 1].set_title(f"Head {h}", fontsize=10) |
| axes[h + 1].axis('off') |
| |
| fig.suptitle(title, fontsize=12, fontweight='bold') |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def evidence_gate_heatmap(gate_values: list, title: str = "Evidence Gate Activations") -> plt.Figure: |
| """Evidence gate activation visualization per rollout step.""" |
| K = len(gate_values) |
| if K == 0: |
| fig, ax = plt.subplots() |
| ax.text(0.5, 0.5, "No gates recorded", ha='center', va='center') |
| return fig |
| |
| fig, axes = plt.subplots(1, K, figsize=(4 * K, 4)) |
| if K == 1: |
| axes = [axes] |
| |
| for k, gv in enumerate(gate_values): |
| if isinstance(gv, torch.Tensor): |
| gv = gv.float().detach().cpu().numpy() |
| if gv.ndim == 3: |
| gv = gv.mean(0) |
| mean_gate = gv.mean(axis=-1) if gv.ndim == 2 else gv |
| im = axes[k].imshow(mean_gate.reshape(1, -1) if mean_gate.ndim == 1 else mean_gate, |
| cmap='YlOrRd', aspect='auto', vmin=0, vmax=1) |
| axes[k].set_title(f'Step {k+1}: μ={mean_gate.mean():.3f}', fontsize=10) |
| axes[k].set_xlabel("Token Index") |
| fig.colorbar(im, ax=axes[k], fraction=0.046, pad=0.04) |
| |
| fig.suptitle(title, fontsize=12, fontweight='bold') |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def token_norm_map(tokens: torch.Tensor, grid_h: int = 16, grid_w: int = 16, |
| title: str = "Token Norm Map") -> plt.Figure: |
| """ |
| Token norm map for artifact detection (DINOv2+Registers style). |
| Source: "Vision Transformers Need Registers" (arxiv:2309.16588). |
| """ |
| if tokens.dim() == 3: |
| tokens = tokens[0] |
| norms = tokens.float().norm(dim=-1).detach().cpu().numpy() |
| n_tokens = len(norms) |
| side = int(math.sqrt(n_tokens)) |
| if side * side != n_tokens: |
| side = grid_h |
| norm_map = norms[:side*side].reshape(side, side) |
| |
| fig, axes = plt.subplots(1, 2, figsize=(10, 4)) |
| im = axes[0].imshow(norm_map, cmap='hot', aspect='auto') |
| axes[0].set_title(f"{title}\nμ={norms.mean():.2f}, σ={norms.std():.2f}", fontsize=10) |
| fig.colorbar(im, ax=axes[0]) |
| axes[1].hist(norms, bins=50, color='steelblue', edgecolor='white', alpha=0.8) |
| axes[1].axvline(norms.mean(), color='red', linestyle='--', label=f'Mean: {norms.mean():.2f}') |
| axes[1].axvline(norms.mean() + 2*norms.std(), color='orange', linestyle='--', |
| label=f'+2σ: {norms.mean() + 2*norms.std():.2f}') |
| axes[1].set_title("Norm Distribution", fontsize=10) |
| axes[1].set_xlabel("||token||₂") |
| axes[1].legend(fontsize=8) |
| fig.suptitle(title, fontsize=12, fontweight='bold') |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def cross_attention_weights_plot(attn_weights: torch.Tensor, |
| title: str = "Perceiver Cross-Attention") -> plt.Figure: |
| """Cross-attention weights in the Perceiver Resampler.""" |
| if attn_weights.dim() == 3: |
| attn = attn_weights.float().mean(0).detach().cpu().numpy() |
| else: |
| attn = attn_weights.float().detach().cpu().numpy() |
| fig, ax = plt.subplots(figsize=(8, 6)) |
| sns.heatmap(attn, cmap='viridis', ax=ax, xticklabels=False, yticklabels=False) |
| ax.set_xlabel("Key Tokens (Evidence: Visual | Text)", fontsize=10) |
| ax.set_ylabel("Query Tokens (Latent)", fontsize=10) |
| ax.set_title(title, fontsize=12, fontweight='bold') |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def rollout_comparison_grid(images: list, pca_maps: list, titles: list = None, |
| suptitle: str = "Latent Rollout PCA Maps") -> plt.Figure: |
| """ |
| Grid comparing original images with PCA feature maps at each rollout step. |
| Source: LeWorldModel Fig. 7 + V-JEPA 2.1 Fig. 3. |
| """ |
| n_samples = min(len(images), 4) |
| n_steps = len(pca_maps[0]) if pca_maps else 0 |
| cols = 1 + n_steps |
| fig, axes = plt.subplots(n_samples, cols, figsize=(3 * cols, 3 * n_samples)) |
| if n_samples == 1: |
| axes = axes.reshape(1, -1) |
| for i in range(n_samples): |
| axes[i, 0].imshow(images[i]) |
| axes[i, 0].set_title("Input" if i == 0 else "", fontsize=9) |
| axes[i, 0].axis('off') |
| for k in range(n_steps): |
| if k < len(pca_maps[i]): |
| axes[i, k+1].imshow(pca_maps[i][k]) |
| axes[i, k+1].set_title(f"z_{k}" if i == 0 else "", fontsize=9) |
| axes[i, k+1].axis('off') |
| fig.suptitle(suptitle, fontsize=13, fontweight='bold') |
| fig.tight_layout() |
| return fig |
|
|
|
|
| |
| |
| |
|
|
| class DiagnosticCollector: |
| """Collects intermediate activations via hooks for visualization.""" |
| def __init__(self): |
| self.gate_values = [] |
| self.cross_attn_weights = [] |
| self.hooks = [] |
| |
| def attach(self, model): |
| self.clear() |
| if hasattr(model, 'rollout') and hasattr(model.rollout, 'predictor'): |
| for i, block in enumerate(model.rollout.predictor): |
| if hasattr(block, 'gate') and block.gate is not None: |
| def make_gate_hook(layer_idx): |
| def hook(module, input, output): |
| self.gate_values.append(output.detach().cpu()) |
| return hook |
| h = block.gate.proj.register_forward_hook(make_gate_hook(i)) |
| self.hooks.append(h) |
| if hasattr(model, 'evidence') and hasattr(model.evidence, 'layers'): |
| last_layer = model.evidence.layers[-1] |
| if hasattr(last_layer, 'xa'): |
| def xa_hook(module, input, output): |
| if isinstance(output, tuple) and len(output) > 1: |
| self.cross_attn_weights.append(output[1].detach().cpu()) |
| h = last_layer.xa.register_forward_hook(xa_hook) |
| self.hooks.append(h) |
| |
| def clear(self): |
| self.gate_values = [] |
| self.cross_attn_weights = [] |
| |
| def detach(self): |
| for h in self.hooks: |
| h.remove() |
| self.hooks = [] |
|
|
|
|
| |
| |
| |
|
|
| def log_visual_diagnostics(model, batch, device, cfg, global_step, epoch, |
| diagnostics_collector=None, vis_interval=100): |
| """ |
| Generate and log all visual diagnostics to Trackio. |
| |
| Implements SOTA visualizations from: |
| - I-JEPA (attention maps), V-JEPA 2.1 (PCA feature maps) |
| - LeWorldModel (trajectory, temporal straightness) |
| - RankMe (effective rank), VICReg (per-dim variance) |
| - DINOv2+Registers (token norm maps) |
| """ |
| import trackio |
| |
| model.eval() |
| log_dict = {} |
| |
| try: |
| with torch.no_grad(): |
| vis_tok = model.vis(batch["pixel_values"].to(device)).float() |
| txt_tok = model.txt(batch["input_ids"].to(device), |
| batch["attention_mask"].to(device)).float() |
| evidence, kv, ev_mask = model.evidence(vis_tok, txt_tok, |
| batch["attention_mask"].to(device)) |
| |
| if model._use_rollout: |
| traj, z_final, z_proj = model.rollout(evidence) |
| else: |
| traj = evidence.unsqueeze(1) |
| z_final = evidence |
| z_proj = evidence.unsqueeze(1) |
| |
| |
| evidence_rank = rankme(evidence) |
| z_final_rank = rankme(z_final) |
| vis_rank = rankme(vis_tok) |
| log_dict["diagnostics/rankme_evidence"] = evidence_rank |
| log_dict["diagnostics/rankme_z_final"] = z_final_rank |
| log_dict["diagnostics/rankme_visual"] = vis_rank |
| |
| |
| ev_stats = vicreg_collapse_stats(evidence) |
| zf_stats = vicreg_collapse_stats(z_final) |
| log_dict["diagnostics/evidence_min_std"] = ev_stats["min_std"] |
| log_dict["diagnostics/evidence_mean_std"] = ev_stats["mean_std"] |
| log_dict["diagnostics/evidence_collapsed_dims"] = ev_stats["collapsed_dims"] |
| log_dict["diagnostics/z_final_min_std"] = zf_stats["min_std"] |
| log_dict["diagnostics/z_final_mean_std"] = zf_stats["mean_std"] |
| log_dict["diagnostics/z_final_collapsed_dims"] = zf_stats["collapsed_dims"] |
| |
| |
| if traj.dim() == 4: |
| straightness = temporal_straightness(traj) |
| log_dict["diagnostics/temporal_straightness"] = straightness |
| centroids = traj.mean(dim=2) |
| for k in range(centroids.size(1) - 1): |
| dist = torch.norm(centroids[:, k+1] - centroids[:, k], dim=-1).mean().item() |
| log_dict[f"diagnostics/step_distance_z{k}_to_z{k+1}"] = dist |
| |
| |
| n_vis_patches = vis_tok.size(1) |
| grid_side = int(math.sqrt(n_vis_patches)) |
| |
| pca_ctx = pca_feature_map(vis_tok[0], grid_side, grid_side) |
| fig_pca_ctx = plt.figure(figsize=(4, 4)) |
| plt.imshow(pca_ctx); plt.title("Context Encoder PCA", fontsize=11, fontweight='bold') |
| plt.axis('off'); plt.tight_layout() |
| log_dict["visuals/pca/context_encoder"] = trackio.Image(fig_to_pil(fig_pca_ctx), |
| caption=f"step={global_step}") |
| plt.close(fig_pca_ctx) |
| |
| pca_ev = pca_feature_map(evidence[0], 8, 8) |
| fig_pca_ev = plt.figure(figsize=(4, 4)) |
| plt.imshow(pca_ev); plt.title("Evidence Memory PCA", fontsize=11, fontweight='bold') |
| plt.axis('off'); plt.tight_layout() |
| log_dict["visuals/pca/evidence_memory"] = trackio.Image(fig_to_pil(fig_pca_ev), |
| caption=f"step={global_step}") |
| plt.close(fig_pca_ev) |
| |
| if traj.dim() == 4 and traj.size(1) > 1: |
| rollout_pcas = [] |
| for k in range(traj.size(1)): |
| n_tok = traj.size(2) |
| side_k = max(int(math.sqrt(n_tok)), 1) |
| rollout_pcas.append(pca_feature_map(traj[0, k], side_k, side_k)) |
| fig_rollout = plt.figure(figsize=(4 * len(rollout_pcas), 4)) |
| for k, pca_k in enumerate(rollout_pcas): |
| ax = fig_rollout.add_subplot(1, len(rollout_pcas), k + 1) |
| ax.imshow(pca_k); ax.set_title(f"z_{k}", fontsize=10); ax.axis('off') |
| fig_rollout.suptitle("Rollout PCA per Step", fontsize=12, fontweight='bold') |
| fig_rollout.tight_layout() |
| log_dict["visuals/pca/rollout_steps"] = trackio.Image(fig_to_pil(fig_rollout), |
| caption=f"step={global_step}") |
| plt.close(fig_rollout) |
| |
| |
| fig_eigen_ev = eigenspectrum_plot(evidence, "Evidence Eigenspectrum") |
| log_dict["visuals/eigenspectrum/evidence"] = trackio.Image(fig_to_pil(fig_eigen_ev), |
| caption=f"RankMe={evidence_rank:.1f}, step={global_step}") |
| plt.close(fig_eigen_ev) |
| |
| fig_eigen_zf = eigenspectrum_plot(z_final, "z_final Eigenspectrum") |
| log_dict["visuals/eigenspectrum/z_final"] = trackio.Image(fig_to_pil(fig_eigen_zf), |
| caption=f"RankMe={z_final_rank:.1f}, step={global_step}") |
| plt.close(fig_eigen_zf) |
| |
| |
| fig_vardim_ev = per_dim_variance_plot(ev_stats["std_values"], |
| f"Evidence Per-Dim Std (collapsed={ev_stats['collapsed_dims']})") |
| log_dict["visuals/collapse/evidence_std"] = trackio.Image(fig_to_pil(fig_vardim_ev), |
| caption=f"step={global_step}") |
| plt.close(fig_vardim_ev) |
| |
| fig_vardim_zf = per_dim_variance_plot(zf_stats["std_values"], |
| f"z_final Per-Dim Std (collapsed={zf_stats['collapsed_dims']})") |
| log_dict["visuals/collapse/z_final_std"] = trackio.Image(fig_to_pil(fig_vardim_zf), |
| caption=f"step={global_step}") |
| plt.close(fig_vardim_zf) |
| |
| |
| if traj.dim() == 4 and traj.size(1) > 1: |
| fig_traj = trajectory_pca_plot(traj[0], |
| f"Latent Trajectory (straightness={straightness:.3f})") |
| log_dict["visuals/trajectory/pca"] = trackio.Image(fig_to_pil(fig_traj), |
| caption=f"K={traj.size(1)-1}, step={global_step}") |
| plt.close(fig_traj) |
| |
| |
| fig_norm_vis = token_norm_map(vis_tok[0], grid_side, grid_side, "Visual Token Norms") |
| log_dict["visuals/norms/visual_tokens"] = trackio.Image(fig_to_pil(fig_norm_vis), |
| caption=f"step={global_step}") |
| plt.close(fig_norm_vis) |
| |
| fig_norm_ev = token_norm_map(evidence[0], 8, 8, "Evidence Token Norms") |
| log_dict["visuals/norms/evidence_tokens"] = trackio.Image(fig_to_pil(fig_norm_ev), |
| caption=f"step={global_step}") |
| plt.close(fig_norm_ev) |
| |
| |
| if diagnostics_collector and diagnostics_collector.cross_attn_weights: |
| attn_w = diagnostics_collector.cross_attn_weights[-1] |
| if attn_w is not None and attn_w.dim() >= 2: |
| if attn_w.dim() == 4: attn_w = attn_w[0] |
| elif attn_w.dim() == 3: attn_w = attn_w[0] |
| fig_xattn = cross_attention_weights_plot(attn_w, "Perceiver Cross-Attention (Last Layer)") |
| log_dict["visuals/attention/perceiver_xattn"] = trackio.Image(fig_to_pil(fig_xattn), |
| caption=f"step={global_step}") |
| plt.close(fig_xattn) |
| |
| |
| if diagnostics_collector and diagnostics_collector.gate_values: |
| gate_vals = [gv[0] if gv.dim() == 3 else gv |
| for gv in diagnostics_collector.gate_values[-cfg.K:]] |
| if gate_vals: |
| fig_gates = evidence_gate_heatmap(gate_vals, "Evidence Gate Activations") |
| log_dict["visuals/gates/activations"] = trackio.Image(fig_to_pil(fig_gates), |
| caption=f"step={global_step}") |
| plt.close(fig_gates) |
| for k, gv in enumerate(gate_vals): |
| if isinstance(gv, torch.Tensor): |
| log_dict[f"diagnostics/gate_mean_step{k+1}"] = gv.float().mean().item() |
| |
| |
| img_tensor = batch["pixel_values"][0] |
| mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) |
| std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) |
| img_denorm = (img_tensor.cpu() * std + mean).clamp(0, 1) |
| img_np = (img_denorm.permute(1, 2, 0).numpy() * 255).astype(np.uint8) |
| |
| vis_feats = vis_tok[0] |
| mean_feat = vis_feats.mean(0, keepdim=True) |
| sim = F.cosine_similarity(vis_feats, mean_feat.expand_as(vis_feats), dim=-1).detach().cpu() |
| |
| n_patches = sim.size(0) |
| side = int(math.sqrt(n_patches)) |
| if side * side == n_patches: |
| fig_attn_overlay, ax = plt.subplots(1, 2, figsize=(8, 4)) |
| ax[0].imshow(img_np); ax[0].set_title("Input Image", fontsize=10); ax[0].axis('off') |
| attn_resized = np.array(PILImage.fromarray( |
| ((sim.numpy().reshape(side, side) - sim.min().item()) / |
| (sim.max().item() - sim.min().item() + 1e-8) * 255).astype(np.uint8) |
| ).resize((img_np.shape[1], img_np.shape[0]), PILImage.BILINEAR)) / 255.0 |
| ax[1].imshow(img_np); ax[1].imshow(attn_resized, cmap='jet', alpha=0.5) |
| ax[1].set_title("Feature Similarity Heatmap", fontsize=10); ax[1].axis('off') |
| fig_attn_overlay.suptitle("Visual Feature Heatmap (DINO-style)", fontsize=12, fontweight='bold') |
| fig_attn_overlay.tight_layout() |
| log_dict["visuals/attention/feature_heatmap"] = trackio.Image(fig_to_pil(fig_attn_overlay), |
| caption=f"step={global_step}") |
| plt.close(fig_attn_overlay) |
| |
| |
| if traj.dim() == 4 and traj.size(1) > 1: |
| n_show = min(batch["pixel_values"].size(0), 3) |
| images_for_grid = [] |
| pcas_for_grid = [] |
| for i in range(n_show): |
| img_i = batch["pixel_values"][i].cpu() |
| img_i_denorm = (img_i * std + mean).clamp(0, 1) |
| images_for_grid.append((img_i_denorm.permute(1, 2, 0).numpy() * 255).astype(np.uint8)) |
| step_pcas = [] |
| for k in range(traj.size(1)): |
| n_tok = traj.size(2) |
| side_k = max(int(math.sqrt(n_tok)), 1) |
| step_pcas.append(pca_feature_map(traj[i, k], side_k, side_k)) |
| pcas_for_grid.append(step_pcas) |
| fig_grid = rollout_comparison_grid(images_for_grid, pcas_for_grid, |
| suptitle=f"Rollout PCA Grid (K={traj.size(1)-1})") |
| log_dict["visuals/rollout/comparison_grid"] = trackio.Image(fig_to_pil(fig_grid), |
| caption=f"step={global_step}") |
| plt.close(fig_grid) |
| |
| trackio.log(log_dict) |
| log.info(f"Logged {len(log_dict)} visual diagnostics at step {global_step}") |
| |
| except Exception as e: |
| log.warning(f"Visual diagnostics failed: {e}") |
| import traceback |
| traceback.print_exc() |
| finally: |
| model.train() |
| plt.close('all') |
| if diagnostics_collector: |
| diagnostics_collector.clear() |
|
|
|
|
| |
| |
| |
|
|
| def download_phase1_checkpoint(hub_model_id: str, run_name: str = "hybrid_main"): |
| from huggingface_hub import hf_hub_download |
| path = hf_hub_download(repo_id=hub_model_id, filename=f"checkpoints/{run_name}_best.pt", repo_type="model") |
| log.info(f"Downloaded checkpoint: {path}") |
| return path |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="MR-JEPA Phase 2 Training") |
| parser.add_argument("--checkpoint", type=str, default=None) |
| parser.add_argument("--hub_model_id", default="JorgeAV/MR-JEPA") |
| parser.add_argument("--run_name", default="hybrid_main_phase2") |
| parser.add_argument("--phase1_run", default="hybrid_main") |
| parser.add_argument("--epochs", type=int, default=10) |
| parser.add_argument("--batch_size", type=int, default=16) |
| parser.add_argument("--grad_accum", type=int, default=8) |
| parser.add_argument("--core_lr", type=float, default=1e-4) |
| parser.add_argument("--backbone_lr", type=float, default=1e-5) |
| parser.add_argument("--text_lr", type=float, default=1e-5) |
| parser.add_argument("--unfreeze_visual_layers", type=int, default=6) |
| parser.add_argument("--unfreeze_text_layers", type=int, default=4) |
| parser.add_argument("--max_eval_samples", type=int, default=500) |
| parser.add_argument("--vis_interval", type=int, default=100) |
| parser.add_argument("--output_dir", default="./outputs/mrjepa_phase2") |
| parser.add_argument("--trackio_space", default="JorgeAV/MR-JEPA-Trackio", |
| help="HF Space ID for persistent Trackio dashboard") |
| args = parser.parse_args() |
|
|
| log.info("Downloading Phase 1 training script...") |
| from huggingface_hub import hf_hub_download |
| p1_script = hf_hub_download(repo_id=args.hub_model_id, filename="train_mrjepa.py", repo_type="model") |
| import importlib.util |
| spec = importlib.util.spec_from_file_location("train_mrjepa", p1_script) |
| p1 = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(p1) |
|
|
| if args.checkpoint and os.path.exists(args.checkpoint): |
| ckpt_path = args.checkpoint |
| else: |
| ckpt_path = download_phase1_checkpoint(args.hub_model_id, args.phase1_run) |
|
|
| log.info(f"Loading Phase 1 checkpoint: {ckpt_path}") |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
|
|
| saved_cfg = ckpt["config"] |
| cfg = p1.Config() |
| for k, v in saved_cfg.items(): |
| if hasattr(cfg, k): |
| setattr(cfg, k, v) |
|
|
| cfg.phase = 2 |
| cfg.epochs = args.epochs |
| cfg.batch_size = args.batch_size |
| cfg.grad_accum = args.grad_accum |
| cfg.lr = args.core_lr |
| cfg.backbone_lr = args.backbone_lr |
| cfg.output_dir = args.output_dir |
| cfg.run_name = args.run_name |
| cfg.freeze_backbone = True |
| cfg.freeze_text = True |
| cfg.max_eval_samples = args.max_eval_samples |
| cfg.resolve() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| log.info(f"Device: {device}") |
| os.makedirs(cfg.output_dir, exist_ok=True) |
|
|
| |
| import trackio |
| trackio.init( |
| name=args.run_name, |
| project="MR-JEPA", |
| space_id=args.trackio_space, |
| config={ |
| "phase": 2, "epochs": args.epochs, |
| "core_lr": args.core_lr, "backbone_lr": args.backbone_lr, "text_lr": args.text_lr, |
| "batch_size": args.batch_size, "grad_accum": args.grad_accum, |
| "unfreeze_visual_layers": args.unfreeze_visual_layers, |
| "unfreeze_text_layers": args.unfreeze_text_layers, |
| "phase1_best_acc": ckpt.get("eval_acc", "unknown"), |
| "vis_interval": args.vis_interval, |
| "backbone": cfg.backbone, "K": cfg.K, "use_jepa": cfg.use_jepa, "loss_fn": cfg.loss_fn, |
| } |
| ) |
| log.info(f"Trackio initialized → Space: https://huggingface.co/spaces/{args.trackio_space}") |
|
|
| log.info("Building model...") |
| model = p1.MRJEPAModel(cfg) |
| model.evidence.load_state_dict(ckpt["evidence"]) |
| model.rollout.load_state_dict(ckpt["rollout"]) |
| model.disc.load_state_dict(ckpt["disc"]) |
| model.target.t_ev.load_state_dict(ckpt["target_ev"]) |
| model.target.t_ro.load_state_dict(ckpt["target_ro"]) |
| log.info(f"Loaded Phase 1 weights (epoch={ckpt.get('epoch','?')}, eval_acc={ckpt.get('eval_acc','?')}%)") |
|
|
| log.info(f"Unfreezing last {args.unfreeze_visual_layers} visual layers, " |
| f"last {args.unfreeze_text_layers} text layers") |
| model.vis.unfreeze_last(args.unfreeze_visual_layers) |
| model.txt.unfreeze_last(args.unfreeze_text_layers) |
|
|
| model = model.to(device) |
| total_p = sum(p.numel() for p in model.parameters()) |
| train_p = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| log.info(f"Total: {total_p:,} | Trainable: {train_p:,} ({100*train_p/total_p:.1f}%)") |
| trackio.log({"model/total_params": total_p, "model/trainable_params": train_p, |
| "model/trainable_pct": 100 * train_p / total_p}) |
|
|
| transform = model.vis.get_transform() |
| tokenizer = model.txt.tokenizer |
| train_ds = p1.ScienceQADataset("train", transform=transform, tokenizer=tokenizer, |
| max_len=cfg.max_text_len, max_opts=cfg.max_options) |
| eval_ds = p1.ScienceQADataset("test", max_samples=cfg.max_eval_samples, |
| transform=transform, tokenizer=tokenizer, |
| max_len=cfg.max_text_len, max_opts=cfg.max_options) |
| coll = lambda batch: p1.collate_fn(batch, transform, tokenizer, cfg.max_text_len, cfg.max_options) |
| train_dl = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, |
| num_workers=2, collate_fn=coll, pin_memory=True, drop_last=True) |
| eval_dl = DataLoader(eval_ds, batch_size=cfg.batch_size, shuffle=False, |
| num_workers=2, collate_fn=coll, pin_memory=True) |
|
|
| backbone_params = [p for p in model.vis.parameters() if p.requires_grad] |
| text_params = [p for p in model.txt.parameters() if p.requires_grad] |
| bb_txt_ids = {id(p) for p in backbone_params + text_params} |
| core_params = [p for p in model.parameters() if p.requires_grad and id(p) not in bb_txt_ids] |
|
|
| param_groups = [ |
| {"params": core_params, "lr": args.core_lr}, |
| {"params": backbone_params, "lr": args.backbone_lr}, |
| {"params": text_params, "lr": args.text_lr}, |
| ] |
| log.info(f"Optimizer: core={len(core_params)} @ {args.core_lr}, " |
| f"backbone={len(backbone_params)} @ {args.backbone_lr}, " |
| f"text={len(text_params)} @ {args.text_lr}") |
|
|
| optimizer = AdamW(param_groups, weight_decay=cfg.weight_decay) |
| total_steps = cfg.epochs * len(train_dl) // cfg.grad_accum |
| warmup_steps = int(total_steps * cfg.warmup_ratio) |
|
|
| def lr_lambda(step): |
| if step < warmup_steps: |
| return step / max(warmup_steps, 1) |
| progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1) |
| return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress)) |
|
|
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
| diag_collector = DiagnosticCollector() |
| diag_collector.attach(model) |
|
|
| log.info(f"Phase 2: {cfg.epochs} epochs, {len(train_dl)} batches/epoch, ga={cfg.grad_accum}") |
| log.info(f"Visual diagnostics every {args.vis_interval} optimizer steps") |
| global_step = 0 |
| best_acc = ckpt.get("eval_acc", 0.0) |
| amp_dtype = torch.bfloat16 if cfg.bf16 else torch.float32 |
| trainable = [p for p in model.parameters() if p.requires_grad] |
|
|
| try: |
| for epoch in range(cfg.epochs): |
| model.train() |
| epoch_losses = defaultdict(list) |
| epoch_correct = 0 |
| epoch_total = 0 |
| optimizer.zero_grad() |
|
|
| for batch_idx, batch in enumerate(train_dl): |
| batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} |
| vis_tok = model.vis(batch["pixel_values"]).float() |
| txt_tok = model.txt(batch["input_ids"], batch["attention_mask"]).float() |
|
|
| with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type == "cuda"): |
| evidence, _, ev_mask = model.evidence(vis_tok, txt_tok, batch["attention_mask"]) |
| if model._use_rollout: |
| traj, z_final, z_proj = model.rollout(evidence) |
| else: |
| B = batch["batch_size"] |
| z0 = model.rollout.init_tokens.expand(B, -1, -1) + \ |
| model.rollout.z0_proj(F.adaptive_avg_pool1d( |
| evidence.permute(0,2,1), model.rollout.num_tokens).permute(0,2,1)) |
| z_final = z0 |
| z_proj = model.rollout.out_proj(z0).unsqueeze(1) |
|
|
| if model._use_jepa: |
| target_proj = model.target(vis_tok.detach(), txt_tok.detach(), batch["attention_mask"].detach()) |
| else: |
| target_proj = None |
|
|
| opt_emb = model.encode_options(batch["opt_input_ids"], batch["opt_attention_mask"]) |
| opt_emb = opt_emb.view(batch["batch_size"], cfg.max_options, -1) |
| logits = model.disc(z_final, opt_emb, batch["opt_mask"]) |
| task_loss = F.cross_entropy(logits, batch["labels"]) |
|
|
| if model._use_jepa and target_proj is not None: |
| losses = model.jepa_loss(z_proj, target_proj, task_loss) |
| else: |
| losses = {"total": task_loss, "jepa": torch.tensor(0.0), "task": task_loss, "reg": torch.tensor(0.0)} |
| loss = losses["total"] / cfg.grad_accum |
|
|
| loss.backward() |
|
|
| if (batch_idx + 1) % cfg.grad_accum == 0: |
| nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm) |
| optimizer.step(); scheduler.step(); optimizer.zero_grad() |
| model.update_target(global_step, total_steps) |
| global_step += 1 |
| if global_step % args.vis_interval == 0 and global_step > 0: |
| log.info(f"Generating visual diagnostics at step {global_step}...") |
| log_visual_diagnostics(model, batch, device, cfg, global_step, epoch, |
| diagnostics_collector=diag_collector, vis_interval=args.vis_interval) |
|
|
| preds = logits.argmax(dim=-1) |
| for k, v in losses.items(): |
| if isinstance(v, torch.Tensor): |
| epoch_losses[k].append(v.item()) |
| epoch_correct += (preds == batch["labels"]).sum().item() |
| epoch_total += batch["batch_size"] |
|
|
| if batch_idx % 50 == 0: |
| avg = {k: np.mean(v[-50:]) for k, v in epoch_losses.items()} |
| acc = epoch_correct / max(epoch_total, 1) * 100 |
| lrs = scheduler.get_last_lr() |
| log.info(f"P2 E{epoch} B{batch_idx}/{len(train_dl)} | " |
| f"loss={avg.get('total',0):.4f} jepa={avg.get('jepa',0):.4f} " |
| f"task={avg.get('task',0):.4f} | acc={acc:.1f}%") |
| trackio.log({ |
| "train/loss": avg.get("total", 0), "train/jepa_loss": avg.get("jepa", 0), |
| "train/task_loss": avg.get("task", 0), "train/reg_loss": avg.get("reg", 0), |
| "train/accuracy": acc, "train/lr": lrs[0] if lrs else 0, |
| "train/backbone_lr": lrs[1] if len(lrs) > 1 else 0, |
| "train/text_lr": lrs[2] if len(lrs) > 2 else 0, |
| "train/ema_momentum": model.target.mom, |
| "train/epoch": epoch, "train/step": global_step, |
| }) |
|
|
| eval_acc = p1.evaluate(model, eval_dl, device, cfg) |
| train_acc = epoch_correct / max(epoch_total, 1) * 100 |
| log.info(f"=== Phase 2 Epoch {epoch} | Train: {train_acc:.1f}% | Eval: {eval_acc:.1f}% ===") |
| trackio.log({"eval/accuracy": eval_acc, "eval/epoch": epoch, |
| "eval/train_accuracy": train_acc, "eval/best_accuracy": max(best_acc, eval_acc)}) |
|
|
| log.info(f"Generating epoch-end visual diagnostics...") |
| diag_batch = next(iter(eval_dl)) |
| diag_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in diag_batch.items()} |
| log_visual_diagnostics(model, diag_batch, device, cfg, global_step, epoch, |
| diagnostics_collector=diag_collector, vis_interval=args.vis_interval) |
|
|
| if eval_acc > best_acc: |
| best_acc = eval_acc |
| p1.save_checkpoint(model, cfg, epoch, eval_acc, is_best=True) |
| log.info(f"New best accuracy: {best_acc:.1f}%") |
|
|
| log.info(f"Phase 2 complete. Best eval accuracy: {best_acc:.1f}%") |
|
|
| finally: |
| |
| diag_collector.detach() |
| trackio.log({"final/best_accuracy": best_acc, "final/phase": 2, "final/total_steps": global_step}) |
| log.info("Finishing Trackio and syncing to Space...") |
| trackio.finish() |
| |
| try: |
| trackio.sync(project="MR-JEPA", space_id=args.trackio_space) |
| log.info(f"Trackio synced to https://huggingface.co/spaces/{args.trackio_space}") |
| except Exception as e: |
| log.warning(f"Trackio sync failed (data may still be available via finish): {e}") |
|
|
| if cfg.push_to_hub: |
| p1.push_results(cfg, best_acc) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|