import torch from model.mdm import MDM from diffusers import DDPMScheduler from utils.parser_util import get_cond_mode from data_loaders.humanml_utils import HML_EE_JOINT_NAMES def load_model_wo_clip(model, state_dict): """ Load model weights, skipping positional encodings from CLIP to avoid mismatches. """ # Remove fixed positional encodings to avoid size mismatches state_dict.pop('sequence_pos_encoder.pe', None) state_dict.pop('embed_timestep.sequence_pos_encoder.pe', None) missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}" assert all([k.startswith('clip_model.') or 'sequence_pos_encoder' in k for k in missing_keys]), \ f"Missing keys: {missing_keys}" def create_model_and_diffusion(args, data): """ Instantiate the MDM model and the diffusion scheduler. """ model = MDM(**get_model_args(args, data)) scheduler = create_diffusion_scheduler(args) return model, scheduler def get_model_args(args, data): # Default configuration clip_version = 'ViT-B/32' action_emb = 'tensor' cond_mode = get_cond_mode(args) num_actions = getattr(data.dataset, 'num_actions', 1) # Data representation defaults if args.dataset == 'humanml': data_rep = 'hml_vec' njoints, nfeats = 263, 1 all_goal_joint_names = ['pelvis'] + HML_EE_JOINT_NAMES elif args.dataset == 'kit': data_rep = 'hml_vec' njoints, nfeats = 251, 1 all_goal_joint_names = [] else: data_rep = 'rot6d' njoints, nfeats = 25, 6 all_goal_joint_names = [] # Ensure backward compatibility args.pred_len = getattr(args, 'pred_len', 0) args.context_len = getattr(args, 'context_len', 0) return { 'modeltype': '', 'njoints': njoints, 'nfeats': nfeats, 'num_actions': num_actions, 'translation': True, 'pose_rep': 'rot6d', 'glob': True, 'glob_rot': True, 'latent_dim': args.latent_dim, 'ff_size': 1024, 'num_layers': args.layers, 'num_heads': 4, 'dropout': 0.1, 'activation': "gelu", 'data_rep': data_rep, 'cond_mode': cond_mode, 'cond_mask_prob': args.cond_mask_prob, 'action_emb': action_emb, 'arch': args.arch, 'emb_trans_dec': args.emb_trans_dec, 'clip_version': clip_version, 'dataset': args.dataset, 'text_encoder_type': args.text_encoder_type, 'pos_embed_max_len': args.pos_embed_max_len, 'mask_frames': args.mask_frames, 'pred_len': args.pred_len, 'context_len': args.context_len, 'emb_policy': getattr(args, 'emb_policy', 'add'), 'all_goal_joint_names': all_goal_joint_names, 'multi_target_cond': getattr(args, 'multi_target_cond', False), 'multi_encoder_type': getattr(args, 'multi_encoder_type', 'multi'), 'target_enc_layers': getattr(args, 'target_enc_layers', 1), } def create_diffusion_scheduler(args): """ Create a DDPM scheduler using Hugging Face's `diffusers` library. """ # Define beta schedule parameters beta_start = getattr(args, 'beta_start', 1e-4) beta_end = getattr(args, 'beta_end', 0.02) beta_schedule = getattr(args, 'noise_schedule', 'linear') scheduler = DDPMScheduler( num_train_timesteps=args.diffusion_steps, beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule, ) # Initialize scheduler timesteps scheduler.set_timesteps(args.diffusion_steps) return scheduler def load_saved_model(model, model_path, use_avg: bool=False): """ Load weights from a checkpoint, optionally using an averaged model. """ checkpoint = torch.load(model_path, map_location='cpu') if use_avg and 'model_avg' in checkpoint: state_dict = checkpoint['model_avg'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint load_model_wo_clip(model, state_dict) return model