Spaces:
Running
Running
Anthony Liang
commited on
Commit
·
d2a5693
1
Parent(s):
f506da8
update interface
Browse files
app.py
CHANGED
|
@@ -12,9 +12,14 @@ from typing import Optional, Tuple
|
|
| 12 |
import logging
|
| 13 |
|
| 14 |
import gradio as gr
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
import matplotlib
|
| 17 |
-
|
|
|
|
| 18 |
import matplotlib.pyplot as plt
|
| 19 |
import numpy as np
|
| 20 |
import requests
|
|
@@ -24,6 +29,7 @@ from typing import Any, Optional, Tuple
|
|
| 24 |
|
| 25 |
from rfm.data.dataset_types import Trajectory, ProgressSample, PreferenceSample
|
| 26 |
from rfm.evals.eval_utils import build_payload, post_batch_npy
|
|
|
|
| 27 |
from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
|
| 28 |
|
| 29 |
logger = logging.getLogger(__name__)
|
|
@@ -57,7 +63,7 @@ PREDEFINED_DATASETS = [
|
|
| 57 |
"aliangdw/usc_xarm_policy_ranking",
|
| 58 |
"aliangdw/usc_franka_policy_ranking",
|
| 59 |
"aliangdw/utd_so101_policy_ranking",
|
| 60 |
-
"aliangdw/utd_so101_human"
|
| 61 |
]
|
| 62 |
|
| 63 |
# Global server state
|
|
@@ -65,17 +71,18 @@ _server_state = {
|
|
| 65 |
"server_url": None,
|
| 66 |
}
|
| 67 |
|
|
|
|
| 68 |
def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[str]]:
|
| 69 |
"""Check server health and get model info."""
|
| 70 |
if not server_url:
|
| 71 |
return "Please provide a server URL.", None, None
|
| 72 |
-
|
| 73 |
try:
|
| 74 |
url = server_url.rstrip("/") + "/health"
|
| 75 |
response = requests.get(url, timeout=5.0)
|
| 76 |
response.raise_for_status()
|
| 77 |
health_data = response.json()
|
| 78 |
-
|
| 79 |
# Also try to get GPU status for more info
|
| 80 |
try:
|
| 81 |
status_url = server_url.rstrip("/") + "/gpu_status"
|
|
@@ -85,7 +92,7 @@ def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[
|
|
| 85 |
health_data.update(status_data)
|
| 86 |
except:
|
| 87 |
pass
|
| 88 |
-
|
| 89 |
# Try to get model info
|
| 90 |
model_info_text = None
|
| 91 |
try:
|
|
@@ -96,9 +103,13 @@ def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[
|
|
| 96 |
model_info_text = format_model_info(model_info_data)
|
| 97 |
except Exception as e:
|
| 98 |
logger.warning(f"Could not fetch model info: {e}")
|
| 99 |
-
|
| 100 |
_server_state["server_url"] = server_url
|
| 101 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
except requests.exceptions.RequestException as e:
|
| 103 |
return f"Error connecting to server: {str(e)}", None, None
|
| 104 |
|
|
@@ -106,31 +117,31 @@ def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[
|
|
| 106 |
def format_model_info(model_info: dict) -> str:
|
| 107 |
"""Format model info and experiment config as markdown."""
|
| 108 |
lines = ["## Model Information\n"]
|
| 109 |
-
|
| 110 |
# Model path
|
| 111 |
model_path = model_info.get("model_path", "Unknown")
|
| 112 |
lines.append(f"**Model Path:** `{model_path}`\n")
|
| 113 |
-
|
| 114 |
# Number of GPUs
|
| 115 |
num_gpus = model_info.get("num_gpus", "Unknown")
|
| 116 |
lines.append(f"**Number of GPUs:** {num_gpus}\n")
|
| 117 |
-
|
| 118 |
# Model architecture
|
| 119 |
model_arch = model_info.get("model_architecture", {})
|
| 120 |
if model_arch and "error" not in model_arch:
|
| 121 |
lines.append("\n## Model Architecture\n")
|
| 122 |
-
|
| 123 |
model_class = model_arch.get("model_class", "Unknown")
|
| 124 |
model_module = model_arch.get("model_module", "Unknown")
|
| 125 |
lines.append(f"- **Model Class:** `{model_class}`\n")
|
| 126 |
lines.append(f"- **Module:** `{model_module}`\n")
|
| 127 |
-
|
| 128 |
# Parameter counts
|
| 129 |
total_params = model_arch.get("total_parameters")
|
| 130 |
trainable_params = model_arch.get("trainable_parameters")
|
| 131 |
frozen_params = model_arch.get("frozen_parameters")
|
| 132 |
trainable_pct = model_arch.get("trainable_percentage")
|
| 133 |
-
|
| 134 |
if total_params is not None:
|
| 135 |
lines.append(f"\n### Parameter Statistics\n")
|
| 136 |
lines.append(f"- **Total Parameters:** {total_params:,}\n")
|
|
@@ -140,7 +151,7 @@ def format_model_info(model_info: dict) -> str:
|
|
| 140 |
lines.append(f"- **Frozen Parameters:** {frozen_params:,}\n")
|
| 141 |
if trainable_pct is not None:
|
| 142 |
lines.append(f"- **Trainable Percentage:** {trainable_pct:.2f}%\n")
|
| 143 |
-
|
| 144 |
# Architecture summary
|
| 145 |
arch_summary = model_arch.get("architecture_summary", [])
|
| 146 |
if arch_summary:
|
|
@@ -150,12 +161,12 @@ def format_model_info(model_info: dict) -> str:
|
|
| 150 |
module_type = module_info.get("type", "Unknown")
|
| 151 |
params = module_info.get("parameters", 0)
|
| 152 |
lines.append(f"- **{name}** (`{module_type}`): {params:,} parameters\n")
|
| 153 |
-
|
| 154 |
# Experiment config
|
| 155 |
exp_config = model_info.get("experiment_config", {})
|
| 156 |
if exp_config:
|
| 157 |
lines.append("\n## Experiment Configuration\n")
|
| 158 |
-
|
| 159 |
# Model config
|
| 160 |
model_cfg = exp_config.get("model", {})
|
| 161 |
if model_cfg:
|
|
@@ -168,29 +179,33 @@ def format_model_info(model_info: dict) -> str:
|
|
| 168 |
lines.append(f"- **Train Success Head:** {model_cfg.get('train_success_head', False)}\n")
|
| 169 |
lines.append(f"- **Use PEFT:** {model_cfg.get('use_peft', False)}\n")
|
| 170 |
lines.append(f"- **Use Unsloth:** {model_cfg.get('use_unsloth', False)}\n")
|
| 171 |
-
|
| 172 |
# Data config
|
| 173 |
data_cfg = exp_config.get("data", {})
|
| 174 |
if data_cfg:
|
| 175 |
lines.append("\n### Data Configuration\n")
|
| 176 |
lines.append(f"- **Max Frames:** {data_cfg.get('max_frames', 'N/A')}\n")
|
| 177 |
-
lines.append(
|
| 178 |
-
|
|
|
|
|
|
|
| 179 |
if train_datasets:
|
| 180 |
lines.append(f"- **Train Datasets:** {', '.join(train_datasets)}\n")
|
| 181 |
-
eval_datasets = data_cfg.get(
|
| 182 |
if eval_datasets:
|
| 183 |
lines.append(f"- **Eval Datasets:** {', '.join(eval_datasets)}\n")
|
| 184 |
-
|
| 185 |
# Training config
|
| 186 |
training_cfg = exp_config.get("training", {})
|
| 187 |
if training_cfg:
|
| 188 |
lines.append("\n### Training Configuration\n")
|
| 189 |
lines.append(f"- **Learning Rate:** {training_cfg.get('learning_rate', 'N/A')}\n")
|
| 190 |
lines.append(f"- **Batch Size:** {training_cfg.get('per_device_train_batch_size', 'N/A')}\n")
|
| 191 |
-
lines.append(
|
|
|
|
|
|
|
| 192 |
lines.append(f"- **Max Steps:** {training_cfg.get('max_steps', 'N/A')}\n")
|
| 193 |
-
|
| 194 |
return "".join(lines)
|
| 195 |
|
| 196 |
|
|
@@ -199,12 +214,12 @@ def load_rfm_dataset(dataset_name, config_name):
|
|
| 199 |
try:
|
| 200 |
if not dataset_name or not config_name:
|
| 201 |
return None, "Please provide both dataset name and configuration"
|
| 202 |
-
|
| 203 |
dataset = load_dataset_hf(dataset_name, name=config_name, split="train")
|
| 204 |
-
|
| 205 |
if len(dataset) == 0:
|
| 206 |
return None, f"Dataset {dataset_name}/{config_name} is empty"
|
| 207 |
-
|
| 208 |
return dataset, f"Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}"
|
| 209 |
except Exception as e:
|
| 210 |
error_msg = str(e)
|
|
@@ -231,18 +246,18 @@ def get_trajectory_video_path(dataset, index, dataset_name):
|
|
| 231 |
try:
|
| 232 |
item = dataset[int(index)]
|
| 233 |
frames_data = item["frames"]
|
| 234 |
-
|
| 235 |
if isinstance(frames_data, str):
|
| 236 |
# Construct HuggingFace Hub URL
|
| 237 |
if dataset_name:
|
| 238 |
video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}"
|
| 239 |
else:
|
| 240 |
video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}"
|
| 241 |
-
|
| 242 |
task = item.get("task", "Complete the task")
|
| 243 |
quality_label = item.get("quality_label", None)
|
| 244 |
partial_success = item.get("partial_success", None)
|
| 245 |
-
|
| 246 |
return video_path, task, quality_label, partial_success
|
| 247 |
else:
|
| 248 |
return None, None, None, None
|
|
@@ -267,7 +282,7 @@ def extract_frames(video_path: str, fps: float = 1.0) -> np.ndarray:
|
|
| 267 |
# Check if it's a URL or local file
|
| 268 |
is_url = video_path.startswith(("http://", "https://"))
|
| 269 |
is_local_file = os.path.exists(video_path) if not is_url else False
|
| 270 |
-
|
| 271 |
if not is_url and not is_local_file:
|
| 272 |
logger.warning(f"Video path does not exist: {video_path}")
|
| 273 |
return None
|
|
@@ -304,7 +319,7 @@ def extract_frames(video_path: str, fps: float = 1.0) -> np.ndarray:
|
|
| 304 |
frame_indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
|
| 305 |
|
| 306 |
frames_array = vr.get_batch(frame_indices).asnumpy() # Shape: (T, H, W, C)
|
| 307 |
-
del vr
|
| 308 |
return frames_array
|
| 309 |
except Exception as e:
|
| 310 |
logger.error(f"Error extracting frames from {video_path}: {e}")
|
|
@@ -316,26 +331,26 @@ def process_single_video(
|
|
| 316 |
task_text: str = "Complete the task",
|
| 317 |
server_url: str = "",
|
| 318 |
fps: float = 1.0,
|
| 319 |
-
) -> Tuple[Optional[str], Optional[str]
|
| 320 |
"""Process single video for progress and success predictions using eval server."""
|
| 321 |
if not server_url:
|
| 322 |
-
return None,
|
| 323 |
-
|
| 324 |
if not _server_state.get("server_url"):
|
| 325 |
-
return None,
|
| 326 |
-
|
| 327 |
if video_path is None:
|
| 328 |
-
return None,
|
| 329 |
|
| 330 |
try:
|
| 331 |
frames_array = extract_frames(video_path, fps=fps)
|
| 332 |
if frames_array is None or frames_array.size == 0:
|
| 333 |
-
return None,
|
| 334 |
|
| 335 |
# Convert frames to (T, H, W, C) numpy array with uint8 values
|
| 336 |
if frames_array.dtype != np.uint8:
|
| 337 |
frames_array = np.clip(frames_array, 0, 255).astype(np.uint8)
|
| 338 |
-
|
| 339 |
num_frames = frames_array.shape[0]
|
| 340 |
frames_shape = frames_array.shape # (T, H, W, C)
|
| 341 |
|
|
@@ -366,25 +381,54 @@ def process_single_video(
|
|
| 366 |
# Process response
|
| 367 |
outputs_progress = response.get("outputs_progress", {})
|
| 368 |
progress_pred = outputs_progress.get("progress_pred", [])
|
| 369 |
-
|
|
|
|
|
|
|
| 370 |
# Extract progress predictions
|
| 371 |
if progress_pred and len(progress_pred) > 0:
|
| 372 |
progress_array = np.array(progress_pred[0]) # First sample
|
| 373 |
else:
|
| 374 |
progress_array = np.array([])
|
| 375 |
|
| 376 |
-
#
|
| 377 |
-
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
|
| 380 |
info_text = f"**Frames processed:** {num_frames}\n"
|
| 381 |
if len(progress_array) > 0:
|
| 382 |
info_text += f"**Final progress:** {progress_array[-1]:.3f}\n"
|
|
|
|
|
|
|
| 383 |
|
| 384 |
-
|
|
|
|
| 385 |
|
| 386 |
except Exception as e:
|
| 387 |
-
return None,
|
| 388 |
|
| 389 |
|
| 390 |
def process_dual_videos(
|
|
@@ -398,7 +442,7 @@ def process_dual_videos(
|
|
| 398 |
"""Process two videos for preference or similarity prediction using eval server."""
|
| 399 |
if not server_url:
|
| 400 |
return "Please provide a server URL and check connection first.", None
|
| 401 |
-
|
| 402 |
if not _server_state.get("server_url"):
|
| 403 |
return "Server not connected. Please check server connection first.", None
|
| 404 |
|
|
@@ -475,6 +519,47 @@ def process_dual_videos(
|
|
| 475 |
else:
|
| 476 |
result_text += "Could not extract preference prediction from server response.\n"
|
| 477 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
else: # similarity - not yet implemented in eval server response format
|
| 479 |
result_text = "Similarity prediction not yet supported in eval server response format."
|
| 480 |
|
|
@@ -489,107 +574,49 @@ def process_dual_videos(
|
|
| 489 |
return f"Error processing videos: {str(e)}", None
|
| 490 |
|
| 491 |
|
| 492 |
-
def create_progress_plot(progress_pred: np.ndarray, num_frames: int) -> str:
|
| 493 |
-
"""Create progress prediction plot."""
|
| 494 |
-
plt.rcParams['font.family'] = 'DejaVu Sans'
|
| 495 |
-
plt.rcParams['font.size'] = 16
|
| 496 |
-
|
| 497 |
-
fig, ax = plt.subplots(figsize=(10, 6))
|
| 498 |
-
|
| 499 |
-
if len(progress_pred) > 0:
|
| 500 |
-
frame_indices = np.arange(len(progress_pred))
|
| 501 |
-
ax.plot(frame_indices, progress_pred, 'b-', linewidth=3, marker='o', markersize=8, label='Progress Prediction')
|
| 502 |
-
else:
|
| 503 |
-
ax.text(0.5, 0.5, 'No progress prediction available',
|
| 504 |
-
horizontalalignment='center', verticalalignment='center',
|
| 505 |
-
transform=ax.transAxes, fontsize=18)
|
| 506 |
-
|
| 507 |
-
ax.set_xlabel('Frame Index', fontsize=18, fontweight='bold')
|
| 508 |
-
ax.set_ylabel('Progress (0-1)', fontsize=18, fontweight='bold')
|
| 509 |
-
ax.set_title('Progress Prediction', fontsize=20, fontweight='bold')
|
| 510 |
-
ax.set_ylim([0, 1])
|
| 511 |
-
|
| 512 |
-
plt.tight_layout()
|
| 513 |
-
|
| 514 |
-
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
| 515 |
-
plt.savefig(tmp_file.name, dpi=150, bbox_inches='tight')
|
| 516 |
-
plt.close()
|
| 517 |
-
|
| 518 |
-
return tmp_file.name
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
def create_success_plot(success_probs: np.ndarray, num_frames: int) -> str:
|
| 522 |
-
"""Create success probability plot."""
|
| 523 |
-
plt.rcParams['font.family'] = 'DejaVu Sans'
|
| 524 |
-
plt.rcParams['font.size'] = 16
|
| 525 |
-
|
| 526 |
-
fig, ax = plt.subplots(figsize=(10, 6))
|
| 527 |
-
|
| 528 |
-
if len(success_probs) > 0:
|
| 529 |
-
frame_indices = np.arange(len(success_probs))
|
| 530 |
-
ax.plot(frame_indices, success_probs, 'g-', linewidth=3, marker='s', markersize=8, label='Success Probability')
|
| 531 |
-
ax.axhline(y=0.5, color='r', linestyle='--', linewidth=2, label='Decision Threshold (0.5)')
|
| 532 |
-
else:
|
| 533 |
-
ax.text(0.5, 0.5, 'No success prediction available',
|
| 534 |
-
horizontalalignment='center', verticalalignment='center',
|
| 535 |
-
transform=ax.transAxes, fontsize=18)
|
| 536 |
-
|
| 537 |
-
ax.set_xlabel('Frame Index', fontsize=18, fontweight='bold')
|
| 538 |
-
ax.set_ylabel('Success Probability (0-1)', fontsize=18, fontweight='bold')
|
| 539 |
-
ax.set_title('Success Prediction', fontsize=20, fontweight='bold')
|
| 540 |
-
ax.set_ylim([0, 1])
|
| 541 |
-
|
| 542 |
-
plt.tight_layout()
|
| 543 |
-
|
| 544 |
-
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
| 545 |
-
plt.savefig(tmp_file.name, dpi=150, bbox_inches='tight')
|
| 546 |
-
plt.close()
|
| 547 |
-
|
| 548 |
-
return tmp_file.name
|
| 549 |
-
|
| 550 |
def create_comparison_plot(frames_a: list, frames_b: list, prediction_type: str) -> str:
|
| 551 |
"""Create side-by-side comparison plot of two videos."""
|
| 552 |
-
plt.rcParams[
|
| 553 |
-
plt.rcParams[
|
| 554 |
-
|
| 555 |
fig, axes = plt.subplots(2, min(8, max(len(frames_a), len(frames_b))), figsize=(16, 4))
|
| 556 |
-
|
| 557 |
if len(axes.shape) == 1:
|
| 558 |
axes = axes.reshape(2, -1)
|
| 559 |
-
|
| 560 |
# Sample frames to display
|
| 561 |
num_display = min(8, max(len(frames_a), len(frames_b)))
|
| 562 |
indices_a = np.linspace(0, len(frames_a) - 1, num_display, dtype=int) if len(frames_a) > 1 else [0]
|
| 563 |
indices_b = np.linspace(0, len(frames_b) - 1, num_display, dtype=int) if len(frames_b) > 1 else [0]
|
| 564 |
-
|
| 565 |
# Display frames from video A (top row)
|
| 566 |
for idx, frame_idx in enumerate(indices_a):
|
| 567 |
if frame_idx < len(frames_a):
|
| 568 |
axes[0, idx].imshow(frames_a[frame_idx])
|
| 569 |
-
axes[0, idx].axis(
|
| 570 |
-
axes[0, idx].set_title(f
|
| 571 |
-
|
| 572 |
# Display frames from video B (bottom row)
|
| 573 |
for idx, frame_idx in enumerate(indices_b):
|
| 574 |
if frame_idx < len(frames_b):
|
| 575 |
axes[1, idx].imshow(frames_b[frame_idx])
|
| 576 |
-
axes[1, idx].axis(
|
| 577 |
-
axes[1, idx].set_title(f
|
| 578 |
-
|
| 579 |
# Add row labels
|
| 580 |
-
fig.text(0.02, 0.75,
|
| 581 |
-
fig.text(0.02, 0.25,
|
| 582 |
-
|
| 583 |
title = f"{prediction_type.capitalize()} Comparison: Video A vs Video B"
|
| 584 |
-
fig.suptitle(title, fontsize=20, fontweight=
|
| 585 |
-
|
| 586 |
plt.tight_layout()
|
| 587 |
-
|
| 588 |
# Save to temporary file
|
| 589 |
-
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=
|
| 590 |
-
plt.savefig(tmp_file.name, dpi=150, bbox_inches=
|
| 591 |
plt.close()
|
| 592 |
-
|
| 593 |
return tmp_file.name
|
| 594 |
|
| 595 |
|
|
@@ -619,7 +646,7 @@ with demo:
|
|
| 619 |
with gr.Tab("Server Setup"):
|
| 620 |
gr.Markdown("### Connect to Eval Server")
|
| 621 |
gr.Markdown("Enter the eval server URL and check connection.")
|
| 622 |
-
|
| 623 |
with gr.Row():
|
| 624 |
with gr.Column(scale=3):
|
| 625 |
server_url_input = gr.Textbox(
|
|
@@ -630,7 +657,7 @@ with demo:
|
|
| 630 |
)
|
| 631 |
with gr.Column(scale=1):
|
| 632 |
check_connection_btn = gr.Button("Check Connection", variant="primary", size="sm")
|
| 633 |
-
|
| 634 |
server_status = gr.Markdown("Enter server URL and click 'Check Connection'")
|
| 635 |
model_info_display = gr.Markdown("", visible=False)
|
| 636 |
|
|
@@ -641,7 +668,7 @@ with demo:
|
|
| 641 |
return status, gr.update(value=model_info_text, visible=True)
|
| 642 |
else:
|
| 643 |
return status, gr.update(visible=False)
|
| 644 |
-
|
| 645 |
check_connection_btn.click(
|
| 646 |
fn=on_check_connection,
|
| 647 |
inputs=[server_url_input],
|
|
@@ -651,7 +678,7 @@ with demo:
|
|
| 651 |
with gr.Tab("Progress Prediction"):
|
| 652 |
gr.Markdown("### Progress & Success Prediction")
|
| 653 |
gr.Markdown("Upload a video or select one from a dataset to get progress predictions.")
|
| 654 |
-
|
| 655 |
with gr.Row():
|
| 656 |
with gr.Column():
|
| 657 |
with gr.Accordion("📁 Select from Dataset", open=False):
|
|
@@ -659,37 +686,29 @@ with demo:
|
|
| 659 |
choices=PREDEFINED_DATASETS,
|
| 660 |
value="jesbu1/oxe_rfm",
|
| 661 |
label="Dataset Name",
|
| 662 |
-
allow_custom_value=True
|
| 663 |
)
|
| 664 |
config_name_single = gr.Dropdown(
|
| 665 |
-
choices=[],
|
| 666 |
-
value="",
|
| 667 |
-
label="Configuration Name",
|
| 668 |
-
allow_custom_value=True
|
| 669 |
)
|
| 670 |
with gr.Row():
|
| 671 |
refresh_configs_btn = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm")
|
| 672 |
load_dataset_btn = gr.Button("Load Dataset", variant="secondary", size="sm")
|
| 673 |
-
|
| 674 |
dataset_status_single = gr.Markdown("", visible=False)
|
| 675 |
with gr.Row():
|
| 676 |
prev_traj_btn = gr.Button("⬅️ Prev", variant="secondary", size="sm")
|
| 677 |
trajectory_slider = gr.Slider(
|
| 678 |
-
minimum=0,
|
| 679 |
-
maximum=0,
|
| 680 |
-
step=1,
|
| 681 |
-
value=0,
|
| 682 |
-
label="Trajectory Index",
|
| 683 |
-
interactive=True
|
| 684 |
)
|
| 685 |
next_traj_btn = gr.Button("Next ➡️", variant="secondary", size="sm")
|
| 686 |
trajectory_metadata = gr.Markdown("", visible=False)
|
| 687 |
use_dataset_video_btn = gr.Button("Use Selected Video", variant="secondary")
|
| 688 |
-
|
| 689 |
gr.Markdown("---")
|
| 690 |
gr.Markdown("**OR**")
|
| 691 |
gr.Markdown("---")
|
| 692 |
-
|
| 693 |
single_video_input = gr.Video(label="Upload Video", height=300)
|
| 694 |
task_text_input = gr.Textbox(
|
| 695 |
label="Task Description",
|
|
@@ -707,13 +726,12 @@ with demo:
|
|
| 707 |
analyze_single_btn = gr.Button("Analyze Video", variant="primary")
|
| 708 |
|
| 709 |
with gr.Column():
|
| 710 |
-
progress_plot = gr.Image(label="Progress Prediction", height=400)
|
| 711 |
-
success_plot = gr.Image(label="Success Prediction", height=400)
|
| 712 |
info_output = gr.Markdown("")
|
| 713 |
-
|
| 714 |
# State variables for dataset
|
| 715 |
current_dataset_single = gr.State(None)
|
| 716 |
-
|
| 717 |
def update_config_choices_single(dataset_name):
|
| 718 |
"""Update config choices when dataset changes."""
|
| 719 |
if not dataset_name:
|
|
@@ -727,7 +745,7 @@ with demo:
|
|
| 727 |
except Exception as e:
|
| 728 |
logger.warning(f"Could not fetch configs: {e}")
|
| 729 |
return gr.update(choices=[], value="")
|
| 730 |
-
|
| 731 |
def load_dataset_single(dataset_name, config_name):
|
| 732 |
"""Load dataset and update slider."""
|
| 733 |
dataset, status = load_rfm_dataset(dataset_name, config_name)
|
|
@@ -736,16 +754,23 @@ with demo:
|
|
| 736 |
return (
|
| 737 |
dataset,
|
| 738 |
gr.update(value=status, visible=True),
|
| 739 |
-
gr.update(
|
|
|
|
|
|
|
| 740 |
)
|
| 741 |
else:
|
| 742 |
return None, gr.update(value=status, visible=True), gr.update(maximum=0, value=0, interactive=False)
|
| 743 |
-
|
| 744 |
def use_dataset_video(dataset, index, dataset_name):
|
| 745 |
"""Load video from dataset and update inputs."""
|
| 746 |
if dataset is None:
|
| 747 |
-
return
|
| 748 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 749 |
video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
|
| 750 |
if video_path:
|
| 751 |
# Build metadata text
|
|
@@ -754,28 +779,35 @@ with demo:
|
|
| 754 |
metadata_lines.append(f"**Quality Label:** {quality_label}")
|
| 755 |
if partial_success is not None:
|
| 756 |
metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
|
| 757 |
-
|
| 758 |
metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
|
| 759 |
status_text = f"✅ Loaded trajectory {index} from dataset"
|
| 760 |
if metadata_text:
|
| 761 |
status_text += f"\n\n{metadata_text}"
|
| 762 |
-
|
| 763 |
return (
|
| 764 |
-
video_path,
|
| 765 |
-
task,
|
| 766 |
gr.update(value=status_text, visible=True),
|
| 767 |
-
gr.update(value=metadata_text, visible=bool(metadata_text))
|
| 768 |
)
|
| 769 |
else:
|
| 770 |
-
return
|
| 771 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 772 |
def next_trajectory(dataset, current_idx, dataset_name):
|
| 773 |
"""Go to next trajectory."""
|
| 774 |
if dataset is None:
|
| 775 |
return 0, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
|
| 776 |
next_idx = min(current_idx + 1, len(dataset) - 1)
|
| 777 |
-
video_path, task, quality_label, partial_success = get_trajectory_video_path(
|
| 778 |
-
|
|
|
|
|
|
|
| 779 |
if video_path:
|
| 780 |
# Build metadata text
|
| 781 |
metadata_lines = []
|
|
@@ -783,25 +815,27 @@ with demo:
|
|
| 783 |
metadata_lines.append(f"**Quality Label:** {quality_label}")
|
| 784 |
if partial_success is not None:
|
| 785 |
metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
|
| 786 |
-
|
| 787 |
metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
|
| 788 |
return (
|
| 789 |
next_idx,
|
| 790 |
video_path,
|
| 791 |
task,
|
| 792 |
gr.update(value=metadata_text, visible=bool(metadata_text)),
|
| 793 |
-
gr.update(value=f"✅ Trajectory {next_idx}/{len(dataset) - 1}", visible=True)
|
| 794 |
)
|
| 795 |
else:
|
| 796 |
return current_idx, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
|
| 797 |
-
|
| 798 |
def prev_trajectory(dataset, current_idx, dataset_name):
|
| 799 |
"""Go to previous trajectory."""
|
| 800 |
if dataset is None:
|
| 801 |
return 0, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
|
| 802 |
prev_idx = max(current_idx - 1, 0)
|
| 803 |
-
video_path, task, quality_label, partial_success = get_trajectory_video_path(
|
| 804 |
-
|
|
|
|
|
|
|
| 805 |
if video_path:
|
| 806 |
# Build metadata text
|
| 807 |
metadata_lines = []
|
|
@@ -809,23 +843,23 @@ with demo:
|
|
| 809 |
metadata_lines.append(f"**Quality Label:** {quality_label}")
|
| 810 |
if partial_success is not None:
|
| 811 |
metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
|
| 812 |
-
|
| 813 |
metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
|
| 814 |
return (
|
| 815 |
prev_idx,
|
| 816 |
video_path,
|
| 817 |
task,
|
| 818 |
gr.update(value=metadata_text, visible=bool(metadata_text)),
|
| 819 |
-
gr.update(value=f"✅ Trajectory {prev_idx}/{len(dataset) - 1}", visible=True)
|
| 820 |
)
|
| 821 |
else:
|
| 822 |
return current_idx, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
|
| 823 |
-
|
| 824 |
def update_trajectory_on_slider_change(dataset, index, dataset_name):
|
| 825 |
"""Update trajectory metadata when slider changes."""
|
| 826 |
if dataset is None:
|
| 827 |
return gr.update(visible=False), gr.update(visible=False)
|
| 828 |
-
|
| 829 |
video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
|
| 830 |
if video_path:
|
| 831 |
# Build metadata text
|
|
@@ -834,64 +868,73 @@ with demo:
|
|
| 834 |
metadata_lines.append(f"**Quality Label:** {quality_label}")
|
| 835 |
if partial_success is not None:
|
| 836 |
metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
|
| 837 |
-
|
| 838 |
metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
|
| 839 |
return (
|
| 840 |
gr.update(value=metadata_text, visible=bool(metadata_text)),
|
| 841 |
-
gr.update(value=f"Trajectory {index}/{len(dataset) - 1}", visible=True)
|
| 842 |
)
|
| 843 |
else:
|
| 844 |
return gr.update(visible=False), gr.update(visible=False)
|
| 845 |
-
|
| 846 |
# Dataset selection handlers
|
| 847 |
dataset_name_single.change(
|
| 848 |
-
fn=update_config_choices_single,
|
| 849 |
-
inputs=[dataset_name_single],
|
| 850 |
-
outputs=[config_name_single]
|
| 851 |
)
|
| 852 |
-
|
| 853 |
refresh_configs_btn.click(
|
| 854 |
-
fn=update_config_choices_single,
|
| 855 |
-
inputs=[dataset_name_single],
|
| 856 |
-
outputs=[config_name_single]
|
| 857 |
)
|
| 858 |
-
|
| 859 |
load_dataset_btn.click(
|
| 860 |
fn=load_dataset_single,
|
| 861 |
inputs=[dataset_name_single, config_name_single],
|
| 862 |
-
outputs=[current_dataset_single, dataset_status_single, trajectory_slider]
|
| 863 |
)
|
| 864 |
-
|
| 865 |
use_dataset_video_btn.click(
|
| 866 |
fn=use_dataset_video,
|
| 867 |
inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
|
| 868 |
-
outputs=[single_video_input, task_text_input, dataset_status_single, trajectory_metadata]
|
| 869 |
)
|
| 870 |
-
|
| 871 |
# Navigation buttons
|
| 872 |
next_traj_btn.click(
|
| 873 |
fn=next_trajectory,
|
| 874 |
inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
|
| 875 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 876 |
)
|
| 877 |
-
|
| 878 |
prev_traj_btn.click(
|
| 879 |
fn=prev_trajectory,
|
| 880 |
inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
|
| 881 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 882 |
)
|
| 883 |
-
|
| 884 |
# Update metadata when slider changes
|
| 885 |
trajectory_slider.change(
|
| 886 |
fn=update_trajectory_on_slider_change,
|
| 887 |
inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
|
| 888 |
-
outputs=[trajectory_metadata, dataset_status_single]
|
| 889 |
)
|
| 890 |
-
|
| 891 |
analyze_single_btn.click(
|
| 892 |
fn=process_single_video,
|
| 893 |
inputs=[single_video_input, task_text_input, server_url_input, fps_input_single],
|
| 894 |
-
outputs=[progress_plot,
|
|
|
|
| 895 |
)
|
| 896 |
|
| 897 |
with gr.Tab("Preference/Similarity Analysis"):
|
|
@@ -906,7 +949,7 @@ with demo:
|
|
| 906 |
value="Complete the task",
|
| 907 |
)
|
| 908 |
prediction_type = gr.Radio(
|
| 909 |
-
choices=["preference", "similarity"],
|
| 910 |
value="preference",
|
| 911 |
label="Prediction Type",
|
| 912 |
)
|
|
@@ -928,16 +971,17 @@ with demo:
|
|
| 928 |
fn=process_dual_videos,
|
| 929 |
inputs=[video_a_input, video_b_input, task_text_dual, prediction_type, server_url_input, fps_input_dual],
|
| 930 |
outputs=[result_text, comparison_plot],
|
|
|
|
| 931 |
)
|
| 932 |
|
| 933 |
|
| 934 |
def main():
|
| 935 |
"""Launch the Gradio app."""
|
| 936 |
import sys
|
| 937 |
-
|
| 938 |
# Check if reload mode is requested
|
| 939 |
watch_files = os.getenv("GRADIO_WATCH", "0") == "1" or "--reload" in sys.argv
|
| 940 |
-
|
| 941 |
demo.launch(
|
| 942 |
server_name="0.0.0.0",
|
| 943 |
server_port=7860,
|
|
|
|
| 12 |
import logging
|
| 13 |
|
| 14 |
import gradio as gr
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import spaces # Required for ZeroGPU on Hugging Face Spaces
|
| 18 |
+
except ImportError:
|
| 19 |
+
spaces = None # Not available when running locally
|
| 20 |
import matplotlib
|
| 21 |
+
|
| 22 |
+
matplotlib.use("Agg") # Use non-interactive backend
|
| 23 |
import matplotlib.pyplot as plt
|
| 24 |
import numpy as np
|
| 25 |
import requests
|
|
|
|
| 29 |
|
| 30 |
from rfm.data.dataset_types import Trajectory, ProgressSample, PreferenceSample
|
| 31 |
from rfm.evals.eval_utils import build_payload, post_batch_npy
|
| 32 |
+
from rfm.evals.eval_viz_utils import create_combined_progress_success_plot
|
| 33 |
from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
|
| 34 |
|
| 35 |
logger = logging.getLogger(__name__)
|
|
|
|
| 63 |
"aliangdw/usc_xarm_policy_ranking",
|
| 64 |
"aliangdw/usc_franka_policy_ranking",
|
| 65 |
"aliangdw/utd_so101_policy_ranking",
|
| 66 |
+
"aliangdw/utd_so101_human",
|
| 67 |
]
|
| 68 |
|
| 69 |
# Global server state
|
|
|
|
| 71 |
"server_url": None,
|
| 72 |
}
|
| 73 |
|
| 74 |
+
|
| 75 |
def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[str]]:
|
| 76 |
"""Check server health and get model info."""
|
| 77 |
if not server_url:
|
| 78 |
return "Please provide a server URL.", None, None
|
| 79 |
+
|
| 80 |
try:
|
| 81 |
url = server_url.rstrip("/") + "/health"
|
| 82 |
response = requests.get(url, timeout=5.0)
|
| 83 |
response.raise_for_status()
|
| 84 |
health_data = response.json()
|
| 85 |
+
|
| 86 |
# Also try to get GPU status for more info
|
| 87 |
try:
|
| 88 |
status_url = server_url.rstrip("/") + "/gpu_status"
|
|
|
|
| 92 |
health_data.update(status_data)
|
| 93 |
except:
|
| 94 |
pass
|
| 95 |
+
|
| 96 |
# Try to get model info
|
| 97 |
model_info_text = None
|
| 98 |
try:
|
|
|
|
| 103 |
model_info_text = format_model_info(model_info_data)
|
| 104 |
except Exception as e:
|
| 105 |
logger.warning(f"Could not fetch model info: {e}")
|
| 106 |
+
|
| 107 |
_server_state["server_url"] = server_url
|
| 108 |
+
return (
|
| 109 |
+
f"Server connected: {health_data.get('available_gpus', 0)}/{health_data.get('total_gpus', 0)} GPUs available",
|
| 110 |
+
health_data,
|
| 111 |
+
model_info_text,
|
| 112 |
+
)
|
| 113 |
except requests.exceptions.RequestException as e:
|
| 114 |
return f"Error connecting to server: {str(e)}", None, None
|
| 115 |
|
|
|
|
| 117 |
def format_model_info(model_info: dict) -> str:
|
| 118 |
"""Format model info and experiment config as markdown."""
|
| 119 |
lines = ["## Model Information\n"]
|
| 120 |
+
|
| 121 |
# Model path
|
| 122 |
model_path = model_info.get("model_path", "Unknown")
|
| 123 |
lines.append(f"**Model Path:** `{model_path}`\n")
|
| 124 |
+
|
| 125 |
# Number of GPUs
|
| 126 |
num_gpus = model_info.get("num_gpus", "Unknown")
|
| 127 |
lines.append(f"**Number of GPUs:** {num_gpus}\n")
|
| 128 |
+
|
| 129 |
# Model architecture
|
| 130 |
model_arch = model_info.get("model_architecture", {})
|
| 131 |
if model_arch and "error" not in model_arch:
|
| 132 |
lines.append("\n## Model Architecture\n")
|
| 133 |
+
|
| 134 |
model_class = model_arch.get("model_class", "Unknown")
|
| 135 |
model_module = model_arch.get("model_module", "Unknown")
|
| 136 |
lines.append(f"- **Model Class:** `{model_class}`\n")
|
| 137 |
lines.append(f"- **Module:** `{model_module}`\n")
|
| 138 |
+
|
| 139 |
# Parameter counts
|
| 140 |
total_params = model_arch.get("total_parameters")
|
| 141 |
trainable_params = model_arch.get("trainable_parameters")
|
| 142 |
frozen_params = model_arch.get("frozen_parameters")
|
| 143 |
trainable_pct = model_arch.get("trainable_percentage")
|
| 144 |
+
|
| 145 |
if total_params is not None:
|
| 146 |
lines.append(f"\n### Parameter Statistics\n")
|
| 147 |
lines.append(f"- **Total Parameters:** {total_params:,}\n")
|
|
|
|
| 151 |
lines.append(f"- **Frozen Parameters:** {frozen_params:,}\n")
|
| 152 |
if trainable_pct is not None:
|
| 153 |
lines.append(f"- **Trainable Percentage:** {trainable_pct:.2f}%\n")
|
| 154 |
+
|
| 155 |
# Architecture summary
|
| 156 |
arch_summary = model_arch.get("architecture_summary", [])
|
| 157 |
if arch_summary:
|
|
|
|
| 161 |
module_type = module_info.get("type", "Unknown")
|
| 162 |
params = module_info.get("parameters", 0)
|
| 163 |
lines.append(f"- **{name}** (`{module_type}`): {params:,} parameters\n")
|
| 164 |
+
|
| 165 |
# Experiment config
|
| 166 |
exp_config = model_info.get("experiment_config", {})
|
| 167 |
if exp_config:
|
| 168 |
lines.append("\n## Experiment Configuration\n")
|
| 169 |
+
|
| 170 |
# Model config
|
| 171 |
model_cfg = exp_config.get("model", {})
|
| 172 |
if model_cfg:
|
|
|
|
| 179 |
lines.append(f"- **Train Success Head:** {model_cfg.get('train_success_head', False)}\n")
|
| 180 |
lines.append(f"- **Use PEFT:** {model_cfg.get('use_peft', False)}\n")
|
| 181 |
lines.append(f"- **Use Unsloth:** {model_cfg.get('use_unsloth', False)}\n")
|
| 182 |
+
|
| 183 |
# Data config
|
| 184 |
data_cfg = exp_config.get("data", {})
|
| 185 |
if data_cfg:
|
| 186 |
lines.append("\n### Data Configuration\n")
|
| 187 |
lines.append(f"- **Max Frames:** {data_cfg.get('max_frames', 'N/A')}\n")
|
| 188 |
+
lines.append(
|
| 189 |
+
f"- **Resized Dimensions:** {data_cfg.get('resized_height', 'N/A')}x{data_cfg.get('resized_width', 'N/A')}\n"
|
| 190 |
+
)
|
| 191 |
+
train_datasets = data_cfg.get("train_datasets", [])
|
| 192 |
if train_datasets:
|
| 193 |
lines.append(f"- **Train Datasets:** {', '.join(train_datasets)}\n")
|
| 194 |
+
eval_datasets = data_cfg.get("eval_datasets", [])
|
| 195 |
if eval_datasets:
|
| 196 |
lines.append(f"- **Eval Datasets:** {', '.join(eval_datasets)}\n")
|
| 197 |
+
|
| 198 |
# Training config
|
| 199 |
training_cfg = exp_config.get("training", {})
|
| 200 |
if training_cfg:
|
| 201 |
lines.append("\n### Training Configuration\n")
|
| 202 |
lines.append(f"- **Learning Rate:** {training_cfg.get('learning_rate', 'N/A')}\n")
|
| 203 |
lines.append(f"- **Batch Size:** {training_cfg.get('per_device_train_batch_size', 'N/A')}\n")
|
| 204 |
+
lines.append(
|
| 205 |
+
f"- **Gradient Accumulation Steps:** {training_cfg.get('gradient_accumulation_steps', 'N/A')}\n"
|
| 206 |
+
)
|
| 207 |
lines.append(f"- **Max Steps:** {training_cfg.get('max_steps', 'N/A')}\n")
|
| 208 |
+
|
| 209 |
return "".join(lines)
|
| 210 |
|
| 211 |
|
|
|
|
| 214 |
try:
|
| 215 |
if not dataset_name or not config_name:
|
| 216 |
return None, "Please provide both dataset name and configuration"
|
| 217 |
+
|
| 218 |
dataset = load_dataset_hf(dataset_name, name=config_name, split="train")
|
| 219 |
+
|
| 220 |
if len(dataset) == 0:
|
| 221 |
return None, f"Dataset {dataset_name}/{config_name} is empty"
|
| 222 |
+
|
| 223 |
return dataset, f"Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}"
|
| 224 |
except Exception as e:
|
| 225 |
error_msg = str(e)
|
|
|
|
| 246 |
try:
|
| 247 |
item = dataset[int(index)]
|
| 248 |
frames_data = item["frames"]
|
| 249 |
+
|
| 250 |
if isinstance(frames_data, str):
|
| 251 |
# Construct HuggingFace Hub URL
|
| 252 |
if dataset_name:
|
| 253 |
video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}"
|
| 254 |
else:
|
| 255 |
video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}"
|
| 256 |
+
|
| 257 |
task = item.get("task", "Complete the task")
|
| 258 |
quality_label = item.get("quality_label", None)
|
| 259 |
partial_success = item.get("partial_success", None)
|
| 260 |
+
|
| 261 |
return video_path, task, quality_label, partial_success
|
| 262 |
else:
|
| 263 |
return None, None, None, None
|
|
|
|
| 282 |
# Check if it's a URL or local file
|
| 283 |
is_url = video_path.startswith(("http://", "https://"))
|
| 284 |
is_local_file = os.path.exists(video_path) if not is_url else False
|
| 285 |
+
|
| 286 |
if not is_url and not is_local_file:
|
| 287 |
logger.warning(f"Video path does not exist: {video_path}")
|
| 288 |
return None
|
|
|
|
| 319 |
frame_indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
|
| 320 |
|
| 321 |
frames_array = vr.get_batch(frame_indices).asnumpy() # Shape: (T, H, W, C)
|
| 322 |
+
del vr
|
| 323 |
return frames_array
|
| 324 |
except Exception as e:
|
| 325 |
logger.error(f"Error extracting frames from {video_path}: {e}")
|
|
|
|
| 331 |
task_text: str = "Complete the task",
|
| 332 |
server_url: str = "",
|
| 333 |
fps: float = 1.0,
|
| 334 |
+
) -> Tuple[Optional[str], Optional[str]]:
|
| 335 |
"""Process single video for progress and success predictions using eval server."""
|
| 336 |
if not server_url:
|
| 337 |
+
return None, "Please provide a server URL and check connection first."
|
| 338 |
+
|
| 339 |
if not _server_state.get("server_url"):
|
| 340 |
+
return None, "Server not connected. Please check server connection first."
|
| 341 |
+
|
| 342 |
if video_path is None:
|
| 343 |
+
return None, "Please provide a video."
|
| 344 |
|
| 345 |
try:
|
| 346 |
frames_array = extract_frames(video_path, fps=fps)
|
| 347 |
if frames_array is None or frames_array.size == 0:
|
| 348 |
+
return None, "Could not extract frames from video."
|
| 349 |
|
| 350 |
# Convert frames to (T, H, W, C) numpy array with uint8 values
|
| 351 |
if frames_array.dtype != np.uint8:
|
| 352 |
frames_array = np.clip(frames_array, 0, 255).astype(np.uint8)
|
| 353 |
+
|
| 354 |
num_frames = frames_array.shape[0]
|
| 355 |
frames_shape = frames_array.shape # (T, H, W, C)
|
| 356 |
|
|
|
|
| 381 |
# Process response
|
| 382 |
outputs_progress = response.get("outputs_progress", {})
|
| 383 |
progress_pred = outputs_progress.get("progress_pred", [])
|
| 384 |
+
outputs_success = response.get("outputs_success", {})
|
| 385 |
+
success_probs = outputs_success.get("success_probs", []) if outputs_success else None
|
| 386 |
+
|
| 387 |
# Extract progress predictions
|
| 388 |
if progress_pred and len(progress_pred) > 0:
|
| 389 |
progress_array = np.array(progress_pred[0]) # First sample
|
| 390 |
else:
|
| 391 |
progress_array = np.array([])
|
| 392 |
|
| 393 |
+
# Extract success predictions if available
|
| 394 |
+
success_array = None
|
| 395 |
+
if success_probs and len(success_probs) > 0:
|
| 396 |
+
success_array = np.array(success_probs[0])
|
| 397 |
+
|
| 398 |
+
# Convert success_array to binary if available
|
| 399 |
+
success_binary = None
|
| 400 |
+
if success_array is not None:
|
| 401 |
+
success_binary = (success_array > 0.5).astype(float)
|
| 402 |
+
|
| 403 |
+
# Create combined plot using shared helper function
|
| 404 |
+
fig = create_combined_progress_success_plot(
|
| 405 |
+
progress_pred=progress_array if len(progress_array) > 0 else np.array([0.0]),
|
| 406 |
+
num_frames=num_frames,
|
| 407 |
+
success_binary=success_binary,
|
| 408 |
+
success_probs=success_array,
|
| 409 |
+
success_labels=None, # No ground truth labels available
|
| 410 |
+
is_discrete_mode=False,
|
| 411 |
+
num_bins=10,
|
| 412 |
+
title=f"Progress & Success - {task_text}",
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# Save to temporary file
|
| 416 |
+
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
|
| 417 |
+
fig.savefig(tmp_file.name, dpi=150, bbox_inches="tight")
|
| 418 |
+
plt.close(fig)
|
| 419 |
+
progress_plot = tmp_file.name
|
| 420 |
|
| 421 |
info_text = f"**Frames processed:** {num_frames}\n"
|
| 422 |
if len(progress_array) > 0:
|
| 423 |
info_text += f"**Final progress:** {progress_array[-1]:.3f}\n"
|
| 424 |
+
if success_array is not None and len(success_array) > 0:
|
| 425 |
+
info_text += f"**Final success probability:** {success_array[-1]:.3f}\n"
|
| 426 |
|
| 427 |
+
# Return combined plot (which includes success if available)
|
| 428 |
+
return progress_plot, info_text
|
| 429 |
|
| 430 |
except Exception as e:
|
| 431 |
+
return None, f"Error processing video: {str(e)}"
|
| 432 |
|
| 433 |
|
| 434 |
def process_dual_videos(
|
|
|
|
| 442 |
"""Process two videos for preference or similarity prediction using eval server."""
|
| 443 |
if not server_url:
|
| 444 |
return "Please provide a server URL and check connection first.", None
|
| 445 |
+
|
| 446 |
if not _server_state.get("server_url"):
|
| 447 |
return "Server not connected. Please check server connection first.", None
|
| 448 |
|
|
|
|
| 519 |
else:
|
| 520 |
result_text += "Could not extract preference prediction from server response.\n"
|
| 521 |
|
| 522 |
+
elif prediction_type == "progress":
|
| 523 |
+
# Create ProgressSamples for both videos
|
| 524 |
+
from rfm.data.dataset_types import ProgressSample
|
| 525 |
+
|
| 526 |
+
progress_sample_a = ProgressSample(
|
| 527 |
+
trajectory=trajectory_a,
|
| 528 |
+
data_gen_strategy="demo",
|
| 529 |
+
)
|
| 530 |
+
progress_sample_b = ProgressSample(
|
| 531 |
+
trajectory=trajectory_b,
|
| 532 |
+
data_gen_strategy="demo",
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
# Build payload and send to server
|
| 536 |
+
files, sample_data = build_payload([progress_sample_a, progress_sample_b])
|
| 537 |
+
response = post_batch_npy(server_url, files, sample_data, timeout_s=120.0)
|
| 538 |
+
|
| 539 |
+
# Process response
|
| 540 |
+
outputs_progress = response.get("outputs_progress", {})
|
| 541 |
+
progress_pred = outputs_progress.get("progress_pred", [])
|
| 542 |
+
|
| 543 |
+
result_text = f"**Progress Comparison:**\n"
|
| 544 |
+
if progress_pred and len(progress_pred) >= 2:
|
| 545 |
+
progress_a = np.array(progress_pred[0])
|
| 546 |
+
progress_b = np.array(progress_pred[1])
|
| 547 |
+
|
| 548 |
+
final_progress_a = float(progress_a[-1]) if len(progress_a) > 0 else 0.0
|
| 549 |
+
final_progress_b = float(progress_b[-1]) if len(progress_b) > 0 else 0.0
|
| 550 |
+
|
| 551 |
+
result_text += f"- Video A final progress: {final_progress_a:.3f}\n"
|
| 552 |
+
result_text += f"- Video B final progress: {final_progress_b:.3f}\n"
|
| 553 |
+
result_text += f"- Difference: {abs(final_progress_a - final_progress_b):.3f}\n"
|
| 554 |
+
if final_progress_a > final_progress_b:
|
| 555 |
+
result_text += f"- Video A has higher progress\n"
|
| 556 |
+
elif final_progress_b > final_progress_a:
|
| 557 |
+
result_text += f"- Video B has higher progress\n"
|
| 558 |
+
else:
|
| 559 |
+
result_text += f"- Both videos have equal progress\n"
|
| 560 |
+
else:
|
| 561 |
+
result_text += "Could not extract progress predictions from server response.\n"
|
| 562 |
+
|
| 563 |
else: # similarity - not yet implemented in eval server response format
|
| 564 |
result_text = "Similarity prediction not yet supported in eval server response format."
|
| 565 |
|
|
|
|
| 574 |
return f"Error processing videos: {str(e)}", None
|
| 575 |
|
| 576 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
def create_comparison_plot(frames_a: list, frames_b: list, prediction_type: str) -> str:
|
| 578 |
"""Create side-by-side comparison plot of two videos."""
|
| 579 |
+
plt.rcParams["font.family"] = "DejaVu Sans"
|
| 580 |
+
plt.rcParams["font.size"] = 16
|
| 581 |
+
|
| 582 |
fig, axes = plt.subplots(2, min(8, max(len(frames_a), len(frames_b))), figsize=(16, 4))
|
| 583 |
+
|
| 584 |
if len(axes.shape) == 1:
|
| 585 |
axes = axes.reshape(2, -1)
|
| 586 |
+
|
| 587 |
# Sample frames to display
|
| 588 |
num_display = min(8, max(len(frames_a), len(frames_b)))
|
| 589 |
indices_a = np.linspace(0, len(frames_a) - 1, num_display, dtype=int) if len(frames_a) > 1 else [0]
|
| 590 |
indices_b = np.linspace(0, len(frames_b) - 1, num_display, dtype=int) if len(frames_b) > 1 else [0]
|
| 591 |
+
|
| 592 |
# Display frames from video A (top row)
|
| 593 |
for idx, frame_idx in enumerate(indices_a):
|
| 594 |
if frame_idx < len(frames_a):
|
| 595 |
axes[0, idx].imshow(frames_a[frame_idx])
|
| 596 |
+
axes[0, idx].axis("off")
|
| 597 |
+
axes[0, idx].set_title(f"Frame {frame_idx}", fontsize=12)
|
| 598 |
+
|
| 599 |
# Display frames from video B (bottom row)
|
| 600 |
for idx, frame_idx in enumerate(indices_b):
|
| 601 |
if frame_idx < len(frames_b):
|
| 602 |
axes[1, idx].imshow(frames_b[frame_idx])
|
| 603 |
+
axes[1, idx].axis("off")
|
| 604 |
+
axes[1, idx].set_title(f"Frame {frame_idx}", fontsize=12)
|
| 605 |
+
|
| 606 |
# Add row labels
|
| 607 |
+
fig.text(0.02, 0.75, "Video A", rotation=90, fontsize=18, fontweight="bold", va="center")
|
| 608 |
+
fig.text(0.02, 0.25, "Video B", rotation=90, fontsize=18, fontweight="bold", va="center")
|
| 609 |
+
|
| 610 |
title = f"{prediction_type.capitalize()} Comparison: Video A vs Video B"
|
| 611 |
+
fig.suptitle(title, fontsize=20, fontweight="bold", y=0.98)
|
| 612 |
+
|
| 613 |
plt.tight_layout()
|
| 614 |
+
|
| 615 |
# Save to temporary file
|
| 616 |
+
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
|
| 617 |
+
plt.savefig(tmp_file.name, dpi=150, bbox_inches="tight")
|
| 618 |
plt.close()
|
| 619 |
+
|
| 620 |
return tmp_file.name
|
| 621 |
|
| 622 |
|
|
|
|
| 646 |
with gr.Tab("Server Setup"):
|
| 647 |
gr.Markdown("### Connect to Eval Server")
|
| 648 |
gr.Markdown("Enter the eval server URL and check connection.")
|
| 649 |
+
|
| 650 |
with gr.Row():
|
| 651 |
with gr.Column(scale=3):
|
| 652 |
server_url_input = gr.Textbox(
|
|
|
|
| 657 |
)
|
| 658 |
with gr.Column(scale=1):
|
| 659 |
check_connection_btn = gr.Button("Check Connection", variant="primary", size="sm")
|
| 660 |
+
|
| 661 |
server_status = gr.Markdown("Enter server URL and click 'Check Connection'")
|
| 662 |
model_info_display = gr.Markdown("", visible=False)
|
| 663 |
|
|
|
|
| 668 |
return status, gr.update(value=model_info_text, visible=True)
|
| 669 |
else:
|
| 670 |
return status, gr.update(visible=False)
|
| 671 |
+
|
| 672 |
check_connection_btn.click(
|
| 673 |
fn=on_check_connection,
|
| 674 |
inputs=[server_url_input],
|
|
|
|
| 678 |
with gr.Tab("Progress Prediction"):
|
| 679 |
gr.Markdown("### Progress & Success Prediction")
|
| 680 |
gr.Markdown("Upload a video or select one from a dataset to get progress predictions.")
|
| 681 |
+
|
| 682 |
with gr.Row():
|
| 683 |
with gr.Column():
|
| 684 |
with gr.Accordion("📁 Select from Dataset", open=False):
|
|
|
|
| 686 |
choices=PREDEFINED_DATASETS,
|
| 687 |
value="jesbu1/oxe_rfm",
|
| 688 |
label="Dataset Name",
|
| 689 |
+
allow_custom_value=True,
|
| 690 |
)
|
| 691 |
config_name_single = gr.Dropdown(
|
| 692 |
+
choices=[], value="", label="Configuration Name", allow_custom_value=True
|
|
|
|
|
|
|
|
|
|
| 693 |
)
|
| 694 |
with gr.Row():
|
| 695 |
refresh_configs_btn = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm")
|
| 696 |
load_dataset_btn = gr.Button("Load Dataset", variant="secondary", size="sm")
|
| 697 |
+
|
| 698 |
dataset_status_single = gr.Markdown("", visible=False)
|
| 699 |
with gr.Row():
|
| 700 |
prev_traj_btn = gr.Button("⬅️ Prev", variant="secondary", size="sm")
|
| 701 |
trajectory_slider = gr.Slider(
|
| 702 |
+
minimum=0, maximum=0, step=1, value=0, label="Trajectory Index", interactive=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
)
|
| 704 |
next_traj_btn = gr.Button("Next ➡️", variant="secondary", size="sm")
|
| 705 |
trajectory_metadata = gr.Markdown("", visible=False)
|
| 706 |
use_dataset_video_btn = gr.Button("Use Selected Video", variant="secondary")
|
| 707 |
+
|
| 708 |
gr.Markdown("---")
|
| 709 |
gr.Markdown("**OR**")
|
| 710 |
gr.Markdown("---")
|
| 711 |
+
|
| 712 |
single_video_input = gr.Video(label="Upload Video", height=300)
|
| 713 |
task_text_input = gr.Textbox(
|
| 714 |
label="Task Description",
|
|
|
|
| 726 |
analyze_single_btn = gr.Button("Analyze Video", variant="primary")
|
| 727 |
|
| 728 |
with gr.Column():
|
| 729 |
+
progress_plot = gr.Image(label="Progress & Success Prediction", height=400)
|
|
|
|
| 730 |
info_output = gr.Markdown("")
|
| 731 |
+
|
| 732 |
# State variables for dataset
|
| 733 |
current_dataset_single = gr.State(None)
|
| 734 |
+
|
| 735 |
def update_config_choices_single(dataset_name):
|
| 736 |
"""Update config choices when dataset changes."""
|
| 737 |
if not dataset_name:
|
|
|
|
| 745 |
except Exception as e:
|
| 746 |
logger.warning(f"Could not fetch configs: {e}")
|
| 747 |
return gr.update(choices=[], value="")
|
| 748 |
+
|
| 749 |
def load_dataset_single(dataset_name, config_name):
|
| 750 |
"""Load dataset and update slider."""
|
| 751 |
dataset, status = load_rfm_dataset(dataset_name, config_name)
|
|
|
|
| 754 |
return (
|
| 755 |
dataset,
|
| 756 |
gr.update(value=status, visible=True),
|
| 757 |
+
gr.update(
|
| 758 |
+
maximum=max_index, value=0, interactive=True, label=f"Trajectory Index (0 to {max_index})"
|
| 759 |
+
),
|
| 760 |
)
|
| 761 |
else:
|
| 762 |
return None, gr.update(value=status, visible=True), gr.update(maximum=0, value=0, interactive=False)
|
| 763 |
+
|
| 764 |
def use_dataset_video(dataset, index, dataset_name):
|
| 765 |
"""Load video from dataset and update inputs."""
|
| 766 |
if dataset is None:
|
| 767 |
+
return (
|
| 768 |
+
None,
|
| 769 |
+
"Complete the task",
|
| 770 |
+
gr.update(value="No dataset loaded", visible=True),
|
| 771 |
+
gr.update(visible=False),
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
|
| 775 |
if video_path:
|
| 776 |
# Build metadata text
|
|
|
|
| 779 |
metadata_lines.append(f"**Quality Label:** {quality_label}")
|
| 780 |
if partial_success is not None:
|
| 781 |
metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
|
| 782 |
+
|
| 783 |
metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
|
| 784 |
status_text = f"✅ Loaded trajectory {index} from dataset"
|
| 785 |
if metadata_text:
|
| 786 |
status_text += f"\n\n{metadata_text}"
|
| 787 |
+
|
| 788 |
return (
|
| 789 |
+
video_path,
|
| 790 |
+
task,
|
| 791 |
gr.update(value=status_text, visible=True),
|
| 792 |
+
gr.update(value=metadata_text, visible=bool(metadata_text)),
|
| 793 |
)
|
| 794 |
else:
|
| 795 |
+
return (
|
| 796 |
+
None,
|
| 797 |
+
"Complete the task",
|
| 798 |
+
gr.update(value="❌ Error loading trajectory", visible=True),
|
| 799 |
+
gr.update(visible=False),
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
def next_trajectory(dataset, current_idx, dataset_name):
|
| 803 |
"""Go to next trajectory."""
|
| 804 |
if dataset is None:
|
| 805 |
return 0, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
|
| 806 |
next_idx = min(current_idx + 1, len(dataset) - 1)
|
| 807 |
+
video_path, task, quality_label, partial_success = get_trajectory_video_path(
|
| 808 |
+
dataset, next_idx, dataset_name
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
if video_path:
|
| 812 |
# Build metadata text
|
| 813 |
metadata_lines = []
|
|
|
|
| 815 |
metadata_lines.append(f"**Quality Label:** {quality_label}")
|
| 816 |
if partial_success is not None:
|
| 817 |
metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
|
| 818 |
+
|
| 819 |
metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
|
| 820 |
return (
|
| 821 |
next_idx,
|
| 822 |
video_path,
|
| 823 |
task,
|
| 824 |
gr.update(value=metadata_text, visible=bool(metadata_text)),
|
| 825 |
+
gr.update(value=f"✅ Trajectory {next_idx}/{len(dataset) - 1}", visible=True),
|
| 826 |
)
|
| 827 |
else:
|
| 828 |
return current_idx, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
|
| 829 |
+
|
| 830 |
def prev_trajectory(dataset, current_idx, dataset_name):
|
| 831 |
"""Go to previous trajectory."""
|
| 832 |
if dataset is None:
|
| 833 |
return 0, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
|
| 834 |
prev_idx = max(current_idx - 1, 0)
|
| 835 |
+
video_path, task, quality_label, partial_success = get_trajectory_video_path(
|
| 836 |
+
dataset, prev_idx, dataset_name
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
if video_path:
|
| 840 |
# Build metadata text
|
| 841 |
metadata_lines = []
|
|
|
|
| 843 |
metadata_lines.append(f"**Quality Label:** {quality_label}")
|
| 844 |
if partial_success is not None:
|
| 845 |
metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
|
| 846 |
+
|
| 847 |
metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
|
| 848 |
return (
|
| 849 |
prev_idx,
|
| 850 |
video_path,
|
| 851 |
task,
|
| 852 |
gr.update(value=metadata_text, visible=bool(metadata_text)),
|
| 853 |
+
gr.update(value=f"✅ Trajectory {prev_idx}/{len(dataset) - 1}", visible=True),
|
| 854 |
)
|
| 855 |
else:
|
| 856 |
return current_idx, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
|
| 857 |
+
|
| 858 |
def update_trajectory_on_slider_change(dataset, index, dataset_name):
|
| 859 |
"""Update trajectory metadata when slider changes."""
|
| 860 |
if dataset is None:
|
| 861 |
return gr.update(visible=False), gr.update(visible=False)
|
| 862 |
+
|
| 863 |
video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
|
| 864 |
if video_path:
|
| 865 |
# Build metadata text
|
|
|
|
| 868 |
metadata_lines.append(f"**Quality Label:** {quality_label}")
|
| 869 |
if partial_success is not None:
|
| 870 |
metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
|
| 871 |
+
|
| 872 |
metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
|
| 873 |
return (
|
| 874 |
gr.update(value=metadata_text, visible=bool(metadata_text)),
|
| 875 |
+
gr.update(value=f"Trajectory {index}/{len(dataset) - 1}", visible=True),
|
| 876 |
)
|
| 877 |
else:
|
| 878 |
return gr.update(visible=False), gr.update(visible=False)
|
| 879 |
+
|
| 880 |
# Dataset selection handlers
|
| 881 |
dataset_name_single.change(
|
| 882 |
+
fn=update_config_choices_single, inputs=[dataset_name_single], outputs=[config_name_single]
|
|
|
|
|
|
|
| 883 |
)
|
| 884 |
+
|
| 885 |
refresh_configs_btn.click(
|
| 886 |
+
fn=update_config_choices_single, inputs=[dataset_name_single], outputs=[config_name_single]
|
|
|
|
|
|
|
| 887 |
)
|
| 888 |
+
|
| 889 |
load_dataset_btn.click(
|
| 890 |
fn=load_dataset_single,
|
| 891 |
inputs=[dataset_name_single, config_name_single],
|
| 892 |
+
outputs=[current_dataset_single, dataset_status_single, trajectory_slider],
|
| 893 |
)
|
| 894 |
+
|
| 895 |
use_dataset_video_btn.click(
|
| 896 |
fn=use_dataset_video,
|
| 897 |
inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
|
| 898 |
+
outputs=[single_video_input, task_text_input, dataset_status_single, trajectory_metadata],
|
| 899 |
)
|
| 900 |
+
|
| 901 |
# Navigation buttons
|
| 902 |
next_traj_btn.click(
|
| 903 |
fn=next_trajectory,
|
| 904 |
inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
|
| 905 |
+
outputs=[
|
| 906 |
+
trajectory_slider,
|
| 907 |
+
single_video_input,
|
| 908 |
+
task_text_input,
|
| 909 |
+
trajectory_metadata,
|
| 910 |
+
dataset_status_single,
|
| 911 |
+
],
|
| 912 |
)
|
| 913 |
+
|
| 914 |
prev_traj_btn.click(
|
| 915 |
fn=prev_trajectory,
|
| 916 |
inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
|
| 917 |
+
outputs=[
|
| 918 |
+
trajectory_slider,
|
| 919 |
+
single_video_input,
|
| 920 |
+
task_text_input,
|
| 921 |
+
trajectory_metadata,
|
| 922 |
+
dataset_status_single,
|
| 923 |
+
],
|
| 924 |
)
|
| 925 |
+
|
| 926 |
# Update metadata when slider changes
|
| 927 |
trajectory_slider.change(
|
| 928 |
fn=update_trajectory_on_slider_change,
|
| 929 |
inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
|
| 930 |
+
outputs=[trajectory_metadata, dataset_status_single],
|
| 931 |
)
|
| 932 |
+
|
| 933 |
analyze_single_btn.click(
|
| 934 |
fn=process_single_video,
|
| 935 |
inputs=[single_video_input, task_text_input, server_url_input, fps_input_single],
|
| 936 |
+
outputs=[progress_plot, info_output],
|
| 937 |
+
api_name="process_single_video",
|
| 938 |
)
|
| 939 |
|
| 940 |
with gr.Tab("Preference/Similarity Analysis"):
|
|
|
|
| 949 |
value="Complete the task",
|
| 950 |
)
|
| 951 |
prediction_type = gr.Radio(
|
| 952 |
+
choices=["preference", "similarity", "progress"],
|
| 953 |
value="preference",
|
| 954 |
label="Prediction Type",
|
| 955 |
)
|
|
|
|
| 971 |
fn=process_dual_videos,
|
| 972 |
inputs=[video_a_input, video_b_input, task_text_dual, prediction_type, server_url_input, fps_input_dual],
|
| 973 |
outputs=[result_text, comparison_plot],
|
| 974 |
+
api_name="process_dual_videos",
|
| 975 |
)
|
| 976 |
|
| 977 |
|
| 978 |
def main():
|
| 979 |
"""Launch the Gradio app."""
|
| 980 |
import sys
|
| 981 |
+
|
| 982 |
# Check if reload mode is requested
|
| 983 |
watch_files = os.getenv("GRADIO_WATCH", "0") == "1" or "--reload" in sys.argv
|
| 984 |
+
|
| 985 |
demo.launch(
|
| 986 |
server_name="0.0.0.0",
|
| 987 |
server_port=7860,
|