Spaces:
Running
Running
| import os | |
| import copy | |
| import functools | |
| import blobfile as bf | |
| import torch | |
| import torch.distributed as dist | |
| from torch.nn.parallel.distributed import DistributedDataParallel as DDP | |
| from torch.optim import AdamW | |
| from . import dist_util, logger | |
| from .fp16_util import ( | |
| make_master_params, | |
| master_params_to_model_params, | |
| model_grads_to_master_grads, | |
| unflatten_master_params, | |
| zero_grad, | |
| ) | |
| from .nn import update_ema | |
| from .resample import LossAwareSampler, UniformSampler | |
| import wandb | |
| from tqdm import tqdm | |
| INITIAL_LOG_LOSS_SCALE = 20.0 | |
| class TrainLoop: | |
| def __init__( | |
| self, | |
| *, | |
| model, | |
| diffusion, | |
| data, | |
| batch_size, | |
| microbatch, | |
| lr, | |
| ema_rate, | |
| log_interval, | |
| save_interval, | |
| resume_checkpoint, | |
| use_fp16=False, | |
| fp16_scale_growth=1e-3, | |
| schedule_sampler=None, | |
| weight_decay=0.0, | |
| lr_anneal_steps=0, | |
| checkpoint_path="", | |
| gradient_clipping=-1.0, | |
| eval_data=None, | |
| eval_interval=-1, | |
| ): | |
| print('Initiating train loop') | |
| rank = dist.get_rank() | |
| world_size = dist.get_world_size() | |
| self.rank = rank | |
| self.world_size = world_size | |
| self.diffusion = diffusion | |
| self.data = data | |
| self.eval_data = eval_data | |
| self.batch_size = batch_size | |
| self.microbatch = microbatch if microbatch > 0 else batch_size | |
| self.lr = lr * world_size | |
| self.ema_rate = ( | |
| [ema_rate] | |
| if isinstance(ema_rate, float) | |
| else [float(x) for x in ema_rate.split(",")] | |
| ) | |
| self.log_interval = log_interval | |
| self.eval_interval = eval_interval | |
| self.save_interval = save_interval | |
| self.resume_checkpoint = resume_checkpoint | |
| self.use_fp16 = use_fp16 | |
| self.fp16_scale_growth = fp16_scale_growth | |
| self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) | |
| self.weight_decay = weight_decay | |
| self.lr_anneal_steps = lr_anneal_steps | |
| self.gradient_clipping = gradient_clipping | |
| self.step = 0 | |
| self.resume_step = 0 | |
| self.global_batch = self.batch_size * dist.get_world_size() | |
| self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE | |
| self.sync_cuda = torch.cuda.is_available() | |
| self.checkpoint_path = checkpoint_path | |
| self.model = model.to(rank) | |
| if torch.cuda.is_available(): # DEBUG ** | |
| self.use_ddp = True | |
| self.ddp_model = self.model | |
| # self.ddp_model = DDP( | |
| # self.model, | |
| # device_ids=[self.rank], | |
| # find_unused_parameters=False, | |
| # ) | |
| else: | |
| self.ddp_model = model.to("cpu") | |
| self.model_params = list(self.ddp_model.parameters()) | |
| self.master_params = self.model_params | |
| self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay) | |
| if self.resume_step: | |
| # self._load_optimizer_state() | |
| # # Model was resumed, either due to a restart or a checkpoint | |
| # # being specified at the command line. | |
| # self.ema_params = [ | |
| # self._load_ema_parameters(rate) for rate in self.ema_rate | |
| # ] | |
| pass | |
| else: | |
| self.ema_params = [ | |
| copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) | |
| ] | |
| print('Finish initiating train loop') | |
| def _load_and_sync_parameters(self): | |
| resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint | |
| if resume_checkpoint: | |
| self.resume_step = parse_resume_step_from_filename(resume_checkpoint) | |
| if dist.get_rank() == 0: | |
| # logger.log(f"loading model from checkpoint: {resume_checkpoint}...") | |
| print(f"loading model from checkpoint: {resume_checkpoint}...") | |
| self.model.load_state_dict( | |
| dist_util.load_state_dict( | |
| resume_checkpoint, map_location=dist_util.dev() | |
| ) | |
| ) | |
| dist_util.sync_params(self.model.parameters()) | |
| def _load_ema_parameters(self, rate): | |
| ema_params = copy.deepcopy(self.master_params) | |
| main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint | |
| ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) | |
| if ema_checkpoint: | |
| if dist.get_rank() == 0: | |
| logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") | |
| state_dict = dist_util.load_state_dict( | |
| ema_checkpoint, map_location=dist_util.dev() | |
| ) | |
| ema_params = self._state_dict_to_master_params(state_dict) | |
| dist_util.sync_params(ema_params) | |
| return ema_params | |
| def _load_optimizer_state(self): | |
| main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint | |
| opt_checkpoint = bf.join( | |
| bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" | |
| ) | |
| if bf.exists(opt_checkpoint): | |
| logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") | |
| state_dict = dist_util.load_state_dict( | |
| opt_checkpoint, map_location=dist_util.dev() | |
| ) | |
| self.opt.load_state_dict(state_dict) | |
| def _setup_fp16(self): | |
| self.master_params = make_master_params(self.model_params) | |
| self.model.convert_to_fp16() | |
| def run_loop(self): | |
| pbar = tqdm(total=self.lr_anneal_steps // self.world_size) | |
| print('Start running train loop') | |
| while ( | |
| not self.lr_anneal_steps | |
| or self.step + self.resume_step < self.lr_anneal_steps // self.world_size | |
| ): | |
| pbar.set_description(f"Step: {self.step + self.resume_step}") | |
| batch = next(self.data) | |
| # if self.step<3: | |
| # print("RANK:",self.rank,"STEP:",self.step,"BATCH:",batch) | |
| self.run_step(batch, cond=None) | |
| if self.step % self.log_interval == 0: | |
| # dist.barrier() | |
| pass | |
| # print('loggggg') | |
| # logger.dumpkvs() | |
| if self.eval_data is not None and self.step % self.eval_interval == 0: | |
| # batch_eval, cond_eval = next(self.eval_data) | |
| # self.forward_only(batch, cond) | |
| print("eval on validation set") | |
| pass # logger.dumpkvs() | |
| if self.step % self.save_interval == 0 and self.step != 0: | |
| self.save() | |
| # Run for a finite amount of time in integration tests. | |
| if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: | |
| return | |
| self.step += 1 | |
| pbar.update(1) | |
| # Save the last checkpoint if it wasn't already saved. | |
| if (self.step - 1) % self.save_interval != 0: | |
| self.save() | |
| def run_step(self, batch, cond): | |
| self.forward_backward(batch, cond) | |
| if self.use_fp16: | |
| self.optimize_fp16() | |
| else: | |
| self.optimize_normal() | |
| self.log_step() | |
| def forward_only(self, batch, cond): | |
| with torch.no_grad(): | |
| zero_grad(self.model_params) | |
| for i in range(0, batch.shape[0], self.microbatch): | |
| micro = batch[i : i + self.microbatch].to(dist_util.dev()) | |
| micro_cond = { | |
| k: v[i : i + self.microbatch].to(dist_util.dev()) | |
| for k, v in cond.items() | |
| } | |
| last_batch = (i + self.microbatch) >= batch.shape[0] | |
| t, weights = self.schedule_sampler.sample( | |
| micro.shape[0], dist_util.dev() | |
| ) | |
| # print(micro_cond.keys()) | |
| compute_losses = functools.partial( | |
| self.diffusion.training_losses, | |
| self.ddp_model, | |
| micro, | |
| t, | |
| micro_cond, | |
| ) | |
| if last_batch or not self.use_ddp: | |
| losses = compute_losses() | |
| else: | |
| with self.ddp_model.no_sync(): | |
| losses = compute_losses() | |
| log_loss_dict( | |
| self.diffusion, | |
| t, | |
| {f"eval_{k}": v * weights for k, v in losses.items()}, | |
| ) | |
| def forward_backward(self, batch, cond): | |
| # zero_grad(self.model_params) | |
| self.opt.zero_grad() | |
| for i in range(0, batch[0].shape[0], self.microbatch): | |
| # micro = batch[i : i + self.microbatch].to(self.rank) | |
| # last_batch = (i + self.microbatch) >= batch.shape[0] | |
| # t, weights = self.schedule_sampler.sample(micro.shape[0], self.rank) | |
| micro = ( | |
| batch[0].to(self.rank), # selfies_ids | |
| batch[1].to(self.rank), # caption_state | |
| batch[2].to(self.rank), # caption_mask | |
| batch[3].to(self.rank), # corrupted_selfies_ids | |
| ) | |
| last_batch = True | |
| t, weights = self.schedule_sampler.sample(micro[0].shape[0], self.rank) | |
| compute_losses = functools.partial( | |
| self.diffusion.training_losses, | |
| self.ddp_model, | |
| micro, | |
| t, | |
| None, | |
| ) | |
| if last_batch or not self.use_ddp: | |
| losses = compute_losses() | |
| else: | |
| with self.ddp_model.no_sync(): | |
| losses = compute_losses() | |
| if isinstance(self.schedule_sampler, LossAwareSampler): | |
| self.schedule_sampler.update_with_local_losses( | |
| t, losses["loss"].detach() | |
| ) | |
| loss = (losses["loss"] * weights).mean() | |
| # print('----DEBUG-----',self.step,self.log_interval) | |
| if self.step % self.log_interval == 0 and self.rank == 0: | |
| print("rank0: ", self.step, loss.item()) | |
| wandb.log({"loss": loss.item()}) | |
| # log_loss_dict( | |
| # self.diffusion, t, {k: v * weights for k, v in losses.items()} | |
| # ) | |
| if self.use_fp16: | |
| # loss_scale = 2 ** self.lg_loss_scale | |
| # (loss * loss_scale).backward() | |
| pass | |
| else: | |
| loss.backward() | |
| def optimize_fp16(self): | |
| if any(not torch.isfinite(p.grad).all() for p in self.model_params): | |
| self.lg_loss_scale -= 1 | |
| logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") | |
| return | |
| model_grads_to_master_grads(self.model_params, self.master_params) | |
| self.master_params[0].grad.mul_(1.0 / (2**self.lg_loss_scale)) | |
| self._log_grad_norm() | |
| self._anneal_lr() | |
| self.opt.step() | |
| for rate, params in zip(self.ema_rate, self.ema_params): | |
| update_ema(params, self.master_params, rate=rate) | |
| master_params_to_model_params(self.model_params, self.master_params) | |
| self.lg_loss_scale += self.fp16_scale_growth | |
| def grad_clip(self): | |
| # print('doing gradient clipping') | |
| max_grad_norm = self.gradient_clipping # 3.0 | |
| if hasattr(self.opt, "clip_grad_norm"): | |
| # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping | |
| self.opt.clip_grad_norm(max_grad_norm) | |
| # else: | |
| # assert False | |
| # elif hasattr(self.model, "clip_grad_norm_"): | |
| # # Some models (like FullyShardedDDP) have a specific way to do gradient clipping | |
| # self.model.clip_grad_norm_(args.max_grad_norm) | |
| else: | |
| # Revert to normal clipping otherwise, handling Apex or full precision | |
| torch.nn.utils.clip_grad_norm_( | |
| self.model.parameters(), # amp.master_params(self.opt) if self.use_apex else | |
| max_grad_norm, | |
| ) | |
| def optimize_normal(self): | |
| if self.gradient_clipping > 0: | |
| self.grad_clip() | |
| # self._log_grad_norm() | |
| self._anneal_lr() | |
| self.opt.step() | |
| for rate, params in zip(self.ema_rate, self.ema_params): | |
| update_ema(params, self.master_params, rate=rate) | |
| def _log_grad_norm(self): | |
| sqsum = 0.0 | |
| for p in self.master_params: | |
| sqsum += (p.grad**2).sum().item() | |
| # logger.logkv_mean("grad_norm", np.sqrt(sqsum)) | |
| def _anneal_lr(self): | |
| if not self.lr_anneal_steps: | |
| return | |
| frac_done = (self.step + self.resume_step) / self.lr_anneal_steps | |
| lr = self.lr * (1 - frac_done) | |
| for param_group in self.opt.param_groups: | |
| param_group["lr"] = lr | |
| def log_step(self): | |
| logger.logkv("step", self.step + self.resume_step) | |
| # logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) | |
| if self.use_fp16: | |
| logger.logkv("lg_loss_scale", self.lg_loss_scale) | |
| def save(self): | |
| def save_checkpoint(rate, params): | |
| state_dict = self._master_params_to_state_dict(params) | |
| if dist.get_rank() == 0: | |
| # logger.log(f"saving model {rate}...") | |
| print(f"saving model {rate}...") | |
| if not rate: | |
| filename = f"PLAIN_model{((self.step+self.resume_step)*self.world_size):06d}.pt" | |
| else: | |
| filename = f"PLAIN_ema_{rate}_{((self.step+self.resume_step)*self.world_size):06d}.pt" | |
| # print('writing to', bf.join(get_blob_logdir(), filename)) | |
| # print('writing to', bf.join(self.checkpoint_path, filename)) | |
| # with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: | |
| # torch.save(state_dict, f) | |
| with bf.BlobFile( | |
| bf.join(self.checkpoint_path, filename), "wb" | |
| ) as f: # DEBUG ** | |
| torch.save(state_dict, f) | |
| save_checkpoint(0, self.master_params) | |
| for rate, params in zip(self.ema_rate, self.ema_params): | |
| save_checkpoint(rate, params) | |
| # if dist.get_rank() == 0: # DEBUG ** | |
| # with bf.BlobFile( | |
| # bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), | |
| # "wb", | |
| # ) as f: | |
| # torch.save(self.opt.state_dict(), f) | |
| dist.barrier() | |
| def _master_params_to_state_dict(self, master_params): | |
| if self.use_fp16: | |
| master_params = unflatten_master_params( | |
| list(self.model.parameters()), master_params # DEBUG ** | |
| ) | |
| state_dict = self.model.state_dict() | |
| for i, (name, _value) in enumerate(self.model.named_parameters()): | |
| assert name in state_dict | |
| state_dict[name] = master_params[i] | |
| return state_dict | |
| def _state_dict_to_master_params(self, state_dict): | |
| params = [state_dict[name] for name, _ in self.model.named_parameters()] | |
| if self.use_fp16: | |
| return make_master_params(params) | |
| else: | |
| return params | |
| def parse_resume_step_from_filename(filename): | |
| """ | |
| Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the | |
| checkpoint's number of steps. | |
| """ | |
| split = filename.split("model") | |
| if len(split) < 2: | |
| return 0 | |
| split1 = split[-1].split(".")[0] | |
| try: | |
| return int(split1) | |
| except ValueError: | |
| return 0 | |
| def get_blob_logdir(): | |
| return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir()) | |
| def find_resume_checkpoint(): | |
| # On your infrastructure, you may want to override this to automatically | |
| # discover the latest checkpoint on your blob storage, etc. | |
| return None | |
| def find_ema_checkpoint(main_checkpoint, step, rate): | |
| if main_checkpoint is None: | |
| return None | |
| filename = f"ema_{rate}_{(step):06d}.pt" | |
| path = bf.join(bf.dirname(main_checkpoint), filename) | |
| if bf.exists(path): | |
| return path | |
| return None | |
| def log_loss_dict(diffusion, ts, losses): | |
| return | |
| for key, values in losses.items(): | |
| logger.logkv_mean(key, values.mean().item()) | |
| # Log the quantiles (four quartiles, in particular). | |
| for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): | |
| quartile = int(4 * sub_t / diffusion.num_timesteps) | |
| logger.logkv_mean(f"{key}_q{quartile}", sub_loss) | |