#!/usr/bin/env python3 """ Gradio app for RBM (Reward Foundation Model) inference visualization. Supports single video (progress/success) and dual video (preference/progress) predictions. Uses eval server for inference instead of loading models locally. """ import os import tempfile from pathlib import Path from typing import Optional, Tuple import logging import gradio as gr try: import spaces # Required for ZeroGPU on Hugging Face Spaces except ImportError: spaces = None # Not available when running locally import matplotlib matplotlib.use("Agg") # Use non-interactive backend import matplotlib.pyplot as plt import numpy as np import requests from typing import Any, List, Optional, Tuple from dataset_types import Trajectory, ProgressSample, PreferenceSample from eval_utils import build_payload, post_batch_npy from eval_viz_utils import create_combined_progress_success_plot, extract_frames from datasets import load_dataset as load_dataset_hf, get_dataset_config_names logger = logging.getLogger(__name__) # Predefined dataset names (same as visualizer) PREDEFINED_DATASETS = [ "abraranwar/agibotworld_alpha_rfm", "abraranwar/libero_rfm", "abraranwar/usc_koch_rewind_rfm", "aliangdw/metaworld", "anqil/rh20t_rfm", "anqil/rh20t_subset_rfm", "jesbu1/auto_eval_rfm", "jesbu1/egodex_rfm", "jesbu1/epic_rfm", "jesbu1/fino_net_rfm", "jesbu1/failsafe_rfm", "jesbu1/hand_paired_rfm", "jesbu1/galaxea_rfm", "jesbu1/h2r_rfm", "jesbu1/humanoid_everyday_rfm", "jesbu1/molmoact_rfm", "jesbu1/motif_rfm", "jesbu1/oxe_rfm", "jesbu1/oxe_rfm_eval", "jesbu1/ph2d_rfm", "jesbu1/racer_rfm", "jesbu1/roboarena_0825_rfm", "jesbu1/soar_rfm", "ykorkmaz/libero_failure_rfm", "aliangdw/usc_xarm_policy_ranking", "aliangdw/usc_franka_policy_ranking", "aliangdw/utd_so101_policy_ranking", "aliangdw/utd_so101_human", "jesbu1/utd_so101_clean_policy_ranking_top", "jesbu1/utd_so101_clean_policy_ranking_wrist", "jesbu1/mit_franka_p-rank_rfm", "jesbu1/usc_koch_p_ranking_rfm", ] # Global server state _server_state = { "server_url": None, "base_url": "https://robometer.a.pinggy.link", # Default: Pinggy tunnel or use http://HOST for port scan } def discover_available_models( base_url: str = "http://40.119.56.66", port_range: tuple = (8000, 8010) ) -> List[Tuple[str, str]]: """Discover available models by pinging the base URL as-is, or ports in the specified range. If base_url is a full URL (e.g. https://robometer.a.pinggy.link), it is tried as-is first. Otherwise we try base_url:8000, base_url:8001, ... up to end_port. Returns: List of tuples: [(server_url, model_name), ...] """ base_url = base_url.strip().rstrip("/") if not base_url: return [] available_models = [] # Try base_url as-is first (for Pinggy/tunnel URLs like https://robometer.a.pinggy.link) try: health_url = f"{base_url}/health" health_response = requests.get(health_url, timeout=5.0) if health_response.status_code == 200: try: model_info_url = f"{base_url}/model_info" model_info_response = requests.get(model_info_url, timeout=5.0) if model_info_response.status_code == 200: model_info_data = model_info_response.json() model_name = model_info_data.get("model_path", base_url) available_models.append((base_url, model_name)) else: available_models.append((base_url, base_url)) except Exception: available_models.append((base_url, base_url)) return available_models except requests.exceptions.RequestException: pass # Port scan: base_url is a host (e.g. http://40.119.56.66), try ports in range start_port, end_port = port_range for port in range(start_port, end_port + 1): server_url = f"{base_url}:{port}" try: health_url = f"{server_url}/health" health_response = requests.get(health_url, timeout=2.0) if health_response.status_code == 200: try: model_info_url = f"{server_url}/model_info" model_info_response = requests.get(model_info_url, timeout=2.0) if model_info_response.status_code == 200: model_info_data = model_info_response.json() model_name = model_info_data.get("model_path", f"Model on port {port}") available_models.append((server_url, model_name)) else: available_models.append((server_url, f"Model on port {port}")) except Exception: available_models.append((server_url, f"Model on port {port}")) except requests.exceptions.RequestException: continue return available_models def get_model_info_for_url(server_url: str) -> Optional[str]: """Get formatted model info for a given server URL.""" if not server_url: return None try: model_info_url = server_url.rstrip("/") + "/model_info" model_info_response = requests.get(model_info_url, timeout=5.0) if model_info_response.status_code == 200: model_info_data = model_info_response.json() return format_model_info(model_info_data) except Exception as e: logger.warning(f"Could not fetch model info: {e}") return None def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[str]]: """Check server health and get model info.""" if not server_url: return "Please provide a server URL.", None, None try: url = server_url.rstrip("/") + "/health" response = requests.get(url, timeout=5.0) response.raise_for_status() health_data = response.json() # Also try to get GPU status for more info try: status_url = server_url.rstrip("/") + "/gpu_status" status_response = requests.get(status_url, timeout=5.0) if status_response.status_code == 200: status_data = status_response.json() health_data.update(status_data) except: pass # Try to get model info model_info_text = get_model_info_for_url(server_url) _server_state["server_url"] = server_url return ( f"Server connected: {health_data.get('available_gpus', 0)}/{health_data.get('total_gpus', 0)} GPUs available", health_data, model_info_text, ) except requests.exceptions.RequestException as e: return f"Error connecting to server: {str(e)}", None, None def format_model_info(model_info: dict) -> str: """Format model info and experiment config as markdown.""" lines = ["## Model Information\n"] # Model path model_path = model_info.get("model_path", "Unknown") lines.append(f"**Model Path:** `{model_path}`\n") # Number of GPUs num_gpus = model_info.get("num_gpus", "Unknown") lines.append(f"**Number of GPUs:** {num_gpus}\n") # Model architecture model_arch = model_info.get("model_architecture", {}) if model_arch and "error" not in model_arch: lines.append("\n## Model Architecture\n") model_class = model_arch.get("model_class", "Unknown") model_module = model_arch.get("model_module", "Unknown") lines.append(f"- **Model Class:** `{model_class}`\n") lines.append(f"- **Module:** `{model_module}`\n") # Parameter counts total_params = model_arch.get("total_parameters") trainable_params = model_arch.get("trainable_parameters") frozen_params = model_arch.get("frozen_parameters") trainable_pct = model_arch.get("trainable_percentage") if total_params is not None: lines.append(f"\n### Parameter Statistics\n") lines.append(f"- **Total Parameters:** {total_params:,}\n") if trainable_params is not None: lines.append(f"- **Trainable Parameters:** {trainable_params:,}\n") if frozen_params is not None: lines.append(f"- **Frozen Parameters:** {frozen_params:,}\n") if trainable_pct is not None: lines.append(f"- **Trainable Percentage:** {trainable_pct:.2f}%\n") # Architecture summary arch_summary = model_arch.get("architecture_summary", []) if arch_summary: lines.append(f"\n### Architecture Summary (Top-Level Modules)\n") for module_info in arch_summary[:10]: # Show first 10 modules name = module_info.get("name", "Unknown") module_type = module_info.get("type", "Unknown") params = module_info.get("parameters", 0) lines.append(f"- **{name}** (`{module_type}`): {params:,} parameters\n") # Experiment config exp_config = model_info.get("experiment_config", {}) if exp_config: lines.append("\n## Experiment Configuration\n") # Model config model_cfg = exp_config.get("model", {}) if model_cfg: lines.append("### Model Configuration\n") lines.append(f"- **Base Model:** `{model_cfg.get('base_model_id', 'N/A')}`\n") lines.append(f"- **Model Type:** `{model_cfg.get('model_type', 'N/A')}`\n") lines.append(f"- **Train Progress Head:** {model_cfg.get('train_progress_head', False)}\n") lines.append(f"- **Train Preference Head:** {model_cfg.get('train_preference_head', False)}\n") lines.append(f"- **Train Success Head:** {model_cfg.get('train_success_head', False)}\n") lines.append(f"- **Use PEFT:** {model_cfg.get('use_peft', False)}\n") lines.append(f"- **Use Unsloth:** {model_cfg.get('use_unsloth', False)}\n") # Data config data_cfg = exp_config.get("data", {}) if data_cfg: lines.append("\n### Data Configuration\n") lines.append(f"- **Max Frames:** {data_cfg.get('max_frames', 'N/A')}\n") lines.append( f"- **Resized Dimensions:** {data_cfg.get('resized_height', 'N/A')}x{data_cfg.get('resized_width', 'N/A')}\n" ) train_datasets = data_cfg.get("train_datasets", []) if train_datasets: lines.append(f"- **Train Datasets:** {', '.join(train_datasets)}\n") eval_datasets = data_cfg.get("eval_datasets", []) if eval_datasets: lines.append(f"- **Eval Datasets:** {', '.join(eval_datasets)}\n") # Training config training_cfg = exp_config.get("training", {}) if training_cfg: lines.append("\n### Training Configuration\n") lines.append(f"- **Learning Rate:** {training_cfg.get('learning_rate', 'N/A')}\n") lines.append(f"- **Batch Size:** {training_cfg.get('per_device_train_batch_size', 'N/A')}\n") lines.append( f"- **Gradient Accumulation Steps:** {training_cfg.get('gradient_accumulation_steps', 'N/A')}\n" ) lines.append(f"- **Max Steps:** {training_cfg.get('max_steps', 'N/A')}\n") return "".join(lines) def load_rbm_dataset(dataset_name, config_name): """Load an RBM-format dataset from HuggingFace Hub.""" try: if not dataset_name or not config_name: return None, "Please provide both dataset name and configuration" dataset = load_dataset_hf(dataset_name, name=config_name, split="train") if len(dataset) == 0: return None, f"Dataset {dataset_name}/{config_name} is empty" return dataset, f"Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}" except Exception as e: error_msg = str(e) if "not found" in error_msg.lower(): return None, f"Dataset or configuration not found: {dataset_name}/{config_name}" elif "authentication" in error_msg.lower(): return None, f"Authentication required for {dataset_name}" else: return None, f"Error loading dataset: {error_msg}" def get_available_configs(dataset_name): """Get available configurations for a dataset.""" try: configs = get_dataset_config_names(dataset_name) return configs except Exception as e: logger.warning(f"Error getting configs for {dataset_name}: {e}") return [] def get_trajectory_video_path(dataset, index, dataset_name): """Get video path and metadata from a trajectory in the dataset.""" try: item = dataset[int(index)] frames_data = item["frames"] if isinstance(frames_data, str): # Construct HuggingFace Hub URL if dataset_name: video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}" else: video_path = f"https://huggingface.co/datasets/rewardfm/rbm-1m/resolve/main/{frames_data}" task = item.get("task", "Complete the task") quality_label = item.get("quality_label", None) partial_success = item.get("partial_success", None) return video_path, task, quality_label, partial_success else: return None, None, None, None except Exception as e: logger.error(f"Error getting trajectory video path: {e}") return None, None, None, None def process_single_video( video_path: str, task_text: str = "Complete the task", server_url: str = "", fps: float = 1.0, use_frame_steps: bool = False, ) -> Tuple[Optional[str], Optional[str]]: """Process single video for progress and success predictions using eval server.""" # Get server URL from state if not provided if not server_url: server_url = _server_state.get("server_url") if not server_url: return None, "Please select a model from the dropdown above and ensure it's connected." if video_path is None: return None, "Please provide a video." try: frames_array = extract_frames(video_path, fps=fps) if frames_array is None or frames_array.size == 0: return None, "Could not extract frames from video." # Convert frames to (T, H, W, C) numpy array with uint8 values if frames_array.dtype != np.uint8: frames_array = np.clip(frames_array, 0, 255).astype(np.uint8) num_frames = frames_array.shape[0] frames_shape = frames_array.shape # (T, H, W, C) # Create target progress (placeholder - would be None in real use) target_progress = np.linspace(0.0, 1.0, num=num_frames).tolist() success_label = [1.0 if prog > 0.5 else 0.0 for prog in target_progress] # predict_last_frame_mask: server collator requires a list (1.0 per frame = no masking for inference) predict_last_frame_mask = [1.0] * num_frames # Create Trajectory trajectory = Trajectory( task=task_text, frames=frames_array, frames_shape=frames_shape, target_progress=target_progress, success_label=success_label, predict_last_frame_mask=predict_last_frame_mask, metadata={"source": "gradio_app"}, ) # Create ProgressSample progress_sample = ProgressSample( trajectory=trajectory, data_gen_strategy="demo", ) # Build payload and send to server files, sample_data = build_payload([progress_sample]) # Add use_frame_steps flag as extra form data extra_data = {"use_frame_steps": use_frame_steps} if use_frame_steps else None response = post_batch_npy(server_url, files, sample_data, timeout_s=120.0, extra_form_data=extra_data) # Process response outputs_progress = response.get("outputs_progress", {}) progress_pred = outputs_progress.get("progress_pred", []) outputs_success = response.get("outputs_success", {}) success_probs = outputs_success.get("success_probs", []) if outputs_success else None # Extract progress predictions if progress_pred and len(progress_pred) > 0: progress_array = np.array(progress_pred[0]) # First sample else: progress_array = np.array([]) # Extract success predictions if available success_array = None if success_probs and len(success_probs) > 0: success_array = np.array(success_probs[0]) # Convert success_array to binary if available success_binary = None if success_array is not None: success_binary = (success_array > 0.5).astype(float) # Create combined plot using shared helper function fig = create_combined_progress_success_plot( progress_pred=progress_array if len(progress_array) > 0 else np.array([0.0]), num_frames=num_frames, success_binary=success_binary, success_probs=success_array, success_labels=None, # No ground truth labels available is_discrete_mode=False, title=f"Progress & Success - {task_text}", ) # Save to temporary file tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png") fig.savefig(tmp_file.name, dpi=150, bbox_inches="tight") plt.close(fig) progress_plot = tmp_file.name info_text = f"**Frames processed:** {num_frames}\n" if len(progress_array) > 0: info_text += f"**Final progress:** {progress_array[-1]:.3f}\n" if success_array is not None and len(success_array) > 0: info_text += f"**Final success probability:** {success_array[-1]:.3f}\n" # Return combined plot (which includes success if available) return progress_plot, info_text except Exception as e: return None, f"Error processing video: {str(e)}" def process_two_videos( video_a_path: str, video_b_path: str, task_text: str = "Complete the task", prediction_type: str = "preference", server_url: str = "", fps: float = 1.0, ) -> Tuple[Optional[str], Optional[str], Optional[str]]: """Process two videos for preference or progress prediction using eval server.""" # Get server URL from state if not provided if not server_url: server_url = _server_state.get("server_url") if not server_url: return "Please select a model from the dropdown above and ensure it's connected.", None, None if video_a_path is None or video_b_path is None: return "Please provide both videos.", None, None try: frames_array_a = extract_frames(video_a_path, fps=fps) frames_array_b = extract_frames(video_b_path, fps=fps) if frames_array_a is None or frames_array_a.size == 0: return "Could not extract frames from video A.", None, None if frames_array_b is None or frames_array_b.size == 0: return "Could not extract frames from video B.", None, None # Convert frames to uint8 if frames_array_a.dtype != np.uint8: frames_array_a = np.clip(frames_array_a, 0, 255).astype(np.uint8) if frames_array_b.dtype != np.uint8: frames_array_b = np.clip(frames_array_b, 0, 255).astype(np.uint8) num_frames_a = frames_array_a.shape[0] num_frames_b = frames_array_b.shape[0] frames_shape_a = frames_array_a.shape frames_shape_b = frames_array_b.shape # Create target progress for both trajectories target_progress_a = np.linspace(0.0, 1.0, num=num_frames_a).tolist() target_progress_b = np.linspace(0.0, 1.0, num=num_frames_b).tolist() success_label_a = [1.0 if prog > 0.5 else 0.0 for prog in target_progress_a] success_label_b = [1.0 if prog > 0.5 else 0.0 for prog in target_progress_b] # predict_last_frame_mask: server collator requires a list per trajectory (1.0 = no masking) mask_a = [1.0] * num_frames_a mask_b = [1.0] * num_frames_b # Create trajectories trajectory_a = Trajectory( task=task_text, frames=frames_array_a, frames_shape=frames_shape_a, target_progress=target_progress_a, success_label=success_label_a, predict_last_frame_mask=mask_a, metadata={"source": "gradio_app", "trajectory": "A"}, ) trajectory_b = Trajectory( task=task_text, frames=frames_array_b, frames_shape=frames_shape_b, target_progress=target_progress_b, success_label=success_label_b, predict_last_frame_mask=mask_b, metadata={"source": "gradio_app", "trajectory": "B"}, ) if prediction_type == "preference": # Create PreferenceSample (A = chosen, B = rejected) preference_sample = PreferenceSample( chosen_trajectory=trajectory_a, rejected_trajectory=trajectory_b, data_gen_strategy="demo", ) # Build payload and send to server files, sample_data = build_payload([preference_sample]) response = post_batch_npy(server_url, files, sample_data, timeout_s=120.0) # Process response outputs_preference = response.get("outputs_preference", {}) predictions = outputs_preference.get("predictions", []) prediction_probs = outputs_preference.get("prediction_probs", []) result_text = f"**Preference Prediction:**\n" if prediction_probs and len(prediction_probs) > 0: prob = prediction_probs[0] result_text += f"- Probability (A preferred): {prob:.3f}\n" result_text += f"- Interpretation: {'Video A is preferred' if prob > 0.5 else 'Video B is preferred'}\n" else: result_text += "Could not extract preference prediction from server response.\n" elif prediction_type == "progress": # Create ProgressSamples for both videos progress_sample_a = ProgressSample( trajectory=trajectory_a, data_gen_strategy="demo", ) progress_sample_b = ProgressSample( trajectory=trajectory_b, data_gen_strategy="demo", ) # Build payload and send to server files, sample_data = build_payload([progress_sample_a, progress_sample_b]) response = post_batch_npy(server_url, files, sample_data, timeout_s=120.0) # Process response outputs_progress = response.get("outputs_progress", {}) progress_pred = outputs_progress.get("progress_pred", []) result_text = f"**Progress Comparison:**\n" if progress_pred and len(progress_pred) >= 2: progress_a = np.array(progress_pred[0]) progress_b = np.array(progress_pred[1]) final_progress_a = float(progress_a[-1]) if len(progress_a) > 0 else 0.0 final_progress_b = float(progress_b[-1]) if len(progress_b) > 0 else 0.0 result_text += f"- Video A final progress: {final_progress_a:.3f}\n" result_text += f"- Video B final progress: {final_progress_b:.3f}\n" result_text += f"- Difference: {abs(final_progress_a - final_progress_b):.3f}\n" if final_progress_a > final_progress_b: result_text += f"- Video A has higher progress\n" elif final_progress_b > final_progress_a: result_text += f"- Video B has higher progress\n" else: result_text += f"- Both videos have equal progress\n" else: result_text += "Could not extract progress predictions from server response.\n" # Return result text and both video paths return result_text, video_a_path, video_b_path except Exception as e: return f"Error processing videos: {str(e)}", None, None # Create Gradio interface try: # Try with theme (Gradio 4.0+) demo = gr.Blocks(title="Robometer Evaluation Server", theme=gr.themes.Soft()) except TypeError: # Fallback for older Gradio versions without theme support demo = gr.Blocks(title="Robometer Evaluation Server") with demo: gr.Markdown( """ # Robometer Evaluation Server """ ) # Hidden state to store server URL and model mapping (define before use) server_url_state = gr.State(value=None) model_url_mapping_state = gr.State(value={}) # Maps model_name -> server_url # Function definitions for event handlers def discover_and_select_models(base_url: str): """Discover models and update dropdown.""" if not base_url: return ( gr.update(choices=[], value=None), gr.update(value="Please provide a base URL", visible=True), gr.update(value="", visible=True), None, {}, # Empty mapping ) _server_state["base_url"] = base_url models = discover_available_models(base_url, port_range=(8000, 8010)) if not models: return ( gr.update(choices=[], value=None), gr.update(value="❌ No models found on ports 8000-8010. Make sure servers are running.", visible=True), gr.update(value="", visible=True), None, {}, # Empty mapping ) # Format choices: show model_name in dropdown # Store mapping of model_name to URL in state choices = [] url_map = {} for url, name in models: choices.append(name) url_map[name] = url # Auto-select first model selected_choice = choices[0] if choices else None selected_url = url_map.get(selected_choice) if selected_choice else None # Get model info for selected model model_info_text = get_model_info_for_url(selected_url) if selected_url else "" status_text = f"✅ Found {len(models)} model(s). Auto-selected first model." _server_state["server_url"] = selected_url return ( gr.update(choices=choices, value=selected_choice), gr.update(value=status_text, visible=True), gr.update(value=model_info_text, visible=True), selected_url, url_map, # Return mapping for state ) def on_model_selected(model_choice: str, url_mapping: dict): """Handle model selection change.""" if not model_choice: return ( gr.update(value="No model selected", visible=True), gr.update(value="", visible=True), None, ) # Get URL from mapping server_url = url_mapping.get(model_choice) if url_mapping else None if not server_url: return ( gr.update( value="Could not find server URL for selected model. Please rediscover models.", visible=True ), gr.update(value="", visible=True), None, ) # Get model info model_info_text = get_model_info_for_url(server_url) or "" status, health_data, _ = check_server_health(server_url) _server_state["server_url"] = server_url return ( gr.update(value=status, visible=True), gr.update(value=model_info_text, visible=True), server_url, ) # Use Gradio's built-in Sidebar component (collapsible by default) with gr.Sidebar(): gr.Markdown("### 🔧 Model Configuration") base_url_input = gr.Textbox( label="Base Server URL", placeholder="https://robometer.a.pinggy.link or http://40.119.56.66", value="https://robometer.a.pinggy.link", interactive=True, info="Full URL (e.g. Pinggy tunnel) or host; discovery tries URL as-is first, then ports 8000-8010", ) discover_btn = gr.Button("🔍 Discover Models", variant="primary", size="lg") model_dropdown = gr.Dropdown( label="Select Model", choices=[], value=None, interactive=True, info="Click Discover to find the eval server (single URL or ports 8000-8010)", ) server_status = gr.Markdown("Click 'Discover Models' to find available models") gr.Markdown("---") gr.Markdown("### 📋 Model Information") model_info_display = gr.Markdown("") # Event handlers for sidebar discover_btn.click( fn=discover_and_select_models, inputs=[base_url_input], outputs=[model_dropdown, server_status, model_info_display, server_url_state, model_url_mapping_state], ) model_dropdown.change( fn=on_model_selected, inputs=[model_dropdown, model_url_mapping_state], outputs=[server_status, model_info_display, server_url_state], ) # Main content area with tabs with gr.Tabs(): with gr.Tab("Progress Prediction"): with gr.Row(): with gr.Column(): single_video_input = gr.Video(label="Upload Video", height=300) task_text_input = gr.Textbox( label="Task Description", placeholder="Describe the task (e.g., 'Pick up the red block')", value="Complete the task", ) fps_input_single = gr.Slider( label="FPS (Frames Per Second)", minimum=0.1, maximum=10.0, value=1.0, step=0.1, info="Frames per second to extract from video (higher = more frames)", ) use_frame_steps_single = gr.Checkbox( label="Per Frame Progress Prediction", value=False, info="If enabled, predict progress per frame rather than feeding the entire video at once", ) analyze_single_btn = gr.Button("Compute Progress", variant="primary") gr.Markdown("---") gr.Markdown("**OR Select from Dataset**") gr.Markdown("---") with gr.Accordion("📁 Select from Dataset", open=False): dataset_name_single = gr.Dropdown( choices=PREDEFINED_DATASETS, value="jesbu1/oxe_rfm", label="Dataset Name", allow_custom_value=True, ) config_name_single = gr.Dropdown( choices=[], value="", label="Configuration Name", allow_custom_value=True ) with gr.Row(): refresh_configs_btn = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm") load_dataset_btn = gr.Button("Load Dataset", variant="secondary", size="sm") dataset_status_single = gr.Markdown("", visible=False) with gr.Row(): prev_traj_btn = gr.Button("⬅️ Prev", variant="secondary", size="sm") trajectory_slider = gr.Slider( minimum=0, maximum=0, step=1, value=0, label="Trajectory Index", interactive=True ) next_traj_btn = gr.Button("Next ➡️", variant="secondary", size="sm") trajectory_metadata = gr.Markdown("", visible=False) use_dataset_video_btn = gr.Button("Use Selected Video", variant="secondary") with gr.Column(): progress_plot = gr.Image(label="Progress & Success Prediction", height=400) info_output = gr.Markdown("") # State variables for dataset current_dataset_single = gr.State(None) def update_config_choices_single(dataset_name): """Update config choices when dataset changes.""" if not dataset_name: return gr.update(choices=[], value="") try: configs = get_available_configs(dataset_name) if configs: return gr.update(choices=configs, value=configs[0]) else: return gr.update(choices=[], value="") except Exception as e: logger.warning(f"Could not fetch configs: {e}") return gr.update(choices=[], value="") def load_dataset_single(dataset_name, config_name): """Load dataset and update slider.""" dataset, status = load_rbm_dataset(dataset_name, config_name) if dataset is not None: max_index = len(dataset) - 1 return ( dataset, gr.update(value=status, visible=True), gr.update( maximum=max_index, value=0, interactive=True, label=f"Trajectory Index (0 to {max_index})" ), ) else: return None, gr.update(value=status, visible=True), gr.update(maximum=0, value=0, interactive=False) def use_dataset_video(dataset, index, dataset_name): """Load video from dataset and update inputs.""" if dataset is None: return ( None, "Complete the task", gr.update(value="No dataset loaded", visible=True), gr.update(visible=False), ) video_path, task, quality_label, partial_success = get_trajectory_video_path( dataset, index, dataset_name ) if video_path: # Build metadata text metadata_lines = [] if quality_label: metadata_lines.append(f"**Quality Label:** {quality_label}") if partial_success is not None: metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") metadata_text = "\n".join(metadata_lines) if metadata_lines else "" status_text = f"✅ Loaded trajectory {index} from dataset" if metadata_text: status_text += f"\n\n{metadata_text}" return ( video_path, task, gr.update(value=status_text, visible=True), gr.update(value=metadata_text, visible=bool(metadata_text)), ) else: return ( None, "Complete the task", gr.update(value="❌ Error loading trajectory", visible=True), gr.update(visible=False), ) def next_trajectory(dataset, current_idx, dataset_name): """Go to next trajectory.""" if dataset is None: return 0, None, "Complete the task", gr.update(visible=False), gr.update(visible=False) next_idx = min(current_idx + 1, len(dataset) - 1) video_path, task, quality_label, partial_success = get_trajectory_video_path( dataset, next_idx, dataset_name ) if video_path: # Build metadata text metadata_lines = [] if quality_label: metadata_lines.append(f"**Quality Label:** {quality_label}") if partial_success is not None: metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") metadata_text = "\n".join(metadata_lines) if metadata_lines else "" return ( next_idx, video_path, task, gr.update(value=metadata_text, visible=bool(metadata_text)), gr.update(value=f"✅ Trajectory {next_idx}/{len(dataset) - 1}", visible=True), ) else: return current_idx, None, "Complete the task", gr.update(visible=False), gr.update(visible=False) def prev_trajectory(dataset, current_idx, dataset_name): """Go to previous trajectory.""" if dataset is None: return 0, None, "Complete the task", gr.update(visible=False), gr.update(visible=False) prev_idx = max(current_idx - 1, 0) video_path, task, quality_label, partial_success = get_trajectory_video_path( dataset, prev_idx, dataset_name ) if video_path: # Build metadata text metadata_lines = [] if quality_label: metadata_lines.append(f"**Quality Label:** {quality_label}") if partial_success is not None: metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") metadata_text = "\n".join(metadata_lines) if metadata_lines else "" return ( prev_idx, video_path, task, gr.update(value=metadata_text, visible=bool(metadata_text)), gr.update(value=f"✅ Trajectory {prev_idx}/{len(dataset) - 1}", visible=True), ) else: return current_idx, None, "Complete the task", gr.update(visible=False), gr.update(visible=False) def update_trajectory_on_slider_change(dataset, index, dataset_name): """Update trajectory metadata when slider changes.""" if dataset is None: return gr.update(visible=False), gr.update(visible=False) video_path, task, quality_label, partial_success = get_trajectory_video_path( dataset, index, dataset_name ) if video_path: # Build metadata text metadata_lines = [] if quality_label: metadata_lines.append(f"**Quality Label:** {quality_label}") if partial_success is not None: metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") metadata_text = "\n".join(metadata_lines) if metadata_lines else "" return ( gr.update(value=metadata_text, visible=bool(metadata_text)), gr.update(value=f"Trajectory {index}/{len(dataset) - 1}", visible=True), ) else: return gr.update(visible=False), gr.update(visible=False) # Dataset selection handlers dataset_name_single.change( fn=update_config_choices_single, inputs=[dataset_name_single], outputs=[config_name_single] ) refresh_configs_btn.click( fn=update_config_choices_single, inputs=[dataset_name_single], outputs=[config_name_single] ) load_dataset_btn.click( fn=load_dataset_single, inputs=[dataset_name_single, config_name_single], outputs=[current_dataset_single, dataset_status_single, trajectory_slider], ) use_dataset_video_btn.click( fn=use_dataset_video, inputs=[current_dataset_single, trajectory_slider, dataset_name_single], outputs=[single_video_input, task_text_input, dataset_status_single, trajectory_metadata], ) # Navigation buttons next_traj_btn.click( fn=next_trajectory, inputs=[current_dataset_single, trajectory_slider, dataset_name_single], outputs=[ trajectory_slider, single_video_input, task_text_input, trajectory_metadata, dataset_status_single, ], ) prev_traj_btn.click( fn=prev_trajectory, inputs=[current_dataset_single, trajectory_slider, dataset_name_single], outputs=[ trajectory_slider, single_video_input, task_text_input, trajectory_metadata, dataset_status_single, ], ) # Update metadata when slider changes trajectory_slider.change( fn=update_trajectory_on_slider_change, inputs=[current_dataset_single, trajectory_slider, dataset_name_single], outputs=[trajectory_metadata, dataset_status_single], ) analyze_single_btn.click( fn=process_single_video, inputs=[ single_video_input, task_text_input, server_url_state, fps_input_single, use_frame_steps_single, ], outputs=[progress_plot, info_output], api_name="process_single_video", ) with gr.Tab("Preference Analysis"): # Full-width row: two videos side by side with gr.Row(): video_a_input = gr.Video(label="Video A", height=320) video_b_input = gr.Video(label="Video B", height=320) task_text_dual = gr.Textbox( label="Task Description", placeholder="Describe the task", value="Complete the task", ) analyze_dual_btn = gr.Button("Compute Preference", variant="primary") gr.Markdown("---") gr.Markdown("**OR Select from Dataset**") gr.Markdown("---") with gr.Accordion("📁 Video A - Select from Dataset", open=False): dataset_name_a = gr.Dropdown( choices=PREDEFINED_DATASETS, value="jesbu1/oxe_rfm", label="Dataset Name", allow_custom_value=True, ) config_name_a = gr.Dropdown( choices=[], value="", label="Configuration Name", allow_custom_value=True ) with gr.Row(): refresh_configs_btn_a = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm") load_dataset_btn_a = gr.Button("Load Dataset", variant="secondary", size="sm") dataset_status_a = gr.Markdown("", visible=False) with gr.Row(): prev_traj_btn_a = gr.Button("⬅️ Prev", variant="secondary", size="sm") trajectory_slider_a = gr.Slider( minimum=0, maximum=0, step=1, value=0, label="Trajectory Index", interactive=True ) next_traj_btn_a = gr.Button("Next ➡️", variant="secondary", size="sm") trajectory_metadata_a = gr.Markdown("", visible=False) use_dataset_video_btn_a = gr.Button("Use Selected Video for A", variant="secondary") with gr.Accordion("📁 Video B - Select from Dataset", open=False): dataset_name_b = gr.Dropdown( choices=PREDEFINED_DATASETS, value="jesbu1/oxe_rfm", label="Dataset Name", allow_custom_value=True, ) config_name_b = gr.Dropdown( choices=[], value="", label="Configuration Name", allow_custom_value=True ) with gr.Row(): refresh_configs_btn_b = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm") load_dataset_btn_b = gr.Button("Load Dataset", variant="secondary", size="sm") dataset_status_b = gr.Markdown("", visible=False) with gr.Row(): prev_traj_btn_b = gr.Button("⬅️ Prev", variant="secondary", size="sm") trajectory_slider_b = gr.Slider( minimum=0, maximum=0, step=1, value=0, label="Trajectory Index", interactive=True ) next_traj_btn_b = gr.Button("Next ➡️", variant="secondary", size="sm") trajectory_metadata_b = gr.Markdown("", visible=False) use_dataset_video_btn_b = gr.Button("Use Selected Video for B", variant="secondary") gr.Markdown("---") gr.Markdown("### Preference result") result_text = gr.Markdown("") # State variables for datasets current_dataset_a = gr.State(None) current_dataset_b = gr.State(None) # Helper functions for Video A def update_config_choices_a(dataset_name): """Update config choices for Video A when dataset changes.""" if not dataset_name: return gr.update(choices=[], value="") try: configs = get_available_configs(dataset_name) if configs: return gr.update(choices=configs, value=configs[0]) else: return gr.update(choices=[], value="") except Exception as e: logger.warning(f"Could not fetch configs: {e}") return gr.update(choices=[], value="") def load_dataset_a(dataset_name, config_name): """Load dataset A and update slider.""" dataset, status = load_rbm_dataset(dataset_name, config_name) if dataset is not None: max_index = len(dataset) - 1 return ( dataset, gr.update(value=status, visible=True), gr.update( maximum=max_index, value=0, interactive=True, label=f"Trajectory Index (0 to {max_index})" ), ) else: return None, gr.update(value=status, visible=True), gr.update(maximum=0, value=0, interactive=False) def use_dataset_video_a(dataset, index, dataset_name): """Load video A from dataset and update input.""" if dataset is None: return ( None, gr.update(value="No dataset loaded", visible=True), gr.update(visible=False), ) video_path, task, quality_label, partial_success = get_trajectory_video_path( dataset, index, dataset_name ) if video_path: # Build metadata text metadata_lines = [] if quality_label: metadata_lines.append(f"**Quality Label:** {quality_label}") if partial_success is not None: metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") metadata_text = "\n".join(metadata_lines) if metadata_lines else "" status_text = f"✅ Loaded trajectory {index} from dataset for Video A" if metadata_text: status_text += f"\n\n{metadata_text}" return ( video_path, gr.update(value=status_text, visible=True), gr.update(value=metadata_text, visible=bool(metadata_text)), ) else: return ( None, gr.update(value="❌ Error loading trajectory", visible=True), gr.update(visible=False), ) def next_trajectory_a(dataset, current_idx, dataset_name): """Go to next trajectory for Video A.""" if dataset is None: return 0, None, gr.update(visible=False), gr.update(visible=False) next_idx = min(current_idx + 1, len(dataset) - 1) video_path, task, quality_label, partial_success = get_trajectory_video_path( dataset, next_idx, dataset_name ) if video_path: # Build metadata text metadata_lines = [] if quality_label: metadata_lines.append(f"**Quality Label:** {quality_label}") if partial_success is not None: metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") metadata_text = "\n".join(metadata_lines) if metadata_lines else "" return ( next_idx, video_path, gr.update(value=metadata_text, visible=bool(metadata_text)), gr.update(value=f"✅ Trajectory {next_idx}/{len(dataset) - 1}", visible=True), ) else: return current_idx, None, gr.update(visible=False), gr.update(visible=False) def prev_trajectory_a(dataset, current_idx, dataset_name): """Go to previous trajectory for Video A.""" if dataset is None: return 0, None, gr.update(visible=False), gr.update(visible=False) prev_idx = max(current_idx - 1, 0) video_path, task, quality_label, partial_success = get_trajectory_video_path( dataset, prev_idx, dataset_name ) if video_path: # Build metadata text metadata_lines = [] if quality_label: metadata_lines.append(f"**Quality Label:** {quality_label}") if partial_success is not None: metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") metadata_text = "\n".join(metadata_lines) if metadata_lines else "" return ( prev_idx, video_path, gr.update(value=metadata_text, visible=bool(metadata_text)), gr.update(value=f"✅ Trajectory {prev_idx}/{len(dataset) - 1}", visible=True), ) else: return current_idx, None, gr.update(visible=False), gr.update(visible=False) def update_trajectory_on_slider_change_a(dataset, index, dataset_name): """Update trajectory metadata when slider changes for Video A.""" if dataset is None: return gr.update(visible=False), gr.update(visible=False) video_path, task, quality_label, partial_success = get_trajectory_video_path( dataset, index, dataset_name ) if video_path: # Build metadata text metadata_lines = [] if quality_label: metadata_lines.append(f"**Quality Label:** {quality_label}") if partial_success is not None: metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") metadata_text = "\n".join(metadata_lines) if metadata_lines else "" return ( gr.update(value=metadata_text, visible=bool(metadata_text)), gr.update(value=f"Trajectory {index}/{len(dataset) - 1}", visible=True), ) else: return gr.update(visible=False), gr.update(visible=False) # Helper functions for Video B (same as Video A) def update_config_choices_b(dataset_name): """Update config choices for Video B when dataset changes.""" if not dataset_name: return gr.update(choices=[], value="") try: configs = get_available_configs(dataset_name) if configs: return gr.update(choices=configs, value=configs[0]) else: return gr.update(choices=[], value="") except Exception as e: logger.warning(f"Could not fetch configs: {e}") return gr.update(choices=[], value="") def load_dataset_b(dataset_name, config_name): """Load dataset B and update slider.""" dataset, status = load_rbm_dataset(dataset_name, config_name) if dataset is not None: max_index = len(dataset) - 1 return ( dataset, gr.update(value=status, visible=True), gr.update( maximum=max_index, value=0, interactive=True, label=f"Trajectory Index (0 to {max_index})" ), ) else: return None, gr.update(value=status, visible=True), gr.update(maximum=0, value=0, interactive=False) def use_dataset_video_b(dataset, index, dataset_name): """Load video B from dataset and update input.""" if dataset is None: return ( None, gr.update(value="No dataset loaded", visible=True), gr.update(visible=False), ) video_path, task, quality_label, partial_success = get_trajectory_video_path( dataset, index, dataset_name ) if video_path: # Build metadata text metadata_lines = [] if quality_label: metadata_lines.append(f"**Quality Label:** {quality_label}") if partial_success is not None: metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") metadata_text = "\n".join(metadata_lines) if metadata_lines else "" status_text = f"✅ Loaded trajectory {index} from dataset for Video B" if metadata_text: status_text += f"\n\n{metadata_text}" return ( video_path, gr.update(value=status_text, visible=True), gr.update(value=metadata_text, visible=bool(metadata_text)), ) else: return ( None, gr.update(value="❌ Error loading trajectory", visible=True), gr.update(visible=False), ) def next_trajectory_b(dataset, current_idx, dataset_name): """Go to next trajectory for Video B.""" if dataset is None: return 0, None, gr.update(visible=False), gr.update(visible=False) next_idx = min(current_idx + 1, len(dataset) - 1) video_path, task, quality_label, partial_success = get_trajectory_video_path( dataset, next_idx, dataset_name ) if video_path: # Build metadata text metadata_lines = [] if quality_label: metadata_lines.append(f"**Quality Label:** {quality_label}") if partial_success is not None: metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") metadata_text = "\n".join(metadata_lines) if metadata_lines else "" return ( next_idx, video_path, gr.update(value=metadata_text, visible=bool(metadata_text)), gr.update(value=f"✅ Trajectory {next_idx}/{len(dataset) - 1}", visible=True), ) else: return current_idx, None, gr.update(visible=False), gr.update(visible=False) def prev_trajectory_b(dataset, current_idx, dataset_name): """Go to previous trajectory for Video B.""" if dataset is None: return 0, None, gr.update(visible=False), gr.update(visible=False) prev_idx = max(current_idx - 1, 0) video_path, task, quality_label, partial_success = get_trajectory_video_path( dataset, prev_idx, dataset_name ) if video_path: # Build metadata text metadata_lines = [] if quality_label: metadata_lines.append(f"**Quality Label:** {quality_label}") if partial_success is not None: metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") metadata_text = "\n".join(metadata_lines) if metadata_lines else "" return ( prev_idx, video_path, gr.update(value=metadata_text, visible=bool(metadata_text)), gr.update(value=f"✅ Trajectory {prev_idx}/{len(dataset) - 1}", visible=True), ) else: return current_idx, None, gr.update(visible=False), gr.update(visible=False) def update_trajectory_on_slider_change_b(dataset, index, dataset_name): """Update trajectory metadata when slider changes for Video B.""" if dataset is None: return gr.update(visible=False), gr.update(visible=False) video_path, task, quality_label, partial_success = get_trajectory_video_path( dataset, index, dataset_name ) if video_path: # Build metadata text metadata_lines = [] if quality_label: metadata_lines.append(f"**Quality Label:** {quality_label}") if partial_success is not None: metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") metadata_text = "\n".join(metadata_lines) if metadata_lines else "" return ( gr.update(value=metadata_text, visible=bool(metadata_text)), gr.update(value=f"Trajectory {index}/{len(dataset) - 1}", visible=True), ) else: return gr.update(visible=False), gr.update(visible=False) # Video A dataset selection handlers dataset_name_a.change(fn=update_config_choices_a, inputs=[dataset_name_a], outputs=[config_name_a]) refresh_configs_btn_a.click(fn=update_config_choices_a, inputs=[dataset_name_a], outputs=[config_name_a]) load_dataset_btn_a.click( fn=load_dataset_a, inputs=[dataset_name_a, config_name_a], outputs=[current_dataset_a, dataset_status_a, trajectory_slider_a], ) use_dataset_video_btn_a.click( fn=use_dataset_video_a, inputs=[current_dataset_a, trajectory_slider_a, dataset_name_a], outputs=[video_a_input, dataset_status_a, trajectory_metadata_a], ) next_traj_btn_a.click( fn=next_trajectory_a, inputs=[current_dataset_a, trajectory_slider_a, dataset_name_a], outputs=[ trajectory_slider_a, video_a_input, trajectory_metadata_a, dataset_status_a, ], ) prev_traj_btn_a.click( fn=prev_trajectory_a, inputs=[current_dataset_a, trajectory_slider_a, dataset_name_a], outputs=[ trajectory_slider_a, video_a_input, trajectory_metadata_a, dataset_status_a, ], ) trajectory_slider_a.change( fn=update_trajectory_on_slider_change_a, inputs=[current_dataset_a, trajectory_slider_a, dataset_name_a], outputs=[trajectory_metadata_a, dataset_status_a], ) # Video B dataset selection handlers dataset_name_b.change(fn=update_config_choices_b, inputs=[dataset_name_b], outputs=[config_name_b]) refresh_configs_btn_b.click(fn=update_config_choices_b, inputs=[dataset_name_b], outputs=[config_name_b]) load_dataset_btn_b.click( fn=load_dataset_b, inputs=[dataset_name_b, config_name_b], outputs=[current_dataset_b, dataset_status_b, trajectory_slider_b], ) use_dataset_video_btn_b.click( fn=use_dataset_video_b, inputs=[current_dataset_b, trajectory_slider_b, dataset_name_b], outputs=[video_b_input, dataset_status_b, trajectory_metadata_b], ) next_traj_btn_b.click( fn=next_trajectory_b, inputs=[current_dataset_b, trajectory_slider_b, dataset_name_b], outputs=[ trajectory_slider_b, video_b_input, trajectory_metadata_b, dataset_status_b, ], ) prev_traj_btn_b.click( fn=prev_trajectory_b, inputs=[current_dataset_b, trajectory_slider_b, dataset_name_b], outputs=[ trajectory_slider_b, video_b_input, trajectory_metadata_b, dataset_status_b, ], ) trajectory_slider_b.change( fn=update_trajectory_on_slider_change_b, inputs=[current_dataset_b, trajectory_slider_b, dataset_name_b], outputs=[trajectory_metadata_b, dataset_status_b], ) def run_preference_comparison(video_a, video_b, task_text, server_url): result, _, _ = process_two_videos( video_a, video_b, task_text, "preference", server_url, fps=1.0 ) return result analyze_dual_btn.click( fn=run_preference_comparison, inputs=[ video_a_input, video_b_input, task_text_dual, server_url_state, ], outputs=[result_text], api_name="process_two_videos", ) def main(): """Launch the Gradio app.""" demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, # Show full error messages ) if __name__ == "__main__": main()