File size: 4,435 Bytes
dba2c56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
Visualization utilities for MR-JEPA.

Tools for analyzing and visualizing:
- Latent trajectory evolution (z₀ → z₁ → z₂ → z₃)
- Evidence gate activations per rollout step
- Attention maps between state and evidence
- t-SNE/UMAP of latent states across benchmarks
"""

import torch
import numpy as np
from typing import Optional, Dict, List


def visualize_trajectory(
    trajectory: torch.Tensor,  # [K+1, N_s, D]
    method: str = "pca",
    title: str = "Latent Trajectory Evolution",
) -> Dict[str, np.ndarray]:
    """
    Visualize the latent trajectory z₀→z₁→...→z_K.
    
    Projects high-dimensional states into 2D for plotting.
    Returns coordinates that can be plotted with matplotlib.
    
    Args:
        trajectory: [K+1, N_s, D] latent states for a single sample
        method: 'pca' or 'tsne'
        title: Plot title
        
    Returns:
        Dict with 'coords': [K+1, 2] projected centroids per step
    """
    K_plus_1, N_s, D = trajectory.shape
    
    # Pool each step's tokens into a single vector
    centroids = trajectory.mean(dim=1).detach().cpu().numpy()  # [K+1, D]
    
    if method == "pca":
        # Simple PCA (no sklearn dependency)
        centered = centroids - centroids.mean(axis=0)
        cov = np.cov(centered.T)
        eigenvalues, eigenvectors = np.linalg.eigh(cov)
        # Take top 2 components
        idx = np.argsort(eigenvalues)[::-1][:2]
        proj_matrix = eigenvectors[:, idx]
        coords = centered @ proj_matrix
    else:
        # Fallback to PCA for simplicity
        centered = centroids - centroids.mean(axis=0)
        U, S, Vt = np.linalg.svd(centered, full_matrices=False)
        coords = U[:, :2] * S[:2]
    
    return {
        'coords': coords,           # [K+1, 2]
        'centroids': centroids,      # [K+1, D] original
        'step_labels': [f'z_{k}' for k in range(K_plus_1)],
    }


def visualize_evidence_gates(
    model,
    sample_output: Dict[str, torch.Tensor],
) -> Dict[str, np.ndarray]:
    """
    Extract and visualize evidence gate activations per rollout step.
    
    Shows how much evidence flows into each step of the rollout.
    Early steps may attend more to visual evidence, while later steps
    rely more on accumulated reasoning.
    
    Args:
        model: MRJEPAModel instance
        sample_output: Forward pass output dict
        
    Returns:
        Dict with gate activation statistics per step
    """
    # This requires hooks or storing gate values during forward pass
    # For now, return placeholder structure
    gate_stats = {
        'mean_gate_values': [],
        'gate_entropy': [],
    }
    
    # Access predictor layers' evidence gates
    for i, layer in enumerate(model.latent_rollout.predictor_layers):
        if hasattr(layer.evidence_gate, 'gate_proj'):
            # Could install hooks here for detailed analysis
            pass
    
    return gate_stats


def compute_trajectory_metrics(
    trajectory: torch.Tensor,  # [B, K+1, N_s, D]
) -> Dict[str, float]:
    """
    Compute analytical metrics on the latent trajectory.
    
    Useful for ablation analysis:
    - Inter-step distance: how much the state changes per step
    - Trajectory length: total path length in latent space
    - Convergence rate: diminishing step sizes indicate convergence
    - State diversity: variance within each step's tokens
    """
    B, K_plus_1, N_s, D = trajectory.shape
    
    # Pool to centroids
    centroids = trajectory.mean(dim=2)  # [B, K+1, D]
    
    # Inter-step distances
    step_distances = []
    for k in range(K_plus_1 - 1):
        dist = torch.norm(centroids[:, k+1] - centroids[:, k], dim=-1)  # [B]
        step_distances.append(dist.mean().item())
    
    # Trajectory length
    total_length = sum(step_distances)
    
    # Convergence rate (ratio of last step distance to first)
    convergence = step_distances[-1] / max(step_distances[0], 1e-6) if step_distances else 1.0
    
    # State diversity per step
    diversity = []
    for k in range(K_plus_1):
        var = trajectory[:, k].var(dim=1).mean().item()  # Avg variance across tokens
        diversity.append(var)
    
    return {
        'step_distances': step_distances,
        'trajectory_length': total_length,
        'convergence_rate': convergence,
        'state_diversity': diversity,
        'avg_step_distance': total_length / max(K_plus_1 - 1, 1),
    }