Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import Subset | |
| from model.mdm import MDM | |
| from model.mdm_controlnet import MDMControlNet | |
| from diffusion import gaussian_diffusion as gd | |
| from diffusion.respace import SpacedDiffusion, space_timesteps | |
| from data_loaders.humanml_utils import HML_EE_JOINT_NAMES | |
| from utils.sampler_util import AutoRegressiveSampler | |
| from data_loaders.humanml.scripts.motion_process import recover_from_ric | |
| from data_loaders.tensors import collate | |
| def get_cond_mode(args): | |
| if args.unconstrained: | |
| cond_mode = "no_cond" | |
| elif args.dataset in ["kit", "humanml", "humanml_with_images"]: | |
| cond_mode = "text" | |
| else: | |
| cond_mode = "action" | |
| return cond_mode | |
| def load_model_wo_clip(model, state_dict): | |
| # assert (state_dict['sequence_pos_encoder.pe'][:model.sequence_pos_encoder.pe.shape[0]] == model.sequence_pos_encoder.pe).all() # TEST | |
| # assert (state_dict['embed_timestep.sequence_pos_encoder.pe'][:model.embed_timestep.sequence_pos_encoder.pe.shape[0]] == model.embed_timestep.sequence_pos_encoder.pe).all() # TEST | |
| if not isinstance(model, MDMControlNet): | |
| del state_dict[ | |
| "sequence_pos_encoder.pe" | |
| ] # no need to load it (fixed), and causes size mismatch for older models | |
| del state_dict[ | |
| "embed_timestep.sequence_pos_encoder.pe" | |
| ] # no need to load it (fixed), and causes size mismatch for older models | |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
| assert len(unexpected_keys) == 0 | |
| assert all( | |
| [ | |
| k.startswith("clip_model.") or "sequence_pos_encoder" in k | |
| for k in missing_keys | |
| ] | |
| ) | |
| def create_model_and_diffusion(args, data): | |
| model = MDM(**get_model_args(args, data)) | |
| diffusion = create_gaussian_diffusion(args) | |
| return model, diffusion | |
| def get_model_args(args, data): | |
| # default args | |
| clip_version = "ViT-B/32" | |
| action_emb = "tensor" | |
| cond_mode = get_cond_mode(args) | |
| if hasattr(data.dataset, "num_actions"): | |
| num_actions = data.dataset.num_actions | |
| else: | |
| num_actions = 1 | |
| # SMPL defaults | |
| data_rep = "rot6d" | |
| njoints = 25 | |
| nfeats = 6 | |
| all_goal_joint_names = [] | |
| if args.dataset in ["humanml", "humanml_with_images"]: | |
| data_rep = "hml_vec" | |
| njoints = 263 | |
| nfeats = 1 | |
| all_goal_joint_names = ["pelvis"] + HML_EE_JOINT_NAMES | |
| elif args.dataset == "kit": | |
| data_rep = "hml_vec" | |
| njoints = 251 | |
| nfeats = 1 | |
| # Compatibility with old models | |
| if not hasattr(args, "pred_len"): | |
| args.pred_len = 0 | |
| args.context_len = 0 | |
| emb_policy = args.__dict__.get("emb_policy", "add") | |
| multi_target_cond = args.__dict__.get("multi_target_cond", False) | |
| multi_encoder_type = args.__dict__.get("multi_encoder_type", "multi") | |
| target_enc_layers = args.__dict__.get("target_enc_layers", 1) | |
| return { | |
| "modeltype": "", | |
| "njoints": njoints, | |
| "nfeats": nfeats, | |
| "num_actions": num_actions, | |
| "translation": True, | |
| "pose_rep": "rot6d", | |
| "glob": True, | |
| "glob_rot": True, | |
| "latent_dim": args.latent_dim, | |
| "ff_size": 1024, | |
| "num_layers": args.layers, | |
| "num_heads": 4, | |
| "dropout": 0.1, | |
| "activation": "gelu", | |
| "data_rep": data_rep, | |
| "cond_mode": cond_mode, | |
| "cond_mask_prob": args.cond_mask_prob, | |
| "action_emb": action_emb, | |
| "arch": args.arch, | |
| "emb_trans_dec": args.emb_trans_dec, | |
| "clip_version": clip_version, | |
| "dataset": args.dataset, | |
| "text_encoder_type": args.text_encoder_type, | |
| "pos_embed_max_len": args.pos_embed_max_len, | |
| "mask_frames": args.mask_frames, | |
| "pred_len": args.pred_len, | |
| "context_len": args.context_len, | |
| "emb_policy": emb_policy, | |
| "all_goal_joint_names": all_goal_joint_names, | |
| "multi_target_cond": multi_target_cond, | |
| "multi_encoder_type": multi_encoder_type, | |
| "target_enc_layers": target_enc_layers, | |
| } | |
| def create_gaussian_diffusion(args): | |
| # default params | |
| predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal! | |
| steps = args.diffusion_steps | |
| scale_beta = 1.0 # no scaling | |
| timestep_respacing = "" # can be used for ddim sampling, we don't use it. | |
| learn_sigma = False | |
| rescale_timesteps = False | |
| betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta) | |
| loss_type = gd.LossType.MSE | |
| if not timestep_respacing: | |
| timestep_respacing = [steps] | |
| if hasattr(args, "lambda_target_loc"): | |
| lambda_target_loc = args.lambda_target_loc | |
| else: | |
| lambda_target_loc = 0.0 | |
| return SpacedDiffusion( | |
| use_timesteps=space_timesteps(steps, timestep_respacing), | |
| betas=betas, | |
| model_mean_type=( | |
| gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X | |
| ), | |
| model_var_type=( | |
| ( | |
| gd.ModelVarType.FIXED_LARGE | |
| if not args.sigma_small | |
| else gd.ModelVarType.FIXED_SMALL | |
| ) | |
| if not learn_sigma | |
| else gd.ModelVarType.LEARNED_RANGE | |
| ), | |
| loss_type=loss_type, | |
| rescale_timesteps=rescale_timesteps, | |
| lambda_vel=args.lambda_vel, | |
| lambda_rcxyz=args.lambda_rcxyz, | |
| lambda_fc=args.lambda_fc, | |
| lambda_target_loc=lambda_target_loc, | |
| ) | |
| def load_saved_model(model, model_path, use_avg: bool = False): # use_avg_model | |
| state_dict = torch.load(model_path, map_location="cpu") | |
| # Use average model when possible | |
| if use_avg and "model_avg" in state_dict.keys(): | |
| # if use_avg_model: | |
| print("loading avg model") | |
| state_dict = state_dict["model_avg"] | |
| else: | |
| if "model" in state_dict: | |
| print("loading model without avg") | |
| state_dict = state_dict["model"] | |
| else: | |
| print("checkpoint has no avg model, loading as usual.") | |
| load_model_wo_clip(model, state_dict) | |
| return model | |
| def sample_from_model( | |
| model, | |
| diffusion, | |
| data=None, | |
| num_samples=1, | |
| num_repetitions=1, | |
| text_prompts=None, | |
| action_name=None, | |
| motion_length=6.0, | |
| guidance_param=3.0, | |
| n_frames=None, | |
| context_motion=None, | |
| context_len=0, | |
| pred_len=0, | |
| autoregressive=False, | |
| device="cuda", | |
| return_xyz=True, | |
| return_numpy=True, | |
| noise=None, | |
| const_noise=False, | |
| cond_images=None, | |
| frame_indices=None, | |
| ): | |
| """ | |
| Sample motions from a trained MDM model. | |
| Parameters: | |
| model: The MDM model | |
| diffusion: The diffusion object | |
| data: Optional dataset loader (used for prefix sampling if needed) | |
| num_samples: Number of samples (text prompts) to process | |
| num_repetitions: Number of different motions to generate for each prompt | |
| text_prompts: List of text prompts or single string prompt | |
| action_name: Action name(s) for action-conditioned generation | |
| motion_length: Length of motion in seconds | |
| guidance_param: Classifier-free guidance scale | |
| n_frames: Number of frames to generate (calculated from motion_length if None) | |
| context_motion: Optional context motion for prefix-based generation | |
| context_len: Context length for prefix-based generation | |
| pred_len: Prediction length for each step in autoregressive generation | |
| autoregressive: Whether to use autoregressive sampling | |
| device: Device to use for sampling | |
| return_xyz: Whether to convert output to XYZ coordinates | |
| return_numpy: Whether to return numpy arrays (True) or torch tensors (False) | |
| noise: Optional noise tensor for sampling | |
| const_noise: Whether to use constant noise for sampling | |
| cond_images: | |
| frame_indices: | |
| Returns: | |
| Dictionary containing: | |
| - motions: Generated motions with shape [num_samples*num_repetitions, njoints, 3, n_frames] | |
| - texts: Text prompts used for generation | |
| - lengths: Length of each generated motion | |
| """ | |
| assert cond_images is not None or isinstance(model, MDMControlNet), ( | |
| "Image conditioning is only supported for MDMControlNet" | |
| ) | |
| if cond_images is not None: | |
| cond_images = model.process_images(cond_images, device=device) | |
| model.eval() # Ensure model is in eval mode | |
| # Move model to the right device if it's not there already | |
| model_device = next(model.parameters()).device | |
| if str(model_device) != device: | |
| model = model.to(device) | |
| # Determine number of frames | |
| fps = 12.5 if model.dataset == "kit" else 20 | |
| if n_frames is None: | |
| n_frames = min( | |
| 196 if model.dataset in ["kit", "humanml", "humanml_with_images"] else 60, | |
| int(motion_length * fps), | |
| ) | |
| # Handle text prompts | |
| if text_prompts is not None: | |
| if isinstance(text_prompts, str): | |
| text_prompts = [text_prompts] * num_samples | |
| elif len(text_prompts) < num_samples: | |
| text_prompts = text_prompts * (num_samples // len(text_prompts) + 1) | |
| text_prompts = text_prompts[:num_samples] | |
| num_samples = len(text_prompts) | |
| # Handle action names | |
| if action_name is not None: | |
| if isinstance(action_name, str): | |
| action_text = [action_name] * num_samples | |
| else: | |
| action_text = action_name | |
| num_samples = len(action_text) | |
| # Set up classifier-free guidance | |
| original_model = model | |
| if guidance_param != 1.0: | |
| from utils.sampler_util import ClassifierFreeSampleModel | |
| model = ClassifierFreeSampleModel(model) | |
| # Set up autoregressive sampling if needed | |
| sample_fn = diffusion.p_sample_loop | |
| if autoregressive: | |
| sample_cls = AutoRegressiveSampler({"pred_len": pred_len}, sample_fn, n_frames) | |
| sample_fn = sample_cls.sample | |
| # Prepare for sampling | |
| motion_shape = (num_samples, model.njoints, model.nfeats, n_frames) | |
| # Set up model kwargs | |
| if context_motion is not None or context_len > 0: | |
| # For prefix-conditioned generation | |
| if data is None: | |
| raise ValueError("Dataset needed for context-based generation") | |
| iterator = iter(data) | |
| input_motion, model_kwargs = next(iterator) | |
| input_motion = input_motion.to(device) | |
| if text_prompts is not None: | |
| model_kwargs["y"]["text"] = text_prompts | |
| else: | |
| collate_args = [ | |
| {"inp": torch.zeros(n_frames), "tokens": None, "lengths": n_frames} | |
| ] * num_samples | |
| if text_prompts is not None: | |
| # Text-to-motion | |
| collate_args = [ | |
| dict(arg, text=txt) for arg, txt in zip(collate_args, text_prompts) | |
| ] | |
| elif action_name is not None: | |
| # Action-to-motion | |
| if hasattr(data.dataset, "action_name_to_action"): | |
| action = data.dataset.action_name_to_action(action_text) | |
| collate_args = [ | |
| dict(arg, action=one_action, action_text=one_action_text) | |
| for arg, one_action, one_action_text in zip( | |
| collate_args, action, action_text | |
| ) | |
| ] | |
| else: | |
| raise ValueError("Dataset doesn't support action conditioning") | |
| _, model_kwargs = collate(collate_args) | |
| # Move model_kwargs to device | |
| model_kwargs["y"] = { | |
| key: val.to(device) if torch.is_tensor(val) else val | |
| for key, val in model_kwargs["y"].items() | |
| } | |
| # Add image conditioning to model_kwargs if provided | |
| if cond_images is not None: | |
| model_kwargs["cond_images"] = cond_images | |
| if frame_indices is not None: | |
| model_kwargs["frame_indices"] = frame_indices | |
| # Add CFG scale to batch | |
| if guidance_param != 1.0: | |
| model_kwargs["y"]["scale"] = ( | |
| torch.ones(num_samples, device=device) * guidance_param | |
| ) | |
| # Pre-encode text for efficiency | |
| if "text" in model_kwargs["y"]: | |
| model_kwargs["y"]["text_embed"] = original_model.encode_text( | |
| model_kwargs["y"]["text"] | |
| ) | |
| # Store all generated motions and related information | |
| all_motions = [] | |
| all_text = [] | |
| all_lengths = [] | |
| # Run generation for each repetition | |
| for rep_i in range(num_repetitions): | |
| print(f"### Sampling [repetition #{rep_i + 1}/{num_repetitions}]") | |
| # Sample from the model | |
| sample = sample_fn( | |
| model, | |
| motion_shape, | |
| clip_denoised=False, | |
| model_kwargs=model_kwargs, | |
| skip_timesteps=0, | |
| init_image=None, | |
| progress=True, | |
| dump_steps=None, | |
| noise=noise, | |
| const_noise=const_noise, | |
| ) | |
| # Get text information for this batch | |
| if "text" in model_kwargs["y"]: | |
| batch_text = model_kwargs["y"]["text"] | |
| elif "action_text" in model_kwargs["y"]: | |
| batch_text = model_kwargs["y"]["action_text"] | |
| else: | |
| batch_text = [""] * num_samples | |
| all_text.extend(batch_text) | |
| # Get lengths | |
| batch_lengths = model_kwargs["y"]["lengths"].cpu() | |
| all_lengths.append(batch_lengths) | |
| # Post-process the sample if returning XYZ coordinates | |
| if return_xyz: | |
| # Recover XYZ positions from vector representation if needed | |
| if model.data_rep == "hml_vec": | |
| n_joints = 22 if sample.shape[1] == 263 else 21 | |
| if isinstance(data.dataset, Subset): | |
| dataset = data.dataset.dataset | |
| else: | |
| dataset = data.dataset | |
| sample = dataset.t2m_dataset.inv_transform( | |
| sample.cpu().permute(0, 2, 3, 1) | |
| ).float() | |
| sample = recover_from_ric(sample, n_joints) | |
| sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) | |
| # Convert rotations to XYZ coordinates | |
| rot2xyz_pose_rep = ( | |
| "xyz" if model.data_rep in ["xyz", "hml_vec"] else model.data_rep | |
| ) | |
| rot2xyz_mask = ( | |
| None | |
| if rot2xyz_pose_rep == "xyz" | |
| else model_kwargs["y"]["mask"].reshape(num_samples, n_frames).bool() | |
| ) | |
| sample = model.rot2xyz( | |
| x=sample, | |
| mask=rot2xyz_mask, | |
| pose_rep=rot2xyz_pose_rep, | |
| glob=True, | |
| translation=True, | |
| jointstype="smpl", | |
| vertstrans=True, | |
| betas=None, | |
| beta=0, | |
| glob_rot=None, | |
| get_rotations_back=False, | |
| ) | |
| # Store this batch of samples | |
| all_motions.append(sample) | |
| # Concatenate all repetitions | |
| all_motions = torch.cat(all_motions, dim=0) | |
| all_lengths = torch.cat(all_lengths, dim=0) | |
| # Convert to numpy if requested | |
| if return_numpy: | |
| all_motions = all_motions.cpu().numpy() | |
| all_lengths = all_lengths.numpy() | |
| # Reset model if we wrapped it | |
| if guidance_param != 1.0: | |
| model = original_model | |
| return {"motions": all_motions, "texts": all_text, "lengths": all_lengths} | |