Ambiguous_BaselineRuns / cimd /scripts /segmentation_train_v4.py
siddharthdhara17's picture
Upload folder using huggingface_hub
457db56 verified
"""
Train a diffusion model on images with proper distributed training support.
"""
import sys
import os
import argparse
import numpy as np
import socket
from datetime import datetime
sys.path.append("..")
sys.path.append(".")
from guided_diffusion import dist_util, logger
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.custom_lidc_dataset import CustomLIDCDataset
from guided_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
args_to_dict,
add_dict_to_argparser,
)
from scripts.metrics import model_size
import torch as th
from torch import nn
import torch.distributed as dist
import torch.multiprocessing as mp
from guided_diffusion.train_util import TrainLoop
def expanduservars(path: str) -> str:
return os.path.expanduser(os.path.expandvars(path))
def find_free_port():
"""Find a free port for distributed training."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
s.listen(1)
port = s.getsockname()[1]
return port
def setup_distributed(rank, world_size, port):
"""Initialize the distributed environment."""
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(port)
# Initialize the process group
if th.cuda.is_available():
backend = 'nccl'
else:
backend = 'gloo'
dist.init_process_group(backend, rank=rank, world_size=world_size)
if th.cuda.is_available():
th.cuda.set_device(rank)
def cleanup_distributed():
"""Clean up the distributed environment."""
dist.destroy_process_group()
def get_device(rank):
"""Get the appropriate device for the given rank."""
if th.cuda.is_available():
return th.device(f'cuda:{rank}')
return th.device('cpu')
def worker_init_fn(worker_id):
"""Seed dataloader workers for reproducibility and ensure consistent tensor types."""
np.random.seed(th.initial_seed() % 2 ** 32)
# Ensure workers use float32 as default
th.set_default_dtype(th.float32)
def run_train(rank, world_size, port, args):
"""Main training function for each process."""
# Set default tensor type to float32
th.set_default_dtype(th.float32)
# Setup distributed training only if world_size > 1
if world_size > 1:
setup_distributed(rank, world_size, port)
# Configure logging (only for rank 0)
if rank == 0:
if args.use_mose_dataset:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_name = f"train_{args.dataset_type}_{timestamp}"
out_dir = f"./results/{args.dataset_type}/{run_name}"
logger.configure(dir=out_dir)
logger.log(f"Log dir: {out_dir}")
else:
logger.configure()
logger.log(f"Arguments: {args.__dict__}")
logger.log(f"World size: {world_size}, Rank: {rank}")
logger.log("creating model, diffusion, prior and posterior distribution...")
model, diffusion, prior, posterior = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
# Move models to appropriate device
device = get_device(rank) if world_size > 1 else (th.device("cuda") if th.cuda.is_available() else th.device("cpu"))
model.to(device)
prior.to(device)
posterior.to(device)
# Ensure all models use float32
model = model.float()
prior = prior.float()
posterior = posterior.float()
# Setup distributed model
if world_size > 1:
model = th.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = th.nn.parallel.DistributedDataParallel(
model,
device_ids=[rank],
output_device=rank,
find_unused_parameters=True
)
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion, maxt=1000)
if rank == 0:
try:
model_size(model, diffusion, prior, posterior, logger)
except Exception as e:
logger.log(f"Could not compute model size: {e}")
# Create dataset
logger.log("Loading dataset...")
if not args.use_mose_dataset:
if args.dataset_type in {"lidc", "multiannotator"}:
ds = CustomLIDCDataset(
data_root=args.data_dir,
split="train",
image_size=args.image_size,
dataset_type=args.dataset_type,
split_strategy=args.split_strategy,
)
logger.log(
f"Using {args.dataset_type} dataset from: {args.data_dir} "
f"(strategy={args.split_strategy})"
)
else:
# Backward-compatible fallback to legacy folder-structured loader.
from guided_diffusion.lidcloader import LIDCDataset
ds = LIDCDataset(args.data_dir, test_flag=False)
logger.log(f"Using converted legacy dataset from: {args.data_dir}")
else:
# Modern dataset (for future use with NPY format)
if args.dataset_type == "lidc":
from guided_diffusion.lidcloader_mose import lidc_Dataloader
ds = lidc_Dataloader(
data_folder=args.lidc_data_folder,
transform_train=None,
transform_test=None
).train_ds
logger.log(f"Using LIDC NPY dataset from: {args.lidc_data_folder}")
elif args.dataset_type == "msmri":
from guided_diffusion.msmri_dataset_mose import msmri_Dataloader
ds = msmri_Dataloader(
data_folder=args.msmri_data_folder,
transform_train=None,
transform_test=None
).train_ds
logger.log(f"Using MSMRI NPY dataset from: {args.msmri_data_folder}")
else:
raise ValueError(f"Unknown dataset type: {args.dataset_type}")
logger.log(f"Dataset size: {len(ds)} samples")
# Setup distributed sampler
if world_size > 1:
train_sampler = th.utils.data.distributed.DistributedSampler(
ds,
rank=rank,
num_replicas=world_size,
shuffle=True
)
batch_size = args.batch_size // world_size
else:
train_sampler = None
batch_size = args.batch_size
# Create data loader
datal = th.utils.data.DataLoader(
ds,
batch_size=batch_size,
sampler=train_sampler,
shuffle=(train_sampler is None),
drop_last=True,
pin_memory=th.cuda.is_available(),
num_workers=args.mp_loaders,
worker_init_fn=worker_init_fn
)
data = iter(datal)
if rank == 0:
logger.log("Starting training...")
# Start training
TrainLoop(
model=model,
diffusion=diffusion,
classifier=None,
data=data,
dataloader=datal,
prior=prior,
posterior=posterior,
batch_size=batch_size,
microbatch=args.microbatch,
lr=args.lr,
ema_rate=args.ema_rate,
log_interval=args.log_interval,
save_interval=args.save_interval,
resume_checkpoint=args.resume_checkpoint,
use_fp16=args.use_fp16,
fp16_scale_growth=args.fp16_scale_growth,
schedule_sampler=schedule_sampler,
weight_decay=args.weight_decay,
lr_anneal_steps=args.lr_anneal_steps,
total_steps=args.total_steps,
).run_loop()
# Cleanup distributed training
if world_size > 1:
cleanup_distributed()
def create_argparser():
defaults = dict(
data_dir="./data/training", # This uses our converted data!
schedule_sampler="uniform",
lr=1e-4,
weight_decay=0.0,
lr_anneal_steps=0,
batch_size=8, # Will be scaled with world_size
microbatch=-1, # -1 disables microbatches
ema_rate="0.9999", # comma-separated list of EMA values
log_interval=100,
save_interval=25000, # Save every 25k steps as requested
mp_loaders=4,
resume_checkpoint='',
use_fp16=False,
fp16_scale_growth=1e-3,
use_mose_dataset=False, # FALSE = Use our converted data in ./data/training
dataset_type="lidc", # "lidc" or "msmri"
split_strategy="all_annotations",
lidc_data_folder="./data/lidc_npy", # Only used if use_mose_dataset=True
msmri_data_folder="./data/msmri_npy", # Only used if use_mose_dataset=True
world_size=1, # Number of GPUs/processes
total_steps=50000,
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
def main():
# Set default tensor type to float32 globally
th.set_default_dtype(th.float32)
args = create_argparser().parse_args()
# Clean up SLURM environment variables that might interfere
os.environ.pop("SLURM_JOBID", None)
# Check available GPUs
if th.cuda.is_available():
world_size = min(args.world_size, th.cuda.device_count())
print(f"Using {world_size} GPU(s)")
else:
world_size = 1
print("Using CPU")
if world_size > 1:
# Multi-GPU distributed training using mp.spawn
port = find_free_port()
print(f"Starting distributed training on port {port}")
# Spawn processes for distributed training
mp.spawn(run_train, args=(world_size, port, args), nprocs=world_size, join=True)
else:
# Single GPU/CPU training
print("Starting single process training")
# For single process, we still call run_train but without distributed setup
run_train(0, 1, find_free_port(), args)
if __name__ == "__main__":
main()