|
|
""" |
|
|
Joint Embedding Predictive Architecture (JEPA) for PDE dynamics. |
|
|
|
|
|
Spatial JEPA: encoder produces spatial feature maps, predictor operates |
|
|
on spatial features, loss computed on spatial latent representations. |
|
|
Prevents collapse via VICReg regularization. |
|
|
""" |
|
|
import copy |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvBlock(nn.Module): |
|
|
def __init__(self, in_ch, out_ch, stride=1): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1) |
|
|
self.bn1 = nn.BatchNorm2d(out_ch) |
|
|
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) |
|
|
self.bn2 = nn.BatchNorm2d(out_ch) |
|
|
self.skip = ( |
|
|
nn.Sequential(nn.Conv2d(in_ch, out_ch, 1, stride=stride), nn.BatchNorm2d(out_ch)) |
|
|
if in_ch != out_ch or stride != 1 |
|
|
else nn.Identity() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
h = F.gelu(self.bn1(self.conv1(x))) |
|
|
h = self.bn2(self.conv2(h)) |
|
|
return F.gelu(h + self.skip(x)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpatialEncoder(nn.Module): |
|
|
"""ResNet-style encoder outputting spatial latent maps. |
|
|
|
|
|
Input: [B, C_in, H, W] |
|
|
Output: [B, lat_ch, H/8, W/8] |
|
|
""" |
|
|
|
|
|
def __init__(self, in_channels, latent_channels=128, base_ch=32): |
|
|
super().__init__() |
|
|
self.stem = nn.Sequential( |
|
|
nn.Conv2d(in_channels, base_ch, 3, padding=1), |
|
|
nn.BatchNorm2d(base_ch), |
|
|
nn.GELU(), |
|
|
) |
|
|
self.layer1 = ConvBlock(base_ch, base_ch * 2, stride=2) |
|
|
self.layer2 = ConvBlock(base_ch * 2, base_ch * 4, stride=2) |
|
|
self.layer3 = ConvBlock(base_ch * 4, latent_channels, stride=2) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.stem(x) |
|
|
x = self.layer1(x) |
|
|
x = self.layer2(x) |
|
|
x = self.layer3(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpatialPredictor(nn.Module): |
|
|
"""Lightweight CNN predictor on spatial latent maps. |
|
|
|
|
|
Input/Output: [B, lat_ch, H', W'] |
|
|
""" |
|
|
|
|
|
def __init__(self, latent_channels=128, hidden_channels=256): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Conv2d(latent_channels, hidden_channels, 3, padding=1), |
|
|
nn.BatchNorm2d(hidden_channels), |
|
|
nn.GELU(), |
|
|
nn.Conv2d(hidden_channels, hidden_channels, 3, padding=1), |
|
|
nn.BatchNorm2d(hidden_channels), |
|
|
nn.GELU(), |
|
|
nn.Conv2d(hidden_channels, latent_channels, 3, padding=1), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def vicreg_loss(z_pred, z_target, sim_w=25.0, var_w=25.0, cov_w=1.0): |
|
|
"""VICReg loss on spatial features (flattened to [B, D]). |
|
|
|
|
|
Args: |
|
|
z_pred: [B, D] predicted latent. |
|
|
z_target: [B, D] target latent (detached). |
|
|
sim_w, var_w, cov_w: loss weights. |
|
|
|
|
|
Returns: |
|
|
total loss, dict of components. |
|
|
""" |
|
|
|
|
|
sim_loss = F.mse_loss(z_pred, z_target) |
|
|
|
|
|
|
|
|
std_p = torch.sqrt(z_pred.var(dim=0) + 1e-4) |
|
|
std_t = torch.sqrt(z_target.var(dim=0) + 1e-4) |
|
|
var_loss = F.relu(1 - std_p).mean() + F.relu(1 - std_t).mean() |
|
|
|
|
|
|
|
|
B, D = z_pred.shape |
|
|
zp = z_pred - z_pred.mean(0) |
|
|
zt = z_target - z_target.mean(0) |
|
|
cov_p = (zp.T @ zp) / max(B - 1, 1) |
|
|
cov_t = (zt.T @ zt) / max(B - 1, 1) |
|
|
mask = ~torch.eye(D, device=z_pred.device).bool() |
|
|
cov_loss = cov_p[mask].pow(2).sum() / D + cov_t[mask].pow(2).sum() / D |
|
|
|
|
|
total = sim_w * sim_loss + var_w * var_loss + cov_w * cov_loss |
|
|
return total, {"sim": sim_loss.item(), "var": var_loss.item(), "cov": cov_loss.item()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class JEPA(nn.Module): |
|
|
"""Spatial JEPA for PDE dynamics prediction. |
|
|
|
|
|
Online encoder + predictor learn to predict the target encoder's |
|
|
representation of the next frame. The target encoder is an EMA |
|
|
copy of the online encoder. |
|
|
|
|
|
Args: |
|
|
in_channels: number of input field channels. |
|
|
latent_channels: spatial latent feature map channels. |
|
|
base_ch: encoder base width. |
|
|
pred_hidden: predictor hidden channels. |
|
|
ema_decay: starting EMA decay. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels, |
|
|
latent_channels=128, |
|
|
base_ch=32, |
|
|
pred_hidden=256, |
|
|
ema_decay=0.996, |
|
|
): |
|
|
super().__init__() |
|
|
self.online_encoder = SpatialEncoder(in_channels, latent_channels, base_ch) |
|
|
self.predictor = SpatialPredictor(latent_channels, pred_hidden) |
|
|
self.target_encoder = copy.deepcopy(self.online_encoder) |
|
|
self.ema_decay = ema_decay |
|
|
|
|
|
|
|
|
for p in self.target_encoder.parameters(): |
|
|
p.requires_grad_(False) |
|
|
|
|
|
@torch.no_grad() |
|
|
def update_target(self): |
|
|
"""EMA update of target encoder.""" |
|
|
for pt, po in zip(self.target_encoder.parameters(), self.online_encoder.parameters()): |
|
|
pt.data.lerp_(po.data, 1 - self.ema_decay) |
|
|
|
|
|
def set_ema_decay(self, decay): |
|
|
"""Update EMA decay (e.g. cosine schedule from 0.996 to 1.0).""" |
|
|
self.ema_decay = decay |
|
|
|
|
|
def forward(self, x_input, x_target): |
|
|
""" |
|
|
Args: |
|
|
x_input: current frame(s) [B, C, H, W] |
|
|
x_target: next frame(s) [B, C, H, W] |
|
|
|
|
|
Returns: |
|
|
z_pred: predicted spatial latent [B, lat_ch, H', W'] |
|
|
z_target: target spatial latent [B, lat_ch, H', W'] |
|
|
""" |
|
|
z_online = self.online_encoder(x_input) |
|
|
z_pred = self.predictor(z_online) |
|
|
|
|
|
with torch.no_grad(): |
|
|
z_target = self.target_encoder(x_target) |
|
|
|
|
|
return z_pred, z_target |
|
|
|
|
|
def compute_loss(self, x_input, x_target): |
|
|
"""Full forward + loss computation. |
|
|
|
|
|
VICReg is computed on channel vectors after spatial averaging |
|
|
to keep the covariance matrix small (D = latent_channels). |
|
|
|
|
|
Returns: |
|
|
loss: scalar. |
|
|
metrics: dict. |
|
|
""" |
|
|
z_pred, z_target = self(x_input, x_target) |
|
|
|
|
|
|
|
|
spatial_mse = F.mse_loss(z_pred, z_target.detach()) |
|
|
|
|
|
|
|
|
zp_avg = z_pred.mean(dim=(-2, -1)) |
|
|
zt_avg = z_target.mean(dim=(-2, -1)) |
|
|
|
|
|
vicreg, vicreg_m = vicreg_loss(zp_avg, zt_avg.detach()) |
|
|
|
|
|
|
|
|
loss = spatial_mse + 0.1 * vicreg |
|
|
metrics = { |
|
|
"sim": vicreg_m["sim"], |
|
|
"var": vicreg_m["var"], |
|
|
"cov": vicreg_m["cov"], |
|
|
"spatial_mse": spatial_mse.item(), |
|
|
} |
|
|
return loss, metrics |
|
|
|