File size: 3,848 Bytes
457db56 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import argparse
import os
import sys
import torch as th
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from guided_diffusion import dist_util, logger
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.custom_lidc_dataset import CustomLIDCDataset
from guided_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
args_to_dict,
add_dict_to_argparser,
)
from guided_diffusion.train_util import TrainLoop
def main():
args = create_argparser().parse_args()
# --- DDP Setup ---
dist_util.setup_dist()
local_rank = int(os.environ["LOCAL_RANK"])
th.cuda.set_device(local_rank)
if dist_util.get_rank() == 0:
os.makedirs(args.log_dir, exist_ok=True)
logger.configure(dir=args.log_dir)
logger.log("creating model, diffusion, prior and posterior distribution...")
model, diffusion, prior, posterior = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.to(dist_util.dev())
prior.to(dist_util.dev())
posterior.to(dist_util.dev())
model = DDP(
model,
device_ids=[local_rank],
output_device=local_rank,
broadcast_buffers=False,
bucket_cap_mb=128,
find_unused_parameters=False,
)
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion, maxt=1000)
logger.log("creating data loader...")
dataset = CustomLIDCDataset(
data_root=args.data_dir,
split="train",
image_size=args.image_size,
dataset_type=args.dataset_type,
split_strategy=args.split_strategy,
)
sampler = DistributedSampler(
dataset,
num_replicas=dist_util.get_world_size(),
rank=dist_util.get_rank(),
shuffle=True
)
dataloader = th.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True
)
logger.log("training...")
TrainLoop(
model=model,
diffusion=diffusion,
classifier=None,
prior=prior,
posterior=posterior,
data=iter(dataloader),
dataloader=dataloader,
batch_size=args.batch_size,
microbatch=args.microbatch,
lr=args.lr,
ema_rate=args.ema_rate,
log_interval=args.log_interval,
save_interval=args.save_interval,
resume_checkpoint=args.resume_checkpoint,
use_fp16=args.use_fp16,
fp16_scale_growth=args.fp16_scale_growth,
schedule_sampler=schedule_sampler,
weight_decay=args.weight_decay,
lr_anneal_steps=args.lr_anneal_steps,
# --- PASS NEW ARGUMENT HERE ---
total_steps=args.total_steps,
).run_loop()
def create_argparser():
defaults = dict(
data_dir="./data/LIDC",
dataset_type="lidc",
split_strategy="all_annotations",
log_dir="./results/training_logs",
schedule_sampler="uniform",
lr=1e-4,
weight_decay=0.0,
lr_anneal_steps=0,
batch_size=4,
microbatch=-1,
ema_rate="0.9999",
log_interval=100,
save_interval=5000,
resume_checkpoint="",
use_fp16=False,
fp16_scale_growth=1e-3,
num_workers=4,
# --- NEW ARGUMENT ADDED HERE ---
total_steps=50000,
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()
|