aliangdw commited on
Commit
9df6336
·
1 Parent(s): 82bc251
Files changed (1) hide show
  1. app.py +201 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_dataset
3
+ import os
4
+
5
+ def load_rfm_dataset(dataset_name, config_name):
6
+ """Load the RFM dataset from HuggingFace Hub."""
7
+ try:
8
+ dataset = load_dataset(dataset_name, config_name, split="train")
9
+ return dataset, f"✅ Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}"
10
+ except Exception as e:
11
+ return None, f"❌ Error loading dataset: {e}"
12
+
13
+ def get_available_configs(dataset_name):
14
+ """Get available configurations for a dataset."""
15
+ try:
16
+ dataset_info = load_dataset(dataset_name, trust_remote_code=True)
17
+ configs = list(dataset_info.keys())
18
+ return configs
19
+ except Exception as e:
20
+ return []
21
+
22
+ def visualize_trajectory(dataset, index):
23
+ """
24
+ Function to retrieve a trajectory and its metadata from the dataset.
25
+ """
26
+ if dataset is None:
27
+ return None, "Error: Could not load dataset", "Error: Could not load dataset"
28
+
29
+ try:
30
+ item = dataset[int(index)]
31
+
32
+ # Get the video file path
33
+ video_path = item["frames"]
34
+
35
+ # Get metadata
36
+ task = item["task"]
37
+ optimal = item["optimal"]
38
+ is_robot = item["is_robot"]
39
+ data_source = item["data_source"]
40
+
41
+ # Create metadata text
42
+ metadata = f"""
43
+ ## Trajectory Information
44
+
45
+ **Task:** {task}
46
+
47
+ **Optimality:** {optimal}
48
+
49
+ **Data Type:** {'🤖 Robot' if is_robot else '👤 Human'}
50
+
51
+ **Source:** {data_source}
52
+
53
+ **Video Path:** `{video_path}`
54
+
55
+ **Trajectory ID:** {item.get('id', 'N/A')}
56
+ """
57
+
58
+ return video_path, metadata, f"Trajectory {index}"
59
+
60
+ except Exception as e:
61
+ return None, f"Error: {str(e)}", f"Error: {str(e)}"
62
+
63
+ # Create the Gradio interface
64
+ with gr.Blocks(title="RFM Dataset Visualizer") as demo:
65
+ gr.Markdown("# RFM Dataset Visualizer")
66
+ gr.Markdown("Browse through trajectory videos and their metadata from the Reward Foundation Model dataset.")
67
+
68
+ # Dataset selection
69
+ with gr.Row():
70
+ with gr.Column(scale=1):
71
+ dataset_name_input = gr.Textbox(
72
+ value="aliangdw/rfm",
73
+ label="Dataset Name",
74
+ placeholder="username/dataset-name"
75
+ )
76
+
77
+ with gr.Column(scale=1):
78
+ config_name_input = gr.Textbox(
79
+ value="libero_10",
80
+ label="Configuration Name",
81
+ placeholder="config-name"
82
+ )
83
+
84
+ with gr.Column(scale=1):
85
+ load_btn = gr.Button("Load Dataset", variant="primary")
86
+
87
+ # Status message
88
+ status_output = gr.Markdown("Ready to load dataset...")
89
+
90
+ # Dataset info
91
+ dataset_info = gr.Markdown("")
92
+
93
+ # Visualization section
94
+ with gr.Row():
95
+ with gr.Column(scale=2):
96
+ # Video display
97
+ video_output = gr.Video(label="Trajectory Video", height=400)
98
+
99
+ with gr.Column(scale=1):
100
+ # Metadata display
101
+ metadata_output = gr.Markdown(label="Metadata")
102
+
103
+ # Navigation controls
104
+ with gr.Row():
105
+ with gr.Column(scale=1):
106
+ prev_btn = gr.Button("⬅️ Previous", variant="secondary")
107
+
108
+ with gr.Column(scale=2):
109
+ # Slider for navigation
110
+ slider = gr.Slider(
111
+ minimum=0,
112
+ maximum=0,
113
+ step=1,
114
+ value=0,
115
+ label="Trajectory Index"
116
+ )
117
+
118
+ with gr.Column(scale=1):
119
+ next_btn = gr.Button("Next ➡️", variant="secondary")
120
+
121
+ # Current trajectory title
122
+ title_output = gr.Textbox(label="Current Trajectory", interactive=False)
123
+
124
+ # State variables
125
+ current_dataset = gr.State(None)
126
+ current_index = gr.State(0)
127
+
128
+ def load_dataset(dataset_name, config_name):
129
+ """Load the dataset and update the interface."""
130
+ dataset, status = load_rfm_dataset(dataset_name, config_name)
131
+ if dataset is not None:
132
+ max_index = len(dataset) - 1
133
+ info = f"**Dataset Info:**\n- **Total Trajectories:** {len(dataset)}\n- **Features:** {list(dataset.features.keys())}"
134
+ return dataset, status, info, max_index, 0
135
+ else:
136
+ return None, status, "", 0, 0
137
+
138
+ def update_trajectory(dataset, index):
139
+ """Update the displayed trajectory."""
140
+ if dataset is None:
141
+ return None, "No dataset loaded", "No dataset loaded"
142
+ return visualize_trajectory(dataset, index)
143
+
144
+ def next_trajectory(dataset, current_idx):
145
+ """Go to next trajectory."""
146
+ if dataset is None:
147
+ return current_idx, None, "No dataset loaded", "No dataset loaded"
148
+ next_idx = min(current_idx + 1, len(dataset) - 1)
149
+ video, metadata, title = visualize_trajectory(dataset, next_idx)
150
+ return next_idx, video, metadata, title
151
+
152
+ def prev_trajectory(dataset, current_idx):
153
+ """Go to previous trajectory."""
154
+ if dataset is None:
155
+ return current_idx, None, "No dataset loaded", "No dataset loaded"
156
+ prev_idx = max(current_idx - 1, 0)
157
+ video, metadata, title = visualize_trajectory(dataset, prev_idx)
158
+ return prev_idx, video, metadata, title
159
+
160
+ # Connect the components
161
+ load_btn.click(
162
+ fn=load_dataset,
163
+ inputs=[dataset_name_input, config_name_input],
164
+ outputs=[current_dataset, status_output, dataset_info, slider, current_index]
165
+ )
166
+
167
+ slider.change(
168
+ fn=lambda dataset, idx: update_trajectory(dataset, idx),
169
+ inputs=[current_dataset, slider],
170
+ outputs=[video_output, metadata_output, title_output]
171
+ )
172
+
173
+ next_btn.click(
174
+ fn=next_trajectory,
175
+ inputs=[current_dataset, current_index],
176
+ outputs=[current_index, video_output, metadata_output, title_output]
177
+ ).then(
178
+ fn=lambda idx: idx,
179
+ inputs=current_index,
180
+ outputs=slider
181
+ )
182
+
183
+ prev_btn.click(
184
+ fn=prev_trajectory,
185
+ inputs=[current_dataset, current_index],
186
+ outputs=[current_index, video_output, metadata_output, title_output]
187
+ ).then(
188
+ fn=lambda idx: idx,
189
+ inputs=current_index,
190
+ outputs=slider
191
+ )
192
+
193
+ # Load initial dataset
194
+ demo.load(
195
+ fn=lambda: load_dataset("aliangdw/rfm", "libero_10"),
196
+ outputs=[current_dataset, status_output, dataset_info, slider, current_index]
197
+ )
198
+
199
+ # Launch the app
200
+ if __name__ == "__main__":
201
+ demo.launch()