jdmayfield commited on
Commit
529f7c0
·
verified ·
1 Parent(s): 40f37af

Create train_mae_swin3d.py

Browse files
Files changed (1) hide show
  1. train_mae_swin3d.py +755 -0
train_mae_swin3d.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Masked Autoencoder (MAE) pretraining with 3D Swin Transformer for OPSCC CT scans.
4
+ Asymmetry-aware reconstruction + overfitting monitoring via cosine similarity.
5
+
6
+ Run example:
7
+ python train_mae_swin3d.py --data-dir /path/to/your/nii_folder --output-dir ./checkpoints
8
+ """
9
+
10
+ """
11
+ Self-Supervised Learning for OPSCC CT using 3D Swin Transformer MAE
12
+ with asymmetry-aware reconstruction and overfitting monitoring
13
+ """
14
+
15
+ import argparse
16
+ import json
17
+ import pickle
18
+ import warnings
19
+ from datetime import datetime
20
+ from pathlib import Path
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from torch.utils.data import Dataset, DataLoader
26
+
27
+ import numpy as np
28
+ from scipy import ndimage
29
+ import nibabel as nib
30
+ from tqdm import tqdm
31
+
32
+ warnings.filterwarnings("ignore", category=UserWarning)
33
+
34
+
35
+ # ==============================================================================
36
+ # Drop Path
37
+ # ==============================================================================
38
+
39
+ class DropPath(nn.Module):
40
+ def __init__(self, drop_prob: float = 0.):
41
+ super().__init__()
42
+ self.drop_prob = drop_prob
43
+
44
+ def forward(self, x):
45
+ if self.drop_prob == 0. or not self.training:
46
+ return x
47
+ keep_prob = 1 - self.drop_prob
48
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
49
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
50
+ random_tensor.floor_()
51
+ return x.div(keep_prob) * random_tensor
52
+
53
+
54
+ # ==============================================================================
55
+ # Asymmetry Detectors
56
+ # ==============================================================================
57
+
58
+ class AirwayAsymmetryDetector:
59
+ def __init__(self, exclude_inferior_fraction=0.15, exclude_superior_fraction=0.10):
60
+ self.exclude_inferior_fraction = exclude_inferior_fraction
61
+ self.exclude_superior_fraction = exclude_superior_fraction
62
+
63
+ def find_midline(self, slice_2d):
64
+ h, w = slice_2d.shape
65
+ search_range = w // 8
66
+ center = w // 2
67
+ best_midline = center
68
+ best_symmetry = float('inf')
69
+ for mid in range(center - search_range, center + search_range):
70
+ compare_width = min(mid, w - mid)
71
+ if compare_width < 10:
72
+ continue
73
+ left = slice_2d[:, mid - compare_width:mid]
74
+ right = np.flip(slice_2d[:, mid:mid + compare_width], axis=1)
75
+ diff = np.abs(left - right).mean()
76
+ if diff < best_symmetry:
77
+ best_symmetry = diff
78
+ best_midline = mid
79
+ return best_midline
80
+
81
+ def detect_airway(self, slice_2d, air_thresh=0.1):
82
+ binary = slice_2d < air_thresh
83
+ labeled, num_features = ndimage.label(binary)
84
+ edge_labels = set(labeled[0,:].flatten()) | set(labeled[-1,:].flatten()) | \
85
+ set(labeled[:,0].flatten()) | set(labeled[:,-1].flatten())
86
+ airway_mask = np.zeros_like(binary)
87
+ for label_id in range(1, num_features + 1):
88
+ if label_id not in edge_labels:
89
+ component = labeled == label_id
90
+ if component.sum() > 20:
91
+ airway_mask |= component
92
+ return airway_mask
93
+
94
+ def forward(self, volume):
95
+ d, h, w = volume.shape
96
+ inferior_cutoff = int(d * self.exclude_inferior_fraction)
97
+ superior_cutoff = int(d * (1 - self.exclude_superior_fraction))
98
+ results = {'effacement': [], 'mass_effect': [], 'midline_shift': [], 'hybrid': [], 'midlines': []}
99
+ for z in range(d):
100
+ slice_2d = volume[z]
101
+ midline = self.find_midline(slice_2d)
102
+ midline_shift = midline - w // 2
103
+ results['midlines'].append(midline)
104
+ airway_mask = self.detect_airway(slice_2d)
105
+ left_air = airway_mask[:, :midline].sum()
106
+ right_air = airway_mask[:, midline:].sum()
107
+ total = left_air + right_air
108
+ effacement = abs(left_air - right_air) / max(total, 1) if total > 0 else 0
109
+ compare_width = min(midline, w - midline)
110
+ mass_effect = 0
111
+ if compare_width > 0:
112
+ soft_tissue = (slice_2d > 0.2) & (slice_2d < 0.7)
113
+ left = slice_2d[:, midline-compare_width:midline] * soft_tissue[:, midline-compare_width:midline]
114
+ right = np.flip(slice_2d[:, midline:midline+compare_width], axis=1) * np.flip(soft_tissue[:, midline:midline+compare_width], axis=1)
115
+ mass_effect = np.abs(left - right).mean()
116
+ in_range = inferior_cutoff <= z <= superior_cutoff
117
+ hybrid = (0.5 * effacement + 0.5 * mass_effect) if in_range else 0
118
+ results['effacement'].append(effacement)
119
+ results['mass_effect'].append(mass_effect)
120
+ results['midline_shift'].append(midline_shift)
121
+ results['hybrid'].append(hybrid)
122
+ return {k: np.array(v) for k, v in results.items()}
123
+
124
+
125
+ class GlobalSoftTissueAsymmetryDetector:
126
+ def __init__(self, exclude_inferior_fraction=0.15, exclude_superior_fraction=0.10):
127
+ self.exclude_inferior_fraction = exclude_inferior_fraction
128
+ self.exclude_superior_fraction = exclude_superior_fraction
129
+
130
+ def forward(self, volume, midlines=None):
131
+ d, h, w = volume.shape
132
+ if midlines is None:
133
+ midlines = [w // 2] * d
134
+ results = {'left_hypo': [], 'right_hypo': [], 'hypo_asymmetry': []}
135
+ for z in range(d):
136
+ slice_2d = volume[z]
137
+ midline = midlines[z]
138
+ soft_tissue = (slice_2d > 0.2) & (slice_2d < 0.7)
139
+ hypodense = (slice_2d < 0.35) & soft_tissue
140
+ hypodense = ndimage.binary_opening(hypodense, iterations=1)
141
+ hypodense = ndimage.binary_closing(hypodense, iterations=2)
142
+ labeled, num_features = ndimage.label(hypodense)
143
+ left_count = right_count = 0
144
+ for i in range(1, num_features + 1):
145
+ region = labeled == i
146
+ size = region.sum()
147
+ if 10 < size < 150:
148
+ centroid_x = np.argwhere(region)[:,1].mean()
149
+ if centroid_x < midline:
150
+ left_count += 1
151
+ else:
152
+ right_count += 1
153
+ results['left_hypo'].append(left_count)
154
+ results['right_hypo'].append(right_count)
155
+ results['hypo_asymmetry'].append(abs(left_count - right_count))
156
+ return {k: np.array(v) for k, v in results.items()}
157
+
158
+
159
+ # ==============================================================================
160
+ # 3D Swin Transformer Components
161
+ # ==============================================================================
162
+
163
+ def window_partition3d(x, window_size=(4,4,4)):
164
+ B, C, D, H, W = x.shape
165
+ ws_d, ws_h, ws_w = window_size
166
+ pad_d = (ws_d - D % ws_d) % ws_d
167
+ pad_h = (ws_h - H % ws_h) % ws_h
168
+ pad_w = (ws_w - W % ws_w) % ws_w
169
+ x = F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d))
170
+ Dp, Hp, Wp = D + pad_d, H + pad_h, W + pad_w
171
+ x = x.reshape(B, C, Dp // ws_d, ws_d, Hp // ws_h, ws_h, Wp // ws_w, ws_w)
172
+ x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
173
+ windows = x.reshape(-1, C, ws_d * ws_h * ws_w).permute(0, 2, 1).contiguous()
174
+ return windows, (pad_d, pad_h, pad_w)
175
+
176
+
177
+ def window_reverse3d(windows, window_size, B, D, H, W, pads):
178
+ pad_d, pad_h, pad_w = pads
179
+ ws_d, ws_h, ws_w = window_size
180
+ Dp, Hp, Wp = D + pad_d, H + pad_h, W + pad_w
181
+ x = windows.reshape(B, Dp // ws_d, Hp // ws_h, Wp // ws_w, ws_d, ws_h, ws_w, -1)
182
+ x = x.permute(0, 7, 1, 4, 2, 5, 3, 6).contiguous()
183
+ x = x.reshape(B, -1, Dp, Hp, Wp)
184
+ x = x[:, :, :D, :H, :W]
185
+ return x
186
+
187
+
188
+ class WindowAttention3D(nn.Module):
189
+ def __init__(self, dim, window_size=(4,4,4), num_heads=3, qkv_bias=True, qk_scale=None,
190
+ attn_drop=0., proj_drop=0.):
191
+ super().__init__()
192
+ self.dim = dim
193
+ self.window_size = window_size
194
+ self.num_heads = num_heads
195
+ head_dim = dim // num_heads
196
+ self.scale = qk_scale or head_dim ** -0.5
197
+
198
+ coords_d = torch.arange(window_size[0])
199
+ coords_h = torch.arange(window_size[1])
200
+ coords_w = torch.arange(window_size[2])
201
+ coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing='ij'))
202
+ coords_flatten = torch.flatten(coords, 1)
203
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
204
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
205
+
206
+ relative_coords[:, :, 0] += window_size[0] - 1
207
+ relative_coords[:, :, 1] += window_size[1] - 1
208
+ relative_coords[:, :, 2] += window_size[2] - 1
209
+
210
+ relative_coords[:, :, 0] *= (2 * window_size[1] - 1) * (2 * window_size[2] - 1)
211
+ relative_coords[:, :, 1] *= (2 * window_size[2] - 1)
212
+ self.relative_position_index = relative_coords.sum(-1)
213
+
214
+ max_rel_pos = self.relative_position_index.max().item()
215
+ self.relative_position_bias_table = nn.Parameter(torch.zeros((max_rel_pos + 1, num_heads)))
216
+ nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
217
+
218
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
219
+ self.attn_drop = nn.Dropout(attn_drop)
220
+ self.proj = nn.Linear(dim, dim)
221
+ self.proj_drop = nn.Dropout(proj_drop)
222
+ self.softmax = nn.Softmax(dim=-1)
223
+
224
+ def forward(self, x, mask=None):
225
+ B_, N, C = x.shape
226
+ rel_index = self.relative_position_index[:N, :N]
227
+ relative_position_bias = self.relative_position_bias_table[rel_index.view(-1)]
228
+ relative_position_bias = relative_position_bias.view(N, N, -1).permute(2, 0, 1).contiguous()
229
+
230
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
231
+ q, k, v = qkv[0], qkv[1], qkv[2]
232
+ q = q * self.scale
233
+ attn = (q @ k.transpose(-2, -1))
234
+ attn = attn + relative_position_bias.unsqueeze(0)
235
+
236
+ if mask is not None:
237
+ nW = mask.shape[0]
238
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
239
+ attn = attn.view(-1, self.num_heads, N, N)
240
+
241
+ attn = self.softmax(attn)
242
+ attn = self.attn_drop(attn)
243
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
244
+ x = self.proj(x)
245
+ x = self.proj_drop(x)
246
+ return x
247
+
248
+
249
+ class SwinTransformerBlock3D(nn.Module):
250
+ def __init__(self, dim, num_heads, window_size=(4,4,4), shift_size=(0,0,0),
251
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
252
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
253
+ super().__init__()
254
+ self.dim = dim
255
+ self.window_size = window_size
256
+ self.shift_size = shift_size
257
+ self.norm1 = norm_layer(dim)
258
+ self.attn = WindowAttention3D(dim=dim, window_size=window_size, num_heads=num_heads,
259
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
260
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
261
+ self.norm2 = norm_layer(dim)
262
+ mlp_hidden_dim = int(dim * mlp_ratio)
263
+ self.mlp = nn.Sequential(
264
+ nn.Linear(dim, mlp_hidden_dim), act_layer(), nn.Dropout(drop),
265
+ nn.Linear(mlp_hidden_dim, dim), nn.Dropout(drop)
266
+ )
267
+
268
+ def forward(self, x):
269
+ shortcut = x
270
+ x_norm = x.permute(0, 2, 3, 4, 1)
271
+ x_norm = self.norm1(x_norm)
272
+ x = x_norm.permute(0, 4, 1, 2, 3)
273
+
274
+ windows, pads = window_partition3d(x, self.window_size)
275
+ attn_windows = self.attn(windows)
276
+ x = window_reverse3d(attn_windows, self.window_size, x.shape[0], x.shape[2], x.shape[3], x.shape[4], pads)
277
+
278
+ x = shortcut + self.drop_path(x)
279
+
280
+ x_norm = x.permute(0, 2, 3, 4, 1)
281
+ x_norm = self.norm2(x_norm)
282
+ x_norm = x_norm.permute(0, 4, 1, 2, 3)
283
+ x_mlp = self.mlp(x_norm.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
284
+ x = x + self.drop_path(x_mlp)
285
+ return x
286
+
287
+
288
+ class PatchEmbed3D(nn.Module):
289
+ def __init__(self, patch_size=(4,4,4), in_chans=1, embed_dim=96):
290
+ super().__init__()
291
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
292
+
293
+ def forward(self, x):
294
+ return self.proj(x)
295
+
296
+
297
+ class PatchMerging3D(nn.Module):
298
+ def __init__(self, dim):
299
+ super().__init__()
300
+ self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
301
+
302
+ def forward(self, x):
303
+ B, C, D, H, W = x.shape
304
+ pad_d, pad_h, pad_w = D % 2, H % 2, W % 2
305
+ if pad_d or pad_h or pad_w:
306
+ x = F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d))
307
+ _, _, Dp, Hp, Wp = x.shape
308
+ x = x.permute(0, 2, 3, 4, 1)
309
+ x = x.view(B, Dp // 2, 2, Hp // 2, 2, Wp // 2, 2, C)
310
+ x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
311
+ x = x.view(B, Dp // 2, Hp // 2, Wp // 2, 8 * C)
312
+ x = self.reduction(x)
313
+ x = x.permute(0, 4, 1, 2, 3).contiguous()
314
+ return x
315
+
316
+
317
+ class SwinTransformer3D(nn.Module):
318
+ def __init__(self, in_chans=1, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
319
+ window_size=(4,4,4), mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1):
320
+ super().__init__()
321
+ self.patch_embed = PatchEmbed3D(in_chans=in_chans, embed_dim=embed_dim)
322
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
323
+ self.layers = nn.ModuleList()
324
+ dim = embed_dim
325
+ for i_layer in range(len(depths)):
326
+ blocks = nn.ModuleList([
327
+ SwinTransformerBlock3D(dim=dim, num_heads=num_heads[i_layer], window_size=window_size,
328
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i])
329
+ for i in range(depths[i_layer])
330
+ ])
331
+ self.layers.append(blocks)
332
+ if i_layer < len(depths)-1:
333
+ self.layers.append(PatchMerging3D(dim))
334
+ dim *= 2
335
+ self.norm = nn.LayerNorm(dim)
336
+ self.avgpool = nn.AdaptiveAvgPool3d(1)
337
+ self.feature_dim = dim
338
+
339
+ def forward(self, x):
340
+ x = self.patch_embed(x)
341
+ for layer in self.layers:
342
+ if isinstance(layer, PatchMerging3D):
343
+ x = layer(x)
344
+ else:
345
+ for blk in layer:
346
+ x = blk(x)
347
+ x = self.avgpool(x).flatten(1)
348
+ x = self.norm(x)
349
+ return x
350
+
351
+
352
+ # ==============================================================================
353
+ # MAE Model
354
+ # ==============================================================================
355
+
356
+ class MAE_Swin3D(nn.Module):
357
+ def __init__(self, input_shape=(60, 128, 128)):
358
+ super().__init__()
359
+ self.input_shape = input_shape
360
+ self.encoder = SwinTransformer3D(in_chans=1)
361
+ decoder_dim = 512
362
+ self.decoder = nn.Sequential(
363
+ nn.Linear(self.encoder.feature_dim, decoder_dim),
364
+ nn.ReLU(),
365
+ nn.Linear(decoder_dim, np.prod(input_shape))
366
+ )
367
+ self.airway_head = nn.Linear(self.encoder.feature_dim, 4 * input_shape[0])
368
+ self.lymph_head = nn.Linear(self.encoder.feature_dim, 3 * input_shape[0])
369
+
370
+ def forward(self, x):
371
+ feat = self.encoder(x)
372
+ recon_flat = self.decoder(feat)
373
+ recon = recon_flat.view(-1, 1, *self.input_shape)
374
+ airway_pred = self.airway_head(feat).view(-1, self.input_shape[0], 4)
375
+ lymph_pred = self.lymph_head(feat).view(-1, self.input_shape[0], 3)
376
+ return {
377
+ 'reconstruction': recon,
378
+ 'airway_pred': airway_pred,
379
+ 'lymph_pred': lymph_pred,
380
+ 'features': feat
381
+ }
382
+
383
+
384
+ # ==============================================================================
385
+ # Augmentations
386
+ # ==============================================================================
387
+
388
+ def augment_volume(volume):
389
+ aug = volume.clone()
390
+ device = aug.device
391
+
392
+ if torch.rand(1) > 0.3:
393
+ shift = (torch.rand(1).to(device) - 0.5) * 0.4
394
+ aug += shift
395
+
396
+ if torch.rand(1) > 0.3:
397
+ scale = 0.7 + torch.rand(1).to(device) * 0.6
398
+ aug *= scale
399
+
400
+ if torch.rand(1) > 0.3:
401
+ noise = torch.randn_like(aug) * 0.1
402
+ aug += noise
403
+
404
+ if torch.rand(1) > 0.5:
405
+ aug = torch.flip(aug, dims=[-1])
406
+
407
+ if torch.rand(1) > 0.5:
408
+ aug = torch.flip(aug, dims=[-2])
409
+
410
+ if torch.rand(1) > 0.7:
411
+ k = torch.randint(1, 4, (1,)).item()
412
+ aug = torch.rot90(aug, k, dims=[-2, -1])
413
+
414
+ if torch.rand(1) > 0.5:
415
+ _, _, D, H, W = aug.shape
416
+ crop_d = int(D * (0.80 + torch.rand(1).item() * 0.15))
417
+ crop_h = int(H * (0.80 + torch.rand(1).item() * 0.15))
418
+ crop_w = int(W * (0.80 + torch.rand(1).item() * 0.15))
419
+ start_d = torch.randint(0, D - crop_d + 1, (1,)).item()
420
+ start_h = torch.randint(0, H - crop_h + 1, (1,)).item()
421
+ start_w = torch.randint(0, W - crop_w + 1, (1,)).item()
422
+ aug = aug[:, :, start_d:start_d+crop_d, start_h:start_h+crop_h, start_w:start_w+crop_w]
423
+ aug = F.interpolate(aug, size=(D, H, W), mode='trilinear', align_corners=False)
424
+
425
+ if torch.rand(1) > 0.7:
426
+ kernel_size = 3
427
+ padding = kernel_size // 2
428
+ aug = F.avg_pool3d(aug, kernel_size=kernel_size, stride=1, padding=padding)
429
+
430
+ if torch.rand(1) > 0.7:
431
+ _, _, D, H, W = aug.shape
432
+ erase_d = int(D * (0.05 + torch.rand(1).item() * 0.10))
433
+ erase_h = int(H * (0.05 + torch.rand(1).item() * 0.10))
434
+ erase_w = int(W * (0.05 + torch.rand(1).item() * 0.10))
435
+ start_d = torch.randint(0, D - erase_d + 1, (1,)).item()
436
+ start_h = torch.randint(0, H - erase_h + 1, (1,)).item()
437
+ start_w = torch.randint(0, W - erase_w + 1, (1,)).item()
438
+ aug[:, :, start_d:start_d+erase_d, start_h:start_h+erase_h, start_w:start_w+erase_w] = aug.mean()
439
+
440
+ aug = torch.clamp(aug, 0, 1)
441
+ return aug
442
+
443
+
444
+ # ==============================================================================
445
+ # Dataset
446
+ # ==============================================================================
447
+
448
+ class OPSCCDataset(Dataset):
449
+ def __init__(self, data_dir: str, cache_asymmetry: bool = True):
450
+ self.data_dir = Path(data_dir)
451
+ self.volume_paths = list(self.data_dir.glob("**/cropped_volume.nii.gz"))
452
+ print(f"Found {len(self.volume_paths)} volumes")
453
+
454
+ self.cache_file = self.data_dir / ".asymmetry_cache.pkl"
455
+ self.cache_asymmetry = cache_asymmetry
456
+ self.asymmetry_cache = {}
457
+ self.airway_detector = AirwayAsymmetryDetector()
458
+ self.lymphnode_detector = GlobalSoftTissueAsymmetryDetector()
459
+
460
+ if self.cache_asymmetry:
461
+ if self.cache_file.is_file():
462
+ try:
463
+ with open(self.cache_file, 'rb') as f:
464
+ self.asymmetry_cache = pickle.load(f)
465
+ print(f"Loaded asymmetry cache ({len(self.asymmetry_cache)} entries)")
466
+ except Exception:
467
+ print("Cache load failed → recomputing")
468
+ self._precompute_asymmetry()
469
+ else:
470
+ print("Computing asymmetry metrics...")
471
+ self._precompute_asymmetry()
472
+ try:
473
+ with open(self.cache_file, 'wb') as f:
474
+ pickle.dump(self.asymmetry_cache, f)
475
+ print("Cache saved")
476
+ except Exception as e:
477
+ print(f"Cache save failed: {e}")
478
+
479
+ def _precompute_asymmetry(self):
480
+ for idx, path in enumerate(tqdm(self.volume_paths, desc="Asymmetry")):
481
+ volume = self._load_volume(path)
482
+ metrics = self._compute_asymmetry(volume)
483
+ self.asymmetry_cache[idx] = metrics
484
+
485
+ def _load_volume(self, path: Path) -> np.ndarray:
486
+ img = nib.load(str(path))
487
+ volume = img.get_fdata().astype(np.float32)
488
+ if volume.ndim == 3 and volume.shape[2] < volume.shape[0]:
489
+ volume = np.transpose(volume, (2, 0, 1))
490
+ return volume
491
+
492
+ def _compute_asymmetry(self, volume: np.ndarray) -> dict:
493
+ airway = self.airway_detector.forward(volume)
494
+ lymphnode = self.lymphnode_detector.forward(volume, airway['midlines'].tolist())
495
+ return {'airway': airway, 'lymphnode': lymphnode}
496
+
497
+ def __len__(self) -> int:
498
+ return len(self.volume_paths)
499
+
500
+ def __getitem__(self, idx: int) -> dict:
501
+ path = self.volume_paths[idx]
502
+ volume = self._load_volume(path)
503
+
504
+ if self.cache_asymmetry and idx in self.asymmetry_cache:
505
+ metrics = self.asymmetry_cache[idx]
506
+ else:
507
+ metrics = self._compute_asymmetry(volume)
508
+
509
+ airway_tensor = np.stack([
510
+ metrics['airway']['effacement'],
511
+ metrics['airway']['mass_effect'],
512
+ metrics['airway']['midline_shift'],
513
+ metrics['airway']['hybrid']
514
+ ], axis=0)
515
+
516
+ lymph_tensor = np.stack([
517
+ metrics['lymphnode']['left_hypo'],
518
+ metrics['lymphnode']['right_hypo'],
519
+ metrics['lymphnode']['hypo_asymmetry']
520
+ ], axis=0)
521
+
522
+ return {
523
+ 'volume': torch.from_numpy(volume).unsqueeze(0).float(),
524
+ 'airway_metrics': torch.from_numpy(airway_tensor).float(),
525
+ 'lymphnode_metrics': torch.from_numpy(lymph_tensor).float(),
526
+ }
527
+
528
+
529
+ # ==============================================================================
530
+ # Loss
531
+ # ==============================================================================
532
+
533
+ class MAEAsymmetryLoss(nn.Module):
534
+ def __init__(self, mask_ratio=0.75, asymmetry_boost=5.0):
535
+ super().__init__()
536
+ self.mse = nn.MSELoss(reduction='none')
537
+ self.mask_ratio = mask_ratio
538
+ self.asymmetry_boost = asymmetry_boost
539
+
540
+ def forward(self, outputs, batch):
541
+ recon = outputs['reconstruction']
542
+ target = batch['volume']
543
+
544
+ B, C, D, H, W = target.shape
545
+ num_patches = D * H * W
546
+ mask = torch.rand(B, num_patches, device=target.device) < self.mask_ratio
547
+ mask = mask.view(B, 1, D, H, W).expand_as(recon)
548
+
549
+ diff = self.mse(recon, target) * mask.float()
550
+
551
+ hybrid = batch['airway_metrics'][:, 3, :]
552
+ hybrid_norm = hybrid / (hybrid.max(dim=1, keepdim=True)[0] + 1e-6)
553
+ slice_weights = 1.0 + self.asymmetry_boost * hybrid_norm
554
+ weights = slice_weights.unsqueeze(1).unsqueeze(3).unsqueeze(4).expand_as(diff)
555
+
556
+ recon_loss = (diff * weights).sum() / (mask.sum() + 1e-6)
557
+
558
+ airway_loss = F.mse_loss(outputs['airway_pred'], batch['airway_metrics'].permute(0, 2, 1))
559
+ lymph_loss = F.mse_loss(outputs['lymph_pred'], batch['lymphnode_metrics'].permute(0, 2, 1))
560
+
561
+ return recon_loss + airway_loss + lymph_loss
562
+
563
+
564
+ # ==============================================================================
565
+ # Trainer
566
+ # ==============================================================================
567
+
568
+ class TrainerWithMonitoring:
569
+ def __init__(self, model, train_loader, device, lr=1e-4, output_dir=None):
570
+ self.model = model.to(device)
571
+ self.device = device
572
+ self.train_loader = train_loader
573
+ self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
574
+ self.loss_fn = MAEAsymmetryLoss()
575
+
576
+ self.output_dir = Path(output_dir) if output_dir else None
577
+ if self.output_dir:
578
+ self.output_dir.mkdir(parents=True, exist_ok=True)
579
+
580
+ self.history = {
581
+ 'epoch': [],
582
+ 'loss': [],
583
+ 'cosine_sim_mean': [],
584
+ 'cosine_sim_std': [],
585
+ }
586
+
587
+ def compute_cosine_similarity(self, n_samples=50):
588
+ self.model.eval()
589
+ similarities = []
590
+ with torch.no_grad():
591
+ for i, batch in enumerate(self.train_loader):
592
+ if i >= n_samples:
593
+ break
594
+ volume = batch['volume'].to(self.device)
595
+ feat1 = self.model.encoder(volume)
596
+ volume_aug = augment_volume(volume)
597
+ feat2 = self.model.encoder(volume_aug)
598
+ feat1_norm = F.normalize(feat1, dim=1)
599
+ feat2_norm = F.normalize(feat2, dim=1)
600
+ sim = (feat1_norm * feat2_norm).sum(dim=1)
601
+ similarities.extend(sim.cpu().numpy().tolist())
602
+ self.model.train()
603
+ return np.mean(similarities), np.std(similarities)
604
+
605
+ def save_checkpoint(self, epoch, is_best=False):
606
+ if not self.output_dir:
607
+ return
608
+ path = self.output_dir / f"checkpoint_epoch_{epoch:03d}.pt"
609
+ torch.save({
610
+ 'epoch': epoch,
611
+ 'model_state_dict': self.model.state_dict(),
612
+ 'optimizer_state_dict': self.optimizer.state_dict(),
613
+ 'history': self.history,
614
+ }, path)
615
+ print(f"Checkpoint saved: {path.name}")
616
+
617
+ if is_best:
618
+ best_path = self.output_dir / "best_model.pt"
619
+ torch.save(self.model.state_dict(), best_path)
620
+ print(f"Best model updated: {best_path.name}")
621
+
622
+ def train(self, n_epochs=100, monitor_every=5, save_every=10,
623
+ early_stop_patience=20, early_stop_after=30):
624
+ best_loss = float('inf')
625
+ patience_counter = 0
626
+ best_epoch = 0
627
+
628
+ for epoch in range(1, n_epochs + 1):
629
+ self.model.train()
630
+ total_loss = 0.0
631
+ num_batches = 0
632
+
633
+ for batch in tqdm(self.train_loader, desc=f"Epoch {epoch}", leave=False):
634
+ volume = batch['volume'].to(self.device)
635
+ airway_metrics = batch['airway_metrics'].to(self.device)
636
+ lymphnode_metrics = batch['lymphnode_metrics'].to(self.device)
637
+
638
+ self.optimizer.zero_grad()
639
+ outputs = self.model(volume)
640
+
641
+ loss = self.loss_fn(outputs, batch)
642
+
643
+ loss.backward()
644
+ self.optimizer.step()
645
+
646
+ total_loss += loss.item()
647
+ num_batches += 1
648
+
649
+ avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
650
+
651
+ is_best = avg_loss < best_loss
652
+ if is_best:
653
+ best_loss = avg_loss
654
+ best_epoch = epoch
655
+ patience_counter = 0
656
+ else:
657
+ patience_counter += 1
658
+
659
+ if epoch % monitor_every == 0 or epoch == 1:
660
+ cos_mean, cos_std = self.compute_cosine_similarity()
661
+ self.history['epoch'].append(epoch)
662
+ self.history['loss'].append(avg_loss)
663
+ self.history['cosine_sim_mean'].append(cos_mean)
664
+ self.history['cosine_sim_std'].append(cos_std)
665
+
666
+ msg = f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | CosSim: {cos_mean:.3f}±{cos_std:.3f}"
667
+ if is_best:
668
+ msg += " ★"
669
+ print(msg)
670
+
671
+ if cos_mean > 0.95:
672
+ print(f" WARNING: Cosine similarity very high ({cos_mean:.3f}) — possible collapse")
673
+
674
+ else:
675
+ msg = f"Epoch {epoch:3d} | Loss: {avg_loss:.4f}"
676
+ if is_best:
677
+ msg += " ★"
678
+ print(msg)
679
+
680
+ if epoch % save_every == 0:
681
+ self.save_checkpoint(epoch, is_best=is_best)
682
+ elif is_best:
683
+ self.save_checkpoint(epoch, is_best=True)
684
+
685
+ if epoch > early_stop_after and patience_counter >= early_stop_patience:
686
+ print(f"Early stopping at epoch {epoch}")
687
+ break
688
+
689
+ if self.output_dir:
690
+ torch.save(self.model.state_dict(), self.output_dir / "final_model.pt")
691
+ with open(self.output_dir / "history.json", 'w') as f:
692
+ json.dump(self.history, f, indent=2)
693
+
694
+ print(f"Best loss: {best_loss:.4f} at epoch {best_epoch}")
695
+ return self.history
696
+
697
+
698
+ # ==============================================================================
699
+ # Main
700
+ # ==============================================================================
701
+
702
+ def main():
703
+ parser = argparse.ArgumentParser(description="3D Swin MAE pretraining")
704
+ parser.add_argument("--data-dir", type=str, required=True, help="Folder containing cropped_volume.nii.gz files")
705
+ parser.add_argument("--output-dir", type=str, default="./checkpoints", help="Folder to save models and logs")
706
+ parser.add_argument("--batch-size", type=int, default=2)
707
+ parser.add_argument("--epochs", type=int, default=100)
708
+ parser.add_argument("--lr", type=float, default=1e-4)
709
+ parser.add_argument("--monitor-every", type=int, default=5)
710
+ parser.add_argument("--save-every", type=int, default=10)
711
+ parser.add_argument("--patience", type=int, default=20)
712
+ parser.add_argument("--early-after", type=int, default=30)
713
+ parser.add_argument("--no-cache", action="store_true")
714
+
715
+ args = parser.parse_args()
716
+
717
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
718
+ print(f"Device: {device}")
719
+
720
+ dataset = OPSCCDataset(
721
+ data_dir=args.data_dir,
722
+ cache_asymmetry=not args.no_cache
723
+ )
724
+
725
+ loader = DataLoader(
726
+ dataset,
727
+ batch_size=args.batch_size,
728
+ shuffle=True,
729
+ num_workers=0,
730
+ pin_memory=device.type == "cuda"
731
+ )
732
+
733
+ model = MAE_Swin3D()
734
+
735
+ trainer = TrainerWithMonitoring(
736
+ model=model,
737
+ train_loader=loader,
738
+ device=device,
739
+ lr=args.lr,
740
+ output_dir=args.output_dir
741
+ )
742
+
743
+ trainer.train(
744
+ n_epochs=args.epochs,
745
+ monitor_every=args.monitor_every,
746
+ save_every=args.save_every,
747
+ early_stop_patience=args.patience,
748
+ early_stop_after=args.early_after
749
+ )
750
+
751
+ print("\nNote: Volumes are expected to be cropped, resized to ~60×128×128, intensities [0,1].")
752
+
753
+
754
+ if __name__ == "__main__":
755
+ main()