Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Gradio app for RFM (Reward Foundation Model) inference visualization. | |
| Supports single video (progress/success) and dual video (preference/similarity) predictions. | |
| Uses eval server for inference instead of loading models locally. | |
| """ | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| import logging | |
| import gradio as gr | |
| try: | |
| import spaces # Required for ZeroGPU on Hugging Face Spaces | |
| except ImportError: | |
| spaces = None # Not available when running locally | |
| import matplotlib | |
| matplotlib.use("Agg") # Use non-interactive backend | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import requests | |
| from typing import Any, List, Optional, Tuple | |
| from dataset_types import Trajectory, ProgressSample, PreferenceSample, SimilaritySample | |
| from eval_utils import build_payload, post_batch_npy | |
| from eval_viz_utils import create_combined_progress_success_plot, extract_frames | |
| from datasets import load_dataset as load_dataset_hf, get_dataset_config_names | |
| logger = logging.getLogger(__name__) | |
| # Predefined dataset names (same as visualizer) | |
| PREDEFINED_DATASETS = [ | |
| "abraranwar/agibotworld_alpha_rfm", | |
| "abraranwar/libero_rfm", | |
| "abraranwar/usc_koch_rewind_rfm", | |
| "aliangdw/metaworld", | |
| "anqil/rh20t_rfm", | |
| "anqil/rh20t_subset_rfm", | |
| "jesbu1/auto_eval_rfm", | |
| "jesbu1/egodex_rfm", | |
| "jesbu1/epic_rfm", | |
| "jesbu1/fino_net_rfm", | |
| "jesbu1/failsafe_rfm", | |
| "jesbu1/hand_paired_rfm", | |
| "jesbu1/galaxea_rfm", | |
| "jesbu1/h2r_rfm", | |
| "jesbu1/humanoid_everyday_rfm", | |
| "jesbu1/molmoact_rfm", | |
| "jesbu1/motif_rfm", | |
| "jesbu1/oxe_rfm", | |
| "jesbu1/oxe_rfm_eval", | |
| "jesbu1/ph2d_rfm", | |
| "jesbu1/racer_rfm", | |
| "jesbu1/roboarena_0825_rfm", | |
| "jesbu1/soar_rfm", | |
| "ykorkmaz/libero_failure_rfm", | |
| "aliangdw/usc_xarm_policy_ranking", | |
| "aliangdw/usc_franka_policy_ranking", | |
| "aliangdw/utd_so101_policy_ranking", | |
| "aliangdw/utd_so101_human", | |
| "jesbu1/utd_so101_clean_policy_ranking_top", | |
| "jesbu1/utd_so101_clean_policy_ranking_wrist", | |
| "jesbu1/mit_franka_p-rank_rfm", | |
| "jesbu1/usc_koch_p_ranking_rfm", | |
| ] | |
| # Global server state | |
| _server_state = { | |
| "server_url": None, | |
| "base_url": "http://40.119.56.66", # Default base URL | |
| } | |
| def discover_available_models( | |
| base_url: str = "http://40.119.56.66", port_range: tuple = (8000, 8010) | |
| ) -> List[Tuple[str, str]]: | |
| """Discover available models by pinging ports in the specified range. | |
| Returns: | |
| List of tuples: [(server_url, model_name), ...] | |
| """ | |
| available_models = [] | |
| start_port, end_port = port_range | |
| for port in range(start_port, end_port + 1): | |
| server_url = f"{base_url.rstrip('/')}:{port}" | |
| try: | |
| # Check health endpoint | |
| health_url = f"{server_url}/health" | |
| health_response = requests.get(health_url, timeout=2.0) | |
| if health_response.status_code == 200: | |
| # Try to get model info for model name | |
| try: | |
| model_info_url = f"{server_url}/model_info" | |
| model_info_response = requests.get(model_info_url, timeout=2.0) | |
| if model_info_response.status_code == 200: | |
| model_info_data = model_info_response.json() | |
| model_name = model_info_data.get("model_path", f"Model on port {port}") | |
| available_models.append((server_url, model_name)) | |
| else: | |
| # Health check passed but no model info, use port as name | |
| available_models.append((server_url, f"Model on port {port}")) | |
| except: | |
| # Health check passed but couldn't get model info | |
| available_models.append((server_url, f"Model on port {port}")) | |
| except requests.exceptions.RequestException: | |
| # Port not available, continue | |
| continue | |
| return available_models | |
| def get_model_info_for_url(server_url: str) -> Optional[str]: | |
| """Get formatted model info for a given server URL.""" | |
| if not server_url: | |
| return None | |
| try: | |
| model_info_url = server_url.rstrip("/") + "/model_info" | |
| model_info_response = requests.get(model_info_url, timeout=5.0) | |
| if model_info_response.status_code == 200: | |
| model_info_data = model_info_response.json() | |
| return format_model_info(model_info_data) | |
| except Exception as e: | |
| logger.warning(f"Could not fetch model info: {e}") | |
| return None | |
| def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[str]]: | |
| """Check server health and get model info.""" | |
| if not server_url: | |
| return "Please provide a server URL.", None, None | |
| try: | |
| url = server_url.rstrip("/") + "/health" | |
| response = requests.get(url, timeout=5.0) | |
| response.raise_for_status() | |
| health_data = response.json() | |
| # Also try to get GPU status for more info | |
| try: | |
| status_url = server_url.rstrip("/") + "/gpu_status" | |
| status_response = requests.get(status_url, timeout=5.0) | |
| if status_response.status_code == 200: | |
| status_data = status_response.json() | |
| health_data.update(status_data) | |
| except: | |
| pass | |
| # Try to get model info | |
| model_info_text = get_model_info_for_url(server_url) | |
| _server_state["server_url"] = server_url | |
| return ( | |
| f"Server connected: {health_data.get('available_gpus', 0)}/{health_data.get('total_gpus', 0)} GPUs available", | |
| health_data, | |
| model_info_text, | |
| ) | |
| except requests.exceptions.RequestException as e: | |
| return f"Error connecting to server: {str(e)}", None, None | |
| def format_model_info(model_info: dict) -> str: | |
| """Format model info and experiment config as markdown.""" | |
| lines = ["## Model Information\n"] | |
| # Model path | |
| model_path = model_info.get("model_path", "Unknown") | |
| lines.append(f"**Model Path:** `{model_path}`\n") | |
| # Number of GPUs | |
| num_gpus = model_info.get("num_gpus", "Unknown") | |
| lines.append(f"**Number of GPUs:** {num_gpus}\n") | |
| # Model architecture | |
| model_arch = model_info.get("model_architecture", {}) | |
| if model_arch and "error" not in model_arch: | |
| lines.append("\n## Model Architecture\n") | |
| model_class = model_arch.get("model_class", "Unknown") | |
| model_module = model_arch.get("model_module", "Unknown") | |
| lines.append(f"- **Model Class:** `{model_class}`\n") | |
| lines.append(f"- **Module:** `{model_module}`\n") | |
| # Parameter counts | |
| total_params = model_arch.get("total_parameters") | |
| trainable_params = model_arch.get("trainable_parameters") | |
| frozen_params = model_arch.get("frozen_parameters") | |
| trainable_pct = model_arch.get("trainable_percentage") | |
| if total_params is not None: | |
| lines.append(f"\n### Parameter Statistics\n") | |
| lines.append(f"- **Total Parameters:** {total_params:,}\n") | |
| if trainable_params is not None: | |
| lines.append(f"- **Trainable Parameters:** {trainable_params:,}\n") | |
| if frozen_params is not None: | |
| lines.append(f"- **Frozen Parameters:** {frozen_params:,}\n") | |
| if trainable_pct is not None: | |
| lines.append(f"- **Trainable Percentage:** {trainable_pct:.2f}%\n") | |
| # Architecture summary | |
| arch_summary = model_arch.get("architecture_summary", []) | |
| if arch_summary: | |
| lines.append(f"\n### Architecture Summary (Top-Level Modules)\n") | |
| for module_info in arch_summary[:10]: # Show first 10 modules | |
| name = module_info.get("name", "Unknown") | |
| module_type = module_info.get("type", "Unknown") | |
| params = module_info.get("parameters", 0) | |
| lines.append(f"- **{name}** (`{module_type}`): {params:,} parameters\n") | |
| # Experiment config | |
| exp_config = model_info.get("experiment_config", {}) | |
| if exp_config: | |
| lines.append("\n## Experiment Configuration\n") | |
| # Model config | |
| model_cfg = exp_config.get("model", {}) | |
| if model_cfg: | |
| lines.append("### Model Configuration\n") | |
| lines.append(f"- **Base Model:** `{model_cfg.get('base_model_id', 'N/A')}`\n") | |
| lines.append(f"- **Model Type:** `{model_cfg.get('model_type', 'N/A')}`\n") | |
| lines.append(f"- **Train Progress Head:** {model_cfg.get('train_progress_head', False)}\n") | |
| lines.append(f"- **Train Preference Head:** {model_cfg.get('train_preference_head', False)}\n") | |
| lines.append(f"- **Train Similarity Head:** {model_cfg.get('train_similarity_head', False)}\n") | |
| lines.append(f"- **Train Success Head:** {model_cfg.get('train_success_head', False)}\n") | |
| lines.append(f"- **Use PEFT:** {model_cfg.get('use_peft', False)}\n") | |
| lines.append(f"- **Use Unsloth:** {model_cfg.get('use_unsloth', False)}\n") | |
| # Data config | |
| data_cfg = exp_config.get("data", {}) | |
| if data_cfg: | |
| lines.append("\n### Data Configuration\n") | |
| lines.append(f"- **Max Frames:** {data_cfg.get('max_frames', 'N/A')}\n") | |
| lines.append( | |
| f"- **Resized Dimensions:** {data_cfg.get('resized_height', 'N/A')}x{data_cfg.get('resized_width', 'N/A')}\n" | |
| ) | |
| train_datasets = data_cfg.get("train_datasets", []) | |
| if train_datasets: | |
| lines.append(f"- **Train Datasets:** {', '.join(train_datasets)}\n") | |
| eval_datasets = data_cfg.get("eval_datasets", []) | |
| if eval_datasets: | |
| lines.append(f"- **Eval Datasets:** {', '.join(eval_datasets)}\n") | |
| # Training config | |
| training_cfg = exp_config.get("training", {}) | |
| if training_cfg: | |
| lines.append("\n### Training Configuration\n") | |
| lines.append(f"- **Learning Rate:** {training_cfg.get('learning_rate', 'N/A')}\n") | |
| lines.append(f"- **Batch Size:** {training_cfg.get('per_device_train_batch_size', 'N/A')}\n") | |
| lines.append( | |
| f"- **Gradient Accumulation Steps:** {training_cfg.get('gradient_accumulation_steps', 'N/A')}\n" | |
| ) | |
| lines.append(f"- **Max Steps:** {training_cfg.get('max_steps', 'N/A')}\n") | |
| return "".join(lines) | |
| def load_rfm_dataset(dataset_name, config_name): | |
| """Load the RFM dataset from HuggingFace Hub.""" | |
| try: | |
| if not dataset_name or not config_name: | |
| return None, "Please provide both dataset name and configuration" | |
| dataset = load_dataset_hf(dataset_name, name=config_name, split="train") | |
| if len(dataset) == 0: | |
| return None, f"Dataset {dataset_name}/{config_name} is empty" | |
| return dataset, f"Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}" | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "not found" in error_msg.lower(): | |
| return None, f"Dataset or configuration not found: {dataset_name}/{config_name}" | |
| elif "authentication" in error_msg.lower(): | |
| return None, f"Authentication required for {dataset_name}" | |
| else: | |
| return None, f"Error loading dataset: {error_msg}" | |
| def get_available_configs(dataset_name): | |
| """Get available configurations for a dataset.""" | |
| try: | |
| configs = get_dataset_config_names(dataset_name) | |
| return configs | |
| except Exception as e: | |
| logger.warning(f"Error getting configs for {dataset_name}: {e}") | |
| return [] | |
| def get_trajectory_video_path(dataset, index, dataset_name): | |
| """Get video path and metadata from a trajectory in the dataset.""" | |
| try: | |
| item = dataset[int(index)] | |
| frames_data = item["frames"] | |
| if isinstance(frames_data, str): | |
| # Construct HuggingFace Hub URL | |
| if dataset_name: | |
| video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}" | |
| else: | |
| video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}" | |
| task = item.get("task", "Complete the task") | |
| quality_label = item.get("quality_label", None) | |
| partial_success = item.get("partial_success", None) | |
| return video_path, task, quality_label, partial_success | |
| else: | |
| return None, None, None, None | |
| except Exception as e: | |
| logger.error(f"Error getting trajectory video path: {e}") | |
| return None, None, None, None | |
| def process_single_video( | |
| video_path: str, | |
| task_text: str = "Complete the task", | |
| server_url: str = "", | |
| fps: float = 1.0, | |
| use_frame_steps: bool = False, | |
| ) -> Tuple[Optional[str], Optional[str]]: | |
| """Process single video for progress and success predictions using eval server.""" | |
| # Get server URL from state if not provided | |
| if not server_url: | |
| server_url = _server_state.get("server_url") | |
| if not server_url: | |
| return None, "Please select a model from the dropdown above and ensure it's connected." | |
| if video_path is None: | |
| return None, "Please provide a video." | |
| try: | |
| frames_array = extract_frames(video_path, fps=fps) | |
| if frames_array is None or frames_array.size == 0: | |
| return None, "Could not extract frames from video." | |
| # Convert frames to (T, H, W, C) numpy array with uint8 values | |
| if frames_array.dtype != np.uint8: | |
| frames_array = np.clip(frames_array, 0, 255).astype(np.uint8) | |
| num_frames = frames_array.shape[0] | |
| frames_shape = frames_array.shape # (T, H, W, C) | |
| # Create target progress (placeholder - would be None in real use) | |
| target_progress = np.linspace(0.0, 1.0, num=num_frames).tolist() | |
| success_label = [1.0 if prog > 0.5 else 0.0 for prog in target_progress] | |
| # Create Trajectory | |
| trajectory = Trajectory( | |
| task=task_text, | |
| frames=frames_array, | |
| frames_shape=frames_shape, | |
| target_progress=target_progress, | |
| success_label=success_label, | |
| metadata={"source": "gradio_app"}, | |
| ) | |
| # Create ProgressSample | |
| progress_sample = ProgressSample( | |
| trajectory=trajectory, | |
| data_gen_strategy="demo", | |
| ) | |
| # Build payload and send to server | |
| files, sample_data = build_payload([progress_sample]) | |
| # Add use_frame_steps flag as extra form data | |
| extra_data = {"use_frame_steps": use_frame_steps} if use_frame_steps else None | |
| response = post_batch_npy(server_url, files, sample_data, timeout_s=120.0, extra_form_data=extra_data) | |
| # Process response | |
| outputs_progress = response.get("outputs_progress", {}) | |
| progress_pred = outputs_progress.get("progress_pred", []) | |
| outputs_success = response.get("outputs_success", {}) | |
| success_probs = outputs_success.get("success_probs", []) if outputs_success else None | |
| # Extract progress predictions | |
| if progress_pred and len(progress_pred) > 0: | |
| progress_array = np.array(progress_pred[0]) # First sample | |
| else: | |
| progress_array = np.array([]) | |
| # Extract success predictions if available | |
| success_array = None | |
| if success_probs and len(success_probs) > 0: | |
| success_array = np.array(success_probs[0]) | |
| # Convert success_array to binary if available | |
| success_binary = None | |
| if success_array is not None: | |
| success_binary = (success_array > 0.5).astype(float) | |
| # Create combined plot using shared helper function | |
| fig = create_combined_progress_success_plot( | |
| progress_pred=progress_array if len(progress_array) > 0 else np.array([0.0]), | |
| num_frames=num_frames, | |
| success_binary=success_binary, | |
| success_probs=success_array, | |
| success_labels=None, # No ground truth labels available | |
| is_discrete_mode=False, | |
| title=f"Progress & Success - {task_text}", | |
| ) | |
| # Save to temporary file | |
| tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
| fig.savefig(tmp_file.name, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| progress_plot = tmp_file.name | |
| info_text = f"**Frames processed:** {num_frames}\n" | |
| if len(progress_array) > 0: | |
| info_text += f"**Final progress:** {progress_array[-1]:.3f}\n" | |
| if success_array is not None and len(success_array) > 0: | |
| info_text += f"**Final success probability:** {success_array[-1]:.3f}\n" | |
| # Return combined plot (which includes success if available) | |
| return progress_plot, info_text | |
| except Exception as e: | |
| return None, f"Error processing video: {str(e)}" | |
| def process_two_videos( | |
| video_a_path: str, | |
| video_b_path: str, | |
| task_text: str = "Complete the task", | |
| prediction_type: str = "preference", | |
| server_url: str = "", | |
| fps: float = 1.0, | |
| ) -> Tuple[Optional[str], Optional[str], Optional[str]]: | |
| """Process two videos for preference, similarity, or progress prediction using eval server.""" | |
| # Get server URL from state if not provided | |
| if not server_url: | |
| server_url = _server_state.get("server_url") | |
| if not server_url: | |
| return "Please select a model from the dropdown above and ensure it's connected.", None, None | |
| if video_a_path is None or video_b_path is None: | |
| return "Please provide both videos.", None, None | |
| try: | |
| frames_array_a = extract_frames(video_a_path, fps=fps) | |
| frames_array_b = extract_frames(video_b_path, fps=fps) | |
| if frames_array_a is None or frames_array_a.size == 0: | |
| return "Could not extract frames from video A.", None, None | |
| if frames_array_b is None or frames_array_b.size == 0: | |
| return "Could not extract frames from video B.", None, None | |
| # Convert frames to uint8 | |
| if frames_array_a.dtype != np.uint8: | |
| frames_array_a = np.clip(frames_array_a, 0, 255).astype(np.uint8) | |
| if frames_array_b.dtype != np.uint8: | |
| frames_array_b = np.clip(frames_array_b, 0, 255).astype(np.uint8) | |
| num_frames_a = frames_array_a.shape[0] | |
| num_frames_b = frames_array_b.shape[0] | |
| frames_shape_a = frames_array_a.shape | |
| frames_shape_b = frames_array_b.shape | |
| # Create target progress for both trajectories | |
| target_progress_a = np.linspace(0.0, 1.0, num=num_frames_a).tolist() | |
| target_progress_b = np.linspace(0.0, 1.0, num=num_frames_b).tolist() | |
| success_label_a = [1.0 if prog > 0.5 else 0.0 for prog in target_progress_a] | |
| success_label_b = [1.0 if prog > 0.5 else 0.0 for prog in target_progress_b] | |
| # Create trajectories | |
| trajectory_a = Trajectory( | |
| task=task_text, | |
| frames=frames_array_a, | |
| frames_shape=frames_shape_a, | |
| target_progress=target_progress_a, | |
| success_label=success_label_a, | |
| metadata={"source": "gradio_app", "trajectory": "A"}, | |
| ) | |
| trajectory_b = Trajectory( | |
| task=task_text, | |
| frames=frames_array_b, | |
| frames_shape=frames_shape_b, | |
| target_progress=target_progress_b, | |
| success_label=success_label_b, | |
| metadata={"source": "gradio_app", "trajectory": "B"}, | |
| ) | |
| if prediction_type == "preference": | |
| # Create PreferenceSample (A = chosen, B = rejected) | |
| preference_sample = PreferenceSample( | |
| chosen_trajectory=trajectory_a, | |
| rejected_trajectory=trajectory_b, | |
| data_gen_strategy="demo", | |
| ) | |
| # Build payload and send to server | |
| files, sample_data = build_payload([preference_sample]) | |
| response = post_batch_npy(server_url, files, sample_data, timeout_s=120.0) | |
| # Process response | |
| outputs_preference = response.get("outputs_preference", {}) | |
| predictions = outputs_preference.get("predictions", []) | |
| prediction_probs = outputs_preference.get("prediction_probs", []) | |
| result_text = f"**Preference Prediction:**\n" | |
| if prediction_probs and len(prediction_probs) > 0: | |
| prob = prediction_probs[0] | |
| result_text += f"- Probability (A preferred): {prob:.3f}\n" | |
| result_text += f"- Interpretation: {'Video A is preferred' if prob > 0.5 else 'Video B is preferred'}\n" | |
| else: | |
| result_text += "Could not extract preference prediction from server response.\n" | |
| elif prediction_type == "progress": | |
| # Create ProgressSamples for both videos | |
| from dataset_types import ProgressSample | |
| progress_sample_a = ProgressSample( | |
| trajectory=trajectory_a, | |
| data_gen_strategy="demo", | |
| ) | |
| progress_sample_b = ProgressSample( | |
| trajectory=trajectory_b, | |
| data_gen_strategy="demo", | |
| ) | |
| # Build payload and send to server | |
| files, sample_data = build_payload([progress_sample_a, progress_sample_b]) | |
| response = post_batch_npy(server_url, files, sample_data, timeout_s=120.0) | |
| # Process response | |
| outputs_progress = response.get("outputs_progress", {}) | |
| progress_pred = outputs_progress.get("progress_pred", []) | |
| result_text = f"**Progress Comparison:**\n" | |
| if progress_pred and len(progress_pred) >= 2: | |
| progress_a = np.array(progress_pred[0]) | |
| progress_b = np.array(progress_pred[1]) | |
| final_progress_a = float(progress_a[-1]) if len(progress_a) > 0 else 0.0 | |
| final_progress_b = float(progress_b[-1]) if len(progress_b) > 0 else 0.0 | |
| result_text += f"- Video A final progress: {final_progress_a:.3f}\n" | |
| result_text += f"- Video B final progress: {final_progress_b:.3f}\n" | |
| result_text += f"- Difference: {abs(final_progress_a - final_progress_b):.3f}\n" | |
| if final_progress_a > final_progress_b: | |
| result_text += f"- Video A has higher progress\n" | |
| elif final_progress_b > final_progress_a: | |
| result_text += f"- Video B has higher progress\n" | |
| else: | |
| result_text += f"- Both videos have equal progress\n" | |
| else: | |
| result_text += "Could not extract progress predictions from server response.\n" | |
| elif prediction_type == "similarity": | |
| # For similarity inference, we have two videos: | |
| # - Video A as reference trajectory | |
| # - Video B as similar trajectory | |
| # diff_trajectory is None in inference mode (only need similarity between ref and sim) | |
| # Create SimilaritySample with Video A as ref and Video B as sim | |
| similarity_sample = SimilaritySample( | |
| ref_trajectory=trajectory_a, | |
| sim_trajectory=trajectory_b, | |
| diff_trajectory=None, # None in inference mode | |
| data_gen_strategy="demo", | |
| ) | |
| # Build payload and send to server | |
| files, sample_data = build_payload([similarity_sample]) | |
| response = post_batch_npy(server_url, files, sample_data, timeout_s=120.0) | |
| # Process response - we only care about sim_score_ref_sim (similarity between Video A and Video B) | |
| outputs_similarity = response.get("outputs_similarity", {}) | |
| sim_score_ref_sim = outputs_similarity.get("sim_score_ref_sim", []) | |
| result_text = f"**Similarity Prediction:**\n" | |
| if sim_score_ref_sim and len(sim_score_ref_sim) > 0: | |
| sim_score = sim_score_ref_sim[0] | |
| if sim_score is not None: | |
| result_text += f"- Similarity score (Video A vs Video B): {sim_score:.3f}\n" | |
| # Interpret similarity score (higher = more similar) | |
| if sim_score > 0.7: | |
| result_text += f"- Interpretation: High similarity - videos are very similar\n" | |
| elif sim_score > 0.4: | |
| result_text += f"- Interpretation: Moderate similarity - videos share some similarities\n" | |
| else: | |
| result_text += f"- Interpretation: Low similarity - videos are quite different\n" | |
| else: | |
| result_text += "Could not extract similarity score from server response.\n" | |
| else: | |
| result_text += "Could not extract similarity prediction from server response.\n" | |
| # Return result text and both video paths | |
| return result_text, video_a_path, video_b_path | |
| except Exception as e: | |
| return f"Error processing videos: {str(e)}", None, None | |
| # Create Gradio interface | |
| try: | |
| # Try with theme (Gradio 4.0+) | |
| demo = gr.Blocks(title="RFM Evaluation Server", theme=gr.themes.Soft()) | |
| except TypeError: | |
| # Fallback for older Gradio versions without theme support | |
| demo = gr.Blocks(title="RFM Evaluation Server") | |
| with demo: | |
| gr.Markdown( | |
| """ | |
| # RFM (Reward Foundation Model) Evaluation Server | |
| """ | |
| ) | |
| # Hidden state to store server URL and model mapping (define before use) | |
| server_url_state = gr.State(value=None) | |
| model_url_mapping_state = gr.State(value={}) # Maps model_name -> server_url | |
| # Function definitions for event handlers | |
| def discover_and_select_models(base_url: str): | |
| """Discover models and update dropdown.""" | |
| if not base_url: | |
| return ( | |
| gr.update(choices=[], value=None), | |
| gr.update(value="Please provide a base URL", visible=True), | |
| gr.update(value="", visible=True), | |
| None, | |
| {}, # Empty mapping | |
| ) | |
| _server_state["base_url"] = base_url | |
| models = discover_available_models(base_url, port_range=(8000, 8010)) | |
| if not models: | |
| return ( | |
| gr.update(choices=[], value=None), | |
| gr.update(value="❌ No models found on ports 8000-8010. Make sure servers are running.", visible=True), | |
| gr.update(value="", visible=True), | |
| None, | |
| {}, # Empty mapping | |
| ) | |
| # Format choices: show model_name in dropdown | |
| # Store mapping of model_name to URL in state | |
| choices = [] | |
| url_map = {} | |
| for url, name in models: | |
| choices.append(name) | |
| url_map[name] = url | |
| # Auto-select first model | |
| selected_choice = choices[0] if choices else None | |
| selected_url = url_map.get(selected_choice) if selected_choice else None | |
| # Get model info for selected model | |
| model_info_text = get_model_info_for_url(selected_url) if selected_url else "" | |
| status_text = f"✅ Found {len(models)} model(s). Auto-selected first model." | |
| _server_state["server_url"] = selected_url | |
| return ( | |
| gr.update(choices=choices, value=selected_choice), | |
| gr.update(value=status_text, visible=True), | |
| gr.update(value=model_info_text, visible=True), | |
| selected_url, | |
| url_map, # Return mapping for state | |
| ) | |
| def on_model_selected(model_choice: str, url_mapping: dict): | |
| """Handle model selection change.""" | |
| if not model_choice: | |
| return ( | |
| gr.update(value="No model selected", visible=True), | |
| gr.update(value="", visible=True), | |
| None, | |
| ) | |
| # Get URL from mapping | |
| server_url = url_mapping.get(model_choice) if url_mapping else None | |
| if not server_url: | |
| return ( | |
| gr.update( | |
| value="Could not find server URL for selected model. Please rediscover models.", visible=True | |
| ), | |
| gr.update(value="", visible=True), | |
| None, | |
| ) | |
| # Get model info | |
| model_info_text = get_model_info_for_url(server_url) or "" | |
| status, health_data, _ = check_server_health(server_url) | |
| _server_state["server_url"] = server_url | |
| return ( | |
| gr.update(value=status, visible=True), | |
| gr.update(value=model_info_text, visible=True), | |
| server_url, | |
| ) | |
| # Use Gradio's built-in Sidebar component (collapsible by default) | |
| with gr.Sidebar(): | |
| gr.Markdown("### 🔧 Model Configuration") | |
| base_url_input = gr.Textbox( | |
| label="Base Server URL", | |
| placeholder="http://40.119.56.66", | |
| value="http://40.119.56.66", | |
| interactive=True, | |
| ) | |
| discover_btn = gr.Button("🔍 Discover Models", variant="primary", size="lg") | |
| model_dropdown = gr.Dropdown( | |
| label="Select Model", | |
| choices=[], | |
| value=None, | |
| interactive=True, | |
| info="Models will be discovered on ports 8000-8010", | |
| ) | |
| server_status = gr.Markdown("Click 'Discover Models' to find available models") | |
| gr.Markdown("---") | |
| gr.Markdown("### 📋 Model Information") | |
| model_info_display = gr.Markdown("") | |
| # Event handlers for sidebar | |
| discover_btn.click( | |
| fn=discover_and_select_models, | |
| inputs=[base_url_input], | |
| outputs=[model_dropdown, server_status, model_info_display, server_url_state, model_url_mapping_state], | |
| ) | |
| model_dropdown.change( | |
| fn=on_model_selected, | |
| inputs=[model_dropdown, model_url_mapping_state], | |
| outputs=[server_status, model_info_display, server_url_state], | |
| ) | |
| # Main content area with tabs | |
| with gr.Tabs(): | |
| with gr.Tab("Progress Prediction"): | |
| gr.Markdown("### Progress & Success Prediction") | |
| gr.Markdown("Upload a video or select one from a dataset to get progress predictions.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| single_video_input = gr.Video(label="Upload Video", height=300) | |
| task_text_input = gr.Textbox( | |
| label="Task Description", | |
| placeholder="Describe the task (e.g., 'Pick up the red block')", | |
| value="Complete the task", | |
| ) | |
| fps_input_single = gr.Slider( | |
| label="FPS (Frames Per Second)", | |
| minimum=0.1, | |
| maximum=10.0, | |
| value=1.0, | |
| step=0.1, | |
| info="Frames per second to extract from video (higher = more frames)", | |
| ) | |
| use_frame_steps_single = gr.Checkbox( | |
| label="Use Frame Steps", | |
| value=False, | |
| info="Process frames incrementally (0:1, 0:2, 0:3, etc.) for autoregressive predictions", | |
| ) | |
| analyze_single_btn = gr.Button("Analyze Video", variant="primary") | |
| gr.Markdown("---") | |
| gr.Markdown("**OR Select from Dataset**") | |
| gr.Markdown("---") | |
| with gr.Accordion("📁 Select from Dataset", open=False): | |
| dataset_name_single = gr.Dropdown( | |
| choices=PREDEFINED_DATASETS, | |
| value="jesbu1/oxe_rfm", | |
| label="Dataset Name", | |
| allow_custom_value=True, | |
| ) | |
| config_name_single = gr.Dropdown( | |
| choices=[], value="", label="Configuration Name", allow_custom_value=True | |
| ) | |
| with gr.Row(): | |
| refresh_configs_btn = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm") | |
| load_dataset_btn = gr.Button("Load Dataset", variant="secondary", size="sm") | |
| dataset_status_single = gr.Markdown("", visible=False) | |
| with gr.Row(): | |
| prev_traj_btn = gr.Button("⬅️ Prev", variant="secondary", size="sm") | |
| trajectory_slider = gr.Slider( | |
| minimum=0, maximum=0, step=1, value=0, label="Trajectory Index", interactive=True | |
| ) | |
| next_traj_btn = gr.Button("Next ➡️", variant="secondary", size="sm") | |
| trajectory_metadata = gr.Markdown("", visible=False) | |
| use_dataset_video_btn = gr.Button("Use Selected Video", variant="secondary") | |
| with gr.Column(): | |
| progress_plot = gr.Image(label="Progress & Success Prediction", height=400) | |
| info_output = gr.Markdown("") | |
| # State variables for dataset | |
| current_dataset_single = gr.State(None) | |
| def update_config_choices_single(dataset_name): | |
| """Update config choices when dataset changes.""" | |
| if not dataset_name: | |
| return gr.update(choices=[], value="") | |
| try: | |
| configs = get_available_configs(dataset_name) | |
| if configs: | |
| return gr.update(choices=configs, value=configs[0]) | |
| else: | |
| return gr.update(choices=[], value="") | |
| except Exception as e: | |
| logger.warning(f"Could not fetch configs: {e}") | |
| return gr.update(choices=[], value="") | |
| def load_dataset_single(dataset_name, config_name): | |
| """Load dataset and update slider.""" | |
| dataset, status = load_rfm_dataset(dataset_name, config_name) | |
| if dataset is not None: | |
| max_index = len(dataset) - 1 | |
| return ( | |
| dataset, | |
| gr.update(value=status, visible=True), | |
| gr.update( | |
| maximum=max_index, value=0, interactive=True, label=f"Trajectory Index (0 to {max_index})" | |
| ), | |
| ) | |
| else: | |
| return None, gr.update(value=status, visible=True), gr.update(maximum=0, value=0, interactive=False) | |
| def use_dataset_video(dataset, index, dataset_name): | |
| """Load video from dataset and update inputs.""" | |
| if dataset is None: | |
| return ( | |
| None, | |
| "Complete the task", | |
| gr.update(value="No dataset loaded", visible=True), | |
| gr.update(visible=False), | |
| ) | |
| video_path, task, quality_label, partial_success = get_trajectory_video_path( | |
| dataset, index, dataset_name | |
| ) | |
| if video_path: | |
| # Build metadata text | |
| metadata_lines = [] | |
| if quality_label: | |
| metadata_lines.append(f"**Quality Label:** {quality_label}") | |
| if partial_success is not None: | |
| metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") | |
| metadata_text = "\n".join(metadata_lines) if metadata_lines else "" | |
| status_text = f"✅ Loaded trajectory {index} from dataset" | |
| if metadata_text: | |
| status_text += f"\n\n{metadata_text}" | |
| return ( | |
| video_path, | |
| task, | |
| gr.update(value=status_text, visible=True), | |
| gr.update(value=metadata_text, visible=bool(metadata_text)), | |
| ) | |
| else: | |
| return ( | |
| None, | |
| "Complete the task", | |
| gr.update(value="❌ Error loading trajectory", visible=True), | |
| gr.update(visible=False), | |
| ) | |
| def next_trajectory(dataset, current_idx, dataset_name): | |
| """Go to next trajectory.""" | |
| if dataset is None: | |
| return 0, None, "Complete the task", gr.update(visible=False), gr.update(visible=False) | |
| next_idx = min(current_idx + 1, len(dataset) - 1) | |
| video_path, task, quality_label, partial_success = get_trajectory_video_path( | |
| dataset, next_idx, dataset_name | |
| ) | |
| if video_path: | |
| # Build metadata text | |
| metadata_lines = [] | |
| if quality_label: | |
| metadata_lines.append(f"**Quality Label:** {quality_label}") | |
| if partial_success is not None: | |
| metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") | |
| metadata_text = "\n".join(metadata_lines) if metadata_lines else "" | |
| return ( | |
| next_idx, | |
| video_path, | |
| task, | |
| gr.update(value=metadata_text, visible=bool(metadata_text)), | |
| gr.update(value=f"✅ Trajectory {next_idx}/{len(dataset) - 1}", visible=True), | |
| ) | |
| else: | |
| return current_idx, None, "Complete the task", gr.update(visible=False), gr.update(visible=False) | |
| def prev_trajectory(dataset, current_idx, dataset_name): | |
| """Go to previous trajectory.""" | |
| if dataset is None: | |
| return 0, None, "Complete the task", gr.update(visible=False), gr.update(visible=False) | |
| prev_idx = max(current_idx - 1, 0) | |
| video_path, task, quality_label, partial_success = get_trajectory_video_path( | |
| dataset, prev_idx, dataset_name | |
| ) | |
| if video_path: | |
| # Build metadata text | |
| metadata_lines = [] | |
| if quality_label: | |
| metadata_lines.append(f"**Quality Label:** {quality_label}") | |
| if partial_success is not None: | |
| metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") | |
| metadata_text = "\n".join(metadata_lines) if metadata_lines else "" | |
| return ( | |
| prev_idx, | |
| video_path, | |
| task, | |
| gr.update(value=metadata_text, visible=bool(metadata_text)), | |
| gr.update(value=f"✅ Trajectory {prev_idx}/{len(dataset) - 1}", visible=True), | |
| ) | |
| else: | |
| return current_idx, None, "Complete the task", gr.update(visible=False), gr.update(visible=False) | |
| def update_trajectory_on_slider_change(dataset, index, dataset_name): | |
| """Update trajectory metadata when slider changes.""" | |
| if dataset is None: | |
| return gr.update(visible=False), gr.update(visible=False) | |
| video_path, task, quality_label, partial_success = get_trajectory_video_path( | |
| dataset, index, dataset_name | |
| ) | |
| if video_path: | |
| # Build metadata text | |
| metadata_lines = [] | |
| if quality_label: | |
| metadata_lines.append(f"**Quality Label:** {quality_label}") | |
| if partial_success is not None: | |
| metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") | |
| metadata_text = "\n".join(metadata_lines) if metadata_lines else "" | |
| return ( | |
| gr.update(value=metadata_text, visible=bool(metadata_text)), | |
| gr.update(value=f"Trajectory {index}/{len(dataset) - 1}", visible=True), | |
| ) | |
| else: | |
| return gr.update(visible=False), gr.update(visible=False) | |
| # Dataset selection handlers | |
| dataset_name_single.change( | |
| fn=update_config_choices_single, inputs=[dataset_name_single], outputs=[config_name_single] | |
| ) | |
| refresh_configs_btn.click( | |
| fn=update_config_choices_single, inputs=[dataset_name_single], outputs=[config_name_single] | |
| ) | |
| load_dataset_btn.click( | |
| fn=load_dataset_single, | |
| inputs=[dataset_name_single, config_name_single], | |
| outputs=[current_dataset_single, dataset_status_single, trajectory_slider], | |
| ) | |
| use_dataset_video_btn.click( | |
| fn=use_dataset_video, | |
| inputs=[current_dataset_single, trajectory_slider, dataset_name_single], | |
| outputs=[single_video_input, task_text_input, dataset_status_single, trajectory_metadata], | |
| ) | |
| # Navigation buttons | |
| next_traj_btn.click( | |
| fn=next_trajectory, | |
| inputs=[current_dataset_single, trajectory_slider, dataset_name_single], | |
| outputs=[ | |
| trajectory_slider, | |
| single_video_input, | |
| task_text_input, | |
| trajectory_metadata, | |
| dataset_status_single, | |
| ], | |
| ) | |
| prev_traj_btn.click( | |
| fn=prev_trajectory, | |
| inputs=[current_dataset_single, trajectory_slider, dataset_name_single], | |
| outputs=[ | |
| trajectory_slider, | |
| single_video_input, | |
| task_text_input, | |
| trajectory_metadata, | |
| dataset_status_single, | |
| ], | |
| ) | |
| # Update metadata when slider changes | |
| trajectory_slider.change( | |
| fn=update_trajectory_on_slider_change, | |
| inputs=[current_dataset_single, trajectory_slider, dataset_name_single], | |
| outputs=[trajectory_metadata, dataset_status_single], | |
| ) | |
| analyze_single_btn.click( | |
| fn=process_single_video, | |
| inputs=[ | |
| single_video_input, | |
| task_text_input, | |
| server_url_state, | |
| fps_input_single, | |
| use_frame_steps_single, | |
| ], | |
| outputs=[progress_plot, info_output], | |
| api_name="process_single_video", | |
| ) | |
| with gr.Tab("Preference/Similarity Analysis"): | |
| gr.Markdown("### Preference & Similarity Prediction") | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_a_input = gr.Video(label="Video A", height=250) | |
| video_b_input = gr.Video(label="Video B", height=250) | |
| task_text_dual = gr.Textbox( | |
| label="Task Description", | |
| placeholder="Describe the task", | |
| value="Complete the task", | |
| ) | |
| prediction_type = gr.Radio( | |
| choices=["preference", "similarity", "progress"], | |
| value="preference", | |
| label="Prediction Type", | |
| ) | |
| fps_input_dual = gr.Slider( | |
| label="FPS (Frames Per Second)", | |
| minimum=0.1, | |
| maximum=10.0, | |
| value=1.0, | |
| step=0.1, | |
| info="Frames per second to extract from videos (higher = more frames)", | |
| ) | |
| analyze_dual_btn = gr.Button("Compare Videos", variant="primary") | |
| gr.Markdown("---") | |
| gr.Markdown("**OR Select from Dataset**") | |
| gr.Markdown("---") | |
| with gr.Accordion("📁 Video A - Select from Dataset", open=False): | |
| dataset_name_a = gr.Dropdown( | |
| choices=PREDEFINED_DATASETS, | |
| value="jesbu1/oxe_rfm", | |
| label="Dataset Name", | |
| allow_custom_value=True, | |
| ) | |
| config_name_a = gr.Dropdown( | |
| choices=[], value="", label="Configuration Name", allow_custom_value=True | |
| ) | |
| with gr.Row(): | |
| refresh_configs_btn_a = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm") | |
| load_dataset_btn_a = gr.Button("Load Dataset", variant="secondary", size="sm") | |
| dataset_status_a = gr.Markdown("", visible=False) | |
| with gr.Row(): | |
| prev_traj_btn_a = gr.Button("⬅️ Prev", variant="secondary", size="sm") | |
| trajectory_slider_a = gr.Slider( | |
| minimum=0, maximum=0, step=1, value=0, label="Trajectory Index", interactive=True | |
| ) | |
| next_traj_btn_a = gr.Button("Next ➡️", variant="secondary", size="sm") | |
| trajectory_metadata_a = gr.Markdown("", visible=False) | |
| use_dataset_video_btn_a = gr.Button("Use Selected Video for A", variant="secondary") | |
| with gr.Accordion("📁 Video B - Select from Dataset", open=False): | |
| dataset_name_b = gr.Dropdown( | |
| choices=PREDEFINED_DATASETS, | |
| value="jesbu1/oxe_rfm", | |
| label="Dataset Name", | |
| allow_custom_value=True, | |
| ) | |
| config_name_b = gr.Dropdown( | |
| choices=[], value="", label="Configuration Name", allow_custom_value=True | |
| ) | |
| with gr.Row(): | |
| refresh_configs_btn_b = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm") | |
| load_dataset_btn_b = gr.Button("Load Dataset", variant="secondary", size="sm") | |
| dataset_status_b = gr.Markdown("", visible=False) | |
| with gr.Row(): | |
| prev_traj_btn_b = gr.Button("⬅️ Prev", variant="secondary", size="sm") | |
| trajectory_slider_b = gr.Slider( | |
| minimum=0, maximum=0, step=1, value=0, label="Trajectory Index", interactive=True | |
| ) | |
| next_traj_btn_b = gr.Button("Next ➡️", variant="secondary", size="sm") | |
| trajectory_metadata_b = gr.Markdown("", visible=False) | |
| use_dataset_video_btn_b = gr.Button("Use Selected Video for B", variant="secondary") | |
| with gr.Column(): | |
| # Videos displayed side by side | |
| with gr.Row(): | |
| video_a_display = gr.Video(label="Video A", height=400) | |
| video_b_display = gr.Video(label="Video B", height=400) | |
| # Result text at the bottom | |
| result_text = gr.Markdown("") | |
| # State variables for datasets | |
| current_dataset_a = gr.State(None) | |
| current_dataset_b = gr.State(None) | |
| # Helper functions for Video A | |
| def update_config_choices_a(dataset_name): | |
| """Update config choices for Video A when dataset changes.""" | |
| if not dataset_name: | |
| return gr.update(choices=[], value="") | |
| try: | |
| configs = get_available_configs(dataset_name) | |
| if configs: | |
| return gr.update(choices=configs, value=configs[0]) | |
| else: | |
| return gr.update(choices=[], value="") | |
| except Exception as e: | |
| logger.warning(f"Could not fetch configs: {e}") | |
| return gr.update(choices=[], value="") | |
| def load_dataset_a(dataset_name, config_name): | |
| """Load dataset A and update slider.""" | |
| dataset, status = load_rfm_dataset(dataset_name, config_name) | |
| if dataset is not None: | |
| max_index = len(dataset) - 1 | |
| return ( | |
| dataset, | |
| gr.update(value=status, visible=True), | |
| gr.update( | |
| maximum=max_index, value=0, interactive=True, label=f"Trajectory Index (0 to {max_index})" | |
| ), | |
| ) | |
| else: | |
| return None, gr.update(value=status, visible=True), gr.update(maximum=0, value=0, interactive=False) | |
| def use_dataset_video_a(dataset, index, dataset_name): | |
| """Load video A from dataset and update input.""" | |
| if dataset is None: | |
| return ( | |
| None, | |
| gr.update(value="No dataset loaded", visible=True), | |
| gr.update(visible=False), | |
| ) | |
| video_path, task, quality_label, partial_success = get_trajectory_video_path( | |
| dataset, index, dataset_name | |
| ) | |
| if video_path: | |
| # Build metadata text | |
| metadata_lines = [] | |
| if quality_label: | |
| metadata_lines.append(f"**Quality Label:** {quality_label}") | |
| if partial_success is not None: | |
| metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") | |
| metadata_text = "\n".join(metadata_lines) if metadata_lines else "" | |
| status_text = f"✅ Loaded trajectory {index} from dataset for Video A" | |
| if metadata_text: | |
| status_text += f"\n\n{metadata_text}" | |
| return ( | |
| video_path, | |
| gr.update(value=status_text, visible=True), | |
| gr.update(value=metadata_text, visible=bool(metadata_text)), | |
| ) | |
| else: | |
| return ( | |
| None, | |
| gr.update(value="❌ Error loading trajectory", visible=True), | |
| gr.update(visible=False), | |
| ) | |
| def next_trajectory_a(dataset, current_idx, dataset_name): | |
| """Go to next trajectory for Video A.""" | |
| if dataset is None: | |
| return 0, None, gr.update(visible=False), gr.update(visible=False) | |
| next_idx = min(current_idx + 1, len(dataset) - 1) | |
| video_path, task, quality_label, partial_success = get_trajectory_video_path( | |
| dataset, next_idx, dataset_name | |
| ) | |
| if video_path: | |
| # Build metadata text | |
| metadata_lines = [] | |
| if quality_label: | |
| metadata_lines.append(f"**Quality Label:** {quality_label}") | |
| if partial_success is not None: | |
| metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") | |
| metadata_text = "\n".join(metadata_lines) if metadata_lines else "" | |
| return ( | |
| next_idx, | |
| video_path, | |
| gr.update(value=metadata_text, visible=bool(metadata_text)), | |
| gr.update(value=f"✅ Trajectory {next_idx}/{len(dataset) - 1}", visible=True), | |
| ) | |
| else: | |
| return current_idx, None, gr.update(visible=False), gr.update(visible=False) | |
| def prev_trajectory_a(dataset, current_idx, dataset_name): | |
| """Go to previous trajectory for Video A.""" | |
| if dataset is None: | |
| return 0, None, gr.update(visible=False), gr.update(visible=False) | |
| prev_idx = max(current_idx - 1, 0) | |
| video_path, task, quality_label, partial_success = get_trajectory_video_path( | |
| dataset, prev_idx, dataset_name | |
| ) | |
| if video_path: | |
| # Build metadata text | |
| metadata_lines = [] | |
| if quality_label: | |
| metadata_lines.append(f"**Quality Label:** {quality_label}") | |
| if partial_success is not None: | |
| metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") | |
| metadata_text = "\n".join(metadata_lines) if metadata_lines else "" | |
| return ( | |
| prev_idx, | |
| video_path, | |
| gr.update(value=metadata_text, visible=bool(metadata_text)), | |
| gr.update(value=f"✅ Trajectory {prev_idx}/{len(dataset) - 1}", visible=True), | |
| ) | |
| else: | |
| return current_idx, None, gr.update(visible=False), gr.update(visible=False) | |
| def update_trajectory_on_slider_change_a(dataset, index, dataset_name): | |
| """Update trajectory metadata when slider changes for Video A.""" | |
| if dataset is None: | |
| return gr.update(visible=False), gr.update(visible=False) | |
| video_path, task, quality_label, partial_success = get_trajectory_video_path( | |
| dataset, index, dataset_name | |
| ) | |
| if video_path: | |
| # Build metadata text | |
| metadata_lines = [] | |
| if quality_label: | |
| metadata_lines.append(f"**Quality Label:** {quality_label}") | |
| if partial_success is not None: | |
| metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") | |
| metadata_text = "\n".join(metadata_lines) if metadata_lines else "" | |
| return ( | |
| gr.update(value=metadata_text, visible=bool(metadata_text)), | |
| gr.update(value=f"Trajectory {index}/{len(dataset) - 1}", visible=True), | |
| ) | |
| else: | |
| return gr.update(visible=False), gr.update(visible=False) | |
| # Helper functions for Video B (same as Video A) | |
| def update_config_choices_b(dataset_name): | |
| """Update config choices for Video B when dataset changes.""" | |
| if not dataset_name: | |
| return gr.update(choices=[], value="") | |
| try: | |
| configs = get_available_configs(dataset_name) | |
| if configs: | |
| return gr.update(choices=configs, value=configs[0]) | |
| else: | |
| return gr.update(choices=[], value="") | |
| except Exception as e: | |
| logger.warning(f"Could not fetch configs: {e}") | |
| return gr.update(choices=[], value="") | |
| def load_dataset_b(dataset_name, config_name): | |
| """Load dataset B and update slider.""" | |
| dataset, status = load_rfm_dataset(dataset_name, config_name) | |
| if dataset is not None: | |
| max_index = len(dataset) - 1 | |
| return ( | |
| dataset, | |
| gr.update(value=status, visible=True), | |
| gr.update( | |
| maximum=max_index, value=0, interactive=True, label=f"Trajectory Index (0 to {max_index})" | |
| ), | |
| ) | |
| else: | |
| return None, gr.update(value=status, visible=True), gr.update(maximum=0, value=0, interactive=False) | |
| def use_dataset_video_b(dataset, index, dataset_name): | |
| """Load video B from dataset and update input.""" | |
| if dataset is None: | |
| return ( | |
| None, | |
| gr.update(value="No dataset loaded", visible=True), | |
| gr.update(visible=False), | |
| ) | |
| video_path, task, quality_label, partial_success = get_trajectory_video_path( | |
| dataset, index, dataset_name | |
| ) | |
| if video_path: | |
| # Build metadata text | |
| metadata_lines = [] | |
| if quality_label: | |
| metadata_lines.append(f"**Quality Label:** {quality_label}") | |
| if partial_success is not None: | |
| metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") | |
| metadata_text = "\n".join(metadata_lines) if metadata_lines else "" | |
| status_text = f"✅ Loaded trajectory {index} from dataset for Video B" | |
| if metadata_text: | |
| status_text += f"\n\n{metadata_text}" | |
| return ( | |
| video_path, | |
| gr.update(value=status_text, visible=True), | |
| gr.update(value=metadata_text, visible=bool(metadata_text)), | |
| ) | |
| else: | |
| return ( | |
| None, | |
| gr.update(value="❌ Error loading trajectory", visible=True), | |
| gr.update(visible=False), | |
| ) | |
| def next_trajectory_b(dataset, current_idx, dataset_name): | |
| """Go to next trajectory for Video B.""" | |
| if dataset is None: | |
| return 0, None, gr.update(visible=False), gr.update(visible=False) | |
| next_idx = min(current_idx + 1, len(dataset) - 1) | |
| video_path, task, quality_label, partial_success = get_trajectory_video_path( | |
| dataset, next_idx, dataset_name | |
| ) | |
| if video_path: | |
| # Build metadata text | |
| metadata_lines = [] | |
| if quality_label: | |
| metadata_lines.append(f"**Quality Label:** {quality_label}") | |
| if partial_success is not None: | |
| metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") | |
| metadata_text = "\n".join(metadata_lines) if metadata_lines else "" | |
| return ( | |
| next_idx, | |
| video_path, | |
| gr.update(value=metadata_text, visible=bool(metadata_text)), | |
| gr.update(value=f"✅ Trajectory {next_idx}/{len(dataset) - 1}", visible=True), | |
| ) | |
| else: | |
| return current_idx, None, gr.update(visible=False), gr.update(visible=False) | |
| def prev_trajectory_b(dataset, current_idx, dataset_name): | |
| """Go to previous trajectory for Video B.""" | |
| if dataset is None: | |
| return 0, None, gr.update(visible=False), gr.update(visible=False) | |
| prev_idx = max(current_idx - 1, 0) | |
| video_path, task, quality_label, partial_success = get_trajectory_video_path( | |
| dataset, prev_idx, dataset_name | |
| ) | |
| if video_path: | |
| # Build metadata text | |
| metadata_lines = [] | |
| if quality_label: | |
| metadata_lines.append(f"**Quality Label:** {quality_label}") | |
| if partial_success is not None: | |
| metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") | |
| metadata_text = "\n".join(metadata_lines) if metadata_lines else "" | |
| return ( | |
| prev_idx, | |
| video_path, | |
| gr.update(value=metadata_text, visible=bool(metadata_text)), | |
| gr.update(value=f"✅ Trajectory {prev_idx}/{len(dataset) - 1}", visible=True), | |
| ) | |
| else: | |
| return current_idx, None, gr.update(visible=False), gr.update(visible=False) | |
| def update_trajectory_on_slider_change_b(dataset, index, dataset_name): | |
| """Update trajectory metadata when slider changes for Video B.""" | |
| if dataset is None: | |
| return gr.update(visible=False), gr.update(visible=False) | |
| video_path, task, quality_label, partial_success = get_trajectory_video_path( | |
| dataset, index, dataset_name | |
| ) | |
| if video_path: | |
| # Build metadata text | |
| metadata_lines = [] | |
| if quality_label: | |
| metadata_lines.append(f"**Quality Label:** {quality_label}") | |
| if partial_success is not None: | |
| metadata_lines.append(f"**Partial Success:** {partial_success:.3f}") | |
| metadata_text = "\n".join(metadata_lines) if metadata_lines else "" | |
| return ( | |
| gr.update(value=metadata_text, visible=bool(metadata_text)), | |
| gr.update(value=f"Trajectory {index}/{len(dataset) - 1}", visible=True), | |
| ) | |
| else: | |
| return gr.update(visible=False), gr.update(visible=False) | |
| # Video A dataset selection handlers | |
| dataset_name_a.change(fn=update_config_choices_a, inputs=[dataset_name_a], outputs=[config_name_a]) | |
| refresh_configs_btn_a.click(fn=update_config_choices_a, inputs=[dataset_name_a], outputs=[config_name_a]) | |
| load_dataset_btn_a.click( | |
| fn=load_dataset_a, | |
| inputs=[dataset_name_a, config_name_a], | |
| outputs=[current_dataset_a, dataset_status_a, trajectory_slider_a], | |
| ) | |
| use_dataset_video_btn_a.click( | |
| fn=use_dataset_video_a, | |
| inputs=[current_dataset_a, trajectory_slider_a, dataset_name_a], | |
| outputs=[video_a_input, dataset_status_a, trajectory_metadata_a], | |
| ) | |
| next_traj_btn_a.click( | |
| fn=next_trajectory_a, | |
| inputs=[current_dataset_a, trajectory_slider_a, dataset_name_a], | |
| outputs=[ | |
| trajectory_slider_a, | |
| video_a_input, | |
| trajectory_metadata_a, | |
| dataset_status_a, | |
| ], | |
| ) | |
| prev_traj_btn_a.click( | |
| fn=prev_trajectory_a, | |
| inputs=[current_dataset_a, trajectory_slider_a, dataset_name_a], | |
| outputs=[ | |
| trajectory_slider_a, | |
| video_a_input, | |
| trajectory_metadata_a, | |
| dataset_status_a, | |
| ], | |
| ) | |
| trajectory_slider_a.change( | |
| fn=update_trajectory_on_slider_change_a, | |
| inputs=[current_dataset_a, trajectory_slider_a, dataset_name_a], | |
| outputs=[trajectory_metadata_a, dataset_status_a], | |
| ) | |
| # Video B dataset selection handlers | |
| dataset_name_b.change(fn=update_config_choices_b, inputs=[dataset_name_b], outputs=[config_name_b]) | |
| refresh_configs_btn_b.click(fn=update_config_choices_b, inputs=[dataset_name_b], outputs=[config_name_b]) | |
| load_dataset_btn_b.click( | |
| fn=load_dataset_b, | |
| inputs=[dataset_name_b, config_name_b], | |
| outputs=[current_dataset_b, dataset_status_b, trajectory_slider_b], | |
| ) | |
| use_dataset_video_btn_b.click( | |
| fn=use_dataset_video_b, | |
| inputs=[current_dataset_b, trajectory_slider_b, dataset_name_b], | |
| outputs=[video_b_input, dataset_status_b, trajectory_metadata_b], | |
| ) | |
| next_traj_btn_b.click( | |
| fn=next_trajectory_b, | |
| inputs=[current_dataset_b, trajectory_slider_b, dataset_name_b], | |
| outputs=[ | |
| trajectory_slider_b, | |
| video_b_input, | |
| trajectory_metadata_b, | |
| dataset_status_b, | |
| ], | |
| ) | |
| prev_traj_btn_b.click( | |
| fn=prev_trajectory_b, | |
| inputs=[current_dataset_b, trajectory_slider_b, dataset_name_b], | |
| outputs=[ | |
| trajectory_slider_b, | |
| video_b_input, | |
| trajectory_metadata_b, | |
| dataset_status_b, | |
| ], | |
| ) | |
| trajectory_slider_b.change( | |
| fn=update_trajectory_on_slider_change_b, | |
| inputs=[current_dataset_b, trajectory_slider_b, dataset_name_b], | |
| outputs=[trajectory_metadata_b, dataset_status_b], | |
| ) | |
| analyze_dual_btn.click( | |
| fn=process_two_videos, | |
| inputs=[ | |
| video_a_input, | |
| video_b_input, | |
| task_text_dual, | |
| prediction_type, | |
| server_url_state, | |
| fps_input_dual, | |
| ], | |
| outputs=[result_text, video_a_display, video_b_display], | |
| api_name="process_two_videos", | |
| ) | |
| def main(): | |
| """Launch the Gradio app.""" | |
| import sys | |
| # Check if reload mode is requested | |
| watch_files = os.getenv("GRADIO_WATCH", "0") == "1" or "--reload" in sys.argv | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, # Show full error messages | |
| ) | |
| if __name__ == "__main__": | |
| main() | |