| |
| |
|
|
| from collections import defaultdict |
| from typing import Optional |
|
|
| import numpy as np |
| import torch |
|
|
| import viser |
| from kimodo.constraints import ( |
| TYPE_TO_CLASS, |
| FullBodyConstraintSet, |
| Root2DConstraintSet, |
| ) |
| from kimodo.exports.mujoco import apply_g1_real_robot_projection |
| from kimodo.skeleton import G1Skeleton34, SOMASkeleton30 |
| from kimodo.tools import seed_everything |
|
|
| from .embedding_cache import CachedTextEncoder |
| from .state import ClientSession, ModelBundle |
|
|
|
|
| def compute_model_constraints_lst( |
| session: ClientSession, |
| model_bundle: ModelBundle, |
| num_frames: int, |
| device: str, |
| ): |
| """Compute the lst of constraints for the model based on the constraints in viser.""" |
| assert len(session.motions) == 1, "Only one motion allowed for constrained generation" |
| if not session.constraints: |
| return [] |
|
|
| model_skeleton = model_bundle.model.skeleton |
| |
| use_skel_slice = isinstance(model_skeleton, SOMASkeleton30) and session.skeleton.nbjoints != model_skeleton.nbjoints |
| skel_slice = model_skeleton.get_skel_slice(session.skeleton) if use_skel_slice else None |
|
|
| dense_smooth_root_pos_2d = None |
| if session.constraints["2D Root"].dense_path: |
| |
| dense_smooth_root_pos_2d = session.constraints["2D Root"].get_constraint_info(device=device)["root_pos"][ |
| :, [0, 2] |
| ] |
|
|
| model_constraints = [] |
| for track_name, constraint in session.constraints.items(): |
| constraint_info = constraint.get_constraint_info(device=device) |
| frame_idx = constraint_info["frame_idx"] |
| |
| valid_info = [(i, fi) for i, fi in enumerate(frame_idx) if fi < num_frames] |
| valid_idx = [i for i, _ in valid_info] |
| valid_frame_idx = [fi for _, fi in valid_info] |
|
|
| if len(valid_frame_idx) == 0: |
| continue |
|
|
| frame_indices = torch.tensor(valid_frame_idx) |
| if track_name == "2D Root": |
| smooth_root_pos_2d = constraint_info["root_pos"][valid_idx][:, [0, 2]].to(device) |
| |
| model_constraints.append( |
| Root2DConstraintSet( |
| model_skeleton, |
| frame_indices, |
| smooth_root_pos_2d, |
| ) |
| ) |
| elif track_name == "Full-Body": |
| constraint_joints_pos = constraint_info["joints_pos"][valid_idx].to(device) |
| constraint_joints_rot = constraint_info["joints_rot"][valid_idx].to(device) |
| if skel_slice is not None: |
| constraint_joints_pos = constraint_joints_pos[:, skel_slice] |
| constraint_joints_rot = constraint_joints_rot[:, skel_slice] |
|
|
| smooth_root_pos_2d = None |
| if dense_smooth_root_pos_2d is not None: |
| smooth_root_pos_2d = dense_smooth_root_pos_2d[frame_indices] |
|
|
| model_constraints.append( |
| FullBodyConstraintSet( |
| model_skeleton, |
| frame_indices, |
| constraint_joints_pos, |
| constraint_joints_rot, |
| smooth_root_2d=smooth_root_pos_2d, |
| ) |
| ) |
| elif track_name == "End-Effectors": |
| constraint_joints_pos = constraint_info["joints_pos"][valid_idx].to(device) |
| constraint_joints_rot = constraint_info["joints_rot"][valid_idx].to(device) |
| if skel_slice is not None: |
| constraint_joints_pos = constraint_joints_pos[:, skel_slice] |
| constraint_joints_rot = constraint_joints_rot[:, skel_slice] |
|
|
| end_effector_type_set_lst = [ |
| end_effector_type_set |
| for i, end_effector_type_set in enumerate(constraint_info["end_effector_type"]) |
| if i in valid_idx |
| ] |
|
|
| |
| cls_idx = defaultdict(list) |
| for idx, end_effector_type_set in enumerate(end_effector_type_set_lst): |
| for end_effector_type in end_effector_type_set: |
| cls_idx[TYPE_TO_CLASS[end_effector_type]].append(idx) |
|
|
| for cls, lst_idx in cls_idx.items(): |
| frame_indices_cls = frame_indices[lst_idx] |
| smooth_root_pos_2d = None |
| if dense_smooth_root_pos_2d is not None: |
| smooth_root_pos_2d = dense_smooth_root_pos_2d[frame_indices_cls] |
|
|
| constraint_joints_pos_el = constraint_joints_pos[lst_idx] |
| constraint_joints_rot_el = constraint_joints_rot[lst_idx] |
|
|
| model_constraints.append( |
| cls( |
| model_skeleton, |
| frame_indices_cls, |
| constraint_joints_pos_el, |
| constraint_joints_rot_el, |
| smooth_root_2d=smooth_root_pos_2d, |
| ) |
| ) |
| else: |
| raise ValueError(f"Unsupported constraint type: {constraint.display_name}") |
| return model_constraints |
|
|
|
|
| def generate( |
| *, |
| client: viser.ClientHandle, |
| session: ClientSession, |
| model_bundle: ModelBundle, |
| prompts: list[str], |
| num_frames: list[int], |
| num_samples: int, |
| seed: int, |
| diffusion_steps: int, |
| cfg_weight: Optional[list[float]] = None, |
| cfg_type: Optional[str] = None, |
| postprocess_parameters: Optional[dict] = None, |
| transitions_parameters: Optional[dict] = None, |
| real_robot_rotations: bool = False, |
| device: str, |
| clear_motions, |
| add_character_motion, |
| ) -> None: |
| client_id = client.client_id |
| print( |
| f"Generating {num_samples} samples for a total of {sum(num_frames)} frames with those prompt: {prompts} (client {client_id})" |
| ) |
|
|
| seed_everything(seed) |
|
|
| model_constraints = compute_model_constraints_lst(session, model_bundle, sum(num_frames), device) |
| cfg_weight = cfg_weight or [2.0, 2.0] |
| postprocess_parameters = postprocess_parameters or {} |
| transitions_parameters = transitions_parameters or {} |
|
|
| encoder = getattr(model_bundle.model, "text_encoder", None) |
| if isinstance(encoder, CachedTextEncoder): |
| with encoder.session_context(session): |
| pred_joints_output = model_bundle.model( |
| prompts, |
| num_frames, |
| diffusion_steps, |
| multi_prompt=True, |
| constraint_lst=model_constraints, |
| cfg_weight=cfg_weight, |
| num_samples=num_samples, |
| cfg_type=cfg_type, |
| **(postprocess_parameters | transitions_parameters), |
| ) |
| else: |
| pred_joints_output = model_bundle.model( |
| prompts, |
| num_frames, |
| diffusion_steps, |
| multi_prompt=True, |
| constraint_lst=model_constraints, |
| cfg_weight=cfg_weight, |
| num_samples=num_samples, |
| cfg_type=cfg_type, |
| **(postprocess_parameters | transitions_parameters), |
| ) |
|
|
| joints_pos = pred_joints_output["posed_joints"] |
| joints_rot = pred_joints_output["global_rot_mats"] |
| foot_contacts = pred_joints_output.get("foot_contacts") |
|
|
| |
| if real_robot_rotations and isinstance(session.skeleton, G1Skeleton34): |
| joints_pos, joints_rot = apply_g1_real_robot_projection( |
| session.skeleton, |
| pred_joints_output["posed_joints"], |
| pred_joints_output["global_rot_mats"], |
| clamp_to_limits=True, |
| ) |
|
|
| |
| clear_motions(client_id) |
| |
| spread_factor = 1.0 |
| center_idx = num_samples // 2 |
| x_trans = (np.arange(num_samples) - center_idx) * spread_factor |
| for i in range(num_samples): |
| cur_joints_pos = joints_pos[i] |
| cur_joints_pos[..., 0] += x_trans[i] |
| add_character_motion( |
| client, |
| session.skeleton, |
| cur_joints_pos, |
| joints_rot[i], |
| foot_contacts[i], |
| ) |
|
|