Spaces:
Paused
Paused
| import copy | |
| import functools | |
| import os | |
| import time | |
| from types import SimpleNamespace | |
| import numpy as np | |
| import re | |
| from os.path import join as pjoin | |
| from typing import Optional | |
| import blobfile as bf | |
| import torch | |
| from torch.optim import AdamW | |
| from diffusion import logger | |
| from utils import dist_util | |
| from diffusion.fp16_util import MixedPrecisionTrainer | |
| from diffusion.resample import LossAwareSampler, UniformSampler | |
| from tqdm import tqdm | |
| from diffusion.resample import create_named_schedule_sampler | |
| from data_loaders.humanml.networks.evaluator_wrapper import EvaluatorMDMWrapper | |
| from eval import eval_humanml, eval_humanact12_uestc | |
| from sample.generate import main as generate | |
| from data_loaders.get_data import get_dataset_loader | |
| from utils.model_util import load_model_wo_clip | |
| from data_loaders.humanml.scripts.motion_process import get_target_location, sample_goal, get_allowed_joint_options | |
| from utils.sampler_util import ClassifierFreeSampleModel | |
| # For ImageNet experiments, this was a good default value. | |
| # We found that the lg_loss_scale quickly climbed to | |
| # 20-21 within the first ~1K steps of training. | |
| INITIAL_LOG_LOSS_SCALE = 20.0 | |
| class TrainLoop: | |
| def __init__(self, args, train_platform, model, diffusion, data): | |
| self.args = args | |
| self.dataset = args.dataset | |
| self.train_platform = train_platform | |
| self.model = model | |
| self.model_avg = None | |
| if self.args.use_ema: | |
| self.model_avg = copy.deepcopy(self.model) | |
| self.model_for_eval = self.model_avg if self.args.use_ema else self.model | |
| if args.gen_guidance_param != 1: | |
| self.model_for_eval = ClassifierFreeSampleModel(self.model_for_eval) # wrapping model with the classifier-free sampler | |
| self.diffusion = diffusion | |
| self.cond_mode = model.cond_mode | |
| self.data = data | |
| self.batch_size = args.batch_size | |
| self.microbatch = args.batch_size # deprecating this option | |
| self.lr = args.lr | |
| self.log_interval = args.log_interval | |
| self.save_interval = args.save_interval | |
| self.resume_checkpoint = args.resume_checkpoint | |
| self.use_fp16 = False # deprecating this option | |
| self.fp16_scale_growth = 1e-3 # deprecating this option | |
| self.weight_decay = args.weight_decay | |
| self.lr_anneal_steps = args.lr_anneal_steps | |
| self.step = 0 | |
| self.resume_step = 0 | |
| self.global_batch = self.batch_size # * dist.get_world_size() | |
| self.num_steps = args.num_steps | |
| self.num_epochs = self.num_steps // len(self.data) + 1 | |
| self.sync_cuda = torch.cuda.is_available() | |
| self._load_and_sync_parameters() | |
| self.mp_trainer = MixedPrecisionTrainer( | |
| model=self.model, | |
| use_fp16=self.use_fp16, | |
| fp16_scale_growth=self.fp16_scale_growth, | |
| ) | |
| self.save_dir = args.save_dir | |
| self.overwrite = args.overwrite | |
| if self.args.use_ema: | |
| self.opt = AdamW( | |
| # with amp, we don't need to use the mp_trainer's master_params | |
| (self.model.parameters() | |
| if self.use_fp16 else self.mp_trainer.master_params), | |
| lr=self.lr, | |
| weight_decay=self.weight_decay, | |
| betas=(0.9, self.args.adam_beta2), | |
| ) | |
| else: | |
| self.opt = AdamW( | |
| self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay | |
| ) | |
| if self.resume_step: | |
| self._load_optimizer_state() | |
| # Model was resumed, either due to a restart or a checkpoint | |
| # being specified at the command line. | |
| self.device = torch.device("cpu") | |
| if torch.cuda.is_available() and dist_util.dev() != 'cpu': | |
| self.device = torch.device(dist_util.dev()) | |
| self.schedule_sampler_type = 'uniform' | |
| self.schedule_sampler = create_named_schedule_sampler(self.schedule_sampler_type, diffusion) | |
| self.eval_wrapper, self.eval_data, self.eval_gt_data = None, None, None | |
| if args.dataset in ['kit', 'humanml'] and args.eval_during_training: | |
| mm_num_samples = 0 # mm is super slow hence we won't run it during training | |
| mm_num_repeats = 0 # mm is super slow hence we won't run it during training | |
| gen_loader = get_dataset_loader(name=args.dataset, batch_size=args.eval_batch_size, num_frames=None, | |
| split=args.eval_split, | |
| hml_mode='eval', | |
| autoregressive=args.autoregressive, | |
| fixed_len=args.context_len+args.pred_len, pred_len=args.pred_len, device=dist_util.dev()) | |
| self.eval_gt_data = get_dataset_loader(name=args.dataset, batch_size=args.eval_batch_size, num_frames=None, | |
| split=args.eval_split, | |
| hml_mode='gt', device=dist_util.dev()) | |
| self.eval_wrapper = EvaluatorMDMWrapper(args.dataset, dist_util.dev()) | |
| self.eval_data = { | |
| 'test': lambda: eval_humanml.get_mdm_loader(self.args, | |
| self.model_for_eval, diffusion, args.eval_batch_size, | |
| gen_loader, mm_num_samples, mm_num_repeats, gen_loader.dataset.opt.max_motion_length, | |
| args.eval_num_samples, scale=args.gen_guidance_param, | |
| ) | |
| } | |
| self.use_ddp = False | |
| self.ddp_model = self.model | |
| def _load_and_sync_parameters(self): | |
| resume_checkpoint = self.find_resume_checkpoint() or self.resume_checkpoint | |
| if resume_checkpoint: | |
| # we add 1 because self.resume_step has already been done and we don't want to run it again | |
| # in particular we don't want to run the evaluation and generation again | |
| self.step += 1 | |
| self.resume_step = parse_resume_step_from_filename(resume_checkpoint) | |
| logger.log(f"loading model from checkpoint: {resume_checkpoint}...") | |
| state_dict = dist_util.load_state_dict( | |
| resume_checkpoint, map_location=dist_util.dev()) | |
| if 'model_avg' in state_dict: | |
| print('loading both model and model_avg') | |
| state_dict, state_dict_avg = state_dict['model'], state_dict[ | |
| 'model_avg'] | |
| load_model_wo_clip(self.model, state_dict) | |
| load_model_wo_clip(self.model_avg, state_dict_avg) | |
| else: | |
| load_model_wo_clip(self.model, state_dict) | |
| if self.args.use_ema: | |
| # in case we load from a legacy checkpoint, just copy the model | |
| print('loading model_avg from model') | |
| self.model_avg.load_state_dict(self.model.state_dict(), strict=False) | |
| # self.model.load_state_dict( | |
| # dist_util.load_state_dict( | |
| # resume_checkpoint, map_location=dist_util.dev() | |
| # ), strict=False | |
| # ) | |
| def _load_optimizer_state(self): | |
| main_checkpoint = self.find_resume_checkpoint() or self.resume_checkpoint | |
| opt_checkpoint = bf.join( | |
| bf.dirname(main_checkpoint), f"opt{self.resume_step:09}.pt" | |
| ) | |
| if bf.exists(opt_checkpoint): | |
| logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") | |
| state_dict = dist_util.load_state_dict( | |
| opt_checkpoint, map_location=dist_util.dev() | |
| ) | |
| if self.use_fp16: | |
| if 'scaler' not in state_dict: | |
| print("scaler state not found ... not loading it.") | |
| else: | |
| # load grad scaler state | |
| self.scaler.load_state_dict(state_dict['scaler']) | |
| # for the rest | |
| state_dict = state_dict['opt'] | |
| tgt_wd = self.opt.param_groups[0]['weight_decay'] | |
| print('target weight decay:', tgt_wd) | |
| self.opt.load_state_dict(state_dict) | |
| print('loaded weight decay (will be replaced):', | |
| self.opt.param_groups[0]['weight_decay']) | |
| # preserve the weight decay parameter | |
| for group in self.opt.param_groups: | |
| group['weight_decay'] = tgt_wd | |
| self.opt.param_groups[0]['capturable'] = True | |
| def cond_modifiers(self, cond, motion): | |
| # All modifiers must be in-place | |
| self.target_cond_modifier(cond, motion) | |
| def target_cond_modifier(self, cond, motion): | |
| if self.args.multi_target_cond: | |
| batch_size = motion.shape[0] | |
| cond['target_joint_names'], cond['is_heading'] = sample_goal(batch_size, motion.device, self.args.target_joint_names) | |
| cond['target_cond'] = get_target_location(motion, | |
| self.data.dataset.mean[None, :, None, None], | |
| self.data.dataset.std[None, :, None, None], | |
| cond['lengths'], | |
| self.data.dataset.t2m_dataset.opt.joints_num, self.model.all_goal_joint_names, cond['target_joint_names'], cond['is_heading']).detach() | |
| def run_loop(self): | |
| print('train steps:', self.num_steps) | |
| for epoch in range(self.num_epochs): | |
| print(f'Starting epoch {epoch}') | |
| for motion, cond in tqdm(self.data): | |
| if not (not self.lr_anneal_steps or self.total_step() < self.lr_anneal_steps): | |
| break | |
| self.cond_modifiers(cond['y'], motion) # Modify in-place for efficiency | |
| motion = motion.to(self.device) | |
| cond['y'] = {key: val.to(self.device) if torch.is_tensor(val) else val for key, val in cond['y'].items()} | |
| self.run_step(motion, cond) | |
| if self.total_step() % self.log_interval == 0: | |
| for k,v in logger.get_current().dumpkvs().items(): | |
| if k == 'loss': | |
| print('step[{}]: loss[{:0.5f}]'.format(self.total_step(), v)) | |
| if k in ['step', 'samples'] or '_q' in k: | |
| continue | |
| else: | |
| self.train_platform.report_scalar(name=k, value=v, iteration=self.total_step(), group_name='Loss') | |
| if self.total_step() % self.save_interval == 0: | |
| self.save() | |
| self.model.eval() | |
| if self.args.use_ema: | |
| self.model_avg.eval() | |
| self.evaluate() | |
| self.generate_during_training() | |
| self.model.train() | |
| if self.args.use_ema: | |
| self.model_avg.train() | |
| # Run for a finite amount of time in integration tests. | |
| if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.total_step() > 0: | |
| return | |
| self.step += 1 | |
| if not (not self.lr_anneal_steps or self.total_step() < self.lr_anneal_steps): | |
| break | |
| # Save the last checkpoint if it wasn't already saved. | |
| if (self.total_step() - 1) % self.save_interval != 0: | |
| self.save() | |
| self.evaluate() | |
| def evaluate(self): | |
| if not self.args.eval_during_training: | |
| return | |
| start_eval = time.time() | |
| if self.eval_wrapper is not None: | |
| print('Running evaluation loop: [Should take about 90 min]') | |
| log_file = os.path.join(self.save_dir, f'eval_humanml_{(self.total_step()):09d}.log') | |
| diversity_times = 300 | |
| mm_num_times = 0 # mm is super slow hence we won't run it during training | |
| eval_dict = eval_humanml.evaluation( | |
| self.eval_wrapper, self.eval_gt_data, self.eval_data, log_file, | |
| replication_times=self.args.eval_rep_times, diversity_times=diversity_times, mm_num_times=mm_num_times, run_mm=False) | |
| print(eval_dict) | |
| for k, v in eval_dict.items(): | |
| if k.startswith('R_precision'): | |
| for i in range(len(v)): | |
| self.train_platform.report_scalar(name=f'top{i + 1}_' + k, value=v[i], | |
| iteration=self.total_step(), | |
| group_name='Eval') | |
| else: | |
| self.train_platform.report_scalar(name=k, value=v, iteration=self.total_step(), | |
| group_name='Eval') | |
| elif self.dataset in ['humanact12', 'uestc']: | |
| eval_args = SimpleNamespace(num_seeds=self.args.eval_rep_times, num_samples=self.args.eval_num_samples, | |
| batch_size=self.args.eval_batch_size, device=self.device, guidance_param = 1, | |
| dataset=self.dataset, unconstrained=self.args.unconstrained, | |
| model_path=os.path.join(self.save_dir, self.ckpt_file_name())) | |
| eval_dict = eval_humanact12_uestc.evaluate(eval_args, model=self.model, diffusion=self.diffusion, data=self.data.dataset) | |
| print(f'Evaluation results on {self.dataset}: {sorted(eval_dict["feats"].items())}') | |
| for k, v in eval_dict["feats"].items(): | |
| if 'unconstrained' not in k: | |
| self.train_platform.report_scalar(name=k, value=np.array(v).astype(float).mean(), iteration=self.step, group_name='Eval') | |
| else: | |
| self.train_platform.report_scalar(name=k, value=np.array(v).astype(float).mean(), iteration=self.step, group_name='Eval Unconstrained') | |
| end_eval = time.time() | |
| print(f'Evaluation time: {round(end_eval-start_eval)/60}min') | |
| def run_step(self, batch, cond): | |
| self.forward_backward(batch, cond) | |
| self.mp_trainer.optimize(self.opt) | |
| self.update_average_model() | |
| self._anneal_lr() | |
| self.log_step() | |
| def update_average_model(self): | |
| # update the average model using exponential moving average | |
| if self.args.use_ema: | |
| # master params are FP32 | |
| params = self.model.parameters( | |
| ) if self.use_fp16 else self.mp_trainer.master_params | |
| for param, avg_param in zip(params, self.model_avg.parameters()): | |
| # avg = avg + (param - avg) * (1 - alpha) | |
| # avg = avg + param * (1 - alpha) - (avg - alpha * avg) | |
| # avg = alpha * avg + param * (1 - alpha) | |
| avg_param.data.mul_(self.args.avg_model_beta).add_( | |
| param.data, alpha=1 - self.args.avg_model_beta) | |
| def forward_backward(self, batch, cond): | |
| self.mp_trainer.zero_grad() | |
| for i in range(0, batch.shape[0], self.microbatch): | |
| # Eliminates the microbatch feature | |
| assert i == 0 | |
| assert self.microbatch == self.batch_size | |
| micro = batch | |
| micro_cond = cond | |
| last_batch = (i + self.microbatch) >= batch.shape[0] | |
| t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) | |
| compute_losses = functools.partial( | |
| self.diffusion.training_losses, | |
| self.ddp_model, | |
| micro, # [bs, ch, image_size, image_size] | |
| t, # [bs](int) sampled timesteps | |
| model_kwargs=micro_cond, | |
| dataset=self.data.dataset | |
| ) | |
| if last_batch or not self.use_ddp: | |
| losses = compute_losses() | |
| else: | |
| with self.ddp_model.no_sync(): | |
| losses = compute_losses() | |
| if isinstance(self.schedule_sampler, LossAwareSampler): | |
| self.schedule_sampler.update_with_local_losses( | |
| t, losses["loss"].detach() | |
| ) | |
| loss = (losses["loss"] * weights).mean() | |
| log_loss_dict( | |
| self.diffusion, t, {k: v * weights for k, v in losses.items()} | |
| ) | |
| self.mp_trainer.backward(loss) | |
| def _anneal_lr(self): | |
| if not self.lr_anneal_steps: | |
| return | |
| frac_done = self.total_step() / self.lr_anneal_steps | |
| lr = self.lr * (1 - frac_done) | |
| for param_group in self.opt.param_groups: | |
| param_group["lr"] = lr | |
| def log_step(self): | |
| logger.logkv("step", self.total_step()) | |
| logger.logkv("samples", (self.total_step() + 1) * self.global_batch) | |
| def ckpt_file_name(self): | |
| return f"model{(self.total_step()):09d}.pt" | |
| def generate_during_training(self): | |
| if not self.args.gen_during_training: | |
| return | |
| gen_args = copy.deepcopy(self.args) | |
| gen_args.model_path = os.path.join(self.save_dir, self.ckpt_file_name()) | |
| gen_args.output_dir = os.path.join(self.save_dir, f'{self.ckpt_file_name()}.samples') | |
| gen_args.num_samples = self.args.gen_num_samples | |
| gen_args.num_repetitions = self.args.gen_num_repetitions | |
| gen_args.guidance_param = self.args.gen_guidance_param | |
| gen_args.motion_length = 6 # fixed length | |
| gen_args.input_text = gen_args.text_prompt = gen_args.action_file = gen_args.action_name = gen_args.dynamic_text_path = '' | |
| if gen_args.multi_target_cond: | |
| gen_args.sampling_mode = 'goal' | |
| gen_args.target_joint_source = 'data' | |
| all_sample_save_path = generate(gen_args) | |
| self.train_platform.report_media(title='Motion', series='Predicted Motion', iteration=self.total_step(), | |
| local_path=all_sample_save_path) | |
| def find_resume_checkpoint(self) -> Optional[str]: | |
| '''look for all file in save directory in the pattent of model{number}.pt | |
| and return the one with the highest step number. | |
| TODO: Implement this function (alredy existing in MDM), so that find model will call it in case a ckpt exist. | |
| TODO: Change call for find_resume_checkpoint and send save_dir as arg. | |
| TODO: This means ignoring the flag of resume_checkpoint in case some other ckpts exists in that dir! | |
| ''' | |
| matches = {file: re.match(r'model(\d+).pt$', file) for file in os.listdir(self.args.save_dir)} | |
| models = {int(match.group(1)): file for file, match in matches.items() if match} | |
| return pjoin(self.args.save_dir, models[max(models)]) if models else None | |
| def total_step(self): | |
| return self.step + self.resume_step | |
| def save(self): | |
| def save_checkpoint(): | |
| def del_clip(state_dict): | |
| # Do not save CLIP weights | |
| clip_weights = [ | |
| e for e in state_dict.keys() if e.startswith('clip_model.') | |
| ] | |
| for e in clip_weights: | |
| del state_dict[e] | |
| if self.use_fp16: | |
| state_dict = self.model.state_dict() | |
| else: | |
| state_dict = self.mp_trainer.master_params_to_state_dict( | |
| self.mp_trainer.master_params) | |
| del_clip(state_dict) | |
| if self.args.use_ema: | |
| # save both the model and the average model | |
| state_dict_avg = self.model_avg.state_dict() | |
| del_clip(state_dict_avg) | |
| state_dict = {'model': state_dict, 'model_avg': state_dict_avg} | |
| logger.log(f"saving model...") | |
| filename = self.ckpt_file_name() | |
| with bf.BlobFile(bf.join(self.save_dir, filename), "wb") as f: | |
| torch.save(state_dict, f) | |
| save_checkpoint() | |
| with bf.BlobFile( | |
| bf.join(self.save_dir, f"opt{(self.total_step()):09d}.pt"), | |
| "wb", | |
| ) as f: | |
| opt_state = self.opt.state_dict() | |
| if self.use_fp16: | |
| # with fp16 we also save the state dict | |
| opt_state = { | |
| 'opt': opt_state, | |
| 'scaler': self.scaler.state_dict(), | |
| } | |
| torch.save(opt_state, f) | |
| def parse_resume_step_from_filename(filename): | |
| """ | |
| Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the | |
| checkpoint's number of steps. | |
| """ | |
| split = filename.split("model") | |
| if len(split) < 2: | |
| return 0 | |
| split1 = split[-1].split(".")[0] | |
| try: | |
| return int(split1) | |
| except ValueError: | |
| return 0 | |
| def get_blob_logdir(): | |
| # You can change this to be a separate path to save checkpoints to | |
| # a blobstore or some external drive. | |
| return logger.get_dir() | |
| def log_loss_dict(diffusion, ts, losses): | |
| for key, values in losses.items(): | |
| logger.logkv_mean(key, values.mean().item()) | |
| # Log the quantiles (four quartiles, in particular). | |
| for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): | |
| quartile = int(4 * sub_t / diffusion.num_timesteps) | |
| logger.logkv_mean(f"{key}_q{quartile}", sub_loss) | |