mdm / utils /model_util.py
hassanjbara's picture
update model
5007d4b
import torch
from torch.utils.data import Subset
from model.mdm import MDM
from model.mdm_controlnet import MDMControlNet
from diffusion import gaussian_diffusion as gd
from diffusion.respace import SpacedDiffusion, space_timesteps
from data_loaders.humanml_utils import HML_EE_JOINT_NAMES
from utils.sampler_util import AutoRegressiveSampler
from data_loaders.humanml.scripts.motion_process import recover_from_ric
from data_loaders.tensors import collate
def get_cond_mode(args):
if args.unconstrained:
cond_mode = "no_cond"
elif args.dataset in ["kit", "humanml", "humanml_with_images"]:
cond_mode = "text"
else:
cond_mode = "action"
return cond_mode
def load_model_wo_clip(model, state_dict):
# assert (state_dict['sequence_pos_encoder.pe'][:model.sequence_pos_encoder.pe.shape[0]] == model.sequence_pos_encoder.pe).all() # TEST
# assert (state_dict['embed_timestep.sequence_pos_encoder.pe'][:model.embed_timestep.sequence_pos_encoder.pe.shape[0]] == model.embed_timestep.sequence_pos_encoder.pe).all() # TEST
if not isinstance(model, MDMControlNet):
del state_dict[
"sequence_pos_encoder.pe"
] # no need to load it (fixed), and causes size mismatch for older models
del state_dict[
"embed_timestep.sequence_pos_encoder.pe"
] # no need to load it (fixed), and causes size mismatch for older models
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
assert len(unexpected_keys) == 0
assert all(
[
k.startswith("clip_model.") or "sequence_pos_encoder" in k
for k in missing_keys
]
)
def create_model_and_diffusion(args, data):
model = MDM(**get_model_args(args, data))
diffusion = create_gaussian_diffusion(args)
return model, diffusion
def get_model_args(args, data):
# default args
clip_version = "ViT-B/32"
action_emb = "tensor"
cond_mode = get_cond_mode(args)
if hasattr(data.dataset, "num_actions"):
num_actions = data.dataset.num_actions
else:
num_actions = 1
# SMPL defaults
data_rep = "rot6d"
njoints = 25
nfeats = 6
all_goal_joint_names = []
if args.dataset in ["humanml", "humanml_with_images"]:
data_rep = "hml_vec"
njoints = 263
nfeats = 1
all_goal_joint_names = ["pelvis"] + HML_EE_JOINT_NAMES
elif args.dataset == "kit":
data_rep = "hml_vec"
njoints = 251
nfeats = 1
# Compatibility with old models
if not hasattr(args, "pred_len"):
args.pred_len = 0
args.context_len = 0
emb_policy = args.__dict__.get("emb_policy", "add")
multi_target_cond = args.__dict__.get("multi_target_cond", False)
multi_encoder_type = args.__dict__.get("multi_encoder_type", "multi")
target_enc_layers = args.__dict__.get("target_enc_layers", 1)
return {
"modeltype": "",
"njoints": njoints,
"nfeats": nfeats,
"num_actions": num_actions,
"translation": True,
"pose_rep": "rot6d",
"glob": True,
"glob_rot": True,
"latent_dim": args.latent_dim,
"ff_size": 1024,
"num_layers": args.layers,
"num_heads": 4,
"dropout": 0.1,
"activation": "gelu",
"data_rep": data_rep,
"cond_mode": cond_mode,
"cond_mask_prob": args.cond_mask_prob,
"action_emb": action_emb,
"arch": args.arch,
"emb_trans_dec": args.emb_trans_dec,
"clip_version": clip_version,
"dataset": args.dataset,
"text_encoder_type": args.text_encoder_type,
"pos_embed_max_len": args.pos_embed_max_len,
"mask_frames": args.mask_frames,
"pred_len": args.pred_len,
"context_len": args.context_len,
"emb_policy": emb_policy,
"all_goal_joint_names": all_goal_joint_names,
"multi_target_cond": multi_target_cond,
"multi_encoder_type": multi_encoder_type,
"target_enc_layers": target_enc_layers,
}
def create_gaussian_diffusion(args):
# default params
predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal!
steps = args.diffusion_steps
scale_beta = 1.0 # no scaling
timestep_respacing = "" # can be used for ddim sampling, we don't use it.
learn_sigma = False
rescale_timesteps = False
betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta)
loss_type = gd.LossType.MSE
if not timestep_respacing:
timestep_respacing = [steps]
if hasattr(args, "lambda_target_loc"):
lambda_target_loc = args.lambda_target_loc
else:
lambda_target_loc = 0.0
return SpacedDiffusion(
use_timesteps=space_timesteps(steps, timestep_respacing),
betas=betas,
model_mean_type=(
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
),
model_var_type=(
(
gd.ModelVarType.FIXED_LARGE
if not args.sigma_small
else gd.ModelVarType.FIXED_SMALL
)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type,
rescale_timesteps=rescale_timesteps,
lambda_vel=args.lambda_vel,
lambda_rcxyz=args.lambda_rcxyz,
lambda_fc=args.lambda_fc,
lambda_target_loc=lambda_target_loc,
)
def load_saved_model(model, model_path, use_avg: bool = False): # use_avg_model
state_dict = torch.load(model_path, map_location="cpu")
# Use average model when possible
if use_avg and "model_avg" in state_dict.keys():
# if use_avg_model:
print("loading avg model")
state_dict = state_dict["model_avg"]
else:
if "model" in state_dict:
print("loading model without avg")
state_dict = state_dict["model"]
else:
print("checkpoint has no avg model, loading as usual.")
load_model_wo_clip(model, state_dict)
return model
def sample_from_model(
model,
diffusion,
data=None,
num_samples=1,
num_repetitions=1,
text_prompts=None,
action_name=None,
motion_length=6.0,
guidance_param=3.0,
n_frames=None,
context_motion=None,
context_len=0,
pred_len=0,
autoregressive=False,
device="cuda",
return_xyz=True,
return_numpy=True,
noise=None,
const_noise=False,
cond_images=None,
frame_indices=None,
):
"""
Sample motions from a trained MDM model.
Parameters:
model: The MDM model
diffusion: The diffusion object
data: Optional dataset loader (used for prefix sampling if needed)
num_samples: Number of samples (text prompts) to process
num_repetitions: Number of different motions to generate for each prompt
text_prompts: List of text prompts or single string prompt
action_name: Action name(s) for action-conditioned generation
motion_length: Length of motion in seconds
guidance_param: Classifier-free guidance scale
n_frames: Number of frames to generate (calculated from motion_length if None)
context_motion: Optional context motion for prefix-based generation
context_len: Context length for prefix-based generation
pred_len: Prediction length for each step in autoregressive generation
autoregressive: Whether to use autoregressive sampling
device: Device to use for sampling
return_xyz: Whether to convert output to XYZ coordinates
return_numpy: Whether to return numpy arrays (True) or torch tensors (False)
noise: Optional noise tensor for sampling
const_noise: Whether to use constant noise for sampling
cond_images:
frame_indices:
Returns:
Dictionary containing:
- motions: Generated motions with shape [num_samples*num_repetitions, njoints, 3, n_frames]
- texts: Text prompts used for generation
- lengths: Length of each generated motion
"""
assert cond_images is not None or isinstance(model, MDMControlNet), (
"Image conditioning is only supported for MDMControlNet"
)
if cond_images is not None:
cond_images = model.process_images(cond_images, device=device)
model.eval() # Ensure model is in eval mode
# Move model to the right device if it's not there already
model_device = next(model.parameters()).device
if str(model_device) != device:
model = model.to(device)
# Determine number of frames
fps = 12.5 if model.dataset == "kit" else 20
if n_frames is None:
n_frames = min(
196 if model.dataset in ["kit", "humanml", "humanml_with_images"] else 60,
int(motion_length * fps),
)
# Handle text prompts
if text_prompts is not None:
if isinstance(text_prompts, str):
text_prompts = [text_prompts] * num_samples
elif len(text_prompts) < num_samples:
text_prompts = text_prompts * (num_samples // len(text_prompts) + 1)
text_prompts = text_prompts[:num_samples]
num_samples = len(text_prompts)
# Handle action names
if action_name is not None:
if isinstance(action_name, str):
action_text = [action_name] * num_samples
else:
action_text = action_name
num_samples = len(action_text)
# Set up classifier-free guidance
original_model = model
if guidance_param != 1.0:
from utils.sampler_util import ClassifierFreeSampleModel
model = ClassifierFreeSampleModel(model)
# Set up autoregressive sampling if needed
sample_fn = diffusion.p_sample_loop
if autoregressive:
sample_cls = AutoRegressiveSampler({"pred_len": pred_len}, sample_fn, n_frames)
sample_fn = sample_cls.sample
# Prepare for sampling
motion_shape = (num_samples, model.njoints, model.nfeats, n_frames)
# Set up model kwargs
if context_motion is not None or context_len > 0:
# For prefix-conditioned generation
if data is None:
raise ValueError("Dataset needed for context-based generation")
iterator = iter(data)
input_motion, model_kwargs = next(iterator)
input_motion = input_motion.to(device)
if text_prompts is not None:
model_kwargs["y"]["text"] = text_prompts
else:
collate_args = [
{"inp": torch.zeros(n_frames), "tokens": None, "lengths": n_frames}
] * num_samples
if text_prompts is not None:
# Text-to-motion
collate_args = [
dict(arg, text=txt) for arg, txt in zip(collate_args, text_prompts)
]
elif action_name is not None:
# Action-to-motion
if hasattr(data.dataset, "action_name_to_action"):
action = data.dataset.action_name_to_action(action_text)
collate_args = [
dict(arg, action=one_action, action_text=one_action_text)
for arg, one_action, one_action_text in zip(
collate_args, action, action_text
)
]
else:
raise ValueError("Dataset doesn't support action conditioning")
_, model_kwargs = collate(collate_args)
# Move model_kwargs to device
model_kwargs["y"] = {
key: val.to(device) if torch.is_tensor(val) else val
for key, val in model_kwargs["y"].items()
}
# Add image conditioning to model_kwargs if provided
if cond_images is not None:
model_kwargs["cond_images"] = cond_images
if frame_indices is not None:
model_kwargs["frame_indices"] = frame_indices
# Add CFG scale to batch
if guidance_param != 1.0:
model_kwargs["y"]["scale"] = (
torch.ones(num_samples, device=device) * guidance_param
)
# Pre-encode text for efficiency
if "text" in model_kwargs["y"]:
model_kwargs["y"]["text_embed"] = original_model.encode_text(
model_kwargs["y"]["text"]
)
# Store all generated motions and related information
all_motions = []
all_text = []
all_lengths = []
# Run generation for each repetition
for rep_i in range(num_repetitions):
print(f"### Sampling [repetition #{rep_i + 1}/{num_repetitions}]")
# Sample from the model
sample = sample_fn(
model,
motion_shape,
clip_denoised=False,
model_kwargs=model_kwargs,
skip_timesteps=0,
init_image=None,
progress=True,
dump_steps=None,
noise=noise,
const_noise=const_noise,
)
# Get text information for this batch
if "text" in model_kwargs["y"]:
batch_text = model_kwargs["y"]["text"]
elif "action_text" in model_kwargs["y"]:
batch_text = model_kwargs["y"]["action_text"]
else:
batch_text = [""] * num_samples
all_text.extend(batch_text)
# Get lengths
batch_lengths = model_kwargs["y"]["lengths"].cpu()
all_lengths.append(batch_lengths)
# Post-process the sample if returning XYZ coordinates
if return_xyz:
# Recover XYZ positions from vector representation if needed
if model.data_rep == "hml_vec":
n_joints = 22 if sample.shape[1] == 263 else 21
if isinstance(data.dataset, Subset):
dataset = data.dataset.dataset
else:
dataset = data.dataset
sample = dataset.t2m_dataset.inv_transform(
sample.cpu().permute(0, 2, 3, 1)
).float()
sample = recover_from_ric(sample, n_joints)
sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1)
# Convert rotations to XYZ coordinates
rot2xyz_pose_rep = (
"xyz" if model.data_rep in ["xyz", "hml_vec"] else model.data_rep
)
rot2xyz_mask = (
None
if rot2xyz_pose_rep == "xyz"
else model_kwargs["y"]["mask"].reshape(num_samples, n_frames).bool()
)
sample = model.rot2xyz(
x=sample,
mask=rot2xyz_mask,
pose_rep=rot2xyz_pose_rep,
glob=True,
translation=True,
jointstype="smpl",
vertstrans=True,
betas=None,
beta=0,
glob_rot=None,
get_rotations_back=False,
)
# Store this batch of samples
all_motions.append(sample)
# Concatenate all repetitions
all_motions = torch.cat(all_motions, dim=0)
all_lengths = torch.cat(all_lengths, dim=0)
# Convert to numpy if requested
if return_numpy:
all_motions = all_motions.cpu().numpy()
all_lengths = all_lengths.numpy()
# Reset model if we wrapped it
if guidance_param != 1.0:
model = original_model
return {"motions": all_motions, "texts": all_text, "lengths": all_lengths}