AsymmetryNet / train_mae_swin3d.py
jdmayfield's picture
Create train_mae_swin3d.py
529f7c0 verified
#!/usr/bin/env python
"""
Masked Autoencoder (MAE) pretraining with 3D Swin Transformer for OPSCC CT scans.
Asymmetry-aware reconstruction + overfitting monitoring via cosine similarity.
Run example:
python train_mae_swin3d.py --data-dir /path/to/your/nii_folder --output-dir ./checkpoints
"""
"""
Self-Supervised Learning for OPSCC CT using 3D Swin Transformer MAE
with asymmetry-aware reconstruction and overfitting monitoring
"""
import argparse
import json
import pickle
import warnings
from datetime import datetime
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from scipy import ndimage
import nibabel as nib
from tqdm import tqdm
warnings.filterwarnings("ignore", category=UserWarning)
# ==============================================================================
# Drop Path
# ==============================================================================
class DropPath(nn.Module):
def __init__(self, drop_prob: float = 0.):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.drop_prob == 0. or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
# ==============================================================================
# Asymmetry Detectors
# ==============================================================================
class AirwayAsymmetryDetector:
def __init__(self, exclude_inferior_fraction=0.15, exclude_superior_fraction=0.10):
self.exclude_inferior_fraction = exclude_inferior_fraction
self.exclude_superior_fraction = exclude_superior_fraction
def find_midline(self, slice_2d):
h, w = slice_2d.shape
search_range = w // 8
center = w // 2
best_midline = center
best_symmetry = float('inf')
for mid in range(center - search_range, center + search_range):
compare_width = min(mid, w - mid)
if compare_width < 10:
continue
left = slice_2d[:, mid - compare_width:mid]
right = np.flip(slice_2d[:, mid:mid + compare_width], axis=1)
diff = np.abs(left - right).mean()
if diff < best_symmetry:
best_symmetry = diff
best_midline = mid
return best_midline
def detect_airway(self, slice_2d, air_thresh=0.1):
binary = slice_2d < air_thresh
labeled, num_features = ndimage.label(binary)
edge_labels = set(labeled[0,:].flatten()) | set(labeled[-1,:].flatten()) | \
set(labeled[:,0].flatten()) | set(labeled[:,-1].flatten())
airway_mask = np.zeros_like(binary)
for label_id in range(1, num_features + 1):
if label_id not in edge_labels:
component = labeled == label_id
if component.sum() > 20:
airway_mask |= component
return airway_mask
def forward(self, volume):
d, h, w = volume.shape
inferior_cutoff = int(d * self.exclude_inferior_fraction)
superior_cutoff = int(d * (1 - self.exclude_superior_fraction))
results = {'effacement': [], 'mass_effect': [], 'midline_shift': [], 'hybrid': [], 'midlines': []}
for z in range(d):
slice_2d = volume[z]
midline = self.find_midline(slice_2d)
midline_shift = midline - w // 2
results['midlines'].append(midline)
airway_mask = self.detect_airway(slice_2d)
left_air = airway_mask[:, :midline].sum()
right_air = airway_mask[:, midline:].sum()
total = left_air + right_air
effacement = abs(left_air - right_air) / max(total, 1) if total > 0 else 0
compare_width = min(midline, w - midline)
mass_effect = 0
if compare_width > 0:
soft_tissue = (slice_2d > 0.2) & (slice_2d < 0.7)
left = slice_2d[:, midline-compare_width:midline] * soft_tissue[:, midline-compare_width:midline]
right = np.flip(slice_2d[:, midline:midline+compare_width], axis=1) * np.flip(soft_tissue[:, midline:midline+compare_width], axis=1)
mass_effect = np.abs(left - right).mean()
in_range = inferior_cutoff <= z <= superior_cutoff
hybrid = (0.5 * effacement + 0.5 * mass_effect) if in_range else 0
results['effacement'].append(effacement)
results['mass_effect'].append(mass_effect)
results['midline_shift'].append(midline_shift)
results['hybrid'].append(hybrid)
return {k: np.array(v) for k, v in results.items()}
class GlobalSoftTissueAsymmetryDetector:
def __init__(self, exclude_inferior_fraction=0.15, exclude_superior_fraction=0.10):
self.exclude_inferior_fraction = exclude_inferior_fraction
self.exclude_superior_fraction = exclude_superior_fraction
def forward(self, volume, midlines=None):
d, h, w = volume.shape
if midlines is None:
midlines = [w // 2] * d
results = {'left_hypo': [], 'right_hypo': [], 'hypo_asymmetry': []}
for z in range(d):
slice_2d = volume[z]
midline = midlines[z]
soft_tissue = (slice_2d > 0.2) & (slice_2d < 0.7)
hypodense = (slice_2d < 0.35) & soft_tissue
hypodense = ndimage.binary_opening(hypodense, iterations=1)
hypodense = ndimage.binary_closing(hypodense, iterations=2)
labeled, num_features = ndimage.label(hypodense)
left_count = right_count = 0
for i in range(1, num_features + 1):
region = labeled == i
size = region.sum()
if 10 < size < 150:
centroid_x = np.argwhere(region)[:,1].mean()
if centroid_x < midline:
left_count += 1
else:
right_count += 1
results['left_hypo'].append(left_count)
results['right_hypo'].append(right_count)
results['hypo_asymmetry'].append(abs(left_count - right_count))
return {k: np.array(v) for k, v in results.items()}
# ==============================================================================
# 3D Swin Transformer Components
# ==============================================================================
def window_partition3d(x, window_size=(4,4,4)):
B, C, D, H, W = x.shape
ws_d, ws_h, ws_w = window_size
pad_d = (ws_d - D % ws_d) % ws_d
pad_h = (ws_h - H % ws_h) % ws_h
pad_w = (ws_w - W % ws_w) % ws_w
x = F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d))
Dp, Hp, Wp = D + pad_d, H + pad_h, W + pad_w
x = x.reshape(B, C, Dp // ws_d, ws_d, Hp // ws_h, ws_h, Wp // ws_w, ws_w)
x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
windows = x.reshape(-1, C, ws_d * ws_h * ws_w).permute(0, 2, 1).contiguous()
return windows, (pad_d, pad_h, pad_w)
def window_reverse3d(windows, window_size, B, D, H, W, pads):
pad_d, pad_h, pad_w = pads
ws_d, ws_h, ws_w = window_size
Dp, Hp, Wp = D + pad_d, H + pad_h, W + pad_w
x = windows.reshape(B, Dp // ws_d, Hp // ws_h, Wp // ws_w, ws_d, ws_h, ws_w, -1)
x = x.permute(0, 7, 1, 4, 2, 5, 3, 6).contiguous()
x = x.reshape(B, -1, Dp, Hp, Wp)
x = x[:, :, :D, :H, :W]
return x
class WindowAttention3D(nn.Module):
def __init__(self, dim, window_size=(4,4,4), num_heads=3, qkv_bias=True, qk_scale=None,
attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
coords_d = torch.arange(window_size[0])
coords_h = torch.arange(window_size[1])
coords_w = torch.arange(window_size[2])
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing='ij'))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += window_size[0] - 1
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 2] += window_size[2] - 1
relative_coords[:, :, 0] *= (2 * window_size[1] - 1) * (2 * window_size[2] - 1)
relative_coords[:, :, 1] *= (2 * window_size[2] - 1)
self.relative_position_index = relative_coords.sum(-1)
max_rel_pos = self.relative_position_index.max().item()
self.relative_position_bias_table = nn.Parameter(torch.zeros((max_rel_pos + 1, num_heads)))
nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
B_, N, C = x.shape
rel_index = self.relative_position_index[:N, :N]
relative_position_bias = self.relative_position_bias_table[rel_index.view(-1)]
relative_position_bias = relative_position_bias.view(N, N, -1).permute(2, 0, 1).contiguous()
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock3D(nn.Module):
def __init__(self, dim, num_heads, window_size=(4,4,4), shift_size=(0,0,0),
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.window_size = window_size
self.shift_size = shift_size
self.norm1 = norm_layer(dim)
self.attn = WindowAttention3D(dim=dim, window_size=window_size, num_heads=num_heads,
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim), act_layer(), nn.Dropout(drop),
nn.Linear(mlp_hidden_dim, dim), nn.Dropout(drop)
)
def forward(self, x):
shortcut = x
x_norm = x.permute(0, 2, 3, 4, 1)
x_norm = self.norm1(x_norm)
x = x_norm.permute(0, 4, 1, 2, 3)
windows, pads = window_partition3d(x, self.window_size)
attn_windows = self.attn(windows)
x = window_reverse3d(attn_windows, self.window_size, x.shape[0], x.shape[2], x.shape[3], x.shape[4], pads)
x = shortcut + self.drop_path(x)
x_norm = x.permute(0, 2, 3, 4, 1)
x_norm = self.norm2(x_norm)
x_norm = x_norm.permute(0, 4, 1, 2, 3)
x_mlp = self.mlp(x_norm.permute(0, 2, 3, 4, 1)).permute(0, 4, 1, 2, 3)
x = x + self.drop_path(x_mlp)
return x
class PatchEmbed3D(nn.Module):
def __init__(self, patch_size=(4,4,4), in_chans=1, embed_dim=96):
super().__init__()
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
return self.proj(x)
class PatchMerging3D(nn.Module):
def __init__(self, dim):
super().__init__()
self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
def forward(self, x):
B, C, D, H, W = x.shape
pad_d, pad_h, pad_w = D % 2, H % 2, W % 2
if pad_d or pad_h or pad_w:
x = F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d))
_, _, Dp, Hp, Wp = x.shape
x = x.permute(0, 2, 3, 4, 1)
x = x.view(B, Dp // 2, 2, Hp // 2, 2, Wp // 2, 2, C)
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()
x = x.view(B, Dp // 2, Hp // 2, Wp // 2, 8 * C)
x = self.reduction(x)
x = x.permute(0, 4, 1, 2, 3).contiguous()
return x
class SwinTransformer3D(nn.Module):
def __init__(self, in_chans=1, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=(4,4,4), mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1):
super().__init__()
self.patch_embed = PatchEmbed3D(in_chans=in_chans, embed_dim=embed_dim)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.layers = nn.ModuleList()
dim = embed_dim
for i_layer in range(len(depths)):
blocks = nn.ModuleList([
SwinTransformerBlock3D(dim=dim, num_heads=num_heads[i_layer], window_size=window_size,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i])
for i in range(depths[i_layer])
])
self.layers.append(blocks)
if i_layer < len(depths)-1:
self.layers.append(PatchMerging3D(dim))
dim *= 2
self.norm = nn.LayerNorm(dim)
self.avgpool = nn.AdaptiveAvgPool3d(1)
self.feature_dim = dim
def forward(self, x):
x = self.patch_embed(x)
for layer in self.layers:
if isinstance(layer, PatchMerging3D):
x = layer(x)
else:
for blk in layer:
x = blk(x)
x = self.avgpool(x).flatten(1)
x = self.norm(x)
return x
# ==============================================================================
# MAE Model
# ==============================================================================
class MAE_Swin3D(nn.Module):
def __init__(self, input_shape=(60, 128, 128)):
super().__init__()
self.input_shape = input_shape
self.encoder = SwinTransformer3D(in_chans=1)
decoder_dim = 512
self.decoder = nn.Sequential(
nn.Linear(self.encoder.feature_dim, decoder_dim),
nn.ReLU(),
nn.Linear(decoder_dim, np.prod(input_shape))
)
self.airway_head = nn.Linear(self.encoder.feature_dim, 4 * input_shape[0])
self.lymph_head = nn.Linear(self.encoder.feature_dim, 3 * input_shape[0])
def forward(self, x):
feat = self.encoder(x)
recon_flat = self.decoder(feat)
recon = recon_flat.view(-1, 1, *self.input_shape)
airway_pred = self.airway_head(feat).view(-1, self.input_shape[0], 4)
lymph_pred = self.lymph_head(feat).view(-1, self.input_shape[0], 3)
return {
'reconstruction': recon,
'airway_pred': airway_pred,
'lymph_pred': lymph_pred,
'features': feat
}
# ==============================================================================
# Augmentations
# ==============================================================================
def augment_volume(volume):
aug = volume.clone()
device = aug.device
if torch.rand(1) > 0.3:
shift = (torch.rand(1).to(device) - 0.5) * 0.4
aug += shift
if torch.rand(1) > 0.3:
scale = 0.7 + torch.rand(1).to(device) * 0.6
aug *= scale
if torch.rand(1) > 0.3:
noise = torch.randn_like(aug) * 0.1
aug += noise
if torch.rand(1) > 0.5:
aug = torch.flip(aug, dims=[-1])
if torch.rand(1) > 0.5:
aug = torch.flip(aug, dims=[-2])
if torch.rand(1) > 0.7:
k = torch.randint(1, 4, (1,)).item()
aug = torch.rot90(aug, k, dims=[-2, -1])
if torch.rand(1) > 0.5:
_, _, D, H, W = aug.shape
crop_d = int(D * (0.80 + torch.rand(1).item() * 0.15))
crop_h = int(H * (0.80 + torch.rand(1).item() * 0.15))
crop_w = int(W * (0.80 + torch.rand(1).item() * 0.15))
start_d = torch.randint(0, D - crop_d + 1, (1,)).item()
start_h = torch.randint(0, H - crop_h + 1, (1,)).item()
start_w = torch.randint(0, W - crop_w + 1, (1,)).item()
aug = aug[:, :, start_d:start_d+crop_d, start_h:start_h+crop_h, start_w:start_w+crop_w]
aug = F.interpolate(aug, size=(D, H, W), mode='trilinear', align_corners=False)
if torch.rand(1) > 0.7:
kernel_size = 3
padding = kernel_size // 2
aug = F.avg_pool3d(aug, kernel_size=kernel_size, stride=1, padding=padding)
if torch.rand(1) > 0.7:
_, _, D, H, W = aug.shape
erase_d = int(D * (0.05 + torch.rand(1).item() * 0.10))
erase_h = int(H * (0.05 + torch.rand(1).item() * 0.10))
erase_w = int(W * (0.05 + torch.rand(1).item() * 0.10))
start_d = torch.randint(0, D - erase_d + 1, (1,)).item()
start_h = torch.randint(0, H - erase_h + 1, (1,)).item()
start_w = torch.randint(0, W - erase_w + 1, (1,)).item()
aug[:, :, start_d:start_d+erase_d, start_h:start_h+erase_h, start_w:start_w+erase_w] = aug.mean()
aug = torch.clamp(aug, 0, 1)
return aug
# ==============================================================================
# Dataset
# ==============================================================================
class OPSCCDataset(Dataset):
def __init__(self, data_dir: str, cache_asymmetry: bool = True):
self.data_dir = Path(data_dir)
self.volume_paths = list(self.data_dir.glob("**/cropped_volume.nii.gz"))
print(f"Found {len(self.volume_paths)} volumes")
self.cache_file = self.data_dir / ".asymmetry_cache.pkl"
self.cache_asymmetry = cache_asymmetry
self.asymmetry_cache = {}
self.airway_detector = AirwayAsymmetryDetector()
self.lymphnode_detector = GlobalSoftTissueAsymmetryDetector()
if self.cache_asymmetry:
if self.cache_file.is_file():
try:
with open(self.cache_file, 'rb') as f:
self.asymmetry_cache = pickle.load(f)
print(f"Loaded asymmetry cache ({len(self.asymmetry_cache)} entries)")
except Exception:
print("Cache load failed → recomputing")
self._precompute_asymmetry()
else:
print("Computing asymmetry metrics...")
self._precompute_asymmetry()
try:
with open(self.cache_file, 'wb') as f:
pickle.dump(self.asymmetry_cache, f)
print("Cache saved")
except Exception as e:
print(f"Cache save failed: {e}")
def _precompute_asymmetry(self):
for idx, path in enumerate(tqdm(self.volume_paths, desc="Asymmetry")):
volume = self._load_volume(path)
metrics = self._compute_asymmetry(volume)
self.asymmetry_cache[idx] = metrics
def _load_volume(self, path: Path) -> np.ndarray:
img = nib.load(str(path))
volume = img.get_fdata().astype(np.float32)
if volume.ndim == 3 and volume.shape[2] < volume.shape[0]:
volume = np.transpose(volume, (2, 0, 1))
return volume
def _compute_asymmetry(self, volume: np.ndarray) -> dict:
airway = self.airway_detector.forward(volume)
lymphnode = self.lymphnode_detector.forward(volume, airway['midlines'].tolist())
return {'airway': airway, 'lymphnode': lymphnode}
def __len__(self) -> int:
return len(self.volume_paths)
def __getitem__(self, idx: int) -> dict:
path = self.volume_paths[idx]
volume = self._load_volume(path)
if self.cache_asymmetry and idx in self.asymmetry_cache:
metrics = self.asymmetry_cache[idx]
else:
metrics = self._compute_asymmetry(volume)
airway_tensor = np.stack([
metrics['airway']['effacement'],
metrics['airway']['mass_effect'],
metrics['airway']['midline_shift'],
metrics['airway']['hybrid']
], axis=0)
lymph_tensor = np.stack([
metrics['lymphnode']['left_hypo'],
metrics['lymphnode']['right_hypo'],
metrics['lymphnode']['hypo_asymmetry']
], axis=0)
return {
'volume': torch.from_numpy(volume).unsqueeze(0).float(),
'airway_metrics': torch.from_numpy(airway_tensor).float(),
'lymphnode_metrics': torch.from_numpy(lymph_tensor).float(),
}
# ==============================================================================
# Loss
# ==============================================================================
class MAEAsymmetryLoss(nn.Module):
def __init__(self, mask_ratio=0.75, asymmetry_boost=5.0):
super().__init__()
self.mse = nn.MSELoss(reduction='none')
self.mask_ratio = mask_ratio
self.asymmetry_boost = asymmetry_boost
def forward(self, outputs, batch):
recon = outputs['reconstruction']
target = batch['volume']
B, C, D, H, W = target.shape
num_patches = D * H * W
mask = torch.rand(B, num_patches, device=target.device) < self.mask_ratio
mask = mask.view(B, 1, D, H, W).expand_as(recon)
diff = self.mse(recon, target) * mask.float()
hybrid = batch['airway_metrics'][:, 3, :]
hybrid_norm = hybrid / (hybrid.max(dim=1, keepdim=True)[0] + 1e-6)
slice_weights = 1.0 + self.asymmetry_boost * hybrid_norm
weights = slice_weights.unsqueeze(1).unsqueeze(3).unsqueeze(4).expand_as(diff)
recon_loss = (diff * weights).sum() / (mask.sum() + 1e-6)
airway_loss = F.mse_loss(outputs['airway_pred'], batch['airway_metrics'].permute(0, 2, 1))
lymph_loss = F.mse_loss(outputs['lymph_pred'], batch['lymphnode_metrics'].permute(0, 2, 1))
return recon_loss + airway_loss + lymph_loss
# ==============================================================================
# Trainer
# ==============================================================================
class TrainerWithMonitoring:
def __init__(self, model, train_loader, device, lr=1e-4, output_dir=None):
self.model = model.to(device)
self.device = device
self.train_loader = train_loader
self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
self.loss_fn = MAEAsymmetryLoss()
self.output_dir = Path(output_dir) if output_dir else None
if self.output_dir:
self.output_dir.mkdir(parents=True, exist_ok=True)
self.history = {
'epoch': [],
'loss': [],
'cosine_sim_mean': [],
'cosine_sim_std': [],
}
def compute_cosine_similarity(self, n_samples=50):
self.model.eval()
similarities = []
with torch.no_grad():
for i, batch in enumerate(self.train_loader):
if i >= n_samples:
break
volume = batch['volume'].to(self.device)
feat1 = self.model.encoder(volume)
volume_aug = augment_volume(volume)
feat2 = self.model.encoder(volume_aug)
feat1_norm = F.normalize(feat1, dim=1)
feat2_norm = F.normalize(feat2, dim=1)
sim = (feat1_norm * feat2_norm).sum(dim=1)
similarities.extend(sim.cpu().numpy().tolist())
self.model.train()
return np.mean(similarities), np.std(similarities)
def save_checkpoint(self, epoch, is_best=False):
if not self.output_dir:
return
path = self.output_dir / f"checkpoint_epoch_{epoch:03d}.pt"
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'history': self.history,
}, path)
print(f"Checkpoint saved: {path.name}")
if is_best:
best_path = self.output_dir / "best_model.pt"
torch.save(self.model.state_dict(), best_path)
print(f"Best model updated: {best_path.name}")
def train(self, n_epochs=100, monitor_every=5, save_every=10,
early_stop_patience=20, early_stop_after=30):
best_loss = float('inf')
patience_counter = 0
best_epoch = 0
for epoch in range(1, n_epochs + 1):
self.model.train()
total_loss = 0.0
num_batches = 0
for batch in tqdm(self.train_loader, desc=f"Epoch {epoch}", leave=False):
volume = batch['volume'].to(self.device)
airway_metrics = batch['airway_metrics'].to(self.device)
lymphnode_metrics = batch['lymphnode_metrics'].to(self.device)
self.optimizer.zero_grad()
outputs = self.model(volume)
loss = self.loss_fn(outputs, batch)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
is_best = avg_loss < best_loss
if is_best:
best_loss = avg_loss
best_epoch = epoch
patience_counter = 0
else:
patience_counter += 1
if epoch % monitor_every == 0 or epoch == 1:
cos_mean, cos_std = self.compute_cosine_similarity()
self.history['epoch'].append(epoch)
self.history['loss'].append(avg_loss)
self.history['cosine_sim_mean'].append(cos_mean)
self.history['cosine_sim_std'].append(cos_std)
msg = f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | CosSim: {cos_mean:.3f}±{cos_std:.3f}"
if is_best:
msg += " ★"
print(msg)
if cos_mean > 0.95:
print(f" WARNING: Cosine similarity very high ({cos_mean:.3f}) — possible collapse")
else:
msg = f"Epoch {epoch:3d} | Loss: {avg_loss:.4f}"
if is_best:
msg += " ★"
print(msg)
if epoch % save_every == 0:
self.save_checkpoint(epoch, is_best=is_best)
elif is_best:
self.save_checkpoint(epoch, is_best=True)
if epoch > early_stop_after and patience_counter >= early_stop_patience:
print(f"Early stopping at epoch {epoch}")
break
if self.output_dir:
torch.save(self.model.state_dict(), self.output_dir / "final_model.pt")
with open(self.output_dir / "history.json", 'w') as f:
json.dump(self.history, f, indent=2)
print(f"Best loss: {best_loss:.4f} at epoch {best_epoch}")
return self.history
# ==============================================================================
# Main
# ==============================================================================
def main():
parser = argparse.ArgumentParser(description="3D Swin MAE pretraining")
parser.add_argument("--data-dir", type=str, required=True, help="Folder containing cropped_volume.nii.gz files")
parser.add_argument("--output-dir", type=str, default="./checkpoints", help="Folder to save models and logs")
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--monitor-every", type=int, default=5)
parser.add_argument("--save-every", type=int, default=10)
parser.add_argument("--patience", type=int, default=20)
parser.add_argument("--early-after", type=int, default=30)
parser.add_argument("--no-cache", action="store_true")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
dataset = OPSCCDataset(
data_dir=args.data_dir,
cache_asymmetry=not args.no_cache
)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=0,
pin_memory=device.type == "cuda"
)
model = MAE_Swin3D()
trainer = TrainerWithMonitoring(
model=model,
train_loader=loader,
device=device,
lr=args.lr,
output_dir=args.output_dir
)
trainer.train(
n_epochs=args.epochs,
monitor_every=args.monitor_every,
save_every=args.save_every,
early_stop_patience=args.patience,
early_stop_after=args.early_after
)
print("\nNote: Volumes are expected to be cropped, resized to ~60×128×128, intensities [0,1].")
if __name__ == "__main__":
main()