AlexWortega commited on
Commit
889bf64
·
verified ·
1 Parent(s): 70ced22

Upload jepa.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. jepa.py +228 -0
jepa.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Joint Embedding Predictive Architecture (JEPA) for PDE dynamics.
3
+
4
+ Spatial JEPA: encoder produces spatial feature maps, predictor operates
5
+ on spatial features, loss computed on spatial latent representations.
6
+ Prevents collapse via VICReg regularization.
7
+ """
8
+ import copy
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # Building blocks
16
+ # ---------------------------------------------------------------------------
17
+
18
+
19
+ class ConvBlock(nn.Module):
20
+ def __init__(self, in_ch, out_ch, stride=1):
21
+ super().__init__()
22
+ self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1)
23
+ self.bn1 = nn.BatchNorm2d(out_ch)
24
+ self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
25
+ self.bn2 = nn.BatchNorm2d(out_ch)
26
+ self.skip = (
27
+ nn.Sequential(nn.Conv2d(in_ch, out_ch, 1, stride=stride), nn.BatchNorm2d(out_ch))
28
+ if in_ch != out_ch or stride != 1
29
+ else nn.Identity()
30
+ )
31
+
32
+ def forward(self, x):
33
+ h = F.gelu(self.bn1(self.conv1(x)))
34
+ h = self.bn2(self.conv2(h))
35
+ return F.gelu(h + self.skip(x))
36
+
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # Spatial Encoder (outputs feature maps, not vectors)
40
+ # ---------------------------------------------------------------------------
41
+
42
+
43
+ class SpatialEncoder(nn.Module):
44
+ """ResNet-style encoder outputting spatial latent maps.
45
+
46
+ Input: [B, C_in, H, W]
47
+ Output: [B, lat_ch, H/8, W/8]
48
+ """
49
+
50
+ def __init__(self, in_channels, latent_channels=128, base_ch=32):
51
+ super().__init__()
52
+ self.stem = nn.Sequential(
53
+ nn.Conv2d(in_channels, base_ch, 3, padding=1),
54
+ nn.BatchNorm2d(base_ch),
55
+ nn.GELU(),
56
+ )
57
+ self.layer1 = ConvBlock(base_ch, base_ch * 2, stride=2) # /2
58
+ self.layer2 = ConvBlock(base_ch * 2, base_ch * 4, stride=2) # /4
59
+ self.layer3 = ConvBlock(base_ch * 4, latent_channels, stride=2) # /8
60
+
61
+ def forward(self, x):
62
+ x = self.stem(x)
63
+ x = self.layer1(x)
64
+ x = self.layer2(x)
65
+ x = self.layer3(x)
66
+ return x
67
+
68
+
69
+ # ---------------------------------------------------------------------------
70
+ # Spatial Predictor (conv-based, operates on feature maps)
71
+ # ---------------------------------------------------------------------------
72
+
73
+
74
+ class SpatialPredictor(nn.Module):
75
+ """Lightweight CNN predictor on spatial latent maps.
76
+
77
+ Input/Output: [B, lat_ch, H', W']
78
+ """
79
+
80
+ def __init__(self, latent_channels=128, hidden_channels=256):
81
+ super().__init__()
82
+ self.net = nn.Sequential(
83
+ nn.Conv2d(latent_channels, hidden_channels, 3, padding=1),
84
+ nn.BatchNorm2d(hidden_channels),
85
+ nn.GELU(),
86
+ nn.Conv2d(hidden_channels, hidden_channels, 3, padding=1),
87
+ nn.BatchNorm2d(hidden_channels),
88
+ nn.GELU(),
89
+ nn.Conv2d(hidden_channels, latent_channels, 3, padding=1),
90
+ )
91
+
92
+ def forward(self, x):
93
+ return self.net(x)
94
+
95
+
96
+ # ---------------------------------------------------------------------------
97
+ # VICReg-style loss (prevents representation collapse)
98
+ # ---------------------------------------------------------------------------
99
+
100
+
101
+ def vicreg_loss(z_pred, z_target, sim_w=25.0, var_w=25.0, cov_w=1.0):
102
+ """VICReg loss on spatial features (flattened to [B, D]).
103
+
104
+ Args:
105
+ z_pred: [B, D] predicted latent.
106
+ z_target: [B, D] target latent (detached).
107
+ sim_w, var_w, cov_w: loss weights.
108
+
109
+ Returns:
110
+ total loss, dict of components.
111
+ """
112
+ # Invariance
113
+ sim_loss = F.mse_loss(z_pred, z_target)
114
+
115
+ # Variance
116
+ std_p = torch.sqrt(z_pred.var(dim=0) + 1e-4)
117
+ std_t = torch.sqrt(z_target.var(dim=0) + 1e-4)
118
+ var_loss = F.relu(1 - std_p).mean() + F.relu(1 - std_t).mean()
119
+
120
+ # Covariance
121
+ B, D = z_pred.shape
122
+ zp = z_pred - z_pred.mean(0)
123
+ zt = z_target - z_target.mean(0)
124
+ cov_p = (zp.T @ zp) / max(B - 1, 1)
125
+ cov_t = (zt.T @ zt) / max(B - 1, 1)
126
+ mask = ~torch.eye(D, device=z_pred.device).bool()
127
+ cov_loss = cov_p[mask].pow(2).sum() / D + cov_t[mask].pow(2).sum() / D
128
+
129
+ total = sim_w * sim_loss + var_w * var_loss + cov_w * cov_loss
130
+ return total, {"sim": sim_loss.item(), "var": var_loss.item(), "cov": cov_loss.item()}
131
+
132
+
133
+ # ---------------------------------------------------------------------------
134
+ # Full JEPA model
135
+ # ---------------------------------------------------------------------------
136
+
137
+
138
+ class JEPA(nn.Module):
139
+ """Spatial JEPA for PDE dynamics prediction.
140
+
141
+ Online encoder + predictor learn to predict the target encoder's
142
+ representation of the next frame. The target encoder is an EMA
143
+ copy of the online encoder.
144
+
145
+ Args:
146
+ in_channels: number of input field channels.
147
+ latent_channels: spatial latent feature map channels.
148
+ base_ch: encoder base width.
149
+ pred_hidden: predictor hidden channels.
150
+ ema_decay: starting EMA decay.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ in_channels,
156
+ latent_channels=128,
157
+ base_ch=32,
158
+ pred_hidden=256,
159
+ ema_decay=0.996,
160
+ ):
161
+ super().__init__()
162
+ self.online_encoder = SpatialEncoder(in_channels, latent_channels, base_ch)
163
+ self.predictor = SpatialPredictor(latent_channels, pred_hidden)
164
+ self.target_encoder = copy.deepcopy(self.online_encoder)
165
+ self.ema_decay = ema_decay
166
+
167
+ # Freeze target
168
+ for p in self.target_encoder.parameters():
169
+ p.requires_grad_(False)
170
+
171
+ @torch.no_grad()
172
+ def update_target(self):
173
+ """EMA update of target encoder."""
174
+ for pt, po in zip(self.target_encoder.parameters(), self.online_encoder.parameters()):
175
+ pt.data.lerp_(po.data, 1 - self.ema_decay)
176
+
177
+ def set_ema_decay(self, decay):
178
+ """Update EMA decay (e.g. cosine schedule from 0.996 to 1.0)."""
179
+ self.ema_decay = decay
180
+
181
+ def forward(self, x_input, x_target):
182
+ """
183
+ Args:
184
+ x_input: current frame(s) [B, C, H, W]
185
+ x_target: next frame(s) [B, C, H, W]
186
+
187
+ Returns:
188
+ z_pred: predicted spatial latent [B, lat_ch, H', W']
189
+ z_target: target spatial latent [B, lat_ch, H', W']
190
+ """
191
+ z_online = self.online_encoder(x_input)
192
+ z_pred = self.predictor(z_online)
193
+
194
+ with torch.no_grad():
195
+ z_target = self.target_encoder(x_target)
196
+
197
+ return z_pred, z_target
198
+
199
+ def compute_loss(self, x_input, x_target):
200
+ """Full forward + loss computation.
201
+
202
+ VICReg is computed on channel vectors after spatial averaging
203
+ to keep the covariance matrix small (D = latent_channels).
204
+
205
+ Returns:
206
+ loss: scalar.
207
+ metrics: dict.
208
+ """
209
+ z_pred, z_target = self(x_input, x_target)
210
+
211
+ # Spatial MSE loss (pixel-level prediction quality)
212
+ spatial_mse = F.mse_loss(z_pred, z_target.detach())
213
+
214
+ # VICReg on spatially-averaged channel vectors [B, C]
215
+ zp_avg = z_pred.mean(dim=(-2, -1)) # [B, lat_ch]
216
+ zt_avg = z_target.mean(dim=(-2, -1)) # [B, lat_ch]
217
+
218
+ vicreg, vicreg_m = vicreg_loss(zp_avg, zt_avg.detach())
219
+
220
+ # Combine: spatial MSE drives prediction, VICReg prevents collapse
221
+ loss = spatial_mse + 0.1 * vicreg
222
+ metrics = {
223
+ "sim": vicreg_m["sim"],
224
+ "var": vicreg_m["var"],
225
+ "cov": vicreg_m["cov"],
226
+ "spatial_mse": spatial_mse.item(),
227
+ }
228
+ return loss, metrics