File size: 15,406 Bytes
5007d4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
import torch
from torch.utils.data import Subset

from model.mdm import MDM
from model.mdm_controlnet import MDMControlNet
from diffusion import gaussian_diffusion as gd
from diffusion.respace import SpacedDiffusion, space_timesteps
from data_loaders.humanml_utils import HML_EE_JOINT_NAMES
from utils.sampler_util import AutoRegressiveSampler
from data_loaders.humanml.scripts.motion_process import recover_from_ric
from data_loaders.tensors import collate


def get_cond_mode(args):
    if args.unconstrained:
        cond_mode = "no_cond"
    elif args.dataset in ["kit", "humanml", "humanml_with_images"]:
        cond_mode = "text"
    else:
        cond_mode = "action"
    return cond_mode


def load_model_wo_clip(model, state_dict):
    # assert (state_dict['sequence_pos_encoder.pe'][:model.sequence_pos_encoder.pe.shape[0]] == model.sequence_pos_encoder.pe).all()  # TEST
    # assert (state_dict['embed_timestep.sequence_pos_encoder.pe'][:model.embed_timestep.sequence_pos_encoder.pe.shape[0]] == model.embed_timestep.sequence_pos_encoder.pe).all()  # TEST
    if not isinstance(model, MDMControlNet):
        del state_dict[
            "sequence_pos_encoder.pe"
        ]  # no need to load it (fixed), and causes size mismatch for older models
        del state_dict[
            "embed_timestep.sequence_pos_encoder.pe"
        ]  # no need to load it (fixed), and causes size mismatch for older models
    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
    assert len(unexpected_keys) == 0
    assert all(
        [
            k.startswith("clip_model.") or "sequence_pos_encoder" in k
            for k in missing_keys
        ]
    )


def create_model_and_diffusion(args, data):
    model = MDM(**get_model_args(args, data))
    diffusion = create_gaussian_diffusion(args)
    return model, diffusion


def get_model_args(args, data):
    # default args
    clip_version = "ViT-B/32"
    action_emb = "tensor"
    cond_mode = get_cond_mode(args)
    if hasattr(data.dataset, "num_actions"):
        num_actions = data.dataset.num_actions
    else:
        num_actions = 1

    # SMPL defaults
    data_rep = "rot6d"
    njoints = 25
    nfeats = 6
    all_goal_joint_names = []

    if args.dataset in ["humanml", "humanml_with_images"]:
        data_rep = "hml_vec"
        njoints = 263
        nfeats = 1
        all_goal_joint_names = ["pelvis"] + HML_EE_JOINT_NAMES
    elif args.dataset == "kit":
        data_rep = "hml_vec"
        njoints = 251
        nfeats = 1

    # Compatibility with old models
    if not hasattr(args, "pred_len"):
        args.pred_len = 0
        args.context_len = 0

    emb_policy = args.__dict__.get("emb_policy", "add")
    multi_target_cond = args.__dict__.get("multi_target_cond", False)
    multi_encoder_type = args.__dict__.get("multi_encoder_type", "multi")
    target_enc_layers = args.__dict__.get("target_enc_layers", 1)

    return {
        "modeltype": "",
        "njoints": njoints,
        "nfeats": nfeats,
        "num_actions": num_actions,
        "translation": True,
        "pose_rep": "rot6d",
        "glob": True,
        "glob_rot": True,
        "latent_dim": args.latent_dim,
        "ff_size": 1024,
        "num_layers": args.layers,
        "num_heads": 4,
        "dropout": 0.1,
        "activation": "gelu",
        "data_rep": data_rep,
        "cond_mode": cond_mode,
        "cond_mask_prob": args.cond_mask_prob,
        "action_emb": action_emb,
        "arch": args.arch,
        "emb_trans_dec": args.emb_trans_dec,
        "clip_version": clip_version,
        "dataset": args.dataset,
        "text_encoder_type": args.text_encoder_type,
        "pos_embed_max_len": args.pos_embed_max_len,
        "mask_frames": args.mask_frames,
        "pred_len": args.pred_len,
        "context_len": args.context_len,
        "emb_policy": emb_policy,
        "all_goal_joint_names": all_goal_joint_names,
        "multi_target_cond": multi_target_cond,
        "multi_encoder_type": multi_encoder_type,
        "target_enc_layers": target_enc_layers,
    }


def create_gaussian_diffusion(args):
    # default params
    predict_xstart = True  # we always predict x_start (a.k.a. x0), that's our deal!
    steps = args.diffusion_steps
    scale_beta = 1.0  # no scaling
    timestep_respacing = ""  # can be used for ddim sampling, we don't use it.
    learn_sigma = False
    rescale_timesteps = False

    betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta)
    loss_type = gd.LossType.MSE

    if not timestep_respacing:
        timestep_respacing = [steps]

    if hasattr(args, "lambda_target_loc"):
        lambda_target_loc = args.lambda_target_loc
    else:
        lambda_target_loc = 0.0

    return SpacedDiffusion(
        use_timesteps=space_timesteps(steps, timestep_respacing),
        betas=betas,
        model_mean_type=(
            gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
        ),
        model_var_type=(
            (
                gd.ModelVarType.FIXED_LARGE
                if not args.sigma_small
                else gd.ModelVarType.FIXED_SMALL
            )
            if not learn_sigma
            else gd.ModelVarType.LEARNED_RANGE
        ),
        loss_type=loss_type,
        rescale_timesteps=rescale_timesteps,
        lambda_vel=args.lambda_vel,
        lambda_rcxyz=args.lambda_rcxyz,
        lambda_fc=args.lambda_fc,
        lambda_target_loc=lambda_target_loc,
    )


def load_saved_model(model, model_path, use_avg: bool = False):  # use_avg_model
    state_dict = torch.load(model_path, map_location="cpu")
    # Use average model when possible
    if use_avg and "model_avg" in state_dict.keys():
        # if use_avg_model:
        print("loading avg model")
        state_dict = state_dict["model_avg"]
    else:
        if "model" in state_dict:
            print("loading model without avg")
            state_dict = state_dict["model"]
        else:
            print("checkpoint has no avg model, loading as usual.")
    load_model_wo_clip(model, state_dict)
    return model


def sample_from_model(
    model,
    diffusion,
    data=None,
    num_samples=1,
    num_repetitions=1,
    text_prompts=None,
    action_name=None,
    motion_length=6.0,
    guidance_param=3.0,
    n_frames=None,
    context_motion=None,
    context_len=0,
    pred_len=0,
    autoregressive=False,
    device="cuda",
    return_xyz=True,
    return_numpy=True,
    noise=None,
    const_noise=False,
    cond_images=None,
    frame_indices=None,
):
    """
    Sample motions from a trained MDM model.

    Parameters:
        model: The MDM model
        diffusion: The diffusion object
        data: Optional dataset loader (used for prefix sampling if needed)
        num_samples: Number of samples (text prompts) to process
        num_repetitions: Number of different motions to generate for each prompt
        text_prompts: List of text prompts or single string prompt
        action_name: Action name(s) for action-conditioned generation
        motion_length: Length of motion in seconds
        guidance_param: Classifier-free guidance scale
        n_frames: Number of frames to generate (calculated from motion_length if None)
        context_motion: Optional context motion for prefix-based generation
        context_len: Context length for prefix-based generation
        pred_len: Prediction length for each step in autoregressive generation
        autoregressive: Whether to use autoregressive sampling
        device: Device to use for sampling
        return_xyz: Whether to convert output to XYZ coordinates
        return_numpy: Whether to return numpy arrays (True) or torch tensors (False)
        noise: Optional noise tensor for sampling
        const_noise: Whether to use constant noise for sampling
        cond_images:
        frame_indices:

    Returns:
        Dictionary containing:
            - motions: Generated motions with shape [num_samples*num_repetitions, njoints, 3, n_frames]
            - texts: Text prompts used for generation
            - lengths: Length of each generated motion
    """
    assert cond_images is not None or isinstance(model, MDMControlNet), (
        "Image conditioning is only supported for MDMControlNet"
    )
    if cond_images is not None:
        cond_images = model.process_images(cond_images, device=device)

    model.eval()  # Ensure model is in eval mode

    # Move model to the right device if it's not there already
    model_device = next(model.parameters()).device
    if str(model_device) != device:
        model = model.to(device)

    # Determine number of frames
    fps = 12.5 if model.dataset == "kit" else 20
    if n_frames is None:
        n_frames = min(
            196 if model.dataset in ["kit", "humanml", "humanml_with_images"] else 60,
            int(motion_length * fps),
        )

    # Handle text prompts
    if text_prompts is not None:
        if isinstance(text_prompts, str):
            text_prompts = [text_prompts] * num_samples
        elif len(text_prompts) < num_samples:
            text_prompts = text_prompts * (num_samples // len(text_prompts) + 1)
            text_prompts = text_prompts[:num_samples]
        num_samples = len(text_prompts)

    # Handle action names
    if action_name is not None:
        if isinstance(action_name, str):
            action_text = [action_name] * num_samples
        else:
            action_text = action_name
            num_samples = len(action_text)

    # Set up classifier-free guidance
    original_model = model
    if guidance_param != 1.0:
        from utils.sampler_util import ClassifierFreeSampleModel

        model = ClassifierFreeSampleModel(model)

    # Set up autoregressive sampling if needed
    sample_fn = diffusion.p_sample_loop
    if autoregressive:
        sample_cls = AutoRegressiveSampler({"pred_len": pred_len}, sample_fn, n_frames)
        sample_fn = sample_cls.sample

    # Prepare for sampling
    motion_shape = (num_samples, model.njoints, model.nfeats, n_frames)

    # Set up model kwargs
    if context_motion is not None or context_len > 0:
        # For prefix-conditioned generation
        if data is None:
            raise ValueError("Dataset needed for context-based generation")
        iterator = iter(data)
        input_motion, model_kwargs = next(iterator)
        input_motion = input_motion.to(device)
        if text_prompts is not None:
            model_kwargs["y"]["text"] = text_prompts
    else:
        collate_args = [
            {"inp": torch.zeros(n_frames), "tokens": None, "lengths": n_frames}
        ] * num_samples

        if text_prompts is not None:
            # Text-to-motion
            collate_args = [
                dict(arg, text=txt) for arg, txt in zip(collate_args, text_prompts)
            ]
        elif action_name is not None:
            # Action-to-motion
            if hasattr(data.dataset, "action_name_to_action"):
                action = data.dataset.action_name_to_action(action_text)
                collate_args = [
                    dict(arg, action=one_action, action_text=one_action_text)
                    for arg, one_action, one_action_text in zip(
                        collate_args, action, action_text
                    )
                ]
            else:
                raise ValueError("Dataset doesn't support action conditioning")

        _, model_kwargs = collate(collate_args)

    # Move model_kwargs to device
    model_kwargs["y"] = {
        key: val.to(device) if torch.is_tensor(val) else val
        for key, val in model_kwargs["y"].items()
    }

    # Add image conditioning to model_kwargs if provided
    if cond_images is not None:
        model_kwargs["cond_images"] = cond_images

    if frame_indices is not None:
        model_kwargs["frame_indices"] = frame_indices

    # Add CFG scale to batch
    if guidance_param != 1.0:
        model_kwargs["y"]["scale"] = (
            torch.ones(num_samples, device=device) * guidance_param
        )

    # Pre-encode text for efficiency
    if "text" in model_kwargs["y"]:
        model_kwargs["y"]["text_embed"] = original_model.encode_text(
            model_kwargs["y"]["text"]
        )

    # Store all generated motions and related information
    all_motions = []
    all_text = []
    all_lengths = []

    # Run generation for each repetition
    for rep_i in range(num_repetitions):
        print(f"### Sampling [repetition #{rep_i + 1}/{num_repetitions}]")

        # Sample from the model
        sample = sample_fn(
            model,
            motion_shape,
            clip_denoised=False,
            model_kwargs=model_kwargs,
            skip_timesteps=0,
            init_image=None,
            progress=True,
            dump_steps=None,
            noise=noise,
            const_noise=const_noise,
        )

        # Get text information for this batch
        if "text" in model_kwargs["y"]:
            batch_text = model_kwargs["y"]["text"]
        elif "action_text" in model_kwargs["y"]:
            batch_text = model_kwargs["y"]["action_text"]
        else:
            batch_text = [""] * num_samples

        all_text.extend(batch_text)

        # Get lengths
        batch_lengths = model_kwargs["y"]["lengths"].cpu()
        all_lengths.append(batch_lengths)

        # Post-process the sample if returning XYZ coordinates
        if return_xyz:
            # Recover XYZ positions from vector representation if needed
            if model.data_rep == "hml_vec":
                n_joints = 22 if sample.shape[1] == 263 else 21
                if isinstance(data.dataset, Subset):
                    dataset = data.dataset.dataset
                else:
                    dataset = data.dataset
                sample = dataset.t2m_dataset.inv_transform(
                    sample.cpu().permute(0, 2, 3, 1)
                ).float()
                sample = recover_from_ric(sample, n_joints)
                sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1)

            # Convert rotations to XYZ coordinates
            rot2xyz_pose_rep = (
                "xyz" if model.data_rep in ["xyz", "hml_vec"] else model.data_rep
            )
            rot2xyz_mask = (
                None
                if rot2xyz_pose_rep == "xyz"
                else model_kwargs["y"]["mask"].reshape(num_samples, n_frames).bool()
            )

            sample = model.rot2xyz(
                x=sample,
                mask=rot2xyz_mask,
                pose_rep=rot2xyz_pose_rep,
                glob=True,
                translation=True,
                jointstype="smpl",
                vertstrans=True,
                betas=None,
                beta=0,
                glob_rot=None,
                get_rotations_back=False,
            )

        # Store this batch of samples
        all_motions.append(sample)

    # Concatenate all repetitions
    all_motions = torch.cat(all_motions, dim=0)
    all_lengths = torch.cat(all_lengths, dim=0)

    # Convert to numpy if requested
    if return_numpy:
        all_motions = all_motions.cpu().numpy()
        all_lengths = all_lengths.numpy()

    # Reset model if we wrapped it
    if guidance_param != 1.0:
        model = original_model

    return {"motions": all_motions, "texts": all_text, "lengths": all_lengths}