MR-JEPA / mr_jepa /utils /visualization.py
JorgeAV's picture
Initial MR-JEPA codebase: architecture, training, evaluation, and tests
dba2c56 verified
"""
Visualization utilities for MR-JEPA.
Tools for analyzing and visualizing:
- Latent trajectory evolution (z₀ → z₁ → z₂ → z₃)
- Evidence gate activations per rollout step
- Attention maps between state and evidence
- t-SNE/UMAP of latent states across benchmarks
"""
import torch
import numpy as np
from typing import Optional, Dict, List
def visualize_trajectory(
trajectory: torch.Tensor, # [K+1, N_s, D]
method: str = "pca",
title: str = "Latent Trajectory Evolution",
) -> Dict[str, np.ndarray]:
"""
Visualize the latent trajectory z₀→z₁→...→z_K.
Projects high-dimensional states into 2D for plotting.
Returns coordinates that can be plotted with matplotlib.
Args:
trajectory: [K+1, N_s, D] latent states for a single sample
method: 'pca' or 'tsne'
title: Plot title
Returns:
Dict with 'coords': [K+1, 2] projected centroids per step
"""
K_plus_1, N_s, D = trajectory.shape
# Pool each step's tokens into a single vector
centroids = trajectory.mean(dim=1).detach().cpu().numpy() # [K+1, D]
if method == "pca":
# Simple PCA (no sklearn dependency)
centered = centroids - centroids.mean(axis=0)
cov = np.cov(centered.T)
eigenvalues, eigenvectors = np.linalg.eigh(cov)
# Take top 2 components
idx = np.argsort(eigenvalues)[::-1][:2]
proj_matrix = eigenvectors[:, idx]
coords = centered @ proj_matrix
else:
# Fallback to PCA for simplicity
centered = centroids - centroids.mean(axis=0)
U, S, Vt = np.linalg.svd(centered, full_matrices=False)
coords = U[:, :2] * S[:2]
return {
'coords': coords, # [K+1, 2]
'centroids': centroids, # [K+1, D] original
'step_labels': [f'z_{k}' for k in range(K_plus_1)],
}
def visualize_evidence_gates(
model,
sample_output: Dict[str, torch.Tensor],
) -> Dict[str, np.ndarray]:
"""
Extract and visualize evidence gate activations per rollout step.
Shows how much evidence flows into each step of the rollout.
Early steps may attend more to visual evidence, while later steps
rely more on accumulated reasoning.
Args:
model: MRJEPAModel instance
sample_output: Forward pass output dict
Returns:
Dict with gate activation statistics per step
"""
# This requires hooks or storing gate values during forward pass
# For now, return placeholder structure
gate_stats = {
'mean_gate_values': [],
'gate_entropy': [],
}
# Access predictor layers' evidence gates
for i, layer in enumerate(model.latent_rollout.predictor_layers):
if hasattr(layer.evidence_gate, 'gate_proj'):
# Could install hooks here for detailed analysis
pass
return gate_stats
def compute_trajectory_metrics(
trajectory: torch.Tensor, # [B, K+1, N_s, D]
) -> Dict[str, float]:
"""
Compute analytical metrics on the latent trajectory.
Useful for ablation analysis:
- Inter-step distance: how much the state changes per step
- Trajectory length: total path length in latent space
- Convergence rate: diminishing step sizes indicate convergence
- State diversity: variance within each step's tokens
"""
B, K_plus_1, N_s, D = trajectory.shape
# Pool to centroids
centroids = trajectory.mean(dim=2) # [B, K+1, D]
# Inter-step distances
step_distances = []
for k in range(K_plus_1 - 1):
dist = torch.norm(centroids[:, k+1] - centroids[:, k], dim=-1) # [B]
step_distances.append(dist.mean().item())
# Trajectory length
total_length = sum(step_distances)
# Convergence rate (ratio of last step distance to first)
convergence = step_distances[-1] / max(step_distances[0], 1e-6) if step_distances else 1.0
# State diversity per step
diversity = []
for k in range(K_plus_1):
var = trajectory[:, k].var(dim=1).mean().item() # Avg variance across tokens
diversity.append(var)
return {
'step_distances': step_distances,
'trajectory_length': total_length,
'convergence_rate': convergence,
'state_diversity': diversity,
'avg_step_distance': total_length / max(K_plus_1 - 1, 1),
}