File size: 4,118 Bytes
f17ae24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import torch
import gradio as gr
import random
import pandas as pd
import numpy as np
from wm.dataset.data_config import DATASET_REGISTRY

def load_dataset_samples(dataset_name):
    if dataset_name not in DATASET_REGISTRY:
        return [None] * 10 + [None] * 10 + [f"Error: Dataset {dataset_name} not found."]
    
    config = DATASET_REGISTRY[dataset_name]
    root_dir = config["root_dir"]
    metadata_path = os.path.join(root_dir, "metadata.pt")
    
    if not os.path.exists(metadata_path):
        return [None] * 10 + [None] * 10 + [f"Error: Metadata not found at {metadata_path}. Wait for conversion to finish if this is a new dataset."]
    
    try:
        metadata = torch.load(metadata_path, weights_only=False)
    except Exception as e:
        return [None] * 10 + [None] * 10 + [f"Error loading metadata: {e}"]
    
    num_samples = min(10, len(metadata))
    indices = random.sample(range(len(metadata)), num_samples)
    
    videos = []
    actions_text = []
    
    for idx in indices:
        entry = metadata[idx]
        video_rel_path = entry["video_path"]
        video_full_path = os.path.join(root_dir, video_rel_path)
        
        if not os.path.exists(video_full_path):
            videos.append(None)
            actions_text.append(f"Video not found: {video_full_path}")
            continue
            
        videos.append(video_full_path)
        
        # Format actions
        if "actions" in entry:
            act = entry["actions"]
        elif "commands" in entry:
            # Handle RECON commands
            lin = entry["commands"]["linear_velocity"]
            ang = entry["commands"]["angular_velocity"]
            act = torch.stack([lin, ang], dim=-1)
        else:
            act = torch.zeros((1, 0))
            
        if isinstance(act, torch.Tensor):
            act_np = act.numpy()
        else:
            act_np = np.array(act)
            
        # Create a string representation for the textbox
        actions_str = f"Shape: {act_np.shape}\n"
        actions_str += str(act_np)
        actions_text.append(actions_str)
        
    # Pad to 10 if necessary
    while len(videos) < 10:
        videos.append(None)
        actions_text.append("")
        
    return videos + actions_text + [f"Loaded {num_samples} samples from {dataset_name}"]

def create_visualizer():
    with gr.Blocks(title="Dataset Visualizer") as demo:
        gr.Markdown("# 📊 World Model Dataset Visualizer")
        
        with gr.Row():
            dataset_dropdown = gr.Dropdown(
                choices=list(DATASET_REGISTRY.keys()), 
                label="Select Dataset",
                value="language_table"
            )
            refresh_btn = gr.Button("🔄 Random Sample 10", variant="primary")
            
        status_text = gr.Markdown("Select a dataset and click refresh.")
        
        video_components = []
        action_components = []
        
        for i in range(5):
            with gr.Row():
                with gr.Column():
                    v = gr.Video(label=f"Sample {i*2}", interactive=False)
                    a = gr.Textbox(label=f"Actions {i*2}", lines=10)
                    video_components.append(v)
                    action_components.append(a)
                with gr.Column():
                    v = gr.Video(label=f"Sample {i*2+1}", interactive=False)
                    a = gr.Textbox(label=f"Actions {i*2+1}", lines=10)
                    video_components.append(v)
                    action_components.append(a)
                    
        def on_refresh(name):
            results = load_dataset_samples(name)
            # results is [v0, v1, ..., v9, a0, a1, ..., a9, status]
            return results
            
        refresh_btn.click(
            fn=on_refresh,
            inputs=[dataset_dropdown],
            outputs=video_components + action_components + [status_text],
            api_name=False
        )
        
    return demo

if __name__ == "__main__":
    demo = create_visualizer()
    demo.launch(server_name="0.0.0.0", server_port=7861, share=True)