Spaces:
Sleeping
Sleeping
small
Browse files
app.py
CHANGED
|
@@ -2,13 +2,42 @@ import gradio as gr
|
|
| 2 |
from datasets import load_dataset as load_dataset_hf
|
| 3 |
import os
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
def load_rfm_dataset(dataset_name, config_name):
|
| 6 |
"""Load the RFM dataset from HuggingFace Hub."""
|
| 7 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
dataset = load_dataset_hf(dataset_name, name=config_name, split="train")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
return dataset, f"✅ Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}"
|
| 10 |
except Exception as e:
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def get_available_configs(dataset_name):
|
| 14 |
"""Get available configurations for a dataset."""
|
|
@@ -20,6 +49,24 @@ def get_available_configs(dataset_name):
|
|
| 20 |
print(f"Error getting configs: {e}")
|
| 21 |
return []
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
def visualize_trajectory(dataset, index, dataset_name=None):
|
| 24 |
"""
|
| 25 |
Function to retrieve a trajectory and its metadata from the dataset.
|
|
@@ -32,7 +79,7 @@ def visualize_trajectory(dataset, index, dataset_name=None):
|
|
| 32 |
|
| 33 |
# Get metadata
|
| 34 |
task = item["task"]
|
| 35 |
-
|
| 36 |
is_robot = item["is_robot"]
|
| 37 |
data_source = item["data_source"]
|
| 38 |
|
|
@@ -56,7 +103,7 @@ def visualize_trajectory(dataset, index, dataset_name=None):
|
|
| 56 |
|
| 57 |
**Language Task:** {task}
|
| 58 |
|
| 59 |
-
**
|
| 60 |
|
| 61 |
**Data Type:** {'Robot' if is_robot else 'Human'}
|
| 62 |
|
|
@@ -78,20 +125,27 @@ with gr.Blocks(title="RFM Dataset Visualizer") as demo:
|
|
| 78 |
|
| 79 |
# Dataset selection
|
| 80 |
with gr.Row():
|
| 81 |
-
with gr.Column(scale=
|
| 82 |
-
dataset_name_input = gr.
|
|
|
|
| 83 |
value="aliangdw/rfm",
|
| 84 |
label="Dataset Name",
|
| 85 |
-
|
|
|
|
| 86 |
)
|
| 87 |
|
| 88 |
-
with gr.Column(scale=
|
| 89 |
-
config_name_input = gr.
|
| 90 |
-
|
|
|
|
| 91 |
label="Configuration Name",
|
| 92 |
-
|
|
|
|
| 93 |
)
|
| 94 |
|
|
|
|
|
|
|
|
|
|
| 95 |
with gr.Column(scale=1):
|
| 96 |
load_btn = gr.Button("Load Dataset", variant="primary")
|
| 97 |
|
|
@@ -196,7 +250,25 @@ with gr.Blocks(title="RFM Dataset Visualizer") as demo:
|
|
| 196 |
interactive=False
|
| 197 |
)
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
# Connect the components
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
load_btn.click(
|
| 201 |
fn=load_dataset,
|
| 202 |
inputs=[dataset_name_input, config_name_input],
|
|
@@ -237,9 +309,17 @@ with gr.Blocks(title="RFM Dataset Visualizer") as demo:
|
|
| 237 |
outputs=slider
|
| 238 |
)
|
| 239 |
|
| 240 |
-
# Load initial dataset
|
| 241 |
demo.load(
|
| 242 |
-
fn=lambda:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
outputs=[current_dataset, status_output, dataset_info, current_index]
|
| 244 |
).then(
|
| 245 |
fn=update_slider_range,
|
|
|
|
| 2 |
from datasets import load_dataset as load_dataset_hf
|
| 3 |
import os
|
| 4 |
|
| 5 |
+
# Predefined dataset names (configs will be fetched dynamically)
|
| 6 |
+
PREDEFINED_DATASETS = [
|
| 7 |
+
"aliangdw/rfm",
|
| 8 |
+
"abraranwar/libero_rfm",
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
def load_rfm_dataset(dataset_name, config_name):
|
| 12 |
"""Load the RFM dataset from HuggingFace Hub."""
|
| 13 |
try:
|
| 14 |
+
# Validate inputs
|
| 15 |
+
if not dataset_name or not config_name:
|
| 16 |
+
return None, "❌ Please provide both dataset name and configuration"
|
| 17 |
+
|
| 18 |
+
# Try to load the dataset
|
| 19 |
dataset = load_dataset_hf(dataset_name, name=config_name, split="train")
|
| 20 |
+
|
| 21 |
+
# Check if dataset has the expected structure
|
| 22 |
+
expected_features = ["task", "frames", "quality_label", "is_robot", "data_source"]
|
| 23 |
+
missing_features = [f for f in expected_features if f not in dataset.features]
|
| 24 |
+
|
| 25 |
+
if missing_features:
|
| 26 |
+
return None, f"⚠️ Dataset loaded but missing expected features: {missing_features}"
|
| 27 |
+
|
| 28 |
+
# Check if dataset has any samples
|
| 29 |
+
if len(dataset) == 0:
|
| 30 |
+
return None, f"⚠️ Dataset {dataset_name}/{config_name} is empty"
|
| 31 |
+
|
| 32 |
return dataset, f"✅ Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}"
|
| 33 |
except Exception as e:
|
| 34 |
+
error_msg = str(e)
|
| 35 |
+
if "not found" in error_msg.lower():
|
| 36 |
+
return None, f"❌ Dataset or configuration not found: {dataset_name}/{config_name}"
|
| 37 |
+
elif "authentication" in error_msg.lower():
|
| 38 |
+
return None, f"❌ Authentication required for {dataset_name}"
|
| 39 |
+
else:
|
| 40 |
+
return None, f"❌ Error loading dataset: {error_msg}"
|
| 41 |
|
| 42 |
def get_available_configs(dataset_name):
|
| 43 |
"""Get available configurations for a dataset."""
|
|
|
|
| 49 |
print(f"Error getting configs: {e}")
|
| 50 |
return []
|
| 51 |
|
| 52 |
+
def update_config_choices_with_custom(dataset_name):
|
| 53 |
+
"""Update config choices by fetching from the dataset."""
|
| 54 |
+
if not dataset_name:
|
| 55 |
+
return gr.update(choices=[], value="")
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
# Always try to fetch configs from the dataset
|
| 59 |
+
configs = get_available_configs(dataset_name)
|
| 60 |
+
if configs:
|
| 61 |
+
current_value = configs[0]
|
| 62 |
+
return gr.update(choices=configs, value=current_value)
|
| 63 |
+
else:
|
| 64 |
+
return gr.update(choices=[], value="")
|
| 65 |
+
except Exception as e:
|
| 66 |
+
# If fetching fails, allow custom input
|
| 67 |
+
print(f"Warning: Could not fetch configs for {dataset_name}: {e}")
|
| 68 |
+
return gr.update(choices=[], value="")
|
| 69 |
+
|
| 70 |
def visualize_trajectory(dataset, index, dataset_name=None):
|
| 71 |
"""
|
| 72 |
Function to retrieve a trajectory and its metadata from the dataset.
|
|
|
|
| 79 |
|
| 80 |
# Get metadata
|
| 81 |
task = item["task"]
|
| 82 |
+
optimal = item["optimal"]
|
| 83 |
is_robot = item["is_robot"]
|
| 84 |
data_source = item["data_source"]
|
| 85 |
|
|
|
|
| 103 |
|
| 104 |
**Language Task:** {task}
|
| 105 |
|
| 106 |
+
**Optimal:** {optimal}
|
| 107 |
|
| 108 |
**Data Type:** {'Robot' if is_robot else 'Human'}
|
| 109 |
|
|
|
|
| 125 |
|
| 126 |
# Dataset selection
|
| 127 |
with gr.Row():
|
| 128 |
+
with gr.Column(scale=2):
|
| 129 |
+
dataset_name_input = gr.Dropdown(
|
| 130 |
+
choices=PREDEFINED_DATASETS,
|
| 131 |
value="aliangdw/rfm",
|
| 132 |
label="Dataset Name",
|
| 133 |
+
allow_custom_value=True,
|
| 134 |
+
placeholder="Select dataset or enter custom"
|
| 135 |
)
|
| 136 |
|
| 137 |
+
with gr.Column(scale=2):
|
| 138 |
+
config_name_input = gr.Dropdown(
|
| 139 |
+
choices=[],
|
| 140 |
+
value="",
|
| 141 |
label="Configuration Name",
|
| 142 |
+
allow_custom_value=True,
|
| 143 |
+
placeholder="Select config or enter custom"
|
| 144 |
)
|
| 145 |
|
| 146 |
+
with gr.Column(scale=1):
|
| 147 |
+
refresh_btn = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm")
|
| 148 |
+
|
| 149 |
with gr.Column(scale=1):
|
| 150 |
load_btn = gr.Button("Load Dataset", variant="primary")
|
| 151 |
|
|
|
|
| 250 |
interactive=False
|
| 251 |
)
|
| 252 |
|
| 253 |
+
def update_config_choices(dataset_name):
|
| 254 |
+
"""Update the config dropdown choices based on selected dataset."""
|
| 255 |
+
return update_config_choices_with_custom(dataset_name)
|
| 256 |
+
|
| 257 |
# Connect the components
|
| 258 |
+
# Update config choices when dataset changes
|
| 259 |
+
dataset_name_input.change(
|
| 260 |
+
fn=update_config_choices,
|
| 261 |
+
inputs=[dataset_name_input],
|
| 262 |
+
outputs=[config_name_input]
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Refresh configs button for custom datasets
|
| 266 |
+
refresh_btn.click(
|
| 267 |
+
fn=update_config_choices_with_custom,
|
| 268 |
+
inputs=[dataset_name_input],
|
| 269 |
+
outputs=[config_name_input]
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
load_btn.click(
|
| 273 |
fn=load_dataset,
|
| 274 |
inputs=[dataset_name_input, config_name_input],
|
|
|
|
| 309 |
outputs=slider
|
| 310 |
)
|
| 311 |
|
| 312 |
+
# Load initial dataset and configs
|
| 313 |
demo.load(
|
| 314 |
+
fn=lambda: ("aliangdw/rfm", "libero_10"), # Set initial values
|
| 315 |
+
outputs=[dataset_name_input, config_name_input]
|
| 316 |
+
).then(
|
| 317 |
+
fn=update_config_choices_with_custom,
|
| 318 |
+
inputs=[dataset_name_input],
|
| 319 |
+
outputs=[config_name_input]
|
| 320 |
+
).then(
|
| 321 |
+
fn=lambda dataset_name, config_name: load_dataset(dataset_name, config_name),
|
| 322 |
+
inputs=[dataset_name_input, config_name_input],
|
| 323 |
outputs=[current_dataset, status_output, dataset_info, current_index]
|
| 324 |
).then(
|
| 325 |
fn=update_slider_range,
|