| | import pdb |
| |
|
| | from model.mdm import MDM |
| | from diffusion import gaussian_diffusion as gd |
| | from diffusion.respace import SpacedDiffusion, space_timesteps |
| |
|
| |
|
| | def load_model_wo_clip(model, state_dict): |
| | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
| | print(missing_keys, unexpected_keys) |
| | assert len(unexpected_keys) == 0 |
| | assert all([k.startswith('clip_model.') for k in missing_keys]) |
| |
|
| |
|
| | def create_model_and_diffusion(args, data): |
| | model = MDM(**get_model_args(args, data)) |
| | diffusion = create_gaussian_diffusion(args) |
| | return model, diffusion |
| |
|
| |
|
| | def get_model_args(args, data): |
| |
|
| | |
| | clip_version = 'ViT-B/32' |
| | action_emb = 'tensor' |
| | if args.unconstrained: |
| | cond_mode = 'no_cond' |
| | elif args.dataset in ['kit', 'humanml']: |
| | cond_mode = 'text' |
| | else: |
| | cond_mode = 'action' |
| | if hasattr(data.dataset, 'num_actions'): |
| | num_actions = data.dataset.num_actions |
| | else: |
| | num_actions = 1 |
| |
|
| | |
| | data_rep = 'rot6d' |
| | njoints = 25 |
| | nfeats = 6 |
| |
|
| | if args.dataset == 'humanml': |
| | data_rep = 'hml_vec' |
| | njoints = 263 |
| | nfeats = 1 |
| | elif args.dataset == 'kit': |
| | data_rep = 'hml_vec' |
| | njoints = 251 |
| | nfeats = 1 |
| |
|
| | return {'modeltype': '', 'njoints': njoints, 'nfeats': nfeats, 'num_actions': num_actions, |
| | 'translation': True, 'pose_rep': 'rot6d', 'glob': True, 'glob_rot': True, |
| | 'latent_dim': args.latent_dim, 'ff_size': 1024, 'num_layers': args.layers, 'num_heads': 4, |
| | 'dropout': 0.1, 'activation': "gelu", 'data_rep': data_rep, 'cond_mode': cond_mode, |
| | 'cond_mask_prob': args.cond_mask_prob, 'action_emb': action_emb, 'arch': args.arch, |
| | 'emb_trans_dec': args.emb_trans_dec, 'clip_version': clip_version, 'dataset': args.dataset} |
| |
|
| |
|
| | def create_gaussian_diffusion(): |
| | noise_schedule = 'cosine' |
| | sigma_small = True |
| | lambda_vel = 0.0 |
| | lambda_rcxyz = 0.0 |
| | lambda_fc = 0.0 |
| |
|
| | |
| | predict_xstart = True |
| | steps = 1000 |
| | scale_beta = 1. |
| | timestep_respacing = '' |
| | learn_sigma = False |
| | rescale_timesteps = False |
| |
|
| | betas = gd.get_named_beta_schedule(noise_schedule, steps, scale_beta) |
| | loss_type = gd.LossType.MSE |
| |
|
| | if not timestep_respacing: |
| | timestep_respacing = [steps] |
| |
|
| | return SpacedDiffusion( |
| | use_timesteps=space_timesteps(steps, timestep_respacing), |
| | betas=betas, |
| | model_mean_type=( |
| | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X |
| | ), |
| | model_var_type=( |
| | ( |
| | gd.ModelVarType.FIXED_LARGE |
| | if not sigma_small |
| | else gd.ModelVarType.FIXED_SMALL |
| | ) |
| | if not learn_sigma |
| | else gd.ModelVarType.LEARNED_RANGE |
| | ), |
| | loss_type=loss_type, |
| | rescale_timesteps=rescale_timesteps, |
| | lambda_vel=lambda_vel, |
| | lambda_rcxyz=lambda_rcxyz, |
| | lambda_fc=lambda_fc, |
| | ) |