| """ |
| Target Encoder (EMA) for MR-JEPA. |
| |
| The target encoder generates the supervision signal for the JEPA objective. |
| It is an exponential moving average (EMA) copy of the online encoder |
| (evidence memory + rollout module). |
| |
| From I-JEPA: |
| θ̄ ← m·θ̄ + (1-m)·θ |
| where m follows a cosine schedule from 0.996 → 1.0 |
| |
| The target encoder processes the same inputs but with stop-gradient, |
| producing target latent states z*_k that the online predictor must predict. |
| |
| From LeWorldModel: We also add SIGReg anti-collapse regularization |
| to prevent the representation space from collapsing. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| import copy |
| from typing import Optional, Dict |
|
|
| from ..configs.model_config import JEPAObjectiveConfig |
|
|
|
|
| class TargetEncoder(nn.Module): |
| """ |
| EMA target encoder that generates JEPA targets. |
| |
| This module wraps a copy of the online encoder (evidence memory + rollout) |
| and updates its weights via exponential moving average. |
| |
| The target latent trajectory is used as the ground truth for the |
| JEPA prediction loss: ||z_predicted_k - sg(z*_k)||² |
| """ |
| |
| def __init__( |
| self, |
| online_evidence_memory: nn.Module, |
| online_rollout: nn.Module, |
| config: JEPAObjectiveConfig, |
| ): |
| super().__init__() |
| self.config = config |
| |
| |
| self.target_evidence_memory = copy.deepcopy(online_evidence_memory) |
| self.target_rollout = copy.deepcopy(online_rollout) |
| |
| |
| for param in self.target_evidence_memory.parameters(): |
| param.requires_grad = False |
| for param in self.target_rollout.parameters(): |
| param.requires_grad = False |
| |
| |
| self._current_momentum = config.ema_momentum_base |
| |
| @torch.no_grad() |
| def update_ema( |
| self, |
| online_evidence_memory: nn.Module, |
| online_rollout: nn.Module, |
| step: int, |
| total_steps: int, |
| ): |
| """ |
| Update target encoder weights via EMA. |
| |
| From I-JEPA: cosine schedule from base momentum to 1.0 |
| m(t) = 1 - (1 - m_base) * (1 + cos(π * t / T)) / 2 |
| """ |
| |
| if self.config.ema_schedule == "cosine": |
| progress = step / max(total_steps, 1) |
| momentum = self.config.ema_momentum_end - \ |
| (self.config.ema_momentum_end - self.config.ema_momentum_base) * \ |
| (1 + math.cos(math.pi * progress)) / 2 |
| elif self.config.ema_schedule == "linear": |
| progress = step / max(total_steps, 1) |
| momentum = self.config.ema_momentum_base + \ |
| (self.config.ema_momentum_end - self.config.ema_momentum_base) * progress |
| else: |
| momentum = self.config.ema_momentum_base |
| |
| self._current_momentum = momentum |
| |
| |
| for online_p, target_p in zip( |
| online_evidence_memory.parameters(), |
| self.target_evidence_memory.parameters() |
| ): |
| target_p.data.mul_(momentum).add_(online_p.data, alpha=1 - momentum) |
| |
| |
| for online_p, target_p in zip( |
| online_rollout.parameters(), |
| self.target_rollout.parameters() |
| ): |
| target_p.data.mul_(momentum).add_(online_p.data, alpha=1 - momentum) |
| |
| @torch.no_grad() |
| def forward( |
| self, |
| visual_tokens: torch.Tensor, |
| text_tokens: torch.Tensor, |
| text_mask: torch.Tensor, |
| **enriched_kwargs, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Generate target latent trajectory (no gradient). |
| |
| Returns: |
| dict with: |
| 'target_trajectory': [B, K+1, N_s, D] - target states |
| 'target_evidence': [B, N_e, D] - target evidence tokens |
| """ |
| evidence_output = self.target_evidence_memory( |
| visual_tokens=visual_tokens, |
| text_tokens=text_tokens, |
| text_mask=text_mask, |
| **enriched_kwargs, |
| ) |
| |
| target_evidence = evidence_output['evidence_tokens'] |
| |
| rollout_output = self.target_rollout( |
| evidence_tokens=target_evidence, |
| ) |
| |
| return { |
| 'target_trajectory': rollout_output['trajectory'], |
| 'target_evidence': target_evidence, |
| } |
|
|
|
|
| class SIGRegLoss(nn.Module): |
| """ |
| Sketched Isotropic Gaussian Regularizer (from LeWorldModel). |
| |
| Prevents representation collapse by encouraging latent embeddings |
| to match an isotropic Gaussian distribution. |
| |
| Uses random projections + Epps-Pulley test statistic. |
| SIGReg(Z) = (1/M) Σ_m T(Z @ u_m) |
| """ |
| |
| def __init__(self, hidden_dim: int, num_projections: int = 1024): |
| super().__init__() |
| self.num_projections = num_projections |
| self.register_buffer( |
| 'projections', |
| F.normalize(torch.randn(hidden_dim, num_projections), dim=0) |
| ) |
| |
| def _epps_pulley_statistic(self, h: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute Epps-Pulley test statistic for univariate normality. |
| Simplified: variance + kurtosis penalty. |
| """ |
| h_mean = h.mean() |
| h_std = h.std() + 1e-6 |
| h_norm = (h - h_mean) / h_std |
| |
| variance = h_norm.var() |
| kurtosis = ((h_norm ** 4).mean() - 3).abs() |
| |
| return (variance - 1.0) ** 2 + 0.5 * kurtosis |
| |
| def forward(self, z: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute SIGReg loss. |
| |
| Args: |
| z: Latent embeddings [B, N, D] or [B*N, D] |
| Returns: |
| Scalar SIGReg loss |
| """ |
| if z.dim() == 3: |
| B, N, D = z.shape |
| z_flat = z.reshape(B * N, D) |
| else: |
| z_flat = z |
| |
| projections = z_flat @ self.projections |
| |
| losses = [] |
| for m in range(min(self.num_projections, 64)): |
| losses.append(self._epps_pulley_statistic(projections[:, m])) |
| |
| return torch.stack(losses).mean() |
|
|
|
|
| class VICRegLoss(nn.Module): |
| """ |
| VICReg-style regularization (alternative to SIGReg). |
| |
| Three terms: |
| - Variance: keep feature std above a threshold |
| - Invariance: (handled by prediction loss) |
| - Covariance: decorrelate features |
| """ |
| |
| def __init__(self, var_weight: float = 1.0, cov_weight: float = 0.04): |
| super().__init__() |
| self.var_weight = var_weight |
| self.cov_weight = cov_weight |
| |
| def forward(self, z: torch.Tensor) -> torch.Tensor: |
| if z.dim() == 3: |
| z = z.reshape(-1, z.size(-1)) |
| |
| |
| std = z.std(dim=0) |
| var_loss = F.relu(1.0 - std).mean() |
| |
| |
| z_centered = z - z.mean(dim=0, keepdim=True) |
| N = z_centered.size(0) |
| cov = (z_centered.T @ z_centered) / (N - 1) |
| D = cov.size(0) |
| off_diag = cov.flatten()[:-1].view(D - 1, D + 1)[:, 1:].flatten() |
| cov_loss = (off_diag ** 2).mean() |
| |
| return self.var_weight * var_loss + self.cov_weight * cov_loss |
|
|
|
|
| class JEPALoss(nn.Module): |
| """ |
| Complete JEPA objective for MR-JEPA. |
| |
| Supports three loss functions (controlled by config.jepa_loss_fn): |
| - smooth_l1: SmoothL1Loss (hybrid default, robust to outliers) |
| - mse: MSE / L2 (original I-JEPA) |
| - cosine: Cosine similarity loss (purist default) |
| |
| Plus anti-collapse regularization: |
| L_total = L_JEPA + λ * SIGReg(Z) + L_task + α * L_gen |
| """ |
| |
| def __init__(self, config: JEPAObjectiveConfig, hidden_dim: int): |
| super().__init__() |
| self.config = config |
| |
| |
| if config.use_sigreg: |
| self.sigreg = SIGRegLoss(hidden_dim, config.sigreg_num_projections) |
| if config.use_vicreg: |
| self.vicreg = VICRegLoss(config.vicreg_var_weight, config.vicreg_cov_weight) |
| |
| def compute_jepa_loss( |
| self, |
| predicted_trajectory: torch.Tensor, |
| target_trajectory: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Compute prediction loss between online and target trajectories. |
| |
| Only compute loss for steps k=1..K (not z₀, which is deterministic). |
| Loss function selected by config.jepa_loss_fn. |
| """ |
| |
| pred = predicted_trajectory[:, 1:] |
| target = target_trajectory[:, 1:].detach() |
| |
| loss_fn = self.config.jepa_loss_fn |
| |
| if loss_fn == "smooth_l1": |
| return F.smooth_l1_loss(pred, target) |
| elif loss_fn == "mse": |
| return F.mse_loss(pred, target) |
| elif loss_fn == "cosine": |
| |
| pred_flat = pred.reshape(-1, pred.size(-1)) |
| target_flat = target.reshape(-1, target.size(-1)) |
| cos_sim = F.cosine_similarity(pred_flat, target_flat, dim=-1) |
| return (1 - cos_sim).mean() |
| else: |
| raise ValueError(f"Unknown JEPA loss function: {loss_fn}") |
| |
| def compute_regularization( |
| self, |
| trajectory: torch.Tensor, |
| ) -> torch.Tensor: |
| """Compute anti-collapse regularization.""" |
| reg_loss = torch.tensor(0.0, device=trajectory.device) |
| |
| if self.config.use_sigreg: |
| B, Kp1, N_s, D = trajectory.shape |
| for k in range(Kp1): |
| reg_loss = reg_loss + self.sigreg(trajectory[:, k]) |
| reg_loss = reg_loss / Kp1 |
| reg_loss = self.config.sigreg_weight * reg_loss |
| |
| if self.config.use_vicreg: |
| B, Kp1, N_s, D = trajectory.shape |
| vicreg_loss = torch.tensor(0.0, device=trajectory.device) |
| for k in range(Kp1): |
| vicreg_loss = vicreg_loss + self.vicreg(trajectory[:, k]) |
| vicreg_loss = vicreg_loss / Kp1 |
| reg_loss = reg_loss + vicreg_loss |
| |
| return reg_loss |
| |
| def forward( |
| self, |
| predicted_trajectory: torch.Tensor, |
| target_trajectory: torch.Tensor, |
| task_loss: torch.Tensor, |
| gen_loss: Optional[torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Compute total MR-JEPA loss. |
| |
| Returns dict with individual loss components for logging. |
| """ |
| |
| jepa_loss = self.compute_jepa_loss(predicted_trajectory, target_trajectory) |
| |
| |
| reg_loss = self.compute_regularization(predicted_trajectory) |
| |
| |
| total = ( |
| self.config.jepa_loss_weight * jepa_loss + |
| self.config.task_loss_weight * task_loss + |
| reg_loss |
| ) |
| |
| losses = { |
| 'total_loss': total, |
| 'jepa_loss': jepa_loss, |
| 'task_loss': task_loss, |
| 'reg_loss': reg_loss, |
| } |
| |
| if gen_loss is not None: |
| total = total + self.config.generative_loss_weight * gen_loss |
| losses['total_loss'] = total |
| losses['gen_loss'] = gen_loss |
| |
| return losses |
|
|