Ambiguous_BaselineRuns / cimd /scripts /segmentation_train_ddp.py
siddharthdhara17's picture
Upload folder using huggingface_hub
457db56 verified
import argparse
import os
import sys
import torch as th
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from guided_diffusion import dist_util, logger
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.custom_lidc_dataset import CustomLIDCDataset
from guided_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
args_to_dict,
add_dict_to_argparser,
)
from guided_diffusion.train_util import TrainLoop
def main():
args = create_argparser().parse_args()
# --- DDP Setup ---
dist_util.setup_dist()
local_rank = int(os.environ["LOCAL_RANK"])
th.cuda.set_device(local_rank)
if dist_util.get_rank() == 0:
os.makedirs(args.log_dir, exist_ok=True)
logger.configure(dir=args.log_dir)
logger.log("creating model, diffusion, prior and posterior distribution...")
model, diffusion, prior, posterior = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.to(dist_util.dev())
prior.to(dist_util.dev())
posterior.to(dist_util.dev())
model = DDP(
model,
device_ids=[local_rank],
output_device=local_rank,
broadcast_buffers=False,
bucket_cap_mb=128,
find_unused_parameters=False,
)
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion, maxt=1000)
logger.log("creating data loader...")
dataset = CustomLIDCDataset(
data_root=args.data_dir,
split="train",
image_size=args.image_size,
dataset_type=args.dataset_type,
split_strategy=args.split_strategy,
)
sampler = DistributedSampler(
dataset,
num_replicas=dist_util.get_world_size(),
rank=dist_util.get_rank(),
shuffle=True
)
dataloader = th.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True
)
logger.log("training...")
TrainLoop(
model=model,
diffusion=diffusion,
classifier=None,
prior=prior,
posterior=posterior,
data=iter(dataloader),
dataloader=dataloader,
batch_size=args.batch_size,
microbatch=args.microbatch,
lr=args.lr,
ema_rate=args.ema_rate,
log_interval=args.log_interval,
save_interval=args.save_interval,
resume_checkpoint=args.resume_checkpoint,
use_fp16=args.use_fp16,
fp16_scale_growth=args.fp16_scale_growth,
schedule_sampler=schedule_sampler,
weight_decay=args.weight_decay,
lr_anneal_steps=args.lr_anneal_steps,
# --- PASS NEW ARGUMENT HERE ---
total_steps=args.total_steps,
).run_loop()
def create_argparser():
defaults = dict(
data_dir="./data/LIDC",
dataset_type="lidc",
split_strategy="all_annotations",
log_dir="./results/training_logs",
schedule_sampler="uniform",
lr=1e-4,
weight_decay=0.0,
lr_anneal_steps=0,
batch_size=4,
microbatch=-1,
ema_rate="0.9999",
log_interval=100,
save_interval=5000,
resume_checkpoint="",
use_fp16=False,
fp16_scale_growth=1e-3,
num_workers=4,
# --- NEW ARGUMENT ADDED HERE ---
total_steps=50000,
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()