Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| import os | |
| import argparse | |
| import torch | |
| import numpy as np | |
| import random | |
| from torch.utils.data import DataLoader | |
| from torch.utils.tensorboard import SummaryWriter | |
| import nibabel as nib | |
| from monai.transforms import ( | |
| Compose, LoadImaged, ScaleIntensityd, Rand3DElasticd, RandAffined, RandGaussianSmoothd, | |
| RandFlipd, RandScaleIntensityd, RandShiftIntensityd, | |
| RandAdjustContrastd, RandGaussianSharpend, RandHistogramShiftd, | |
| RandCoarseDropoutd, RandSpatialCropd, SpatialPadd, | |
| EnsureChannelFirstd, Orientationd, RandGaussianNoised, NormalizeIntensityd, | |
| RandCropByPosNegLabeld, | |
| ) | |
| from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, load_json | |
| from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed, nnUNet_results | |
| from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name | |
| from nnInteractive.trainer import nnInteractiveTrainer | |
| class MENRTDataset(torch.utils.data.Dataset): | |
| def __init__(self, data_dir, patch_size=(192,192,192), transform=None, mode='train'): | |
| """ | |
| Dataset for BraTS-MEN-RT data | |
| Args: | |
| data_dir: Path to the data directory | |
| patch_size: Size of the patches to extract | |
| transform: Transforms to apply to the data | |
| mode: 'train' or 'test' | |
| """ | |
| self.data_dir = f'{data_dir}/{mode}' | |
| self.transform = transform | |
| self.patch_size = patch_size | |
| self.cases = sorted([d for d in os.listdir(self.data_dir) if os.path.isdir(os.path.join(self.data_dir, d))]) | |
| def __len__(self): | |
| return len(self.cases) | |
| def __getitem__(self, idx): | |
| case = self.cases[idx] | |
| case_dir = os.path.join(self.data_dir, case) | |
| img_path = os.path.join(case_dir, f'{case}_t1c.nii.gz') | |
| seg_path = os.path.join(case_dir, f'{case}_gtv.nii.gz') | |
| data = { | |
| 'image': img_path, | |
| 'mask': seg_path, | |
| 'case': case | |
| } | |
| if self.transform: | |
| data = self.transform(data) | |
| if isinstance(data, list): | |
| data = data[0] | |
| # Convert data to the format expected by nnInteractiveTrainer | |
| if 'image' in data and 'mask' in data: | |
| return { | |
| 'data': data['image'], | |
| 'target': data['mask'], | |
| 'case_id': data['case'] | |
| } | |
| else: | |
| return data | |
| def get_transforms(config, mode='train'): | |
| if mode == 'train': | |
| return Compose([ | |
| LoadImaged(keys=['image', 'mask']), | |
| EnsureChannelFirstd(keys=['image', 'mask']), | |
| Orientationd(keys=['image', 'mask'], axcodes='RAS'), | |
| # Spacingd(keys=['image', 'mask'], pixdim=(1.0, 1.0, 1.0), mode=('bilinear', 'nearest')), | |
| SpatialPadd(keys=['image', 'mask'], spatial_size=config.patch_size), | |
| RandFlipd(keys=['image', 'mask'], prob=0.5, spatial_axis=0), | |
| RandFlipd(keys=['image', 'mask'], prob=0.5, spatial_axis=1), | |
| RandFlipd(keys=['image', 'mask'], prob=0.5, spatial_axis=2), | |
| RandAffined( | |
| keys=['image', 'mask'], | |
| rotate_range=(np.pi/6, np.pi/6, np.pi/6), # ±30° | |
| scale_range=(0.1, 0.1, 0.1), # ±10% | |
| mode=('bilinear', 'nearest'), | |
| prob=0.2 | |
| ), | |
| Rand3DElasticd( | |
| keys=['image', 'mask'], | |
| sigma_range=(5, 7), | |
| magnitude_range=(100, 200), | |
| mode=('bilinear', 'nearest'), | |
| prob=0.2 | |
| ), | |
| RandGaussianSmoothd( | |
| keys=['image'], | |
| prob=0.2, | |
| sigma_x=(0.5, 1.0), | |
| sigma_y=(0.5, 1.0), | |
| sigma_z=(0.5, 1.0) | |
| ), | |
| RandScaleIntensityd(keys=['image'], prob=0.2, factors=0.1), | |
| RandShiftIntensityd(keys=['image'], prob=0.2, offsets=0.1), | |
| RandAdjustContrastd(keys=['image'], prob=0.2, gamma=(0.8, 1.2)), | |
| NormalizeIntensityd(keys=['image'], nonzero=True, channel_wise=True), | |
| # 가우시안 노이즈 (p=0.2) | |
| RandGaussianNoised( | |
| keys=['image'], | |
| prob=0.2, | |
| mean=0.0, | |
| std=0.05 | |
| ), | |
| RandCropByPosNegLabeld( | |
| keys=['image', 'mask'], | |
| label_key='mask', | |
| spatial_size=config.patch_size, | |
| pos=1, | |
| neg=1, | |
| num_samples=1, | |
| allow_smaller=True, | |
| ), | |
| ]) | |
| else: | |
| return Compose([ | |
| LoadImaged(keys=['image', 'mask']), | |
| EnsureChannelFirstd(keys=['image', 'mask']), | |
| Orientationd(keys=['image', 'mask'], axcodes='RAS'), | |
| # Spacingd(keys=['image', 'mask'], pixdim=(1.0, 1.0, 1.0), mode=('bilinear', 'nearest')), | |
| SpatialPadd(keys=['image', 'mask'], spatial_size=config.patch_size), | |
| NormalizeIntensityd(keys=['image'], nonzero=True, channel_wise=True), | |
| ]) | |
| def setup_trainer(config): | |
| """ | |
| Set up the trainer with the configuration | |
| Args: | |
| config: Configuration object | |
| Returns: | |
| Trainer, train_loader, val_loader | |
| """ | |
| # Set random seeds for reproducibility | |
| if config.seed is not None: | |
| torch.manual_seed(config.seed) | |
| torch.cuda.manual_seed_all(config.seed) | |
| np.random.seed(config.seed) | |
| random.seed(config.seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| # Device setup | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Set up paths | |
| dataset_name = maybe_convert_to_dataset_name(config.dataset_id) | |
| preprocessed_dataset_folder = join(nnUNet_preprocessed, dataset_name) | |
| plans_file = join(preprocessed_dataset_folder, "nnUNetPlans.json") | |
| # Load plans and dataset JSON | |
| plans = load_json(plans_file) | |
| dataset_json = load_json(join(preprocessed_dataset_folder, "dataset.json")) | |
| # Initialize trainer | |
| trainer = nnInteractiveTrainer( | |
| plans=plans, | |
| configuration="3d_fullres", | |
| fold=config.fold, | |
| dataset_json=dataset_json, | |
| device=device, | |
| point_radius=config.point_radius, | |
| preferred_scribble_thickness=config.scribble_thickness, | |
| interaction_decay=config.interaction_decay | |
| ) | |
| # Initialize the trainer | |
| trainer.initialize() | |
| # Load pretrained weights if specified | |
| if config.pretrained_weights: | |
| trainer.load_checkpoint(config.pretrained_weights) | |
| # Setup data | |
| train_transform = get_transforms(config, mode='train') | |
| val_transform = get_transforms(config, mode='test') | |
| train_dataset = MENRTDataset( | |
| config.data_dir, | |
| patch_size=config.patch_size, | |
| transform=train_transform, | |
| mode='train' | |
| ) | |
| val_dataset = MENRTDataset( | |
| config.data_dir, | |
| patch_size=config.patch_size, | |
| transform=val_transform, | |
| mode='test' | |
| ) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=config.batch_size, | |
| shuffle=True, | |
| num_workers=config.num_workers, | |
| pin_memory=True | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=config.num_workers, | |
| pin_memory=True | |
| ) | |
| # Set up TensorBoard writer | |
| output_folder = trainer.output_folder | |
| log_dir = join(output_folder, "logs") | |
| maybe_mkdir_p(log_dir) | |
| tensorboard_writer = SummaryWriter(log_dir=log_dir) | |
| trainer.set_tensorboard_writer(tensorboard_writer) | |
| return trainer, train_loader, val_loader, tensorboard_writer | |
| def parse_args(): | |
| """Parse command line arguments""" | |
| parser = argparse.ArgumentParser(description='Train nnInteractive model on BraTS-MEN-RT dataset') | |
| # Dataset arguments | |
| parser.add_argument('--data_dir', type=str, required=True, help='Path to the data directory') | |
| parser.add_argument('--dataset_id', type=str, required=True, help='Dataset ID for nnUNet') | |
| parser.add_argument('--fold', type=int, default=0, help='Fold to train on') | |
| # Training arguments | |
| parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training') | |
| parser.add_argument('--num_epochs', type=int, default=1000, help='Number of epochs to train for') | |
| parser.add_argument('--patch_size', type=tuple, default=(128, 128, 128), help='Patch size for training') | |
| parser.add_argument('--spacing', type=tuple, default=(1.0, 1.0, 1.0), help='Spacing for resampling') | |
| parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading') | |
| parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility') | |
| parser.add_argument('--pin_memory', type=bool, default=True, help='Use pinned memory for data loading') | |
| # Interactive segmentation arguments | |
| parser.add_argument('--point_radius', type=int, default=4, help='Radius for point interactions') | |
| parser.add_argument('--scribble_thickness', type=int, default=2, help='Thickness for scribble interactions') | |
| parser.add_argument('--interaction_decay', type=float, default=0.9, help='Decay factor for interactions') | |
| # Model arguments | |
| parser.add_argument('--pretrained_weights', type=str, default=None, help='Path to pretrained weights') | |
| parser.add_argument('--continue_training', action='store_true', help='Continue training from last checkpoint') | |
| parser.add_argument('--only_val', action='store_true', help='Only run validation') | |
| return parser.parse_args() | |
| def main(): | |
| """Main training function""" | |
| config = parse_args() | |
| trainer, train_loader, val_loader, tensorboard_writer = setup_trainer(config) | |
| # Print some info | |
| print(f"Training on {len(train_loader.dataset)} samples") | |
| print(f"Validating on {len(val_loader.dataset)} samples") | |
| if config.only_val: | |
| print("Running validation only...") | |
| trainer.run_validation() | |
| return | |
| # Training loop | |
| print(f"Starting training for {config.num_epochs} epochs...") | |
| trainer.on_train_start() | |
| trainer.train(config.num_epochs) | |
| # Final validation with best checkpoint | |
| trainer.load_checkpoint(join(trainer.output_folder, "checkpoint_best.pth")) | |
| trainer.run_validation() | |
| tensorboard_writer.close() | |
| print("Training completed!") | |
| if __name__ == "__main__": | |
| main() |