| """ |
| Visualization utilities for MR-JEPA. |
| |
| Tools for analyzing and visualizing: |
| - Latent trajectory evolution (z₀ → z₁ → z₂ → z₃) |
| - Evidence gate activations per rollout step |
| - Attention maps between state and evidence |
| - t-SNE/UMAP of latent states across benchmarks |
| """ |
|
|
| import torch |
| import numpy as np |
| from typing import Optional, Dict, List |
|
|
|
|
| def visualize_trajectory( |
| trajectory: torch.Tensor, |
| method: str = "pca", |
| title: str = "Latent Trajectory Evolution", |
| ) -> Dict[str, np.ndarray]: |
| """ |
| Visualize the latent trajectory z₀→z₁→...→z_K. |
| |
| Projects high-dimensional states into 2D for plotting. |
| Returns coordinates that can be plotted with matplotlib. |
| |
| Args: |
| trajectory: [K+1, N_s, D] latent states for a single sample |
| method: 'pca' or 'tsne' |
| title: Plot title |
| |
| Returns: |
| Dict with 'coords': [K+1, 2] projected centroids per step |
| """ |
| K_plus_1, N_s, D = trajectory.shape |
| |
| |
| centroids = trajectory.mean(dim=1).detach().cpu().numpy() |
| |
| if method == "pca": |
| |
| centered = centroids - centroids.mean(axis=0) |
| cov = np.cov(centered.T) |
| eigenvalues, eigenvectors = np.linalg.eigh(cov) |
| |
| idx = np.argsort(eigenvalues)[::-1][:2] |
| proj_matrix = eigenvectors[:, idx] |
| coords = centered @ proj_matrix |
| else: |
| |
| centered = centroids - centroids.mean(axis=0) |
| U, S, Vt = np.linalg.svd(centered, full_matrices=False) |
| coords = U[:, :2] * S[:2] |
| |
| return { |
| 'coords': coords, |
| 'centroids': centroids, |
| 'step_labels': [f'z_{k}' for k in range(K_plus_1)], |
| } |
|
|
|
|
| def visualize_evidence_gates( |
| model, |
| sample_output: Dict[str, torch.Tensor], |
| ) -> Dict[str, np.ndarray]: |
| """ |
| Extract and visualize evidence gate activations per rollout step. |
| |
| Shows how much evidence flows into each step of the rollout. |
| Early steps may attend more to visual evidence, while later steps |
| rely more on accumulated reasoning. |
| |
| Args: |
| model: MRJEPAModel instance |
| sample_output: Forward pass output dict |
| |
| Returns: |
| Dict with gate activation statistics per step |
| """ |
| |
| |
| gate_stats = { |
| 'mean_gate_values': [], |
| 'gate_entropy': [], |
| } |
| |
| |
| for i, layer in enumerate(model.latent_rollout.predictor_layers): |
| if hasattr(layer.evidence_gate, 'gate_proj'): |
| |
| pass |
| |
| return gate_stats |
|
|
|
|
| def compute_trajectory_metrics( |
| trajectory: torch.Tensor, |
| ) -> Dict[str, float]: |
| """ |
| Compute analytical metrics on the latent trajectory. |
| |
| Useful for ablation analysis: |
| - Inter-step distance: how much the state changes per step |
| - Trajectory length: total path length in latent space |
| - Convergence rate: diminishing step sizes indicate convergence |
| - State diversity: variance within each step's tokens |
| """ |
| B, K_plus_1, N_s, D = trajectory.shape |
| |
| |
| centroids = trajectory.mean(dim=2) |
| |
| |
| step_distances = [] |
| for k in range(K_plus_1 - 1): |
| dist = torch.norm(centroids[:, k+1] - centroids[:, k], dim=-1) |
| step_distances.append(dist.mean().item()) |
| |
| |
| total_length = sum(step_distances) |
| |
| |
| convergence = step_distances[-1] / max(step_distances[0], 1e-6) if step_distances else 1.0 |
| |
| |
| diversity = [] |
| for k in range(K_plus_1): |
| var = trajectory[:, k].var(dim=1).mean().item() |
| diversity.append(var) |
| |
| return { |
| 'step_distances': step_distances, |
| 'trajectory_length': total_length, |
| 'convergence_rate': convergence, |
| 'state_diversity': diversity, |
| 'avg_step_distance': total_length / max(K_plus_1 - 1, 1), |
| } |
|
|