txt2motion / api /mdm_loader.py
Fred808's picture
Update api/mdm_loader.py
d25e8a6 verified
import torch
from utils.model_util import create_model_and_diffusion, load_saved_model
from utils.parser_util import generate_args
import os
def load_model_and_args():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args_path = os.path.join(os.path.dirname(__file__), "..", "models", "args.json")
args = generate_args()
if os.path.exists(args_path):
import json
with open(args_path, "r") as f:
args.__dict__.update(json.load(f))
model, diffusion = create_model_and_diffusion(args, data=None)
model = load_saved_model(model, "models/model000475000.pt")
model.to(device)
model.eval()
return model, diffusion, args, device
def generate_motion(prompt: str, num_frames: int, style: str = "default"):
import numpy as np
from utils import dist_util
from data_loaders.get_data import get_dataset_loader
from data_loaders.humanml.scripts.motion_process import recover_from_ric
import torch
import json
import os
from argparse import Namespace
# Load args from models/args.json
args_path = os.path.join(os.path.dirname(__file__), "..", "models", "args.json")
with open(args_path, "r") as f:
args_dict = json.load(f)
args = Namespace()
args.__dict__.update(args_dict)
# Set/override fields from API input:
args.text_prompt = prompt
args.num_samples = 1
args.num_repetitions = 1
args.motion_length = num_frames / 20 # assuming 20 fps
args.output_dir = ""
args.model_path = os.path.join(os.path.dirname(__file__), "..", "models", "model000475000.pt")
# Set missing defaults for required fields
if not hasattr(args, "unconstrained"):
args.unconstrained = False
if not hasattr(args, "use_ema"):
args.use_ema = False
if not hasattr(args, "context_len"):
args.context_len = 0
if not hasattr(args, "pred_len"):
args.pred_len = 0
if not hasattr(args, "guidance_param"):
args.guidance_param = 1.0
if not hasattr(args, "dataset"):
args.dataset = "humanml"
if not hasattr(args, "device"):
args.device = 0
# Setup device
dist_util.setup_dist(args.device)
# Prepare dataset and model_kwargs
max_frames = 196 if args.dataset in ['kit', 'humanml'] else 60
n_frames = min(max_frames, int(args.motion_length * 20))
data = get_dataset_loader(name=args.dataset, batch_size=1, num_frames=max_frames, split='test', hml_mode='text_only', fixed_len=0, pred_len=0, device=dist_util.dev())
# Load model and diffusion with dataset
model, diffusion = create_model_and_diffusion(args, data)
load_saved_model(model, args.model_path, use_avg=getattr(args, 'use_ema', False))
model.to(dist_util.dev())
model.eval()
iterator = iter(data)
_, model_kwargs = next(iterator)
model_kwargs['y']['text'] = [prompt]
model_kwargs['y'] = {key: val.to(dist_util.dev()) if torch.is_tensor(val) else val for key, val in model_kwargs['y'].items()}
# Sampling
motion_shape = (1, model.njoints, model.nfeats, n_frames)
sample_fn = diffusion.p_sample_loop
with torch.no_grad():
sample = sample_fn(
model,
motion_shape,
clip_denoised=False,
model_kwargs=model_kwargs,
skip_timesteps=0,
init_image=None,
progress=False,
dump_steps=None,
noise=None,
const_noise=False,
)
# Post-process (recover xyz if needed)
if getattr(model, 'data_rep', None) == 'hml_vec':
n_joints = 22 if sample.shape[1] == 263 else 21
sample = data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float()
sample = recover_from_ric(sample, n_joints)
sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1)
# Convert to JSON-serializable format
motion_data = sample.cpu().numpy().tolist()
return {
"motion_data": motion_data,
"info": f"Generated motion for prompt: '{prompt}' with {num_frames} frames."
}