Spaces:
Paused
Paused
File size: 4,100 Bytes
f389eb9 1c71332 f389eb9 c0c84cd f389eb9 1bf5dc4 653ab16 1bf5dc4 653ab16 1bf5dc4 653ab16 1bf5dc4 7480aa2 1bf5dc4 d25e8a6 1bf5dc4 4768330 1bf5dc4 4768330 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import torch
from utils.model_util import create_model_and_diffusion, load_saved_model
from utils.parser_util import generate_args
import os
def load_model_and_args():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args_path = os.path.join(os.path.dirname(__file__), "..", "models", "args.json")
args = generate_args()
if os.path.exists(args_path):
import json
with open(args_path, "r") as f:
args.__dict__.update(json.load(f))
model, diffusion = create_model_and_diffusion(args, data=None)
model = load_saved_model(model, "models/model000475000.pt")
model.to(device)
model.eval()
return model, diffusion, args, device
def generate_motion(prompt: str, num_frames: int, style: str = "default"):
import numpy as np
from utils import dist_util
from data_loaders.get_data import get_dataset_loader
from data_loaders.humanml.scripts.motion_process import recover_from_ric
import torch
import json
import os
from argparse import Namespace
# Load args from models/args.json
args_path = os.path.join(os.path.dirname(__file__), "..", "models", "args.json")
with open(args_path, "r") as f:
args_dict = json.load(f)
args = Namespace()
args.__dict__.update(args_dict)
# Set/override fields from API input:
args.text_prompt = prompt
args.num_samples = 1
args.num_repetitions = 1
args.motion_length = num_frames / 20 # assuming 20 fps
args.output_dir = ""
args.model_path = os.path.join(os.path.dirname(__file__), "..", "models", "model000475000.pt")
# Set missing defaults for required fields
if not hasattr(args, "unconstrained"):
args.unconstrained = False
if not hasattr(args, "use_ema"):
args.use_ema = False
if not hasattr(args, "context_len"):
args.context_len = 0
if not hasattr(args, "pred_len"):
args.pred_len = 0
if not hasattr(args, "guidance_param"):
args.guidance_param = 1.0
if not hasattr(args, "dataset"):
args.dataset = "humanml"
if not hasattr(args, "device"):
args.device = 0
# Setup device
dist_util.setup_dist(args.device)
# Prepare dataset and model_kwargs
max_frames = 196 if args.dataset in ['kit', 'humanml'] else 60
n_frames = min(max_frames, int(args.motion_length * 20))
data = get_dataset_loader(name=args.dataset, batch_size=1, num_frames=max_frames, split='test', hml_mode='text_only', fixed_len=0, pred_len=0, device=dist_util.dev())
# Load model and diffusion with dataset
model, diffusion = create_model_and_diffusion(args, data)
load_saved_model(model, args.model_path, use_avg=getattr(args, 'use_ema', False))
model.to(dist_util.dev())
model.eval()
iterator = iter(data)
_, model_kwargs = next(iterator)
model_kwargs['y']['text'] = [prompt]
model_kwargs['y'] = {key: val.to(dist_util.dev()) if torch.is_tensor(val) else val for key, val in model_kwargs['y'].items()}
# Sampling
motion_shape = (1, model.njoints, model.nfeats, n_frames)
sample_fn = diffusion.p_sample_loop
with torch.no_grad():
sample = sample_fn(
model,
motion_shape,
clip_denoised=False,
model_kwargs=model_kwargs,
skip_timesteps=0,
init_image=None,
progress=False,
dump_steps=None,
noise=None,
const_noise=False,
)
# Post-process (recover xyz if needed)
if getattr(model, 'data_rep', None) == 'hml_vec':
n_joints = 22 if sample.shape[1] == 263 else 21
sample = data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float()
sample = recover_from_ric(sample, n_joints)
sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1)
# Convert to JSON-serializable format
motion_data = sample.cpu().numpy().tolist()
return {
"motion_data": motion_data,
"info": f"Generated motion for prompt: '{prompt}' with {num_frames} frames."
}
|