# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # ruff: noqa: I001 import math import os import threading import time from typing import Optional from kimodo.constraints import load_constraints_lst, save_constraints_lst from kimodo.exports.bvh import motion_to_bvh_bytes, save_motion_bvh from kimodo.exports.motion_io import ( amass_npz_to_bytes, g1_csv_to_bytes, kimodo_npz_to_bytes, load_motion_file, save_kimodo_npz, ) from kimodo.model.registry import kimodo_short_key_for_skeleton_dataset, registry_skeleton_for_joint_count from kimodo.tools import to_torch from kimodo.viz import viser_utils from kimodo.viz.viser_utils import GuiElements import numpy as np import torch import viser from viser._timeline_api import PROMPT_COLORS from . import generation from ._qwen_prompts import call_qwen_for_prompts from .config import ( DEFAULT_CUR_DURATION, DEMO_UI_INSTRUCTIONS_TAB_MD, get_datasets, get_model_info, get_models_for_dataset_skeleton, get_skeleton_display_name, get_skeleton_display_names_for_dataset, get_skeleton_key_from_display_name, get_short_key_from_display_name, HF_MODE, INIT_POSTPROCESSING, MODEL_NAMES, NB_TRANSITION_FRAMES, SHOW_TRANSITION_PARAMS, ) from .state import ClientSession from kimodo.skeleton import G1Skeleton34, SOMASkeleton30, SOMASkeleton77 QWEN_EXAMPLE_NAME = "09_qwen_agentic_actions" QWEN_EXAMPLE_LEGACY_NAME = "09_qwen_agentic_10_actions" def extract_intervals_and_singles(t: torch.Tensor): intervals = [] intervals_indices = [] single_frames = [] single_frames_indices = [] start_idx = 0 for i in range(1, len(t) + 1): # End of run if: # - end of tensor # - non-consecutive value if i == len(t) or t[i] != t[i - 1] + 1: run_length = i - start_idx if run_length >= 2: intervals.append((int(t[start_idx]), int(t[i - 1]))) intervals_indices.append((start_idx, i - 1)) else: single_frames.append(int(t[start_idx])) single_frames_indices.append(start_idx) start_idx = i return intervals, intervals_indices, single_frames, single_frames_indices def create_gui( demo, client: viser.ClientHandle, model_name: str, model_fps: float, ): """Create GUI elements for a specific client.""" client_id = client.client_id def get_active_session(event_client: viser.ClientHandle | None): if event_client is None: return None if not demo.client_active(event_client.client_id): return None return demo.client_sessions[event_client.client_id] def build_timeline_tracks(): timeline = client.timeline demo.set_timeline_defaults(timeline, model_fps) timeline.set_visible(True) timeline.set_current_frame(0) timeline_tracks = {} fullbody_id = timeline.add_track( "Full-Body", track_type="keyframe", color=(219, 148, 86), height_scale=0.5, ) timeline_tracks[fullbody_id] = { "name": "Full-Body", "track_type": "keyframe", "color": (219, 148, 86), "height_scale": 0.5, } root2d_id = timeline.add_track( "2D Root", track_type="keyframe", color=(150, 100, 200), height_scale=0.5, ) timeline_tracks[root2d_id] = { "name": "2D Root", "track_type": "keyframe", "color": (150, 100, 200), "height_scale": 0.5, } lefthand_id = timeline.add_track( "Left Hand", track_type="keyframe", color=(100, 200, 150), height_scale=0.5, ) timeline_tracks[lefthand_id] = { "name": "Left Hand", "track_type": "keyframe", "color": (100, 200, 150), "height_scale": 0.5, } righthand_id = timeline.add_track( "Right Hand", track_type="keyframe", color=(200, 100, 150), height_scale=0.5, ) timeline_tracks[righthand_id] = { "name": "Right Hand", "track_type": "keyframe", "color": (200, 100, 150), "height_scale": 0.5, } leftfoot_id = timeline.add_track( "Left Foot", track_type="keyframe", color=(219, 148, 86), height_scale=0.5, ) timeline_tracks[leftfoot_id] = { "name": "Left Foot", "track_type": "keyframe", "color": (219, 148, 86), "height_scale": 0.5, } rightfoot_id = timeline.add_track( "Right Foot", track_type="keyframe", color=(150, 100, 200), height_scale=0.5, ) timeline_tracks[rightfoot_id] = { "name": "Right Foot", "track_type": "keyframe", "color": (150, 100, 200), "height_scale": 0.5, } return timeline, timeline_tracks timeline, timeline_tracks = build_timeline_tracks() # These handles are part of GuiElements, but the demo currently uses timeline + buttons # embedded in the Viser UI instead of custom controls. gui_play_pause_button = None gui_next_frame_button = None gui_prev_frame_button = None gui_timeline = None gui_duration_slider = None # now other gui elements tab_group = client.gui.add_tab_group() # # Playback and Motion generation controls # with tab_group.add_tab("Generate", viser.Icon.WALK): with client.gui.add_folder("Model Selection", expand_by_default=True): info = get_model_info(model_name) if info is None: info = get_model_info(next(iter(MODEL_NAMES))) def get_allowed_skeleton_labels(dataset_ui_label: str) -> list[str]: labels = get_skeleton_display_names_for_dataset(dataset_ui_label, family="Kimodo") if HF_MODE: labels = [label for label in labels if get_skeleton_key_from_display_name(label) != "SMPLX"] return labels dataset_ui_label = "Rigplay" if HF_MODE else info.dataset_ui_label datasets = ["Rigplay"] if HF_MODE else get_datasets(family="Kimodo") skeleton_labels = get_allowed_skeleton_labels(dataset_ui_label) initial_skeleton_label = get_skeleton_display_name(info.skeleton) if initial_skeleton_label not in skeleton_labels and skeleton_labels: initial_skeleton_label = skeleton_labels[0] initial_skeleton_key = ( get_skeleton_key_from_display_name(initial_skeleton_label) if skeleton_labels else None ) models_for_pair = ( get_models_for_dataset_skeleton(dataset_ui_label, initial_skeleton_key, family="Kimodo") if initial_skeleton_key is not None else [] ) version_options = [m.display_name for m in models_for_pair] initial_version = ( info.display_name if info.display_name in version_options else (version_options[0] if version_options else "") ) gui_dataset_selector = client.gui.add_dropdown( "Training dataset", options=datasets, initial_value=dataset_ui_label, visible=not HF_MODE, ) gui_skeleton_selector = client.gui.add_dropdown( "Model" if HF_MODE else "Skeleton", options=skeleton_labels, initial_value=initial_skeleton_label, ) gui_version_selector = client.gui.add_dropdown( "Version", options=version_options, initial_value=initial_version, ) gui_version_selector.visible = len(models_for_pair) > 1 gui_model_display = client.gui.add_markdown( content=f"**Model:** {initial_version}", ) gui_load_model_button = client.gui.add_button( "Load model", hint="Load the selected model (dataset, skeleton, version).", ) class ModelSelectorHandle: """Wrapper so session and callbacks can treat three dropdowns as one.""" def __init__(self): self._dataset = gui_dataset_selector self._skeleton = gui_skeleton_selector self._version = gui_version_selector self._display = gui_model_display @property def value(self) -> str: return get_short_key_from_display_name(self._version.value) or "" def set_from_short_key(self, short_key: str) -> None: info = get_model_info(short_key) if info is None: return dataset_ui_label = "Rigplay" if HF_MODE else info.dataset_ui_label self._dataset.value = dataset_ui_label self._skeleton.options = get_allowed_skeleton_labels(dataset_ui_label) skeleton_label = get_skeleton_display_name(info.skeleton) if skeleton_label not in self._skeleton.options and self._skeleton.options: skeleton_label = self._skeleton.options[0] self._skeleton.value = skeleton_label skeleton_key = get_skeleton_key_from_display_name(skeleton_label) if skeleton_key is None: return models = get_models_for_dataset_skeleton(dataset_ui_label, skeleton_key, family="Kimodo") self._version.options = [m.display_name for m in models] self._version.value = ( info.display_name if info.display_name in self._version.options else self._version.options[0] ) self._version.visible = len(models) > 1 self._display.content = f"**Model:** {self._version.value}" gui_model_selector = ModelSelectorHandle() with client.gui.add_folder("Examples", expand_by_default=True): examples_base_dir = demo.get_examples_base_dir(model_name, absolute=True) example_dict = viser_utils.load_example_cases(examples_base_dir) example_names = list(example_dict.keys()) example_names.append(QWEN_EXAMPLE_NAME) gui_examples_dropdown = client.gui.add_dropdown( "Example", options=example_names, initial_value=example_names[0], ) gui_load_example_button = client.gui.add_button( "Load Example", hint="Load the selected example (or Qwen agentic prompt plan).", disabled=False, ) def update_examples_dropdown( new_example_dict: dict[str, str], keep_selection: bool = True, ) -> None: example_names_local = list(new_example_dict.keys()) if QWEN_EXAMPLE_NAME not in example_names_local: example_names_local.append(QWEN_EXAMPLE_NAME) if QWEN_EXAMPLE_LEGACY_NAME not in example_names_local: example_names_local.append(QWEN_EXAMPLE_LEGACY_NAME) gui_examples_dropdown.options = example_names_local if keep_selection and gui_examples_dropdown.value in example_names_local: return gui_examples_dropdown.value = example_names_local[0] with client.gui.add_folder("Generate", expand_by_default=True): gui_duration = client.gui.add_markdown(content=f"Total duration: {DEFAULT_CUR_DURATION:.1f} (sec)") def update_duration_gui(duration): gui_duration.content = f"Total duration: {duration:.1f} (sec)" def compute_prompt_num_frames(prompt_values): """Convert timeline prompt bounds to per-prompt frame counts. Convention in this demo: - All prompts except the last are treated as [start_frame, end_frame) (end is exclusive). - The last prompt is treated as [start_frame, end_frame] (end is inclusive). - This assumes the prompts values are sorted by start_frame. """ if len(prompt_values) == 0: return [] num_frames = [] for i, x in enumerate(prompt_values): cur = x.end_frame - x.start_frame if i == len(prompt_values) - 1: cur += 1 num_frames.append(cur) return num_frames def update_duration_auto(): session = demo.client_sessions[client_id] prompt_values = sorted( [x for x in timeline._prompts.values()], key=lambda x: x.start_frame, ) num_frames = compute_prompt_num_frames(prompt_values) total_nb_frames = sum(num_frames) cur_duration = total_nb_frames / session.model_fps set_new_duration(client_id, cur_duration) update_duration_gui(cur_duration) gui_num_samples_slider = client.gui.add_slider( "Num Samples", min=1, max=10, step=1, initial_value=1, visible=not HF_MODE, ) gui_use_soma_layer_checkbox = client.gui.add_checkbox( "SOMA layer", initial_value=False, visible="soma" in (model_name or ""), ) with client.gui.add_folder("Model Parameters", expand_by_default=False): gui_seed = client.gui.add_number("Seed", initial_value=42) with client.gui.add_folder("Diffusion", expand_by_default=False): gui_diffusion_steps_slider = client.gui.add_slider( "Denoising Steps", min=2, max=1000, step=10, initial_value=100, ) with client.gui.add_folder("Classifier-Free Guidance", expand_by_default=False): gui_cfg_checkbox = client.gui.add_checkbox( "Enable", initial_value=True, visible=True, ) gui_cfg_text_weight_slider = client.gui.add_slider( "Text Weight", min=0.0, max=5.0, step=0.1, initial_value=2.0, visible=True, ) gui_cfg_constraint_weight_slider = client.gui.add_slider( "Constraint Weight", min=0.0, max=5.0, step=0.1, initial_value=2.0, visible=True, ) with client.gui.add_folder( "Transitions", expand_by_default=False, visible=SHOW_TRANSITION_PARAMS, ): gui_num_transition_frames_slider = client.gui.add_slider( "Transition frames", min=1, max=10, step=1, initial_value=NB_TRANSITION_FRAMES, visible=True, ) gui_share_transition_checkbox = client.gui.add_checkbox( # noqa "Override previous frames", initial_value=False, visible=True, ) gui_percentage_transition_sharing_slider = client.gui.add_slider( "Percentage overriding frames", min=0, max=30, step=1, initial_value=10, visible=True, ) @gui_share_transition_checkbox.on_update def _(event: viser.GuiEvent) -> None: if get_active_session(event.client) is None: return # disable the slider if sharing transition is False gui_percentage_transition_sharing_slider.visible = gui_share_transition_checkbox.value with client.gui.add_folder("Post Processing", expand_by_default=False): _model_name = model_name or "" _postprocess_visible = "g1" not in _model_name gui_postprocess_checkbox = client.gui.add_checkbox( "Enable", initial_value=INIT_POSTPROCESSING, hint="Apply motion post-processing (not available for G1)", visible=_postprocess_visible, ) gui_root_margin = client.gui.add_number( "Root Margin", min=0.0, # max=0.5, step=0.01, initial_value=0.04, hint="Margin for root position (meters). Lower values pin root closer to target.", visible=INIT_POSTPROCESSING and _postprocess_visible, ) @gui_postprocess_checkbox.on_update def _(event: viser.GuiEvent) -> None: if get_active_session(event.client) is None: return # disable the slider if sharing transition is False gui_root_margin.visible = gui_postprocess_checkbox.value gui_real_robot_rotations_checkbox = client.gui.add_checkbox( "Real robot rotations", initial_value=False, hint="Project joint rotations to G1 real robot DoF (1-DoF per joint) and clamp to axis limits from the MuJoCo XML.", visible="g1" in _model_name, ) with client.gui.add_folder("Qwen Auto-Prompts", expand_by_default=True): gui_qwen_scene = client.gui.add_text( "Scene context", initial_value="A lone figure moving through an empty plaza", hint="Describe the scene or character context for Qwen to generate motion prompts.", ) gui_qwen_actions = client.gui.add_slider( "Target actions", min=1, max=10, step=1, initial_value=6, hint="Number of prompt segments to place on the timeline.", ) gui_qwen_auto_run = client.gui.add_checkbox( "Auto-run Generate after loading prompts", initial_value=False, ) gui_qwen_status = client.gui.add_markdown(content="") gui_qwen_button = client.gui.add_button("Fill Timeline via Qwen", color="blue") @gui_qwen_button.on_click def _(event: viser.GuiEvent) -> None: event_client = event.client session = get_active_session(event_client) if session is None: return def _run_qwen_fill_and_maybe_generate() -> None: gui_qwen_button.disabled = True gui_qwen_status.content = "⏳ Calling Qwen…" target_actions = int(gui_qwen_actions.value) history: list[str] = [] all_texts: list[str] = [] all_durations: list[float] = [] rounds = 0 # Keep requesting batches until we fill target actions (max 10). while len(all_texts) < target_actions and rounds < 6: remaining = target_actions - len(all_texts) batch, history = call_qwen_for_prompts( scene=gui_qwen_scene.value, history=history, requested_actions=min(5, remaining), ) batch_texts = batch.get("texts", []) batch_durations = batch.get("durations", []) for t, d in zip(batch_texts, batch_durations): if len(all_texts) >= target_actions: break all_texts.append(t) all_durations.append(float(d)) rounds += 1 if len(all_texts) == 0: gui_qwen_status.content = "⚠ Qwen did not return usable prompts" gui_qwen_button.disabled = False return fps = session.model_fps event_client.timeline.clear_prompts() frame_cursor = 0 for i, (txt, dur) in enumerate(zip(all_texts, all_durations)): n = max(1, int(round(dur * fps))) start_f = frame_cursor end_f = frame_cursor + n if i < len(all_texts) - 1 else frame_cursor + n - 1 color = PROMPT_COLORS[i % len(PROMPT_COLORS)] event_client.timeline.add_prompt(txt, start_f, end_f, color=color) frame_cursor += n # Keep timeline readable by expanding zoom to planned sequence length. target_visible_frames = int(math.ceil(1.10 * frame_cursor)) event_client.timeline.set_zoom_settings(default_num_frames_zoom=max(60, target_visible_frames)) update_duration_auto() gui_qwen_status.content = f"✓ Loaded {len(all_texts)} Qwen prompt segments" if gui_qwen_auto_run.value: gui_qwen_status.content = f"✓ Loaded {len(all_texts)} prompts, generating motion…" try: demo.generate( event_client, all_texts, [max(1, int(round(d * fps))) for d in all_durations], gui_num_samples_slider.value, gui_seed.value, gui_diffusion_steps_slider.value, cfg_weight=[ gui_cfg_text_weight_slider.value, gui_cfg_constraint_weight_slider.value, ], cfg_type="separated" if gui_cfg_checkbox.value else "nocfg", postprocess_parameters={ "post_processing": gui_postprocess_checkbox.value, "root_margin": gui_root_margin.value, }, transitions_parameters={ "num_transition_frames": gui_num_transition_frames_slider.value, "share_transition": gui_share_transition_checkbox.value, "percentage_transition_override": gui_percentage_transition_sharing_slider.value / 100, }, real_robot_rotations=gui_real_robot_rotations_checkbox.value, ) gui_qwen_status.content = f"✓ Generated motion from {len(all_texts)} Qwen actions" except Exception as exc: gui_qwen_status.content = f"⚠ Generate error: {exc}" gui_qwen_button.disabled = False threading.Thread(target=_run_qwen_fill_and_maybe_generate, daemon=True).start() gui_generate_button = client.gui.add_button("Generate", color="green") with client.gui.add_folder("Constraints", expand_by_default=False): gui_gizmo_space_dropdown = client.gui.add_dropdown( "Gizmo space", ("Local", "World"), initial_value="Local", visible="g1" not in _model_name, ) gui_edit_constraint_button = client.gui.add_button("Enter Editing Mode") gui_snap_to_constraint_button = client.gui.add_button( "Snap to Constraint", disabled=True, ) gui_reset_constraint_button = client.gui.add_button( "Reset Constraint", disabled=True, ) gui_undo_drag_button = client.gui.add_button( "Undo Move", disabled=True, ) with client.gui.add_folder("Root 2D Options", expand_by_default=True): gui_dense_path_checkbox = client.gui.add_checkbox( "Make Smooth Path", initial_value=False, visible=True, ) gui_show_only_current_constraint_checkbox = client.gui.add_checkbox( "Show only Current", initial_value=False, hint="Show only constraint overlays at the current frame; uncheck to show all.", ) def apply_constraint_overlay_visibility(session: ClientSession) -> None: demo._apply_constraint_overlay_visibility(session) @gui_show_only_current_constraint_checkbox.on_update def _(event: viser.GuiEvent) -> None: session = get_active_session(event.client) if session is None: return session.show_only_current_constraint = gui_show_only_current_constraint_checkbox.value apply_constraint_overlay_visibility(session) gui_clear_all_constraints_button = client.gui.add_button( "Clear All Constraints", color="red", ) def has_constraint_at_frame(session: ClientSession, frame_idx: int) -> bool: for constraint_name in ["Full-Body", "End-Effectors", "2D Root"]: constraint = session.constraints.get(constraint_name) if constraint is None: continue if frame_idx in constraint.keyframes: return True return False def update_snap_to_constraint_button(session: ClientSession) -> None: gui_snap_to_constraint_button.disabled = not has_constraint_at_frame(session, session.frame_idx) def ensure_edit_snapshot(session: ClientSession, motion, frame_idx: int) -> None: if session.edit_mode_snapshot is None: session.edit_mode_snapshot = {} if frame_idx in session.edit_mode_snapshot: return session.edit_mode_snapshot[frame_idx] = { "joints_pos": motion.get_joints_pos(frame_idx), "joints_rot": motion.get_joints_rot(frame_idx), } def _update_dense_path(motion, session): constraint_info = session.constraints["2D Root"].get_constraint_info() if len(constraint_info["frame_idx"]) > 0: min_root_frame = min(constraint_info["frame_idx"]) max_root_frame = max(constraint_info["frame_idx"]) motion.set_projected_root_pos_path( constraint_info["root_pos"][:, [0, 2]], min_frame_idx=min_root_frame, max_frame_idx=max_root_frame, ) # Delay (ms) after last keyframe/interval move before updating path = "on release". DENSE_PATH_AFTER_RELEASE_MS = 300 def _schedule_dense_path_after_release(session): """Schedule a single path update to run after user stops dragging.""" if "2D Root" not in session.constraints or not session.constraints["2D Root"].dense_path: return tdata = session.timeline_data if tdata.get("dense_path_after_release_timer"): tdata["dense_path_after_release_timer"].cancel() delay = DENSE_PATH_AFTER_RELEASE_MS / 1000.0 def run(): if not demo.client_active(client_id): return sess = demo.client_sessions[client_id] tdata["dense_path_after_release_timer"] = None if "2D Root" not in sess.constraints or not sess.constraints["2D Root"].dense_path: return mot = list(sess.motions.values())[0] _update_dense_path(mot, sess) t = threading.Timer(delay, run) tdata["dense_path_after_release_timer"] = t t.start() @gui_dense_path_checkbox.on_update def _(event: viser.GuiEvent) -> None: session = get_active_session(event.client) if session is None: return if gui_dense_path_checkbox.value: # Make sure 0 and max_frame_idx keyframes are added to the constraint # since dense path should cover full duration for best model performance root_2d_track = session.timeline_data["tracks_ids"]["2D Root"] # add a locked keyframe at 0 start_keyframe_id = client.timeline.add_locked_keyframe( # noqa root_2d_track, 0, opacity=0.0, ) session.timeline_data["keyframes"][start_keyframe_id] = { "frame": 0, "track_id": root_2d_track, "locked": True, "opacity": 0.0, "value": None, } add_constraint_callback( start_keyframe_id, "2D Root", (0, 0), verbose=False, ) # add a locked keyframe at max_frame_idx end_keyframe_id = client.timeline.add_locked_keyframe( root_2d_track, session.max_frame_idx, opacity=0.0, ) session.timeline_data["keyframes"][end_keyframe_id] = { "frame": session.max_frame_idx, "track_id": root_2d_track, "locked": True, "opacity": 0.0, "value": None, } add_constraint_callback( end_keyframe_id, "2D Root", (session.max_frame_idx, session.max_frame_idx), verbose=False, ) # add a locked interval only for visual purposes locked_interval = client.timeline.add_locked_interval( # noqa root_2d_track, start_frame=0, end_frame=session.max_frame_idx, ) session.timeline_data["intervals"][locked_interval] = { "track_id": root_2d_track, "start_frame_idx": 0, "end_frame_idx": session.max_frame_idx, "locked": True, "opacity": 0.3, "value": None, } session.constraints["2D Root"].set_dense_path(gui_dense_path_checkbox.value) if session.constraints["2D Root"].dense_path: # update the character motion to reflect the full path # will be full length by construction, no need to specify min/max frame idx motion = list(session.motions.values())[0] _update_dense_path(motion, session) # remove locked interval and locked keyframes if not gui_dense_path_checkbox.value: # Get all locked keyframes keyframes_to_remove = [] for uuid, keyframe in client.timeline._keyframes.items(): if keyframe.locked: keyframes_to_remove.append(uuid) _data = session.timeline_data["keyframes"][uuid] remove_constraint_callback( uuid, constraint_type=session.timeline_data["tracks"][_data["track_id"]]["name"], frame_range=(_data["frame"], _data["frame"]), verbose=False, ) intervals_to_remove = [] # remove all locked intervals for uuid, interval in client.timeline._intervals.items(): if interval.locked: intervals_to_remove.append(uuid) # removing keyframes and intervals for uuid in keyframes_to_remove: client.timeline.remove_keyframe(uuid) for uuid in intervals_to_remove: client.timeline.remove_interval(uuid) apply_constraint_overlay_visibility(session) with client.gui.add_folder( "Load/Save", expand_by_default=False, visible=not HF_MODE, ): with client.gui.add_folder("Motion", expand_by_default=False): gui_save_motion_path_text = client.gui.add_text("Save Path", initial_value="output") gui_save_motion_format_dropdown = client.gui.add_dropdown( "Save Format", options=( ["NPZ", "CSV"] if "g1" in model_name.lower() else ["NPZ", "AMASS NPZ"] if "smplx" in model_name.lower() else ["NPZ", "BVH"] ), initial_value="NPZ", ) gui_save_motion_button = client.gui.add_button( "Save Motion", hint="Save the current motion (format + path above)", ) gui_load_motion_path_text = client.gui.add_text( "Load Path", initial_value="output.npz", hint="SOMA .bvh, Kimodo or AMASS .npz, or G1 MuJoCo .csv", ) gui_load_motion_button = client.gui.add_button( "Load Motion", hint="Load the selected motion", ) with client.gui.add_folder("Constraints", expand_by_default=False): gui_save_constraints_path_text = client.gui.add_text( "Save Path", initial_value="output_constraints.json" ) gui_save_constraints_button = client.gui.add_button("Save Constraints") gui_load_constraints_path_text = client.gui.add_text( "Load Path", initial_value="output_constraints.json" ) gui_load_constraints_button = client.gui.add_button("Load Constraints") with client.gui.add_folder("Example", expand_by_default=False): gui_save_example_path_text = client.gui.add_text( "Save Dir", initial_value=os.path.join( demo.get_examples_base_dir(model_name, absolute=True), "custom_example_1", ), ) gui_save_example_button = client.gui.add_button("Save Example") gui_load_example_path_text = client.gui.add_text( "Load Dir", initial_value=os.path.join( demo.get_examples_base_dir(model_name, absolute=True), "custom_example_1", ), ) gui_load_gt_checkbox = client.gui.add_checkbox( "Load GT instead", initial_value=False, ) gui_load_example_from_path_button = client.gui.add_button("Load Example") def _get_primary_motion(session: ClientSession): return list(session.motions.values())[0] def _motion_to_numpy_dict(motion) -> dict[str, np.ndarray]: joints_pos = motion.joints_pos.detach().cpu().numpy() joints_rot = motion.joints_rot.detach().cpu().numpy() joints_local_rot = motion.joints_local_rot.detach().cpu().numpy() if joints_pos.ndim != 3: raise ValueError(f"Expected unbatched joints_pos with shape [T, J, 3], got {joints_pos.shape}") if joints_rot.ndim != 4: raise ValueError(f"Expected unbatched joints_rot with shape [T, J, 3, 3], got {joints_rot.shape}") if joints_local_rot.ndim != 4: raise ValueError( "Expected unbatched joints_local_rot with shape " f"[T, J, 3, 3], got {joints_local_rot.shape}" ) motion_data = { "posed_joints": joints_pos, "global_rot_mats": joints_rot, "local_rot_mats": joints_local_rot, "root_positions": joints_pos[:, motion.skeleton.root_idx, :], } if motion.foot_contacts is not None: foot_contacts = motion.foot_contacts.detach().cpu().numpy() if foot_contacts.ndim != 2: raise ValueError( f"Expected unbatched foot_contacts with shape [T, C], got {foot_contacts.shape}" ) motion_data["foot_contacts"] = foot_contacts return motion_data def _coerce_save_path(raw_path: str, *, ext: str) -> str: """Ensure the save path ends with the correct extension for the chosen format.""" name = (raw_path or "").strip() if name == "": return f"output{ext}" known_exts = (".npz", ".bvh", ".csv") if name.lower().endswith(known_exts): return os.path.splitext(name)[0] + ext if os.path.splitext(name)[1] == "": return name + ext return name def save_motion(client, save_path, fmt): session = demo.client_sessions[client.client_id] motion = _get_primary_motion(session) motion_data = _motion_to_numpy_dict(motion) if fmt == "BVH": save_path = _coerce_save_path(save_path, ext=".bvh") save_motion_bvh( save_path, motion.joints_local_rot, motion.joints_pos[:, session.skeleton.root_idx, :], skeleton=session.skeleton, fps=float(session.model_fps), ) elif fmt == "CSV": save_path = _coerce_save_path(save_path, ext=".csv") data = g1_csv_to_bytes(motion_data, session.skeleton, demo.device) with open(save_path, "wb") as f: f.write(data) elif fmt == "AMASS NPZ": save_path = _coerce_save_path(save_path, ext=".npz") data = amass_npz_to_bytes(motion_data, session.skeleton, session.model_fps) with open(save_path, "wb") as f: f.write(data) else: save_path = _coerce_save_path(save_path, ext=".npz") save_kimodo_npz(save_path, motion_data) return save_path @gui_save_motion_button.on_click def _(event: viser.GuiEvent) -> None: event_client = event.client if get_active_session(event_client) is None: return raw_path = gui_save_motion_path_text.value fmt = str(gui_save_motion_format_dropdown.value).upper() try: saved_path = save_motion(event_client, raw_path, fmt) event_client.add_notification( title="Motion saved!", body=f"Saved motion to {saved_path}", auto_close_seconds=5.0, color="green", ) except Exception as e: import traceback traceback.print_exc() event_client.add_notification( title="Failed to save motion!", body=str(e), auto_close_seconds=5.0, color="red", ) def load_motion(client, load_path): session = demo.client_sessions[client.client_id] fps_arg = session.model_fps if session.model_fps and session.model_fps > 0 else None motion_dict, num_joints_motion = load_motion_file(load_path, target_fps=fps_arg) target_skel = registry_skeleton_for_joint_count(num_joints_motion) current_info = get_model_info(session.model_name) current_skel = current_info.skeleton if current_info is not None else None if current_skel != target_skel: dataset = current_info.dataset if current_info is not None else "RP" new_key = kimodo_short_key_for_skeleton_dataset(target_skel, dataset) if new_key is None: new_key = kimodo_short_key_for_skeleton_dataset(target_skel, "RP") if new_key is None: raise ValueError( f"No Kimodo model found for skeleton {target_skel} (motion has J={num_joints_motion})." ) if new_key != session.model_name: gui_model_selector.set_from_short_key(new_key) apply_model_selection(new_key) _update_visibility_for_loaded_model(new_key) client.add_notification( title="Model switched", body=f"Switched to {new_key} to match loaded motion (J={num_joints_motion}).", auto_close_seconds=5.0, color="blue", ) session = demo.client_sessions[client.client_id] joints_pos = motion_dict["posed_joints"].to(device=demo.device, dtype=torch.float32) joints_rot = motion_dict["global_rot_mats"].to(device=demo.device, dtype=torch.float32) foot_contacts = motion_dict.get("foot_contacts") if foot_contacts is not None: foot_contacts = foot_contacts.to(device=demo.device, dtype=torch.float32) # Support both batched [B, T, J, 3] and unbatched [T, J, 3]; take first sample if batched if joints_pos.ndim == 4: joints_pos = joints_pos[0] if joints_rot.ndim == 5: joints_rot = joints_rot[0] if foot_contacts is not None and foot_contacts.ndim == 3: foot_contacts = foot_contacts[0] # Motion must match the current model's skeleton after auto-switch num_joints_loaded = joints_pos.shape[1] num_joints_skeleton = session.skeleton.nbjoints if num_joints_loaded != num_joints_skeleton: # Backward compat: expand 30-joint SOMA motion to 77 if ( num_joints_loaded == 30 and num_joints_skeleton == 77 and isinstance(session.skeleton, SOMASkeleton77) ): from kimodo.skeleton import global_rots_to_local_rots skel30 = SOMASkeleton30().to(demo.device) if "local_rot_mats" in motion_dict: local_rot_30 = motion_dict["local_rot_mats"].to(device=demo.device, dtype=torch.float32) if local_rot_30.ndim == 4: local_rot_30 = local_rot_30[0] else: local_rot_30 = global_rots_to_local_rots(joints_rot, skel30) local_rot_77 = skel30.to_SOMASkeleton77(local_rot_30) root_positions = joints_pos[:, skel30.root_idx, :] joints_rot, joints_pos, _ = session.skeleton.fk(local_rot_77, root_positions) if foot_contacts is not None and foot_contacts.shape[-1] == 4: foot_contacts = torch.cat( [ foot_contacts[..., :2], foot_contacts[..., 1:2], foot_contacts[..., 2:4], foot_contacts[..., 3:4], ], dim=-1, ) else: raise ValueError( f"The loaded motion has {num_joints_loaded} joints but the current model " f"({session.model_name}) has {num_joints_skeleton} joints. " "Load a motion generated with the same skeleton, or switch the model to match the motion." ) elif joints_rot.shape[1] != num_joints_skeleton: raise ValueError( f"Rotation data has {joints_rot.shape[1]} joints but the current model has " f"{num_joints_skeleton} joints. The NPZ may be corrupted or from a different skeleton." ) # Apply G1 real robot projection (1-DoF per joint + axis limits) if enabled. if ( "g1" in session.model_name and isinstance(session.skeleton, G1Skeleton34) and gui_real_robot_rotations_checkbox.value ): joints_pos, joints_rot = generation.apply_g1_real_robot_projection( session.skeleton, joints_pos, joints_rot ) # Update duration and frame range based on loaded motion num_frames = joints_pos.shape[0] duration = num_frames / session.model_fps # Update GUI elements session.cur_duration = duration session.max_frame_idx = num_frames - 1 # Clear existing motions and add the loaded one demo.clear_motions(client.client_id) demo.add_character_motion( client, session.skeleton, joints_pos, joints_rot, foot_contacts, ) # Reset to frame 0 demo.set_frame(client.client_id, 0) @gui_load_motion_button.on_click def _(event: viser.GuiEvent) -> None: event_client = event.client session = get_active_session(event_client) if session is None: return load_path = gui_load_motion_path_text.value loading_notif = event_client.add_notification( title="Loading motion...", body=f"Loading from {load_path}", loading=True, with_close_button=False, auto_close_seconds=None, ) try: load_motion(event_client, load_path) loading_notif.title = "Motion loaded!" loading_notif.body = f"Loaded motion from {load_path} ({session.max_frame_idx + 1} frames, {session.cur_duration:.2f}s)" loading_notif.loading = False loading_notif.with_close_button = True loading_notif.auto_close_seconds = 5.0 loading_notif.color = "green" except Exception as e: import traceback traceback.print_exc() loading_notif.title = "Failed to load motion!" loading_notif.body = str(e) loading_notif.loading = False loading_notif.with_close_button = True loading_notif.auto_close_seconds = 10.0 loading_notif.color = "red" def save_constraints(client, save_path): session = demo.client_sessions[client.client_id] # Keep save behavior aligned with demo frame convention: # valid frame indices are [0, max_frame_idx], so count is +1. num_frames = session.max_frame_idx + 1 model_bundle = demo.load_model(session.model_name) constraints_lst = demo.compute_model_constraints_lst(session, model_bundle, num_frames) save_constraints_lst(save_path, constraints_lst) @gui_save_constraints_button.on_click def _(event: viser.GuiEvent) -> None: event_client = event.client if get_active_session(event_client) is None: return try: save_path = gui_save_constraints_path_text.value save_constraints(event_client, save_path) event_client.add_notification( title="Constraints saved!", body=f"Saved constraints to {save_path}", auto_close_seconds=5.0, color="green", ) except Exception as e: import traceback traceback.print_exc() event_client.add_notification( title="Failed to save constraints!", body=str(e), auto_close_seconds=10.0, color="red", ) def load_constraints(client, load_path): session = demo.client_sessions[client.client_id] constraints_lst = load_constraints_lst(load_path, skeleton=session.skeleton) # Clear existing constraints first with session.timeline_data["keyframe_update_lock"]: for constraint in list(session.constraints.values()): constraint.clear() client.timeline.clear_keyframes() client.timeline.clear_intervals() # Add loaded constraints to the session # We need to directly add constraint data, not read from current motion device = demo.device for constraint_obj in constraints_lst: constraint_type = constraint_obj.name # decompose the frame indices into intervals or single keyframes frame_indices = constraint_obj.frame_indices ( intervals, intervals_indices, single_frames, single_frames_indices, ) = extract_intervals_and_singles(frame_indices) load_targets: list[dict] = [] root_pos = None if constraint_type == "root2d": # smooth_root_2d is [T, 2] (x, z), convert to [T, 3] (x, 0, z) num_frames = constraint_obj.smooth_root_2d.shape[0] root_pos = torch.zeros(num_frames, 3, device=device) root_pos[:, 0] = constraint_obj.smooth_root_2d[:, 0] root_pos[:, 2] = constraint_obj.smooth_root_2d[:, 1] load_targets = [ { "track_name": "2D Root", "constraint_track": session.constraints["2D Root"], } ] elif constraint_type == "fullbody": load_targets = [ { "track_name": "Full-Body", "constraint_track": session.constraints["Full-Body"], } ] elif constraint_type in { "left-hand", "right-hand", "left-foot", "right-foot", }: track_name = { "left-hand": "Left Hand", "right-hand": "Right Hand", "left-foot": "Left Foot", "right-foot": "Right Foot", }[constraint_type] load_targets = [ { "track_name": track_name, "constraint_track": session.constraints["End-Effectors"], "joint_names": constraint_obj.joint_names, "end_effector_type": constraint_type, } ] elif constraint_type in {"end-effector", "end-effectors"}: # Backward-compatible loader: # split a generic end-effector constraint into per-limb timeline tracks. joint_names_set = set(constraint_obj.joint_names) for jname, track_name, eff_type in [ ("LeftHand", "Left Hand", "left-hand"), ("RightHand", "Right Hand", "right-hand"), ("LeftFoot", "Left Foot", "left-foot"), ("RightFoot", "Right Foot", "right-foot"), ]: if jname not in joint_names_set: continue target_joint_names = [jname] if "Hips" in joint_names_set: target_joint_names.append("Hips") load_targets.append( { "track_name": track_name, "constraint_track": session.constraints["End-Effectors"], "joint_names": target_joint_names, "end_effector_type": eff_type, } ) if not load_targets: raise KeyError( "No recognized end-effector joint in constraint " f"joint_names={constraint_obj.joint_names}" ) else: raise KeyError(f"Unsupported constraint type in loader: {constraint_type}") for target in load_targets: track_id = session.timeline_data["tracks_ids"][target["track_name"]] constraint_track = target["constraint_track"] # add intervals for (start_idx, end_idx), (start_idx_t, end_idx_t) in zip(intervals, intervals_indices): # Add to timeline interval_id = client.timeline.add_interval(track_id, start_idx, end_idx) session.timeline_data["intervals"][interval_id] = { "track_id": track_id, "start_frame_idx": start_idx, "end_frame_idx": end_idx, "locked": False, "opacity": 1.0, "value": None, } if constraint_type == "root2d": constraint_track.add_interval( interval_id, start_idx, end_idx, root_pos[start_idx_t : end_idx_t + 1], ) elif constraint_type == "fullbody": constraint_track.add_interval( interval_id, start_idx, end_idx, constraint_obj.global_joints_positions[start_idx_t : end_idx_t + 1], constraint_obj.global_joints_rots[start_idx_t : end_idx_t + 1], ) else: constraint_track.add_interval( interval_id, start_idx, end_idx, constraint_obj.global_joints_positions[start_idx_t : end_idx_t + 1], constraint_obj.global_joints_rots[start_idx_t : end_idx_t + 1], target["joint_names"], target["end_effector_type"], ) # add keyframes for frame, frame_t in zip(single_frames, single_frames_indices): # Add to timeline keyframe_id = client.timeline.add_keyframe(track_id, frame) session.timeline_data["keyframes"][keyframe_id] = { "track_id": track_id, "frame": frame, "locked": False, "opacity": 1.0, "value": None, } if constraint_type == "root2d": constraint_track.add_keyframe( keyframe_id, frame, root_pos[frame_t], ) elif constraint_type == "fullbody": constraint_track.add_keyframe( keyframe_id, frame, constraint_obj.global_joints_positions[frame_t], constraint_obj.global_joints_rots[frame_t], ) else: constraint_track.add_keyframe( keyframe_id, frame, constraint_obj.global_joints_positions[frame_t], constraint_obj.global_joints_rots[frame_t], target["joint_names"], target["end_effector_type"], ) @gui_load_constraints_button.on_click def _(event: viser.GuiEvent) -> None: event_client = event.client if get_active_session(event_client) is None: return try: load_path = gui_load_constraints_path_text.value load_constraints(event_client, load_path) session = demo.client_sessions[event_client.client_id] apply_constraint_overlay_visibility(session) event_client.add_notification( title="Constraints loaded!", body=f"Loaded constraints from {load_path}", auto_close_seconds=5.0, color="green", ) except Exception as e: import traceback traceback.print_exc() event_client.add_notification( title="Failed to load constraints!", body=str(e), auto_close_seconds=10.0, color="red", ) with client.gui.add_folder("Exports", expand_by_default=False): with client.gui.add_folder("Screenshot", expand_by_default=False, visible=not HF_MODE): gui_screenshot_path_text = client.gui.add_text( "Save Path", initial_value="render.png", hint="Filename for the screenshot (PNG).", ) gui_screenshot_button = client.gui.add_button( "Download Screenshot", hint="Capture the current canvas and download a PNG.", ) with client.gui.add_folder("Video", expand_by_default=False, visible=not HF_MODE): gui_video_path_text = client.gui.add_text( "Save Path", initial_value="render.mp4", hint="Filename for the video (MP4).", ) gui_video_button = client.gui.add_button( "Download Video", hint="Render every frame and download as MP4.", ) with client.gui.add_folder("Motion", expand_by_default=True): gui_download_name_text = client.gui.add_text( "Name", initial_value="output", hint="Base filename to save as (extension will be added based on format if omitted).", ) gui_download_format_dropdown = client.gui.add_dropdown( "Format", options=( ["NPZ", "CSV"] if "g1" in model_name.lower() else ["NPZ", "AMASS NPZ"] if "smplx" in model_name.lower() else ["NPZ", "BVH"] ), initial_value="NPZ", ) gui_download_button = client.gui.add_button( "Download", hint="Download the current motion (format + name above).", ) def _download_bytes_to_browser( event_client: viser.ClientHandle, *, data: bytes, filename: str, mime_type: str = "application/octet-stream", ) -> None: """Trigger a browser download for an in-memory byte payload. Important: this intentionally does NOT use `showSaveFilePicker()` to avoid Chrome/Edge's file-write permission prompt ("this site can see edits you make"). If you want "always ask where to save", configure your browser download settings. """ import base64 import json # Base64 is the most robust way to move binary over our websocket JS channel. b64 = base64.b64encode(data).decode("ascii") js = f""" (() => {{ const filename = {json.dumps(filename)}; const mimeType = {json.dumps(mime_type)}; const b64 = {json.dumps(b64)}; // Decode base64 -> Uint8Array. const binStr = atob(b64); const bytes = new Uint8Array(binStr.length); for (let i = 0; i < binStr.length; i++) bytes[i] = binStr.charCodeAt(i); const blob = new Blob([bytes], {{ type: mimeType }}); // Standard browser download behavior. const url = URL.createObjectURL(blob); const a = document.createElement("a"); a.href = url; a.download = filename; document.body.appendChild(a); a.click(); a.remove(); URL.revokeObjectURL(url); }})(); """ # Reuse viser’s JS execution mechanism (used for Plotly setup). from viser import _messages as _viser_messages event_client.gui._websock_interface.queue_message( # type: ignore[attr-defined] _viser_messages.RunJavascriptMessage(source=js) ) def _motion_to_npz_bytes(motion) -> bytes: motion_data = _motion_to_numpy_dict(motion) return kimodo_npz_to_bytes(motion_data) def _motion_to_csv_bytes(motion, session: ClientSession) -> bytes: motion_data = _motion_to_numpy_dict(motion) return g1_csv_to_bytes(motion_data, session.skeleton, demo.device) def _motion_to_amass_npz_bytes(motion, session: ClientSession) -> bytes: motion_data = _motion_to_numpy_dict(motion) return amass_npz_to_bytes(motion_data, session.skeleton, session.model_fps) def _get_motion_export_formats(loaded_model_name: str) -> list[str]: model_name_lower = (loaded_model_name or "").lower() if "g1" in model_name_lower: return ["NPZ", "CSV"] if "smplx" in model_name_lower: return ["NPZ", "AMASS NPZ"] return ["NPZ", "BVH"] def _update_format_dropdown(dropdown, loaded_model_name: str) -> None: new_options = _get_motion_export_formats(loaded_model_name) current_value = str(dropdown.value) dropdown.options = new_options dropdown.value = current_value if current_value in new_options else new_options[0] def _update_motion_export_dropdown(loaded_model_name: str) -> None: _update_format_dropdown(gui_download_format_dropdown, loaded_model_name) _update_format_dropdown(gui_save_motion_format_dropdown, loaded_model_name) def _coerce_download_filename(raw_name: str, *, ext: str) -> str: """Coerce a user-entered filename to a safe basename with the desired extension. - If empty: uses "output{ext}" - If no extension: appends ext - If endswith a known export extension: rewrites extension to ext (prevents mismatches) - Any provided directory components are stripped """ import os name = (raw_name or "").strip() name = os.path.basename(name.replace("\\", "/")) if name == "": return f"output{ext}" known_exts = (".npz", ".bvh", ".csv", ".png", ".mp4") lower = name.lower() if lower.endswith(known_exts): return os.path.splitext(name)[0] + ext root, cur_ext = os.path.splitext(name) if cur_ext == "": return name + ext return name def _get_render_size(event_client: viser.ClientHandle) -> tuple[int, int]: width = int(event_client.camera.image_width) height = int(event_client.camera.image_height) if width <= 0 or height <= 0: # Fall back to a reasonable default if the camera hasn't synced yet. return (1280, 720) return (width, height) def _round_up_to_multiple(value: int, multiple: int) -> int: if multiple <= 0: return value return ((value + multiple - 1) // multiple) * multiple def _download_canvas_to_browser(event_client: viser.ClientHandle, *, filename: str) -> None: """Use the client-side canvas save path to avoid server-side renders.""" import json js = f""" (() => {{ const filename = {json.dumps(filename)}; const canvases = Array.from(document.querySelectorAll("canvas")); if (!canvases.length) {{ console.error("No canvases found to save."); return; }} // Pick the largest canvas by area (usually the main 3D view). const canvas = canvases.reduce((best, cur) => {{ const bestArea = (best?.width || 0) * (best?.height || 0); const curArea = (cur?.width || 0) * (cur?.height || 0); return curArea > bestArea ? cur : best; }}, null); if (!canvas) {{ console.error("No canvas selected to save."); return; }} canvas.toBlob((blob) => {{ if (!blob) {{ console.error("Export failed"); return; }} const url = URL.createObjectURL(blob); const a = document.createElement("a"); a.href = url; a.download = filename; document.body.appendChild(a); a.click(); a.remove(); URL.revokeObjectURL(url); }}, "image/png"); }})(); """ from viser import _messages as _viser_messages event_client.gui._websock_interface.queue_message( # type: ignore[attr-defined] _viser_messages.RunJavascriptMessage(source=js) ) @gui_screenshot_button.on_click def _(event: viser.GuiEvent) -> None: event_client = event.client if get_active_session(event_client) is None: return try: filename = _coerce_download_filename( str(gui_screenshot_path_text.value), ext=".png", ) _download_canvas_to_browser(event_client, filename=filename) event_client.add_notification( title="Screenshot download started", body=f"Saving {filename}", auto_close_seconds=5.0, color="green", ) except Exception as e: import traceback traceback.print_exc() event_client.add_notification( title="Failed to download screenshot!", body=str(e), auto_close_seconds=10.0, color="red", ) @gui_video_button.on_click def _(event: viser.GuiEvent) -> None: event_client = event.client session = get_active_session(event_client) if session is None: return recording_notification: viser.NotificationHandle | None = None try: recording_notification = event_client.add_notification( title="Recording video...", body="Saving frames, please wait.", loading=True, with_close_button=False, auto_close_seconds=None, color="blue", ) event_client.timeline.disable_constraints() width, height = _get_render_size(event_client) # Avoid ffmpeg macro block resizing warnings. width = _round_up_to_multiple(width, 16) height = _round_up_to_multiple(height, 16) original_frame = session.frame_idx frames = [] for frame_idx in range(session.max_frame_idx + 1): demo.set_frame( event_client.client_id, frame_idx, update_timeline=True, ) frames.append( event_client.get_render( height=height, width=width, transport_format="jpeg", ) ) # Restore the original frame (and timeline). demo.set_frame(event_client.client_id, original_frame) import imageio.v3 as iio filename = _coerce_download_filename( str(gui_video_path_text.value), ext=".mp4", ) payload = iio.imwrite( "", frames, extension=".mp4", fps=float(session.model_fps), codec="h264", plugin="pyav", ) event_client.send_file_download(filename, payload, save_immediately=True) event_client.add_notification( title="Video download started", body=f"Saving {filename}", auto_close_seconds=5.0, color="green", ) except Exception as e: import traceback traceback.print_exc() event_client.add_notification( title="Failed to download video!", body=str(e), auto_close_seconds=10.0, color="red", ) finally: event_client.timeline.enable_constraints() if recording_notification is not None: recording_notification.remove() @gui_download_button.on_click def _(event: viser.GuiEvent) -> None: event_client = event.client session = get_active_session(event_client) if session is None: return motion = _get_primary_motion(session) try: fmt = str(gui_download_format_dropdown.value).upper() raw_name = str(gui_download_name_text.value) if fmt == "BVH": filename = _coerce_download_filename(raw_name, ext=".bvh") payload = motion_to_bvh_bytes( motion.joints_local_rot, motion.joints_pos[:, session.skeleton.root_idx, :], # root positions skeleton=session.skeleton, fps=float(session.model_fps), ) mime = "text/plain" elif fmt == "CSV": filename = _coerce_download_filename(raw_name, ext=".csv") payload = _motion_to_csv_bytes(motion, session) mime = "text/csv" elif fmt == "AMASS NPZ": filename = _coerce_download_filename(raw_name, ext=".npz") payload = _motion_to_amass_npz_bytes(motion, session) mime = "application/octet-stream" else: # Default to NPZ (most common and matches existing save/load). filename = _coerce_download_filename(raw_name, ext=".npz") payload = _motion_to_npz_bytes(motion) mime = "application/octet-stream" _download_bytes_to_browser( event_client, data=payload, filename=filename, mime_type=mime, ) event_client.add_notification( title="Download started", body=f"Saving {filename}", auto_close_seconds=5.0, color="green", ) except Exception as e: import traceback traceback.print_exc() event_client.add_notification( title="Failed to download motion!", body=str(e), auto_close_seconds=10.0, color="red", ) @gui_save_example_button.on_click def _(event: viser.GuiEvent) -> None: from kimodo.tools import save_json event_client = event.client session = get_active_session(event_client) if session is None: return save_dir = gui_save_example_path_text.value if os.path.exists(save_dir): event_client.add_notification( title="Failed to save example!", body="Example directory already exists", auto_close_seconds=10.0, color="red", ) return try: os.makedirs(save_dir) # save the constraints constraint_path = os.path.join(save_dir, "constraints.json") save_constraints(event_client, constraint_path) # save the motion motion_path = os.path.join(save_dir, "motion.npz") save_motion(event_client, motion_path, "NPZ") # save the gui metadata meta_path = os.path.join(save_dir, "meta.json") prompt_texts = [] prompt_durations_sec = [] prompt_values = sorted( [x for x in client.timeline._prompts.values()], key=lambda x: x.start_frame, ) for i, prompt in enumerate(prompt_values): prompt_texts.append(prompt.text) # Match demo/generation convention: # non-last prompts: [start, end) ; last prompt: [start, end]. n_frames = prompt.end_frame - prompt.start_frame if i == len(prompt_values) - 1: n_frames += 1 prompt_durations_sec.append(n_frames / session.model_fps) if len(prompt_texts) == 1: meta_info = { "text": prompt_texts[0], "duration": prompt_durations_sec[0], } else: meta_info = { "texts": prompt_texts, "durations": prompt_durations_sec, } meta_info["num_samples"] = gui_num_samples_slider.value meta_info["seed"] = gui_seed.value meta_info["diffusion_steps"] = gui_diffusion_steps_slider.value meta_info["cfg"] = { "enabled": gui_cfg_checkbox.value, "text_weight": gui_cfg_text_weight_slider.value, "constraint_weight": gui_cfg_constraint_weight_slider.value, } save_json(meta_path, meta_info) # update the example dropdown session.example_dict = viser_utils.load_example_cases(session.examples_base_dir) update_examples_dropdown(session.example_dict, keep_selection=True) event_client.add_notification( title="Example saved!", body=f"Saved example to {save_dir}", auto_close_seconds=5.0, color="green", ) except Exception as e: import traceback traceback.print_exc() event_client.add_notification( title="Failed to save example!", body=str(e), auto_close_seconds=10.0, color="red", ) def set_new_duration(client_id, new_duration): session = demo.client_sessions[client_id] session.cur_duration = new_duration update_duration_gui(new_duration) session.max_frame_idx = int(session.cur_duration * session.model_fps - 1) if session.frame_idx > session.max_frame_idx: demo.set_frame(client_id, session.max_frame_idx) def apply_model_selection(new_model_name: str) -> None: session = demo.client_sessions[client_id] if new_model_name == session.model_name: return session.playing = False # Pause playback when switching models. old_model_fps = session.model_fps old_duration = session.cur_duration old_prompts = [ (prompt.text, prompt.start_frame, prompt.end_frame) for prompt in client.timeline._prompts.values() ] old_default_zoom_frames = client.timeline._default_num_frames_zoom old_max_zoom_frames = client.timeline._max_frames_zoom model_bundle = demo.load_model(new_model_name) # Clear motions and constraints when switching models. if session.edit_mode and session.motions: exit_editing_mode(session) session.edit_mode = False demo.clear_motions(client_id) with session.timeline_data["keyframe_update_lock"]: for constraint in list(session.constraints.values()): constraint.clear() session.constraints = demo.build_constraint_tracks(client, model_bundle.skeleton) session.timeline_data["keyframes"] = {} session.timeline_data["intervals"] = {} client.timeline.clear_keyframes() client.timeline.clear_intervals() session.model_name = new_model_name session.model_fps = model_bundle.model_fps session.skeleton = model_bundle.skeleton session.motion_rep = model_bundle.motion_rep session.cur_duration = old_duration session.max_frame_idx = int(session.cur_duration * session.model_fps - 1) session.frame_idx = 0 session.edit_mode = False demo.set_timeline_defaults(client.timeline, session.model_fps) client.timeline.set_current_frame(0) gui_model_fps.value = session.model_fps update_duration_gui(session.cur_duration) if old_model_fps > 0: default_zoom_seconds = old_default_zoom_frames / old_model_fps max_zoom_seconds = old_max_zoom_frames / old_model_fps new_default_zoom = int(round(default_zoom_seconds * session.model_fps)) new_max_zoom = int(round(max_zoom_seconds * session.model_fps)) new_default_zoom = max(1, new_default_zoom) new_max_zoom = max(new_default_zoom, new_max_zoom) client.timeline.set_zoom_settings( default_num_frames_zoom=new_default_zoom, max_frames_zoom=new_max_zoom, ) client.timeline.clear_prompts() if old_prompts and old_model_fps > 0: for i, (prompt_text, start_frame, end_frame) in enumerate(old_prompts): start_sec = start_frame / old_model_fps end_sec = end_frame / old_model_fps new_start = int(round(start_sec * session.model_fps)) new_end = int(round(end_sec * session.model_fps)) new_start = max(0, min(new_start, session.max_frame_idx)) new_end = max(new_start, min(new_end, session.max_frame_idx)) color = PROMPT_COLORS[i % len(PROMPT_COLORS)] client.timeline.add_prompt(prompt_text, new_start, new_end, color=color) session.examples_base_dir = demo.get_examples_base_dir(new_model_name, absolute=True) session.example_dict = viser_utils.load_example_cases(session.examples_base_dir) update_examples_dropdown(session.example_dict, keep_selection=False) gui_save_example_path_text.value = os.path.join( demo.get_examples_base_dir(new_model_name, absolute=True), "custom_example_1", ) gui_load_example_path_text.value = os.path.join( demo.get_examples_base_dir(new_model_name, absolute=True), "custom_example_1", ) demo.add_character_motion(client, session.skeleton) apply_constraint_overlay_visibility(session) def _update_version_and_display_from_dataset_skeleton() -> None: dataset_ui = gui_dataset_selector.value skeleton_display = gui_skeleton_selector.value skeleton_val = get_skeleton_key_from_display_name(skeleton_display) if skeleton_val is None: return models = get_models_for_dataset_skeleton(dataset_ui, skeleton_val, family="Kimodo") if not models: return gui_version_selector.options = [m.display_name for m in models] gui_version_selector.value = models[0].display_name gui_version_selector.visible = len(models) > 1 gui_model_display.content = f"**Model:** {models[0].display_name}" def _update_visibility_for_loaded_model(loaded_model_name: str) -> None: """Update model-specific controls from the currently loaded model only.""" if not loaded_model_name: return _update_motion_export_dropdown(loaded_model_name) gui_use_soma_layer_checkbox.visible = "soma" in loaded_model_name _is_g1 = "g1" in loaded_model_name gui_real_robot_rotations_checkbox.visible = _is_g1 gui_postprocess_checkbox.visible = not _is_g1 gui_root_margin.visible = not _is_g1 and gui_postprocess_checkbox.value if _is_g1: gui_gizmo_space_dropdown.value = "Local" gui_gizmo_space_dropdown.visible = not _is_g1 gui_gizmo_space_dropdown.disabled = _is_g1 def _on_load_model_click(event: viser.GuiEvent) -> None: """Load the currently selected model (called from Load model button).""" if get_active_session(event.client) is None: return new_model_name = gui_model_selector.value if not new_model_name: return info = get_model_info(new_model_name) if info is None: return session = demo.client_sessions[event.client.client_id] if new_model_name == session.model_name: return loading_notif = event.client.add_notification( title="Loading model...", body=f"Loading {info.display_name}", loading=True, with_close_button=False, ) try: apply_model_selection(new_model_name) _update_visibility_for_loaded_model(new_model_name) loading_notif.title = "Model loaded" loading_notif.body = f"{info.display_name} is ready." loading_notif.loading = False loading_notif.with_close_button = True loading_notif.auto_close_seconds = 5.0 loading_notif.color = "green" except Exception as e: loading_notif.loading = False loading_notif.with_close_button = True event.client.add_notification( title="Model failed to load", body=str(e), color="red", auto_close_seconds=10.0, ) gui_model_selector.set_from_short_key(session.model_name) @gui_load_model_button.on_click def _(event: viser.GuiEvent) -> None: _on_load_model_click(event) @gui_dataset_selector.on_update def _(event: viser.GuiEvent) -> None: if get_active_session(event.client) is None: return skeleton_labels = get_allowed_skeleton_labels(gui_dataset_selector.value) gui_skeleton_selector.options = skeleton_labels gui_skeleton_selector.value = skeleton_labels[0] if skeleton_labels else "" _update_version_and_display_from_dataset_skeleton() @gui_skeleton_selector.on_update def _(event: viser.GuiEvent) -> None: if get_active_session(event.client) is None: return _update_version_and_display_from_dataset_skeleton() @gui_version_selector.on_update def _(event: viser.GuiEvent) -> None: if get_active_session(event.client) is None: return info = get_model_info(gui_model_selector.value) if info is not None: gui_model_display.content = f"**Model:** {info.display_name}" @gui_use_soma_layer_checkbox.on_update def _(event: viser.GuiEvent) -> None: session = get_active_session(event.client) if session is None or "soma" not in (session.model_name or ""): return loading_notif = event.client.add_notification( title="Applying SOMA layer...", body="Updating mesh.", loading=True, with_close_button=False, ) try: current_motion = list(session.motions.values())[0] if session.motions else None current_frame_idx = session.frame_idx # Recreate the character to apply the new SOMA mesh mode selection. demo.clear_motions(event.client.client_id) if current_motion is None: demo.add_character_motion(event.client, session.skeleton) else: demo.add_character_motion( event.client, session.skeleton, current_motion.joints_pos, current_motion.joints_rot, current_motion.foot_contacts, ) demo.set_frame(event.client.client_id, current_frame_idx) except Exception as e: print(e) event.client.add_notification( title="SOMA layer failed", body=str(e), color="red", auto_close_seconds=10.0, ) gui_use_soma_layer_checkbox.value = not gui_use_soma_layer_checkbox.value finally: loading_notif.loading = False loading_notif.with_close_button = True loading_notif.auto_close_seconds = 2.0 @gui_real_robot_rotations_checkbox.on_update def _(event: viser.GuiEvent) -> None: session = get_active_session(event.client) if session is None or "g1" not in session.model_name: return if not isinstance(session.skeleton, G1Skeleton34) or not session.motions: return if not gui_real_robot_rotations_checkbox.value: return # Reproject all displayed G1 motions to real robot DoF (1-DoF per joint + axis limits). from kimodo.skeleton import global_rots_to_local_rots current_frame_idx = session.frame_idx for motion in session.motions.values(): if motion.length <= 1: continue rest_pos = motion.joints_pos[0:1] rest_rot = motion.joints_rot[0:1] same_as_rest = (motion.joints_pos - rest_pos).abs().max().item() < 1e-6 and ( motion.joints_rot - rest_rot ).abs().max().item() < 1e-6 if same_as_rest: continue new_pos, new_rot = generation.apply_g1_real_robot_projection( session.skeleton, motion.joints_pos, motion.joints_rot, ) motion.joints_pos = new_pos motion.joints_rot = new_rot motion.joints_local_rot = global_rots_to_local_rots(new_rot, session.skeleton) # Refresh skeleton and skinned mesh caches so the viz uses new positions. motion.precompute_mesh_info() demo.set_frame(event.client.client_id, current_frame_idx) event.client.add_notification( title="Real robot projection applied", body="The motion is projected to G1 real robot DoF (1-DoF per joint, clamped to axis limits).", auto_close_seconds=4.0, color="green", ) def load_example_from_path( event_client: viser.ClientHandle, example_path: str, load_gt: bool = False, ) -> None: from kimodo.meta import parse_prompts_from_meta from kimodo.tools import load_json session = get_active_session(event_client) if session is None: return # Pause playback when loading an example. session.playing = False if not os.path.isdir(example_path): event_client.add_notification( title="Example path not found", body=f"Directory does not exist: {example_path}", auto_close_seconds=5.0, color="red", ) return try: # constraints constraints_path = os.path.join(example_path, "constraints.json") if os.path.exists(constraints_path): load_constraints(event_client, constraints_path) else: # clear all existing constraints with session.timeline_data["keyframe_update_lock"]: for constraint in list(session.constraints.values()): constraint.clear() event_client.timeline.clear_keyframes() event_client.timeline.clear_intervals() # motion motion_filename = "gt_motion.npz" if load_gt else "motion.npz" motion_path = os.path.join(example_path, motion_filename) if os.path.exists(motion_path): load_motion(event_client, motion_path) # metadata meta_path = os.path.join(example_path, "meta.json") if os.path.exists(meta_path): meta_info = load_json(meta_path) event_client.timeline.clear_prompts() texts, durations_sec = parse_prompts_from_meta(meta_info) fps = session.model_fps # Convert durations (seconds) to consecutive frame bounds num_frames = 0 frame_bounds = [] for i, d in enumerate(durations_sec): n_frames = max(1, int(round(d * fps))) start_frame = num_frames # Inverse of compute_prompt_num_frames(): # non-last prompts end at next prompt start (exclusive), # last prompt includes its end frame. if i == len(durations_sec) - 1: end_frame = num_frames + n_frames - 1 else: end_frame = num_frames + n_frames frame_bounds.append((start_frame, end_frame)) num_frames += n_frames # Adapt timeline zoom to the loaded motion. target_visible_frames = int(math.ceil(1.10 * num_frames)) event_client.timeline.set_zoom_settings( default_num_frames_zoom=target_visible_frames, ) for i, (prompt_text, (start_frame, end_frame)) in enumerate(zip(texts, frame_bounds)): color = PROMPT_COLORS[i % len(PROMPT_COLORS)] event_client.timeline.add_prompt(prompt_text, start_frame, end_frame, color=color) update_duration_auto() # Only load optional fields if present if "num_samples" in meta_info: gui_num_samples_slider.value = meta_info["num_samples"] if "seed" in meta_info: gui_seed.value = meta_info["seed"] if "diffusion_steps" in meta_info: gui_diffusion_steps_slider.value = meta_info["diffusion_steps"] if "cfg" in meta_info: cfg = meta_info["cfg"] if "enabled" in cfg: gui_cfg_checkbox.value = cfg["enabled"] if "text_weight" in cfg: gui_cfg_text_weight_slider.value = cfg["text_weight"] if "constraint_weight" in cfg: gui_cfg_constraint_weight_slider.value = cfg["constraint_weight"] # Set frame to 0 when example is loaded. session.frame_idx = 0 event_client.timeline.set_current_frame(0) demo.set_frame(event_client.client_id, 0) event_client.add_notification( title="Example loaded!", body=f"Loaded example from {example_path}", auto_close_seconds=5.0, color="green", ) except Exception as e: import traceback traceback.print_exc() event_client.add_notification( title="Failed to load example!", body=str(e), auto_close_seconds=10.0, color="red", ) def load_qwen_example_plan(event_client: viser.ClientHandle) -> None: """Load a Qwen-generated 10-action prompt plan into the timeline. This preserves the native UI flow: 1) Load Example -> fills timeline prompt segments 2) Generate -> synthesizes motion from loaded prompts """ session = get_active_session(event_client) if session is None: return def _thread_fn() -> None: try: history: list[str] = [] all_texts: list[str] = [] all_durations: list[float] = [] target_actions = 10 rounds = 0 while len(all_texts) < target_actions and rounds < 8: remaining = target_actions - len(all_texts) batch, history = call_qwen_for_prompts( scene="Agentic demo: keep one character in continuous motion", history=history, requested_actions=min(5, remaining), ) texts = batch.get("texts", []) durations = batch.get("durations", []) for t, d in zip(texts, durations): if len(all_texts) >= target_actions: break all_texts.append(t) all_durations.append(float(d)) rounds += 1 if len(all_texts) == 0: event_client.add_notification( title="Qwen example load failed", body="No prompt segments were produced.", auto_close_seconds=6.0, color="red", ) return fps = session.model_fps event_client.timeline.clear_prompts() frame_cursor = 0 for i, (txt, dur) in enumerate(zip(all_texts, all_durations)): n_frames = max(1, int(round(dur * fps))) start_frame = frame_cursor end_frame = frame_cursor + n_frames if i < len(all_texts) - 1 else frame_cursor + n_frames - 1 color = PROMPT_COLORS[i % len(PROMPT_COLORS)] event_client.timeline.add_prompt(txt, start_frame, end_frame, color=color) frame_cursor += n_frames target_visible_frames = int(math.ceil(1.10 * frame_cursor)) event_client.timeline.set_zoom_settings(default_num_frames_zoom=max(60, target_visible_frames)) update_duration_auto() event_client.add_notification( title="Qwen example loaded", body=f"Loaded {len(all_texts)} prompt segments. Click Generate to synthesize motion.", auto_close_seconds=6.0, color="green", ) except Exception as e: event_client.add_notification( title="Qwen example load failed", body=str(e), auto_close_seconds=8.0, color="red", ) threading.Thread(target=_thread_fn, daemon=True).start() @gui_load_example_button.on_click def _(event: viser.GuiEvent) -> None: event_client = event.client session = get_active_session(event_client) if session is None: return if gui_examples_dropdown.value in (QWEN_EXAMPLE_NAME, QWEN_EXAMPLE_LEGACY_NAME): load_qwen_example_plan(event_client) return if not session.example_dict or (gui_examples_dropdown.value not in session.example_dict): event_client.add_notification( title="No examples available", body="No examples found for the selected model.", auto_close_seconds=5.0, color="red", ) return example_path = session.example_dict[gui_examples_dropdown.value] load_example_from_path(event_client, example_path, gui_load_gt_checkbox.value) @gui_load_example_from_path_button.on_click def _(event: viser.GuiEvent) -> None: event_client = event.client session = get_active_session(event_client) if session is None: return example_path = gui_load_example_path_text.value if not example_path: event_client.add_notification( title="No example path", body="Please provide an example directory.", auto_close_seconds=5.0, color="red", ) return load_example_from_path(event_client, example_path, gui_load_gt_checkbox.value) @gui_cfg_checkbox.on_update def _(_) -> None: if not demo.client_active(client_id): return val = gui_cfg_checkbox.value gui_cfg_text_weight_slider.visible = val gui_cfg_constraint_weight_slider.visible = val def exit_editing_mode(session: ClientSession): gui_edit_constraint_button.label = "Enter Editing Mode" gui_generate_button.disabled = False gui_generate_button.label = "Generate" gui_reset_constraint_button.disabled = True if "g1" in session.model_name: gui_gizmo_space_dropdown.value = "Local" gui_gizmo_space_dropdown.disabled = True gui_gizmo_space_dropdown.visible = False else: gui_gizmo_space_dropdown.disabled = False gui_gizmo_space_dropdown.visible = True gui_undo_drag_button.disabled = True gui_use_soma_layer_checkbox.disabled = False session.edit_mode_snapshot = None session.undo_drag_snapshot = None motion = list(session.motions.values())[0] motion.clear_all_gizmos() motion.character.set_skinned_mesh_wireframe(False) motion.character.set_skeleton_visibility(False) motion.character.set_skinned_mesh_visibility(True) motion.character.set_skinned_mesh_opacity(1.0) session.gui_elements.gui_viz_skinned_mesh_opacity_slider.value = 1.0 # If the path is dense, put the motion back on the path if "2D Root" in session.constraints and session.constraints["2D Root"].dense_path: _update_dense_path(motion, session) gui_viz_skinned_mesh_checkbox.value = True gui_viz_skeleton_checkbox.value = False # enter editing mode callback @gui_edit_constraint_button.on_click def _(event: viser.GuiEvent) -> None: event_client = event.client session = get_active_session(event_client) if session is None: return session.edit_mode = not session.edit_mode edit_alert = "Entered editing mode" no_edit_alert = "Exited editing mode" edit_message = "You can now modify pose or path constraints." no_edit_message = "Can now generate motions." event_client.add_notification( title=edit_alert if session.edit_mode else no_edit_alert, body=edit_message if session.edit_mode else no_edit_message, auto_close_seconds=10.0, color="blue", ) if session.edit_mode: gui_edit_constraint_button.label = "Exit Editing Mode" gui_generate_button.disabled = True gui_generate_button.label = "Generate Disabled In Editing Mode" if "g1" in session.model_name: gui_gizmo_space_dropdown.value = "Local" gui_gizmo_space_dropdown.disabled = True gui_use_soma_layer_checkbox.disabled = True assert len(session.motions) == 1, "Only one motion allowed in edit mode" motion = list(session.motions.values())[0] snapshot_frame_idx = min(session.frame_idx, motion.length - 1) session.edit_mode_snapshot = {} ensure_edit_snapshot(session, motion, snapshot_frame_idx) gui_reset_constraint_button.disabled = False motion.character.set_skeleton_visibility(True) # motion.character.set_skinned_mesh_wireframe(True) motion.character.set_skinned_mesh_opacity(0.65) session.gui_elements.gui_viz_skinned_mesh_opacity_slider.value = 0.65 motion.character.set_skinned_mesh_visibility(True) gui_viz_skinned_mesh_checkbox.value = True gui_viz_skeleton_checkbox.value = True # need gizmos for root translation and individual joints def _on_root2d_gizmo_release(): if "2D Root" in session.constraints and session.constraints["2D Root"].dense_path: mot = list(session.motions.values())[0] _update_dense_path(mot, session) def _on_gizmo_drag_start(): mot = list(session.motions.values())[0] frame_idx = min(session.frame_idx, mot.length - 1) session.undo_drag_snapshot = { "frame_idx": frame_idx, "joints_pos": mot.get_joints_pos(frame_idx), "joints_rot": mot.get_joints_rot(frame_idx), } gui_undo_drag_button.disabled = False motion.add_root_translation_gizmo( session.constraints, on_2d_root_drag_end=_on_root2d_gizmo_release, on_drag_start=_on_gizmo_drag_start, ) gizmo_space = "local" if "g1" in session.model_name else gui_gizmo_space_dropdown.value.lower() motion.add_joint_gizmos( session.constraints, space=gizmo_space, on_drag_start=_on_gizmo_drag_start, ) else: exit_editing_mode(session) @gui_reset_constraint_button.on_click def _(event: viser.GuiEvent) -> None: event_client = event.client session = get_active_session(event_client) if session is None or not session.edit_mode_snapshot: return if not session.motions: return motion = list(session.motions.values())[0] snapshot_frame_idx = min(session.frame_idx, motion.length - 1) if snapshot_frame_idx not in session.edit_mode_snapshot: return motion.update_pose_at_frame( snapshot_frame_idx, joints_pos=session.edit_mode_snapshot[snapshot_frame_idx]["joints_pos"], joints_rot=session.edit_mode_snapshot[snapshot_frame_idx]["joints_rot"], ) demo.set_frame(event_client.client_id, snapshot_frame_idx, update_timeline=False) @gui_undo_drag_button.on_click def _(event: viser.GuiEvent) -> None: event_client = event.client session = get_active_session(event_client) if session is None or session.undo_drag_snapshot is None: return if not session.motions: return motion = list(session.motions.values())[0] frame_idx = session.undo_drag_snapshot["frame_idx"] motion.update_pose_at_frame( frame_idx, joints_pos=session.undo_drag_snapshot["joints_pos"], joints_rot=session.undo_drag_snapshot["joints_rot"], ) demo.set_frame(event_client.client_id, frame_idx, update_timeline=False) session.undo_drag_snapshot = None gui_undo_drag_button.disabled = True def validate_interval(start_frame_idx: int, end_frame_idx: int, max_frame_idx: int) -> bool: if start_frame_idx < 0 or start_frame_idx > max_frame_idx: return False if end_frame_idx < 0 or end_frame_idx > max_frame_idx: return False if end_frame_idx < start_frame_idx: return False return True def clamp_interval_to_range( start_frame_idx: int, end_frame_idx: int, max_frame_idx: int ) -> Optional[tuple[int, int]]: if end_frame_idx < 0 or start_frame_idx > max_frame_idx: return None start_clamped = max(0, start_frame_idx) end_clamped = min(max_frame_idx, end_frame_idx) if end_clamped < start_clamped: return None return start_clamped, end_clamped # add constraint callback def add_constraint_callback( constraint_id: str, constraint_type: str, frame_range: tuple[int, int], joint_names: list[str] = None, verbose: bool = True, ): """Add a constraint to the session. Args: constraint_type: str, the type of constraint to add frame_range: tuple[int, int], the frame range to add the constraint to joint_names: list[str], the names of the joints to constraint if the constraint type is End-Effectors """ # Check if session still exists if not demo.client_active(client_id): return session = demo.client_sessions[client_id] assert len(session.motions) == 1, "Only one motion allowed for adding constraints" motion = list(session.motions.values())[0] end_effector_type = None if constraint_type in [ "Left Hand", "Right Hand", "Left Foot", "Right Foot", ]: joint_names = [constraint_type.replace(" ", ""), "Hips"] # Hips are required because of smooth root representation end_effector_type = constraint_type.replace(" ", "-").lower() constraint_type = "End-Effectors" # check to make sure interval is valid is_interval = frame_range[1] != frame_range[0] start_frame_idx = int(frame_range[0]) end_frame_idx = int(frame_range[1]) if is_interval: clamped = clamp_interval_to_range(start_frame_idx, end_frame_idx, session.max_frame_idx) if clamped is None: print("Interval outside range! Couldn't add constraint.") return start_frame_idx, end_frame_idx = clamped else: if not validate_interval(start_frame_idx, end_frame_idx, session.max_frame_idx): print("Invalid interval! Couldn't add constraint.") return # collect input args for the constraint based on which track it is if is_interval: constraint_kwargs = { "interval_id": constraint_id, "start_frame_idx": start_frame_idx, "end_frame_idx": end_frame_idx, } else: constraint_kwargs = { "keyframe_id": constraint_id, "frame_idx": start_frame_idx, } if constraint_type in ["Full-Body", "End-Effectors"]: constraint_kwargs["joints_pos"] = motion.get_joints_pos(start_frame_idx, end_frame_idx) constraint_kwargs["joints_rot"] = motion.get_joints_rot(start_frame_idx, end_frame_idx) if constraint_type == "End-Effectors": constraint_kwargs["joint_names"] = joint_names constraint_kwargs["end_effector_type"] = end_effector_type elif constraint_type == "2D Root": constraint_kwargs["root_pos"] = motion.get_projected_root_pos(start_frame_idx, end_frame_idx) # add the keyframe(s) to the constraint track constraint = session.constraints[constraint_type] if is_interval: constraint.add_interval(**constraint_kwargs) else: constraint.add_keyframe(**constraint_kwargs) apply_constraint_overlay_visibility(session) if verbose: client.add_notification( title="Constraint added", body="", auto_close_seconds=5.0, color="blue", ) # timeline callbacks for keyframes and intervals @client.timeline.on_keyframe_add def _(keyframe_id: str, track_id: str, frame: int): """Called when a keyframe is added to a track.""" if not demo.client_active(client_id): return session = demo.client_sessions[client_id] with session.timeline_data["keyframe_update_lock"]: constraint_type = session.timeline_data["tracks"][track_id]["name"] add_constraint_callback( keyframe_id, constraint_type, (frame, frame), verbose=False, ) keyframe_data = client.timeline._keyframes.get(keyframe_id) session.timeline_data["keyframes"][keyframe_id] = { "frame": frame, "track_id": track_id, "locked": bool(keyframe_data.locked) if keyframe_data is not None else False, "opacity": keyframe_data.opacity if keyframe_data is not None else 1.0, "value": keyframe_data.value if keyframe_data is not None else None, } # Update smooth path when adding a keyframe (single action, not drag). if constraint_type == "2D Root" and session.constraints["2D Root"].dense_path: motion = list(session.motions.values())[0] _update_dense_path(motion, session) @client.timeline.on_interval_add def handle_interval_add(interval_id: str, track_id: str, start_frame: int, end_frame: int): """Called when an interval is added to a track.""" if not demo.client_active(client_id): return session = demo.client_sessions[client_id] with session.timeline_data["keyframe_update_lock"]: constraint_type = session.timeline_data["tracks"][track_id]["name"] add_constraint_callback( interval_id, constraint_type, (start_frame, end_frame), verbose=False, ) interval_data = client.timeline._intervals.get(interval_id) session.timeline_data["intervals"][interval_id] = { "track_id": track_id, "start_frame_idx": start_frame, "end_frame_idx": end_frame, "locked": bool(interval_data.locked) if interval_data is not None else False, "opacity": interval_data.opacity if interval_data is not None else 1.0, "value": interval_data.value if interval_data is not None else None, } if constraint_type == "2D Root" and session.constraints["2D Root"].dense_path: motion = list(session.motions.values())[0] _update_dense_path(motion, session) def remove_constraint_callback( constraint_id: str, constraint_type: str, frame_range: tuple[int, int], verbose: bool = True, ) -> None: if not demo.client_active(client_id): return session = demo.client_sessions[client_id] session.updating_motions = True is_interval = frame_range[1] != frame_range[0] start_frame_idx = int(frame_range[0]) end_frame_idx = int(frame_range[1]) if is_interval: clamped = clamp_interval_to_range(start_frame_idx, end_frame_idx, session.max_frame_idx) if clamped is None: return start_frame_idx, end_frame_idx = clamped else: if not validate_interval(start_frame_idx, end_frame_idx, session.max_frame_idx): print("Invalid interval! Couldn't remove constraint.") return if constraint_type in [ "Left Hand", "Right Hand", "Left Foot", "Right Foot", ]: constraint_type = "End-Effectors" constraint = session.constraints[constraint_type] if is_interval: constraint.remove_interval(constraint_id, start_frame_idx, end_frame_idx) else: constraint.remove_keyframe(constraint_id, start_frame_idx) if verbose: client.add_notification( title="Constraint removed", body="", auto_close_seconds=5.0, color="blue", ) @client.timeline.on_keyframe_move def handle_keyframe_move(keyframe_id: str, new_frame: int): """Called when a keyframe is moved to a new frame.""" # print(f"Keyframe moved: {keyframe_id} to frame {new_frame}") if not demo.client_active(client_id): return session = demo.client_sessions[client_id] # Cancel any pending timer for this keyframe timeline_data = session.timeline_data with timeline_data["keyframe_update_lock"]: if keyframe_id in timeline_data["keyframe_move_timers"]: timeline_data["keyframe_move_timers"][keyframe_id].cancel() # Store the latest target frame timeline_data["pending_keyframe_moves"][keyframe_id] = new_frame # Create a new timer to execute the actual move after a delay # This debounces rapid movements - only execute when user stops moving timer = threading.Timer( 0.03, # 10ms delay - adjust as needed _execute_keyframe_move, args=(client_id, keyframe_id, new_frame, session), ) timeline_data["keyframe_move_timers"][keyframe_id] = timer timer.start() def _execute_keyframe_move( client_id: int, keyframe_id: str, new_frame: int, session: ClientSession, ): """Actually execute the keyframe move operation (called after debounce delay).""" timeline_data = session.timeline_data with timeline_data["keyframe_update_lock"]: # Check if this move is still the latest one if keyframe_id not in timeline_data["pending_keyframe_moves"]: return # Move was cancelled if timeline_data["pending_keyframe_moves"][keyframe_id] != new_frame: return # A newer move superseded this one # Remove from pending del timeline_data["pending_keyframe_moves"][keyframe_id] if keyframe_id in timeline_data["keyframe_move_timers"]: del timeline_data["keyframe_move_timers"][keyframe_id] # Now execute the actual move (keep it in the lock so we don't delete it while moving) if keyframe_id not in timeline_data["keyframes"]: # double check return keyframe_data = timeline_data["keyframes"][keyframe_id] if not keyframe_data: return # if the frame did not move, don't do anything if keyframe_data["frame"] == new_frame: return track_id = keyframe_data["track_id"] constraint_type = timeline_data["tracks"][track_id]["name"] cur_frame = keyframe_data["frame"] # Remove constraint at old frame remove_constraint_callback( keyframe_id, constraint_type, (cur_frame, cur_frame), verbose=False, ) # Add constraint at new frame add_constraint_callback( keyframe_id, constraint_type, (new_frame, new_frame), verbose=False, ) # update our data keyframe_data["frame"] = new_frame # Schedule path update only after user stops dragging (no move for 300ms). if constraint_type == "2D Root": _schedule_dense_path_after_release(session) @client.timeline.on_keyframe_delete def handle_keyframe_delete(keyframe_id: str): """Called when a keyframe is deleted.""" if not demo.client_active(client_id): return session = demo.client_sessions[client_id] with session.timeline_data["keyframe_update_lock"]: if keyframe_id not in session.timeline_data["keyframes"]: return keyframe_data = session.timeline_data["keyframes"][keyframe_id] track_id = keyframe_data["track_id"] constraint_type = session.timeline_data["tracks"][track_id]["name"] cur_frame = keyframe_data["frame"] remove_constraint_callback( keyframe_id, constraint_type, (cur_frame, cur_frame), verbose=False, ) del session.timeline_data["keyframes"][keyframe_id] if constraint_type == "2D Root" and session.constraints["2D Root"].dense_path: motion = list(session.motions.values())[0] _update_dense_path(motion, session) @client.timeline.on_interval_move def handle_interval_move(interval_id: str, new_start: int, new_end: int): """Called when an interval is moved or resized.""" # print(f"Interval moved: {interval_id} to {new_start}-{new_end}") if not demo.client_active(client_id): return session = demo.client_sessions[client_id] # Cancel any pending timer for this interval # We share the same lock for keyframe and interval moves assuming the user can't move both at the same time timeline_data = session.timeline_data with timeline_data["keyframe_update_lock"]: if interval_id in timeline_data["keyframe_move_timers"]: timeline_data["keyframe_move_timers"][interval_id].cancel() # Store the latest target frame new_interval = (new_start, new_end) timeline_data["pending_keyframe_moves"][interval_id] = new_interval # Create a new timer to execute the actual move after a delay # This debounces rapid movements - only execute when user stops moving timer = threading.Timer( 0.5, # 100ms delay - adding interval is much slower than moving a keyframe _execute_interval_move, args=(client_id, interval_id, new_interval, session), ) timeline_data["keyframe_move_timers"][interval_id] = timer timer.start() def _execute_interval_move( client_id: int, interval_id: str, new_interval: tuple[int, int], session: ClientSession, ): """Actually execute the interval move operation (called after debounce delay).""" timeline_data = session.timeline_data with timeline_data["keyframe_update_lock"]: # Check if this move is still the latest one if interval_id not in timeline_data["pending_keyframe_moves"]: return # Move was cancelled if timeline_data["pending_keyframe_moves"][interval_id] != new_interval: return # A newer move superseded this one # Remove from pending del timeline_data["pending_keyframe_moves"][interval_id] if interval_id in timeline_data["keyframe_move_timers"]: del timeline_data["keyframe_move_timers"][interval_id] # Now execute the actual move if interval_id not in timeline_data["intervals"]: return interval_data = timeline_data["intervals"][interval_id] if not interval_data: return # if the interval did not move, don't do anything if ( interval_data["start_frame_idx"] == new_interval[0] and interval_data["end_frame_idx"] == new_interval[1] ): return track_id = interval_data["track_id"] constraint_type = timeline_data["tracks"][track_id]["name"] cur_range = ( interval_data["start_frame_idx"], interval_data["end_frame_idx"], ) # Remove constraint at old frame remove_constraint_callback( interval_id, constraint_type, cur_range, verbose=False, ) # Add constraint at new frame add_constraint_callback( interval_id, constraint_type, new_interval, verbose=False, ) # update our data interval_data["start_frame_idx"] = new_interval[0] interval_data["end_frame_idx"] = new_interval[1] # Schedule path update only after user stops dragging (no move for 300ms). if constraint_type == "2D Root": _schedule_dense_path_after_release(session) @client.timeline.on_interval_delete def handle_interval_delete(interval_id: str): """Called when an interval is deleted.""" if not demo.client_active(client_id): return session = demo.client_sessions[client_id] with session.timeline_data["keyframe_update_lock"]: if interval_id not in session.timeline_data["intervals"]: return interval_data = session.timeline_data["intervals"][interval_id] track_id = interval_data["track_id"] constraint_type = session.timeline_data["tracks"][track_id]["name"] remove_constraint_callback( interval_id, constraint_type, ( interval_data["start_frame_idx"], interval_data["end_frame_idx"], ), verbose=False, ) del session.timeline_data["intervals"][interval_id] if constraint_type == "2D Root" and session.constraints["2D Root"].dense_path: motion = list(session.motions.values())[0] _update_dense_path(motion, session) @gui_snap_to_constraint_button.on_click def _(event: viser.GuiEvent) -> None: event_client = event.client session = get_active_session(event_client) if session is None: return target_character_motion = list(session.motions.values())[0] frame_idx = session.frame_idx if frame_idx >= target_character_motion.length: # frame idx larger than the motion, could not snap return for constraint_name in ["Full-Body", "End-Effectors"]: if ( constraint_name in session.constraints and frame_idx in session.constraints[constraint_name].keyframes ): pos = session.constraints[constraint_name].keyframes[frame_idx]["joints_pos"] rot = session.constraints[constraint_name].keyframes[frame_idx]["joints_rot"] # update the full joints_pos of the character to match the constraints target_character_motion.update_pose_at_frame( frame_idx, joints_pos=pos, joints_rot=rot, ) target_character_motion.set_frame(frame_idx) return # motion already fully changed if "2D Root" in session.constraints and frame_idx in session.constraints["2D Root"].keyframes: # update only the root position new_root_pos = session.constraints["2D Root"].keyframes[frame_idx] old_root_pos = target_character_motion.get_projected_root_pos(frame_idx) root_diff = new_root_pos - old_root_pos root_diff[1] = 0.0 # don't change height new_joints_pos = ( target_character_motion.joints_pos[frame_idx] + to_torch( root_diff, device=target_character_motion.joints_pos.device, dtype=target_character_motion.joints_pos.dtype, )[None] ) rot = target_character_motion.joints_rot[frame_idx] target_character_motion.update_pose_at_frame( frame_idx, joints_pos=new_joints_pos, joints_rot=rot, ) target_character_motion.set_frame(frame_idx) @gui_clear_all_constraints_button.on_click def _(event: viser.GuiEvent) -> None: event_client = event.client session = get_active_session(event_client) if session is None: return with session.timeline_data["keyframe_update_lock"]: # use the lock here to wait for any constraint updates to finish for constraint in list(session.constraints.values()): constraint.clear() client.timeline.clear_keyframes() client.timeline.clear_intervals() if gui_dense_path_checkbox.value: gui_dense_path_checkbox.value = False if "2D Root" in session.constraints: session.constraints["2D Root"].set_dense_path(False) # generation callback @gui_generate_button.on_click def _(event: viser.GuiEvent) -> None: event_client = event.client session = get_active_session(event_client) if session is None: return generating_notif = event_client.add_notification( title="Generating motion...", body="Generating motions for the given prompt!", loading=True, with_close_button=False, ) gui_generate_button.disabled = True client.timeline.disable_constraints() num_samples = gui_num_samples_slider.value timeline = session.client.timeline # sort them to avoid issues: prompt_values = sorted([x for x in timeline._prompts.values()], key=lambda x: x.start_frame) texts = [x.text for x in prompt_values] num_frames = compute_prompt_num_frames(prompt_values) # compute the total duration total_nb_frames = sum(num_frames) total_duration = total_nb_frames / session.model_fps # update just in case set_new_duration(client_id, total_duration) transitions_parameters = { "num_transition_frames": gui_num_transition_frames_slider.value, "share_transition": gui_share_transition_checkbox.value, "percentage_transition_override": gui_percentage_transition_sharing_slider.value / 100, } # G1: postprocessing is disabled (does not work well for this model). postprocess_parameters = { "post_processing": (False if "g1" in session.model_name else gui_postprocess_checkbox.value), "root_margin": gui_root_margin.value, } try: demo.generate( event_client, texts, num_frames, num_samples, gui_seed.value, gui_diffusion_steps_slider.value, cfg_weight=[ gui_cfg_text_weight_slider.value, gui_cfg_constraint_weight_slider.value, ], cfg_type="separated" if gui_cfg_checkbox.value else "nocfg", postprocess_parameters=postprocess_parameters, transitions_parameters=transitions_parameters, real_robot_rotations=gui_real_robot_rotations_checkbox.value, ) session.max_frame_idx = int(session.cur_duration * session.model_fps - 1) session.max_frame_idx = int(session.cur_duration * session.model_fps) - 1 if session.frame_idx > session.max_frame_idx: session.frame_idx = session.max_frame_idx if num_samples > 1: # add mesh selector to choose character to commit def commit_motion(event: viser.GuiEvent) -> None: target = event.target commit_name = target.name.split("/")[1] # e.g. /character0/simple_skinned print(f"Committing motion for character: {commit_name}") # delete non-selected motions new_motion_kwargs = None for character_name, motion in session.motions.items(): if character_name == commit_name: new_motion_kwargs = { "skeleton": session.skeleton, "joints_rot": motion.joints_rot, "foot_contacts": motion.foot_contacts, } root_x_offset = motion.joints_pos[0, session.skeleton.root_idx, 0] new_joints_pos = motion.joints_pos.clone() new_joints_pos[..., 0] -= root_x_offset new_motion_kwargs["joints_pos"] = new_joints_pos break # clear and re-add the selected motion demo.clear_motions(event_client.client_id) demo.add_character_motion(event_client, **new_motion_kwargs) gui_edit_constraint_button.disabled = False gui_generate_button.disabled = False gui_snap_to_constraint_button.disabled = False client.timeline.enable_constraints() gui_generate_button.label = "Generate" gui_save_example_button.disabled = False gui_save_motion_button.disabled = False gui_download_button.disabled = False gui_save_constraints_button.disabled = False gui_load_example_button.disabled = False for motion in session.motions.values(): char = motion.character character_name = char.name # e.g. "character0" if char.skinned_mesh is not None: char.skinned_mesh.on_click(commit_motion) elif char.g1_mesh_rig is not None: # Register click on every part so any part can be clicked, # and use highlight_group so the whole robot highlights together. for handle in char.g1_mesh_rig.mesh_handles: handle.on_click(commit_motion, highlight_group=character_name) gui_edit_constraint_button.disabled = True gui_generate_button.disabled = True gui_snap_to_constraint_button.disabled = True gui_generate_button.label = "Choose Sample Before Generating" gui_save_example_button.disabled = True gui_save_motion_button.disabled = True gui_download_button.disabled = True gui_save_constraints_button.disabled = True gui_load_example_button.disabled = True else: gui_edit_constraint_button.disabled = False gui_generate_button.disabled = False gui_snap_to_constraint_button.disabled = False client.timeline.enable_constraints() generating_notif.title = "Motion generation finished!" generating_notif.body = "Motions have been generated successfully for the given prompt." if num_samples > 1: generating_notif.body += " Now choose which sample to commit." generating_notif.loading = False generating_notif.with_close_button = True generating_notif.auto_close_seconds = 5.0 generating_notif.color = "green" # put the motion at zero demo.set_frame(client_id, 0) except Exception as e: import traceback traceback.print_exc() print(f"Error during generation for client {event_client.client_id}: {e}") # Re-enable buttons and notify the user if event_client.client_id in demo.client_sessions: session = demo.client_sessions[event_client.client_id] gui_generate_button.disabled = False gui_load_example_button.disabled = False gui_save_example_button.disabled = False gui_save_motion_button.disabled = False gui_download_button.disabled = False # Reuse persistent notification instead of creating a new one try: generating_notif.title = "Generation failed!" generating_notif.body = f"Error: {str(e)}" generating_notif.loading = False generating_notif.with_close_button = True generating_notif.auto_close_seconds = 6.0 generating_notif.color = "red" except Exception: pass demo.check_cuda_health() # # Visualization settings # with tab_group.add_tab("Visualize", viser.Icon.EYE): with client.gui.add_folder("Playback", expand_by_default=True): gui_model_fps = client.gui.add_number("Model FPS", initial_value=model_fps, disabled=True) gui_playback_speed_buttons = client.gui.add_button_group( "Playback Speed", options=[ "0.5x", "1x", "2x", ], ) gui_playback_speed_buttons.value = "1x" @client.timeline.on_frame_change def handle_timeline_frame_change(new_frame_idx: int): """Update the frame when the user clicks on the timeline.""" demo.set_frame(client_id, new_frame_idx, update_timeline=False) session = demo.client_sessions.get(client_id) if session is not None: if session.edit_mode and session.motions: motion = list(session.motions.values())[0] snapshot_frame_idx = min(session.frame_idx, motion.length - 1) ensure_edit_snapshot(session, motion, snapshot_frame_idx) update_snap_to_constraint_button(session) @client.timeline.on_prompt_add async def _on_add( prompt_id: str, start_frame: int, end_frame: int, text: str, color: tuple[int, int, int] | None, ) -> None: update_duration_auto() @client.timeline.on_prompt_update async def _on_update(prompt_id: str, new_text: str) -> None: update_duration_auto() @client.timeline.on_prompt_resize async def _on_resize(prompt_id: str, new_start: int, new_end: int) -> None: update_duration_auto() @client.timeline.on_prompt_move async def _on_move(prompt_id: str, new_start: int, new_end: int) -> None: update_duration_auto() @client.timeline.on_prompt_delete async def _on_delete(prompt_id: str) -> None: update_duration_auto() def play_pause_button_callback(session: ClientSession): session.playing = not session.playing def next_frame_callback(session: ClientSession): if session.frame_idx < session.max_frame_idx: session.frame_idx += 1 if session.frame_idx == session.max_frame_idx: pass demo.set_frame(client_id, session.frame_idx) def prev_frame_callback(session: ClientSession): if session.frame_idx > 0: session.frame_idx -= 1 if session.frame_idx == 0: pass demo.set_frame(client_id, session.frame_idx) @gui_playback_speed_buttons.on_click def _(_) -> None: if not demo.client_active(client_id): return speed_map = { "0.5x": 0.5, "1x": 1.0, "2x": 2.0, } session = demo.client_sessions[client_id] session.playback_speed = speed_map[gui_playback_speed_buttons.value] with client.gui.add_folder("Body options", expand_by_default=True): gui_viz_skinned_mesh_checkbox = client.gui.add_checkbox("Show Mesh", initial_value=True) gui_viz_skinned_mesh_opacity_slider = client.gui.add_slider( "Mesh Opacity", min=0.0, max=1.0, step=0.01, initial_value=1.0 ) gui_viz_skeleton_checkbox = client.gui.add_checkbox("Show Skeleton", initial_value=False) gui_viz_foot_contacts_checkbox = client.gui.add_checkbox("Show Foot Contacts", initial_value=False) gui_viz_foot_contacts_checkbox.visible = gui_viz_skeleton_checkbox.value with client.gui.add_folder("Camera options", expand_by_default=True): gui_camera_fov_slider = client.gui.add_slider( "Camera FOV (deg)", min=30.0, max=90.0, step=1.0, initial_value=45.0, ) client.camera.fov = np.deg2rad(gui_camera_fov_slider.value) with client.gui.add_folder("Interface options", expand_by_default=True): gui_show_timeline_checkbox = client.gui.add_checkbox( "Show Timeline", initial_value=True, ) gui_show_constraint_tracks_checkbox = client.gui.add_checkbox( "Show Constraint tracks", initial_value=True, ) gui_show_constraint_labels_checkbox = client.gui.add_checkbox( "Show Constraint labels", initial_value=True, ) gui_show_starting_direction_checkbox = client.gui.add_checkbox( "Show Starting Direction", initial_value=True, ) gui_dark_mode_checkbox = client.gui.add_checkbox( "Dark Mode", initial_value=False, # Default to light mode ) gui_show_constraint_tracks_checkbox.visible = gui_show_timeline_checkbox.value demo.set_start_direction_visible(client_id, gui_show_starting_direction_checkbox.value) @gui_dark_mode_checkbox.on_update def _(_): # Apply the theme using configure_theme (pass uuid so titlebar toggle stays) demo.configure_theme( client, gui_dark_mode_checkbox.value, titlebar_dark_mode_checkbox_uuid=gui_dark_mode_checkbox.uuid, ) session = demo.client_sessions[client.client_id] for motion in session.motions.values(): motion.character.change_theme(gui_dark_mode_checkbox.value) # Show dark mode toggle in titlebar (right of Github), hide sidebar checkbox demo.configure_theme( client, gui_dark_mode_checkbox.value, titlebar_dark_mode_checkbox_uuid=gui_dark_mode_checkbox.uuid, ) gui_dark_mode_checkbox.visible = False @gui_show_constraint_labels_checkbox.on_update def _(_): if not demo.client_active(client_id): return session = demo.client_sessions[client_id] for constraint in session.constraints.values(): constraint.set_label_visibility(gui_show_constraint_labels_checkbox.value) @gui_show_timeline_checkbox.on_update def _(_): if not demo.client_active(client_id): return session = demo.client_sessions[client_id] session.client.timeline.set_visible(gui_show_timeline_checkbox.value) gui_show_constraint_tracks_checkbox.visible = gui_show_timeline_checkbox.value if gui_show_timeline_checkbox.value: demo.set_constraint_tracks_visible(session, gui_show_constraint_tracks_checkbox.value) @gui_show_constraint_tracks_checkbox.on_update def _(_): if not demo.client_active(client_id): return session = demo.client_sessions[client_id] demo.set_constraint_tracks_visible(session, gui_show_constraint_tracks_checkbox.value) @gui_show_starting_direction_checkbox.on_update def _(_): if not demo.client_active(client_id): return demo.set_start_direction_visible(client_id, gui_show_starting_direction_checkbox.value) @gui_viz_skeleton_checkbox.on_update def _(_) -> None: if not demo.client_active(client_id): return session = demo.client_sessions[client_id] gui_viz_foot_contacts_checkbox.visible = gui_viz_skeleton_checkbox.value if not gui_viz_skeleton_checkbox.value: gui_viz_foot_contacts_checkbox.value = False for motion in session.motions.values(): motion.character.set_skeleton_visibility(gui_viz_skeleton_checkbox.value) @gui_viz_foot_contacts_checkbox.on_update def _(_) -> None: if not demo.client_active(client_id): return session = demo.client_sessions[client_id] for motion in session.motions.values(): motion.character.set_show_foot_contacts( gui_viz_foot_contacts_checkbox.value, frame_idx=motion.cur_frame_idx ) @gui_viz_skinned_mesh_checkbox.on_update def _(_) -> None: if not demo.client_active(client_id): return session = demo.client_sessions[client_id] for motion in session.motions.values(): motion.character.set_skinned_mesh_visibility(gui_viz_skinned_mesh_checkbox.value) @gui_viz_skinned_mesh_opacity_slider.on_update def _(_) -> None: if not demo.client_active(client_id): return session = demo.client_sessions[client_id] for motion in session.motions.values(): motion.character.set_skinned_mesh_opacity(gui_viz_skinned_mesh_opacity_slider.value) @gui_camera_fov_slider.on_update def _(_) -> None: if not demo.client_active(client_id): return client.camera.fov = np.deg2rad(gui_camera_fov_slider.value) # # Instructions tab # with tab_group.add_tab("Instructions", viser.Icon.INFO_CIRCLE): client.gui.add_markdown(DEMO_UI_INSTRUCTIONS_TAB_MD) # # Keyboard events # @client.scene.on_keyboard_event("keydown", debounce_ms=100) def handle_key(event: viser.KeyboardEvent) -> None: # Check if client session still exists if client_id not in demo.client_sessions: return session = demo.client_sessions[client_id] # Space bar: only toggle on FIRST press if event.key == " ": now = time.monotonic() if now - session.last_space_toggle_time >= 0.2: session.last_space_toggle_time = now play_pause_button_callback(session) return # Handle arrow keys: frame navigation (fast OS repeat with 50ms debounce). elif event.key == "ArrowLeft": prev_frame_callback(session) elif event.key == "ArrowRight": next_frame_callback(session) gui_elements = GuiElements( gui_play_pause_button=gui_play_pause_button, gui_next_frame_button=gui_next_frame_button, gui_prev_frame_button=gui_prev_frame_button, gui_generate_button=gui_generate_button, gui_model_fps=gui_model_fps, gui_timeline=gui_timeline, gui_viz_skeleton_checkbox=gui_viz_skeleton_checkbox, gui_viz_foot_contacts_checkbox=gui_viz_foot_contacts_checkbox, gui_viz_skinned_mesh_checkbox=gui_viz_skinned_mesh_checkbox, gui_viz_skinned_mesh_opacity_slider=gui_viz_skinned_mesh_opacity_slider, gui_camera_fov_slider=gui_camera_fov_slider, gui_duration_slider=gui_duration_slider, gui_num_samples_slider=gui_num_samples_slider, gui_cfg_checkbox=gui_cfg_checkbox, gui_cfg_text_weight_slider=gui_cfg_text_weight_slider, gui_cfg_constraint_weight_slider=gui_cfg_constraint_weight_slider, gui_diffusion_steps_slider=gui_diffusion_steps_slider, gui_seed=gui_seed, gui_postprocess_checkbox=gui_postprocess_checkbox, gui_root_margin=gui_root_margin, gui_real_robot_rotations_checkbox=gui_real_robot_rotations_checkbox, gui_dark_mode_checkbox=gui_dark_mode_checkbox, gui_use_soma_layer_checkbox=gui_use_soma_layer_checkbox, ) return ( gui_elements, timeline_tracks, example_dict, gui_examples_dropdown, gui_save_example_path_text, gui_model_selector, )