Spaces:
Paused
Paused
| import torch | |
| from utils.model_util import create_model_and_diffusion, load_saved_model | |
| from utils.parser_util import generate_args | |
| import os | |
| def load_model_and_args(): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| args_path = os.path.join(os.path.dirname(__file__), "..", "models", "args.json") | |
| args = generate_args() | |
| if os.path.exists(args_path): | |
| import json | |
| with open(args_path, "r") as f: | |
| args.__dict__.update(json.load(f)) | |
| model, diffusion = create_model_and_diffusion(args, data=None) | |
| model = load_saved_model(model, "models/model000475000.pt") | |
| model.to(device) | |
| model.eval() | |
| return model, diffusion, args, device | |
| def generate_motion(prompt: str, num_frames: int, style: str = "default"): | |
| import numpy as np | |
| from utils import dist_util | |
| from data_loaders.get_data import get_dataset_loader | |
| from data_loaders.humanml.scripts.motion_process import recover_from_ric | |
| import torch | |
| import json | |
| import os | |
| from argparse import Namespace | |
| # Load args from models/args.json | |
| args_path = os.path.join(os.path.dirname(__file__), "..", "models", "args.json") | |
| with open(args_path, "r") as f: | |
| args_dict = json.load(f) | |
| args = Namespace() | |
| args.__dict__.update(args_dict) | |
| # Set/override fields from API input: | |
| args.text_prompt = prompt | |
| args.num_samples = 1 | |
| args.num_repetitions = 1 | |
| args.motion_length = num_frames / 20 # assuming 20 fps | |
| args.output_dir = "" | |
| args.model_path = os.path.join(os.path.dirname(__file__), "..", "models", "model000475000.pt") | |
| # Set missing defaults for required fields | |
| if not hasattr(args, "unconstrained"): | |
| args.unconstrained = False | |
| if not hasattr(args, "use_ema"): | |
| args.use_ema = False | |
| if not hasattr(args, "context_len"): | |
| args.context_len = 0 | |
| if not hasattr(args, "pred_len"): | |
| args.pred_len = 0 | |
| if not hasattr(args, "guidance_param"): | |
| args.guidance_param = 1.0 | |
| if not hasattr(args, "dataset"): | |
| args.dataset = "humanml" | |
| if not hasattr(args, "device"): | |
| args.device = 0 | |
| # Setup device | |
| dist_util.setup_dist(args.device) | |
| # Prepare dataset and model_kwargs | |
| max_frames = 196 if args.dataset in ['kit', 'humanml'] else 60 | |
| n_frames = min(max_frames, int(args.motion_length * 20)) | |
| data = get_dataset_loader(name=args.dataset, batch_size=1, num_frames=max_frames, split='test', hml_mode='text_only', fixed_len=0, pred_len=0, device=dist_util.dev()) | |
| # Load model and diffusion with dataset | |
| model, diffusion = create_model_and_diffusion(args, data) | |
| load_saved_model(model, args.model_path, use_avg=getattr(args, 'use_ema', False)) | |
| model.to(dist_util.dev()) | |
| model.eval() | |
| iterator = iter(data) | |
| _, model_kwargs = next(iterator) | |
| model_kwargs['y']['text'] = [prompt] | |
| model_kwargs['y'] = {key: val.to(dist_util.dev()) if torch.is_tensor(val) else val for key, val in model_kwargs['y'].items()} | |
| # Sampling | |
| motion_shape = (1, model.njoints, model.nfeats, n_frames) | |
| sample_fn = diffusion.p_sample_loop | |
| with torch.no_grad(): | |
| sample = sample_fn( | |
| model, | |
| motion_shape, | |
| clip_denoised=False, | |
| model_kwargs=model_kwargs, | |
| skip_timesteps=0, | |
| init_image=None, | |
| progress=False, | |
| dump_steps=None, | |
| noise=None, | |
| const_noise=False, | |
| ) | |
| # Post-process (recover xyz if needed) | |
| if getattr(model, 'data_rep', None) == 'hml_vec': | |
| n_joints = 22 if sample.shape[1] == 263 else 21 | |
| sample = data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float() | |
| sample = recover_from_ric(sample, n_joints) | |
| sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) | |
| # Convert to JSON-serializable format | |
| motion_data = sample.cpu().numpy().tolist() | |
| return { | |
| "motion_data": motion_data, | |
| "info": f"Generated motion for prompt: '{prompt}' with {num_frames} frames." | |
| } | |