hanjang's picture
Upload folder using huggingface_hub
24e5510 verified
#!/usr/bin/env python3
import os
import argparse
import torch
import numpy as np
import random
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import nibabel as nib
from monai.transforms import (
Compose, LoadImaged, ScaleIntensityd, Rand3DElasticd, RandAffined, RandGaussianSmoothd,
RandFlipd, RandScaleIntensityd, RandShiftIntensityd,
RandAdjustContrastd, RandGaussianSharpend, RandHistogramShiftd,
RandCoarseDropoutd, RandSpatialCropd, SpatialPadd,
EnsureChannelFirstd, Orientationd, RandGaussianNoised, NormalizeIntensityd,
RandCropByPosNegLabeld,
)
from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, load_json
from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed, nnUNet_results
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnInteractive.trainer import nnInteractiveTrainer
class MENRTDataset(torch.utils.data.Dataset):
def __init__(self, data_dir, patch_size=(192,192,192), transform=None, mode='train'):
"""
Dataset for BraTS-MEN-RT data
Args:
data_dir: Path to the data directory
patch_size: Size of the patches to extract
transform: Transforms to apply to the data
mode: 'train' or 'test'
"""
self.data_dir = f'{data_dir}/{mode}'
self.transform = transform
self.patch_size = patch_size
self.cases = sorted([d for d in os.listdir(self.data_dir) if os.path.isdir(os.path.join(self.data_dir, d))])
def __len__(self):
return len(self.cases)
def __getitem__(self, idx):
case = self.cases[idx]
case_dir = os.path.join(self.data_dir, case)
img_path = os.path.join(case_dir, f'{case}_t1c.nii.gz')
seg_path = os.path.join(case_dir, f'{case}_gtv.nii.gz')
data = {
'image': img_path,
'mask': seg_path,
'case': case
}
if self.transform:
data = self.transform(data)
if isinstance(data, list):
data = data[0]
# Convert data to the format expected by nnInteractiveTrainer
if 'image' in data and 'mask' in data:
return {
'data': data['image'],
'target': data['mask'],
'case_id': data['case']
}
else:
return data
def get_transforms(config, mode='train'):
if mode == 'train':
return Compose([
LoadImaged(keys=['image', 'mask']),
EnsureChannelFirstd(keys=['image', 'mask']),
Orientationd(keys=['image', 'mask'], axcodes='RAS'),
# Spacingd(keys=['image', 'mask'], pixdim=(1.0, 1.0, 1.0), mode=('bilinear', 'nearest')),
SpatialPadd(keys=['image', 'mask'], spatial_size=config.patch_size),
RandFlipd(keys=['image', 'mask'], prob=0.5, spatial_axis=0),
RandFlipd(keys=['image', 'mask'], prob=0.5, spatial_axis=1),
RandFlipd(keys=['image', 'mask'], prob=0.5, spatial_axis=2),
RandAffined(
keys=['image', 'mask'],
rotate_range=(np.pi/6, np.pi/6, np.pi/6), # ±30°
scale_range=(0.1, 0.1, 0.1), # ±10%
mode=('bilinear', 'nearest'),
prob=0.2
),
Rand3DElasticd(
keys=['image', 'mask'],
sigma_range=(5, 7),
magnitude_range=(100, 200),
mode=('bilinear', 'nearest'),
prob=0.2
),
RandGaussianSmoothd(
keys=['image'],
prob=0.2,
sigma_x=(0.5, 1.0),
sigma_y=(0.5, 1.0),
sigma_z=(0.5, 1.0)
),
RandScaleIntensityd(keys=['image'], prob=0.2, factors=0.1),
RandShiftIntensityd(keys=['image'], prob=0.2, offsets=0.1),
RandAdjustContrastd(keys=['image'], prob=0.2, gamma=(0.8, 1.2)),
NormalizeIntensityd(keys=['image'], nonzero=True, channel_wise=True),
# 가우시안 노이즈 (p=0.2)
RandGaussianNoised(
keys=['image'],
prob=0.2,
mean=0.0,
std=0.05
),
RandCropByPosNegLabeld(
keys=['image', 'mask'],
label_key='mask',
spatial_size=config.patch_size,
pos=1,
neg=1,
num_samples=1,
allow_smaller=True,
),
])
else:
return Compose([
LoadImaged(keys=['image', 'mask']),
EnsureChannelFirstd(keys=['image', 'mask']),
Orientationd(keys=['image', 'mask'], axcodes='RAS'),
# Spacingd(keys=['image', 'mask'], pixdim=(1.0, 1.0, 1.0), mode=('bilinear', 'nearest')),
SpatialPadd(keys=['image', 'mask'], spatial_size=config.patch_size),
NormalizeIntensityd(keys=['image'], nonzero=True, channel_wise=True),
])
def setup_trainer(config):
"""
Set up the trainer with the configuration
Args:
config: Configuration object
Returns:
Trainer, train_loader, val_loader
"""
# Set random seeds for reproducibility
if config.seed is not None:
torch.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)
np.random.seed(config.seed)
random.seed(config.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set up paths
dataset_name = maybe_convert_to_dataset_name(config.dataset_id)
preprocessed_dataset_folder = join(nnUNet_preprocessed, dataset_name)
plans_file = join(preprocessed_dataset_folder, "nnUNetPlans.json")
# Load plans and dataset JSON
plans = load_json(plans_file)
dataset_json = load_json(join(preprocessed_dataset_folder, "dataset.json"))
# Initialize trainer
trainer = nnInteractiveTrainer(
plans=plans,
configuration="3d_fullres",
fold=config.fold,
dataset_json=dataset_json,
device=device,
point_radius=config.point_radius,
preferred_scribble_thickness=config.scribble_thickness,
interaction_decay=config.interaction_decay
)
# Initialize the trainer
trainer.initialize()
# Load pretrained weights if specified
if config.pretrained_weights:
trainer.load_checkpoint(config.pretrained_weights)
# Setup data
train_transform = get_transforms(config, mode='train')
val_transform = get_transforms(config, mode='test')
train_dataset = MENRTDataset(
config.data_dir,
patch_size=config.patch_size,
transform=train_transform,
mode='train'
)
val_dataset = MENRTDataset(
config.data_dir,
patch_size=config.patch_size,
transform=val_transform,
mode='test'
)
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True
)
# Set up TensorBoard writer
output_folder = trainer.output_folder
log_dir = join(output_folder, "logs")
maybe_mkdir_p(log_dir)
tensorboard_writer = SummaryWriter(log_dir=log_dir)
trainer.set_tensorboard_writer(tensorboard_writer)
return trainer, train_loader, val_loader, tensorboard_writer
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description='Train nnInteractive model on BraTS-MEN-RT dataset')
# Dataset arguments
parser.add_argument('--data_dir', type=str, required=True, help='Path to the data directory')
parser.add_argument('--dataset_id', type=str, required=True, help='Dataset ID for nnUNet')
parser.add_argument('--fold', type=int, default=0, help='Fold to train on')
# Training arguments
parser.add_argument('--batch_size', type=int, default=2, help='Batch size for training')
parser.add_argument('--num_epochs', type=int, default=1000, help='Number of epochs to train for')
parser.add_argument('--patch_size', type=tuple, default=(128, 128, 128), help='Patch size for training')
parser.add_argument('--spacing', type=tuple, default=(1.0, 1.0, 1.0), help='Spacing for resampling')
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading')
parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
parser.add_argument('--pin_memory', type=bool, default=True, help='Use pinned memory for data loading')
# Interactive segmentation arguments
parser.add_argument('--point_radius', type=int, default=4, help='Radius for point interactions')
parser.add_argument('--scribble_thickness', type=int, default=2, help='Thickness for scribble interactions')
parser.add_argument('--interaction_decay', type=float, default=0.9, help='Decay factor for interactions')
# Model arguments
parser.add_argument('--pretrained_weights', type=str, default=None, help='Path to pretrained weights')
parser.add_argument('--continue_training', action='store_true', help='Continue training from last checkpoint')
parser.add_argument('--only_val', action='store_true', help='Only run validation')
return parser.parse_args()
def main():
"""Main training function"""
config = parse_args()
trainer, train_loader, val_loader, tensorboard_writer = setup_trainer(config)
# Print some info
print(f"Training on {len(train_loader.dataset)} samples")
print(f"Validating on {len(val_loader.dataset)} samples")
if config.only_val:
print("Running validation only...")
trainer.run_validation()
return
# Training loop
print(f"Starting training for {config.num_epochs} epochs...")
trainer.on_train_start()
trainer.train(config.num_epochs)
# Final validation with best checkpoint
trainer.load_checkpoint(join(trainer.output_folder, "checkpoint_best.pth"))
trainer.run_validation()
tensorboard_writer.close()
print("Training completed!")
if __name__ == "__main__":
main()