aliangdw commited on
Commit
22f4bcb
·
1 Parent(s): e429295
Files changed (1) hide show
  1. app.py +335 -0
app.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
3
+ import os
4
+
5
+ # Predefined dataset names (configs will be fetched dynamically)
6
+ PREDEFINED_DATASETS = [
7
+ "aliangdw/rfm",
8
+ "abraranwar/libero_rfm",
9
+ "jesbu1/oxe_rfm"
10
+ ]
11
+
12
+ def load_rfm_dataset(dataset_name, config_name):
13
+ """Load the RFM dataset from HuggingFace Hub."""
14
+ try:
15
+ # Validate inputs
16
+ if not dataset_name or not config_name:
17
+ return None, "❌ Please provide both dataset name and configuration"
18
+
19
+ # Try to load the dataset
20
+ dataset = load_dataset_hf(dataset_name, name=config_name, split="train")
21
+
22
+ # Check if dataset has the expected structure
23
+ expected_features = ["task", "frames", "quality_label", "is_robot", "data_source"]
24
+ missing_features = [f for f in expected_features if f not in dataset.features]
25
+
26
+ if missing_features:
27
+ return None, f"⚠️ Dataset loaded but missing expected features: {missing_features}"
28
+
29
+ # Check if dataset has any samples
30
+ if len(dataset) == 0:
31
+ return None, f"⚠️ Dataset {dataset_name}/{config_name} is empty"
32
+
33
+ return dataset, f"✅ Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}"
34
+ except Exception as e:
35
+ error_msg = str(e)
36
+ if "not found" in error_msg.lower():
37
+ return None, f"❌ Dataset or configuration not found: {dataset_name}/{config_name}"
38
+ elif "authentication" in error_msg.lower():
39
+ return None, f"❌ Authentication required for {dataset_name}"
40
+ else:
41
+ return None, f"❌ Error loading dataset: {error_msg}"
42
+
43
+ def get_available_configs(dataset_name):
44
+ """Get available configurations for a dataset."""
45
+ try:
46
+ # Use the dedicated function to get config names
47
+ configs = get_dataset_config_names(dataset_name)
48
+ return configs
49
+ except Exception as e:
50
+ print(f"Error getting configs for {dataset_name}: {e}")
51
+ return []
52
+
53
+ def update_config_choices_with_custom(dataset_name):
54
+ """Update config choices by fetching from the dataset."""
55
+ if not dataset_name:
56
+ return gr.update(choices=[], value="")
57
+
58
+ try:
59
+ # Always try to fetch configs from the dataset
60
+ configs = get_available_configs(dataset_name)
61
+ if configs:
62
+ current_value = configs[0]
63
+ return gr.update(choices=configs, value=current_value)
64
+ else:
65
+ return gr.update(choices=[], value="")
66
+ except Exception as e:
67
+ # If fetching fails, allow custom input
68
+ print(f"Warning: Could not fetch configs for {dataset_name}: {e}")
69
+ return gr.update(choices=[], value="")
70
+
71
+ def visualize_trajectory(dataset, index, dataset_name=None):
72
+ """
73
+ Function to retrieve a trajectory and its metadata from the dataset.
74
+ """
75
+ if dataset is None:
76
+ return None, "Error: Could not load dataset", "Error: Could not load dataset", None
77
+
78
+ try:
79
+ item = dataset[int(index)]
80
+
81
+ # Get metadata
82
+ task = item["task"]
83
+ quality_label = item["quality_label"]
84
+ is_robot = item["is_robot"]
85
+ data_source = item["data_source"]
86
+
87
+ # Get the frames data (video file path)
88
+ frames_data = item["frames"]
89
+
90
+ # Handle video file path (could be local path or HuggingFace Hub URL)
91
+ if isinstance(frames_data, str):
92
+ # Use dynamic dataset name if provided, otherwise fallback to default
93
+ if dataset_name:
94
+ video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}"
95
+ else:
96
+ video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}"
97
+ frames_info = f"Video file: {video_path}"
98
+ else:
99
+ return None, "Error: Invalid video path", "Error: Invalid video path", None
100
+
101
+ # Create metadata
102
+ metadata = f"""
103
+ ## Trajectory Information
104
+
105
+ **Language Task:** {task}
106
+
107
+ **Quality Label:** {quality_label}
108
+
109
+ **Data Type:** {'Robot' if is_robot else 'Human'}
110
+
111
+ **Source:** {data_source}
112
+
113
+ **Trajectory ID:** {item.get('id', 'N/A')}
114
+ """
115
+
116
+ # Return video path for Gradio to display
117
+ return video_path, metadata, f"Trajectory {index}", None
118
+
119
+ except Exception as e:
120
+ return None, f"Error: {str(e)}", f"Error: {str(e)}", None
121
+
122
+ # Create the Gradio interface
123
+ with gr.Blocks(title="RFM Dataset Visualizer") as demo:
124
+ gr.Markdown("# RFM Dataset Visualizer")
125
+ gr.Markdown("Browse through trajectory videos and their metadata from the Reward Foundation Model dataset.")
126
+
127
+ # Dataset selection
128
+ with gr.Row():
129
+ with gr.Column(scale=2):
130
+ dataset_name_input = gr.Dropdown(
131
+ choices=PREDEFINED_DATASETS,
132
+ value="aliangdw/rfm",
133
+ label="Dataset Name",
134
+ allow_custom_value=True
135
+ )
136
+
137
+ with gr.Column(scale=2):
138
+ config_name_input = gr.Dropdown(
139
+ choices=[],
140
+ value="",
141
+ label="Configuration Name",
142
+ allow_custom_value=True
143
+ )
144
+
145
+ with gr.Column(scale=1):
146
+ refresh_btn = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm")
147
+
148
+ with gr.Column(scale=1):
149
+ load_btn = gr.Button("Load Dataset", variant="primary")
150
+
151
+ # Status message
152
+ status_output = gr.Markdown("Ready to load dataset...")
153
+
154
+ # Dataset info
155
+ dataset_info = gr.Markdown("")
156
+
157
+ # Visualization section
158
+ with gr.Row():
159
+ with gr.Column(scale=2):
160
+ # Video/Image display
161
+ video_output = gr.Video(label="Trajectory Video", height=400, autoplay=True)
162
+ image_output = gr.Image(label="Frame Preview", height=400, visible=False)
163
+
164
+ with gr.Column(scale=1):
165
+ # Metadata display
166
+ metadata_output = gr.Markdown(label="Metadata")
167
+
168
+ # Navigation controls
169
+ with gr.Row():
170
+ with gr.Column(scale=1):
171
+ prev_btn = gr.Button("⬅️ Previous", variant="secondary")
172
+
173
+ with gr.Column(scale=2):
174
+ # Slider for navigation with dynamic max
175
+ slider = gr.Slider(
176
+ minimum=0,
177
+ maximum=0,
178
+ step=1,
179
+ value=0,
180
+ label="Select a dataset first",
181
+ interactive=False
182
+ )
183
+
184
+ with gr.Column(scale=1):
185
+ next_btn = gr.Button("Next ➡️", variant="secondary")
186
+
187
+ # Current trajectory title
188
+ title_output = gr.Textbox(label="Current Trajectory", interactive=False)
189
+
190
+ # State variables
191
+ current_dataset = gr.State(None)
192
+ current_index = gr.State(0)
193
+
194
+ def load_dataset(dataset_name, config_name):
195
+ """Load the dataset and update the interface."""
196
+ dataset, status = load_rfm_dataset(dataset_name, config_name)
197
+ if dataset is not None:
198
+ max_index = len(dataset) - 1
199
+ info = f"**Dataset Info:**\n- **Total Trajectories:** {len(dataset)}\n- **Features:** {list(dataset.features.keys())}"
200
+ # Return the dataset length for number input configuration
201
+ return dataset, status, info, 0, max_index
202
+ else:
203
+ return None, status, "", 0, 0
204
+
205
+ def update_trajectory(dataset, index, dataset_name=None):
206
+ """Update the displayed trajectory."""
207
+ if dataset is None:
208
+ return None, "No dataset loaded", "No dataset loaded", None
209
+ # Ensure index is within bounds and is a valid number
210
+ if index is None or not isinstance(index, (int, float)):
211
+ index = 0
212
+ elif index >= len(dataset):
213
+ index = len(dataset) - 1
214
+ elif index < 0:
215
+ index = 0
216
+ return visualize_trajectory(dataset, int(index), dataset_name)
217
+
218
+ def next_trajectory(dataset, current_idx, dataset_name=None):
219
+ """Go to next trajectory."""
220
+ if dataset is None:
221
+ return current_idx, None, "No dataset loaded", "No dataset loaded", None
222
+ next_idx = min(current_idx + 1, len(dataset) - 1)
223
+ video, metadata, title, image = visualize_trajectory(dataset, next_idx, dataset_name)
224
+ return next_idx, video, metadata, title, image
225
+
226
+ def prev_trajectory(dataset, current_idx, dataset_name=None):
227
+ """Go to previous trajectory."""
228
+ if dataset is None:
229
+ return current_idx, None, "No dataset loaded", "No dataset loaded", None
230
+ prev_idx = max(current_idx - 1, 0)
231
+ video, metadata, title, image = visualize_trajectory(dataset, prev_idx, dataset_name)
232
+ return prev_idx, video, metadata, title, image
233
+
234
+ def update_slider_range(dataset):
235
+ """Update the slider with new maximum value based on dataset length."""
236
+ if dataset is not None:
237
+ max_value = len(dataset) - 1
238
+ return gr.update(
239
+ maximum=max_value,
240
+ value=0, # Reset to beginning
241
+ label=f"Trajectory Index (0 to {max_value})",
242
+ interactive=True
243
+ )
244
+ else:
245
+ return gr.update(
246
+ maximum=0,
247
+ value=0,
248
+ label="Select a dataset first",
249
+ interactive=False
250
+ )
251
+
252
+ def update_config_choices(dataset_name):
253
+ """Update the config dropdown choices based on selected dataset."""
254
+ return update_config_choices_with_custom(dataset_name)
255
+
256
+ # Connect the components
257
+ # Update config choices when dataset changes
258
+ dataset_name_input.change(
259
+ fn=update_config_choices,
260
+ inputs=[dataset_name_input],
261
+ outputs=[config_name_input]
262
+ )
263
+
264
+ # Refresh configs button for custom datasets
265
+ refresh_btn.click(
266
+ fn=update_config_choices_with_custom,
267
+ inputs=[dataset_name_input],
268
+ outputs=[config_name_input]
269
+ )
270
+
271
+ load_btn.click(
272
+ fn=load_dataset,
273
+ inputs=[dataset_name_input, config_name_input],
274
+ outputs=[current_dataset, status_output, dataset_info, current_index]
275
+ ).then(
276
+ fn=update_slider_range,
277
+ inputs=current_dataset,
278
+ outputs=slider
279
+ )
280
+
281
+ slider.change(
282
+ fn=lambda dataset, idx, dataset_name: update_trajectory(dataset, idx, dataset_name),
283
+ inputs=[current_dataset, slider, dataset_name_input],
284
+ outputs=[video_output, metadata_output, title_output, image_output]
285
+ ).then(
286
+ fn=lambda dataset, idx: idx if dataset is None or idx < len(dataset) else len(dataset) - 1,
287
+ inputs=[current_dataset, slider],
288
+ outputs=[current_index]
289
+ )
290
+
291
+ next_btn.click(
292
+ fn=lambda dataset, idx, dataset_name: next_trajectory(dataset, idx, dataset_name),
293
+ inputs=[current_dataset, current_index, dataset_name_input],
294
+ outputs=[current_index, video_output, metadata_output, title_output, image_output]
295
+ ).then(
296
+ fn=lambda idx: idx,
297
+ inputs=current_index,
298
+ outputs=slider
299
+ )
300
+
301
+ prev_btn.click(
302
+ fn=lambda dataset, idx, dataset_name: prev_trajectory(dataset, idx, dataset_name),
303
+ inputs=[current_dataset, current_index, dataset_name_input],
304
+ outputs=[current_index, video_output, metadata_output, title_output, image_output]
305
+ ).then(
306
+ fn=lambda idx: idx,
307
+ inputs=current_index,
308
+ outputs=slider
309
+ )
310
+
311
+ # Load initial dataset and configs
312
+ demo.load(
313
+ fn=lambda: ("aliangdw/rfm", "libero_10"), # Set initial values
314
+ outputs=[dataset_name_input, config_name_input]
315
+ ).then(
316
+ fn=update_config_choices_with_custom,
317
+ inputs=[dataset_name_input],
318
+ outputs=[config_name_input]
319
+ ).then(
320
+ fn=lambda dataset_name, config_name: load_dataset(dataset_name, config_name),
321
+ inputs=[dataset_name_input, config_name_input],
322
+ outputs=[current_dataset, status_output, dataset_info, current_index]
323
+ ).then(
324
+ fn=update_slider_range,
325
+ inputs=current_dataset,
326
+ outputs=slider
327
+ )
328
+
329
+ def main():
330
+ """Main function to launch the RFM visualizer."""
331
+ demo.launch()
332
+
333
+ # Launch the app
334
+ if __name__ == "__main__":
335
+ main()