File size: 4,100 Bytes
f389eb9
1c71332
 
f389eb9
 
c0c84cd
 
 
 
 
 
 
 
 
 
 
 
 
f389eb9
 
1bf5dc4
 
 
 
 
 
 
653ab16
1bf5dc4
 
 
 
 
653ab16
1bf5dc4
653ab16
1bf5dc4
 
 
 
 
 
 
7480aa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bf5dc4
 
 
 
 
 
 
d25e8a6
 
 
 
 
 
 
1bf5dc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4768330
 
1bf5dc4
4768330
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from utils.model_util import create_model_and_diffusion, load_saved_model
from utils.parser_util import generate_args
import os

def load_model_and_args():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args_path = os.path.join(os.path.dirname(__file__), "..", "models", "args.json")
    args = generate_args()
    if os.path.exists(args_path):
        import json
        with open(args_path, "r") as f:
            args.__dict__.update(json.load(f))
    model, diffusion = create_model_and_diffusion(args, data=None)
    model = load_saved_model(model, "models/model000475000.pt")
    model.to(device)
    model.eval()
    return model, diffusion, args, device

def generate_motion(prompt: str, num_frames: int, style: str = "default"):
    import numpy as np
    from utils import dist_util
    from data_loaders.get_data import get_dataset_loader
    from data_loaders.humanml.scripts.motion_process import recover_from_ric
    import torch
    import json
    import os
    from argparse import Namespace

    # Load args from models/args.json
    args_path = os.path.join(os.path.dirname(__file__), "..", "models", "args.json")
    with open(args_path, "r") as f:
        args_dict = json.load(f)
    args = Namespace()
    args.__dict__.update(args_dict)
    # Set/override fields from API input:
    args.text_prompt = prompt
    args.num_samples = 1
    args.num_repetitions = 1
    args.motion_length = num_frames / 20  # assuming 20 fps
    args.output_dir = ""
    args.model_path = os.path.join(os.path.dirname(__file__), "..", "models", "model000475000.pt")

    # Set missing defaults for required fields
    if not hasattr(args, "unconstrained"):
        args.unconstrained = False
    if not hasattr(args, "use_ema"):
        args.use_ema = False
    if not hasattr(args, "context_len"):
        args.context_len = 0
    if not hasattr(args, "pred_len"):
        args.pred_len = 0
    if not hasattr(args, "guidance_param"):
        args.guidance_param = 1.0
    if not hasattr(args, "dataset"):
        args.dataset = "humanml"
    if not hasattr(args, "device"):
        args.device = 0

    # Setup device
    dist_util.setup_dist(args.device)

    # Prepare dataset and model_kwargs
    max_frames = 196 if args.dataset in ['kit', 'humanml'] else 60
    n_frames = min(max_frames, int(args.motion_length * 20))
    data = get_dataset_loader(name=args.dataset, batch_size=1, num_frames=max_frames, split='test', hml_mode='text_only', fixed_len=0, pred_len=0, device=dist_util.dev())

    # Load model and diffusion with dataset
    model, diffusion = create_model_and_diffusion(args, data)
    load_saved_model(model, args.model_path, use_avg=getattr(args, 'use_ema', False))
    model.to(dist_util.dev())
    model.eval()

    iterator = iter(data)
    _, model_kwargs = next(iterator)
    model_kwargs['y']['text'] = [prompt]
    model_kwargs['y'] = {key: val.to(dist_util.dev()) if torch.is_tensor(val) else val for key, val in model_kwargs['y'].items()}

    # Sampling
    motion_shape = (1, model.njoints, model.nfeats, n_frames)
    sample_fn = diffusion.p_sample_loop
    with torch.no_grad():
        sample = sample_fn(
            model,
            motion_shape,
            clip_denoised=False,
            model_kwargs=model_kwargs,
            skip_timesteps=0,
            init_image=None,
            progress=False,
            dump_steps=None,
            noise=None,
            const_noise=False,
        )

    # Post-process (recover xyz if needed)
    if getattr(model, 'data_rep', None) == 'hml_vec':
        n_joints = 22 if sample.shape[1] == 263 else 21
        sample = data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float()
        sample = recover_from_ric(sample, n_joints)
        sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1)

    # Convert to JSON-serializable format
    motion_data = sample.cpu().numpy().tolist()

    return {
        "motion_data": motion_data,
        "info": f"Generated motion for prompt: '{prompt}' with {num_frames} frames."
    }