import os import torch import gradio as gr import random import pandas as pd import numpy as np from wm.dataset.data_config import DATASET_REGISTRY def load_dataset_samples(dataset_name): if dataset_name not in DATASET_REGISTRY: return [None] * 10 + [None] * 10 + [f"Error: Dataset {dataset_name} not found."] config = DATASET_REGISTRY[dataset_name] root_dir = config["root_dir"] metadata_path = os.path.join(root_dir, "metadata.pt") if not os.path.exists(metadata_path): return [None] * 10 + [None] * 10 + [f"Error: Metadata not found at {metadata_path}. Wait for conversion to finish if this is a new dataset."] try: metadata = torch.load(metadata_path, weights_only=False) except Exception as e: return [None] * 10 + [None] * 10 + [f"Error loading metadata: {e}"] num_samples = min(10, len(metadata)) indices = random.sample(range(len(metadata)), num_samples) videos = [] actions_text = [] for idx in indices: entry = metadata[idx] video_rel_path = entry["video_path"] video_full_path = os.path.join(root_dir, video_rel_path) if not os.path.exists(video_full_path): videos.append(None) actions_text.append(f"Video not found: {video_full_path}") continue videos.append(video_full_path) # Format actions if "actions" in entry: act = entry["actions"] elif "commands" in entry: # Handle RECON commands lin = entry["commands"]["linear_velocity"] ang = entry["commands"]["angular_velocity"] act = torch.stack([lin, ang], dim=-1) else: act = torch.zeros((1, 0)) if isinstance(act, torch.Tensor): act_np = act.numpy() else: act_np = np.array(act) # Create a string representation for the textbox actions_str = f"Shape: {act_np.shape}\n" actions_str += str(act_np) actions_text.append(actions_str) # Pad to 10 if necessary while len(videos) < 10: videos.append(None) actions_text.append("") return videos + actions_text + [f"Loaded {num_samples} samples from {dataset_name}"] def create_visualizer(): with gr.Blocks(title="Dataset Visualizer") as demo: gr.Markdown("# 📊 World Model Dataset Visualizer") with gr.Row(): dataset_dropdown = gr.Dropdown( choices=list(DATASET_REGISTRY.keys()), label="Select Dataset", value="language_table" ) refresh_btn = gr.Button("🔄 Random Sample 10", variant="primary") status_text = gr.Markdown("Select a dataset and click refresh.") video_components = [] action_components = [] for i in range(5): with gr.Row(): with gr.Column(): v = gr.Video(label=f"Sample {i*2}", interactive=False) a = gr.Textbox(label=f"Actions {i*2}", lines=10) video_components.append(v) action_components.append(a) with gr.Column(): v = gr.Video(label=f"Sample {i*2+1}", interactive=False) a = gr.Textbox(label=f"Actions {i*2+1}", lines=10) video_components.append(v) action_components.append(a) def on_refresh(name): results = load_dataset_samples(name) # results is [v0, v1, ..., v9, a0, a1, ..., a9, status] return results refresh_btn.click( fn=on_refresh, inputs=[dataset_dropdown], outputs=video_components + action_components + [status_text], api_name=False ) return demo if __name__ == "__main__": demo = create_visualizer() demo.launch(server_name="0.0.0.0", server_port=7861, share=True)