import torch from torch.utils.data import Subset from model.mdm import MDM from model.mdm_controlnet import MDMControlNet from diffusion import gaussian_diffusion as gd from diffusion.respace import SpacedDiffusion, space_timesteps from data_loaders.humanml_utils import HML_EE_JOINT_NAMES from utils.sampler_util import AutoRegressiveSampler from data_loaders.humanml.scripts.motion_process import recover_from_ric from data_loaders.tensors import collate def get_cond_mode(args): if args.unconstrained: cond_mode = "no_cond" elif args.dataset in ["kit", "humanml", "humanml_with_images"]: cond_mode = "text" else: cond_mode = "action" return cond_mode def load_model_wo_clip(model, state_dict): # assert (state_dict['sequence_pos_encoder.pe'][:model.sequence_pos_encoder.pe.shape[0]] == model.sequence_pos_encoder.pe).all() # TEST # assert (state_dict['embed_timestep.sequence_pos_encoder.pe'][:model.embed_timestep.sequence_pos_encoder.pe.shape[0]] == model.embed_timestep.sequence_pos_encoder.pe).all() # TEST if not isinstance(model, MDMControlNet): del state_dict[ "sequence_pos_encoder.pe" ] # no need to load it (fixed), and causes size mismatch for older models del state_dict[ "embed_timestep.sequence_pos_encoder.pe" ] # no need to load it (fixed), and causes size mismatch for older models missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) assert len(unexpected_keys) == 0 assert all( [ k.startswith("clip_model.") or "sequence_pos_encoder" in k for k in missing_keys ] ) def create_model_and_diffusion(args, data): model = MDM(**get_model_args(args, data)) diffusion = create_gaussian_diffusion(args) return model, diffusion def get_model_args(args, data): # default args clip_version = "ViT-B/32" action_emb = "tensor" cond_mode = get_cond_mode(args) if hasattr(data.dataset, "num_actions"): num_actions = data.dataset.num_actions else: num_actions = 1 # SMPL defaults data_rep = "rot6d" njoints = 25 nfeats = 6 all_goal_joint_names = [] if args.dataset in ["humanml", "humanml_with_images"]: data_rep = "hml_vec" njoints = 263 nfeats = 1 all_goal_joint_names = ["pelvis"] + HML_EE_JOINT_NAMES elif args.dataset == "kit": data_rep = "hml_vec" njoints = 251 nfeats = 1 # Compatibility with old models if not hasattr(args, "pred_len"): args.pred_len = 0 args.context_len = 0 emb_policy = args.__dict__.get("emb_policy", "add") multi_target_cond = args.__dict__.get("multi_target_cond", False) multi_encoder_type = args.__dict__.get("multi_encoder_type", "multi") target_enc_layers = args.__dict__.get("target_enc_layers", 1) return { "modeltype": "", "njoints": njoints, "nfeats": nfeats, "num_actions": num_actions, "translation": True, "pose_rep": "rot6d", "glob": True, "glob_rot": True, "latent_dim": args.latent_dim, "ff_size": 1024, "num_layers": args.layers, "num_heads": 4, "dropout": 0.1, "activation": "gelu", "data_rep": data_rep, "cond_mode": cond_mode, "cond_mask_prob": args.cond_mask_prob, "action_emb": action_emb, "arch": args.arch, "emb_trans_dec": args.emb_trans_dec, "clip_version": clip_version, "dataset": args.dataset, "text_encoder_type": args.text_encoder_type, "pos_embed_max_len": args.pos_embed_max_len, "mask_frames": args.mask_frames, "pred_len": args.pred_len, "context_len": args.context_len, "emb_policy": emb_policy, "all_goal_joint_names": all_goal_joint_names, "multi_target_cond": multi_target_cond, "multi_encoder_type": multi_encoder_type, "target_enc_layers": target_enc_layers, } def create_gaussian_diffusion(args): # default params predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal! steps = args.diffusion_steps scale_beta = 1.0 # no scaling timestep_respacing = "" # can be used for ddim sampling, we don't use it. learn_sigma = False rescale_timesteps = False betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta) loss_type = gd.LossType.MSE if not timestep_respacing: timestep_respacing = [steps] if hasattr(args, "lambda_target_loc"): lambda_target_loc = args.lambda_target_loc else: lambda_target_loc = 0.0 return SpacedDiffusion( use_timesteps=space_timesteps(steps, timestep_respacing), betas=betas, model_mean_type=( gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X ), model_var_type=( ( gd.ModelVarType.FIXED_LARGE if not args.sigma_small else gd.ModelVarType.FIXED_SMALL ) if not learn_sigma else gd.ModelVarType.LEARNED_RANGE ), loss_type=loss_type, rescale_timesteps=rescale_timesteps, lambda_vel=args.lambda_vel, lambda_rcxyz=args.lambda_rcxyz, lambda_fc=args.lambda_fc, lambda_target_loc=lambda_target_loc, ) def load_saved_model(model, model_path, use_avg: bool = False): # use_avg_model state_dict = torch.load(model_path, map_location="cpu") # Use average model when possible if use_avg and "model_avg" in state_dict.keys(): # if use_avg_model: print("loading avg model") state_dict = state_dict["model_avg"] else: if "model" in state_dict: print("loading model without avg") state_dict = state_dict["model"] else: print("checkpoint has no avg model, loading as usual.") load_model_wo_clip(model, state_dict) return model def sample_from_model( model, diffusion, data=None, num_samples=1, num_repetitions=1, text_prompts=None, action_name=None, motion_length=6.0, guidance_param=3.0, n_frames=None, context_motion=None, context_len=0, pred_len=0, autoregressive=False, device="cuda", return_xyz=True, return_numpy=True, noise=None, const_noise=False, cond_images=None, frame_indices=None, ): """ Sample motions from a trained MDM model. Parameters: model: The MDM model diffusion: The diffusion object data: Optional dataset loader (used for prefix sampling if needed) num_samples: Number of samples (text prompts) to process num_repetitions: Number of different motions to generate for each prompt text_prompts: List of text prompts or single string prompt action_name: Action name(s) for action-conditioned generation motion_length: Length of motion in seconds guidance_param: Classifier-free guidance scale n_frames: Number of frames to generate (calculated from motion_length if None) context_motion: Optional context motion for prefix-based generation context_len: Context length for prefix-based generation pred_len: Prediction length for each step in autoregressive generation autoregressive: Whether to use autoregressive sampling device: Device to use for sampling return_xyz: Whether to convert output to XYZ coordinates return_numpy: Whether to return numpy arrays (True) or torch tensors (False) noise: Optional noise tensor for sampling const_noise: Whether to use constant noise for sampling cond_images: frame_indices: Returns: Dictionary containing: - motions: Generated motions with shape [num_samples*num_repetitions, njoints, 3, n_frames] - texts: Text prompts used for generation - lengths: Length of each generated motion """ assert cond_images is not None or isinstance(model, MDMControlNet), ( "Image conditioning is only supported for MDMControlNet" ) if cond_images is not None: cond_images = model.process_images(cond_images, device=device) model.eval() # Ensure model is in eval mode # Move model to the right device if it's not there already model_device = next(model.parameters()).device if str(model_device) != device: model = model.to(device) # Determine number of frames fps = 12.5 if model.dataset == "kit" else 20 if n_frames is None: n_frames = min( 196 if model.dataset in ["kit", "humanml", "humanml_with_images"] else 60, int(motion_length * fps), ) # Handle text prompts if text_prompts is not None: if isinstance(text_prompts, str): text_prompts = [text_prompts] * num_samples elif len(text_prompts) < num_samples: text_prompts = text_prompts * (num_samples // len(text_prompts) + 1) text_prompts = text_prompts[:num_samples] num_samples = len(text_prompts) # Handle action names if action_name is not None: if isinstance(action_name, str): action_text = [action_name] * num_samples else: action_text = action_name num_samples = len(action_text) # Set up classifier-free guidance original_model = model if guidance_param != 1.0: from utils.sampler_util import ClassifierFreeSampleModel model = ClassifierFreeSampleModel(model) # Set up autoregressive sampling if needed sample_fn = diffusion.p_sample_loop if autoregressive: sample_cls = AutoRegressiveSampler({"pred_len": pred_len}, sample_fn, n_frames) sample_fn = sample_cls.sample # Prepare for sampling motion_shape = (num_samples, model.njoints, model.nfeats, n_frames) # Set up model kwargs if context_motion is not None or context_len > 0: # For prefix-conditioned generation if data is None: raise ValueError("Dataset needed for context-based generation") iterator = iter(data) input_motion, model_kwargs = next(iterator) input_motion = input_motion.to(device) if text_prompts is not None: model_kwargs["y"]["text"] = text_prompts else: collate_args = [ {"inp": torch.zeros(n_frames), "tokens": None, "lengths": n_frames} ] * num_samples if text_prompts is not None: # Text-to-motion collate_args = [ dict(arg, text=txt) for arg, txt in zip(collate_args, text_prompts) ] elif action_name is not None: # Action-to-motion if hasattr(data.dataset, "action_name_to_action"): action = data.dataset.action_name_to_action(action_text) collate_args = [ dict(arg, action=one_action, action_text=one_action_text) for arg, one_action, one_action_text in zip( collate_args, action, action_text ) ] else: raise ValueError("Dataset doesn't support action conditioning") _, model_kwargs = collate(collate_args) # Move model_kwargs to device model_kwargs["y"] = { key: val.to(device) if torch.is_tensor(val) else val for key, val in model_kwargs["y"].items() } # Add image conditioning to model_kwargs if provided if cond_images is not None: model_kwargs["cond_images"] = cond_images if frame_indices is not None: model_kwargs["frame_indices"] = frame_indices # Add CFG scale to batch if guidance_param != 1.0: model_kwargs["y"]["scale"] = ( torch.ones(num_samples, device=device) * guidance_param ) # Pre-encode text for efficiency if "text" in model_kwargs["y"]: model_kwargs["y"]["text_embed"] = original_model.encode_text( model_kwargs["y"]["text"] ) # Store all generated motions and related information all_motions = [] all_text = [] all_lengths = [] # Run generation for each repetition for rep_i in range(num_repetitions): print(f"### Sampling [repetition #{rep_i + 1}/{num_repetitions}]") # Sample from the model sample = sample_fn( model, motion_shape, clip_denoised=False, model_kwargs=model_kwargs, skip_timesteps=0, init_image=None, progress=True, dump_steps=None, noise=noise, const_noise=const_noise, ) # Get text information for this batch if "text" in model_kwargs["y"]: batch_text = model_kwargs["y"]["text"] elif "action_text" in model_kwargs["y"]: batch_text = model_kwargs["y"]["action_text"] else: batch_text = [""] * num_samples all_text.extend(batch_text) # Get lengths batch_lengths = model_kwargs["y"]["lengths"].cpu() all_lengths.append(batch_lengths) # Post-process the sample if returning XYZ coordinates if return_xyz: # Recover XYZ positions from vector representation if needed if model.data_rep == "hml_vec": n_joints = 22 if sample.shape[1] == 263 else 21 if isinstance(data.dataset, Subset): dataset = data.dataset.dataset else: dataset = data.dataset sample = dataset.t2m_dataset.inv_transform( sample.cpu().permute(0, 2, 3, 1) ).float() sample = recover_from_ric(sample, n_joints) sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) # Convert rotations to XYZ coordinates rot2xyz_pose_rep = ( "xyz" if model.data_rep in ["xyz", "hml_vec"] else model.data_rep ) rot2xyz_mask = ( None if rot2xyz_pose_rep == "xyz" else model_kwargs["y"]["mask"].reshape(num_samples, n_frames).bool() ) sample = model.rot2xyz( x=sample, mask=rot2xyz_mask, pose_rep=rot2xyz_pose_rep, glob=True, translation=True, jointstype="smpl", vertstrans=True, betas=None, beta=0, glob_rot=None, get_rotations_back=False, ) # Store this batch of samples all_motions.append(sample) # Concatenate all repetitions all_motions = torch.cat(all_motions, dim=0) all_lengths = torch.cat(all_lengths, dim=0) # Convert to numpy if requested if return_numpy: all_motions = all_motions.cpu().numpy() all_lengths = all_lengths.numpy() # Reset model if we wrapped it if guidance_param != 1.0: model = original_model return {"motions": all_motions, "texts": all_text, "lengths": all_lengths}