""" Visualization utilities for MR-JEPA. Tools for analyzing and visualizing: - Latent trajectory evolution (z₀ → z₁ → z₂ → z₃) - Evidence gate activations per rollout step - Attention maps between state and evidence - t-SNE/UMAP of latent states across benchmarks """ import torch import numpy as np from typing import Optional, Dict, List def visualize_trajectory( trajectory: torch.Tensor, # [K+1, N_s, D] method: str = "pca", title: str = "Latent Trajectory Evolution", ) -> Dict[str, np.ndarray]: """ Visualize the latent trajectory z₀→z₁→...→z_K. Projects high-dimensional states into 2D for plotting. Returns coordinates that can be plotted with matplotlib. Args: trajectory: [K+1, N_s, D] latent states for a single sample method: 'pca' or 'tsne' title: Plot title Returns: Dict with 'coords': [K+1, 2] projected centroids per step """ K_plus_1, N_s, D = trajectory.shape # Pool each step's tokens into a single vector centroids = trajectory.mean(dim=1).detach().cpu().numpy() # [K+1, D] if method == "pca": # Simple PCA (no sklearn dependency) centered = centroids - centroids.mean(axis=0) cov = np.cov(centered.T) eigenvalues, eigenvectors = np.linalg.eigh(cov) # Take top 2 components idx = np.argsort(eigenvalues)[::-1][:2] proj_matrix = eigenvectors[:, idx] coords = centered @ proj_matrix else: # Fallback to PCA for simplicity centered = centroids - centroids.mean(axis=0) U, S, Vt = np.linalg.svd(centered, full_matrices=False) coords = U[:, :2] * S[:2] return { 'coords': coords, # [K+1, 2] 'centroids': centroids, # [K+1, D] original 'step_labels': [f'z_{k}' for k in range(K_plus_1)], } def visualize_evidence_gates( model, sample_output: Dict[str, torch.Tensor], ) -> Dict[str, np.ndarray]: """ Extract and visualize evidence gate activations per rollout step. Shows how much evidence flows into each step of the rollout. Early steps may attend more to visual evidence, while later steps rely more on accumulated reasoning. Args: model: MRJEPAModel instance sample_output: Forward pass output dict Returns: Dict with gate activation statistics per step """ # This requires hooks or storing gate values during forward pass # For now, return placeholder structure gate_stats = { 'mean_gate_values': [], 'gate_entropy': [], } # Access predictor layers' evidence gates for i, layer in enumerate(model.latent_rollout.predictor_layers): if hasattr(layer.evidence_gate, 'gate_proj'): # Could install hooks here for detailed analysis pass return gate_stats def compute_trajectory_metrics( trajectory: torch.Tensor, # [B, K+1, N_s, D] ) -> Dict[str, float]: """ Compute analytical metrics on the latent trajectory. Useful for ablation analysis: - Inter-step distance: how much the state changes per step - Trajectory length: total path length in latent space - Convergence rate: diminishing step sizes indicate convergence - State diversity: variance within each step's tokens """ B, K_plus_1, N_s, D = trajectory.shape # Pool to centroids centroids = trajectory.mean(dim=2) # [B, K+1, D] # Inter-step distances step_distances = [] for k in range(K_plus_1 - 1): dist = torch.norm(centroids[:, k+1] - centroids[:, k], dim=-1) # [B] step_distances.append(dist.mean().item()) # Trajectory length total_length = sum(step_distances) # Convergence rate (ratio of last step distance to first) convergence = step_distances[-1] / max(step_distances[0], 1e-6) if step_distances else 1.0 # State diversity per step diversity = [] for k in range(K_plus_1): var = trajectory[:, k].var(dim=1).mean().item() # Avg variance across tokens diversity.append(var) return { 'step_distances': step_distances, 'trajectory_length': total_length, 'convergence_rate': convergence, 'state_diversity': diversity, 'avg_step_distance': total_length / max(K_plus_1 - 1, 1), }