MR-JEPA / train_phase2.py
JorgeAV's picture
feat: add persistent Trackio Space for image logging (space_id + sync)
da959e9 verified
#!/usr/bin/env python3
"""
MR-JEPA Phase 2 Training — Perception Fine-tuning + SOTA Visual Diagnostics
Loads the best Phase 1 checkpoint and unfreezes:
- Last 6 DINOv3-L layers (LR: 1e-5)
- Last 4 Qwen3-Embedding layers (LR: 1e-5)
- Reasoning core continues at 1e-4
Visual diagnostics logged to Trackio (state-of-the-art for JEPA):
1. PCA Feature Maps (V-JEPA 2.1 style) — patch features → RGB via PCA
2. Multi-head Attention Heatmaps (DINO style) — CLS attention per head overlaid on image
3. RankMe Score (anti-collapse) — effective rank of embedding matrix
4. Per-dimension Variance (VICReg style) — collapse detection per dim
5. Latent Trajectory PCA — z₀→z₁→z₂→z₃ projected to 2D
6. Temporal Straightness (LeWM Eq. 9) — trajectory coherence metric
7. Evidence Gate Activation Heatmaps — gate values per rollout step
8. Token Norm Maps (DINOv2-Reg style) — artifact detection
9. Cross-Attention Weights in Perceiver — which evidence each query attends to
10. Eigenspectrum Plot — singular value distribution of latent space
All images are persisted to HF Space JorgeAV/MR-JEPA-Trackio via space_id parameter.
Usage:
python train_phase2.py --checkpoint checkpoints/hybrid_main_best.pt
python train_phase2.py --epochs 10 --backbone_lr 1e-5
Prerequisites:
Phase 1 must be complete with a saved checkpoint at JorgeAV/MR-JEPA
"""
import os
import sys
import json
import math
import copy
import logging
import argparse
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader
# ── Matplotlib non-interactive backend (MUST be before pyplot import) ──
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
from sklearn.decomposition import PCA as SklearnPCA
from PIL import Image as PILImage
import io
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s", datefmt="%H:%M:%S")
log = logging.getLogger("mrjepa-p2")
def fig_to_pil(fig: plt.Figure) -> PILImage.Image:
"""Convert matplotlib figure to PIL Image for Trackio logging.
Trackio accepts PIL.Image but NOT matplotlib.figure.Figure directly."""
buf = io.BytesIO()
fig.savefig(buf, format='png', dpi=120, bbox_inches='tight', pad_inches=0.1)
buf.seek(0)
img = PILImage.open(buf).copy()
buf.close()
return img
# ══════════════════════════════════════════════════════════════════════════
# SOTA VISUALIZATION FUNCTIONS (Papers: I-JEPA, V-JEPA 2.1, DINO,
# LeWorldModel, RankMe, VICReg, DINOv2+Registers)
# ══════════════════════════════════════════════════════════════════════════
def rankme(Z: torch.Tensor, epsilon: float = 1e-7) -> float:
"""
RankMe: effective rank of embedding matrix via Shannon entropy of singular values.
Source: "RankMe: Assessing the Downstream Performance of Pretrained Self-Supervised
Representations by Their Rank" (arxiv:2210.02885)
Z: (N, D) — batch of embeddings
Returns scalar effective rank (higher = less collapsed)
"""
if Z.dim() == 3:
Z = Z.reshape(-1, Z.size(-1))
Z_centered = Z - Z.mean(0, keepdim=True)
try:
_, S, _ = torch.linalg.svd(Z_centered.float(), full_matrices=False)
p = S / (S.sum() + epsilon) + epsilon
return torch.exp(-torch.sum(p * torch.log(p))).item()
except:
return 0.0
def vicreg_collapse_stats(Z: torch.Tensor) -> dict:
"""
VICReg-style per-dimension variance monitoring.
Source: "VICReg: Variance-Invariance-Covariance Regularization" (arxiv:2105.04906)
Z: (N, D) — monitor each dimension's std
Collapsed dims have std < 0.1
"""
if Z.dim() == 3:
Z = Z.reshape(-1, Z.size(-1))
std_per_dim = Z.float().std(0) # (D,)
return {
"min_std": std_per_dim.min().item(),
"mean_std": std_per_dim.mean().item(),
"max_std": std_per_dim.max().item(),
"collapsed_dims": (std_per_dim < 0.1).sum().item(),
"total_dims": std_per_dim.size(0),
"std_values": std_per_dim.detach().cpu().numpy(),
}
def temporal_straightness(z_seq: torch.Tensor) -> float:
"""
Temporal straightness metric from LeWorldModel (Eq. 9, Appendix H, Fig. 17).
Source: "Le World Model" (arxiv:2603.19312)
Measures geometric coherence of latent trajectory.
z_seq: (B, K+1, D) or (B, K+1, N, D) — latent trajectory
Returns mean cosine similarity between consecutive velocity vectors.
Higher = more coherent (straighter) trajectory.
"""
if z_seq.dim() == 4:
z_seq = z_seq.mean(dim=2) # Pool tokens → (B, K+1, D)
v = z_seq[:, 1:] - z_seq[:, :-1] # (B, K, D)
if v.size(1) < 2:
return 0.0
v_norm = F.normalize(v.float(), dim=-1)
cos_sim = (v_norm[:, :-1] * v_norm[:, 1:]).sum(-1) # (B, K-1)
return cos_sim.mean().item()
def pca_feature_map(patch_features: torch.Tensor, grid_h: int, grid_w: int) -> np.ndarray:
"""
PCA RGB feature map visualization (V-JEPA 2.1 style).
Source: "Revisiting Feature Prediction for Learning Visual Representations
from Video" (arxiv:2603.14482), Section 2.2, Figs. 1, 3.
Maps first 3 principal components of patch features → RGB channels.
Semantic objects share PCA colors regardless of position.
patch_features: (N_patches, D)
Returns: (grid_h, grid_w, 3) uint8 array suitable for display
"""
feats = patch_features.float().detach().cpu().numpy()
n_components = min(3, feats.shape[0], feats.shape[1])
if n_components < 3:
return np.zeros((grid_h, grid_w, 3), dtype=np.uint8)
pca = SklearnPCA(n_components=3)
pca_feats = pca.fit_transform(feats) # (N, 3)
# Normalize each component to [0, 1]
for c in range(3):
vmin, vmax = pca_feats[:, c].min(), pca_feats[:, c].max()
if vmax - vmin > 1e-8:
pca_feats[:, c] = (pca_feats[:, c] - vmin) / (vmax - vmin)
else:
pca_feats[:, c] = 0.5
n_patches = pca_feats.shape[0]
if n_patches != grid_h * grid_w:
actual_side = int(math.sqrt(n_patches))
grid_h = grid_w = actual_side
pca_img = (pca_feats[:grid_h * grid_w].reshape(grid_h, grid_w, 3) * 255).astype(np.uint8)
return pca_img
def eigenspectrum_plot(Z: torch.Tensor, title: str = "Eigenspectrum") -> plt.Figure:
"""
Plot singular value spectrum of embedding matrix.
Source: "RankMe" (arxiv:2210.02885), used in DINO, I-JEPA for collapse monitoring.
Healthy representations show flat spectrum; collapsed ones show sharp decay.
"""
if Z.dim() == 3:
Z = Z.reshape(-1, Z.size(-1))
Z_c = Z.float() - Z.float().mean(0, keepdim=True)
try:
_, S, _ = torch.linalg.svd(Z_c, full_matrices=False)
s = S.detach().cpu().numpy()
except:
s = np.ones(10)
fig, ax = plt.subplots(figsize=(6, 4))
ax.semilogy(s[:min(100, len(s))], 'b-', linewidth=1.5)
ax.fill_between(range(min(100, len(s))), s[:min(100, len(s))], alpha=0.15, color='blue')
ax.set_xlabel("Singular Value Index", fontsize=10)
ax.set_ylabel("Singular Value (log)", fontsize=10)
ax.set_title(title, fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3)
fig.tight_layout()
return fig
def per_dim_variance_plot(std_values: np.ndarray, title: str = "Per-Dimension Std") -> plt.Figure:
"""
VICReg-style per-dimension standard deviation bar plot.
Source: "VICReg" (arxiv:2105.04906), Section 4.
Each bar = std of one embedding dimension across the batch.
Target: all dims near γ=1.0. Dims with std→0 are collapsed.
"""
fig, ax = plt.subplots(figsize=(8, 3))
n = min(len(std_values), 200)
colors = ['red' if v < 0.1 else ('orange' if v < 0.3 else 'steelblue') for v in std_values[:n]]
ax.bar(range(n), std_values[:n], color=colors, width=1.0, edgecolor='none')
ax.axhline(y=1.0, color='green', linestyle='--', alpha=0.7, label='Target γ=1.0')
ax.axhline(y=0.1, color='red', linestyle='--', alpha=0.5, label='Collapse threshold')
ax.set_xlabel("Dimension Index", fontsize=9)
ax.set_ylabel("Std", fontsize=9)
ax.set_title(title, fontsize=11, fontweight='bold')
ax.legend(fontsize=8)
ax.set_xlim(-0.5, n - 0.5)
fig.tight_layout()
return fig
def trajectory_pca_plot(trajectory: torch.Tensor, title: str = "Latent Trajectory") -> plt.Figure:
"""
Latent trajectory visualization via PCA projection.
Source: I-JEPA + LeWorldModel trajectory analysis.
trajectory: (K+1, N_tokens, D) — single sample trajectory
Projects step centroids to 2D, draws arrows showing reasoning evolution.
"""
K_plus_1, N_s, D = trajectory.shape
centroids = trajectory.mean(dim=1).float().detach().cpu().numpy()
if K_plus_1 < 2:
fig, ax = plt.subplots(figsize=(5, 5))
ax.set_title(title)
return fig
centered = centroids - centroids.mean(axis=0)
try:
pca = SklearnPCA(n_components=2)
coords = pca.fit_transform(centered)
except:
coords = centered[:, :2]
fig, ax = plt.subplots(figsize=(6, 6))
colors = plt.cm.viridis(np.linspace(0, 1, K_plus_1))
for k in range(K_plus_1 - 1):
ax.annotate("", xy=coords[k+1], xytext=coords[k],
arrowprops=dict(arrowstyle="->", color=colors[k], lw=2.5))
for k in range(K_plus_1):
ax.scatter(coords[k, 0], coords[k, 1], c=[colors[k]], s=150,
zorder=5, edgecolors='black', linewidth=1.5)
label = 'z₀' if k == 0 else f'z_{k}'
ax.annotate(label, (coords[k, 0], coords[k, 1]),
textcoords="offset points", xytext=(10, 10),
fontsize=12, fontweight='bold', color=colors[k])
ax.set_xlabel("PC1", fontsize=10)
ax.set_ylabel("PC2", fontsize=10)
ax.set_title(title, fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3)
fig.tight_layout()
return fig
def attention_heatmap_overlay(
attn_weights: torch.Tensor,
original_image: np.ndarray,
grid_h: int = 16, grid_w: int = 16,
title: str = "Attention Heatmap",
num_heads_to_show: int = 4,
) -> plt.Figure:
"""
DINO-style multi-head self-attention heatmap overlay.
Source: "Emerging Properties in Self-Supervised Vision Transformers"
(arxiv:2104.14294), Section 4.2.2, Fig. 3.
"""
if attn_weights.dim() == 2:
attn_weights = attn_weights.unsqueeze(0)
n_heads = min(attn_weights.size(0), num_heads_to_show)
fig, axes = plt.subplots(1, n_heads + 1, figsize=(4 * (n_heads + 1), 4))
if n_heads + 1 == 1:
axes = [axes]
axes[0].imshow(original_image)
axes[0].set_title("Input Image", fontsize=10)
axes[0].axis('off')
head_colors = ['Reds', 'Blues', 'Greens', 'Purples', 'Oranges', 'YlOrRd']
for h in range(n_heads):
attn = attn_weights[h].float().detach().cpu().numpy()
if attn.ndim == 2:
attn_map = attn[0]
else:
attn_map = attn
n_tokens = attn_map.shape[0]
side = int(math.sqrt(n_tokens))
if side * side != n_tokens:
side = grid_h
attn_2d = attn_map[:side*side].reshape(side, side)
attn_2d = (attn_2d - attn_2d.min()) / (attn_2d.max() - attn_2d.min() + 1e-8)
attn_resized = np.array(PILImage.fromarray((attn_2d * 255).astype(np.uint8)).resize(
(original_image.shape[1], original_image.shape[0]), PILImage.BILINEAR)) / 255.0
axes[h + 1].imshow(original_image)
axes[h + 1].imshow(attn_resized, cmap=head_colors[h % len(head_colors)], alpha=0.6)
axes[h + 1].set_title(f"Head {h}", fontsize=10)
axes[h + 1].axis('off')
fig.suptitle(title, fontsize=12, fontweight='bold')
fig.tight_layout()
return fig
def evidence_gate_heatmap(gate_values: list, title: str = "Evidence Gate Activations") -> plt.Figure:
"""Evidence gate activation visualization per rollout step."""
K = len(gate_values)
if K == 0:
fig, ax = plt.subplots()
ax.text(0.5, 0.5, "No gates recorded", ha='center', va='center')
return fig
fig, axes = plt.subplots(1, K, figsize=(4 * K, 4))
if K == 1:
axes = [axes]
for k, gv in enumerate(gate_values):
if isinstance(gv, torch.Tensor):
gv = gv.float().detach().cpu().numpy()
if gv.ndim == 3:
gv = gv.mean(0)
mean_gate = gv.mean(axis=-1) if gv.ndim == 2 else gv
im = axes[k].imshow(mean_gate.reshape(1, -1) if mean_gate.ndim == 1 else mean_gate,
cmap='YlOrRd', aspect='auto', vmin=0, vmax=1)
axes[k].set_title(f'Step {k+1}: μ={mean_gate.mean():.3f}', fontsize=10)
axes[k].set_xlabel("Token Index")
fig.colorbar(im, ax=axes[k], fraction=0.046, pad=0.04)
fig.suptitle(title, fontsize=12, fontweight='bold')
fig.tight_layout()
return fig
def token_norm_map(tokens: torch.Tensor, grid_h: int = 16, grid_w: int = 16,
title: str = "Token Norm Map") -> plt.Figure:
"""
Token norm map for artifact detection (DINOv2+Registers style).
Source: "Vision Transformers Need Registers" (arxiv:2309.16588).
"""
if tokens.dim() == 3:
tokens = tokens[0]
norms = tokens.float().norm(dim=-1).detach().cpu().numpy()
n_tokens = len(norms)
side = int(math.sqrt(n_tokens))
if side * side != n_tokens:
side = grid_h
norm_map = norms[:side*side].reshape(side, side)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
im = axes[0].imshow(norm_map, cmap='hot', aspect='auto')
axes[0].set_title(f"{title}\nμ={norms.mean():.2f}, σ={norms.std():.2f}", fontsize=10)
fig.colorbar(im, ax=axes[0])
axes[1].hist(norms, bins=50, color='steelblue', edgecolor='white', alpha=0.8)
axes[1].axvline(norms.mean(), color='red', linestyle='--', label=f'Mean: {norms.mean():.2f}')
axes[1].axvline(norms.mean() + 2*norms.std(), color='orange', linestyle='--',
label=f'+2σ: {norms.mean() + 2*norms.std():.2f}')
axes[1].set_title("Norm Distribution", fontsize=10)
axes[1].set_xlabel("||token||₂")
axes[1].legend(fontsize=8)
fig.suptitle(title, fontsize=12, fontweight='bold')
fig.tight_layout()
return fig
def cross_attention_weights_plot(attn_weights: torch.Tensor,
title: str = "Perceiver Cross-Attention") -> plt.Figure:
"""Cross-attention weights in the Perceiver Resampler."""
if attn_weights.dim() == 3:
attn = attn_weights.float().mean(0).detach().cpu().numpy()
else:
attn = attn_weights.float().detach().cpu().numpy()
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(attn, cmap='viridis', ax=ax, xticklabels=False, yticklabels=False)
ax.set_xlabel("Key Tokens (Evidence: Visual | Text)", fontsize=10)
ax.set_ylabel("Query Tokens (Latent)", fontsize=10)
ax.set_title(title, fontsize=12, fontweight='bold')
fig.tight_layout()
return fig
def rollout_comparison_grid(images: list, pca_maps: list, titles: list = None,
suptitle: str = "Latent Rollout PCA Maps") -> plt.Figure:
"""
Grid comparing original images with PCA feature maps at each rollout step.
Source: LeWorldModel Fig. 7 + V-JEPA 2.1 Fig. 3.
"""
n_samples = min(len(images), 4)
n_steps = len(pca_maps[0]) if pca_maps else 0
cols = 1 + n_steps
fig, axes = plt.subplots(n_samples, cols, figsize=(3 * cols, 3 * n_samples))
if n_samples == 1:
axes = axes.reshape(1, -1)
for i in range(n_samples):
axes[i, 0].imshow(images[i])
axes[i, 0].set_title("Input" if i == 0 else "", fontsize=9)
axes[i, 0].axis('off')
for k in range(n_steps):
if k < len(pca_maps[i]):
axes[i, k+1].imshow(pca_maps[i][k])
axes[i, k+1].set_title(f"z_{k}" if i == 0 else "", fontsize=9)
axes[i, k+1].axis('off')
fig.suptitle(suptitle, fontsize=13, fontweight='bold')
fig.tight_layout()
return fig
# ══════════════════════════════════════════════════════════════════════════
# HOOK-BASED DIAGNOSTIC COLLECTOR
# ══════════════════════════════════════════════════════════════════════════
class DiagnosticCollector:
"""Collects intermediate activations via hooks for visualization."""
def __init__(self):
self.gate_values = []
self.cross_attn_weights = []
self.hooks = []
def attach(self, model):
self.clear()
if hasattr(model, 'rollout') and hasattr(model.rollout, 'predictor'):
for i, block in enumerate(model.rollout.predictor):
if hasattr(block, 'gate') and block.gate is not None:
def make_gate_hook(layer_idx):
def hook(module, input, output):
self.gate_values.append(output.detach().cpu())
return hook
h = block.gate.proj.register_forward_hook(make_gate_hook(i))
self.hooks.append(h)
if hasattr(model, 'evidence') and hasattr(model.evidence, 'layers'):
last_layer = model.evidence.layers[-1]
if hasattr(last_layer, 'xa'):
def xa_hook(module, input, output):
if isinstance(output, tuple) and len(output) > 1:
self.cross_attn_weights.append(output[1].detach().cpu())
h = last_layer.xa.register_forward_hook(xa_hook)
self.hooks.append(h)
def clear(self):
self.gate_values = []
self.cross_attn_weights = []
def detach(self):
for h in self.hooks:
h.remove()
self.hooks = []
# ══════════════════════════════════════════════════════════════════════════
# MAIN VISUALIZATION LOGGING FUNCTION
# ══════════════════════════════════════════════════════════════════════════
def log_visual_diagnostics(model, batch, device, cfg, global_step, epoch,
diagnostics_collector=None, vis_interval=100):
"""
Generate and log all visual diagnostics to Trackio.
Implements SOTA visualizations from:
- I-JEPA (attention maps), V-JEPA 2.1 (PCA feature maps)
- LeWorldModel (trajectory, temporal straightness)
- RankMe (effective rank), VICReg (per-dim variance)
- DINOv2+Registers (token norm maps)
"""
import trackio
model.eval()
log_dict = {}
try:
with torch.no_grad():
vis_tok = model.vis(batch["pixel_values"].to(device)).float()
txt_tok = model.txt(batch["input_ids"].to(device),
batch["attention_mask"].to(device)).float()
evidence, kv, ev_mask = model.evidence(vis_tok, txt_tok,
batch["attention_mask"].to(device))
if model._use_rollout:
traj, z_final, z_proj = model.rollout(evidence)
else:
traj = evidence.unsqueeze(1)
z_final = evidence
z_proj = evidence.unsqueeze(1)
# ── RankMe Score (anti-collapse) ──
evidence_rank = rankme(evidence)
z_final_rank = rankme(z_final)
vis_rank = rankme(vis_tok)
log_dict["diagnostics/rankme_evidence"] = evidence_rank
log_dict["diagnostics/rankme_z_final"] = z_final_rank
log_dict["diagnostics/rankme_visual"] = vis_rank
# ── VICReg Collapse Stats ──
ev_stats = vicreg_collapse_stats(evidence)
zf_stats = vicreg_collapse_stats(z_final)
log_dict["diagnostics/evidence_min_std"] = ev_stats["min_std"]
log_dict["diagnostics/evidence_mean_std"] = ev_stats["mean_std"]
log_dict["diagnostics/evidence_collapsed_dims"] = ev_stats["collapsed_dims"]
log_dict["diagnostics/z_final_min_std"] = zf_stats["min_std"]
log_dict["diagnostics/z_final_mean_std"] = zf_stats["mean_std"]
log_dict["diagnostics/z_final_collapsed_dims"] = zf_stats["collapsed_dims"]
# ── Temporal Straightness (LeWM Eq. 9) ──
if traj.dim() == 4:
straightness = temporal_straightness(traj)
log_dict["diagnostics/temporal_straightness"] = straightness
centroids = traj.mean(dim=2)
for k in range(centroids.size(1) - 1):
dist = torch.norm(centroids[:, k+1] - centroids[:, k], dim=-1).mean().item()
log_dict[f"diagnostics/step_distance_z{k}_to_z{k+1}"] = dist
# ── PCA Feature Maps (V-JEPA 2.1) ──
n_vis_patches = vis_tok.size(1)
grid_side = int(math.sqrt(n_vis_patches))
pca_ctx = pca_feature_map(vis_tok[0], grid_side, grid_side)
fig_pca_ctx = plt.figure(figsize=(4, 4))
plt.imshow(pca_ctx); plt.title("Context Encoder PCA", fontsize=11, fontweight='bold')
plt.axis('off'); plt.tight_layout()
log_dict["visuals/pca/context_encoder"] = trackio.Image(fig_to_pil(fig_pca_ctx),
caption=f"step={global_step}")
plt.close(fig_pca_ctx)
pca_ev = pca_feature_map(evidence[0], 8, 8)
fig_pca_ev = plt.figure(figsize=(4, 4))
plt.imshow(pca_ev); plt.title("Evidence Memory PCA", fontsize=11, fontweight='bold')
plt.axis('off'); plt.tight_layout()
log_dict["visuals/pca/evidence_memory"] = trackio.Image(fig_to_pil(fig_pca_ev),
caption=f"step={global_step}")
plt.close(fig_pca_ev)
if traj.dim() == 4 and traj.size(1) > 1:
rollout_pcas = []
for k in range(traj.size(1)):
n_tok = traj.size(2)
side_k = max(int(math.sqrt(n_tok)), 1)
rollout_pcas.append(pca_feature_map(traj[0, k], side_k, side_k))
fig_rollout = plt.figure(figsize=(4 * len(rollout_pcas), 4))
for k, pca_k in enumerate(rollout_pcas):
ax = fig_rollout.add_subplot(1, len(rollout_pcas), k + 1)
ax.imshow(pca_k); ax.set_title(f"z_{k}", fontsize=10); ax.axis('off')
fig_rollout.suptitle("Rollout PCA per Step", fontsize=12, fontweight='bold')
fig_rollout.tight_layout()
log_dict["visuals/pca/rollout_steps"] = trackio.Image(fig_to_pil(fig_rollout),
caption=f"step={global_step}")
plt.close(fig_rollout)
# ── Eigenspectrum Plot ──
fig_eigen_ev = eigenspectrum_plot(evidence, "Evidence Eigenspectrum")
log_dict["visuals/eigenspectrum/evidence"] = trackio.Image(fig_to_pil(fig_eigen_ev),
caption=f"RankMe={evidence_rank:.1f}, step={global_step}")
plt.close(fig_eigen_ev)
fig_eigen_zf = eigenspectrum_plot(z_final, "z_final Eigenspectrum")
log_dict["visuals/eigenspectrum/z_final"] = trackio.Image(fig_to_pil(fig_eigen_zf),
caption=f"RankMe={z_final_rank:.1f}, step={global_step}")
plt.close(fig_eigen_zf)
# ── Per-Dimension Variance Plot (VICReg) ──
fig_vardim_ev = per_dim_variance_plot(ev_stats["std_values"],
f"Evidence Per-Dim Std (collapsed={ev_stats['collapsed_dims']})")
log_dict["visuals/collapse/evidence_std"] = trackio.Image(fig_to_pil(fig_vardim_ev),
caption=f"step={global_step}")
plt.close(fig_vardim_ev)
fig_vardim_zf = per_dim_variance_plot(zf_stats["std_values"],
f"z_final Per-Dim Std (collapsed={zf_stats['collapsed_dims']})")
log_dict["visuals/collapse/z_final_std"] = trackio.Image(fig_to_pil(fig_vardim_zf),
caption=f"step={global_step}")
plt.close(fig_vardim_zf)
# ── Latent Trajectory PCA ──
if traj.dim() == 4 and traj.size(1) > 1:
fig_traj = trajectory_pca_plot(traj[0],
f"Latent Trajectory (straightness={straightness:.3f})")
log_dict["visuals/trajectory/pca"] = trackio.Image(fig_to_pil(fig_traj),
caption=f"K={traj.size(1)-1}, step={global_step}")
plt.close(fig_traj)
# ── Token Norm Maps (DINOv2+Registers) ──
fig_norm_vis = token_norm_map(vis_tok[0], grid_side, grid_side, "Visual Token Norms")
log_dict["visuals/norms/visual_tokens"] = trackio.Image(fig_to_pil(fig_norm_vis),
caption=f"step={global_step}")
plt.close(fig_norm_vis)
fig_norm_ev = token_norm_map(evidence[0], 8, 8, "Evidence Token Norms")
log_dict["visuals/norms/evidence_tokens"] = trackio.Image(fig_to_pil(fig_norm_ev),
caption=f"step={global_step}")
plt.close(fig_norm_ev)
# ── Cross-Attention Weights (hooks) ──
if diagnostics_collector and diagnostics_collector.cross_attn_weights:
attn_w = diagnostics_collector.cross_attn_weights[-1]
if attn_w is not None and attn_w.dim() >= 2:
if attn_w.dim() == 4: attn_w = attn_w[0]
elif attn_w.dim() == 3: attn_w = attn_w[0]
fig_xattn = cross_attention_weights_plot(attn_w, "Perceiver Cross-Attention (Last Layer)")
log_dict["visuals/attention/perceiver_xattn"] = trackio.Image(fig_to_pil(fig_xattn),
caption=f"step={global_step}")
plt.close(fig_xattn)
# ── Evidence Gate Heatmaps ──
if diagnostics_collector and diagnostics_collector.gate_values:
gate_vals = [gv[0] if gv.dim() == 3 else gv
for gv in diagnostics_collector.gate_values[-cfg.K:]]
if gate_vals:
fig_gates = evidence_gate_heatmap(gate_vals, "Evidence Gate Activations")
log_dict["visuals/gates/activations"] = trackio.Image(fig_to_pil(fig_gates),
caption=f"step={global_step}")
plt.close(fig_gates)
for k, gv in enumerate(gate_vals):
if isinstance(gv, torch.Tensor):
log_dict[f"diagnostics/gate_mean_step{k+1}"] = gv.float().mean().item()
# ── Input Image with Feature Similarity Heatmap (DINO style) ──
img_tensor = batch["pixel_values"][0]
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
img_denorm = (img_tensor.cpu() * std + mean).clamp(0, 1)
img_np = (img_denorm.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
vis_feats = vis_tok[0]
mean_feat = vis_feats.mean(0, keepdim=True)
sim = F.cosine_similarity(vis_feats, mean_feat.expand_as(vis_feats), dim=-1).detach().cpu()
n_patches = sim.size(0)
side = int(math.sqrt(n_patches))
if side * side == n_patches:
fig_attn_overlay, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].imshow(img_np); ax[0].set_title("Input Image", fontsize=10); ax[0].axis('off')
attn_resized = np.array(PILImage.fromarray(
((sim.numpy().reshape(side, side) - sim.min().item()) /
(sim.max().item() - sim.min().item() + 1e-8) * 255).astype(np.uint8)
).resize((img_np.shape[1], img_np.shape[0]), PILImage.BILINEAR)) / 255.0
ax[1].imshow(img_np); ax[1].imshow(attn_resized, cmap='jet', alpha=0.5)
ax[1].set_title("Feature Similarity Heatmap", fontsize=10); ax[1].axis('off')
fig_attn_overlay.suptitle("Visual Feature Heatmap (DINO-style)", fontsize=12, fontweight='bold')
fig_attn_overlay.tight_layout()
log_dict["visuals/attention/feature_heatmap"] = trackio.Image(fig_to_pil(fig_attn_overlay),
caption=f"step={global_step}")
plt.close(fig_attn_overlay)
# ── Rollout Comparison Grid ──
if traj.dim() == 4 and traj.size(1) > 1:
n_show = min(batch["pixel_values"].size(0), 3)
images_for_grid = []
pcas_for_grid = []
for i in range(n_show):
img_i = batch["pixel_values"][i].cpu()
img_i_denorm = (img_i * std + mean).clamp(0, 1)
images_for_grid.append((img_i_denorm.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
step_pcas = []
for k in range(traj.size(1)):
n_tok = traj.size(2)
side_k = max(int(math.sqrt(n_tok)), 1)
step_pcas.append(pca_feature_map(traj[i, k], side_k, side_k))
pcas_for_grid.append(step_pcas)
fig_grid = rollout_comparison_grid(images_for_grid, pcas_for_grid,
suptitle=f"Rollout PCA Grid (K={traj.size(1)-1})")
log_dict["visuals/rollout/comparison_grid"] = trackio.Image(fig_to_pil(fig_grid),
caption=f"step={global_step}")
plt.close(fig_grid)
trackio.log(log_dict)
log.info(f"Logged {len(log_dict)} visual diagnostics at step {global_step}")
except Exception as e:
log.warning(f"Visual diagnostics failed: {e}")
import traceback
traceback.print_exc()
finally:
model.train()
plt.close('all')
if diagnostics_collector:
diagnostics_collector.clear()
# ══════════════════════════════════════════════════════════════════════════
# PHASE 2 TRAINING
# ══════════════════════════════════════════════════════════════════════════
def download_phase1_checkpoint(hub_model_id: str, run_name: str = "hybrid_main"):
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id=hub_model_id, filename=f"checkpoints/{run_name}_best.pt", repo_type="model")
log.info(f"Downloaded checkpoint: {path}")
return path
def main():
parser = argparse.ArgumentParser(description="MR-JEPA Phase 2 Training")
parser.add_argument("--checkpoint", type=str, default=None)
parser.add_argument("--hub_model_id", default="JorgeAV/MR-JEPA")
parser.add_argument("--run_name", default="hybrid_main_phase2")
parser.add_argument("--phase1_run", default="hybrid_main")
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--grad_accum", type=int, default=8)
parser.add_argument("--core_lr", type=float, default=1e-4)
parser.add_argument("--backbone_lr", type=float, default=1e-5)
parser.add_argument("--text_lr", type=float, default=1e-5)
parser.add_argument("--unfreeze_visual_layers", type=int, default=6)
parser.add_argument("--unfreeze_text_layers", type=int, default=4)
parser.add_argument("--max_eval_samples", type=int, default=500)
parser.add_argument("--vis_interval", type=int, default=100)
parser.add_argument("--output_dir", default="./outputs/mrjepa_phase2")
parser.add_argument("--trackio_space", default="JorgeAV/MR-JEPA-Trackio",
help="HF Space ID for persistent Trackio dashboard")
args = parser.parse_args()
log.info("Downloading Phase 1 training script...")
from huggingface_hub import hf_hub_download
p1_script = hf_hub_download(repo_id=args.hub_model_id, filename="train_mrjepa.py", repo_type="model")
import importlib.util
spec = importlib.util.spec_from_file_location("train_mrjepa", p1_script)
p1 = importlib.util.module_from_spec(spec)
spec.loader.exec_module(p1)
if args.checkpoint and os.path.exists(args.checkpoint):
ckpt_path = args.checkpoint
else:
ckpt_path = download_phase1_checkpoint(args.hub_model_id, args.phase1_run)
log.info(f"Loading Phase 1 checkpoint: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
saved_cfg = ckpt["config"]
cfg = p1.Config()
for k, v in saved_cfg.items():
if hasattr(cfg, k):
setattr(cfg, k, v)
cfg.phase = 2
cfg.epochs = args.epochs
cfg.batch_size = args.batch_size
cfg.grad_accum = args.grad_accum
cfg.lr = args.core_lr
cfg.backbone_lr = args.backbone_lr
cfg.output_dir = args.output_dir
cfg.run_name = args.run_name
cfg.freeze_backbone = True
cfg.freeze_text = True
cfg.max_eval_samples = args.max_eval_samples
cfg.resolve()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info(f"Device: {device}")
os.makedirs(cfg.output_dir, exist_ok=True)
# ── Initialize Trackio with persistent HF Space ──
import trackio
trackio.init(
name=args.run_name,
project="MR-JEPA",
space_id=args.trackio_space,
config={
"phase": 2, "epochs": args.epochs,
"core_lr": args.core_lr, "backbone_lr": args.backbone_lr, "text_lr": args.text_lr,
"batch_size": args.batch_size, "grad_accum": args.grad_accum,
"unfreeze_visual_layers": args.unfreeze_visual_layers,
"unfreeze_text_layers": args.unfreeze_text_layers,
"phase1_best_acc": ckpt.get("eval_acc", "unknown"),
"vis_interval": args.vis_interval,
"backbone": cfg.backbone, "K": cfg.K, "use_jepa": cfg.use_jepa, "loss_fn": cfg.loss_fn,
}
)
log.info(f"Trackio initialized → Space: https://huggingface.co/spaces/{args.trackio_space}")
log.info("Building model...")
model = p1.MRJEPAModel(cfg)
model.evidence.load_state_dict(ckpt["evidence"])
model.rollout.load_state_dict(ckpt["rollout"])
model.disc.load_state_dict(ckpt["disc"])
model.target.t_ev.load_state_dict(ckpt["target_ev"])
model.target.t_ro.load_state_dict(ckpt["target_ro"])
log.info(f"Loaded Phase 1 weights (epoch={ckpt.get('epoch','?')}, eval_acc={ckpt.get('eval_acc','?')}%)")
log.info(f"Unfreezing last {args.unfreeze_visual_layers} visual layers, "
f"last {args.unfreeze_text_layers} text layers")
model.vis.unfreeze_last(args.unfreeze_visual_layers)
model.txt.unfreeze_last(args.unfreeze_text_layers)
model = model.to(device)
total_p = sum(p.numel() for p in model.parameters())
train_p = sum(p.numel() for p in model.parameters() if p.requires_grad)
log.info(f"Total: {total_p:,} | Trainable: {train_p:,} ({100*train_p/total_p:.1f}%)")
trackio.log({"model/total_params": total_p, "model/trainable_params": train_p,
"model/trainable_pct": 100 * train_p / total_p})
transform = model.vis.get_transform()
tokenizer = model.txt.tokenizer
train_ds = p1.ScienceQADataset("train", transform=transform, tokenizer=tokenizer,
max_len=cfg.max_text_len, max_opts=cfg.max_options)
eval_ds = p1.ScienceQADataset("test", max_samples=cfg.max_eval_samples,
transform=transform, tokenizer=tokenizer,
max_len=cfg.max_text_len, max_opts=cfg.max_options)
coll = lambda batch: p1.collate_fn(batch, transform, tokenizer, cfg.max_text_len, cfg.max_options)
train_dl = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
num_workers=2, collate_fn=coll, pin_memory=True, drop_last=True)
eval_dl = DataLoader(eval_ds, batch_size=cfg.batch_size, shuffle=False,
num_workers=2, collate_fn=coll, pin_memory=True)
backbone_params = [p for p in model.vis.parameters() if p.requires_grad]
text_params = [p for p in model.txt.parameters() if p.requires_grad]
bb_txt_ids = {id(p) for p in backbone_params + text_params}
core_params = [p for p in model.parameters() if p.requires_grad and id(p) not in bb_txt_ids]
param_groups = [
{"params": core_params, "lr": args.core_lr},
{"params": backbone_params, "lr": args.backbone_lr},
{"params": text_params, "lr": args.text_lr},
]
log.info(f"Optimizer: core={len(core_params)} @ {args.core_lr}, "
f"backbone={len(backbone_params)} @ {args.backbone_lr}, "
f"text={len(text_params)} @ {args.text_lr}")
optimizer = AdamW(param_groups, weight_decay=cfg.weight_decay)
total_steps = cfg.epochs * len(train_dl) // cfg.grad_accum
warmup_steps = int(total_steps * cfg.warmup_ratio)
def lr_lambda(step):
if step < warmup_steps:
return step / max(warmup_steps, 1)
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
diag_collector = DiagnosticCollector()
diag_collector.attach(model)
log.info(f"Phase 2: {cfg.epochs} epochs, {len(train_dl)} batches/epoch, ga={cfg.grad_accum}")
log.info(f"Visual diagnostics every {args.vis_interval} optimizer steps")
global_step = 0
best_acc = ckpt.get("eval_acc", 0.0)
amp_dtype = torch.bfloat16 if cfg.bf16 else torch.float32
trainable = [p for p in model.parameters() if p.requires_grad]
try:
for epoch in range(cfg.epochs):
model.train()
epoch_losses = defaultdict(list)
epoch_correct = 0
epoch_total = 0
optimizer.zero_grad()
for batch_idx, batch in enumerate(train_dl):
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
vis_tok = model.vis(batch["pixel_values"]).float()
txt_tok = model.txt(batch["input_ids"], batch["attention_mask"]).float()
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type == "cuda"):
evidence, _, ev_mask = model.evidence(vis_tok, txt_tok, batch["attention_mask"])
if model._use_rollout:
traj, z_final, z_proj = model.rollout(evidence)
else:
B = batch["batch_size"]
z0 = model.rollout.init_tokens.expand(B, -1, -1) + \
model.rollout.z0_proj(F.adaptive_avg_pool1d(
evidence.permute(0,2,1), model.rollout.num_tokens).permute(0,2,1))
z_final = z0
z_proj = model.rollout.out_proj(z0).unsqueeze(1)
if model._use_jepa:
target_proj = model.target(vis_tok.detach(), txt_tok.detach(), batch["attention_mask"].detach())
else:
target_proj = None
opt_emb = model.encode_options(batch["opt_input_ids"], batch["opt_attention_mask"])
opt_emb = opt_emb.view(batch["batch_size"], cfg.max_options, -1)
logits = model.disc(z_final, opt_emb, batch["opt_mask"])
task_loss = F.cross_entropy(logits, batch["labels"])
if model._use_jepa and target_proj is not None:
losses = model.jepa_loss(z_proj, target_proj, task_loss)
else:
losses = {"total": task_loss, "jepa": torch.tensor(0.0), "task": task_loss, "reg": torch.tensor(0.0)}
loss = losses["total"] / cfg.grad_accum
loss.backward()
if (batch_idx + 1) % cfg.grad_accum == 0:
nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm)
optimizer.step(); scheduler.step(); optimizer.zero_grad()
model.update_target(global_step, total_steps)
global_step += 1
if global_step % args.vis_interval == 0 and global_step > 0:
log.info(f"Generating visual diagnostics at step {global_step}...")
log_visual_diagnostics(model, batch, device, cfg, global_step, epoch,
diagnostics_collector=diag_collector, vis_interval=args.vis_interval)
preds = logits.argmax(dim=-1)
for k, v in losses.items():
if isinstance(v, torch.Tensor):
epoch_losses[k].append(v.item())
epoch_correct += (preds == batch["labels"]).sum().item()
epoch_total += batch["batch_size"]
if batch_idx % 50 == 0:
avg = {k: np.mean(v[-50:]) for k, v in epoch_losses.items()}
acc = epoch_correct / max(epoch_total, 1) * 100
lrs = scheduler.get_last_lr()
log.info(f"P2 E{epoch} B{batch_idx}/{len(train_dl)} | "
f"loss={avg.get('total',0):.4f} jepa={avg.get('jepa',0):.4f} "
f"task={avg.get('task',0):.4f} | acc={acc:.1f}%")
trackio.log({
"train/loss": avg.get("total", 0), "train/jepa_loss": avg.get("jepa", 0),
"train/task_loss": avg.get("task", 0), "train/reg_loss": avg.get("reg", 0),
"train/accuracy": acc, "train/lr": lrs[0] if lrs else 0,
"train/backbone_lr": lrs[1] if len(lrs) > 1 else 0,
"train/text_lr": lrs[2] if len(lrs) > 2 else 0,
"train/ema_momentum": model.target.mom,
"train/epoch": epoch, "train/step": global_step,
})
eval_acc = p1.evaluate(model, eval_dl, device, cfg)
train_acc = epoch_correct / max(epoch_total, 1) * 100
log.info(f"=== Phase 2 Epoch {epoch} | Train: {train_acc:.1f}% | Eval: {eval_acc:.1f}% ===")
trackio.log({"eval/accuracy": eval_acc, "eval/epoch": epoch,
"eval/train_accuracy": train_acc, "eval/best_accuracy": max(best_acc, eval_acc)})
log.info(f"Generating epoch-end visual diagnostics...")
diag_batch = next(iter(eval_dl))
diag_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in diag_batch.items()}
log_visual_diagnostics(model, diag_batch, device, cfg, global_step, epoch,
diagnostics_collector=diag_collector, vis_interval=args.vis_interval)
if eval_acc > best_acc:
best_acc = eval_acc
p1.save_checkpoint(model, cfg, epoch, eval_acc, is_best=True)
log.info(f"New best accuracy: {best_acc:.1f}%")
log.info(f"Phase 2 complete. Best eval accuracy: {best_acc:.1f}%")
finally:
# ── Ensure Trackio data is persisted even if training crashes ──
diag_collector.detach()
trackio.log({"final/best_accuracy": best_acc, "final/phase": 2, "final/total_steps": global_step})
log.info("Finishing Trackio and syncing to Space...")
trackio.finish()
# Belt-and-suspenders: explicit sync to ensure all images are uploaded
try:
trackio.sync(project="MR-JEPA", space_id=args.trackio_space)
log.info(f"Trackio synced to https://huggingface.co/spaces/{args.trackio_space}")
except Exception as e:
log.warning(f"Trackio sync failed (data may still be available via finish): {e}")
if cfg.push_to_hub:
p1.push_results(cfg, best_acc)
if __name__ == "__main__":
main()