aliangdw commited on
Commit
6c9db9a
·
1 Parent(s): 89b6262
Files changed (1) hide show
  1. app.py +92 -12
app.py CHANGED
@@ -2,13 +2,42 @@ import gradio as gr
2
  from datasets import load_dataset as load_dataset_hf
3
  import os
4
 
 
 
 
 
 
 
5
  def load_rfm_dataset(dataset_name, config_name):
6
  """Load the RFM dataset from HuggingFace Hub."""
7
  try:
 
 
 
 
 
8
  dataset = load_dataset_hf(dataset_name, name=config_name, split="train")
 
 
 
 
 
 
 
 
 
 
 
 
9
  return dataset, f"✅ Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}"
10
  except Exception as e:
11
- return None, f"❌ Error loading dataset: {e}"
 
 
 
 
 
 
12
 
13
  def get_available_configs(dataset_name):
14
  """Get available configurations for a dataset."""
@@ -20,6 +49,24 @@ def get_available_configs(dataset_name):
20
  print(f"Error getting configs: {e}")
21
  return []
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def visualize_trajectory(dataset, index, dataset_name=None):
24
  """
25
  Function to retrieve a trajectory and its metadata from the dataset.
@@ -32,7 +79,7 @@ def visualize_trajectory(dataset, index, dataset_name=None):
32
 
33
  # Get metadata
34
  task = item["task"]
35
- quality_label = item["quality_label"]
36
  is_robot = item["is_robot"]
37
  data_source = item["data_source"]
38
 
@@ -56,7 +103,7 @@ def visualize_trajectory(dataset, index, dataset_name=None):
56
 
57
  **Language Task:** {task}
58
 
59
- **Quality:** {quality_label}
60
 
61
  **Data Type:** {'Robot' if is_robot else 'Human'}
62
 
@@ -78,20 +125,27 @@ with gr.Blocks(title="RFM Dataset Visualizer") as demo:
78
 
79
  # Dataset selection
80
  with gr.Row():
81
- with gr.Column(scale=1):
82
- dataset_name_input = gr.Textbox(
 
83
  value="aliangdw/rfm",
84
  label="Dataset Name",
85
- placeholder="username/dataset-name"
 
86
  )
87
 
88
- with gr.Column(scale=1):
89
- config_name_input = gr.Textbox(
90
- value="libero_10",
 
91
  label="Configuration Name",
92
- placeholder="config-name"
 
93
  )
94
 
 
 
 
95
  with gr.Column(scale=1):
96
  load_btn = gr.Button("Load Dataset", variant="primary")
97
 
@@ -196,7 +250,25 @@ with gr.Blocks(title="RFM Dataset Visualizer") as demo:
196
  interactive=False
197
  )
198
 
 
 
 
 
199
  # Connect the components
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  load_btn.click(
201
  fn=load_dataset,
202
  inputs=[dataset_name_input, config_name_input],
@@ -237,9 +309,17 @@ with gr.Blocks(title="RFM Dataset Visualizer") as demo:
237
  outputs=slider
238
  )
239
 
240
- # Load initial dataset
241
  demo.load(
242
- fn=lambda: load_dataset("aliangdw/rfm", "libero_10"),
 
 
 
 
 
 
 
 
243
  outputs=[current_dataset, status_output, dataset_info, current_index]
244
  ).then(
245
  fn=update_slider_range,
 
2
  from datasets import load_dataset as load_dataset_hf
3
  import os
4
 
5
+ # Predefined dataset names (configs will be fetched dynamically)
6
+ PREDEFINED_DATASETS = [
7
+ "aliangdw/rfm",
8
+ "abraranwar/libero_rfm",
9
+ ]
10
+
11
  def load_rfm_dataset(dataset_name, config_name):
12
  """Load the RFM dataset from HuggingFace Hub."""
13
  try:
14
+ # Validate inputs
15
+ if not dataset_name or not config_name:
16
+ return None, "❌ Please provide both dataset name and configuration"
17
+
18
+ # Try to load the dataset
19
  dataset = load_dataset_hf(dataset_name, name=config_name, split="train")
20
+
21
+ # Check if dataset has the expected structure
22
+ expected_features = ["task", "frames", "quality_label", "is_robot", "data_source"]
23
+ missing_features = [f for f in expected_features if f not in dataset.features]
24
+
25
+ if missing_features:
26
+ return None, f"⚠️ Dataset loaded but missing expected features: {missing_features}"
27
+
28
+ # Check if dataset has any samples
29
+ if len(dataset) == 0:
30
+ return None, f"⚠️ Dataset {dataset_name}/{config_name} is empty"
31
+
32
  return dataset, f"✅ Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}"
33
  except Exception as e:
34
+ error_msg = str(e)
35
+ if "not found" in error_msg.lower():
36
+ return None, f"❌ Dataset or configuration not found: {dataset_name}/{config_name}"
37
+ elif "authentication" in error_msg.lower():
38
+ return None, f"❌ Authentication required for {dataset_name}"
39
+ else:
40
+ return None, f"❌ Error loading dataset: {error_msg}"
41
 
42
  def get_available_configs(dataset_name):
43
  """Get available configurations for a dataset."""
 
49
  print(f"Error getting configs: {e}")
50
  return []
51
 
52
+ def update_config_choices_with_custom(dataset_name):
53
+ """Update config choices by fetching from the dataset."""
54
+ if not dataset_name:
55
+ return gr.update(choices=[], value="")
56
+
57
+ try:
58
+ # Always try to fetch configs from the dataset
59
+ configs = get_available_configs(dataset_name)
60
+ if configs:
61
+ current_value = configs[0]
62
+ return gr.update(choices=configs, value=current_value)
63
+ else:
64
+ return gr.update(choices=[], value="")
65
+ except Exception as e:
66
+ # If fetching fails, allow custom input
67
+ print(f"Warning: Could not fetch configs for {dataset_name}: {e}")
68
+ return gr.update(choices=[], value="")
69
+
70
  def visualize_trajectory(dataset, index, dataset_name=None):
71
  """
72
  Function to retrieve a trajectory and its metadata from the dataset.
 
79
 
80
  # Get metadata
81
  task = item["task"]
82
+ optimal = item["optimal"]
83
  is_robot = item["is_robot"]
84
  data_source = item["data_source"]
85
 
 
103
 
104
  **Language Task:** {task}
105
 
106
+ **Optimal:** {optimal}
107
 
108
  **Data Type:** {'Robot' if is_robot else 'Human'}
109
 
 
125
 
126
  # Dataset selection
127
  with gr.Row():
128
+ with gr.Column(scale=2):
129
+ dataset_name_input = gr.Dropdown(
130
+ choices=PREDEFINED_DATASETS,
131
  value="aliangdw/rfm",
132
  label="Dataset Name",
133
+ allow_custom_value=True,
134
+ placeholder="Select dataset or enter custom"
135
  )
136
 
137
+ with gr.Column(scale=2):
138
+ config_name_input = gr.Dropdown(
139
+ choices=[],
140
+ value="",
141
  label="Configuration Name",
142
+ allow_custom_value=True,
143
+ placeholder="Select config or enter custom"
144
  )
145
 
146
+ with gr.Column(scale=1):
147
+ refresh_btn = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm")
148
+
149
  with gr.Column(scale=1):
150
  load_btn = gr.Button("Load Dataset", variant="primary")
151
 
 
250
  interactive=False
251
  )
252
 
253
+ def update_config_choices(dataset_name):
254
+ """Update the config dropdown choices based on selected dataset."""
255
+ return update_config_choices_with_custom(dataset_name)
256
+
257
  # Connect the components
258
+ # Update config choices when dataset changes
259
+ dataset_name_input.change(
260
+ fn=update_config_choices,
261
+ inputs=[dataset_name_input],
262
+ outputs=[config_name_input]
263
+ )
264
+
265
+ # Refresh configs button for custom datasets
266
+ refresh_btn.click(
267
+ fn=update_config_choices_with_custom,
268
+ inputs=[dataset_name_input],
269
+ outputs=[config_name_input]
270
+ )
271
+
272
  load_btn.click(
273
  fn=load_dataset,
274
  inputs=[dataset_name_input, config_name_input],
 
309
  outputs=slider
310
  )
311
 
312
+ # Load initial dataset and configs
313
  demo.load(
314
+ fn=lambda: ("aliangdw/rfm", "libero_10"), # Set initial values
315
+ outputs=[dataset_name_input, config_name_input]
316
+ ).then(
317
+ fn=update_config_choices_with_custom,
318
+ inputs=[dataset_name_input],
319
+ outputs=[config_name_input]
320
+ ).then(
321
+ fn=lambda dataset_name, config_name: load_dataset(dataset_name, config_name),
322
+ inputs=[dataset_name_input, config_name_input],
323
  outputs=[current_dataset, status_output, dataset_info, current_index]
324
  ).then(
325
  fn=update_slider_range,