MR-JEPA / mr_jepa /models /target_encoder.py
JorgeAV's picture
fix: target_encoder.py — respect config.jepa_loss_fn (smooth_l1/mse/cosine) instead of hardcoded MSE
ac52da3 verified
"""
Target Encoder (EMA) for MR-JEPA.
The target encoder generates the supervision signal for the JEPA objective.
It is an exponential moving average (EMA) copy of the online encoder
(evidence memory + rollout module).
From I-JEPA:
θ̄ ← m·θ̄ + (1-m)·θ
where m follows a cosine schedule from 0.996 → 1.0
The target encoder processes the same inputs but with stop-gradient,
producing target latent states z*_k that the online predictor must predict.
From LeWorldModel: We also add SIGReg anti-collapse regularization
to prevent the representation space from collapsing.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy
from typing import Optional, Dict
from ..configs.model_config import JEPAObjectiveConfig
class TargetEncoder(nn.Module):
"""
EMA target encoder that generates JEPA targets.
This module wraps a copy of the online encoder (evidence memory + rollout)
and updates its weights via exponential moving average.
The target latent trajectory is used as the ground truth for the
JEPA prediction loss: ||z_predicted_k - sg(z*_k)||²
"""
def __init__(
self,
online_evidence_memory: nn.Module,
online_rollout: nn.Module,
config: JEPAObjectiveConfig,
):
super().__init__()
self.config = config
# Deep copy of online modules
self.target_evidence_memory = copy.deepcopy(online_evidence_memory)
self.target_rollout = copy.deepcopy(online_rollout)
# Freeze target encoder (no gradient)
for param in self.target_evidence_memory.parameters():
param.requires_grad = False
for param in self.target_rollout.parameters():
param.requires_grad = False
# EMA schedule tracking
self._current_momentum = config.ema_momentum_base
@torch.no_grad()
def update_ema(
self,
online_evidence_memory: nn.Module,
online_rollout: nn.Module,
step: int,
total_steps: int,
):
"""
Update target encoder weights via EMA.
From I-JEPA: cosine schedule from base momentum to 1.0
m(t) = 1 - (1 - m_base) * (1 + cos(π * t / T)) / 2
"""
# Compute momentum
if self.config.ema_schedule == "cosine":
progress = step / max(total_steps, 1)
momentum = self.config.ema_momentum_end - \
(self.config.ema_momentum_end - self.config.ema_momentum_base) * \
(1 + math.cos(math.pi * progress)) / 2
elif self.config.ema_schedule == "linear":
progress = step / max(total_steps, 1)
momentum = self.config.ema_momentum_base + \
(self.config.ema_momentum_end - self.config.ema_momentum_base) * progress
else: # constant
momentum = self.config.ema_momentum_base
self._current_momentum = momentum
# Update evidence memory
for online_p, target_p in zip(
online_evidence_memory.parameters(),
self.target_evidence_memory.parameters()
):
target_p.data.mul_(momentum).add_(online_p.data, alpha=1 - momentum)
# Update rollout module
for online_p, target_p in zip(
online_rollout.parameters(),
self.target_rollout.parameters()
):
target_p.data.mul_(momentum).add_(online_p.data, alpha=1 - momentum)
@torch.no_grad()
def forward(
self,
visual_tokens: torch.Tensor,
text_tokens: torch.Tensor,
text_mask: torch.Tensor,
**enriched_kwargs,
) -> Dict[str, torch.Tensor]:
"""
Generate target latent trajectory (no gradient).
Returns:
dict with:
'target_trajectory': [B, K+1, N_s, D] - target states
'target_evidence': [B, N_e, D] - target evidence tokens
"""
evidence_output = self.target_evidence_memory(
visual_tokens=visual_tokens,
text_tokens=text_tokens,
text_mask=text_mask,
**enriched_kwargs,
)
target_evidence = evidence_output['evidence_tokens']
rollout_output = self.target_rollout(
evidence_tokens=target_evidence,
)
return {
'target_trajectory': rollout_output['trajectory'],
'target_evidence': target_evidence,
}
class SIGRegLoss(nn.Module):
"""
Sketched Isotropic Gaussian Regularizer (from LeWorldModel).
Prevents representation collapse by encouraging latent embeddings
to match an isotropic Gaussian distribution.
Uses random projections + Epps-Pulley test statistic.
SIGReg(Z) = (1/M) Σ_m T(Z @ u_m)
"""
def __init__(self, hidden_dim: int, num_projections: int = 1024):
super().__init__()
self.num_projections = num_projections
self.register_buffer(
'projections',
F.normalize(torch.randn(hidden_dim, num_projections), dim=0)
)
def _epps_pulley_statistic(self, h: torch.Tensor) -> torch.Tensor:
"""
Compute Epps-Pulley test statistic for univariate normality.
Simplified: variance + kurtosis penalty.
"""
h_mean = h.mean()
h_std = h.std() + 1e-6
h_norm = (h - h_mean) / h_std
variance = h_norm.var()
kurtosis = ((h_norm ** 4).mean() - 3).abs()
return (variance - 1.0) ** 2 + 0.5 * kurtosis
def forward(self, z: torch.Tensor) -> torch.Tensor:
"""
Compute SIGReg loss.
Args:
z: Latent embeddings [B, N, D] or [B*N, D]
Returns:
Scalar SIGReg loss
"""
if z.dim() == 3:
B, N, D = z.shape
z_flat = z.reshape(B * N, D)
else:
z_flat = z
projections = z_flat @ self.projections # [B*N, M]
losses = []
for m in range(min(self.num_projections, 64)):
losses.append(self._epps_pulley_statistic(projections[:, m]))
return torch.stack(losses).mean()
class VICRegLoss(nn.Module):
"""
VICReg-style regularization (alternative to SIGReg).
Three terms:
- Variance: keep feature std above a threshold
- Invariance: (handled by prediction loss)
- Covariance: decorrelate features
"""
def __init__(self, var_weight: float = 1.0, cov_weight: float = 0.04):
super().__init__()
self.var_weight = var_weight
self.cov_weight = cov_weight
def forward(self, z: torch.Tensor) -> torch.Tensor:
if z.dim() == 3:
z = z.reshape(-1, z.size(-1))
# Variance: penalize if std drops below 1
std = z.std(dim=0)
var_loss = F.relu(1.0 - std).mean()
# Covariance: penalize off-diagonal correlations
z_centered = z - z.mean(dim=0, keepdim=True)
N = z_centered.size(0)
cov = (z_centered.T @ z_centered) / (N - 1)
D = cov.size(0)
off_diag = cov.flatten()[:-1].view(D - 1, D + 1)[:, 1:].flatten()
cov_loss = (off_diag ** 2).mean()
return self.var_weight * var_loss + self.cov_weight * cov_loss
class JEPALoss(nn.Module):
"""
Complete JEPA objective for MR-JEPA.
Supports three loss functions (controlled by config.jepa_loss_fn):
- smooth_l1: SmoothL1Loss (hybrid default, robust to outliers)
- mse: MSE / L2 (original I-JEPA)
- cosine: Cosine similarity loss (purist default)
Plus anti-collapse regularization:
L_total = L_JEPA + λ * SIGReg(Z) + L_task + α * L_gen
"""
def __init__(self, config: JEPAObjectiveConfig, hidden_dim: int):
super().__init__()
self.config = config
# Anti-collapse
if config.use_sigreg:
self.sigreg = SIGRegLoss(hidden_dim, config.sigreg_num_projections)
if config.use_vicreg:
self.vicreg = VICRegLoss(config.vicreg_var_weight, config.vicreg_cov_weight)
def compute_jepa_loss(
self,
predicted_trajectory: torch.Tensor, # [B, K+1, N_s, D]
target_trajectory: torch.Tensor, # [B, K+1, N_s, D]
) -> torch.Tensor:
"""
Compute prediction loss between online and target trajectories.
Only compute loss for steps k=1..K (not z₀, which is deterministic).
Loss function selected by config.jepa_loss_fn.
"""
# Skip z₀ (step 0) — only supervise predicted states
pred = predicted_trajectory[:, 1:] # [B, K, N_s, D]
target = target_trajectory[:, 1:].detach() # [B, K, N_s, D]
loss_fn = self.config.jepa_loss_fn
if loss_fn == "smooth_l1":
return F.smooth_l1_loss(pred, target)
elif loss_fn == "mse":
return F.mse_loss(pred, target)
elif loss_fn == "cosine":
# Cosine similarity loss: 1 - cos(pred, target), averaged
pred_flat = pred.reshape(-1, pred.size(-1))
target_flat = target.reshape(-1, target.size(-1))
cos_sim = F.cosine_similarity(pred_flat, target_flat, dim=-1)
return (1 - cos_sim).mean()
else:
raise ValueError(f"Unknown JEPA loss function: {loss_fn}")
def compute_regularization(
self,
trajectory: torch.Tensor, # [B, K+1, N_s, D]
) -> torch.Tensor:
"""Compute anti-collapse regularization."""
reg_loss = torch.tensor(0.0, device=trajectory.device)
if self.config.use_sigreg:
B, Kp1, N_s, D = trajectory.shape
for k in range(Kp1):
reg_loss = reg_loss + self.sigreg(trajectory[:, k])
reg_loss = reg_loss / Kp1
reg_loss = self.config.sigreg_weight * reg_loss
if self.config.use_vicreg:
B, Kp1, N_s, D = trajectory.shape
vicreg_loss = torch.tensor(0.0, device=trajectory.device)
for k in range(Kp1):
vicreg_loss = vicreg_loss + self.vicreg(trajectory[:, k])
vicreg_loss = vicreg_loss / Kp1
reg_loss = reg_loss + vicreg_loss
return reg_loss
def forward(
self,
predicted_trajectory: torch.Tensor,
target_trajectory: torch.Tensor,
task_loss: torch.Tensor,
gen_loss: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Compute total MR-JEPA loss.
Returns dict with individual loss components for logging.
"""
# JEPA prediction loss
jepa_loss = self.compute_jepa_loss(predicted_trajectory, target_trajectory)
# Anti-collapse regularization
reg_loss = self.compute_regularization(predicted_trajectory)
# Total loss
total = (
self.config.jepa_loss_weight * jepa_loss +
self.config.task_loss_weight * task_loss +
reg_loss
)
losses = {
'total_loss': total,
'jepa_loss': jepa_loss,
'task_loss': task_loss,
'reg_loss': reg_loss,
}
if gen_loss is not None:
total = total + self.config.generative_loss_weight * gen_loss
losses['total_loss'] = total
losses['gen_loss'] = gen_loss
return losses