|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import re |
| import tempfile |
| from functools import lru_cache |
|
|
| import gradio as gr |
| import h5py |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from huggingface_hub import hf_hub_download |
| from PIL import Image, ImageDraw |
|
|
| try: |
| import imageio.v2 as imageio |
| except Exception: |
| imageio = None |
|
|
| try: |
| import cv2 |
| except Exception: |
| cv2 = None |
|
|
|
|
| DATASET_PRESETS = { |
| "Robosuite Square Correction": { |
| "repo_id": "Zhaoting123/Robosuite_Square_image_abs_with_state", |
| "filename": ( |
| "20260410_205606_Diffusion_CLIC_intervention_Circular_square_image_abs_" |
| "Ta16_offlineFalse_Scale0.01/trajectory_buffer_0.hdf5" |
| ), |
| "default_reverse_channels": False, |
| }, |
| "InsertT Demonstration": { |
| "repo_id": "Zhaoting123/InsertT", |
| "filename": "trajectory_buffer_Nov10_demo.hdf5", |
| "default_reverse_channels": True, |
| }, |
| "InsertT Correction": { |
| "repo_id": "Zhaoting123/InsertT", |
| "filename": "trajectory_buffer_Nov11_intervention.hdf5", |
| "default_reverse_channels": True, |
| }, |
| "RoundTable Correction": { |
| "repo_id": "Zhaoting123/Furniture_Bench_Round_Table_Assembly", |
| "filename": "trajectory_buffer_0_Nov24_intervention_relabeled.hdf5", |
| "default_reverse_channels": True, |
| }, |
| } |
|
|
| DEFAULT_PRESET = "Robosuite Square Correction" |
| REPO_TYPE = "dataset" |
| DEFAULT_CHUNK_LEN = 16 |
| DEFAULT_DISPLAY_SCALE = 1 |
| VIDEO_STATUS_FIGSIZE = (6.0, 1.8) |
| VIDEO_STATUS_DPI = 120 |
| PREFERRED_IMAGE_KEYS = [ |
| "image1", |
| "image2", |
| "agentview_image", |
| "robot0_eye_in_hand_image", |
| "front_image", |
| "wrist_image", |
| ] |
| IMAGE_KEY_HINTS = ["rgb", "image", "img", "camera", "cam"] |
|
|
|
|
| def resolve_dataset(preset_name, custom_repo_id=None, custom_filename=None): |
| preset_name = preset_name or DEFAULT_PRESET |
| if preset_name == "Custom": |
| repo_id = str(custom_repo_id or "").strip() |
| filename = str(custom_filename or "").strip() |
| if not repo_id or not filename: |
| raise ValueError("For Custom mode, provide both repo_id and HDF5 filename/path.") |
| return repo_id, filename |
|
|
| item = DATASET_PRESETS.get(preset_name, DATASET_PRESETS[DEFAULT_PRESET]) |
| return item["repo_id"], item["filename"] |
|
|
|
|
| def get_default_reverse_channels(preset_name): |
| """Dataset-specific default for BGR<->RGB reversal. |
| |
| Robosuite Square presets use normal RGB ordering. |
| InsertT / PushT-style preset requires reversal. |
| Custom datasets default to False so users can still override manually. |
| """ |
| preset_name = preset_name or DEFAULT_PRESET |
| if preset_name == "Custom": |
| return False |
| item = DATASET_PRESETS.get(preset_name, DATASET_PRESETS[DEFAULT_PRESET]) |
| return bool(item.get("default_reverse_channels", False)) |
|
|
|
|
| @lru_cache(maxsize=8) |
| def get_local_hdf5_path(repo_id, filename): |
| return hf_hub_download(repo_id=repo_id, filename=filename, repo_type=REPO_TYPE) |
|
|
|
|
| def _natural_sort_key(name): |
| match = re.search(r"([0-9]+)$", str(name)) |
| if match: |
| return 0, int(match.group(1)) |
| return 1, str(name) |
|
|
|
|
| @lru_cache(maxsize=8) |
| def get_trajectory_keys(repo_id, filename): |
| path = get_local_hdf5_path(repo_id, filename) |
| with h5py.File(path, "r") as f: |
| root_episode_keys = [ |
| key for key in f.keys() |
| if isinstance(f[key], h5py.Group) and str(key).startswith("episode_") |
| ] |
| if root_episode_keys: |
| return tuple(sorted(root_episode_keys, key=_natural_sort_key)) |
|
|
| if "data" in f and isinstance(f["data"], h5py.Group): |
| data_group = f["data"] |
| keys = [key for key in data_group.keys() if isinstance(data_group[key], h5py.Group)] |
| return tuple("data/" + key for key in sorted(keys, key=_natural_sort_key)) |
|
|
| keys = [key for key in f.keys() if isinstance(f[key], h5py.Group)] |
| return tuple(sorted(keys, key=_natural_sort_key)) |
|
|
|
|
| @lru_cache(maxsize=8) |
| def get_num_trajectories(repo_id, filename): |
| return len(get_trajectory_keys(repo_id, filename)) |
|
|
|
|
| def inspect_hdf5_tree(preset_name, custom_repo_id, custom_filename, max_lines=180): |
| repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) |
| path = get_local_hdf5_path(repo_id, filename) |
|
|
| lines = [] |
| with h5py.File(path, "r") as f: |
| def visitor(name, obj): |
| if len(lines) >= max_lines: |
| return |
| if isinstance(obj, h5py.Dataset): |
| lines.append("DATASET {} shape={} dtype={}".format(name, obj.shape, obj.dtype)) |
| elif isinstance(obj, h5py.Group): |
| lines.append("GROUP {}".format(name)) |
|
|
| f.visititems(visitor) |
|
|
| if len(lines) >= max_lines: |
| lines.append("...") |
| return "\n".join(lines) if lines else "No HDF5 contents found." |
|
|
|
|
| def _read_dataset_value(dataset): |
| value = dataset[()] |
| if isinstance(value, bytes): |
| return value.decode("utf-8") |
| return value |
|
|
|
|
| def _read_group_recursive(group): |
| out = {} |
| for key, obj in group.items(): |
| if isinstance(obj, h5py.Dataset): |
| out[key] = _read_dataset_value(obj) |
| elif isinstance(obj, h5py.Group): |
| out[key] = _read_group_recursive(obj) |
| return out |
|
|
|
|
| def _find_first_key(mapping, candidate_keys): |
| for key in candidate_keys: |
| if key in mapping: |
| return key |
| return None |
|
|
|
|
| def _infer_time_length(data): |
| for key in ["timesteps", "dones", "robot_actions", "teacher_actions", "actions"]: |
| if key in data: |
| arr = np.asarray(data[key]) |
| if arr.ndim >= 1: |
| return int(arr.shape[0]) |
|
|
| obs_group = None |
| if isinstance(data.get("observation"), dict): |
| obs_group = data["observation"] |
| elif isinstance(data.get("obs"), dict): |
| obs_group = data["obs"] |
|
|
| if obs_group: |
| lengths = [] |
| for value in obs_group.values(): |
| arr = np.asarray(value) |
| if arr.ndim >= 1: |
| lengths.append(int(arr.shape[0])) |
| if lengths: |
| values, counts = np.unique(lengths, return_counts=True) |
| return int(values[np.argmax(counts)]) |
| return 1 |
|
|
|
|
| def _slice_time(value, t, T): |
| arr = np.asarray(value) |
| if arr.ndim >= 1 and arr.shape[0] == T: |
| return arr[t] |
| return arr |
|
|
|
|
| @lru_cache(maxsize=64) |
| def load_traj(repo_id, filename, traj_id): |
| traj_keys = get_trajectory_keys(repo_id, filename) |
| if not traj_keys: |
| return [] |
|
|
| traj_id = int(np.clip(int(traj_id), 0, len(traj_keys) - 1)) |
| traj_key = traj_keys[traj_id] |
| path = get_local_hdf5_path(repo_id, filename) |
|
|
| with h5py.File(path, "r") as f: |
| data = _read_group_recursive(f[traj_key]) |
|
|
| T = _infer_time_length(data) |
|
|
| if isinstance(data.get("observation"), dict): |
| obs_all = data["observation"] |
| elif isinstance(data.get("obs"), dict): |
| obs_all = data["obs"] |
| else: |
| obs_all = {} |
|
|
| action_key = _find_first_key(data, ["actions", "action"]) |
| teacher_key = _find_first_key(data, ["teacher_actions", "teacher_action"]) |
| robot_key = _find_first_key(data, ["robot_actions", "robot_action"]) |
| no_teacher_key = _find_first_key(data, ["no_teacher_actions", "no_teacher_action"]) |
| no_robot_key = _find_first_key(data, ["no_robot_actions", "no_robot_action"]) |
| done_key = _find_first_key(data, ["dones", "done"]) |
| timestep_key = _find_first_key(data, ["timesteps", "timestep"]) |
| success_key = _find_first_key(data, ["if_success", "success", "successes"]) |
|
|
| traj = [] |
| for t in range(T): |
| obs_t = {key: _slice_time(value, t, T) for key, value in obs_all.items()} |
|
|
| default_action = np.zeros(1, dtype=np.float32) |
| if action_key is not None: |
| default_action = _slice_time(data[action_key], t, T) |
|
|
| teacher_action = _slice_time(data[teacher_key], t, T) if teacher_key else default_action |
| robot_action = _slice_time(data[robot_key], t, T) if robot_key else default_action |
| no_teacher = _slice_time(data[no_teacher_key], t, T) if no_teacher_key else False |
| no_robot = _slice_time(data[no_robot_key], t, T) if no_robot_key else False |
| done = _slice_time(data[done_key], t, T) if done_key else False |
| if_success = _slice_time(data[success_key], t, T) if success_key else False |
|
|
| timestep = t |
| if timestep_key is not None: |
| timestep_arr = _slice_time(data[timestep_key], t, T) |
| timestep = int(np.asarray(timestep_arr).reshape(-1)[0]) |
|
|
| traj.append({ |
| "obs": obs_t, |
| "robot_action": np.asarray(robot_action), |
| "teacher_action": np.asarray(teacher_action), |
| "done": bool(np.asarray(done).reshape(-1)[0]), |
| "timestep": timestep, |
| "no_robot_action": bool(np.asarray(no_robot).reshape(-1)[0]), |
| "no_teacher_action": bool(np.asarray(no_teacher).reshape(-1)[0]), |
| "episode_id": traj_key, |
| "if_success": bool(np.asarray(if_success).reshape(-1)[0]), |
| }) |
|
|
| return traj |
|
|
|
|
| def _extract_latest_obs_value(value): |
| """Return the latest stacked observation only when there is a clear stack axis. |
| |
| Important: |
| - [obs_T, C, H, W] or [obs_T, H, W, C] should become the latest frame. |
| - [C, H, W] must NOT be sliced, otherwise an RGB image becomes one |
| grayscale channel. |
| """ |
| arr = np.asarray(value) |
|
|
| |
| if arr.ndim == 4 and arr.shape[0] in (1, 2, 3, 4): |
| channel_first = arr.shape[1] in (1, 3, 4) |
| channel_last = arr.shape[-1] in (1, 3, 4) |
| if channel_first or channel_last: |
| return arr[-1] |
|
|
| |
| if arr.ndim == 2 and arr.shape[0] in (1, 2): |
| return arr[-1] |
|
|
| return arr |
|
|
|
|
| def _looks_like_image_array(key, value): |
| arr = np.asarray(value) |
| key_l = str(key).lower() |
| key_hint = any(hint in key_l for hint in IMAGE_KEY_HINTS) |
|
|
| |
| if arr.ndim == 4 and arr.shape[0] in (1, 2, 3, 4): |
| if arr.shape[1] in (1, 3, 4) or arr.shape[-1] in (1, 3, 4): |
| arr = arr[-1] |
|
|
| shape_hint = False |
| if arr.ndim == 2: |
| shape_hint = True |
| elif arr.ndim == 3: |
| shape_hint = arr.shape[-1] in (1, 3, 4) or arr.shape[0] in (1, 3, 4) |
| elif arr.ndim == 4: |
| shape_hint = arr.shape[1] in (1, 3, 4) or arr.shape[-1] in (1, 3, 4) |
|
|
| return key_hint or shape_hint |
|
|
|
|
| def _float_img_to_uint8(img): |
| arr = img.astype(np.float32) |
| arr_min = float(np.nanmin(arr)) |
| arr_max = float(np.nanmax(arr)) |
|
|
| if arr_min >= -1.01 and arr_max <= 1.01: |
| if arr_min < 0.0: |
| arr = (arr + 1.0) * 0.5 |
| arr = np.clip(arr, 0.0, 1.0) * 255.0 |
| elif arr_max <= 255.0: |
| arr = np.clip(arr, 0.0, 255.0) |
| else: |
| arr = 255.0 * (arr - arr_min) / max(arr_max - arr_min, 1e-8) |
|
|
| return np.round(arr).astype(np.uint8) |
|
|
|
|
| def _extract_display_image(value, reverse_channels=False): |
| img = np.asarray(_extract_latest_obs_value(value)) |
|
|
| if img.ndim == 2: |
| img = np.repeat(img[..., None], 3, axis=-1) |
| elif img.ndim == 3 and img.shape[0] in (1, 3, 4): |
| img = np.transpose(img, (1, 2, 0)) |
|
|
| if img.ndim == 3 and img.shape[-1] == 1: |
| img = np.repeat(img, 3, axis=-1) |
| elif img.ndim == 3 and img.shape[-1] == 4: |
| img = img[..., :3] |
|
|
| if img.ndim != 3: |
| raise ValueError("Unsupported image shape: {}".format(img.shape)) |
|
|
| out = img.copy() if img.dtype == np.uint8 else _float_img_to_uint8(img) |
|
|
| if reverse_channels and out.shape[-1] == 3: |
| out = out[..., ::-1] |
| return out |
|
|
|
|
| def _resize_image_for_display(img, display_scale): |
| scale = float(display_scale) |
| if scale == 1.0: |
| return img |
|
|
| h, w = img.shape[:2] |
| new_size = (max(1, int(round(w * scale))), max(1, int(round(h * scale)))) |
|
|
| if cv2 is not None: |
| return cv2.resize(img, new_size, interpolation=cv2.INTER_NEAREST) |
|
|
| pil_img = Image.fromarray(img) |
| return np.asarray(pil_img.resize(new_size, resample=Image.Resampling.NEAREST)) |
|
|
|
|
| def _extract_mixed_action_chunk(traj, start_idx, chunk_length): |
| chunk = [] |
| sources = [] |
| end_idx = min(len(traj), int(start_idx) + int(chunk_length)) |
|
|
| for idx in range(int(start_idx), end_idx): |
| step = traj[idx] |
| use_teacher = not bool(step.get("no_teacher_action", False)) |
| action = step["teacher_action"] if use_teacher else step["robot_action"] |
| chunk.append(np.asarray(action, dtype=np.float32).reshape(-1)) |
| sources.append("T" if use_teacher else "R") |
|
|
| if not chunk: |
| return None, "" |
| return np.stack(chunk, axis=0), "".join(sources) |
|
|
|
|
| def _extract_robot_action_chunk(traj, start_idx, chunk_length): |
| chunk = [] |
| end_idx = min(len(traj), int(start_idx) + int(chunk_length)) |
|
|
| for idx in range(int(start_idx), end_idx): |
| step = traj[idx] |
| chunk.append(np.asarray(step["robot_action"], dtype=np.float32).reshape(-1)) |
|
|
| if not chunk: |
| return None |
| return np.stack(chunk, axis=0) |
|
|
|
|
| def _safe_array_str(value, precision=3, max_items=24): |
| arr = np.asarray(value).reshape(-1) |
| shown = arr[:max_items] |
| text = np.array2string(shown, precision=precision, separator=", ") |
| if arr.size > max_items: |
| text += " ... +{} more".format(arr.size - max_items) |
| return text |
|
|
|
|
| def _make_action_chunk_plot(mixed_chunk, robot_chunk): |
| if mixed_chunk is None: |
| return None |
|
|
| mixed_chunk = np.asarray(mixed_chunk, dtype=np.float32) |
| if mixed_chunk.ndim == 1: |
| mixed_chunk = mixed_chunk[:, None] |
|
|
| fig, ax = plt.subplots(figsize=(7, 3.2), dpi=140) |
| x = np.arange(mixed_chunk.shape[0]) |
| max_dims = min(mixed_chunk.shape[1], 10) |
|
|
| for dim in range(max_dims): |
| ax.plot(x, mixed_chunk[:, dim], label="mixed[{}]".format(dim)) |
|
|
| if robot_chunk is not None: |
| robot_chunk = np.asarray(robot_chunk, dtype=np.float32) |
| if robot_chunk.ndim == 1: |
| robot_chunk = robot_chunk[:, None] |
| for dim in range(min(robot_chunk.shape[1], max_dims)): |
| ax.plot( |
| x, |
| robot_chunk[:, dim], |
| linestyle="--", |
| alpha=0.55, |
| label="robot[{}]".format(dim), |
| ) |
|
|
| ax.set_title("Action chunk") |
| ax.set_xlabel("chunk step") |
| ax.set_ylabel("action value") |
| ax.grid(True, alpha=0.3) |
| ax.legend(loc="upper right", fontsize=7, ncol=2) |
| fig.tight_layout() |
| fig.canvas.draw() |
| rgba = np.asarray(fig.canvas.buffer_rgba()) |
| image = rgba[..., :3].copy() |
| plt.close(fig) |
| return image |
|
|
|
|
| @lru_cache(maxsize=8192) |
| def get_cached_gallery_items(repo_id, filename, traj_id, timestep, image_keys_tuple, display_scale, reverse_channels): |
| traj = load_traj(repo_id, filename, int(traj_id)) |
| timestep = int(np.clip(int(timestep), 0, len(traj) - 1)) |
| obs = traj[timestep].get("obs", {}) |
|
|
| gallery_items = [] |
| warnings = [] |
| for key in image_keys_tuple: |
| if key not in obs: |
| warnings.append("Missing image key: {}".format(key)) |
| continue |
| try: |
| img = _extract_display_image(obs[key], reverse_channels=bool(reverse_channels)) |
| img = _resize_image_for_display(img, float(display_scale)) |
| gallery_items.append((img, key)) |
| except Exception as exc: |
| warnings.append("{}: {}".format(key, exc)) |
|
|
| return gallery_items, tuple(warnings) |
|
|
|
|
| def _compute_valid_start_indices(traj, min_seq_len): |
| """Match the original local script's valid-start heuristic. |
| |
| A timestep is valid when the following min_seq_len steps all have |
| no_teacher_action == False. |
| """ |
| total_steps = len(traj) |
| min_seq_len = int(max(1, min_seq_len)) |
| no_teacher = np.asarray( |
| [int(bool(step.get("no_teacher_action", False))) for step in traj], |
| dtype=np.int32, |
| ) |
|
|
| valid_indices = [] |
| max_start = total_steps - min_seq_len + 1 |
| for t in range(max(0, max_start)): |
| if int(np.sum(no_teacher[t:t + min_seq_len])) == 0: |
| valid_indices.append(t) |
|
|
| return no_teacher, valid_indices |
|
|
|
|
| def _make_trajectory_status_plot(traj, timestep, min_seq_len): |
| """Render the same high-level status figure as the local matplotlib tool. |
| |
| Shows: |
| - orange no_teacher_action step plot |
| - green triangles for algorithmic valid start points |
| - black vertical cursor at current timestep |
| """ |
| total_steps = len(traj) |
| if total_steps == 0: |
| return None, False, 0 |
|
|
| timestep = int(np.clip(int(timestep), 0, total_steps - 1)) |
| timesteps = np.asarray( |
| [int(np.asarray(step.get("timestep", idx)).reshape(-1)[0]) for idx, step in enumerate(traj)], |
| dtype=np.int32, |
| ) |
| no_teacher, valid_indices = _compute_valid_start_indices(traj, min_seq_len) |
| is_valid_start = timestep in set(valid_indices) |
|
|
| fig, ax = plt.subplots(figsize=(10.5, 2.8), dpi=170) |
|
|
| ax.step( |
| np.arange(total_steps), |
| no_teacher, |
| where="post", |
| label="no_teacher_action", |
| color="orange", |
| ) |
|
|
| if valid_indices: |
| ax.scatter( |
| valid_indices, |
| [-0.15] * len(valid_indices), |
| color="green", |
| marker="^", |
| s=18, |
| label="Valid Start (len >= {})".format(int(min_seq_len)), |
| ) |
|
|
| ax.axvline(timestep, color="black", linestyle="-", alpha=0.85, linewidth=1.5) |
| ax.set_xlim(0, max(total_steps - 1, 1)) |
| ax.set_ylim(-0.38, 1.1) |
| ax.set_ylabel("Flag", fontsize=10) |
| ax.set_xlabel("Timestep index", fontsize=10) |
| ax.set_yticks([0, 1]) |
| ax.set_yticklabels(["False", "True"]) |
| ax.grid(True, axis="x", alpha=0.2) |
|
|
| title = "no_teacher_action | step {} / {}".format(timestep, total_steps - 1) |
| if is_valid_start: |
| title += " | VALID START" |
| ax.set_title(title, fontsize=11) |
| ax.tick_params(axis="both", labelsize=9) |
| ax.legend(loc="upper right", fontsize=9) |
|
|
| |
| saved_timestep = int(timesteps[timestep]) if len(timesteps) else timestep |
| if saved_timestep != timestep: |
| ax.text( |
| 0.01, |
| 0.04, |
| "saved timestep: {}".format(saved_timestep), |
| transform=ax.transAxes, |
| fontsize=8, |
| va="bottom", |
| ha="left", |
| ) |
|
|
| fig.tight_layout() |
| fig.canvas.draw() |
| rgba = np.asarray(fig.canvas.buffer_rgba()) |
| image = rgba[..., :3].copy() |
| plt.close(fig) |
|
|
| return image, bool(is_valid_start), len(valid_indices) |
|
|
|
|
| @lru_cache(maxsize=8192) |
| def get_cached_status_plot(repo_id, filename, traj_id, timestep, min_seq_len): |
| traj = load_traj(repo_id, filename, int(traj_id)) |
| timestep = int(np.clip(int(timestep), 0, len(traj) - 1)) |
| return _make_trajectory_status_plot(traj, timestep, int(min_seq_len)) |
|
|
|
|
| def preload_current_trajectory(preset_name, custom_repo_id, custom_filename, traj_id, image_keys, chunk_len, display_scale, reverse_channels): |
| repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) |
| n_traj = get_num_trajectories(repo_id, filename) |
| if n_traj == 0: |
| return "No trajectories found." |
|
|
| traj_id = int(np.clip(int(traj_id), 0, n_traj - 1)) |
| traj = load_traj(repo_id, filename, traj_id) |
| if not traj: |
| return "Trajectory could not be loaded." |
|
|
| if image_keys is None: |
| image_keys = [] |
| if isinstance(image_keys, str): |
| image_keys = [image_keys] |
| image_keys_tuple = tuple(image_keys) |
|
|
| total = len(traj) |
| for t in range(total): |
| get_cached_gallery_items(repo_id, filename, traj_id, t, image_keys_tuple, float(display_scale), bool(reverse_channels)) |
| get_cached_status_plot(repo_id, filename, traj_id, t, int(chunk_len)) |
|
|
| status = "Preloaded trajectory {}".format(traj_id) |
| status += "\nFrames cached: {}".format(total) |
| status += "\nImage keys: {}".format(", ".join(image_keys_tuple) if image_keys_tuple else "none") |
| return status |
|
|
|
|
| def _compose_video_frame(gallery_items, frame_label, status_plot=None): |
| """Compose one video frame. |
| |
| Top: selected observation images. |
| Bottom: trajectory-status plot with the moving timestep cursor. |
| |
| Important: do NOT downscale the status plot to the image width. The plot |
| contains tick labels and a legend, so preserving its native width makes the |
| generated MP4 much more readable. |
| """ |
| small_text_y = 3 |
|
|
| if not gallery_items: |
| obs_canvas = Image.new("RGB", (640, 360), color=(20, 20, 20)) |
| draw = ImageDraw.Draw(obs_canvas) |
| draw.text((8, small_text_y), "No selected image keys", fill=(255, 255, 255)) |
| else: |
| pil_images = [] |
| for img, label in gallery_items: |
| pil_img = Image.fromarray(np.asarray(img, dtype=np.uint8)).convert("RGB") |
|
|
| |
| label_h = 16 |
| panel = Image.new("RGB", (pil_img.width, pil_img.height + label_h), color=(0, 0, 0)) |
| panel.paste(pil_img, (0, label_h)) |
| draw = ImageDraw.Draw(panel) |
| draw.text((4, small_text_y), str(label), fill=(220, 220, 220)) |
| pil_images.append(panel) |
|
|
| gap = 8 |
| top_h = 18 |
| width = sum(im.width for im in pil_images) + gap * max(len(pil_images) - 1, 0) |
| height = max(im.height for im in pil_images) + top_h |
| obs_canvas = Image.new("RGB", (width, height), color=(0, 0, 0)) |
| draw = ImageDraw.Draw(obs_canvas) |
|
|
| |
| draw.text((6, small_text_y), frame_label, fill=(220, 220, 220)) |
|
|
| x = 0 |
| for im in pil_images: |
| obs_canvas.paste(im, (x, top_h)) |
| x += im.width + gap |
|
|
| if status_plot is not None: |
| status_img = Image.fromarray(np.asarray(status_plot, dtype=np.uint8)).convert("RGB") |
|
|
| |
| |
| final_w = max(obs_canvas.width, status_img.width) |
| if obs_canvas.width < final_w: |
| padded_obs = Image.new("RGB", (final_w, obs_canvas.height), color=(0, 0, 0)) |
| padded_obs.paste(obs_canvas, ((final_w - obs_canvas.width) // 2, 0)) |
| obs_canvas = padded_obs |
| elif status_img.width < final_w: |
| padded_status = Image.new("RGB", (final_w, status_img.height), color=(255, 255, 255)) |
| padded_status.paste(status_img, ((final_w - status_img.width) // 2, 0)) |
| status_img = padded_status |
|
|
| gap_h = 8 |
| canvas = Image.new( |
| "RGB", |
| (final_w, obs_canvas.height + gap_h + status_img.height), |
| color=(0, 0, 0), |
| ) |
| canvas.paste(obs_canvas, (0, 0)) |
| canvas.paste(status_img, (0, obs_canvas.height + gap_h)) |
| else: |
| canvas = obs_canvas |
|
|
| |
| pad_w = int(np.ceil(canvas.width / 16.0) * 16) |
| pad_h = int(np.ceil(canvas.height / 16.0) * 16) |
| if pad_w != canvas.width or pad_h != canvas.height: |
| padded = Image.new("RGB", (pad_w, pad_h), color=(0, 0, 0)) |
| padded.paste(canvas, (0, 0)) |
| canvas = padded |
|
|
| return np.asarray(canvas) |
|
|
|
|
| @lru_cache(maxsize=128) |
| def get_video_status_plot_base(repo_id, filename, traj_id, valid_window_len): |
| """Render the static part of the status plot once for video export. |
| |
| Matplotlib per frame is slow. This function draws no_teacher_action and |
| valid-start markers once, records the axes pixel bounds, and returns a base |
| image. The moving cursor is later drawn with PIL, which is much faster. |
| """ |
| traj = load_traj(repo_id, filename, int(traj_id)) |
| total_steps = len(traj) |
| if total_steps == 0: |
| return None, (0, 0, 1, 1), 0 |
|
|
| no_teacher, valid_indices = _compute_valid_start_indices(traj, int(valid_window_len)) |
|
|
| fig, ax = plt.subplots(figsize=VIDEO_STATUS_FIGSIZE, dpi=VIDEO_STATUS_DPI) |
| ax.step( |
| np.arange(total_steps), |
| no_teacher, |
| where="post", |
| label="no_teacher_action", |
| color="orange", |
| ) |
|
|
| if valid_indices: |
| ax.scatter( |
| valid_indices, |
| [-0.15] * len(valid_indices), |
| color="green", |
| marker="^", |
| s=18, |
| label="Valid Start (len >= {})".format(int(valid_window_len)), |
| ) |
|
|
| ax.set_xlim(0, max(total_steps - 1, 1)) |
| ax.set_ylim(-0.38, 1.1) |
| ax.set_ylabel("Flag", fontsize=8) |
| ax.set_xlabel("Timestep index", fontsize=8) |
| ax.set_yticks([0, 1]) |
| ax.set_yticklabels(["False", "True"]) |
| ax.grid(True, axis="x", alpha=0.2) |
| ax.set_title("no_teacher_action and valid starts", fontsize=9) |
| ax.tick_params(axis="both", labelsize=7) |
| ax.legend(loc="upper right", fontsize=7) |
| fig.tight_layout() |
| fig.canvas.draw() |
|
|
| rgba = np.asarray(fig.canvas.buffer_rgba()) |
| base = rgba[..., :3].copy() |
|
|
| bbox = ax.get_window_extent() |
| height = base.shape[0] |
|
|
| |
| x0 = int(round(bbox.x0)) |
| x1 = int(round(bbox.x1)) |
| y0 = int(round(height - bbox.y1)) |
| y1 = int(round(height - bbox.y0)) |
|
|
| plt.close(fig) |
| return base, (x0, y0, x1, y1), total_steps |
|
|
|
|
| @lru_cache(maxsize=8192) |
| def get_cached_video_status_frame(repo_id, filename, traj_id, timestep, valid_window_len): |
| """Draw the moving cursor on a cached static status plot.""" |
| base, bounds, total_steps = get_video_status_plot_base( |
| repo_id, |
| filename, |
| int(traj_id), |
| int(valid_window_len), |
| ) |
| if base is None: |
| return None |
|
|
| timestep = int(np.clip(int(timestep), 0, max(total_steps - 1, 0))) |
| x0, y0, x1, y1 = bounds |
| denom = max(total_steps - 1, 1) |
| x = int(round(x0 + (x1 - x0) * float(timestep) / float(denom))) |
|
|
| img = Image.fromarray(np.asarray(base, dtype=np.uint8)).convert("RGB") |
| draw = ImageDraw.Draw(img) |
|
|
| |
| draw.line([(x, y0), (x, y1)], fill=(0, 0, 0), width=4) |
|
|
| |
| label = "step {}/{}".format(timestep, total_steps - 1) |
| draw.rectangle((x0 + 4, y0 + 4, x0 + 118, y0 + 24), fill=(255, 255, 255)) |
| draw.text((x0 + 8, y0 + 7), label, fill=(0, 0, 0)) |
|
|
| return np.asarray(img) |
|
|
| def _draw_status_cursor_on_base(base, bounds, total_steps, timestep): |
| """Fast video status frame: copy one static Matplotlib image and draw cursor. |
| |
| This avoids calling the lru-cached per-timestep status frame function during |
| video export. For long trajectories, caching thousands of status images can |
| consume a lot of memory and still requires PIL conversion for every frame. |
| """ |
| if base is None: |
| return None |
|
|
| total_steps = int(max(total_steps, 1)) |
| timestep = int(np.clip(int(timestep), 0, total_steps - 1)) |
| x0, y0, x1, y1 = [int(v) for v in bounds] |
| denom = max(total_steps - 1, 1) |
| x = int(round(x0 + (x1 - x0) * float(timestep) / float(denom))) |
|
|
| img = np.asarray(base, dtype=np.uint8).copy() |
|
|
| |
| |
| x_left = max(0, x - 2) |
| x_right = min(img.shape[1], x + 2) |
| y_top = max(0, y0) |
| y_bottom = min(img.shape[0], y1) |
| img[y_top:y_bottom, x_left:x_right, :] = 0 |
|
|
| |
| pil_img = Image.fromarray(img).convert("RGB") |
| draw = ImageDraw.Draw(pil_img) |
| label = "step {}/{}".format(timestep, total_steps - 1) |
| draw.rectangle((x0 + 4, y0 + 4, x0 + 126, y0 + 24), fill=(255, 255, 255)) |
| draw.text((x0 + 8, y0 + 7), label, fill=(0, 0, 0)) |
| return np.asarray(pil_img) |
|
|
|
|
| def _get_fast_video_writer(out_path, fps): |
| """Use ffmpeg's ultrafast x264 preset for interactive Spaces exports.""" |
| return imageio.get_writer( |
| out_path, |
| fps=float(fps), |
| codec="libx264", |
| macro_block_size=16, |
| ffmpeg_params=[ |
| "-preset", "ultrafast", |
| "-crf", "28", |
| "-pix_fmt", "yuv420p", |
| "-movflags", "+faststart", |
| ], |
| ) |
|
|
|
|
| def build_current_trajectory_video(preset_name, custom_repo_id, custom_filename, traj_id, image_keys, display_scale, reverse_channels, fps, valid_window_len, video_stride=4): |
| if imageio is None: |
| return None, "Video export requires imageio and imageio-ffmpeg in requirements.txt." |
|
|
| repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) |
| n_traj = get_num_trajectories(repo_id, filename) |
| if n_traj == 0: |
| return None, "No trajectories found." |
|
|
| traj_id = int(np.clip(int(traj_id), 0, n_traj - 1)) |
| traj = load_traj(repo_id, filename, traj_id) |
| if not traj: |
| return None, "Trajectory could not be loaded." |
|
|
| if image_keys is None: |
| image_keys = [] |
| if isinstance(image_keys, str): |
| image_keys = [image_keys] |
| image_keys_tuple = tuple(image_keys) |
|
|
| video_stride = int(max(1, int(video_stride))) |
| frame_indices = list(range(0, len(traj), video_stride)) |
| if frame_indices and frame_indices[-1] != len(traj) - 1: |
| frame_indices.append(len(traj) - 1) |
|
|
| safe_repo = re.sub(r"[^A-Za-z0-9_.-]+", "_", repo_id) |
| safe_file = re.sub(r"[^A-Za-z0-9_.-]+", "_", filename)[-80:] |
| out_path = os.path.join( |
| tempfile.gettempdir(), |
| "trajectory_{}_{}_traj{:04d}_fps{}_stride{}.mp4".format( |
| safe_repo, safe_file, traj_id, int(fps), video_stride |
| ), |
| ) |
|
|
| |
| status_base, status_bounds, total_steps = get_video_status_plot_base( |
| repo_id, |
| filename, |
| traj_id, |
| int(valid_window_len), |
| ) |
|
|
| writer = _get_fast_video_writer(out_path, fps) |
| written = 0 |
| try: |
| for t in frame_indices: |
| |
| |
| gallery_items, _warnings = get_cached_gallery_items( |
| repo_id, |
| filename, |
| traj_id, |
| t, |
| image_keys_tuple, |
| float(display_scale), |
| bool(reverse_channels), |
| ) |
| label = "trajectory {} | frame {}/{}".format(traj_id, t, len(traj) - 1) |
| status_plot = _draw_status_cursor_on_base(status_base, status_bounds, total_steps, t) |
| frame = _compose_video_frame(gallery_items, label, status_plot=status_plot) |
| writer.append_data(frame) |
| written += 1 |
| finally: |
| writer.close() |
|
|
| approx_seconds = float(written) / float(max(float(fps), 1.0)) |
| status = "Built trajectory video with optimized encoder and status rendering" |
| status += "\nTrajectory: {}".format(traj_id) |
| status += "\nOriginal timesteps: {} | Written frames: {} | Stride: {}".format(len(traj), written, video_stride) |
| status += "\nFPS: {} | Approx video duration: {:.1f}s".format(fps, approx_seconds) |
| status += "\nValid-window length: {}".format(int(valid_window_len)) |
| status += "\nSpeedups: x264 ultrafast preset; static status plot rendered once; cursor drawn with NumPy/PIL" |
| return out_path, status |
|
|
| def get_available_image_keys(repo_id, filename, traj_id): |
| n_traj = get_num_trajectories(repo_id, filename) |
| if n_traj == 0: |
| return [] |
|
|
| traj_id = int(np.clip(int(traj_id), 0, n_traj - 1)) |
| traj = load_traj(repo_id, filename, traj_id) |
| if not traj: |
| return [] |
|
|
| obs = traj[0].get("obs", {}) |
| keys = [] |
| for key, value in obs.items(): |
| try: |
| if _looks_like_image_array(key, value): |
| keys.append(key) |
| except Exception: |
| pass |
|
|
| ordered = [key for key in PREFERRED_IMAGE_KEYS if key in keys] |
| ordered.extend([key for key in keys if key not in ordered]) |
| return ordered |
|
|
|
|
| def update_custom_visibility(preset_name): |
| visible = preset_name == "Custom" |
| return gr.update(visible=visible), gr.update(visible=visible) |
|
|
|
|
| def update_after_dataset_change(preset_name, custom_repo_id, custom_filename): |
| repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) |
| n_traj = get_num_trajectories(repo_id, filename) |
|
|
| reverse_default = get_default_reverse_channels(preset_name) |
|
|
| if n_traj == 0: |
| status = "Loaded `{}` / `{}`".format(repo_id, filename) |
| status += "\nDetected trajectories: 0" |
| status += "\nreverse_channels default: {}".format(int(reverse_default)) |
| return ( |
| gr.update(maximum=1, value=0), |
| gr.update(maximum=1, value=0), |
| gr.update(choices=[], value=[]), |
| status, |
| gr.update(value=reverse_default), |
| ) |
|
|
| keys = get_available_image_keys(repo_id, filename, 0) |
| traj = load_traj(repo_id, filename, 0) |
|
|
| status = "Loaded `{}` / `{}`".format(repo_id, filename) |
| status += "\nDetected trajectories: {}".format(n_traj) |
| status += "\nreverse_channels default: {}".format(int(reverse_default)) |
|
|
| return ( |
| gr.update(maximum=max(n_traj - 1, 1), value=0), |
| gr.update(maximum=max(len(traj) - 1, 1), value=0), |
| gr.update(choices=keys, value=keys[:2]), |
| status, |
| gr.update(value=reverse_default), |
| ) |
|
|
|
|
| def update_after_traj_change(preset_name, custom_repo_id, custom_filename, traj_id): |
| repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) |
| n_traj = get_num_trajectories(repo_id, filename) |
| if n_traj == 0: |
| return gr.update(maximum=1, value=0), gr.update(choices=[], value=[]) |
|
|
| traj_id = int(np.clip(int(traj_id), 0, n_traj - 1)) |
| traj = load_traj(repo_id, filename, traj_id) |
| keys = get_available_image_keys(repo_id, filename, traj_id) |
|
|
| return ( |
| gr.update(maximum=max(len(traj) - 1, 1), value=0), |
| gr.update(choices=keys, value=keys[:2]), |
| ) |
|
|
|
|
| def render_frame(preset_name, custom_repo_id, custom_filename, traj_id, timestep, image_keys, chunk_len, display_scale, reverse_channels): |
| repo_id, filename = resolve_dataset(preset_name, custom_repo_id, custom_filename) |
| n_traj = get_num_trajectories(repo_id, filename) |
|
|
| if n_traj == 0: |
| return [], None, "No trajectory groups found. Open Debug: HDF5 tree." |
|
|
| traj_id = int(np.clip(int(traj_id), 0, n_traj - 1)) |
| traj = load_traj(repo_id, filename, traj_id) |
| if not traj: |
| return [], None, "Trajectory could not be loaded. Open Debug: HDF5 tree." |
|
|
| timestep = int(np.clip(int(timestep), 0, len(traj) - 1)) |
| chunk_len = int(chunk_len) |
| display_scale = float(display_scale) |
|
|
| if image_keys is None: |
| image_keys = [] |
| if isinstance(image_keys, str): |
| image_keys = [image_keys] |
|
|
| step = traj[timestep] |
| image_keys_tuple = tuple(image_keys) |
|
|
| gallery_items, warnings_tuple = get_cached_gallery_items( |
| repo_id, filename, traj_id, timestep, image_keys_tuple, display_scale, bool(reverse_channels) |
| ) |
| warnings = list(warnings_tuple) |
|
|
| status_plot, is_valid_start, num_valid_starts = get_cached_status_plot(repo_id, filename, traj_id, timestep, chunk_len) |
|
|
| image_debug_lines = [] |
| for _key in image_keys: |
| if _key in step.get("obs", {}): |
| _arr = np.asarray(step["obs"][_key]) |
| image_debug_lines.append( |
| "{} shape={} dtype={}".format(_key, tuple(_arr.shape), _arr.dtype) |
| ) |
|
|
| info_lines = [ |
| "dataset: {} / {}".format(repo_id, filename), |
| "detected trajectories: {}".format(n_traj), |
| "trajectory: {}".format(traj_id), |
| "episode_id: {}".format(step.get("episode_id", "")), |
| "timestep: {} / {}".format(timestep, len(traj) - 1), |
| "saved timestep: {}".format(step.get("timestep", timestep)), |
| "done: {}".format(int(bool(step.get("done", False)))), |
| "if_success: {}".format(int(bool(step.get("if_success", False)))), |
| "no_teacher_action: {}".format(int(bool(step.get("no_teacher_action", False)))), |
| "no_robot_action: {}".format(int(bool(step.get("no_robot_action", False)))), |
| "valid-window length: {}".format(chunk_len), |
| "valid_start: {}".format(int(bool(is_valid_start))), |
| "num_valid_starts: {}".format(num_valid_starts), |
| "", |
| "teacher_action: {}".format(_safe_array_str(step.get("teacher_action", []))), |
| "robot_action: {}".format(_safe_array_str(step.get("robot_action", []))), |
| "", |
| "selected image tensors:", |
| *image_debug_lines, |
| ] |
|
|
| if warnings: |
| info_lines.append("") |
| info_lines.append("Image warnings:") |
| info_lines.extend(warnings) |
|
|
| return gallery_items, status_plot, "\n".join(info_lines) |
|
|
|
|
| def build_app(): |
| repo_id, filename = resolve_dataset(DEFAULT_PRESET) |
|
|
| try: |
| n_traj = get_num_trajectories(repo_id, filename) |
| first_keys = get_available_image_keys(repo_id, filename, 0) if n_traj else [] |
| startup_warning = "" |
| except Exception as exc: |
| n_traj = 0 |
| first_keys = [] |
| startup_warning = repr(exc) |
|
|
| default_status = "Loaded default dataset\nDetected trajectories: {}\nreverse_channels default: {}".format(n_traj, int(get_default_reverse_channels(DEFAULT_PRESET))) |
|
|
| with gr.Blocks(title="HDF5 Trajectory Viewer") as demo: |
| gr.Markdown( |
| "# HDF5 Trajectory Viewer\n\n" |
| "Standalone viewer for TrajectoryBuffer-style HDF5 datasets on Hugging Face.\n\n" |
| "The status plot matches the local labeling view: orange `no_teacher_action`, green valid-start markers, and a black timestep cursor." |
| ) |
|
|
| if startup_warning: |
| gr.Markdown("Startup warning: `{}`".format(startup_warning)) |
|
|
| with gr.Row(): |
| preset = gr.Dropdown( |
| choices=list(DATASET_PRESETS.keys()) + ["Custom"], |
| value=DEFAULT_PRESET, |
| label="Dataset preset", |
| ) |
| custom_repo_id = gr.Textbox(value="", label="Custom repo_id, e.g. Zhaoting123/InsertT", visible=False) |
| custom_filename = gr.Textbox(value="", label="Custom HDF5 path in repo", visible=False) |
|
|
| dataset_status = gr.Textbox(label="Dataset status", lines=2, value=default_status, interactive=False) |
|
|
| with gr.Row(): |
| traj_slider = gr.Slider(minimum=0, maximum=max(n_traj - 1, 1), value=0, step=1, label="Trajectory index") |
| timestep_slider = gr.Slider(minimum=0, maximum=1, value=0, step=1, label="Timestep") |
|
|
| with gr.Row(): |
| image_keys = gr.CheckboxGroup(choices=first_keys, value=first_keys[:2], label="Image keys") |
| chunk_len = gr.Slider(minimum=1, maximum=64, value=DEFAULT_CHUNK_LEN, step=1, label="Valid-window length") |
| display_scale = gr.State(value=DEFAULT_DISPLAY_SCALE) |
| reverse_channels = gr.Checkbox(value=get_default_reverse_channels(DEFAULT_PRESET), label="Reverse channels BGR↔RGB") |
|
|
| with gr.Row(): |
| render_btn = gr.Button("Render frame", variant="primary") |
| preload_btn = gr.Button("Preload current trajectory") |
| video_btn = gr.Button("Build trajectory video") |
| video_fps = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Video FPS") |
| video_stride = gr.Slider(minimum=1, maximum=10, value=4, step=1, label="Video frame stride") |
|
|
| preload_status = gr.Textbox(label="Preload / video status", lines=4, value="Not preloaded yet.", interactive=False) |
|
|
| with gr.Row(): |
| with gr.Column(scale=3): |
| gallery = gr.Gallery( |
| label="Camera images", |
| columns=2, |
| height=360, |
| object_fit="contain", |
| ) |
| with gr.Column(scale=2): |
| status_plot = gr.Image( |
| label="no_teacher_action + valid starts", |
| type="numpy", |
| height=360, |
| ) |
|
|
| trajectory_video = gr.Video(label="Trajectory video: smooth browser-side playback") |
| info = gr.Textbox(label="Frame info", lines=16) |
|
|
| with gr.Accordion("Debug: HDF5 tree", open=False): |
| inspect_btn = gr.Button("Inspect HDF5 structure") |
| hdf5_tree = gr.Textbox(lines=24, label="HDF5 tree") |
|
|
| preset.change( |
| fn=update_custom_visibility, |
| inputs=preset, |
| outputs=[custom_repo_id, custom_filename], |
| ).then( |
| fn=update_after_dataset_change, |
| inputs=[preset, custom_repo_id, custom_filename], |
| outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels], |
| ).then( |
| fn=render_frame, |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], |
| outputs=[gallery, status_plot, info], |
| ) |
|
|
| custom_repo_id.submit( |
| fn=update_after_dataset_change, |
| inputs=[preset, custom_repo_id, custom_filename], |
| outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels], |
| ) |
| custom_filename.submit( |
| fn=update_after_dataset_change, |
| inputs=[preset, custom_repo_id, custom_filename], |
| outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels], |
| ) |
|
|
| traj_slider.change( |
| fn=update_after_traj_change, |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider], |
| outputs=[timestep_slider, image_keys], |
| ).then( |
| fn=render_frame, |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], |
| outputs=[gallery, status_plot, info], |
| ) |
|
|
| timestep_slider.release( |
| fn=render_frame, |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], |
| outputs=[gallery, status_plot, info], |
| ) |
|
|
| for widget in [image_keys, chunk_len, reverse_channels]: |
| widget.change( |
| fn=render_frame, |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], |
| outputs=[gallery, status_plot, info], |
| ) |
|
|
| render_btn.click( |
| fn=render_frame, |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], |
| outputs=[gallery, status_plot, info], |
| ) |
|
|
| preload_btn.click( |
| fn=preload_current_trajectory, |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider, image_keys, chunk_len, display_scale, reverse_channels], |
| outputs=preload_status, |
| ) |
|
|
| video_btn.click( |
| fn=build_current_trajectory_video, |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider, image_keys, display_scale, reverse_channels, video_fps, chunk_len, video_stride], |
| outputs=[trajectory_video, preload_status], |
| ) |
|
|
| inspect_btn.click( |
| fn=inspect_hdf5_tree, |
| inputs=[preset, custom_repo_id, custom_filename], |
| outputs=hdf5_tree, |
| ) |
|
|
| demo.load( |
| fn=update_after_dataset_change, |
| inputs=[preset, custom_repo_id, custom_filename], |
| outputs=[traj_slider, timestep_slider, image_keys, dataset_status, reverse_channels], |
| ).then( |
| fn=render_frame, |
| inputs=[preset, custom_repo_id, custom_filename, traj_slider, timestep_slider, image_keys, chunk_len, display_scale, reverse_channels], |
| outputs=[gallery, status_plot, info], |
| ) |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| demo = build_app() |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=int(os.environ.get("PORT", 7860)), |
| share=False, |
| ssr_mode=False, |
| ) |