| | import os |
| | import torch |
| | import gradio as gr |
| | import random |
| | import pandas as pd |
| | import numpy as np |
| | from wm.dataset.data_config import DATASET_REGISTRY |
| |
|
| | def load_dataset_samples(dataset_name): |
| | if dataset_name not in DATASET_REGISTRY: |
| | return [None] * 10 + [None] * 10 + [f"Error: Dataset {dataset_name} not found."] |
| | |
| | config = DATASET_REGISTRY[dataset_name] |
| | root_dir = config["root_dir"] |
| | metadata_path = os.path.join(root_dir, "metadata.pt") |
| | |
| | if not os.path.exists(metadata_path): |
| | return [None] * 10 + [None] * 10 + [f"Error: Metadata not found at {metadata_path}. Wait for conversion to finish if this is a new dataset."] |
| | |
| | try: |
| | metadata = torch.load(metadata_path, weights_only=False) |
| | except Exception as e: |
| | return [None] * 10 + [None] * 10 + [f"Error loading metadata: {e}"] |
| | |
| | num_samples = min(10, len(metadata)) |
| | indices = random.sample(range(len(metadata)), num_samples) |
| | |
| | videos = [] |
| | actions_text = [] |
| | |
| | for idx in indices: |
| | entry = metadata[idx] |
| | video_rel_path = entry["video_path"] |
| | video_full_path = os.path.join(root_dir, video_rel_path) |
| | |
| | if not os.path.exists(video_full_path): |
| | videos.append(None) |
| | actions_text.append(f"Video not found: {video_full_path}") |
| | continue |
| | |
| | videos.append(video_full_path) |
| | |
| | |
| | if "actions" in entry: |
| | act = entry["actions"] |
| | elif "commands" in entry: |
| | |
| | lin = entry["commands"]["linear_velocity"] |
| | ang = entry["commands"]["angular_velocity"] |
| | act = torch.stack([lin, ang], dim=-1) |
| | else: |
| | act = torch.zeros((1, 0)) |
| | |
| | if isinstance(act, torch.Tensor): |
| | act_np = act.numpy() |
| | else: |
| | act_np = np.array(act) |
| | |
| | |
| | actions_str = f"Shape: {act_np.shape}\n" |
| | actions_str += str(act_np) |
| | actions_text.append(actions_str) |
| | |
| | |
| | while len(videos) < 10: |
| | videos.append(None) |
| | actions_text.append("") |
| | |
| | return videos + actions_text + [f"Loaded {num_samples} samples from {dataset_name}"] |
| |
|
| | def create_visualizer(): |
| | with gr.Blocks(title="Dataset Visualizer") as demo: |
| | gr.Markdown("# ๐ World Model Dataset Visualizer") |
| | |
| | with gr.Row(): |
| | dataset_dropdown = gr.Dropdown( |
| | choices=list(DATASET_REGISTRY.keys()), |
| | label="Select Dataset", |
| | value="language_table" |
| | ) |
| | refresh_btn = gr.Button("๐ Random Sample 10", variant="primary") |
| | |
| | status_text = gr.Markdown("Select a dataset and click refresh.") |
| | |
| | video_components = [] |
| | action_components = [] |
| | |
| | for i in range(5): |
| | with gr.Row(): |
| | with gr.Column(): |
| | v = gr.Video(label=f"Sample {i*2}", interactive=False) |
| | a = gr.Textbox(label=f"Actions {i*2}", lines=10) |
| | video_components.append(v) |
| | action_components.append(a) |
| | with gr.Column(): |
| | v = gr.Video(label=f"Sample {i*2+1}", interactive=False) |
| | a = gr.Textbox(label=f"Actions {i*2+1}", lines=10) |
| | video_components.append(v) |
| | action_components.append(a) |
| | |
| | def on_refresh(name): |
| | results = load_dataset_samples(name) |
| | |
| | return results |
| | |
| | refresh_btn.click( |
| | fn=on_refresh, |
| | inputs=[dataset_dropdown], |
| | outputs=video_components + action_components + [status_text], |
| | api_name=False |
| | ) |
| | |
| | return demo |
| |
|
| | if __name__ == "__main__": |
| | demo = create_visualizer() |
| | demo.launch(server_name="0.0.0.0", server_port=7861, share=True) |
| |
|