Spaces:
Running
Running
File size: 13,479 Bytes
22f4bcb 68c08d3 22f4bcb 9a16674 99bc590 7cddc01 9f87b78 640bb3b b40820f 36df9f6 e49bcbb c04c5bb b4289ba 36df9f6 25813bc 78ca9a4 443fed2 6a3973a 02a22f3 946289c 51c4fed afa7f83 ddb316e f0e8104 02a22f3 6f3ce7a 946289c faac200 a1978c5 b62dad4 22f4bcb 9010913 22f4bcb dcb84d4 22f4bcb dcb84d4 22f4bcb 9967263 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 |
import gradio as gr
from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
import os
# Predefined dataset names (configs will be fetched dynamically)
PREDEFINED_DATASETS = [
"abraranwar/agibotworld_alpha_rfm",
"abraranwar/libero_rfm",
"abraranwar/usc_koch_rewind_rfm",
"aliangdw/metaworld",
"anqil/rh20t_rfm",
"anqil/rh20t_subset_rfm",
"jesbu1/auto_eval_rfm",
"jesbu1/egodex_rfm",
"jesbu1/epic_rfm",
"jesbu1/fino_net_rfm",
"jesbu1/failsafe_rfm",
"jesbu1/hand_paired_rfm",
"jesbu1/galaxea_rfm",
"jesbu1/mit_franka_p-rank_rfm",
"jesbu1/utd_so101_clean_policy_ranking_top",
"jesbu1/utd_so101_clean_policy_ranking_wrist",
"jesbu1/h2r_rfm",
"jesbu1/humanoid_everyday_rfm",
"jesbu1/molmoact_rfm",
"jesbu1/motif_rfm",
"jesbu1/oxe_rfm",
"jesbu1/oxe_rfm_eval",
"jesbu1/ph2d_rfm",
"jesbu1/racer_rfm",
"jesbu1/roboarena_0825_rfm",
"jesbu1/soar_rfm",
"jesbu1/usc_koch_human_robot_paired",
"jesbu1/usc_koch_p_ranking_rfm",
"ykorkmaz/libero_failure_rfm",
"aliangdw/usc_xarm_policy_ranking",
"aliangdw/usc_franka_policy_ranking",
"aliangdw/utd_so101_policy_ranking",
"aliangdw/utd_so101_human"
]
def load_rfm_dataset(dataset_name, config_name):
"""Load the RFM dataset from HuggingFace Hub."""
try:
# Validate inputs
if not dataset_name or not config_name:
return None, "β Please provide both dataset name and configuration"
# Try to load the dataset
dataset = load_dataset_hf(dataset_name, name=config_name, split="train")
# Check if dataset has the expected structure
expected_features = ["task", "frames", "quality_label", "is_robot", "data_source"]
missing_features = [f for f in expected_features if f not in dataset.features]
if missing_features:
return None, f"β οΈ Dataset loaded but missing expected features: {missing_features}"
# Check if dataset has any samples
if len(dataset) == 0:
return None, f"β οΈ Dataset {dataset_name}/{config_name} is empty"
return dataset, f"β
Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}"
except Exception as e:
error_msg = str(e)
if "not found" in error_msg.lower():
return None, f"β Dataset or configuration not found: {dataset_name}/{config_name}"
elif "authentication" in error_msg.lower():
return None, f"β Authentication required for {dataset_name}"
else:
return None, f"β Error loading dataset: {error_msg}"
def get_available_configs(dataset_name):
"""Get available configurations for a dataset."""
try:
# Use the dedicated function to get config names
configs = get_dataset_config_names(dataset_name)
return configs
except Exception as e:
print(f"Error getting configs for {dataset_name}: {e}")
return []
def update_config_choices_with_custom(dataset_name):
"""Update config choices by fetching from the dataset."""
if not dataset_name:
return gr.update(choices=[], value="")
try:
# Always try to fetch configs from the dataset
configs = get_available_configs(dataset_name)
if configs:
current_value = configs[0]
return gr.update(choices=configs, value=current_value)
else:
return gr.update(choices=[], value="")
except Exception as e:
# If fetching fails, allow custom input
print(f"Warning: Could not fetch configs for {dataset_name}: {e}")
return gr.update(choices=[], value="")
def visualize_trajectory(dataset, index, dataset_name=None):
"""
Function to retrieve a trajectory and its metadata from the dataset.
"""
if dataset is None:
return None, "Error: Could not load dataset", "Error: Could not load dataset", None
try:
item = dataset[int(index)]
# Get metadata
task = item["task"]
quality_label = item["quality_label"]
is_robot = item["is_robot"]
data_source = item["data_source"]
# Get the frames data (video file path)
frames_data = item["frames"]
# Handle video file path (could be local path or HuggingFace Hub URL)
if isinstance(frames_data, str):
# Use dynamic dataset name if provided, otherwise fallback to default
if dataset_name:
video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}"
else:
video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}"
frames_info = f"Video file: {video_path}"
else:
return None, "Error: Invalid video path", "Error: Invalid video path", None
# Create metadata
metadata = f"""
## Trajectory Information
**Video path:** {video_path}
**Language Task:** {task}
**Quality Label:** {quality_label}
**Data Type:** {'Robot' if is_robot else 'Human'}
**Source:** {data_source}
**Trajectory ID:** {item.get('id', 'N/A')}
"""
# Return video path for Gradio to display
return video_path, metadata, f"Trajectory {index}", None
except Exception as e:
return None, f"Error: {str(e)}", f"Error: {str(e)}", None
# Create the Gradio interface
with gr.Blocks(title="RFM Dataset Visualizer") as demo:
gr.Markdown("# RFM Dataset Visualizer")
gr.Markdown("Browse through trajectory videos and their metadata from the Reward Foundation Model dataset.")
# Dataset selection
with gr.Row():
with gr.Column(scale=2):
dataset_name_input = gr.Dropdown(
choices=PREDEFINED_DATASETS,
value="jesbu1/oxe_rfm",
label="Dataset Name",
allow_custom_value=True
)
with gr.Column(scale=2):
config_name_input = gr.Dropdown(
choices=[],
value="",
label="Configuration Name",
allow_custom_value=True
)
with gr.Column(scale=1):
refresh_btn = gr.Button("π Refresh Configs", variant="secondary", size="sm")
with gr.Column(scale=1):
load_btn = gr.Button("Load Dataset", variant="primary")
# Status message
status_output = gr.Markdown("Ready to load dataset...")
# Dataset info
dataset_info = gr.Markdown("")
# Visualization section
with gr.Row():
with gr.Column(scale=2):
# Video/Image display
video_output = gr.Video(label="Trajectory Video", height=400, autoplay=True)
image_output = gr.Image(label="Frame Preview", height=400, visible=False)
with gr.Column(scale=1):
# Metadata display
metadata_output = gr.Markdown(label="Metadata")
# Navigation controls
with gr.Row():
with gr.Column(scale=1):
prev_btn = gr.Button("β¬
οΈ Previous", variant="secondary")
with gr.Column(scale=2):
# Slider for navigation with dynamic max
slider = gr.Slider(
minimum=0,
maximum=0,
step=1,
value=0,
label="Select a dataset first",
interactive=False
)
with gr.Column(scale=1):
next_btn = gr.Button("Next β‘οΈ", variant="secondary")
# Current trajectory title
title_output = gr.Textbox(label="Current Trajectory", interactive=False)
# State variables
current_dataset = gr.State(None)
current_index = gr.State(0)
def load_dataset(dataset_name, config_name):
"""Load the dataset and update the interface."""
dataset, status = load_rfm_dataset(dataset_name, config_name)
if dataset is not None:
max_index = len(dataset) - 1
info = f"**Dataset Info:**\n- **Total Trajectories:** {len(dataset)}\n- **Features:** {list(dataset.features.keys())}"
# Return the dataset length for number input configuration
return dataset, status, info, 0, max_index
else:
return None, status, "", 0, 0
def update_trajectory(dataset, index, dataset_name=None):
"""Update the displayed trajectory."""
if dataset is None:
return None, "No dataset loaded", "No dataset loaded", None
# Ensure index is within bounds and is a valid number
if index is None or not isinstance(index, (int, float)):
index = 0
elif index >= len(dataset):
index = len(dataset) - 1
elif index < 0:
index = 0
return visualize_trajectory(dataset, int(index), dataset_name)
def next_trajectory(dataset, current_idx, dataset_name=None):
"""Go to next trajectory."""
if dataset is None:
return current_idx, None, "No dataset loaded", "No dataset loaded", None
next_idx = min(current_idx + 1, len(dataset) - 1)
video, metadata, title, image = visualize_trajectory(dataset, next_idx, dataset_name)
return next_idx, video, metadata, title, image
def prev_trajectory(dataset, current_idx, dataset_name=None):
"""Go to previous trajectory."""
if dataset is None:
return current_idx, None, "No dataset loaded", "No dataset loaded", None
prev_idx = max(current_idx - 1, 0)
video, metadata, title, image = visualize_trajectory(dataset, prev_idx, dataset_name)
return prev_idx, video, metadata, title, image
def update_slider_range(dataset):
"""Update the slider with new maximum value based on dataset length."""
if dataset is not None:
max_value = len(dataset) - 1
return gr.update(
maximum=max_value,
value=0, # Reset to beginning
label=f"Trajectory Index (0 to {max_value})",
interactive=True
)
else:
return gr.update(
maximum=0,
value=0,
label="Select a dataset first",
interactive=False
)
def update_config_choices(dataset_name):
"""Update the config dropdown choices based on selected dataset."""
return update_config_choices_with_custom(dataset_name)
# Connect the components
# Update config choices when dataset changes
dataset_name_input.change(
fn=update_config_choices,
inputs=[dataset_name_input],
outputs=[config_name_input]
)
# Refresh configs button for custom datasets
refresh_btn.click(
fn=update_config_choices_with_custom,
inputs=[dataset_name_input],
outputs=[config_name_input]
)
load_btn.click(
fn=load_dataset,
inputs=[dataset_name_input, config_name_input],
outputs=[current_dataset, status_output, dataset_info, current_index]
).then(
fn=update_slider_range,
inputs=current_dataset,
outputs=slider
)
slider.change(
fn=lambda dataset, idx, dataset_name: update_trajectory(dataset, idx, dataset_name),
inputs=[current_dataset, slider, dataset_name_input],
outputs=[video_output, metadata_output, title_output, image_output]
).then(
fn=lambda dataset, idx: idx if dataset is None or idx < len(dataset) else len(dataset) - 1,
inputs=[current_dataset, slider],
outputs=[current_index]
)
next_btn.click(
fn=lambda dataset, idx, dataset_name: next_trajectory(dataset, idx, dataset_name),
inputs=[current_dataset, current_index, dataset_name_input],
outputs=[current_index, video_output, metadata_output, title_output, image_output]
).then(
fn=lambda idx: idx,
inputs=current_index,
outputs=slider
)
prev_btn.click(
fn=lambda dataset, idx, dataset_name: prev_trajectory(dataset, idx, dataset_name),
inputs=[current_dataset, current_index, dataset_name_input],
outputs=[current_index, video_output, metadata_output, title_output, image_output]
).then(
fn=lambda idx: idx,
inputs=current_index,
outputs=slider
)
# Load initial dataset and configs
demo.load(
fn=lambda: ("jesbu1/oxe_rfm", "oxe_jaco_play"), # Set initial values
outputs=[dataset_name_input, config_name_input]
).then(
fn=update_config_choices_with_custom,
inputs=[dataset_name_input],
outputs=[config_name_input]
).then(
fn=lambda dataset_name, config_name: load_dataset(dataset_name, config_name),
inputs=[dataset_name_input, config_name_input],
outputs=[current_dataset, status_output, dataset_info, current_index]
).then(
fn=update_slider_range,
inputs=current_dataset,
outputs=slider
)
def main():
"""Main function to launch the RFM visualizer."""
demo.launch()
# Launch the app
if __name__ == "__main__":
main()
|