| | import copy |
| | from pdb import set_trace as st |
| | import functools |
| | import os |
| | import numpy as np |
| |
|
| | import blobfile as bf |
| | import torch as th |
| | import torch.distributed as dist |
| | from torch.nn.parallel.distributed import DistributedDataParallel as DDP |
| | from torch.optim import AdamW |
| |
|
| | from . import dist_util, logger |
| | from .fp16_util import MixedPrecisionTrainer |
| | from .nn import update_ema |
| | from .resample import LossAwareSampler, UniformSampler |
| |
|
| | from pathlib import Path |
| |
|
| | |
| | |
| | |
| | INITIAL_LOG_LOSS_SCALE = 20.0 |
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class TrainLoop: |
| |
|
| | def __init__( |
| | self, |
| | *, |
| | model, |
| | diffusion, |
| | data, |
| | batch_size, |
| | microbatch, |
| | lr, |
| | ema_rate, |
| | log_interval, |
| | save_interval, |
| | resume_checkpoint, |
| | use_fp16=False, |
| | fp16_scale_growth=1e-3, |
| | schedule_sampler=None, |
| | weight_decay=0.0, |
| | lr_anneal_steps=0, |
| | use_amp=False, |
| | model_name='ddpm', |
| | **kwargs |
| | ): |
| |
|
| | self.kwargs = kwargs |
| | self.pool_512 = th.nn.AdaptiveAvgPool2d((512, 512)) |
| | self.pool_256 = th.nn.AdaptiveAvgPool2d((256, 256)) |
| | self.pool_128 = th.nn.AdaptiveAvgPool2d((128, 128)) |
| | self.pool_64 = th.nn.AdaptiveAvgPool2d((64, 64)) |
| |
|
| | self.use_amp = use_amp |
| | self.model_name = model_name |
| | self.model = model |
| | self.diffusion = diffusion |
| | self.data = data |
| | self.batch_size = batch_size |
| | self.microbatch = microbatch if microbatch > 0 else batch_size |
| | self.lr = lr |
| | self.ema_rate = ([ema_rate] if isinstance(ema_rate, float) else |
| | [float(x) for x in ema_rate.split(",")]) |
| | self.log_interval = log_interval |
| | self.save_interval = save_interval |
| | self.resume_checkpoint = resume_checkpoint |
| | self.use_fp16 = use_fp16 |
| | self.fp16_scale_growth = fp16_scale_growth |
| | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) |
| | self.weight_decay = weight_decay |
| | self.lr_anneal_steps = lr_anneal_steps |
| |
|
| | self.step = 0 |
| | self.resume_step = 0 |
| | self.global_batch = self.batch_size * dist.get_world_size() |
| |
|
| | self.sync_cuda = th.cuda.is_available() |
| | self._setup_model() |
| | self._load_model() |
| | self._setup_opt() |
| | |
| | def _load_model(self): |
| | self._load_and_sync_parameters() |
| | |
| | def _setup_opt(self): |
| | self.opt = AdamW(self.mp_trainer.master_params, |
| | lr=self.lr, |
| | weight_decay=self.weight_decay) |
| | |
| | def _setup_model(self): |
| |
|
| | self.mp_trainer = MixedPrecisionTrainer( |
| | model=self.model, |
| | use_fp16=self.use_fp16, |
| | fp16_scale_growth=self.fp16_scale_growth, |
| | use_amp=self.use_amp, |
| | model_name=self.model_name |
| | ) |
| |
|
| | if self.resume_step: |
| | self._load_optimizer_state() |
| | |
| | |
| | self.ema_params = [ |
| | self._load_ema_parameters(rate) for rate in self.ema_rate |
| | ] |
| | else: |
| | self.ema_params = [ |
| | copy.deepcopy(self.mp_trainer.master_params) |
| | for _ in range(len(self.ema_rate)) |
| | ] |
| | |
| | |
| |
|
| | |
| | if th.cuda.is_available(): |
| | self.use_ddp = True |
| | self.ddpm_model = self.model |
| | self.ddp_model = DDP( |
| | self.model, |
| | device_ids=[dist_util.dev()], |
| | output_device=dist_util.dev(), |
| | broadcast_buffers=False, |
| | bucket_cap_mb=128, |
| | find_unused_parameters=False, |
| | ) |
| | else: |
| | if dist.get_world_size() > 1: |
| | logger.warn("Distributed training requires CUDA. " |
| | "Gradients will not be synchronized properly!") |
| | self.use_ddp = False |
| | self.ddp_model = self.model |
| | |
| |
|
| |
|
| | def _load_and_sync_parameters(self): |
| | resume_checkpoint, resume_step = find_resume_checkpoint( |
| | ) or self.resume_checkpoint |
| |
|
| | if resume_checkpoint: |
| | if not Path(resume_checkpoint).exists(): |
| | logger.log( |
| | f"failed to load model from checkpoint: {resume_checkpoint}, not exist" |
| | ) |
| | return |
| |
|
| | |
| | self.resume_step = resume_step |
| | if dist.get_rank() == 0: |
| | logger.log( |
| | f"loading model from checkpoint: {resume_checkpoint}...") |
| | |
| | |
| | self.model.load_state_dict( |
| | dist_util.load_state_dict( |
| | resume_checkpoint, |
| | map_location=dist_util.dev(), |
| | )) |
| |
|
| | |
| | dist_util.sync_params(self.model.parameters()) |
| | |
| |
|
| | def _load_ema_parameters(self, |
| | rate, |
| | model=None, |
| | mp_trainer=None, |
| | model_name='ddpm'): |
| |
|
| | if mp_trainer is None: |
| | mp_trainer = self.mp_trainer |
| | if model is None: |
| | model = self.model |
| |
|
| | ema_params = copy.deepcopy(mp_trainer.master_params) |
| |
|
| | main_checkpoint, _ = find_resume_checkpoint( |
| | self.resume_checkpoint, model_name) or self.resume_checkpoint |
| | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, |
| | rate, model_name) |
| | if ema_checkpoint: |
| |
|
| | if dist_util.get_rank() == 0: |
| |
|
| | if not Path(ema_checkpoint).exists(): |
| | logger.log( |
| | f"failed to load EMA from checkpoint: {ema_checkpoint}, not exist" |
| | ) |
| | return |
| |
|
| | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") |
| |
|
| | map_location = { |
| | 'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank() |
| | } |
| |
|
| | state_dict = dist_util.load_state_dict( |
| | ema_checkpoint, map_location=map_location) |
| |
|
| | model_ema_state_dict = model.state_dict() |
| |
|
| | for k, v in state_dict.items(): |
| | if k in model_ema_state_dict.keys() and v.size( |
| | ) == model_ema_state_dict[k].size(): |
| | model_ema_state_dict[k] = v |
| |
|
| | |
| | |
| |
|
| | else: |
| | print('ignore key: ', k, ": ", v.size()) |
| |
|
| | ema_params = mp_trainer.state_dict_to_master_params( |
| | model_ema_state_dict) |
| |
|
| | del state_dict |
| |
|
| | |
| | if dist_util.get_world_size() > 1: |
| | dist_util.sync_params(ema_params) |
| | |
| | |
| | return ema_params |
| |
|
| | def _load_ema_parameters_freezeAE( |
| | self, |
| | rate, |
| | model, |
| | |
| | model_name='rec'): |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | main_checkpoint, _ = find_resume_checkpoint( |
| | self.resume_checkpoint, model_name) or self.resume_checkpoint |
| | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, |
| | rate, model_name) |
| | if ema_checkpoint: |
| |
|
| | if dist_util.get_rank() == 0: |
| |
|
| | if not Path(ema_checkpoint).exists(): |
| | logger.log( |
| | f"failed to load EMA from checkpoint: {ema_checkpoint}, not exist" |
| | ) |
| | return |
| |
|
| | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") |
| |
|
| | map_location = { |
| | 'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank() |
| | } |
| |
|
| | state_dict = dist_util.load_state_dict( |
| | ema_checkpoint, map_location=map_location) |
| |
|
| | model_ema_state_dict = model.state_dict() |
| |
|
| | for k, v in state_dict.items(): |
| | if k in model_ema_state_dict.keys() and v.size( |
| | ) == model_ema_state_dict[k].size(): |
| | model_ema_state_dict[k] = v |
| | else: |
| | print('ignore key: ', k, ": ", v.size()) |
| |
|
| | ema_params = mp_trainer.state_dict_to_master_params( |
| | model_ema_state_dict) |
| |
|
| | del state_dict |
| |
|
| | |
| | if dist_util.get_world_size() > 1: |
| | dist_util.sync_params(ema_params) |
| | |
| | |
| | return ema_params |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | def _load_optimizer_state(self): |
| | main_checkpoint, _ = find_resume_checkpoint() or self.resume_checkpoint |
| | opt_checkpoint = bf.join(bf.dirname(main_checkpoint), |
| | f"opt{self.resume_step:06}.pt") |
| | if bf.exists(opt_checkpoint): |
| | logger.log( |
| | f"loading optimizer state from checkpoint: {opt_checkpoint}") |
| | state_dict = dist_util.load_state_dict( |
| | opt_checkpoint, map_location=dist_util.dev()) |
| | self.opt.load_state_dict(state_dict) |
| |
|
| | def run_loop(self): |
| | while (not self.lr_anneal_steps |
| | or self.step + self.resume_step < self.lr_anneal_steps): |
| | batch, cond = next(self.data) |
| | self.run_step(batch, cond) |
| | if self.step % self.log_interval == 0: |
| | logger.dumpkvs() |
| | if self.step % self.save_interval == 0: |
| | self.save() |
| | |
| | if os.environ.get("DIFFUSION_TRAINING_TEST", |
| | "") and self.step > 0: |
| | return |
| | self.step += 1 |
| | |
| | if (self.step - 1) % self.save_interval != 0: |
| | self.save() |
| |
|
| | def run_step(self, batch, cond): |
| | self.forward_backward(batch, cond) |
| | took_step = self.mp_trainer.optimize(self.opt) |
| | if took_step: |
| | self._update_ema() |
| | self._anneal_lr() |
| | self.log_step() |
| |
|
| | def forward_backward(self, batch, cond): |
| | self.mp_trainer.zero_grad() |
| | for i in range(0, batch.shape[0], self.microbatch): |
| |
|
| | |
| | with th.autocast(device_type=dist_util.dev(), |
| | dtype=th.float16, |
| | enabled=self.mp_trainer.use_amp): |
| |
|
| | micro = batch[i:i + self.microbatch].to(dist_util.dev()) |
| | micro_cond = { |
| | k: v[i:i + self.microbatch].to(dist_util.dev()) |
| | for k, v in cond.items() |
| | } |
| | last_batch = (i + self.microbatch) >= batch.shape[0] |
| | t, weights = self.schedule_sampler.sample( |
| | micro.shape[0], dist_util.dev()) |
| |
|
| | compute_losses = functools.partial( |
| | self.diffusion.training_losses, |
| | self.ddp_model, |
| | micro, |
| | t, |
| | model_kwargs=micro_cond, |
| | ) |
| |
|
| | if last_batch or not self.use_ddp: |
| | losses = compute_losses() |
| | else: |
| | with self.ddp_model.no_sync(): |
| | losses = compute_losses() |
| |
|
| | if isinstance(self.schedule_sampler, LossAwareSampler): |
| | self.schedule_sampler.update_with_local_losses( |
| | t, losses["loss"].detach()) |
| |
|
| | loss = (losses["loss"] * weights).mean() |
| | log_loss_dict(self.diffusion, t, |
| | {k: v * weights |
| | for k, v in losses.items()}) |
| |
|
| | self.mp_trainer.backward(loss) |
| |
|
| | def _update_ema(self): |
| | for rate, params in zip(self.ema_rate, self.ema_params): |
| | update_ema(params, self.mp_trainer.master_params, rate=rate) |
| |
|
| | def _anneal_lr(self): |
| | if not self.lr_anneal_steps: |
| | return |
| | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps |
| | lr = self.lr * (1 - frac_done) |
| | for param_group in self.opt.param_groups: |
| | param_group["lr"] = lr |
| |
|
| | def log_step(self): |
| | logger.logkv("step", self.step + self.resume_step) |
| | logger.logkv("samples", |
| | (self.step + self.resume_step + 1) * self.global_batch) |
| |
|
| | def save(self): |
| |
|
| | def save_checkpoint(rate, params): |
| | state_dict = self.mp_trainer.master_params_to_state_dict(params) |
| | if dist.get_rank() == 0: |
| | logger.log(f"saving model {rate}...") |
| | if not rate: |
| | filename = f"model{(self.step+self.resume_step):07d}.pt" |
| | else: |
| | filename = f"ema_{rate}_{(self.step+self.resume_step):07d}.pt" |
| | with bf.BlobFile(bf.join(get_blob_logdir(), filename), |
| | "wb") as f: |
| | th.save(state_dict, f) |
| |
|
| | save_checkpoint(0, self.mp_trainer.master_params) |
| | for rate, params in zip(self.ema_rate, self.ema_params): |
| | save_checkpoint(rate, params) |
| |
|
| | if dist.get_rank() == 0: |
| | with bf.BlobFile( |
| | bf.join(get_blob_logdir(), |
| | f"opt{(self.step+self.resume_step):07d}.pt"), |
| | "wb", |
| | ) as f: |
| | th.save(self.opt.state_dict(), f) |
| |
|
| | dist.barrier() |
| |
|
| |
|
| | def parse_resume_step_from_filename(filename): |
| | """ |
| | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the |
| | checkpoint's number of steps. |
| | """ |
| | |
| | split1 = Path(filename).stem[-7:] |
| | |
| | |
| | |
| | |
| | try: |
| | return int(split1) |
| | except ValueError: |
| | print('fail to load model step', split1) |
| | return 0 |
| |
|
| |
|
| | def get_blob_logdir(): |
| | |
| | |
| | return logger.get_dir() |
| |
|
| |
|
| | def find_resume_checkpoint(resume_checkpoint='', model_name='ddpm'): |
| | |
| | |
| |
|
| | if resume_checkpoint != '': |
| | step = parse_resume_step_from_filename(resume_checkpoint) |
| | split = resume_checkpoint.split("model") |
| | resume_ckpt_path = str( |
| | Path(split[0]) / f'model_{model_name}{step:07d}.pt') |
| | else: |
| | resume_ckpt_path = '' |
| | step = 0 |
| |
|
| | return resume_ckpt_path, step |
| |
|
| |
|
| | def find_ema_checkpoint(main_checkpoint, step, rate, model_name=''): |
| | if main_checkpoint is None: |
| | return None |
| | if model_name == '': |
| | filename = f"ema_{rate}_{(step):07d}.pt" |
| | else: |
| | filename = f"ema_{model_name}_{rate}_{(step):07d}.pt" |
| | path = bf.join(bf.dirname(main_checkpoint), filename) |
| | |
| | |
| | if bf.exists(path): |
| | print('fine ema model', path) |
| | return path |
| | else: |
| | print('fail to find ema model', path) |
| | return None |
| |
|
| |
|
| | def log_loss_dict(diffusion, ts, losses): |
| | for key, values in losses.items(): |
| | logger.logkv_mean(key, values.mean().item()) |
| | |
| | for sub_t, sub_loss in zip(ts.cpu().numpy(), |
| | values.detach().cpu().numpy()): |
| | quartile = int(4 * sub_t / diffusion.num_timesteps) |
| | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) |
| |
|
| |
|
| | def log_rec3d_loss_dict(loss_dict): |
| | for key, values in loss_dict.items(): |
| | try: |
| | logger.logkv_mean(key, values.mean().item()) |
| | except: |
| | print('type error:', key) |
| | |
| |
|
| |
|
| | def calc_average_loss(all_loss_dicts, verbose=True): |
| | all_scores = {} |
| | mean_all_scores = {} |
| |
|
| | for loss_dict in all_loss_dicts: |
| | for k, v in loss_dict.items(): |
| | v = v.item() |
| | if k not in all_scores: |
| | |
| | all_scores[k] = [v] |
| | else: |
| | all_scores[k].append(v) |
| |
|
| | for k, v in all_scores.items(): |
| | mean = np.mean(v) |
| | std = np.std(v) |
| | if k in ['loss_lpis', 'loss_ssim']: |
| | mean = 1 - mean |
| | result_str = '{} average loss is {:.4f} +- {:.4f}'.format(k, mean, std) |
| | mean_all_scores[k] = mean |
| | if verbose: |
| | print(result_str) |
| |
|
| | val_scores_for_logging = { |
| | f'{k}_val': v |
| | for k, v in mean_all_scores.items() |
| | } |
| | return val_scores_for_logging |