File size: 4,118 Bytes
f17ae24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | import os
import torch
import gradio as gr
import random
import pandas as pd
import numpy as np
from wm.dataset.data_config import DATASET_REGISTRY
def load_dataset_samples(dataset_name):
if dataset_name not in DATASET_REGISTRY:
return [None] * 10 + [None] * 10 + [f"Error: Dataset {dataset_name} not found."]
config = DATASET_REGISTRY[dataset_name]
root_dir = config["root_dir"]
metadata_path = os.path.join(root_dir, "metadata.pt")
if not os.path.exists(metadata_path):
return [None] * 10 + [None] * 10 + [f"Error: Metadata not found at {metadata_path}. Wait for conversion to finish if this is a new dataset."]
try:
metadata = torch.load(metadata_path, weights_only=False)
except Exception as e:
return [None] * 10 + [None] * 10 + [f"Error loading metadata: {e}"]
num_samples = min(10, len(metadata))
indices = random.sample(range(len(metadata)), num_samples)
videos = []
actions_text = []
for idx in indices:
entry = metadata[idx]
video_rel_path = entry["video_path"]
video_full_path = os.path.join(root_dir, video_rel_path)
if not os.path.exists(video_full_path):
videos.append(None)
actions_text.append(f"Video not found: {video_full_path}")
continue
videos.append(video_full_path)
# Format actions
if "actions" in entry:
act = entry["actions"]
elif "commands" in entry:
# Handle RECON commands
lin = entry["commands"]["linear_velocity"]
ang = entry["commands"]["angular_velocity"]
act = torch.stack([lin, ang], dim=-1)
else:
act = torch.zeros((1, 0))
if isinstance(act, torch.Tensor):
act_np = act.numpy()
else:
act_np = np.array(act)
# Create a string representation for the textbox
actions_str = f"Shape: {act_np.shape}\n"
actions_str += str(act_np)
actions_text.append(actions_str)
# Pad to 10 if necessary
while len(videos) < 10:
videos.append(None)
actions_text.append("")
return videos + actions_text + [f"Loaded {num_samples} samples from {dataset_name}"]
def create_visualizer():
with gr.Blocks(title="Dataset Visualizer") as demo:
gr.Markdown("# 📊 World Model Dataset Visualizer")
with gr.Row():
dataset_dropdown = gr.Dropdown(
choices=list(DATASET_REGISTRY.keys()),
label="Select Dataset",
value="language_table"
)
refresh_btn = gr.Button("🔄 Random Sample 10", variant="primary")
status_text = gr.Markdown("Select a dataset and click refresh.")
video_components = []
action_components = []
for i in range(5):
with gr.Row():
with gr.Column():
v = gr.Video(label=f"Sample {i*2}", interactive=False)
a = gr.Textbox(label=f"Actions {i*2}", lines=10)
video_components.append(v)
action_components.append(a)
with gr.Column():
v = gr.Video(label=f"Sample {i*2+1}", interactive=False)
a = gr.Textbox(label=f"Actions {i*2+1}", lines=10)
video_components.append(v)
action_components.append(a)
def on_refresh(name):
results = load_dataset_samples(name)
# results is [v0, v1, ..., v9, a0, a1, ..., a9, status]
return results
refresh_btn.click(
fn=on_refresh,
inputs=[dataset_dropdown],
outputs=video_components + action_components + [status_text],
api_name=False
)
return demo
if __name__ == "__main__":
demo = create_visualizer()
demo.launch(server_name="0.0.0.0", server_port=7861, share=True)
|