""" Target Encoder (EMA) for MR-JEPA. The target encoder generates the supervision signal for the JEPA objective. It is an exponential moving average (EMA) copy of the online encoder (evidence memory + rollout module). From I-JEPA: θ̄ ← m·θ̄ + (1-m)·θ where m follows a cosine schedule from 0.996 → 1.0 The target encoder processes the same inputs but with stop-gradient, producing target latent states z*_k that the online predictor must predict. From LeWorldModel: We also add SIGReg anti-collapse regularization to prevent the representation space from collapsing. """ import torch import torch.nn as nn import torch.nn.functional as F import math import copy from typing import Optional, Dict from ..configs.model_config import JEPAObjectiveConfig class TargetEncoder(nn.Module): """ EMA target encoder that generates JEPA targets. This module wraps a copy of the online encoder (evidence memory + rollout) and updates its weights via exponential moving average. The target latent trajectory is used as the ground truth for the JEPA prediction loss: ||z_predicted_k - sg(z*_k)||² """ def __init__( self, online_evidence_memory: nn.Module, online_rollout: nn.Module, config: JEPAObjectiveConfig, ): super().__init__() self.config = config # Deep copy of online modules self.target_evidence_memory = copy.deepcopy(online_evidence_memory) self.target_rollout = copy.deepcopy(online_rollout) # Freeze target encoder (no gradient) for param in self.target_evidence_memory.parameters(): param.requires_grad = False for param in self.target_rollout.parameters(): param.requires_grad = False # EMA schedule tracking self._current_momentum = config.ema_momentum_base @torch.no_grad() def update_ema( self, online_evidence_memory: nn.Module, online_rollout: nn.Module, step: int, total_steps: int, ): """ Update target encoder weights via EMA. From I-JEPA: cosine schedule from base momentum to 1.0 m(t) = 1 - (1 - m_base) * (1 + cos(π * t / T)) / 2 """ # Compute momentum if self.config.ema_schedule == "cosine": progress = step / max(total_steps, 1) momentum = self.config.ema_momentum_end - \ (self.config.ema_momentum_end - self.config.ema_momentum_base) * \ (1 + math.cos(math.pi * progress)) / 2 elif self.config.ema_schedule == "linear": progress = step / max(total_steps, 1) momentum = self.config.ema_momentum_base + \ (self.config.ema_momentum_end - self.config.ema_momentum_base) * progress else: # constant momentum = self.config.ema_momentum_base self._current_momentum = momentum # Update evidence memory for online_p, target_p in zip( online_evidence_memory.parameters(), self.target_evidence_memory.parameters() ): target_p.data.mul_(momentum).add_(online_p.data, alpha=1 - momentum) # Update rollout module for online_p, target_p in zip( online_rollout.parameters(), self.target_rollout.parameters() ): target_p.data.mul_(momentum).add_(online_p.data, alpha=1 - momentum) @torch.no_grad() def forward( self, visual_tokens: torch.Tensor, text_tokens: torch.Tensor, text_mask: torch.Tensor, **enriched_kwargs, ) -> Dict[str, torch.Tensor]: """ Generate target latent trajectory (no gradient). Returns: dict with: 'target_trajectory': [B, K+1, N_s, D] - target states 'target_evidence': [B, N_e, D] - target evidence tokens """ evidence_output = self.target_evidence_memory( visual_tokens=visual_tokens, text_tokens=text_tokens, text_mask=text_mask, **enriched_kwargs, ) target_evidence = evidence_output['evidence_tokens'] rollout_output = self.target_rollout( evidence_tokens=target_evidence, ) return { 'target_trajectory': rollout_output['trajectory'], 'target_evidence': target_evidence, } class SIGRegLoss(nn.Module): """ Sketched Isotropic Gaussian Regularizer (from LeWorldModel). Prevents representation collapse by encouraging latent embeddings to match an isotropic Gaussian distribution. Uses random projections + Epps-Pulley test statistic. SIGReg(Z) = (1/M) Σ_m T(Z @ u_m) """ def __init__(self, hidden_dim: int, num_projections: int = 1024): super().__init__() self.num_projections = num_projections self.register_buffer( 'projections', F.normalize(torch.randn(hidden_dim, num_projections), dim=0) ) def _epps_pulley_statistic(self, h: torch.Tensor) -> torch.Tensor: """ Compute Epps-Pulley test statistic for univariate normality. Simplified: variance + kurtosis penalty. """ h_mean = h.mean() h_std = h.std() + 1e-6 h_norm = (h - h_mean) / h_std variance = h_norm.var() kurtosis = ((h_norm ** 4).mean() - 3).abs() return (variance - 1.0) ** 2 + 0.5 * kurtosis def forward(self, z: torch.Tensor) -> torch.Tensor: """ Compute SIGReg loss. Args: z: Latent embeddings [B, N, D] or [B*N, D] Returns: Scalar SIGReg loss """ if z.dim() == 3: B, N, D = z.shape z_flat = z.reshape(B * N, D) else: z_flat = z projections = z_flat @ self.projections # [B*N, M] losses = [] for m in range(min(self.num_projections, 64)): losses.append(self._epps_pulley_statistic(projections[:, m])) return torch.stack(losses).mean() class VICRegLoss(nn.Module): """ VICReg-style regularization (alternative to SIGReg). Three terms: - Variance: keep feature std above a threshold - Invariance: (handled by prediction loss) - Covariance: decorrelate features """ def __init__(self, var_weight: float = 1.0, cov_weight: float = 0.04): super().__init__() self.var_weight = var_weight self.cov_weight = cov_weight def forward(self, z: torch.Tensor) -> torch.Tensor: if z.dim() == 3: z = z.reshape(-1, z.size(-1)) # Variance: penalize if std drops below 1 std = z.std(dim=0) var_loss = F.relu(1.0 - std).mean() # Covariance: penalize off-diagonal correlations z_centered = z - z.mean(dim=0, keepdim=True) N = z_centered.size(0) cov = (z_centered.T @ z_centered) / (N - 1) D = cov.size(0) off_diag = cov.flatten()[:-1].view(D - 1, D + 1)[:, 1:].flatten() cov_loss = (off_diag ** 2).mean() return self.var_weight * var_loss + self.cov_weight * cov_loss class JEPALoss(nn.Module): """ Complete JEPA objective for MR-JEPA. Supports three loss functions (controlled by config.jepa_loss_fn): - smooth_l1: SmoothL1Loss (hybrid default, robust to outliers) - mse: MSE / L2 (original I-JEPA) - cosine: Cosine similarity loss (purist default) Plus anti-collapse regularization: L_total = L_JEPA + λ * SIGReg(Z) + L_task + α * L_gen """ def __init__(self, config: JEPAObjectiveConfig, hidden_dim: int): super().__init__() self.config = config # Anti-collapse if config.use_sigreg: self.sigreg = SIGRegLoss(hidden_dim, config.sigreg_num_projections) if config.use_vicreg: self.vicreg = VICRegLoss(config.vicreg_var_weight, config.vicreg_cov_weight) def compute_jepa_loss( self, predicted_trajectory: torch.Tensor, # [B, K+1, N_s, D] target_trajectory: torch.Tensor, # [B, K+1, N_s, D] ) -> torch.Tensor: """ Compute prediction loss between online and target trajectories. Only compute loss for steps k=1..K (not z₀, which is deterministic). Loss function selected by config.jepa_loss_fn. """ # Skip z₀ (step 0) — only supervise predicted states pred = predicted_trajectory[:, 1:] # [B, K, N_s, D] target = target_trajectory[:, 1:].detach() # [B, K, N_s, D] loss_fn = self.config.jepa_loss_fn if loss_fn == "smooth_l1": return F.smooth_l1_loss(pred, target) elif loss_fn == "mse": return F.mse_loss(pred, target) elif loss_fn == "cosine": # Cosine similarity loss: 1 - cos(pred, target), averaged pred_flat = pred.reshape(-1, pred.size(-1)) target_flat = target.reshape(-1, target.size(-1)) cos_sim = F.cosine_similarity(pred_flat, target_flat, dim=-1) return (1 - cos_sim).mean() else: raise ValueError(f"Unknown JEPA loss function: {loss_fn}") def compute_regularization( self, trajectory: torch.Tensor, # [B, K+1, N_s, D] ) -> torch.Tensor: """Compute anti-collapse regularization.""" reg_loss = torch.tensor(0.0, device=trajectory.device) if self.config.use_sigreg: B, Kp1, N_s, D = trajectory.shape for k in range(Kp1): reg_loss = reg_loss + self.sigreg(trajectory[:, k]) reg_loss = reg_loss / Kp1 reg_loss = self.config.sigreg_weight * reg_loss if self.config.use_vicreg: B, Kp1, N_s, D = trajectory.shape vicreg_loss = torch.tensor(0.0, device=trajectory.device) for k in range(Kp1): vicreg_loss = vicreg_loss + self.vicreg(trajectory[:, k]) vicreg_loss = vicreg_loss / Kp1 reg_loss = reg_loss + vicreg_loss return reg_loss def forward( self, predicted_trajectory: torch.Tensor, target_trajectory: torch.Tensor, task_loss: torch.Tensor, gen_loss: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: """ Compute total MR-JEPA loss. Returns dict with individual loss components for logging. """ # JEPA prediction loss jepa_loss = self.compute_jepa_loss(predicted_trajectory, target_trajectory) # Anti-collapse regularization reg_loss = self.compute_regularization(predicted_trajectory) # Total loss total = ( self.config.jepa_loss_weight * jepa_loss + self.config.task_loss_weight * task_loss + reg_loss ) losses = { 'total_loss': total, 'jepa_loss': jepa_loss, 'task_loss': task_loss, 'reg_loss': reg_loss, } if gen_loss is not None: total = total + self.config.generative_loss_weight * gen_loss losses['total_loss'] = total losses['gen_loss'] = gen_loss return losses