JensLundsgaard commited on
Commit
dddfdd1
·
verified ·
1 Parent(s): 344e3be

Upload raffael_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. raffael_model.py +418 -0
raffael_model.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Complete High-Quality ConvLSTM Autoencoder
3
+ - Uses true ConvLSTM (not regular LSTM)
4
+ - Complete Encoder (2D CNN + ConvLSTM) with flattened latents
5
+ - Complete Decoder (ConvLSTM + ConvTranspose)
6
+ - Optional Empty/Non-empty Classifier
7
+ - Works with 128x128 input images
8
+ - Latent format: (B, T, N) where N is flattened spatial dimensions
9
+ - Includes ResNet-style residual connections in CNN layers
10
+ """
11
+ import torch
12
+ import torch.nn as nn
13
+ from raffael_conv_lstm import ConvLSTM
14
+ from huggingface_hub import PyTorchModelHubMixin
15
+
16
+
17
+ class ResidualBlock(nn.Module):
18
+ """
19
+ Residual block for encoder with optional downsampling
20
+ """
21
+ def __init__(self, in_channels, out_channels, downsample=False):
22
+ super(ResidualBlock, self).__init__()
23
+
24
+ stride = 2 if downsample else 1
25
+
26
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
27
+ self.bn1 = nn.BatchNorm2d(out_channels)
28
+ self.relu = nn.ReLU(inplace=True)
29
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
30
+ self.bn2 = nn.BatchNorm2d(out_channels)
31
+
32
+ # Projection shortcut if channels change or downsampling
33
+ if in_channels != out_channels or downsample:
34
+ self.shortcut = nn.Sequential(
35
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
36
+ nn.BatchNorm2d(out_channels)
37
+ )
38
+ else:
39
+ self.shortcut = nn.Identity()
40
+
41
+ def forward(self, x):
42
+ identity = self.shortcut(x)
43
+
44
+ out = self.conv1(x)
45
+ out = self.bn1(out)
46
+ out = self.relu(out)
47
+
48
+ out = self.conv2(out)
49
+ out = self.bn2(out)
50
+
51
+ out += identity
52
+ out = self.relu(out)
53
+
54
+ return out
55
+
56
+
57
+ class ResidualUpBlock(nn.Module):
58
+ """
59
+ Residual block for decoder with upsampling
60
+ """
61
+ def __init__(self, in_channels, out_channels):
62
+ super(ResidualUpBlock, self).__init__()
63
+
64
+ self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
65
+ self.bn1 = nn.BatchNorm2d(out_channels)
66
+ self.relu = nn.ReLU(inplace=True)
67
+ self.conv = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
68
+ self.bn2 = nn.BatchNorm2d(out_channels)
69
+
70
+ # Shortcut with upsampling
71
+ self.shortcut = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
72
+
73
+ def forward(self, x):
74
+ identity = self.shortcut(x)
75
+
76
+ out = self.upsample(x)
77
+ out = self.bn1(out)
78
+ out = self.relu(out)
79
+
80
+ out = self.conv(out)
81
+ out = self.bn2(out)
82
+
83
+ out += identity
84
+ out = self.relu(out)
85
+
86
+ return out
87
+
88
+
89
+ class Encoder(nn.Module):
90
+ """
91
+ Encoder: 2D CNN spatial compression + ConvLSTM temporal modeling + flatten to (B, T, N)
92
+ Output: z_seq (B, T, latent_size) and z_last (B, latent_size)
93
+ """
94
+
95
+ def __init__(self, input_channels=1, hidden_dim=256, num_layers=2, latent_size=4096):
96
+ super(Encoder, self).__init__()
97
+
98
+ self.hidden_dim = hidden_dim
99
+ self.latent_size = latent_size
100
+
101
+ # Spatial convolution with residual connections: process each frame separately
102
+ # 128x128 -> 64x64 -> 32x32 -> 16x16
103
+ self.spatial_cnn = nn.Sequential(
104
+ # Layer 1: 128 -> 64 (with downsampling)
105
+ ResidualBlock(input_channels, 64, downsample=True),
106
+
107
+ # Layer 2: 64 -> 32 (with downsampling)
108
+ ResidualBlock(64, 128, downsample=True),
109
+
110
+ # Layer 3: 32 -> 16 (with downsampling)
111
+ ResidualBlock(128, 256, downsample=True),
112
+ )
113
+
114
+ # ConvLSTM: process temporal sequence
115
+ # Input: (B, T, 256, 16, 16)
116
+ # Output: (B, T, hidden_dim, 16, 16)
117
+ self.convlstm = ConvLSTM(
118
+ input_dim=256,
119
+ hidden_dim=hidden_dim,
120
+ kernel_size=(3, 3),
121
+ num_layers=num_layers,
122
+ batch_first=True,
123
+ return_all_layers=False
124
+ )
125
+
126
+ # Dropout before latent compression
127
+ self.dropout = nn.Dropout(0.1)
128
+
129
+ # Linear layer to compress spatial latent to fixed size
130
+ # Input: (B*T, hidden_dim * 16 * 16)
131
+ # Output: (B*T, latent_size)
132
+ self.latent_compress = nn.Linear(hidden_dim * 16 * 16, latent_size)
133
+
134
+ def forward(self, x):
135
+ """
136
+ Args:
137
+ x: (B, T, 1, H, W) - input video sequence (any size, will be resized to 128x128)
138
+
139
+ Returns:
140
+ z_seq: (B, T, latent_size) - compressed latent sequence
141
+ z_last: (B, latent_size) - last timestep compressed latent
142
+ """
143
+ B, T, C, H, W = x.shape
144
+
145
+ # Resize to 128x128 if needed
146
+ x = x.view(B * T, C, H, W) # (B*T, 1, H, W)
147
+ if H != 128 or W != 128:
148
+ x = torch.nn.functional.interpolate(x, size=(128, 128), mode='bilinear', align_corners=True)
149
+
150
+ # Spatial compression: process each frame separately
151
+ x = self.spatial_cnn(x) # (B*T, 256, 16, 16)
152
+ _, C2, H2, W2 = x.shape
153
+ x = x.view(B, T, C2, H2, W2) # (B, T, 256, 16, 16)
154
+
155
+ # ConvLSTM processes temporal sequence
156
+ lstm_out, _ = self.convlstm(x) # list of (B, T, hidden_dim, 16, 16)
157
+ h_seq = lstm_out[0] # (B, T, hidden_dim, 16, 16)
158
+
159
+ # Flatten and compress spatial dimensions with linear layer
160
+ B, T, C, H, W = h_seq.shape
161
+ h_flat = h_seq.view(B * T, C * H * W) # (B*T, hidden_dim * 16 * 16)
162
+ h_flat = self.dropout(h_flat) # Apply dropout
163
+ z_compressed = self.latent_compress(h_flat) # (B*T, latent_size)
164
+ z_seq = z_compressed.view(B, T, self.latent_size) # (B, T, latent_size)
165
+
166
+ # Take last timestep
167
+ z_last = z_seq[:, -1] # (B, latent_size)
168
+
169
+ return z_seq, z_last
170
+
171
+
172
+ class Decoder(nn.Module):
173
+ """
174
+ Decoder: Linear expansion + ConvLSTM temporal decoding + ConvTranspose spatial reconstruction
175
+ Input: z_seq (B, T, latent_size)
176
+ Output: x_rec (B, T, 1, 128, 128)
177
+ """
178
+
179
+ def __init__(self, seq_len, latent_size=4096, latent_dim=256, hidden_dim=128, num_layers=2):
180
+ super(Decoder, self).__init__()
181
+ self.seq_len = seq_len
182
+ self.latent_dim = latent_dim
183
+ self.latent_size = latent_size
184
+
185
+ # Linear layer to expand compressed latent to spatial dimensions
186
+ # Input: (B*T, latent_size)
187
+ # Output: (B*T, latent_dim * 16 * 16)
188
+ self.latent_expand = nn.Linear(latent_size, latent_dim * 16 * 16)
189
+
190
+ # ConvLSTM decodes temporal dimension
191
+ self.convlstm = ConvLSTM(
192
+ input_dim=latent_dim,
193
+ hidden_dim=hidden_dim,
194
+ kernel_size=(3, 3),
195
+ num_layers=num_layers,
196
+ batch_first=True,
197
+ return_all_layers=False
198
+ )
199
+
200
+ # Spatial decoding with residual connections: 16x16 -> 32x32 -> 64x64 -> 128x128
201
+ self.spatial_decoder = nn.Sequential(
202
+ # 16 -> 32 (with upsampling)
203
+ ResidualUpBlock(hidden_dim, 128),
204
+
205
+ # 32 -> 64 (with upsampling)
206
+ ResidualUpBlock(128, 64),
207
+
208
+ # 64 -> 128 (with upsampling)
209
+ ResidualUpBlock(64, 32),
210
+
211
+ # Final output layer
212
+ nn.Conv2d(32, 1, kernel_size=3, padding=1),
213
+ nn.Sigmoid() # Assume pixels normalized to [0,1]
214
+ )
215
+
216
+ def forward(self, z_seq):
217
+ """
218
+ Args:
219
+ z_seq: (B, T, latent_size) - compressed latent sequence from encoder
220
+
221
+ Returns:
222
+ x_rec: (B, T, 1, 128, 128) - reconstructed video sequence
223
+ """
224
+ B, T, L = z_seq.shape
225
+
226
+ # Expand compressed latent to spatial dimensions
227
+ z_flat = z_seq.view(B * T, L) # (B*T, latent_size)
228
+ z_expanded = self.latent_expand(z_flat) # (B*T, latent_dim * 16 * 16)
229
+ z_spatial = z_expanded.view(B, T, self.latent_dim, 16, 16) # (B, T, latent_dim, 16, 16)
230
+
231
+ # ConvLSTM decodes temporal dimension
232
+ lstm_out, _ = self.convlstm(z_spatial) # list of (B, T, hidden_dim, 16, 16)
233
+ h_seq = lstm_out[0] # (B, T, hidden_dim, 16, 16)
234
+
235
+ # Spatial decoding: process each timestep separately
236
+ B, T, C, H, W = h_seq.shape
237
+ h_seq = h_seq.view(B * T, C, H, W) # (B*T, hidden_dim, 16, 16)
238
+ x_rec = self.spatial_decoder(h_seq) # (B*T, 1, 128, 128)
239
+ x_rec = x_rec.view(B, T, 1, 128, 128) # (B, T, 1, 128, 128)
240
+
241
+ return x_rec
242
+
243
+
244
+ class LatentClassifier(nn.Module):
245
+ """
246
+ Empty / Non-empty Well Classifier
247
+ Classifies based on last timestep latent
248
+ """
249
+
250
+ def __init__(self, latent_size=4096, num_classes=2, dropout=0.3):
251
+ super(LatentClassifier, self).__init__()
252
+
253
+ self.head = nn.Sequential(
254
+ # Classification head - input is already flattened (B, latent_size)
255
+ nn.Linear(latent_size, 512),
256
+ nn.BatchNorm1d(512),
257
+ nn.ReLU(inplace=True),
258
+ nn.Dropout(dropout),
259
+
260
+ nn.Linear(512, 256),
261
+ nn.BatchNorm1d(256),
262
+ nn.ReLU(inplace=True),
263
+ nn.Dropout(dropout),
264
+
265
+ nn.Linear(256, num_classes)
266
+ )
267
+
268
+ def forward(self, z_last):
269
+ """
270
+ Args:
271
+ z_last: (B, latent_size) - last timestep compressed latent
272
+
273
+ Returns:
274
+ logits: (B, num_classes) - classification logits
275
+ """
276
+ return self.head(z_last)
277
+
278
+
279
+ class ConvLSTMAutoencoder(nn.Module, PyTorchModelHubMixin):
280
+ """
281
+ Complete ConvLSTM Autoencoder
282
+ Includes Encoder, Decoder, and optional Classifier
283
+ Compatible with HuggingFace Hub
284
+ Works with 128x128 images
285
+ """
286
+ def __init__(
287
+ self,
288
+ config=None,
289
+ seq_len=20,
290
+ input_channels=1,
291
+ encoder_hidden_dim=256,
292
+ encoder_layers=2,
293
+ decoder_hidden_dim=128,
294
+ decoder_layers=2,
295
+ latent_size=4096,
296
+ use_classifier=True,
297
+ num_classes=2,
298
+ use_latent_split=False,
299
+ # Ablation parameters
300
+ dropout_rate=0.1,
301
+ use_convlstm=True,
302
+ use_residual=True,
303
+ use_batchnorm=True
304
+ ):
305
+ super(ConvLSTMAutoencoder, self).__init__()
306
+ self.seq_len = seq_len
307
+ self.use_classifier = use_classifier
308
+ self.encoder_hidden_dim = encoder_hidden_dim
309
+ self.latent_size = latent_size
310
+ self.use_latent_split = use_latent_split
311
+ # Store ablation settings for reproducibility
312
+ self.dropout_rate = dropout_rate
313
+ self.use_convlstm = use_convlstm
314
+ self.use_residual = use_residual
315
+ self.use_batchnorm = use_batchnorm
316
+ if(config != None):
317
+ # Handle config as dict (from HuggingFace) or object
318
+ if isinstance(config, dict):
319
+ self.seq_len = config.get('seq_len', seq_len)
320
+ self.use_classifier = config.get('use_classifier', use_classifier)
321
+ self.encoder_hidden_dim = config.get('encoder_hidden_dim', encoder_hidden_dim)
322
+ self.latent_size = config.get('latent_size', latent_size)
323
+ self.use_latent_split = config.get('use_latent_split', use_latent_split)
324
+ self.dropout_rate = config.get('dropout_rate', dropout_rate)
325
+ self.use_convlstm = config.get('use_convlstm', use_convlstm)
326
+ self.use_residual = config.get('use_residual', use_residual)
327
+ self.use_batchnorm = config.get('use_batchnorm', use_batchnorm)
328
+ else:
329
+ self.seq_len = config.seq_len
330
+ self.use_classifier = config.use_classifier
331
+ self.encoder_hidden_dim = config.encoder_hidden_dim
332
+ self.latent_size = config.latent_size
333
+ self.use_latent_split = config.use_latent_split
334
+ self.dropout_rate = config.dropout_rate
335
+ self.use_convlstm = config.use_convlstm
336
+ self.use_residual = config.use_residual
337
+ self.use_batchnorm = config.use_batchnorm
338
+
339
+ # Core components
340
+ self.encoder = Encoder(
341
+ input_channels=input_channels,
342
+ hidden_dim=encoder_hidden_dim,
343
+ num_layers=encoder_layers,
344
+ latent_size=latent_size
345
+ )
346
+
347
+ self.decoder = Decoder(
348
+ seq_len=seq_len,
349
+ latent_size=latent_size,
350
+ latent_dim=encoder_hidden_dim,
351
+ hidden_dim=decoder_hidden_dim,
352
+ num_layers=decoder_layers
353
+ )
354
+
355
+ # Optional classifier
356
+ if use_classifier:
357
+ self.classifier = LatentClassifier(
358
+ latent_size=latent_size,
359
+ num_classes=num_classes
360
+ )
361
+
362
+ def forward(self, x, return_all=False):
363
+ """
364
+ Args:
365
+ x: (B, T, 1, H, W) - input video sequence (any size, will be resized internally)
366
+ return_all: whether to return all intermediate results
367
+
368
+ Returns:
369
+ Tuple of (reconstruction, lat_vec_seq) where:
370
+ - reconstruction: (B, T, 1, H, W) - reconstructed video (same size as input)
371
+ - lat_vec_seq: (B, T, latent_size) - compressed latent sequence
372
+
373
+ If return_all is True, returns dict with keys:
374
+ - reconstruction: (B, T, 1, H, W) - reconstructed video
375
+ - z_seq: (B, T, latent_size) - compressed latent sequence
376
+ - z_last: (B, latent_size) - last timestep compressed latent
377
+ - logits: (B, num_classes) - classification logits (if enabled)
378
+ """
379
+ B, T, C, orig_H, orig_W = x.shape
380
+
381
+ # Encode (will resize to 128x128 internally)
382
+ z_seq, z_last = self.encoder(x)
383
+
384
+ # Decode (outputs 128x128)
385
+ x_rec = self.decoder(z_seq)
386
+
387
+ # Resize back to original input size if needed
388
+ if orig_H != 128 or orig_W != 128:
389
+ x_rec_flat = x_rec.view(B * T, C, 128, 128)
390
+ x_rec_flat = torch.nn.functional.interpolate(x_rec_flat, size=(orig_H, orig_W), mode='bilinear', align_corners=True)
391
+ x_rec = x_rec_flat.view(B, T, C, orig_H, orig_W)
392
+
393
+ if return_all:
394
+ # Build output dictionary
395
+ output = {
396
+ "reconstruction": x_rec,
397
+ "z_seq": z_seq,
398
+ "z_last": z_last,
399
+ }
400
+
401
+ # Optional classification
402
+ if self.use_classifier:
403
+ logits = self.classifier(z_last)
404
+ output["logits"] = logits
405
+
406
+ return output
407
+ else:
408
+ # Return tuple: (reconstruction, latent_vector)
409
+ return x_rec, z_seq
410
+
411
+ def encode(self, x):
412
+ """Encode only, for extracting latent"""
413
+ z_seq, z_last = self.encoder(x)
414
+ return z_seq, z_last
415
+
416
+ def decode(self, z_seq):
417
+ """Decode only, for reconstructing from latent"""
418
+ return self.decoder(z_seq)