import copy import functools import os import blobfile as bf import time import torch as th import torch.distributed as dist from torch.nn.parallel.distributed import DistributedDataParallel as DDP from torch.optim import AdamW from . import dist_util, logger from .fp16_util import MixedPrecisionTrainer from .nn import update_ema from .resample import LossAwareSampler, UniformSampler # For ImageNet experiments, this was a good default value. # We found that the lg_loss_scale quickly climbed to # 20-21 within the first ~1K steps of training. INITIAL_LOG_LOSS_SCALE = 20.0 def visualize(img): _min = img.min() _max = img.max() normalized_img = (img - _min)/ (_max - _min) return normalized_img class TrainLoop: def __init__( self, *, model, classifier, diffusion, data, dataloader, prior, posterior, batch_size, microbatch, lr, ema_rate, log_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=1e-3, schedule_sampler=None, weight_decay=0.0, lr_anneal_steps=0, # --- NEW ARGUMENT ADDED HERE --- total_steps=0, ): self.model = model self.dataloader=dataloader self.classifier = classifier self.diffusion = diffusion self.data = data self.batch_size = batch_size self.microbatch = microbatch if microbatch > 0 else batch_size self.lr = lr self.ema_rate = ( [ema_rate] if isinstance(ema_rate, float) else [float(x) for x in ema_rate.split(",")] ) self.log_interval = log_interval self.save_interval = save_interval self.resume_checkpoint = resume_checkpoint self.use_fp16 = use_fp16 self.fp16_scale_growth = fp16_scale_growth self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) self.weight_decay = weight_decay self.lr_anneal_steps = lr_anneal_steps # --- NEW ATTRIBUTE STORED HERE --- self.total_steps = total_steps self.prior = prior self.posterior = posterior self.step = 0 self.resume_step = 0 if isinstance(self.model, th.nn.DataParallel): # This case might not be hit with DDP, but left for safety. self.global_batch = self.batch_size else: self.global_batch = self.batch_size * dist_util.get_world_size() self.sync_cuda = th.cuda.is_available() self._load_and_sync_parameters() self.mp_trainer = MixedPrecisionTrainer( model=self.model, use_fp16=self.use_fp16, fp16_scale_growth=fp16_scale_growth, ) self.opt = AdamW( self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay ) if self.resume_step: self._load_optimizer_state() # Model was resumed, either due to a restart or a checkpoint # being specified at the command line. self.ema_params = [ self._load_ema_parameters(rate) for rate in self.ema_rate ] else: self.ema_params = [ copy.deepcopy(self.mp_trainer.master_params) for _ in range(len(self.ema_rate)) ] self.use_ddp = isinstance(self.model, DDP) self.ddp_model = self.model if not self.use_ddp and dist_util.get_world_size() > 1: logger.warn( "Running with world_size > 1 but model is not wrapped in DDP." ) def _load_and_sync_parameters(self): resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint if resume_checkpoint: self.resume_step = parse_resume_step_from_filename(resume_checkpoint) if dist_util.get_rank() == 0: logger.log(f"loading model from checkpoint: {resume_checkpoint}...") state_dict = dist_util.load_state_dict(resume_checkpoint, map_location=dist_util.dev()) if isinstance(self.model, DDP): self.model.module.load_state_dict(state_dict) else: self.model.load_state_dict(state_dict) dist_util.sync_params(self.model.parameters()) def _load_ema_parameters(self, rate): ema_params = copy.deepcopy(self.mp_trainer.master_params) main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) if ema_checkpoint: if dist_util.get_rank() == 0: logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") state_dict = dist_util.load_state_dict( ema_checkpoint, map_location=dist_util.dev() ) ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) dist_util.sync_params(ema_params) return ema_params def _load_optimizer_state(self): main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint opt_checkpoint = bf.join( bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" ) if bf.exists(opt_checkpoint): logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") state_dict = dist_util.load_state_dict( opt_checkpoint, map_location=dist_util.dev() ) self.opt.load_state_dict(state_dict) def run_loop(self): data_iter = iter(self.dataloader) # --- LOOP CONDITION MODIFIED HERE --- # The loop now runs until the target number of steps is reached. while self.step + self.resume_step < self.total_steps: try: batch, cond = next(data_iter) except StopIteration: # Re-initialize data loader when it runs out data_iter = iter(self.dataloader) batch, cond = next(data_iter) self.run_step(batch, cond) if self.step % self.log_interval == 0: logger.dumpkvs() # Save checkpoint if self.step > 0 and self.step % self.save_interval == 0: self.save() # Run for a finite amount of time in integration tests. if os.environ.get("DIFFUSION_TRAINING_TEST", ""): return self.step += 1 # Save the final checkpoint if dist_util.get_rank() == 0: self.save() def run_step(self, batch, cond): batch=th.cat((batch, cond), dim=1) self.forward_backward(batch, {}) took_step = self.mp_trainer.optimize(self.opt) if took_step: self._update_ema() self._anneal_lr() self.log_step() def forward_backward(self, batch, cond): self.mp_trainer.zero_grad() for i in range(0, batch.shape[0], self.microbatch): micro = batch[i : i + self.microbatch].to(dist_util.dev()) micro_cond = { k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items() } last_batch = (i + self.microbatch) >= batch.shape[0] t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) compute_losses = functools.partial( self.diffusion.training_losses_segmentation, self.ddp_model, None, # classifier is None for this task self.prior, self.posterior, micro, t, model_kwargs=micro_cond, ) if last_batch or not self.use_ddp: losses1 = compute_losses() else: with self.ddp_model.no_sync(): losses1 = compute_losses() losses, _ = losses1 loss = (losses["loss"] * weights).mean() log_loss_dict( self.diffusion, t, {k: v * weights for k, v in losses.items()} ) self.mp_trainer.backward(loss) def _update_ema(self): for rate, params in zip(self.ema_rate, self.ema_params): update_ema(params, self.mp_trainer.master_params, rate=rate) def _anneal_lr(self): if not self.lr_anneal_steps: return frac_done = (self.step + self.resume_step) / self.lr_anneal_steps lr = self.lr * (1 - frac_done) for param_group in self.opt.param_groups: param_group["lr"] = lr def log_step(self): logger.logkv("step", self.step + self.resume_step) logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) def save(self): def save_checkpoint(rate, params): state_dict = self.mp_trainer.master_params_to_state_dict(params) if dist_util.get_rank() == 0: current_step = self.step + self.resume_step logger.log(f"saving model {rate} at step {current_step}...") if not rate: filename = f"savedmodel{current_step:06d}.pt" else: filename = f"ema_{rate}_{current_step:06d}.pt" with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: th.save(state_dict, f) # Save the master parameters save_checkpoint(0, self.mp_trainer.master_params) # Save the EMA parameters for rate, params in zip(self.ema_rate, self.ema_params): save_checkpoint(rate, params) # Save the optimizer state if dist_util.get_rank() == 0: current_step = self.step + self.resume_step with bf.BlobFile( bf.join(get_blob_logdir(), f"opt{current_step:06d}.pt"), "wb" ) as f: th.save(self.opt.state_dict(), f) dist_util.barrier() def parse_resume_step_from_filename(filename): """ Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the checkpoint's number of steps. """ split = filename.split("model") if len(split) < 2: return 0 split1 = split[-1].split(".")[0] try: return int(split1) except ValueError: return 0 def get_blob_logdir(): return logger.get_dir() def find_resume_checkpoint(): return None def find_ema_checkpoint(main_checkpoint, step, rate): if main_checkpoint is None: return None filename = f"ema_{rate}_{step:06d}.pt" path = bf.join(bf.dirname(main_checkpoint), filename) if bf.exists(path): return path return None def log_loss_dict(diffusion, ts, losses): for key, values in losses.items(): logger.logkv_mean(key, values.mean().item()) for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): quartile = int(4 * sub_t / diffusion.num_timesteps) logger.logkv_mean(f"{key}_q{quartile}", sub_loss)