movimento / generation.py
rydlrKE's picture
Fix CUDA import order - import spaces before torch (commit e28bffd)
2a5255e verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
from typing import Optional
import numpy as np
import torch
import viser
from kimodo.constraints import (
TYPE_TO_CLASS,
FullBodyConstraintSet,
Root2DConstraintSet,
)
from kimodo.exports.mujoco import apply_g1_real_robot_projection
from kimodo.skeleton import G1Skeleton34, SOMASkeleton30
from kimodo.tools import seed_everything
from .embedding_cache import CachedTextEncoder
from .state import ClientSession, ModelBundle
def compute_model_constraints_lst(
session: ClientSession,
model_bundle: ModelBundle,
num_frames: int,
device: str,
):
"""Compute the lst of constraints for the model based on the constraints in viser."""
assert len(session.motions) == 1, "Only one motion allowed for constrained generation"
if not session.constraints:
return []
model_skeleton = model_bundle.model.skeleton
# For SOMA, UI uses somaskel77; extract 30-joint subset for the model
use_skel_slice = isinstance(model_skeleton, SOMASkeleton30) and session.skeleton.nbjoints != model_skeleton.nbjoints
skel_slice = model_skeleton.get_skel_slice(session.skeleton) if use_skel_slice else None
dense_smooth_root_pos_2d = None
if session.constraints["2D Root"].dense_path:
# get the full 2d root
dense_smooth_root_pos_2d = session.constraints["2D Root"].get_constraint_info(device=device)["root_pos"][
:, [0, 2]
]
model_constraints = []
for track_name, constraint in session.constraints.items():
constraint_info = constraint.get_constraint_info(device=device)
frame_idx = constraint_info["frame_idx"]
# drop any constraints outside the generation range
valid_info = [(i, fi) for i, fi in enumerate(frame_idx) if fi < num_frames]
valid_idx = [i for i, _ in valid_info]
valid_frame_idx = [fi for _, fi in valid_info]
if len(valid_frame_idx) == 0:
continue
frame_indices = torch.tensor(valid_frame_idx)
if track_name == "2D Root":
smooth_root_pos_2d = constraint_info["root_pos"][valid_idx][:, [0, 2]].to(device)
# same as "smooth_root_2d"
model_constraints.append(
Root2DConstraintSet(
model_skeleton,
frame_indices,
smooth_root_pos_2d,
)
)
elif track_name == "Full-Body":
constraint_joints_pos = constraint_info["joints_pos"][valid_idx].to(device)
constraint_joints_rot = constraint_info["joints_rot"][valid_idx].to(device)
if skel_slice is not None:
constraint_joints_pos = constraint_joints_pos[:, skel_slice]
constraint_joints_rot = constraint_joints_rot[:, skel_slice]
smooth_root_pos_2d = None
if dense_smooth_root_pos_2d is not None:
smooth_root_pos_2d = dense_smooth_root_pos_2d[frame_indices]
model_constraints.append(
FullBodyConstraintSet(
model_skeleton,
frame_indices,
constraint_joints_pos,
constraint_joints_rot,
smooth_root_2d=smooth_root_pos_2d,
)
)
elif track_name == "End-Effectors":
constraint_joints_pos = constraint_info["joints_pos"][valid_idx].to(device)
constraint_joints_rot = constraint_info["joints_rot"][valid_idx].to(device)
if skel_slice is not None:
constraint_joints_pos = constraint_joints_pos[:, skel_slice]
constraint_joints_rot = constraint_joints_rot[:, skel_slice]
end_effector_type_set_lst = [
end_effector_type_set
for i, end_effector_type_set in enumerate(constraint_info["end_effector_type"])
if i in valid_idx
]
# regroup the end effector data by type
cls_idx = defaultdict(list)
for idx, end_effector_type_set in enumerate(end_effector_type_set_lst):
for end_effector_type in end_effector_type_set:
cls_idx[TYPE_TO_CLASS[end_effector_type]].append(idx)
for cls, lst_idx in cls_idx.items():
frame_indices_cls = frame_indices[lst_idx]
smooth_root_pos_2d = None
if dense_smooth_root_pos_2d is not None:
smooth_root_pos_2d = dense_smooth_root_pos_2d[frame_indices_cls]
constraint_joints_pos_el = constraint_joints_pos[lst_idx]
constraint_joints_rot_el = constraint_joints_rot[lst_idx]
model_constraints.append(
cls(
model_skeleton,
frame_indices_cls,
constraint_joints_pos_el,
constraint_joints_rot_el,
smooth_root_2d=smooth_root_pos_2d,
)
)
else:
raise ValueError(f"Unsupported constraint type: {constraint.display_name}")
return model_constraints
def generate(
*,
client: viser.ClientHandle,
session: ClientSession,
model_bundle: ModelBundle,
prompts: list[str],
num_frames: list[int],
num_samples: int,
seed: int,
diffusion_steps: int,
cfg_weight: Optional[list[float]] = None,
cfg_type: Optional[str] = None,
postprocess_parameters: Optional[dict] = None,
transitions_parameters: Optional[dict] = None,
real_robot_rotations: bool = False,
device: str,
clear_motions,
add_character_motion,
) -> None:
client_id = client.client_id
print(
f"Generating {num_samples} samples for a total of {sum(num_frames)} frames with those prompt: {prompts} (client {client_id})"
)
seed_everything(seed)
model_constraints = compute_model_constraints_lst(session, model_bundle, sum(num_frames), device)
cfg_weight = cfg_weight or [2.0, 2.0]
postprocess_parameters = postprocess_parameters or {}
transitions_parameters = transitions_parameters or {}
encoder = getattr(model_bundle.model, "text_encoder", None)
if isinstance(encoder, CachedTextEncoder):
with encoder.session_context(session):
pred_joints_output = model_bundle.model(
prompts,
num_frames,
diffusion_steps,
multi_prompt=True,
constraint_lst=model_constraints,
cfg_weight=cfg_weight,
num_samples=num_samples,
cfg_type=cfg_type,
**(postprocess_parameters | transitions_parameters),
) # [B, T, motion_rep_dim]
else:
pred_joints_output = model_bundle.model(
prompts,
num_frames,
diffusion_steps,
multi_prompt=True,
constraint_lst=model_constraints,
cfg_weight=cfg_weight,
num_samples=num_samples,
cfg_type=cfg_type,
**(postprocess_parameters | transitions_parameters),
) # [B, T, motion_rep_dim]
joints_pos = pred_joints_output["posed_joints"] # [B, T, J, 3]
joints_rot = pred_joints_output["global_rot_mats"]
foot_contacts = pred_joints_output.get("foot_contacts")
# Optionally project G1 to real robot DoF (1-DoF per joint, clamped) for display.
if real_robot_rotations and isinstance(session.skeleton, G1Skeleton34):
joints_pos, joints_rot = apply_g1_real_robot_projection(
session.skeleton,
pred_joints_output["posed_joints"],
pred_joints_output["global_rot_mats"],
clamp_to_limits=True,
)
# Display on characters (callbacks keep this module UI-agnostic).
clear_motions(client_id)
# Keep one sample centered at the origin so constraints align.
spread_factor = 1.0 # meters
center_idx = num_samples // 2
x_trans = (np.arange(num_samples) - center_idx) * spread_factor
for i in range(num_samples):
cur_joints_pos = joints_pos[i]
cur_joints_pos[..., 0] += x_trans[i]
add_character_motion(
client,
session.skeleton,
cur_joints_pos,
joints_rot[i],
foot_contacts[i],
)