Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import cProfile as profile | |
| import functools | |
| import pstats | |
| import blobfile as bf | |
| import numpy as np | |
| import torch | |
| from torch.optim import AdamW | |
| from tqdm import tqdm | |
| import utils.logger as logger | |
| from diffusion.fp16_util import MixedPrecisionTrainer | |
| from diffusion.resample import LossAwareSampler, create_named_schedule_sampler | |
| from utils.misc import dev, load_state_dict | |
| INITIAL_LOG_LOSS_SCALE = 20.0 | |
| class TrainLoop: | |
| def __init__( | |
| self, args, train_platform, model, diffusion, data, writer, rank=0, world_size=1 | |
| ): | |
| self.args = args | |
| self.dataset = args.dataset | |
| self.train_platform = train_platform | |
| self.model = model | |
| self.diffusion = diffusion | |
| self.cond_mode = model.module.cond_mode if world_size > 1 else model.cond_mode | |
| self.data = data | |
| self.batch_size = args.batch_size | |
| self.microbatch = args.batch_size # deprecating this option | |
| self.lr = args.lr | |
| self.log_interval = args.log_interval | |
| self.save_interval = args.save_interval | |
| self.resume_checkpoint = args.resume_checkpoint | |
| self.use_fp16 = False # deprecating this option | |
| self.fp16_scale_growth = 1e-3 # deprecating this option | |
| self.weight_decay = args.weight_decay | |
| self.lr_anneal_steps = args.lr_anneal_steps | |
| self.rank = rank | |
| self.world_size = world_size | |
| self.step = 0 | |
| self.resume_step = 0 | |
| self.global_batch = self.batch_size | |
| self.num_steps = args.num_steps | |
| self.num_epochs = self.num_steps // len(self.data) + 1 | |
| chunks = list(range(self.num_steps)) | |
| num_chunks = int(self.num_steps / 10) | |
| chunks = np.array_split(chunks, num_chunks) | |
| self.chunks = np.reshape(chunks[10_000::10], (-1)) | |
| self.sync_cuda = torch.cuda.is_available() | |
| self.writer = writer | |
| self._load_and_sync_parameters() | |
| self.mp_trainer = MixedPrecisionTrainer( | |
| model=self.model, | |
| use_fp16=self.use_fp16, | |
| fp16_scale_growth=self.fp16_scale_growth, | |
| ) | |
| self.save_dir = args.save_dir | |
| self.overwrite = args.overwrite | |
| self.opt = AdamW( | |
| self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay | |
| ) | |
| if self.resume_step: | |
| self._load_optimizer_state() | |
| if torch.cuda.is_available(): | |
| self.device = torch.device(f"cuda:{self.rank}") | |
| self.schedule_sampler_type = "uniform" | |
| self.schedule_sampler = create_named_schedule_sampler( | |
| self.schedule_sampler_type, diffusion | |
| ) | |
| self.eval_wrapper, self.eval_data, self.eval_gt_data = None, None, None | |
| self.use_ddp = True | |
| self.ddp_model = self.model | |
| def _load_and_sync_parameters(self): | |
| resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint | |
| if resume_checkpoint: | |
| self.resume_step = parse_resume_step_from_filename(resume_checkpoint) | |
| logger.log(f"loading model from checkpoint: {resume_checkpoint}...") | |
| self.model.load_state_dict( | |
| load_state_dict(resume_checkpoint, map_location=dev()) | |
| ) | |
| def _load_optimizer_state(self): | |
| main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint | |
| opt_checkpoint = bf.join( | |
| bf.dirname(main_checkpoint), f"opt{self.resume_step:09}.pt" | |
| ) | |
| if bf.exists(opt_checkpoint): | |
| logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") | |
| state_dict = load_state_dict(opt_checkpoint, map_location=dev()) | |
| self.opt.load_state_dict(state_dict) | |
| def _print_stats(self, logger): | |
| if (self.step % 100 == 0 and self.step > 0) and self.rank == 0: | |
| v = logger.get_current().name2val | |
| v = v["loss"] | |
| print("step[{}]: loss[{:0.5f}]".format(self.step + self.resume_step, v)) | |
| def _write_to_logger(self, logger): | |
| if (self.step % self.log_interval == 0) and self.rank == 0: | |
| for k, v in logger.get_current().name2val.items(): | |
| if k == "loss": | |
| print( | |
| "step[{}]: loss[{:0.5f}]".format( | |
| self.step + self.resume_step, v | |
| ) | |
| ) | |
| self.writer.add_scalar(f"./Train/{k}", v, self.step) | |
| if k in ["step", "samples"] or "_q" in k: | |
| continue | |
| else: | |
| self.train_platform.report_scalar( | |
| name=k, value=v, iteration=self.step, group_name="Loss" | |
| ) | |
| self.writer.add_scalar(f"./Train/{k}", v, self.step) | |
| def run_loop(self): | |
| for _ in range(self.num_epochs): | |
| if self.rank == 0: | |
| prof = profile.Profile() | |
| prof.enable() | |
| for motion, cond in tqdm(self.data, disable=(self.rank != 0)): | |
| if not ( | |
| not self.lr_anneal_steps | |
| or self.step + self.resume_step < self.lr_anneal_steps | |
| ): | |
| break | |
| motion = motion.to(self.device) | |
| cond["y"] = { | |
| key: val.to(self.device) if torch.is_tensor(val) else val | |
| for key, val in cond["y"].items() | |
| } | |
| self.run_step(motion, cond) | |
| self._print_stats(logger) | |
| self._write_to_logger(logger) | |
| if (self.step % self.save_interval == 0) and self.rank == 0: | |
| self.save() | |
| self.step += 1 | |
| if (self.step == 1000) and self.rank == 0: | |
| prof.disable() | |
| stats = pstats.Stats(prof).strip_dirs().sort_stats("cumtime") | |
| stats.print_stats(10) | |
| if not ( | |
| not self.lr_anneal_steps | |
| or self.step + self.resume_step < self.lr_anneal_steps | |
| ): | |
| break | |
| # Save the last checkpoint if it wasn't already saved. | |
| if ((self.step - 1) % self.save_interval != 0) and self.rank == 0: | |
| self.save() | |
| def run_step(self, batch, cond): | |
| self.forward_backward(batch, cond) | |
| self.mp_trainer.optimize(self.opt) | |
| self._anneal_lr() | |
| if self.rank == 0: | |
| self.log_step() | |
| def forward_backward(self, batch, cond): | |
| self.mp_trainer.zero_grad() | |
| for i in range(0, batch.shape[0], self.microbatch): | |
| # Eliminates the microbatch feature | |
| assert i == 0 | |
| assert self.microbatch == self.batch_size | |
| micro = batch | |
| micro_cond = cond | |
| last_batch = (i + self.microbatch) >= batch.shape[0] | |
| t, weights = self.schedule_sampler.sample(micro.shape[0], batch.device) | |
| compute_losses = functools.partial( | |
| self.diffusion.training_losses, | |
| self.ddp_model, | |
| micro, | |
| t, | |
| model_kwargs=micro_cond, | |
| ) | |
| if last_batch or not self.use_ddp: | |
| losses = compute_losses() | |
| else: | |
| with self.ddp_model.no_sync(): | |
| losses = compute_losses() | |
| if isinstance(self.schedule_sampler, LossAwareSampler): | |
| self.schedule_sampler.update_with_local_losses( | |
| t, losses["loss"].detach() | |
| ) | |
| loss = (losses["loss"] * weights).mean() | |
| log_loss_dict( | |
| self.diffusion, t, {k: v * weights for k, v in losses.items()} | |
| ) | |
| self.mp_trainer.backward(loss) | |
| def _anneal_lr(self): | |
| if not self.lr_anneal_steps: | |
| return | |
| frac_done = (self.step + self.resume_step) / self.lr_anneal_steps | |
| lr = self.lr * (1 - frac_done) | |
| for param_group in self.opt.param_groups: | |
| param_group["lr"] = lr | |
| def log_step(self): | |
| logger.logkv("step", self.step + self.resume_step) | |
| logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) | |
| def ckpt_file_name(self): | |
| return f"model{(self.step+self.resume_step):09d}.pt" | |
| def save(self): | |
| def save_checkpoint(params): | |
| state_dict = self.mp_trainer.master_params_to_state_dict(params) | |
| # Do not save CLIP weights | |
| clip_weights = [e for e in state_dict.keys() if e.startswith("clip_model.")] | |
| for e in clip_weights: | |
| del state_dict[e] | |
| logger.log(f"saving model...") | |
| filename = self.ckpt_file_name() | |
| with bf.BlobFile(bf.join(self.save_dir, filename), "wb") as f: | |
| torch.save(state_dict, f) | |
| save_checkpoint(self.mp_trainer.master_params) | |
| with bf.BlobFile( | |
| bf.join(self.save_dir, f"opt{(self.step+self.resume_step):09d}.pt"), | |
| "wb", | |
| ) as f: | |
| torch.save(self.opt.state_dict(), f) | |
| def parse_resume_step_from_filename(filename): | |
| """ | |
| Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the | |
| checkpoint's number of steps. | |
| """ | |
| split = filename.split("model") | |
| if len(split) < 2: | |
| return 0 | |
| split1 = split[-1].split(".")[0] | |
| try: | |
| return int(split1) | |
| except ValueError: | |
| return 0 | |
| def get_blob_logdir(): | |
| # You can change this to be a separate path to save checkpoints to | |
| # a blobstore or some external drive. | |
| return logger.get_dir() | |
| def find_resume_checkpoint(): | |
| # On your infrastructure, you may want to override this to automatically | |
| # discover the latest checkpoint on your blob storage, etc. | |
| return None | |
| def log_loss_dict(diffusion, ts, losses): | |
| for key, values in losses.items(): | |
| logger.logkv_mean(key, values.mean().item()) | |
| # Log the quantiles (four quartiles, in particular). | |
| for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): | |
| quartile = int(4 * sub_t / diffusion.num_timesteps) | |
| logger.logkv_mean(f"{key}_q{quartile}", sub_loss) | |