Spaces:
Paused
Paused
| import torch | |
| from model.mdm import MDM | |
| from diffusers import DDPMScheduler | |
| from utils.parser_util import get_cond_mode | |
| from data_loaders.humanml_utils import HML_EE_JOINT_NAMES | |
| def load_model_wo_clip(model, state_dict): | |
| """ | |
| Load model weights, skipping positional encodings from CLIP to avoid mismatches. | |
| """ | |
| # Remove fixed positional encodings to avoid size mismatches | |
| state_dict.pop('sequence_pos_encoder.pe', None) | |
| state_dict.pop('embed_timestep.sequence_pos_encoder.pe', None) | |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
| assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}" | |
| assert all([k.startswith('clip_model.') or 'sequence_pos_encoder' in k for k in missing_keys]), \ | |
| f"Missing keys: {missing_keys}" | |
| def create_model_and_diffusion(args, data): | |
| """ | |
| Instantiate the MDM model and the diffusion scheduler. | |
| """ | |
| model = MDM(**get_model_args(args, data)) | |
| scheduler = create_diffusion_scheduler(args) | |
| return model, scheduler | |
| def get_model_args(args, data): | |
| # Default configuration | |
| clip_version = 'ViT-B/32' | |
| action_emb = 'tensor' | |
| cond_mode = get_cond_mode(args) | |
| num_actions = getattr(data.dataset, 'num_actions', 1) | |
| # Data representation defaults | |
| if args.dataset == 'humanml': | |
| data_rep = 'hml_vec' | |
| njoints, nfeats = 263, 1 | |
| all_goal_joint_names = ['pelvis'] + HML_EE_JOINT_NAMES | |
| elif args.dataset == 'kit': | |
| data_rep = 'hml_vec' | |
| njoints, nfeats = 251, 1 | |
| all_goal_joint_names = [] | |
| else: | |
| data_rep = 'rot6d' | |
| njoints, nfeats = 25, 6 | |
| all_goal_joint_names = [] | |
| # Ensure backward compatibility | |
| args.pred_len = getattr(args, 'pred_len', 0) | |
| args.context_len = getattr(args, 'context_len', 0) | |
| return { | |
| 'modeltype': '', | |
| 'njoints': njoints, | |
| 'nfeats': nfeats, | |
| 'num_actions': num_actions, | |
| 'translation': True, | |
| 'pose_rep': 'rot6d', | |
| 'glob': True, | |
| 'glob_rot': True, | |
| 'latent_dim': args.latent_dim, | |
| 'ff_size': 1024, | |
| 'num_layers': args.layers, | |
| 'num_heads': 4, | |
| 'dropout': 0.1, | |
| 'activation': "gelu", | |
| 'data_rep': data_rep, | |
| 'cond_mode': cond_mode, | |
| 'cond_mask_prob': args.cond_mask_prob, | |
| 'action_emb': action_emb, | |
| 'arch': args.arch, | |
| 'emb_trans_dec': args.emb_trans_dec, | |
| 'clip_version': clip_version, | |
| 'dataset': args.dataset, | |
| 'text_encoder_type': args.text_encoder_type, | |
| 'pos_embed_max_len': args.pos_embed_max_len, | |
| 'mask_frames': args.mask_frames, | |
| 'pred_len': args.pred_len, | |
| 'context_len': args.context_len, | |
| 'emb_policy': getattr(args, 'emb_policy', 'add'), | |
| 'all_goal_joint_names': all_goal_joint_names, | |
| 'multi_target_cond': getattr(args, 'multi_target_cond', False), | |
| 'multi_encoder_type': getattr(args, 'multi_encoder_type', 'multi'), | |
| 'target_enc_layers': getattr(args, 'target_enc_layers', 1), | |
| } | |
| def create_diffusion_scheduler(args): | |
| """ | |
| Create a DDPM scheduler using Hugging Face's `diffusers` library. | |
| """ | |
| # Define beta schedule parameters | |
| beta_start = getattr(args, 'beta_start', 1e-4) | |
| beta_end = getattr(args, 'beta_end', 0.02) | |
| beta_schedule = getattr(args, 'noise_schedule', 'linear') | |
| scheduler = DDPMScheduler( | |
| num_train_timesteps=args.diffusion_steps, | |
| beta_start=beta_start, | |
| beta_end=beta_end, | |
| beta_schedule=beta_schedule, | |
| ) | |
| # Initialize scheduler timesteps | |
| scheduler.set_timesteps(args.diffusion_steps) | |
| return scheduler | |
| def load_saved_model(model, model_path, use_avg: bool=False): | |
| """ | |
| Load weights from a checkpoint, optionally using an averaged model. | |
| """ | |
| checkpoint = torch.load(model_path, map_location='cpu') | |
| if use_avg and 'model_avg' in checkpoint: | |
| state_dict = checkpoint['model_avg'] | |
| elif 'model' in checkpoint: | |
| state_dict = checkpoint['model'] | |
| else: | |
| state_dict = checkpoint | |
| load_model_wo_clip(model, state_dict) | |
| return model | |