| import copy
|
| import functools
|
| import os
|
| import blobfile as bf
|
| import time
|
| import torch as th
|
| import torch.distributed as dist
|
| from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
| from torch.optim import AdamW
|
|
|
| from . import dist_util, logger
|
| from .fp16_util import MixedPrecisionTrainer
|
| from .nn import update_ema
|
| from .resample import LossAwareSampler, UniformSampler
|
|
|
|
|
|
|
|
|
| INITIAL_LOG_LOSS_SCALE = 20.0
|
|
|
| def visualize(img):
|
| _min = img.min()
|
| _max = img.max()
|
| normalized_img = (img - _min)/ (_max - _min)
|
| return normalized_img
|
|
|
| class TrainLoop:
|
| def __init__(
|
| self,
|
| *,
|
| model,
|
| classifier,
|
| diffusion,
|
| data,
|
| dataloader,
|
| prior,
|
| posterior,
|
| batch_size,
|
| microbatch,
|
| lr,
|
| ema_rate,
|
| log_interval,
|
| save_interval,
|
| resume_checkpoint,
|
| use_fp16=False,
|
| fp16_scale_growth=1e-3,
|
| schedule_sampler=None,
|
| weight_decay=0.0,
|
| lr_anneal_steps=0,
|
|
|
| total_steps=0,
|
| ):
|
| self.model = model
|
| self.dataloader=dataloader
|
| self.classifier = classifier
|
| self.diffusion = diffusion
|
| self.data = data
|
| self.batch_size = batch_size
|
| self.microbatch = microbatch if microbatch > 0 else batch_size
|
| self.lr = lr
|
| self.ema_rate = (
|
| [ema_rate]
|
| if isinstance(ema_rate, float)
|
| else [float(x) for x in ema_rate.split(",")]
|
| )
|
| self.log_interval = log_interval
|
| self.save_interval = save_interval
|
| self.resume_checkpoint = resume_checkpoint
|
| self.use_fp16 = use_fp16
|
| self.fp16_scale_growth = fp16_scale_growth
|
| self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
|
| self.weight_decay = weight_decay
|
| self.lr_anneal_steps = lr_anneal_steps
|
|
|
|
|
| self.total_steps = total_steps
|
|
|
| self.prior = prior
|
| self.posterior = posterior
|
|
|
| self.step = 0
|
| self.resume_step = 0
|
| if isinstance(self.model, th.nn.DataParallel): |
| |
| self.global_batch = self.batch_size |
| else: |
| self.global_batch = self.batch_size * dist_util.get_world_size() |
|
|
| self.sync_cuda = th.cuda.is_available()
|
|
|
| self._load_and_sync_parameters()
|
| self.mp_trainer = MixedPrecisionTrainer(
|
| model=self.model,
|
| use_fp16=self.use_fp16,
|
| fp16_scale_growth=fp16_scale_growth,
|
| )
|
|
|
| self.opt = AdamW(
|
| self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
|
| )
|
| if self.resume_step:
|
| self._load_optimizer_state()
|
|
|
|
|
| self.ema_params = [
|
| self._load_ema_parameters(rate) for rate in self.ema_rate
|
| ]
|
| else:
|
| self.ema_params = [
|
| copy.deepcopy(self.mp_trainer.master_params)
|
| for _ in range(len(self.ema_rate))
|
| ]
|
|
|
| self.use_ddp = isinstance(self.model, DDP) |
| self.ddp_model = self.model |
| if not self.use_ddp and dist_util.get_world_size() > 1: |
| logger.warn( |
| "Running with world_size > 1 but model is not wrapped in DDP." |
| ) |
|
|
| def _load_and_sync_parameters(self):
|
| resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
|
|
|
| if resume_checkpoint:
|
| self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
|
| if dist_util.get_rank() == 0: |
| logger.log(f"loading model from checkpoint: {resume_checkpoint}...") |
| state_dict = dist_util.load_state_dict(resume_checkpoint, map_location=dist_util.dev()) |
| if isinstance(self.model, DDP): |
| self.model.module.load_state_dict(state_dict) |
| else: |
| self.model.load_state_dict(state_dict) |
|
|
| dist_util.sync_params(self.model.parameters())
|
|
|
| def _load_ema_parameters(self, rate):
|
| ema_params = copy.deepcopy(self.mp_trainer.master_params)
|
|
|
| main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
|
| ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
|
| if ema_checkpoint:
|
| if dist_util.get_rank() == 0: |
| logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") |
| state_dict = dist_util.load_state_dict( |
| ema_checkpoint, map_location=dist_util.dev() |
| ) |
| ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)
|
|
|
| dist_util.sync_params(ema_params)
|
| return ema_params
|
|
|
| def _load_optimizer_state(self):
|
| main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
|
| opt_checkpoint = bf.join(
|
| bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
|
| )
|
| if bf.exists(opt_checkpoint):
|
| logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
|
| state_dict = dist_util.load_state_dict(
|
| opt_checkpoint, map_location=dist_util.dev()
|
| )
|
| self.opt.load_state_dict(state_dict)
|
|
|
| def run_loop(self):
|
| data_iter = iter(self.dataloader)
|
|
|
|
|
|
|
| while self.step + self.resume_step < self.total_steps:
|
| try:
|
| batch, cond = next(data_iter)
|
| except StopIteration:
|
|
|
| data_iter = iter(self.dataloader)
|
| batch, cond = next(data_iter)
|
|
|
| self.run_step(batch, cond)
|
|
|
| if self.step % self.log_interval == 0:
|
| logger.dumpkvs()
|
|
|
|
|
| if self.step > 0 and self.step % self.save_interval == 0:
|
| self.save()
|
|
|
| if os.environ.get("DIFFUSION_TRAINING_TEST", ""):
|
| return
|
|
|
| self.step += 1
|
|
|
|
|
| if dist_util.get_rank() == 0: |
| self.save() |
|
|
| def run_step(self, batch, cond):
|
| batch=th.cat((batch, cond), dim=1)
|
| self.forward_backward(batch, {})
|
| took_step = self.mp_trainer.optimize(self.opt)
|
| if took_step:
|
| self._update_ema()
|
| self._anneal_lr()
|
| self.log_step()
|
|
|
| def forward_backward(self, batch, cond):
|
| self.mp_trainer.zero_grad()
|
| for i in range(0, batch.shape[0], self.microbatch):
|
| micro = batch[i : i + self.microbatch].to(dist_util.dev())
|
| micro_cond = {
|
| k: v[i : i + self.microbatch].to(dist_util.dev())
|
| for k, v in cond.items()
|
| }
|
|
|
| last_batch = (i + self.microbatch) >= batch.shape[0]
|
| t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
|
|
|
| compute_losses = functools.partial(
|
| self.diffusion.training_losses_segmentation,
|
| self.ddp_model,
|
| None,
|
| self.prior,
|
| self.posterior,
|
| micro,
|
| t,
|
| model_kwargs=micro_cond,
|
| )
|
|
|
| if last_batch or not self.use_ddp:
|
| losses1 = compute_losses()
|
| else:
|
| with self.ddp_model.no_sync():
|
| losses1 = compute_losses()
|
|
|
| losses, _ = losses1
|
| loss = (losses["loss"] * weights).mean()
|
|
|
| log_loss_dict(
|
| self.diffusion, t, {k: v * weights for k, v in losses.items()}
|
| )
|
| self.mp_trainer.backward(loss)
|
|
|
| def _update_ema(self):
|
| for rate, params in zip(self.ema_rate, self.ema_params):
|
| update_ema(params, self.mp_trainer.master_params, rate=rate)
|
|
|
| def _anneal_lr(self):
|
| if not self.lr_anneal_steps:
|
| return
|
| frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
|
| lr = self.lr * (1 - frac_done)
|
| for param_group in self.opt.param_groups:
|
| param_group["lr"] = lr
|
|
|
| def log_step(self):
|
| logger.logkv("step", self.step + self.resume_step)
|
| logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
|
|
|
| def save(self):
|
| def save_checkpoint(rate, params):
|
| state_dict = self.mp_trainer.master_params_to_state_dict(params)
|
| if dist_util.get_rank() == 0: |
| current_step = self.step + self.resume_step |
| logger.log(f"saving model {rate} at step {current_step}...") |
| if not rate: |
| filename = f"savedmodel{current_step:06d}.pt"
|
| else:
|
| filename = f"ema_{rate}_{current_step:06d}.pt"
|
| with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
|
| th.save(state_dict, f)
|
|
|
|
|
| save_checkpoint(0, self.mp_trainer.master_params)
|
|
|
| for rate, params in zip(self.ema_rate, self.ema_params):
|
| save_checkpoint(rate, params)
|
|
|
| if dist_util.get_rank() == 0: |
| current_step = self.step + self.resume_step |
| with bf.BlobFile( |
| bf.join(get_blob_logdir(), f"opt{current_step:06d}.pt"), "wb" |
| ) as f: |
| th.save(self.opt.state_dict(), f) |
|
|
| dist_util.barrier() |
|
|
|
|
| def parse_resume_step_from_filename(filename):
|
| """
|
| Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
|
| checkpoint's number of steps.
|
| """
|
| split = filename.split("model")
|
| if len(split) < 2:
|
| return 0
|
| split1 = split[-1].split(".")[0]
|
| try:
|
| return int(split1)
|
| except ValueError:
|
| return 0
|
|
|
|
|
| def get_blob_logdir():
|
| return logger.get_dir()
|
|
|
|
|
| def find_resume_checkpoint():
|
| return None
|
|
|
|
|
| def find_ema_checkpoint(main_checkpoint, step, rate):
|
| if main_checkpoint is None:
|
| return None
|
| filename = f"ema_{rate}_{step:06d}.pt"
|
| path = bf.join(bf.dirname(main_checkpoint), filename)
|
| if bf.exists(path):
|
| return path
|
| return None
|
|
|
|
|
| def log_loss_dict(diffusion, ts, losses):
|
| for key, values in losses.items():
|
| logger.logkv_mean(key, values.mean().item())
|
| for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
|
| quartile = int(4 * sub_t / diffusion.num_timesteps)
|
| logger.logkv_mean(f"{key}_q{quartile}", sub_loss) |
|
|