JensLundsgaard commited on
Commit
062d27a
·
verified ·
1 Parent(s): cef47ad

Upload raffael_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. raffael_model.py +431 -0
raffael_model.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Complete High-Quality ConvLSTM Autoencoder
3
+ - Uses true ConvLSTM (not regular LSTM)
4
+ - Complete Encoder (2D CNN + ConvLSTM) with flattened latents
5
+ - Complete Decoder (ConvLSTM + ConvTranspose)
6
+ - Optional Empty/Non-empty Classifier
7
+ - Works with 128x128 input images
8
+ - Latent format: (B, T, N) where N is flattened spatial dimensions
9
+ - Includes ResNet-style residual connections in CNN layers
10
+ """
11
+ import torch
12
+ import torch.nn as nn
13
+ from raffael_conv_lstm import ConvLSTM
14
+ from huggingface_hub import PyTorchModelHubMixin
15
+
16
+
17
+ class ResidualBlock(nn.Module):
18
+ """
19
+ Residual block for encoder with optional downsampling
20
+ """
21
+ def __init__(self, in_channels, out_channels, downsample=False):
22
+ super(ResidualBlock, self).__init__()
23
+
24
+ stride = 2 if downsample else 1
25
+
26
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
27
+ self.bn1 = nn.BatchNorm2d(out_channels)
28
+ self.relu = nn.ReLU(inplace=True)
29
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
30
+ self.bn2 = nn.BatchNorm2d(out_channels)
31
+
32
+ # Projection shortcut if channels change or downsampling
33
+ if in_channels != out_channels or downsample:
34
+ self.shortcut = nn.Sequential(
35
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
36
+ nn.BatchNorm2d(out_channels)
37
+ )
38
+ else:
39
+ self.shortcut = nn.Identity()
40
+
41
+ def forward(self, x):
42
+ identity = self.shortcut(x)
43
+
44
+ out = self.conv1(x)
45
+ out = self.bn1(out)
46
+ out = self.relu(out)
47
+
48
+ out = self.conv2(out)
49
+ out = self.bn2(out)
50
+
51
+ out += identity
52
+ out = self.relu(out)
53
+
54
+ return out
55
+
56
+
57
+ class ResidualUpBlock(nn.Module):
58
+ """
59
+ Residual block for decoder with upsampling
60
+ """
61
+ def __init__(self, in_channels, out_channels):
62
+ super(ResidualUpBlock, self).__init__()
63
+
64
+ self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
65
+ self.bn1 = nn.BatchNorm2d(out_channels)
66
+ self.relu = nn.ReLU(inplace=True)
67
+ self.conv = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
68
+ self.bn2 = nn.BatchNorm2d(out_channels)
69
+
70
+ # Shortcut with upsampling
71
+ self.shortcut = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
72
+
73
+ def forward(self, x):
74
+ identity = self.shortcut(x)
75
+
76
+ out = self.upsample(x)
77
+ out = self.bn1(out)
78
+ out = self.relu(out)
79
+
80
+ out = self.conv(out)
81
+ out = self.bn2(out)
82
+
83
+ out += identity
84
+ out = self.relu(out)
85
+
86
+ return out
87
+
88
+
89
+ class Encoder(nn.Module):
90
+ """
91
+ Encoder: 2D CNN spatial compression + ConvLSTM temporal modeling + flatten to (B, T, N)
92
+ Output: z_seq (B, T, latent_size) and z_last (B, latent_size)
93
+ """
94
+
95
+ def __init__(self, input_channels=1, hidden_dim=256, num_layers=2, latent_size=4096, use_convlstm=True):
96
+ super(Encoder, self).__init__()
97
+
98
+ self.hidden_dim = hidden_dim
99
+ self.latent_size = latent_size
100
+
101
+ # Spatial convolution with residual connections: process each frame separately
102
+ # 128x128 -> 64x64 -> 32x32 -> 16x16
103
+ self.spatial_cnn = nn.Sequential(
104
+ # Layer 1: 128 -> 64 (with downsampling)
105
+ ResidualBlock(input_channels, 64, downsample=True),
106
+
107
+ # Layer 2: 64 -> 32 (with downsampling)
108
+ ResidualBlock(64, 128, downsample=True),
109
+
110
+ # Layer 3: 32 -> 16 (with downsampling)
111
+ ResidualBlock(128, 256, downsample=True),
112
+ )
113
+
114
+ # ConvLSTM: process temporal sequence
115
+ # Input: (B, T, 256, 16, 16)
116
+ # Output: (B, T, hidden_dim, 16, 16)
117
+ self.convlstm = ConvLSTM(
118
+ input_dim=256,
119
+ hidden_dim=hidden_dim,
120
+ kernel_size=(3, 3),
121
+ num_layers=num_layers,
122
+ batch_first=True,
123
+ return_all_layers=False
124
+ )
125
+
126
+ # Dropout before latent compression
127
+ self.dropout = nn.Dropout(0.1)
128
+
129
+ # Linear layer to compress spatial latent to fixed size
130
+ # Input: (B*T, hidden_dim * 16 * 16)
131
+ # Output: (B*T, latent_size)
132
+ self.latent_compress = nn.Linear(hidden_dim * 16 * 16, latent_size)
133
+ self.use_convlstm = use_convlstm
134
+
135
+ def forward(self, x):
136
+ """
137
+ Args:
138
+ x: (B, T, 1, H, W) - input video sequence (any size, will be resized to 128x128)
139
+
140
+ Returns:
141
+ z_seq: (B, T, latent_size) - compressed latent sequence
142
+ z_last: (B, latent_size) - last timestep compressed latent
143
+ """
144
+ B, T, C, H, W = x.shape
145
+
146
+ # Resize to 128x128 if needed
147
+ x = x.view(B * T, C, H, W) # (B*T, 1, H, W)
148
+ if H != 128 or W != 128:
149
+ x = torch.nn.functional.interpolate(x, size=(128, 128), mode='bilinear', align_corners=True)
150
+
151
+ # Spatial compression: process each frame separately
152
+ x = self.spatial_cnn(x) # (B*T, 256, 16, 16)
153
+ _, C2, H2, W2 = x.shape
154
+ x = x.view(B, T, C2, H2, W2) # (B, T, 256, 16, 16)
155
+
156
+ # ConvLSTM processes temporal sequence
157
+ if(self.use_convlstm):
158
+ lstm_out, _ = self.convlstm(x) # list of (B, T, hidden_dim, 16, 16)
159
+ h_seq = lstm_out[0] # (B, T, hidden_dim, 16, 16)
160
+ else:
161
+ h_seq = x # just pass it forward if not
162
+
163
+ # Flatten and compress spatial dimensions with linear layer
164
+ B, T, C, H, W = h_seq.shape
165
+ h_flat = h_seq.view(B * T, C * H * W) # (B*T, hidden_dim * 16 * 16)
166
+ h_flat = self.dropout(h_flat) # Apply dropout
167
+ z_compressed = self.latent_compress(h_flat) # (B*T, latent_size)
168
+ z_seq = z_compressed.view(B, T, self.latent_size) # (B, T, latent_size)
169
+
170
+ # Take last timestep
171
+ z_last = z_seq[:, -1] # (B, latent_size)
172
+
173
+ return z_seq, z_last
174
+
175
+
176
+ class Decoder(nn.Module):
177
+ """
178
+ Decoder: Linear expansion + ConvLSTM temporal decoding + ConvTranspose spatial reconstruction
179
+ Input: z_seq (B, T, latent_size)
180
+ Output: x_rec (B, T, 1, 128, 128)
181
+ """
182
+
183
+ def __init__(self, seq_len, latent_size=4096, latent_dim=256, hidden_dim=256, num_layers=2, use_convlstm=True):
184
+ super(Decoder, self).__init__()
185
+ self.seq_len = seq_len
186
+ self.latent_dim = latent_dim
187
+ self.latent_size = latent_size
188
+
189
+ # Linear layer to expand compressed latent to spatial dimensions
190
+ # Input: (B*T, latent_size)
191
+ # Output: (B*T, latent_dim * 16 * 16)
192
+ self.latent_expand = nn.Linear(latent_size, latent_dim * 16 * 16)
193
+
194
+ # ConvLSTM decodes temporal dimension
195
+ self.convlstm = ConvLSTM(
196
+ input_dim=latent_dim,
197
+ hidden_dim=hidden_dim,
198
+ kernel_size=(3, 3),
199
+ num_layers=num_layers,
200
+ batch_first=True,
201
+ return_all_layers=False
202
+ )
203
+
204
+ # Spatial decoding with residual connections: 16x16 -> 32x32 -> 64x64 -> 128x128
205
+ self.spatial_decoder = nn.Sequential(
206
+ # 16 -> 32 (with upsampling)
207
+ ResidualUpBlock(hidden_dim, 128),
208
+
209
+ # 32 -> 64 (with upsampling)
210
+ ResidualUpBlock(128, 64),
211
+
212
+ # 64 -> 128 (with upsampling)
213
+ ResidualUpBlock(64, 32),
214
+
215
+ # Final output layer
216
+ nn.Conv2d(32, 1, kernel_size=3, padding=1),
217
+ nn.Sigmoid() # Assume pixels normalized to [0,1]
218
+ )
219
+ self.use_convlstm = use_convlstm
220
+
221
+ def forward(self, z_seq):
222
+ """
223
+ Args:
224
+ z_seq: (B, T, latent_size) - compressed latent sequence from encoder
225
+
226
+ Returns:
227
+ x_rec: (B, T, 1, 128, 128) - reconstructed video sequence
228
+ """
229
+ B, T, L = z_seq.shape
230
+
231
+ # Expand compressed latent to spatial dimensions
232
+ z_flat = z_seq.view(B * T, L) # (B*T, latent_size)
233
+ z_expanded = self.latent_expand(z_flat) # (B*T, latent_dim * 16 * 16)
234
+ z_spatial = z_expanded.view(B, T, self.latent_dim, 16, 16) # (B, T, latent_dim, 16, 16)
235
+
236
+ # ConvLSTM decodes temporal dimension
237
+ if(self.use_convlstm):
238
+ lstm_out, _ = self.convlstm(z_spatial) # list of (B, T, hidden_dim, 16, 16)
239
+ h_seq = lstm_out[0] # (B, T, hidden_dim, 16, 16)
240
+ else:
241
+ h_seq = z_spatial # just pass it forward
242
+
243
+ # Spatial decoding: process each timestep separately
244
+ B, T, C, H, W = h_seq.shape
245
+ h_seq = h_seq.view(B * T, C, H, W) # (B*T, hidden_dim, 16, 16)
246
+ x_rec = self.spatial_decoder(h_seq) # (B*T, 1, 128, 128)
247
+ x_rec = x_rec.view(B, T, 1, 128, 128) # (B, T, 1, 128, 128)
248
+
249
+ return x_rec
250
+
251
+
252
+ class LatentClassifier(nn.Module):
253
+ """
254
+ Empty / Non-empty Well Classifier
255
+ Classifies based on last timestep latent
256
+ """
257
+
258
+ def __init__(self, latent_size=4096, num_classes=2, dropout=0.3):
259
+ super(LatentClassifier, self).__init__()
260
+
261
+ self.head = nn.Sequential(
262
+ # Classification head - input is already flattened (B, latent_size)
263
+ nn.Linear(latent_size, 512),
264
+ nn.BatchNorm1d(512),
265
+ nn.ReLU(inplace=True),
266
+ nn.Dropout(dropout),
267
+
268
+ nn.Linear(512, 256),
269
+ nn.BatchNorm1d(256),
270
+ nn.ReLU(inplace=True),
271
+ nn.Dropout(dropout),
272
+
273
+ nn.Linear(256, num_classes)
274
+ )
275
+
276
+ def forward(self, z_last):
277
+ """
278
+ Args:
279
+ z_last: (B, latent_size) - last timestep compressed latent
280
+
281
+ Returns:
282
+ logits: (B, num_classes) - classification logits
283
+ """
284
+ return self.head(z_last)
285
+
286
+
287
+ class ConvLSTMAutoencoder(nn.Module, PyTorchModelHubMixin):
288
+ """
289
+ Complete ConvLSTM Autoencoder
290
+ Includes Encoder, Decoder, and optional Classifier
291
+ Compatible with HuggingFace Hub
292
+ Works with 128x128 images
293
+ """
294
+ def __init__(
295
+ self,
296
+ config=None,
297
+ seq_len=20,
298
+ input_channels=1,
299
+ encoder_hidden_dim=256,
300
+ encoder_layers=2,
301
+ decoder_hidden_dim=128,
302
+ decoder_layers=2,
303
+ latent_size=4096,
304
+ use_classifier=True,
305
+ num_classes=2,
306
+ use_latent_split=False,
307
+ # Ablation parameters
308
+ dropout_rate=0.1,
309
+ use_convlstm=True,
310
+ use_residual=True,
311
+ use_batchnorm=True
312
+ ):
313
+ super(ConvLSTMAutoencoder, self).__init__()
314
+ self.seq_len = seq_len
315
+ self.use_classifier = use_classifier
316
+ self.encoder_hidden_dim = encoder_hidden_dim
317
+ self.latent_size = latent_size
318
+ self.use_latent_split = use_latent_split
319
+ # Store ablation settings for reproducibility
320
+ self.dropout_rate = dropout_rate
321
+ self.use_convlstm = use_convlstm
322
+ self.use_residual = use_residual
323
+ self.use_batchnorm = use_batchnorm
324
+ if(config != None):
325
+ # Handle config as dict (from HuggingFace) or object
326
+ if isinstance(config, dict):
327
+ self.seq_len = config.get('seq_len', seq_len)
328
+ self.use_classifier = config.get('use_classifier', use_classifier)
329
+ self.encoder_hidden_dim = config.get('encoder_hidden_dim', encoder_hidden_dim)
330
+ self.latent_size = config.get('latent_size', latent_size)
331
+ self.use_latent_split = config.get('use_latent_split', use_latent_split)
332
+ self.dropout_rate = config.get('dropout_rate', dropout_rate)
333
+ self.use_convlstm = config.get('use_convlstm', use_convlstm)
334
+ self.use_residual = config.get('use_residual', use_residual)
335
+ self.use_batchnorm = config.get('use_batchnorm', use_batchnorm)
336
+ else:
337
+ self.seq_len = config.seq_len
338
+ self.use_classifier = config.use_classifier
339
+ self.encoder_hidden_dim = config.encoder_hidden_dim
340
+ self.latent_size = config.latent_size
341
+ self.use_latent_split = config.use_latent_split
342
+ self.dropout_rate = config.dropout_rate
343
+ self.use_convlstm = config.use_convlstm
344
+ self.use_residual = config.use_residual
345
+ self.use_batchnorm = config.use_batchnorm
346
+
347
+ # Core components
348
+ self.encoder = Encoder(
349
+ latent_size=self.latent_size,
350
+ use_convlstm=self.use_convlstm
351
+ )
352
+
353
+ self.decoder = Decoder(
354
+ seq_len=self.seq_len,
355
+ latent_size=self.latent_size,
356
+ use_convlstm=self.use_convlstm
357
+ )
358
+
359
+ # Optional classifier
360
+ if use_classifier:
361
+ self.classifier = LatentClassifier(
362
+ latent_size=latent_size,
363
+ num_classes=num_classes
364
+ )
365
+
366
+ def forward(self, x, return_all=False, hidden=None):
367
+ """
368
+ Args:
369
+ x: (B, T, 1, H, W) - input video sequence (any size, will be resized internally)
370
+ return_all: whether to return all intermediate results
371
+
372
+ Returns:
373
+ Tuple of (reconstruction, lat_vec_seq) where:
374
+ - reconstruction: (B, T, 1, H, W) - reconstructed video (same size as input)
375
+ - lat_vec_seq: (B, T, latent_size) - compressed latent sequence
376
+
377
+ If return_all is True, returns dict with keys:
378
+ - reconstruction: (B, T, 1, H, W) - reconstructed video
379
+ - z_seq: (B, T, latent_size) - compressed latent sequence
380
+ - z_last: (B, latent_size) - last timestep compressed latent
381
+ - logits: (B, num_classes) - classification logits (if enabled)
382
+ """
383
+ B, T, C, orig_H, orig_W = x.shape
384
+
385
+ # Encode (will resize to 128x128 internally)
386
+ if return_all:
387
+ z_seq, z_last, h_last_enc, c_last_enc = self.encoder(x, return_all=return_all)#, hidden_state=hidden_state['enc'] if hidden_state != None else None)
388
+
389
+ else:
390
+ z_seq, z_last = self.encoder(x)#, hidden_state=hidden_state['enc'] if hidden_state != None else None)
391
+
392
+
393
+ # Decode (outputs 128x128)
394
+ if return_all:
395
+ x_rec, h_last_dec, c_last_dec = self.decoder(z_seq, return_all=return_all)#, hidden_state=hidden_state['dec'] if hidden_state != None else None)
396
+
397
+ else:
398
+ x_rec = self.decoder(z_seq)#, hidden_state=hidden_state['dec'] if hidden_state != None else None)
399
+
400
+ # Resize back to original input size if needed
401
+ if orig_H != 128 or orig_W != 128:
402
+ x_rec_flat = x_rec.view(B * T, C, 128, 128)
403
+ x_rec_flat = torch.nn.functional.interpolate(x_rec_flat, size=(orig_H, orig_W), mode='bilinear', align_corners=True)
404
+ x_rec = x_rec_flat.view(B, T, C, orig_H, orig_W)
405
+
406
+ if return_all:
407
+ # Build output dictionary
408
+ output = {
409
+ "reconstruction": x_rec,
410
+ "z_seq": z_seq,
411
+ "z_last": z_last,
412
+ }
413
+
414
+ # Optional classification
415
+ if self.use_classifier:
416
+ logits = self.classifier(z_last)
417
+ output["logits"] = logits
418
+
419
+ return output
420
+ else:
421
+ # Return tuple: (reconstruction, latent_vector)
422
+ return x_rec, z_seq
423
+
424
+ def encode(self, x):
425
+ """Encode only, for extracting latent"""
426
+ z_seq, z_last = self.encoder(x)
427
+ return z_seq, z_last
428
+
429
+ def decode(self, z_seq):
430
+ """Decode only, for reconstructing from latent"""
431
+ return self.decoder(z_seq)