| |
| """ |
| Masked Autoencoder (MAE) pretraining with 3D Swin Transformer for OPSCC CT scans. |
| Asymmetry-aware reconstruction + overfitting monitoring via cosine similarity. |
| |
| Run example: |
| python train_mae_swin3d.py --data-dir /path/to/your/nii_folder --output-dir ./checkpoints |
| """ |
|
|
| """ |
| Self-Supervised Learning for OPSCC CT using 3D Swin Transformer MAE |
| with asymmetry-aware reconstruction and overfitting monitoring |
| """ |
|
|
| import argparse |
| import json |
| import pickle |
| import warnings |
| from datetime import datetime |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
|
|
| import numpy as np |
| from scipy import ndimage |
| import nibabel as nib |
| from tqdm import tqdm |
|
|
| warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
|
| |
| |
| |
|
|
| class DropPath(nn.Module): |
| def __init__(self, drop_prob: float = 0.): |
| super().__init__() |
| self.drop_prob = drop_prob |
|
|
| def forward(self, x): |
| if self.drop_prob == 0. or not self.training: |
| return x |
| keep_prob = 1 - self.drop_prob |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
| random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
| random_tensor.floor_() |
| return x.div(keep_prob) * random_tensor |
|
|
|
|
| |
| |
| |
|
|
| class AirwayAsymmetryDetector: |
| def __init__(self, exclude_inferior_fraction=0.15, exclude_superior_fraction=0.10): |
| self.exclude_inferior_fraction = exclude_inferior_fraction |
| self.exclude_superior_fraction = exclude_superior_fraction |
|
|
| def find_midline(self, slice_2d): |
| h, w = slice_2d.shape |
| search_range = w // 8 |
| center = w // 2 |
| best_midline = center |
| best_symmetry = float('inf') |
| for mid in range(center - search_range, center + search_range): |
| compare_width = min(mid, w - mid) |
| if compare_width < 10: |
| continue |
| left = slice_2d[:, mid - compare_width:mid] |
| right = np.flip(slice_2d[:, mid:mid + compare_width], axis=1) |
| diff = np.abs(left - right).mean() |
| if diff < best_symmetry: |
| best_symmetry = diff |
| best_midline = mid |
| return best_midline |
|
|
| def detect_airway(self, slice_2d, air_thresh=0.1): |
| binary = slice_2d < air_thresh |
| labeled, num_features = ndimage.label(binary) |
| edge_labels = set(labeled[0,:].flatten()) | set(labeled[-1,:].flatten()) | \ |
| set(labeled[:,0].flatten()) | set(labeled[:,-1].flatten()) |
| airway_mask = np.zeros_like(binary) |
| for label_id in range(1, num_features + 1): |
| if label_id not in edge_labels: |
| component = labeled == label_id |
| if component.sum() > 20: |
| airway_mask |= component |
| return airway_mask |
|
|
| def forward(self, volume): |
| d, h, w = volume.shape |
| inferior_cutoff = int(d * self.exclude_inferior_fraction) |
| superior_cutoff = int(d * (1 - self.exclude_superior_fraction)) |
| results = {'effacement': [], 'mass_effect': [], 'midline_shift': [], 'hybrid': [], 'midlines': []} |
| for z in range(d): |
| slice_2d = volume[z] |
| midline = self.find_midline(slice_2d) |
| midline_shift = midline - w // 2 |
| results['midlines'].append(midline) |
| airway_mask = self.detect_airway(slice_2d) |
| left_air = airway_mask[:, :midline].sum() |
| right_air = airway_mask[:, midline:].sum() |
| total = left_air + right_air |
| effacement = abs(left_air - right_air) / max(total, 1) if total > 0 else 0 |
| compare_width = min(midline, w - midline) |
| mass_effect = 0 |
| if compare_width > 0: |
| soft_tissue = (slice_2d > 0.2) & (slice_2d < 0.7) |
| left = slice_2d[:, midline-compare_width:midline] * soft_tissue[:, midline-compare_width:midline] |
| right = np.flip(slice_2d[:, midline:midline+compare_width], axis=1) * np.flip(soft_tissue[:, midline:midline+compare_width], axis=1) |
| mass_effect = np.abs(left - right).mean() |
| in_range = inferior_cutoff <= z <= superior_cutoff |
| hybrid = (0.5 * effacement + 0.5 * mass_effect) if in_range else 0 |
| results['effacement'].append(effacement) |
| results['mass_effect'].append(mass_effect) |
| results['midline_shift'].append(midline_shift) |
| results['hybrid'].append(hybrid) |
| return {k: np.array(v) for k, v in results.items()} |
|
|
|
|
| class GlobalSoftTissueAsymmetryDetector: |
| def __init__(self, exclude_inferior_fraction=0.15, exclude_superior_fraction=0.10): |
| self.exclude_inferior_fraction = exclude_inferior_fraction |
| self.exclude_superior_fraction = exclude_superior_fraction |
|
|
| def forward(self, volume, midlines=None): |
| d, h, w = volume.shape |
| if midlines is None: |
| midlines = [w // 2] * d |
| results = {'left_hypo': [], 'right_hypo': [], 'hypo_asymmetry': []} |
| for z in range(d): |
| slice_2d = volume[z] |
| midline = midlines[z] |
| soft_tissue = (slice_2d > 0.2) & (slice_2d < 0.7) |
| hypodense = (slice_2d < 0.35) & soft_tissue |
| hypodense = ndimage.binary_opening(hypodense, iterations=1) |
| hypodense = ndimage.binary_closing(hypodense, iterations=2) |
| labeled, num_features = ndimage.label(hypodense) |
| left_count = right_count = 0 |
| for i in range(1, num_features + 1): |
| region = labeled == i |
| size = region.sum() |
| if 10 < size < 150: |
| centroid_x = np.argwhere(region)[:,1].mean() |
| if centroid_x < midline: |
| left_count += 1 |
| else: |
| right_count += 1 |
| results['left_hypo'].append(left_count) |
| results['right_hypo'].append(right_count) |
| results['hypo_asymmetry'].append(abs(left_count - right_count)) |
| return {k: np.array(v) for k, v in results.items()} |
|
|
|
|
| |
| |
| |
|
|
| def window_partition3d(x, window_size=(4,4,4)): |
| B, C, D, H, W = x.shape |
| ws_d, ws_h, ws_w = window_size |
| pad_d = (ws_d - D % ws_d) % ws_d |
| pad_h = (ws_h - H % ws_h) % ws_h |
| pad_w = (ws_w - W % ws_w) % ws_w |
| x = F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d)) |
| Dp, Hp, Wp = D + pad_d, H + pad_h, W + pad_w |
| x = x.reshape(B, C, Dp // ws_d, ws_d, Hp // ws_h, ws_h, Wp // ws_w, ws_w) |
| x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous() |
| windows = x.reshape(-1, C, ws_d * ws_h * ws_w).permute(0, 2, 1).contiguous() |
| return windows, (pad_d, pad_h, pad_w) |
|
|
|
|
| def window_reverse3d(windows, window_size, B, D, H, W, pads): |
| pad_d, pad_h, pad_w = pads |
| ws_d, ws_h, ws_w = window_size |
| Dp, Hp, Wp = D + pad_d, H + pad_h, W + pad_w |
| x = windows.reshape(B, Dp // ws_d, Hp // ws_h, Wp // ws_w, ws_d, ws_h, ws_w, -1) |
| x = x.permute(0, 7, 1, 4, 2, 5, 3, 6).contiguous() |
| x = x.reshape(B, -1, Dp, Hp, Wp) |
| x = x[:, :, :D, :H, :W] |
| return x |
|
|
|
|
| class WindowAttention3D(nn.Module): |
| def __init__(self, dim, window_size=(4,4,4), num_heads=3, qkv_bias=True, qk_scale=None, |
| attn_drop=0., proj_drop=0.): |
| super().__init__() |
| self.dim = dim |
| self.window_size = window_size |
| self.num_heads = num_heads |
| head_dim = dim // num_heads |
| self.scale = qk_scale or head_dim ** -0.5 |
|
|
| coords_d = torch.arange(window_size[0]) |
| coords_h = torch.arange(window_size[1]) |
| coords_w = torch.arange(window_size[2]) |
| coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing='ij')) |
| coords_flatten = torch.flatten(coords, 1) |
| relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] |
| relative_coords = relative_coords.permute(1, 2, 0).contiguous() |
|
|
| relative_coords[:, :, 0] += window_size[0] - 1 |
| relative_coords[:, :, 1] += window_size[1] - 1 |
| relative_coords[:, :, 2] += window_size[2] - 1 |
|
|
| relative_coords[:, :, 0] *= (2 * window_size[1] - 1) * (2 * window_size[2] - 1) |
| relative_coords[:, :, 1] *= (2 * window_size[2] - 1) |
| self.relative_position_index = relative_coords.sum(-1) |
|
|
| max_rel_pos = self.relative_position_index.max().item() |
| self.relative_position_bias_table = nn.Parameter(torch.zeros((max_rel_pos + 1, num_heads))) |
| nn.init.trunc_normal_(self.relative_position_bias_table, std=.02) |
|
|
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
| self.softmax = nn.Softmax(dim=-1) |
|
|
| def forward(self, x, mask=None): |
| B_, N, C = x.shape |
| rel_index = self.relative_position_index[:N, :N] |
| relative_position_bias = self.relative_position_bias_table[rel_index.view(-1)] |
| relative_position_bias = relative_position_bias.view(N, N, -1).permute(2, 0, 1).contiguous() |
|
|
| qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| q = q * self.scale |
| attn = (q @ k.transpose(-2, -1)) |
| attn = attn + relative_position_bias.unsqueeze(0) |
|
|
| if mask is not None: |
| nW = mask.shape[0] |
| attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) |
| attn = attn.view(-1, self.num_heads, N, N) |
|
|
| attn = self.softmax(attn) |
| attn = self.attn_drop(attn) |
| x = (attn @ v).transpose(1, 2).reshape(B_, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
| return x |
|
|
|
|
| class SwinTransformerBlock3D(nn.Module): |
| def __init__(self, dim, num_heads, window_size=(4,4,4), shift_size=(0,0,0), |
| mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., |
| act_layer=nn.GELU, norm_layer=nn.LayerNorm): |
| super().__init__() |
| self.dim = dim |
| self.window_size = window_size |
| self.shift_size = shift_size |
| self.norm1 = norm_layer(dim) |
| self.attn = WindowAttention3D(dim=dim, window_size=window_size, num_heads=num_heads, |
| qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| self.norm2 = norm_layer(dim) |
| mlp_hidden_dim = int(dim * mlp_ratio) |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, mlp_hidden_dim), act_layer(), nn.Dropout(drop), |
| nn.Linear(mlp_hidden_dim, dim), nn.Dropout(drop) |
| ) |
|
|
| def forward(self, x): |
| shortcut = x |
| x_norm = x.permute(0, 2, 3, 4, 1) |
| x_norm = self.norm1(x_norm) |
| x = x_norm.permute(0, 4, 1, 2, 3) |
|
|
| windows, pads = window_partition3d(x, self.window_size) |
| attn_windows = self.attn(windows) |
| x = window_reverse3d(attn_windows, self.window_size, x.shape[0], x.shape[2], x.shape[3], x.shape[4], pads) |
|
|
| x = shortcut + self.drop_path(x) |
|
|
| x_norm = x.permute(0, 2, 3, 4, 1) |
| x_norm = self.norm2(x_norm) |
| x_norm = x_norm.permute(0, 4, 1, 2, 3) |
| x_mlp = self.mlp(x_norm.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3) |
| x = x + self.drop_path(x_mlp) |
| return x |
|
|
|
|
| class PatchEmbed3D(nn.Module): |
| def __init__(self, patch_size=(4,4,4), in_chans=1, embed_dim=96): |
| super().__init__() |
| self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
|
|
| def forward(self, x): |
| return self.proj(x) |
|
|
|
|
| class PatchMerging3D(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False) |
|
|
| def forward(self, x): |
| B, C, D, H, W = x.shape |
| pad_d, pad_h, pad_w = D % 2, H % 2, W % 2 |
| if pad_d or pad_h or pad_w: |
| x = F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d)) |
| _, _, Dp, Hp, Wp = x.shape |
| x = x.permute(0, 2, 3, 4, 1) |
| x = x.view(B, Dp // 2, 2, Hp // 2, 2, Wp // 2, 2, C) |
| x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous() |
| x = x.view(B, Dp // 2, Hp // 2, Wp // 2, 8 * C) |
| x = self.reduction(x) |
| x = x.permute(0, 4, 1, 2, 3).contiguous() |
| return x |
|
|
|
|
| class SwinTransformer3D(nn.Module): |
| def __init__(self, in_chans=1, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], |
| window_size=(4,4,4), mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1): |
| super().__init__() |
| self.patch_embed = PatchEmbed3D(in_chans=in_chans, embed_dim=embed_dim) |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] |
| self.layers = nn.ModuleList() |
| dim = embed_dim |
| for i_layer in range(len(depths)): |
| blocks = nn.ModuleList([ |
| SwinTransformerBlock3D(dim=dim, num_heads=num_heads[i_layer], window_size=window_size, |
| drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i]) |
| for i in range(depths[i_layer]) |
| ]) |
| self.layers.append(blocks) |
| if i_layer < len(depths)-1: |
| self.layers.append(PatchMerging3D(dim)) |
| dim *= 2 |
| self.norm = nn.LayerNorm(dim) |
| self.avgpool = nn.AdaptiveAvgPool3d(1) |
| self.feature_dim = dim |
|
|
| def forward(self, x): |
| x = self.patch_embed(x) |
| for layer in self.layers: |
| if isinstance(layer, PatchMerging3D): |
| x = layer(x) |
| else: |
| for blk in layer: |
| x = blk(x) |
| x = self.avgpool(x).flatten(1) |
| x = self.norm(x) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class MAE_Swin3D(nn.Module): |
| def __init__(self, input_shape=(60, 128, 128)): |
| super().__init__() |
| self.input_shape = input_shape |
| self.encoder = SwinTransformer3D(in_chans=1) |
| decoder_dim = 512 |
| self.decoder = nn.Sequential( |
| nn.Linear(self.encoder.feature_dim, decoder_dim), |
| nn.ReLU(), |
| nn.Linear(decoder_dim, np.prod(input_shape)) |
| ) |
| self.airway_head = nn.Linear(self.encoder.feature_dim, 4 * input_shape[0]) |
| self.lymph_head = nn.Linear(self.encoder.feature_dim, 3 * input_shape[0]) |
|
|
| def forward(self, x): |
| feat = self.encoder(x) |
| recon_flat = self.decoder(feat) |
| recon = recon_flat.view(-1, 1, *self.input_shape) |
| airway_pred = self.airway_head(feat).view(-1, self.input_shape[0], 4) |
| lymph_pred = self.lymph_head(feat).view(-1, self.input_shape[0], 3) |
| return { |
| 'reconstruction': recon, |
| 'airway_pred': airway_pred, |
| 'lymph_pred': lymph_pred, |
| 'features': feat |
| } |
|
|
|
|
| |
| |
| |
|
|
| def augment_volume(volume): |
| aug = volume.clone() |
| device = aug.device |
|
|
| if torch.rand(1) > 0.3: |
| shift = (torch.rand(1).to(device) - 0.5) * 0.4 |
| aug += shift |
|
|
| if torch.rand(1) > 0.3: |
| scale = 0.7 + torch.rand(1).to(device) * 0.6 |
| aug *= scale |
|
|
| if torch.rand(1) > 0.3: |
| noise = torch.randn_like(aug) * 0.1 |
| aug += noise |
|
|
| if torch.rand(1) > 0.5: |
| aug = torch.flip(aug, dims=[-1]) |
|
|
| if torch.rand(1) > 0.5: |
| aug = torch.flip(aug, dims=[-2]) |
|
|
| if torch.rand(1) > 0.7: |
| k = torch.randint(1, 4, (1,)).item() |
| aug = torch.rot90(aug, k, dims=[-2, -1]) |
|
|
| if torch.rand(1) > 0.5: |
| _, _, D, H, W = aug.shape |
| crop_d = int(D * (0.80 + torch.rand(1).item() * 0.15)) |
| crop_h = int(H * (0.80 + torch.rand(1).item() * 0.15)) |
| crop_w = int(W * (0.80 + torch.rand(1).item() * 0.15)) |
| start_d = torch.randint(0, D - crop_d + 1, (1,)).item() |
| start_h = torch.randint(0, H - crop_h + 1, (1,)).item() |
| start_w = torch.randint(0, W - crop_w + 1, (1,)).item() |
| aug = aug[:, :, start_d:start_d+crop_d, start_h:start_h+crop_h, start_w:start_w+crop_w] |
| aug = F.interpolate(aug, size=(D, H, W), mode='trilinear', align_corners=False) |
|
|
| if torch.rand(1) > 0.7: |
| kernel_size = 3 |
| padding = kernel_size // 2 |
| aug = F.avg_pool3d(aug, kernel_size=kernel_size, stride=1, padding=padding) |
|
|
| if torch.rand(1) > 0.7: |
| _, _, D, H, W = aug.shape |
| erase_d = int(D * (0.05 + torch.rand(1).item() * 0.10)) |
| erase_h = int(H * (0.05 + torch.rand(1).item() * 0.10)) |
| erase_w = int(W * (0.05 + torch.rand(1).item() * 0.10)) |
| start_d = torch.randint(0, D - erase_d + 1, (1,)).item() |
| start_h = torch.randint(0, H - erase_h + 1, (1,)).item() |
| start_w = torch.randint(0, W - erase_w + 1, (1,)).item() |
| aug[:, :, start_d:start_d+erase_d, start_h:start_h+erase_h, start_w:start_w+erase_w] = aug.mean() |
|
|
| aug = torch.clamp(aug, 0, 1) |
| return aug |
|
|
|
|
| |
| |
| |
|
|
| class OPSCCDataset(Dataset): |
| def __init__(self, data_dir: str, cache_asymmetry: bool = True): |
| self.data_dir = Path(data_dir) |
| self.volume_paths = list(self.data_dir.glob("**/cropped_volume.nii.gz")) |
| print(f"Found {len(self.volume_paths)} volumes") |
|
|
| self.cache_file = self.data_dir / ".asymmetry_cache.pkl" |
| self.cache_asymmetry = cache_asymmetry |
| self.asymmetry_cache = {} |
| self.airway_detector = AirwayAsymmetryDetector() |
| self.lymphnode_detector = GlobalSoftTissueAsymmetryDetector() |
|
|
| if self.cache_asymmetry: |
| if self.cache_file.is_file(): |
| try: |
| with open(self.cache_file, 'rb') as f: |
| self.asymmetry_cache = pickle.load(f) |
| print(f"Loaded asymmetry cache ({len(self.asymmetry_cache)} entries)") |
| except Exception: |
| print("Cache load failed → recomputing") |
| self._precompute_asymmetry() |
| else: |
| print("Computing asymmetry metrics...") |
| self._precompute_asymmetry() |
| try: |
| with open(self.cache_file, 'wb') as f: |
| pickle.dump(self.asymmetry_cache, f) |
| print("Cache saved") |
| except Exception as e: |
| print(f"Cache save failed: {e}") |
|
|
| def _precompute_asymmetry(self): |
| for idx, path in enumerate(tqdm(self.volume_paths, desc="Asymmetry")): |
| volume = self._load_volume(path) |
| metrics = self._compute_asymmetry(volume) |
| self.asymmetry_cache[idx] = metrics |
|
|
| def _load_volume(self, path: Path) -> np.ndarray: |
| img = nib.load(str(path)) |
| volume = img.get_fdata().astype(np.float32) |
| if volume.ndim == 3 and volume.shape[2] < volume.shape[0]: |
| volume = np.transpose(volume, (2, 0, 1)) |
| return volume |
|
|
| def _compute_asymmetry(self, volume: np.ndarray) -> dict: |
| airway = self.airway_detector.forward(volume) |
| lymphnode = self.lymphnode_detector.forward(volume, airway['midlines'].tolist()) |
| return {'airway': airway, 'lymphnode': lymphnode} |
|
|
| def __len__(self) -> int: |
| return len(self.volume_paths) |
|
|
| def __getitem__(self, idx: int) -> dict: |
| path = self.volume_paths[idx] |
| volume = self._load_volume(path) |
|
|
| if self.cache_asymmetry and idx in self.asymmetry_cache: |
| metrics = self.asymmetry_cache[idx] |
| else: |
| metrics = self._compute_asymmetry(volume) |
|
|
| airway_tensor = np.stack([ |
| metrics['airway']['effacement'], |
| metrics['airway']['mass_effect'], |
| metrics['airway']['midline_shift'], |
| metrics['airway']['hybrid'] |
| ], axis=0) |
|
|
| lymph_tensor = np.stack([ |
| metrics['lymphnode']['left_hypo'], |
| metrics['lymphnode']['right_hypo'], |
| metrics['lymphnode']['hypo_asymmetry'] |
| ], axis=0) |
|
|
| return { |
| 'volume': torch.from_numpy(volume).unsqueeze(0).float(), |
| 'airway_metrics': torch.from_numpy(airway_tensor).float(), |
| 'lymphnode_metrics': torch.from_numpy(lymph_tensor).float(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| class MAEAsymmetryLoss(nn.Module): |
| def __init__(self, mask_ratio=0.75, asymmetry_boost=5.0): |
| super().__init__() |
| self.mse = nn.MSELoss(reduction='none') |
| self.mask_ratio = mask_ratio |
| self.asymmetry_boost = asymmetry_boost |
|
|
| def forward(self, outputs, batch): |
| recon = outputs['reconstruction'] |
| target = batch['volume'] |
|
|
| B, C, D, H, W = target.shape |
| num_patches = D * H * W |
| mask = torch.rand(B, num_patches, device=target.device) < self.mask_ratio |
| mask = mask.view(B, 1, D, H, W).expand_as(recon) |
|
|
| diff = self.mse(recon, target) * mask.float() |
|
|
| hybrid = batch['airway_metrics'][:, 3, :] |
| hybrid_norm = hybrid / (hybrid.max(dim=1, keepdim=True)[0] + 1e-6) |
| slice_weights = 1.0 + self.asymmetry_boost * hybrid_norm |
| weights = slice_weights.unsqueeze(1).unsqueeze(3).unsqueeze(4).expand_as(diff) |
|
|
| recon_loss = (diff * weights).sum() / (mask.sum() + 1e-6) |
|
|
| airway_loss = F.mse_loss(outputs['airway_pred'], batch['airway_metrics'].permute(0, 2, 1)) |
| lymph_loss = F.mse_loss(outputs['lymph_pred'], batch['lymphnode_metrics'].permute(0, 2, 1)) |
|
|
| return recon_loss + airway_loss + lymph_loss |
|
|
|
|
| |
| |
| |
|
|
| class TrainerWithMonitoring: |
| def __init__(self, model, train_loader, device, lr=1e-4, output_dir=None): |
| self.model = model.to(device) |
| self.device = device |
| self.train_loader = train_loader |
| self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr) |
| self.loss_fn = MAEAsymmetryLoss() |
|
|
| self.output_dir = Path(output_dir) if output_dir else None |
| if self.output_dir: |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| self.history = { |
| 'epoch': [], |
| 'loss': [], |
| 'cosine_sim_mean': [], |
| 'cosine_sim_std': [], |
| } |
|
|
| def compute_cosine_similarity(self, n_samples=50): |
| self.model.eval() |
| similarities = [] |
| with torch.no_grad(): |
| for i, batch in enumerate(self.train_loader): |
| if i >= n_samples: |
| break |
| volume = batch['volume'].to(self.device) |
| feat1 = self.model.encoder(volume) |
| volume_aug = augment_volume(volume) |
| feat2 = self.model.encoder(volume_aug) |
| feat1_norm = F.normalize(feat1, dim=1) |
| feat2_norm = F.normalize(feat2, dim=1) |
| sim = (feat1_norm * feat2_norm).sum(dim=1) |
| similarities.extend(sim.cpu().numpy().tolist()) |
| self.model.train() |
| return np.mean(similarities), np.std(similarities) |
|
|
| def save_checkpoint(self, epoch, is_best=False): |
| if not self.output_dir: |
| return |
| path = self.output_dir / f"checkpoint_epoch_{epoch:03d}.pt" |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': self.model.state_dict(), |
| 'optimizer_state_dict': self.optimizer.state_dict(), |
| 'history': self.history, |
| }, path) |
| print(f"Checkpoint saved: {path.name}") |
|
|
| if is_best: |
| best_path = self.output_dir / "best_model.pt" |
| torch.save(self.model.state_dict(), best_path) |
| print(f"Best model updated: {best_path.name}") |
|
|
| def train(self, n_epochs=100, monitor_every=5, save_every=10, |
| early_stop_patience=20, early_stop_after=30): |
| best_loss = float('inf') |
| patience_counter = 0 |
| best_epoch = 0 |
|
|
| for epoch in range(1, n_epochs + 1): |
| self.model.train() |
| total_loss = 0.0 |
| num_batches = 0 |
|
|
| for batch in tqdm(self.train_loader, desc=f"Epoch {epoch}", leave=False): |
| volume = batch['volume'].to(self.device) |
| airway_metrics = batch['airway_metrics'].to(self.device) |
| lymphnode_metrics = batch['lymphnode_metrics'].to(self.device) |
|
|
| self.optimizer.zero_grad() |
| outputs = self.model(volume) |
|
|
| loss = self.loss_fn(outputs, batch) |
|
|
| loss.backward() |
| self.optimizer.step() |
|
|
| total_loss += loss.item() |
| num_batches += 1 |
|
|
| avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 |
|
|
| is_best = avg_loss < best_loss |
| if is_best: |
| best_loss = avg_loss |
| best_epoch = epoch |
| patience_counter = 0 |
| else: |
| patience_counter += 1 |
|
|
| if epoch % monitor_every == 0 or epoch == 1: |
| cos_mean, cos_std = self.compute_cosine_similarity() |
| self.history['epoch'].append(epoch) |
| self.history['loss'].append(avg_loss) |
| self.history['cosine_sim_mean'].append(cos_mean) |
| self.history['cosine_sim_std'].append(cos_std) |
|
|
| msg = f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | CosSim: {cos_mean:.3f}±{cos_std:.3f}" |
| if is_best: |
| msg += " ★" |
| print(msg) |
|
|
| if cos_mean > 0.95: |
| print(f" WARNING: Cosine similarity very high ({cos_mean:.3f}) — possible collapse") |
|
|
| else: |
| msg = f"Epoch {epoch:3d} | Loss: {avg_loss:.4f}" |
| if is_best: |
| msg += " ★" |
| print(msg) |
|
|
| if epoch % save_every == 0: |
| self.save_checkpoint(epoch, is_best=is_best) |
| elif is_best: |
| self.save_checkpoint(epoch, is_best=True) |
|
|
| if epoch > early_stop_after and patience_counter >= early_stop_patience: |
| print(f"Early stopping at epoch {epoch}") |
| break |
|
|
| if self.output_dir: |
| torch.save(self.model.state_dict(), self.output_dir / "final_model.pt") |
| with open(self.output_dir / "history.json", 'w') as f: |
| json.dump(self.history, f, indent=2) |
|
|
| print(f"Best loss: {best_loss:.4f} at epoch {best_epoch}") |
| return self.history |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="3D Swin MAE pretraining") |
| parser.add_argument("--data-dir", type=str, required=True, help="Folder containing cropped_volume.nii.gz files") |
| parser.add_argument("--output-dir", type=str, default="./checkpoints", help="Folder to save models and logs") |
| parser.add_argument("--batch-size", type=int, default=2) |
| parser.add_argument("--epochs", type=int, default=100) |
| parser.add_argument("--lr", type=float, default=1e-4) |
| parser.add_argument("--monitor-every", type=int, default=5) |
| parser.add_argument("--save-every", type=int, default=10) |
| parser.add_argument("--patience", type=int, default=20) |
| parser.add_argument("--early-after", type=int, default=30) |
| parser.add_argument("--no-cache", action="store_true") |
|
|
| args = parser.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
|
|
| dataset = OPSCCDataset( |
| data_dir=args.data_dir, |
| cache_asymmetry=not args.no_cache |
| ) |
|
|
| loader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=0, |
| pin_memory=device.type == "cuda" |
| ) |
|
|
| model = MAE_Swin3D() |
|
|
| trainer = TrainerWithMonitoring( |
| model=model, |
| train_loader=loader, |
| device=device, |
| lr=args.lr, |
| output_dir=args.output_dir |
| ) |
|
|
| trainer.train( |
| n_epochs=args.epochs, |
| monitor_every=args.monitor_every, |
| save_every=args.save_every, |
| early_stop_patience=args.patience, |
| early_stop_after=args.early_after |
| ) |
|
|
| print("\nNote: Volumes are expected to be cropped, resized to ~60×128×128, intensities [0,1].") |
|
|
|
|
| if __name__ == "__main__": |
| main() |