File size: 13,479 Bytes
22f4bcb
 
 
 
 
 
68c08d3
22f4bcb
9a16674
99bc590
7cddc01
9f87b78
640bb3b
b40820f
36df9f6
e49bcbb
c04c5bb
b4289ba
36df9f6
25813bc
78ca9a4
 
443fed2
6a3973a
02a22f3
 
946289c
51c4fed
afa7f83
ddb316e
f0e8104
02a22f3
6f3ce7a
 
946289c
faac200
a1978c5
b62dad4
 
22f4bcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9010913
 
22f4bcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcb84d4
22f4bcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcb84d4
22f4bcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9967263
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
import gradio as gr
from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
import os

# Predefined dataset names (configs will be fetched dynamically)
PREDEFINED_DATASETS = [
    "abraranwar/agibotworld_alpha_rfm",
    "abraranwar/libero_rfm",
    "abraranwar/usc_koch_rewind_rfm",
    "aliangdw/metaworld",
    "anqil/rh20t_rfm",
    "anqil/rh20t_subset_rfm",
    "jesbu1/auto_eval_rfm",
    "jesbu1/egodex_rfm",
    "jesbu1/epic_rfm",
    "jesbu1/fino_net_rfm",
    "jesbu1/failsafe_rfm",
    "jesbu1/hand_paired_rfm",
    "jesbu1/galaxea_rfm",
    "jesbu1/mit_franka_p-rank_rfm",
    "jesbu1/utd_so101_clean_policy_ranking_top",
    "jesbu1/utd_so101_clean_policy_ranking_wrist",
    "jesbu1/h2r_rfm",
    "jesbu1/humanoid_everyday_rfm",
    "jesbu1/molmoact_rfm",
    "jesbu1/motif_rfm",
    "jesbu1/oxe_rfm",
    "jesbu1/oxe_rfm_eval",
    "jesbu1/ph2d_rfm",
    "jesbu1/racer_rfm",
    "jesbu1/roboarena_0825_rfm",
    "jesbu1/soar_rfm",
    "jesbu1/usc_koch_human_robot_paired",
    "jesbu1/usc_koch_p_ranking_rfm",
    "ykorkmaz/libero_failure_rfm",
    "aliangdw/usc_xarm_policy_ranking",
    "aliangdw/usc_franka_policy_ranking",
    "aliangdw/utd_so101_policy_ranking",
    "aliangdw/utd_so101_human"
]

def load_rfm_dataset(dataset_name, config_name):
    """Load the RFM dataset from HuggingFace Hub."""
    try:
        # Validate inputs
        if not dataset_name or not config_name:
            return None, "❌ Please provide both dataset name and configuration"
        
        # Try to load the dataset
        dataset = load_dataset_hf(dataset_name, name=config_name, split="train")
        
        # Check if dataset has the expected structure
        expected_features = ["task", "frames", "quality_label", "is_robot", "data_source"]
        missing_features = [f for f in expected_features if f not in dataset.features]
        
        if missing_features:
            return None, f"⚠️ Dataset loaded but missing expected features: {missing_features}"
        
        # Check if dataset has any samples
        if len(dataset) == 0:
            return None, f"⚠️ Dataset {dataset_name}/{config_name} is empty"
        
        return dataset, f"βœ… Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}"
    except Exception as e:
        error_msg = str(e)
        if "not found" in error_msg.lower():
            return None, f"❌ Dataset or configuration not found: {dataset_name}/{config_name}"
        elif "authentication" in error_msg.lower():
            return None, f"❌ Authentication required for {dataset_name}"
        else:
            return None, f"❌ Error loading dataset: {error_msg}"

def get_available_configs(dataset_name):
    """Get available configurations for a dataset."""
    try:
        # Use the dedicated function to get config names
        configs = get_dataset_config_names(dataset_name)
        return configs
    except Exception as e:
        print(f"Error getting configs for {dataset_name}: {e}")
        return []

def update_config_choices_with_custom(dataset_name):
    """Update config choices by fetching from the dataset."""
    if not dataset_name:
        return gr.update(choices=[], value="")
    
    try:
        # Always try to fetch configs from the dataset
        configs = get_available_configs(dataset_name)
        if configs:
            current_value = configs[0]
            return gr.update(choices=configs, value=current_value)
        else:
            return gr.update(choices=[], value="")
    except Exception as e:
        # If fetching fails, allow custom input
        print(f"Warning: Could not fetch configs for {dataset_name}: {e}")
        return gr.update(choices=[], value="")

def visualize_trajectory(dataset, index, dataset_name=None):
    """
    Function to retrieve a trajectory and its metadata from the dataset.
    """
    if dataset is None:
        return None, "Error: Could not load dataset", "Error: Could not load dataset", None
    
    try:
        item = dataset[int(index)]
        
        # Get metadata
        task = item["task"]
        quality_label = item["quality_label"]
        is_robot = item["is_robot"]
        data_source = item["data_source"]
        
        # Get the frames data (video file path)
        frames_data = item["frames"]
        
        # Handle video file path (could be local path or HuggingFace Hub URL)
        if isinstance(frames_data, str):
            # Use dynamic dataset name if provided, otherwise fallback to default
            if dataset_name:
                video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}"
            else:
                video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}"
            frames_info = f"Video file: {video_path}"
        else:
            return None, "Error: Invalid video path", "Error: Invalid video path", None
        
        # Create metadata
        metadata = f"""
        ## Trajectory Information

        **Video path:** {video_path}
        
        **Language Task:** {task}
        
        **Quality Label:** {quality_label}
        
        **Data Type:** {'Robot' if is_robot else 'Human'}
        
        **Source:** {data_source}
                
        **Trajectory ID:** {item.get('id', 'N/A')}
        """
        
        # Return video path for Gradio to display
        return video_path, metadata, f"Trajectory {index}", None
        
    except Exception as e:
        return None, f"Error: {str(e)}", f"Error: {str(e)}", None

# Create the Gradio interface
with gr.Blocks(title="RFM Dataset Visualizer") as demo:
    gr.Markdown("# RFM Dataset Visualizer")
    gr.Markdown("Browse through trajectory videos and their metadata from the Reward Foundation Model dataset.")
    
    # Dataset selection
    with gr.Row():
        with gr.Column(scale=2):
            dataset_name_input = gr.Dropdown(
                choices=PREDEFINED_DATASETS,
                value="jesbu1/oxe_rfm",
                label="Dataset Name",
                allow_custom_value=True
            )
        
        with gr.Column(scale=2):
            config_name_input = gr.Dropdown(
                choices=[],
                value="",
                label="Configuration Name",
                allow_custom_value=True
            )
        
        with gr.Column(scale=1):
            refresh_btn = gr.Button("πŸ”„ Refresh Configs", variant="secondary", size="sm")
        
        with gr.Column(scale=1):
            load_btn = gr.Button("Load Dataset", variant="primary")
    
    # Status message
    status_output = gr.Markdown("Ready to load dataset...")
    
    # Dataset info
    dataset_info = gr.Markdown("")
    
    # Visualization section
    with gr.Row():
        with gr.Column(scale=2):
            # Video/Image display
            video_output = gr.Video(label="Trajectory Video", height=400, autoplay=True)
            image_output = gr.Image(label="Frame Preview", height=400, visible=False)
        
        with gr.Column(scale=1):
            # Metadata display
            metadata_output = gr.Markdown(label="Metadata")
    
    # Navigation controls
    with gr.Row():
        with gr.Column(scale=1):
            prev_btn = gr.Button("⬅️ Previous", variant="secondary")
        
        with gr.Column(scale=2):
            # Slider for navigation with dynamic max
            slider = gr.Slider(
                minimum=0,
                maximum=0,
                step=1,
                value=0,
                label="Select a dataset first",
                interactive=False
            )
        
        with gr.Column(scale=1):
            next_btn = gr.Button("Next ➑️", variant="secondary")
    
    # Current trajectory title
    title_output = gr.Textbox(label="Current Trajectory", interactive=False)
    
    # State variables
    current_dataset = gr.State(None)
    current_index = gr.State(0)
    
    def load_dataset(dataset_name, config_name):
        """Load the dataset and update the interface."""
        dataset, status = load_rfm_dataset(dataset_name, config_name)
        if dataset is not None:
            max_index = len(dataset) - 1
            info = f"**Dataset Info:**\n- **Total Trajectories:** {len(dataset)}\n- **Features:** {list(dataset.features.keys())}"
            # Return the dataset length for number input configuration
            return dataset, status, info, 0, max_index
        else:
            return None, status, "", 0, 0
    
    def update_trajectory(dataset, index, dataset_name=None):
        """Update the displayed trajectory."""
        if dataset is None:
            return None, "No dataset loaded", "No dataset loaded", None
        # Ensure index is within bounds and is a valid number
        if index is None or not isinstance(index, (int, float)):
            index = 0
        elif index >= len(dataset):
            index = len(dataset) - 1
        elif index < 0:
            index = 0
        return visualize_trajectory(dataset, int(index), dataset_name)
    
    def next_trajectory(dataset, current_idx, dataset_name=None):
        """Go to next trajectory."""
        if dataset is None:
            return current_idx, None, "No dataset loaded", "No dataset loaded", None
        next_idx = min(current_idx + 1, len(dataset) - 1)
        video, metadata, title, image = visualize_trajectory(dataset, next_idx, dataset_name)
        return next_idx, video, metadata, title, image
    
    def prev_trajectory(dataset, current_idx, dataset_name=None):
        """Go to previous trajectory."""
        if dataset is None:
            return current_idx, None, "No dataset loaded", "No dataset loaded", None
        prev_idx = max(current_idx - 1, 0)
        video, metadata, title, image = visualize_trajectory(dataset, prev_idx, dataset_name)
        return prev_idx, video, metadata, title, image
    
    def update_slider_range(dataset):
        """Update the slider with new maximum value based on dataset length."""
        if dataset is not None:
            max_value = len(dataset) - 1
            return gr.update(
                maximum=max_value,
                value=0,  # Reset to beginning
                label=f"Trajectory Index (0 to {max_value})",
                interactive=True
            )
        else:
            return gr.update(
                maximum=0,
                value=0,
                label="Select a dataset first",
                interactive=False
            )
    
    def update_config_choices(dataset_name):
        """Update the config dropdown choices based on selected dataset."""
        return update_config_choices_with_custom(dataset_name)
    
    # Connect the components
    # Update config choices when dataset changes
    dataset_name_input.change(
        fn=update_config_choices,
        inputs=[dataset_name_input],
        outputs=[config_name_input]
    )
    
    # Refresh configs button for custom datasets
    refresh_btn.click(
        fn=update_config_choices_with_custom,
        inputs=[dataset_name_input],
        outputs=[config_name_input]
    )
    
    load_btn.click(
        fn=load_dataset,
        inputs=[dataset_name_input, config_name_input],
        outputs=[current_dataset, status_output, dataset_info, current_index]
    ).then(
        fn=update_slider_range,
        inputs=current_dataset,
        outputs=slider
    )
    
    slider.change(
        fn=lambda dataset, idx, dataset_name: update_trajectory(dataset, idx, dataset_name),
        inputs=[current_dataset, slider, dataset_name_input],
        outputs=[video_output, metadata_output, title_output, image_output]
    ).then(
        fn=lambda dataset, idx: idx if dataset is None or idx < len(dataset) else len(dataset) - 1,
        inputs=[current_dataset, slider],
        outputs=[current_index]
    )
    
    next_btn.click(
        fn=lambda dataset, idx, dataset_name: next_trajectory(dataset, idx, dataset_name),
        inputs=[current_dataset, current_index, dataset_name_input],
        outputs=[current_index, video_output, metadata_output, title_output, image_output]
    ).then(
        fn=lambda idx: idx,
        inputs=current_index,
        outputs=slider
    )
    
    prev_btn.click(
        fn=lambda dataset, idx, dataset_name: prev_trajectory(dataset, idx, dataset_name),
        inputs=[current_dataset, current_index, dataset_name_input],
        outputs=[current_index, video_output, metadata_output, title_output, image_output]
    ).then(
        fn=lambda idx: idx,
        inputs=current_index,
        outputs=slider
    )
    
    # Load initial dataset and configs
    demo.load(
        fn=lambda: ("jesbu1/oxe_rfm", "oxe_jaco_play"),  # Set initial values
        outputs=[dataset_name_input, config_name_input]
    ).then(
        fn=update_config_choices_with_custom,
        inputs=[dataset_name_input],
        outputs=[config_name_input]
    ).then(
        fn=lambda dataset_name, config_name: load_dataset(dataset_name, config_name),
        inputs=[dataset_name_input, config_name_input],
        outputs=[current_dataset, status_output, dataset_info, current_index]
    ).then(
        fn=update_slider_range,
        inputs=current_dataset,
        outputs=slider
    )

def main():
    """Main function to launch the RFM visualizer."""
    demo.launch()

# Launch the app
if __name__ == "__main__":
    main()