Anthony Liang commited on
Commit
d2a5693
·
1 Parent(s): f506da8

update interface

Browse files
Files changed (1) hide show
  1. app.py +240 -196
app.py CHANGED
@@ -12,9 +12,14 @@ from typing import Optional, Tuple
12
  import logging
13
 
14
  import gradio as gr
15
- import spaces # Required for ZeroGPU
 
 
 
 
16
  import matplotlib
17
- matplotlib.use('Agg') # Use non-interactive backend
 
18
  import matplotlib.pyplot as plt
19
  import numpy as np
20
  import requests
@@ -24,6 +29,7 @@ from typing import Any, Optional, Tuple
24
 
25
  from rfm.data.dataset_types import Trajectory, ProgressSample, PreferenceSample
26
  from rfm.evals.eval_utils import build_payload, post_batch_npy
 
27
  from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
28
 
29
  logger = logging.getLogger(__name__)
@@ -57,7 +63,7 @@ PREDEFINED_DATASETS = [
57
  "aliangdw/usc_xarm_policy_ranking",
58
  "aliangdw/usc_franka_policy_ranking",
59
  "aliangdw/utd_so101_policy_ranking",
60
- "aliangdw/utd_so101_human"
61
  ]
62
 
63
  # Global server state
@@ -65,17 +71,18 @@ _server_state = {
65
  "server_url": None,
66
  }
67
 
 
68
  def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[str]]:
69
  """Check server health and get model info."""
70
  if not server_url:
71
  return "Please provide a server URL.", None, None
72
-
73
  try:
74
  url = server_url.rstrip("/") + "/health"
75
  response = requests.get(url, timeout=5.0)
76
  response.raise_for_status()
77
  health_data = response.json()
78
-
79
  # Also try to get GPU status for more info
80
  try:
81
  status_url = server_url.rstrip("/") + "/gpu_status"
@@ -85,7 +92,7 @@ def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[
85
  health_data.update(status_data)
86
  except:
87
  pass
88
-
89
  # Try to get model info
90
  model_info_text = None
91
  try:
@@ -96,9 +103,13 @@ def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[
96
  model_info_text = format_model_info(model_info_data)
97
  except Exception as e:
98
  logger.warning(f"Could not fetch model info: {e}")
99
-
100
  _server_state["server_url"] = server_url
101
- return f"Server connected: {health_data.get('available_gpus', 0)}/{health_data.get('total_gpus', 0)} GPUs available", health_data, model_info_text
 
 
 
 
102
  except requests.exceptions.RequestException as e:
103
  return f"Error connecting to server: {str(e)}", None, None
104
 
@@ -106,31 +117,31 @@ def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[
106
  def format_model_info(model_info: dict) -> str:
107
  """Format model info and experiment config as markdown."""
108
  lines = ["## Model Information\n"]
109
-
110
  # Model path
111
  model_path = model_info.get("model_path", "Unknown")
112
  lines.append(f"**Model Path:** `{model_path}`\n")
113
-
114
  # Number of GPUs
115
  num_gpus = model_info.get("num_gpus", "Unknown")
116
  lines.append(f"**Number of GPUs:** {num_gpus}\n")
117
-
118
  # Model architecture
119
  model_arch = model_info.get("model_architecture", {})
120
  if model_arch and "error" not in model_arch:
121
  lines.append("\n## Model Architecture\n")
122
-
123
  model_class = model_arch.get("model_class", "Unknown")
124
  model_module = model_arch.get("model_module", "Unknown")
125
  lines.append(f"- **Model Class:** `{model_class}`\n")
126
  lines.append(f"- **Module:** `{model_module}`\n")
127
-
128
  # Parameter counts
129
  total_params = model_arch.get("total_parameters")
130
  trainable_params = model_arch.get("trainable_parameters")
131
  frozen_params = model_arch.get("frozen_parameters")
132
  trainable_pct = model_arch.get("trainable_percentage")
133
-
134
  if total_params is not None:
135
  lines.append(f"\n### Parameter Statistics\n")
136
  lines.append(f"- **Total Parameters:** {total_params:,}\n")
@@ -140,7 +151,7 @@ def format_model_info(model_info: dict) -> str:
140
  lines.append(f"- **Frozen Parameters:** {frozen_params:,}\n")
141
  if trainable_pct is not None:
142
  lines.append(f"- **Trainable Percentage:** {trainable_pct:.2f}%\n")
143
-
144
  # Architecture summary
145
  arch_summary = model_arch.get("architecture_summary", [])
146
  if arch_summary:
@@ -150,12 +161,12 @@ def format_model_info(model_info: dict) -> str:
150
  module_type = module_info.get("type", "Unknown")
151
  params = module_info.get("parameters", 0)
152
  lines.append(f"- **{name}** (`{module_type}`): {params:,} parameters\n")
153
-
154
  # Experiment config
155
  exp_config = model_info.get("experiment_config", {})
156
  if exp_config:
157
  lines.append("\n## Experiment Configuration\n")
158
-
159
  # Model config
160
  model_cfg = exp_config.get("model", {})
161
  if model_cfg:
@@ -168,29 +179,33 @@ def format_model_info(model_info: dict) -> str:
168
  lines.append(f"- **Train Success Head:** {model_cfg.get('train_success_head', False)}\n")
169
  lines.append(f"- **Use PEFT:** {model_cfg.get('use_peft', False)}\n")
170
  lines.append(f"- **Use Unsloth:** {model_cfg.get('use_unsloth', False)}\n")
171
-
172
  # Data config
173
  data_cfg = exp_config.get("data", {})
174
  if data_cfg:
175
  lines.append("\n### Data Configuration\n")
176
  lines.append(f"- **Max Frames:** {data_cfg.get('max_frames', 'N/A')}\n")
177
- lines.append(f"- **Resized Dimensions:** {data_cfg.get('resized_height', 'N/A')}x{data_cfg.get('resized_width', 'N/A')}\n")
178
- train_datasets = data_cfg.get('train_datasets', [])
 
 
179
  if train_datasets:
180
  lines.append(f"- **Train Datasets:** {', '.join(train_datasets)}\n")
181
- eval_datasets = data_cfg.get('eval_datasets', [])
182
  if eval_datasets:
183
  lines.append(f"- **Eval Datasets:** {', '.join(eval_datasets)}\n")
184
-
185
  # Training config
186
  training_cfg = exp_config.get("training", {})
187
  if training_cfg:
188
  lines.append("\n### Training Configuration\n")
189
  lines.append(f"- **Learning Rate:** {training_cfg.get('learning_rate', 'N/A')}\n")
190
  lines.append(f"- **Batch Size:** {training_cfg.get('per_device_train_batch_size', 'N/A')}\n")
191
- lines.append(f"- **Gradient Accumulation Steps:** {training_cfg.get('gradient_accumulation_steps', 'N/A')}\n")
 
 
192
  lines.append(f"- **Max Steps:** {training_cfg.get('max_steps', 'N/A')}\n")
193
-
194
  return "".join(lines)
195
 
196
 
@@ -199,12 +214,12 @@ def load_rfm_dataset(dataset_name, config_name):
199
  try:
200
  if not dataset_name or not config_name:
201
  return None, "Please provide both dataset name and configuration"
202
-
203
  dataset = load_dataset_hf(dataset_name, name=config_name, split="train")
204
-
205
  if len(dataset) == 0:
206
  return None, f"Dataset {dataset_name}/{config_name} is empty"
207
-
208
  return dataset, f"Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}"
209
  except Exception as e:
210
  error_msg = str(e)
@@ -231,18 +246,18 @@ def get_trajectory_video_path(dataset, index, dataset_name):
231
  try:
232
  item = dataset[int(index)]
233
  frames_data = item["frames"]
234
-
235
  if isinstance(frames_data, str):
236
  # Construct HuggingFace Hub URL
237
  if dataset_name:
238
  video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}"
239
  else:
240
  video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}"
241
-
242
  task = item.get("task", "Complete the task")
243
  quality_label = item.get("quality_label", None)
244
  partial_success = item.get("partial_success", None)
245
-
246
  return video_path, task, quality_label, partial_success
247
  else:
248
  return None, None, None, None
@@ -267,7 +282,7 @@ def extract_frames(video_path: str, fps: float = 1.0) -> np.ndarray:
267
  # Check if it's a URL or local file
268
  is_url = video_path.startswith(("http://", "https://"))
269
  is_local_file = os.path.exists(video_path) if not is_url else False
270
-
271
  if not is_url and not is_local_file:
272
  logger.warning(f"Video path does not exist: {video_path}")
273
  return None
@@ -304,7 +319,7 @@ def extract_frames(video_path: str, fps: float = 1.0) -> np.ndarray:
304
  frame_indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
305
 
306
  frames_array = vr.get_batch(frame_indices).asnumpy() # Shape: (T, H, W, C)
307
- del vr
308
  return frames_array
309
  except Exception as e:
310
  logger.error(f"Error extracting frames from {video_path}: {e}")
@@ -316,26 +331,26 @@ def process_single_video(
316
  task_text: str = "Complete the task",
317
  server_url: str = "",
318
  fps: float = 1.0,
319
- ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
320
  """Process single video for progress and success predictions using eval server."""
321
  if not server_url:
322
- return None, None, "Please provide a server URL and check connection first."
323
-
324
  if not _server_state.get("server_url"):
325
- return None, None, "Server not connected. Please check server connection first."
326
-
327
  if video_path is None:
328
- return None, None, "Please provide a video."
329
 
330
  try:
331
  frames_array = extract_frames(video_path, fps=fps)
332
  if frames_array is None or frames_array.size == 0:
333
- return None, None, "Could not extract frames from video."
334
 
335
  # Convert frames to (T, H, W, C) numpy array with uint8 values
336
  if frames_array.dtype != np.uint8:
337
  frames_array = np.clip(frames_array, 0, 255).astype(np.uint8)
338
-
339
  num_frames = frames_array.shape[0]
340
  frames_shape = frames_array.shape # (T, H, W, C)
341
 
@@ -366,25 +381,54 @@ def process_single_video(
366
  # Process response
367
  outputs_progress = response.get("outputs_progress", {})
368
  progress_pred = outputs_progress.get("progress_pred", [])
369
-
 
 
370
  # Extract progress predictions
371
  if progress_pred and len(progress_pred) > 0:
372
  progress_array = np.array(progress_pred[0]) # First sample
373
  else:
374
  progress_array = np.array([])
375
 
376
- # Create plots
377
- progress_plot = create_progress_plot(progress_array, num_frames)
378
- success_plot = None # Success predictions not always available from server
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
  info_text = f"**Frames processed:** {num_frames}\n"
381
  if len(progress_array) > 0:
382
  info_text += f"**Final progress:** {progress_array[-1]:.3f}\n"
 
 
383
 
384
- return progress_plot, success_plot, info_text
 
385
 
386
  except Exception as e:
387
- return None, None, f"Error processing video: {str(e)}"
388
 
389
 
390
  def process_dual_videos(
@@ -398,7 +442,7 @@ def process_dual_videos(
398
  """Process two videos for preference or similarity prediction using eval server."""
399
  if not server_url:
400
  return "Please provide a server URL and check connection first.", None
401
-
402
  if not _server_state.get("server_url"):
403
  return "Server not connected. Please check server connection first.", None
404
 
@@ -475,6 +519,47 @@ def process_dual_videos(
475
  else:
476
  result_text += "Could not extract preference prediction from server response.\n"
477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  else: # similarity - not yet implemented in eval server response format
479
  result_text = "Similarity prediction not yet supported in eval server response format."
480
 
@@ -489,107 +574,49 @@ def process_dual_videos(
489
  return f"Error processing videos: {str(e)}", None
490
 
491
 
492
- def create_progress_plot(progress_pred: np.ndarray, num_frames: int) -> str:
493
- """Create progress prediction plot."""
494
- plt.rcParams['font.family'] = 'DejaVu Sans'
495
- plt.rcParams['font.size'] = 16
496
-
497
- fig, ax = plt.subplots(figsize=(10, 6))
498
-
499
- if len(progress_pred) > 0:
500
- frame_indices = np.arange(len(progress_pred))
501
- ax.plot(frame_indices, progress_pred, 'b-', linewidth=3, marker='o', markersize=8, label='Progress Prediction')
502
- else:
503
- ax.text(0.5, 0.5, 'No progress prediction available',
504
- horizontalalignment='center', verticalalignment='center',
505
- transform=ax.transAxes, fontsize=18)
506
-
507
- ax.set_xlabel('Frame Index', fontsize=18, fontweight='bold')
508
- ax.set_ylabel('Progress (0-1)', fontsize=18, fontweight='bold')
509
- ax.set_title('Progress Prediction', fontsize=20, fontweight='bold')
510
- ax.set_ylim([0, 1])
511
-
512
- plt.tight_layout()
513
-
514
- tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
515
- plt.savefig(tmp_file.name, dpi=150, bbox_inches='tight')
516
- plt.close()
517
-
518
- return tmp_file.name
519
-
520
-
521
- def create_success_plot(success_probs: np.ndarray, num_frames: int) -> str:
522
- """Create success probability plot."""
523
- plt.rcParams['font.family'] = 'DejaVu Sans'
524
- plt.rcParams['font.size'] = 16
525
-
526
- fig, ax = plt.subplots(figsize=(10, 6))
527
-
528
- if len(success_probs) > 0:
529
- frame_indices = np.arange(len(success_probs))
530
- ax.plot(frame_indices, success_probs, 'g-', linewidth=3, marker='s', markersize=8, label='Success Probability')
531
- ax.axhline(y=0.5, color='r', linestyle='--', linewidth=2, label='Decision Threshold (0.5)')
532
- else:
533
- ax.text(0.5, 0.5, 'No success prediction available',
534
- horizontalalignment='center', verticalalignment='center',
535
- transform=ax.transAxes, fontsize=18)
536
-
537
- ax.set_xlabel('Frame Index', fontsize=18, fontweight='bold')
538
- ax.set_ylabel('Success Probability (0-1)', fontsize=18, fontweight='bold')
539
- ax.set_title('Success Prediction', fontsize=20, fontweight='bold')
540
- ax.set_ylim([0, 1])
541
-
542
- plt.tight_layout()
543
-
544
- tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
545
- plt.savefig(tmp_file.name, dpi=150, bbox_inches='tight')
546
- plt.close()
547
-
548
- return tmp_file.name
549
-
550
  def create_comparison_plot(frames_a: list, frames_b: list, prediction_type: str) -> str:
551
  """Create side-by-side comparison plot of two videos."""
552
- plt.rcParams['font.family'] = 'DejaVu Sans'
553
- plt.rcParams['font.size'] = 16
554
-
555
  fig, axes = plt.subplots(2, min(8, max(len(frames_a), len(frames_b))), figsize=(16, 4))
556
-
557
  if len(axes.shape) == 1:
558
  axes = axes.reshape(2, -1)
559
-
560
  # Sample frames to display
561
  num_display = min(8, max(len(frames_a), len(frames_b)))
562
  indices_a = np.linspace(0, len(frames_a) - 1, num_display, dtype=int) if len(frames_a) > 1 else [0]
563
  indices_b = np.linspace(0, len(frames_b) - 1, num_display, dtype=int) if len(frames_b) > 1 else [0]
564
-
565
  # Display frames from video A (top row)
566
  for idx, frame_idx in enumerate(indices_a):
567
  if frame_idx < len(frames_a):
568
  axes[0, idx].imshow(frames_a[frame_idx])
569
- axes[0, idx].axis('off')
570
- axes[0, idx].set_title(f'Frame {frame_idx}', fontsize=12)
571
-
572
  # Display frames from video B (bottom row)
573
  for idx, frame_idx in enumerate(indices_b):
574
  if frame_idx < len(frames_b):
575
  axes[1, idx].imshow(frames_b[frame_idx])
576
- axes[1, idx].axis('off')
577
- axes[1, idx].set_title(f'Frame {frame_idx}', fontsize=12)
578
-
579
  # Add row labels
580
- fig.text(0.02, 0.75, 'Video A', rotation=90, fontsize=18, fontweight='bold', va='center')
581
- fig.text(0.02, 0.25, 'Video B', rotation=90, fontsize=18, fontweight='bold', va='center')
582
-
583
  title = f"{prediction_type.capitalize()} Comparison: Video A vs Video B"
584
- fig.suptitle(title, fontsize=20, fontweight='bold', y=0.98)
585
-
586
  plt.tight_layout()
587
-
588
  # Save to temporary file
589
- tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
590
- plt.savefig(tmp_file.name, dpi=150, bbox_inches='tight')
591
  plt.close()
592
-
593
  return tmp_file.name
594
 
595
 
@@ -619,7 +646,7 @@ with demo:
619
  with gr.Tab("Server Setup"):
620
  gr.Markdown("### Connect to Eval Server")
621
  gr.Markdown("Enter the eval server URL and check connection.")
622
-
623
  with gr.Row():
624
  with gr.Column(scale=3):
625
  server_url_input = gr.Textbox(
@@ -630,7 +657,7 @@ with demo:
630
  )
631
  with gr.Column(scale=1):
632
  check_connection_btn = gr.Button("Check Connection", variant="primary", size="sm")
633
-
634
  server_status = gr.Markdown("Enter server URL and click 'Check Connection'")
635
  model_info_display = gr.Markdown("", visible=False)
636
 
@@ -641,7 +668,7 @@ with demo:
641
  return status, gr.update(value=model_info_text, visible=True)
642
  else:
643
  return status, gr.update(visible=False)
644
-
645
  check_connection_btn.click(
646
  fn=on_check_connection,
647
  inputs=[server_url_input],
@@ -651,7 +678,7 @@ with demo:
651
  with gr.Tab("Progress Prediction"):
652
  gr.Markdown("### Progress & Success Prediction")
653
  gr.Markdown("Upload a video or select one from a dataset to get progress predictions.")
654
-
655
  with gr.Row():
656
  with gr.Column():
657
  with gr.Accordion("📁 Select from Dataset", open=False):
@@ -659,37 +686,29 @@ with demo:
659
  choices=PREDEFINED_DATASETS,
660
  value="jesbu1/oxe_rfm",
661
  label="Dataset Name",
662
- allow_custom_value=True
663
  )
664
  config_name_single = gr.Dropdown(
665
- choices=[],
666
- value="",
667
- label="Configuration Name",
668
- allow_custom_value=True
669
  )
670
  with gr.Row():
671
  refresh_configs_btn = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm")
672
  load_dataset_btn = gr.Button("Load Dataset", variant="secondary", size="sm")
673
-
674
  dataset_status_single = gr.Markdown("", visible=False)
675
  with gr.Row():
676
  prev_traj_btn = gr.Button("⬅️ Prev", variant="secondary", size="sm")
677
  trajectory_slider = gr.Slider(
678
- minimum=0,
679
- maximum=0,
680
- step=1,
681
- value=0,
682
- label="Trajectory Index",
683
- interactive=True
684
  )
685
  next_traj_btn = gr.Button("Next ➡️", variant="secondary", size="sm")
686
  trajectory_metadata = gr.Markdown("", visible=False)
687
  use_dataset_video_btn = gr.Button("Use Selected Video", variant="secondary")
688
-
689
  gr.Markdown("---")
690
  gr.Markdown("**OR**")
691
  gr.Markdown("---")
692
-
693
  single_video_input = gr.Video(label="Upload Video", height=300)
694
  task_text_input = gr.Textbox(
695
  label="Task Description",
@@ -707,13 +726,12 @@ with demo:
707
  analyze_single_btn = gr.Button("Analyze Video", variant="primary")
708
 
709
  with gr.Column():
710
- progress_plot = gr.Image(label="Progress Prediction", height=400)
711
- success_plot = gr.Image(label="Success Prediction", height=400)
712
  info_output = gr.Markdown("")
713
-
714
  # State variables for dataset
715
  current_dataset_single = gr.State(None)
716
-
717
  def update_config_choices_single(dataset_name):
718
  """Update config choices when dataset changes."""
719
  if not dataset_name:
@@ -727,7 +745,7 @@ with demo:
727
  except Exception as e:
728
  logger.warning(f"Could not fetch configs: {e}")
729
  return gr.update(choices=[], value="")
730
-
731
  def load_dataset_single(dataset_name, config_name):
732
  """Load dataset and update slider."""
733
  dataset, status = load_rfm_dataset(dataset_name, config_name)
@@ -736,16 +754,23 @@ with demo:
736
  return (
737
  dataset,
738
  gr.update(value=status, visible=True),
739
- gr.update(maximum=max_index, value=0, interactive=True, label=f"Trajectory Index (0 to {max_index})")
 
 
740
  )
741
  else:
742
  return None, gr.update(value=status, visible=True), gr.update(maximum=0, value=0, interactive=False)
743
-
744
  def use_dataset_video(dataset, index, dataset_name):
745
  """Load video from dataset and update inputs."""
746
  if dataset is None:
747
- return None, "Complete the task", gr.update(value="No dataset loaded", visible=True), gr.update(visible=False)
748
-
 
 
 
 
 
749
  video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
750
  if video_path:
751
  # Build metadata text
@@ -754,28 +779,35 @@ with demo:
754
  metadata_lines.append(f"**Quality Label:** {quality_label}")
755
  if partial_success is not None:
756
  metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
757
-
758
  metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
759
  status_text = f"✅ Loaded trajectory {index} from dataset"
760
  if metadata_text:
761
  status_text += f"\n\n{metadata_text}"
762
-
763
  return (
764
- video_path,
765
- task,
766
  gr.update(value=status_text, visible=True),
767
- gr.update(value=metadata_text, visible=bool(metadata_text))
768
  )
769
  else:
770
- return None, "Complete the task", gr.update(value="❌ Error loading trajectory", visible=True), gr.update(visible=False)
771
-
 
 
 
 
 
772
  def next_trajectory(dataset, current_idx, dataset_name):
773
  """Go to next trajectory."""
774
  if dataset is None:
775
  return 0, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
776
  next_idx = min(current_idx + 1, len(dataset) - 1)
777
- video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, next_idx, dataset_name)
778
-
 
 
779
  if video_path:
780
  # Build metadata text
781
  metadata_lines = []
@@ -783,25 +815,27 @@ with demo:
783
  metadata_lines.append(f"**Quality Label:** {quality_label}")
784
  if partial_success is not None:
785
  metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
786
-
787
  metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
788
  return (
789
  next_idx,
790
  video_path,
791
  task,
792
  gr.update(value=metadata_text, visible=bool(metadata_text)),
793
- gr.update(value=f"✅ Trajectory {next_idx}/{len(dataset) - 1}", visible=True)
794
  )
795
  else:
796
  return current_idx, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
797
-
798
  def prev_trajectory(dataset, current_idx, dataset_name):
799
  """Go to previous trajectory."""
800
  if dataset is None:
801
  return 0, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
802
  prev_idx = max(current_idx - 1, 0)
803
- video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, prev_idx, dataset_name)
804
-
 
 
805
  if video_path:
806
  # Build metadata text
807
  metadata_lines = []
@@ -809,23 +843,23 @@ with demo:
809
  metadata_lines.append(f"**Quality Label:** {quality_label}")
810
  if partial_success is not None:
811
  metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
812
-
813
  metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
814
  return (
815
  prev_idx,
816
  video_path,
817
  task,
818
  gr.update(value=metadata_text, visible=bool(metadata_text)),
819
- gr.update(value=f"✅ Trajectory {prev_idx}/{len(dataset) - 1}", visible=True)
820
  )
821
  else:
822
  return current_idx, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
823
-
824
  def update_trajectory_on_slider_change(dataset, index, dataset_name):
825
  """Update trajectory metadata when slider changes."""
826
  if dataset is None:
827
  return gr.update(visible=False), gr.update(visible=False)
828
-
829
  video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
830
  if video_path:
831
  # Build metadata text
@@ -834,64 +868,73 @@ with demo:
834
  metadata_lines.append(f"**Quality Label:** {quality_label}")
835
  if partial_success is not None:
836
  metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
837
-
838
  metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
839
  return (
840
  gr.update(value=metadata_text, visible=bool(metadata_text)),
841
- gr.update(value=f"Trajectory {index}/{len(dataset) - 1}", visible=True)
842
  )
843
  else:
844
  return gr.update(visible=False), gr.update(visible=False)
845
-
846
  # Dataset selection handlers
847
  dataset_name_single.change(
848
- fn=update_config_choices_single,
849
- inputs=[dataset_name_single],
850
- outputs=[config_name_single]
851
  )
852
-
853
  refresh_configs_btn.click(
854
- fn=update_config_choices_single,
855
- inputs=[dataset_name_single],
856
- outputs=[config_name_single]
857
  )
858
-
859
  load_dataset_btn.click(
860
  fn=load_dataset_single,
861
  inputs=[dataset_name_single, config_name_single],
862
- outputs=[current_dataset_single, dataset_status_single, trajectory_slider]
863
  )
864
-
865
  use_dataset_video_btn.click(
866
  fn=use_dataset_video,
867
  inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
868
- outputs=[single_video_input, task_text_input, dataset_status_single, trajectory_metadata]
869
  )
870
-
871
  # Navigation buttons
872
  next_traj_btn.click(
873
  fn=next_trajectory,
874
  inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
875
- outputs=[trajectory_slider, single_video_input, task_text_input, trajectory_metadata, dataset_status_single]
 
 
 
 
 
 
876
  )
877
-
878
  prev_traj_btn.click(
879
  fn=prev_trajectory,
880
  inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
881
- outputs=[trajectory_slider, single_video_input, task_text_input, trajectory_metadata, dataset_status_single]
 
 
 
 
 
 
882
  )
883
-
884
  # Update metadata when slider changes
885
  trajectory_slider.change(
886
  fn=update_trajectory_on_slider_change,
887
  inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
888
- outputs=[trajectory_metadata, dataset_status_single]
889
  )
890
-
891
  analyze_single_btn.click(
892
  fn=process_single_video,
893
  inputs=[single_video_input, task_text_input, server_url_input, fps_input_single],
894
- outputs=[progress_plot, success_plot, info_output],
 
895
  )
896
 
897
  with gr.Tab("Preference/Similarity Analysis"):
@@ -906,7 +949,7 @@ with demo:
906
  value="Complete the task",
907
  )
908
  prediction_type = gr.Radio(
909
- choices=["preference", "similarity"],
910
  value="preference",
911
  label="Prediction Type",
912
  )
@@ -928,16 +971,17 @@ with demo:
928
  fn=process_dual_videos,
929
  inputs=[video_a_input, video_b_input, task_text_dual, prediction_type, server_url_input, fps_input_dual],
930
  outputs=[result_text, comparison_plot],
 
931
  )
932
 
933
 
934
  def main():
935
  """Launch the Gradio app."""
936
  import sys
937
-
938
  # Check if reload mode is requested
939
  watch_files = os.getenv("GRADIO_WATCH", "0") == "1" or "--reload" in sys.argv
940
-
941
  demo.launch(
942
  server_name="0.0.0.0",
943
  server_port=7860,
 
12
  import logging
13
 
14
  import gradio as gr
15
+
16
+ try:
17
+ import spaces # Required for ZeroGPU on Hugging Face Spaces
18
+ except ImportError:
19
+ spaces = None # Not available when running locally
20
  import matplotlib
21
+
22
+ matplotlib.use("Agg") # Use non-interactive backend
23
  import matplotlib.pyplot as plt
24
  import numpy as np
25
  import requests
 
29
 
30
  from rfm.data.dataset_types import Trajectory, ProgressSample, PreferenceSample
31
  from rfm.evals.eval_utils import build_payload, post_batch_npy
32
+ from rfm.evals.eval_viz_utils import create_combined_progress_success_plot
33
  from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
34
 
35
  logger = logging.getLogger(__name__)
 
63
  "aliangdw/usc_xarm_policy_ranking",
64
  "aliangdw/usc_franka_policy_ranking",
65
  "aliangdw/utd_so101_policy_ranking",
66
+ "aliangdw/utd_so101_human",
67
  ]
68
 
69
  # Global server state
 
71
  "server_url": None,
72
  }
73
 
74
+
75
  def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[str]]:
76
  """Check server health and get model info."""
77
  if not server_url:
78
  return "Please provide a server URL.", None, None
79
+
80
  try:
81
  url = server_url.rstrip("/") + "/health"
82
  response = requests.get(url, timeout=5.0)
83
  response.raise_for_status()
84
  health_data = response.json()
85
+
86
  # Also try to get GPU status for more info
87
  try:
88
  status_url = server_url.rstrip("/") + "/gpu_status"
 
92
  health_data.update(status_data)
93
  except:
94
  pass
95
+
96
  # Try to get model info
97
  model_info_text = None
98
  try:
 
103
  model_info_text = format_model_info(model_info_data)
104
  except Exception as e:
105
  logger.warning(f"Could not fetch model info: {e}")
106
+
107
  _server_state["server_url"] = server_url
108
+ return (
109
+ f"Server connected: {health_data.get('available_gpus', 0)}/{health_data.get('total_gpus', 0)} GPUs available",
110
+ health_data,
111
+ model_info_text,
112
+ )
113
  except requests.exceptions.RequestException as e:
114
  return f"Error connecting to server: {str(e)}", None, None
115
 
 
117
  def format_model_info(model_info: dict) -> str:
118
  """Format model info and experiment config as markdown."""
119
  lines = ["## Model Information\n"]
120
+
121
  # Model path
122
  model_path = model_info.get("model_path", "Unknown")
123
  lines.append(f"**Model Path:** `{model_path}`\n")
124
+
125
  # Number of GPUs
126
  num_gpus = model_info.get("num_gpus", "Unknown")
127
  lines.append(f"**Number of GPUs:** {num_gpus}\n")
128
+
129
  # Model architecture
130
  model_arch = model_info.get("model_architecture", {})
131
  if model_arch and "error" not in model_arch:
132
  lines.append("\n## Model Architecture\n")
133
+
134
  model_class = model_arch.get("model_class", "Unknown")
135
  model_module = model_arch.get("model_module", "Unknown")
136
  lines.append(f"- **Model Class:** `{model_class}`\n")
137
  lines.append(f"- **Module:** `{model_module}`\n")
138
+
139
  # Parameter counts
140
  total_params = model_arch.get("total_parameters")
141
  trainable_params = model_arch.get("trainable_parameters")
142
  frozen_params = model_arch.get("frozen_parameters")
143
  trainable_pct = model_arch.get("trainable_percentage")
144
+
145
  if total_params is not None:
146
  lines.append(f"\n### Parameter Statistics\n")
147
  lines.append(f"- **Total Parameters:** {total_params:,}\n")
 
151
  lines.append(f"- **Frozen Parameters:** {frozen_params:,}\n")
152
  if trainable_pct is not None:
153
  lines.append(f"- **Trainable Percentage:** {trainable_pct:.2f}%\n")
154
+
155
  # Architecture summary
156
  arch_summary = model_arch.get("architecture_summary", [])
157
  if arch_summary:
 
161
  module_type = module_info.get("type", "Unknown")
162
  params = module_info.get("parameters", 0)
163
  lines.append(f"- **{name}** (`{module_type}`): {params:,} parameters\n")
164
+
165
  # Experiment config
166
  exp_config = model_info.get("experiment_config", {})
167
  if exp_config:
168
  lines.append("\n## Experiment Configuration\n")
169
+
170
  # Model config
171
  model_cfg = exp_config.get("model", {})
172
  if model_cfg:
 
179
  lines.append(f"- **Train Success Head:** {model_cfg.get('train_success_head', False)}\n")
180
  lines.append(f"- **Use PEFT:** {model_cfg.get('use_peft', False)}\n")
181
  lines.append(f"- **Use Unsloth:** {model_cfg.get('use_unsloth', False)}\n")
182
+
183
  # Data config
184
  data_cfg = exp_config.get("data", {})
185
  if data_cfg:
186
  lines.append("\n### Data Configuration\n")
187
  lines.append(f"- **Max Frames:** {data_cfg.get('max_frames', 'N/A')}\n")
188
+ lines.append(
189
+ f"- **Resized Dimensions:** {data_cfg.get('resized_height', 'N/A')}x{data_cfg.get('resized_width', 'N/A')}\n"
190
+ )
191
+ train_datasets = data_cfg.get("train_datasets", [])
192
  if train_datasets:
193
  lines.append(f"- **Train Datasets:** {', '.join(train_datasets)}\n")
194
+ eval_datasets = data_cfg.get("eval_datasets", [])
195
  if eval_datasets:
196
  lines.append(f"- **Eval Datasets:** {', '.join(eval_datasets)}\n")
197
+
198
  # Training config
199
  training_cfg = exp_config.get("training", {})
200
  if training_cfg:
201
  lines.append("\n### Training Configuration\n")
202
  lines.append(f"- **Learning Rate:** {training_cfg.get('learning_rate', 'N/A')}\n")
203
  lines.append(f"- **Batch Size:** {training_cfg.get('per_device_train_batch_size', 'N/A')}\n")
204
+ lines.append(
205
+ f"- **Gradient Accumulation Steps:** {training_cfg.get('gradient_accumulation_steps', 'N/A')}\n"
206
+ )
207
  lines.append(f"- **Max Steps:** {training_cfg.get('max_steps', 'N/A')}\n")
208
+
209
  return "".join(lines)
210
 
211
 
 
214
  try:
215
  if not dataset_name or not config_name:
216
  return None, "Please provide both dataset name and configuration"
217
+
218
  dataset = load_dataset_hf(dataset_name, name=config_name, split="train")
219
+
220
  if len(dataset) == 0:
221
  return None, f"Dataset {dataset_name}/{config_name} is empty"
222
+
223
  return dataset, f"Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}"
224
  except Exception as e:
225
  error_msg = str(e)
 
246
  try:
247
  item = dataset[int(index)]
248
  frames_data = item["frames"]
249
+
250
  if isinstance(frames_data, str):
251
  # Construct HuggingFace Hub URL
252
  if dataset_name:
253
  video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}"
254
  else:
255
  video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}"
256
+
257
  task = item.get("task", "Complete the task")
258
  quality_label = item.get("quality_label", None)
259
  partial_success = item.get("partial_success", None)
260
+
261
  return video_path, task, quality_label, partial_success
262
  else:
263
  return None, None, None, None
 
282
  # Check if it's a URL or local file
283
  is_url = video_path.startswith(("http://", "https://"))
284
  is_local_file = os.path.exists(video_path) if not is_url else False
285
+
286
  if not is_url and not is_local_file:
287
  logger.warning(f"Video path does not exist: {video_path}")
288
  return None
 
319
  frame_indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
320
 
321
  frames_array = vr.get_batch(frame_indices).asnumpy() # Shape: (T, H, W, C)
322
+ del vr
323
  return frames_array
324
  except Exception as e:
325
  logger.error(f"Error extracting frames from {video_path}: {e}")
 
331
  task_text: str = "Complete the task",
332
  server_url: str = "",
333
  fps: float = 1.0,
334
+ ) -> Tuple[Optional[str], Optional[str]]:
335
  """Process single video for progress and success predictions using eval server."""
336
  if not server_url:
337
+ return None, "Please provide a server URL and check connection first."
338
+
339
  if not _server_state.get("server_url"):
340
+ return None, "Server not connected. Please check server connection first."
341
+
342
  if video_path is None:
343
+ return None, "Please provide a video."
344
 
345
  try:
346
  frames_array = extract_frames(video_path, fps=fps)
347
  if frames_array is None or frames_array.size == 0:
348
+ return None, "Could not extract frames from video."
349
 
350
  # Convert frames to (T, H, W, C) numpy array with uint8 values
351
  if frames_array.dtype != np.uint8:
352
  frames_array = np.clip(frames_array, 0, 255).astype(np.uint8)
353
+
354
  num_frames = frames_array.shape[0]
355
  frames_shape = frames_array.shape # (T, H, W, C)
356
 
 
381
  # Process response
382
  outputs_progress = response.get("outputs_progress", {})
383
  progress_pred = outputs_progress.get("progress_pred", [])
384
+ outputs_success = response.get("outputs_success", {})
385
+ success_probs = outputs_success.get("success_probs", []) if outputs_success else None
386
+
387
  # Extract progress predictions
388
  if progress_pred and len(progress_pred) > 0:
389
  progress_array = np.array(progress_pred[0]) # First sample
390
  else:
391
  progress_array = np.array([])
392
 
393
+ # Extract success predictions if available
394
+ success_array = None
395
+ if success_probs and len(success_probs) > 0:
396
+ success_array = np.array(success_probs[0])
397
+
398
+ # Convert success_array to binary if available
399
+ success_binary = None
400
+ if success_array is not None:
401
+ success_binary = (success_array > 0.5).astype(float)
402
+
403
+ # Create combined plot using shared helper function
404
+ fig = create_combined_progress_success_plot(
405
+ progress_pred=progress_array if len(progress_array) > 0 else np.array([0.0]),
406
+ num_frames=num_frames,
407
+ success_binary=success_binary,
408
+ success_probs=success_array,
409
+ success_labels=None, # No ground truth labels available
410
+ is_discrete_mode=False,
411
+ num_bins=10,
412
+ title=f"Progress & Success - {task_text}",
413
+ )
414
+
415
+ # Save to temporary file
416
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
417
+ fig.savefig(tmp_file.name, dpi=150, bbox_inches="tight")
418
+ plt.close(fig)
419
+ progress_plot = tmp_file.name
420
 
421
  info_text = f"**Frames processed:** {num_frames}\n"
422
  if len(progress_array) > 0:
423
  info_text += f"**Final progress:** {progress_array[-1]:.3f}\n"
424
+ if success_array is not None and len(success_array) > 0:
425
+ info_text += f"**Final success probability:** {success_array[-1]:.3f}\n"
426
 
427
+ # Return combined plot (which includes success if available)
428
+ return progress_plot, info_text
429
 
430
  except Exception as e:
431
+ return None, f"Error processing video: {str(e)}"
432
 
433
 
434
  def process_dual_videos(
 
442
  """Process two videos for preference or similarity prediction using eval server."""
443
  if not server_url:
444
  return "Please provide a server URL and check connection first.", None
445
+
446
  if not _server_state.get("server_url"):
447
  return "Server not connected. Please check server connection first.", None
448
 
 
519
  else:
520
  result_text += "Could not extract preference prediction from server response.\n"
521
 
522
+ elif prediction_type == "progress":
523
+ # Create ProgressSamples for both videos
524
+ from rfm.data.dataset_types import ProgressSample
525
+
526
+ progress_sample_a = ProgressSample(
527
+ trajectory=trajectory_a,
528
+ data_gen_strategy="demo",
529
+ )
530
+ progress_sample_b = ProgressSample(
531
+ trajectory=trajectory_b,
532
+ data_gen_strategy="demo",
533
+ )
534
+
535
+ # Build payload and send to server
536
+ files, sample_data = build_payload([progress_sample_a, progress_sample_b])
537
+ response = post_batch_npy(server_url, files, sample_data, timeout_s=120.0)
538
+
539
+ # Process response
540
+ outputs_progress = response.get("outputs_progress", {})
541
+ progress_pred = outputs_progress.get("progress_pred", [])
542
+
543
+ result_text = f"**Progress Comparison:**\n"
544
+ if progress_pred and len(progress_pred) >= 2:
545
+ progress_a = np.array(progress_pred[0])
546
+ progress_b = np.array(progress_pred[1])
547
+
548
+ final_progress_a = float(progress_a[-1]) if len(progress_a) > 0 else 0.0
549
+ final_progress_b = float(progress_b[-1]) if len(progress_b) > 0 else 0.0
550
+
551
+ result_text += f"- Video A final progress: {final_progress_a:.3f}\n"
552
+ result_text += f"- Video B final progress: {final_progress_b:.3f}\n"
553
+ result_text += f"- Difference: {abs(final_progress_a - final_progress_b):.3f}\n"
554
+ if final_progress_a > final_progress_b:
555
+ result_text += f"- Video A has higher progress\n"
556
+ elif final_progress_b > final_progress_a:
557
+ result_text += f"- Video B has higher progress\n"
558
+ else:
559
+ result_text += f"- Both videos have equal progress\n"
560
+ else:
561
+ result_text += "Could not extract progress predictions from server response.\n"
562
+
563
  else: # similarity - not yet implemented in eval server response format
564
  result_text = "Similarity prediction not yet supported in eval server response format."
565
 
 
574
  return f"Error processing videos: {str(e)}", None
575
 
576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  def create_comparison_plot(frames_a: list, frames_b: list, prediction_type: str) -> str:
578
  """Create side-by-side comparison plot of two videos."""
579
+ plt.rcParams["font.family"] = "DejaVu Sans"
580
+ plt.rcParams["font.size"] = 16
581
+
582
  fig, axes = plt.subplots(2, min(8, max(len(frames_a), len(frames_b))), figsize=(16, 4))
583
+
584
  if len(axes.shape) == 1:
585
  axes = axes.reshape(2, -1)
586
+
587
  # Sample frames to display
588
  num_display = min(8, max(len(frames_a), len(frames_b)))
589
  indices_a = np.linspace(0, len(frames_a) - 1, num_display, dtype=int) if len(frames_a) > 1 else [0]
590
  indices_b = np.linspace(0, len(frames_b) - 1, num_display, dtype=int) if len(frames_b) > 1 else [0]
591
+
592
  # Display frames from video A (top row)
593
  for idx, frame_idx in enumerate(indices_a):
594
  if frame_idx < len(frames_a):
595
  axes[0, idx].imshow(frames_a[frame_idx])
596
+ axes[0, idx].axis("off")
597
+ axes[0, idx].set_title(f"Frame {frame_idx}", fontsize=12)
598
+
599
  # Display frames from video B (bottom row)
600
  for idx, frame_idx in enumerate(indices_b):
601
  if frame_idx < len(frames_b):
602
  axes[1, idx].imshow(frames_b[frame_idx])
603
+ axes[1, idx].axis("off")
604
+ axes[1, idx].set_title(f"Frame {frame_idx}", fontsize=12)
605
+
606
  # Add row labels
607
+ fig.text(0.02, 0.75, "Video A", rotation=90, fontsize=18, fontweight="bold", va="center")
608
+ fig.text(0.02, 0.25, "Video B", rotation=90, fontsize=18, fontweight="bold", va="center")
609
+
610
  title = f"{prediction_type.capitalize()} Comparison: Video A vs Video B"
611
+ fig.suptitle(title, fontsize=20, fontweight="bold", y=0.98)
612
+
613
  plt.tight_layout()
614
+
615
  # Save to temporary file
616
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
617
+ plt.savefig(tmp_file.name, dpi=150, bbox_inches="tight")
618
  plt.close()
619
+
620
  return tmp_file.name
621
 
622
 
 
646
  with gr.Tab("Server Setup"):
647
  gr.Markdown("### Connect to Eval Server")
648
  gr.Markdown("Enter the eval server URL and check connection.")
649
+
650
  with gr.Row():
651
  with gr.Column(scale=3):
652
  server_url_input = gr.Textbox(
 
657
  )
658
  with gr.Column(scale=1):
659
  check_connection_btn = gr.Button("Check Connection", variant="primary", size="sm")
660
+
661
  server_status = gr.Markdown("Enter server URL and click 'Check Connection'")
662
  model_info_display = gr.Markdown("", visible=False)
663
 
 
668
  return status, gr.update(value=model_info_text, visible=True)
669
  else:
670
  return status, gr.update(visible=False)
671
+
672
  check_connection_btn.click(
673
  fn=on_check_connection,
674
  inputs=[server_url_input],
 
678
  with gr.Tab("Progress Prediction"):
679
  gr.Markdown("### Progress & Success Prediction")
680
  gr.Markdown("Upload a video or select one from a dataset to get progress predictions.")
681
+
682
  with gr.Row():
683
  with gr.Column():
684
  with gr.Accordion("📁 Select from Dataset", open=False):
 
686
  choices=PREDEFINED_DATASETS,
687
  value="jesbu1/oxe_rfm",
688
  label="Dataset Name",
689
+ allow_custom_value=True,
690
  )
691
  config_name_single = gr.Dropdown(
692
+ choices=[], value="", label="Configuration Name", allow_custom_value=True
 
 
 
693
  )
694
  with gr.Row():
695
  refresh_configs_btn = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm")
696
  load_dataset_btn = gr.Button("Load Dataset", variant="secondary", size="sm")
697
+
698
  dataset_status_single = gr.Markdown("", visible=False)
699
  with gr.Row():
700
  prev_traj_btn = gr.Button("⬅️ Prev", variant="secondary", size="sm")
701
  trajectory_slider = gr.Slider(
702
+ minimum=0, maximum=0, step=1, value=0, label="Trajectory Index", interactive=True
 
 
 
 
 
703
  )
704
  next_traj_btn = gr.Button("Next ➡️", variant="secondary", size="sm")
705
  trajectory_metadata = gr.Markdown("", visible=False)
706
  use_dataset_video_btn = gr.Button("Use Selected Video", variant="secondary")
707
+
708
  gr.Markdown("---")
709
  gr.Markdown("**OR**")
710
  gr.Markdown("---")
711
+
712
  single_video_input = gr.Video(label="Upload Video", height=300)
713
  task_text_input = gr.Textbox(
714
  label="Task Description",
 
726
  analyze_single_btn = gr.Button("Analyze Video", variant="primary")
727
 
728
  with gr.Column():
729
+ progress_plot = gr.Image(label="Progress & Success Prediction", height=400)
 
730
  info_output = gr.Markdown("")
731
+
732
  # State variables for dataset
733
  current_dataset_single = gr.State(None)
734
+
735
  def update_config_choices_single(dataset_name):
736
  """Update config choices when dataset changes."""
737
  if not dataset_name:
 
745
  except Exception as e:
746
  logger.warning(f"Could not fetch configs: {e}")
747
  return gr.update(choices=[], value="")
748
+
749
  def load_dataset_single(dataset_name, config_name):
750
  """Load dataset and update slider."""
751
  dataset, status = load_rfm_dataset(dataset_name, config_name)
 
754
  return (
755
  dataset,
756
  gr.update(value=status, visible=True),
757
+ gr.update(
758
+ maximum=max_index, value=0, interactive=True, label=f"Trajectory Index (0 to {max_index})"
759
+ ),
760
  )
761
  else:
762
  return None, gr.update(value=status, visible=True), gr.update(maximum=0, value=0, interactive=False)
763
+
764
  def use_dataset_video(dataset, index, dataset_name):
765
  """Load video from dataset and update inputs."""
766
  if dataset is None:
767
+ return (
768
+ None,
769
+ "Complete the task",
770
+ gr.update(value="No dataset loaded", visible=True),
771
+ gr.update(visible=False),
772
+ )
773
+
774
  video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
775
  if video_path:
776
  # Build metadata text
 
779
  metadata_lines.append(f"**Quality Label:** {quality_label}")
780
  if partial_success is not None:
781
  metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
782
+
783
  metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
784
  status_text = f"✅ Loaded trajectory {index} from dataset"
785
  if metadata_text:
786
  status_text += f"\n\n{metadata_text}"
787
+
788
  return (
789
+ video_path,
790
+ task,
791
  gr.update(value=status_text, visible=True),
792
+ gr.update(value=metadata_text, visible=bool(metadata_text)),
793
  )
794
  else:
795
+ return (
796
+ None,
797
+ "Complete the task",
798
+ gr.update(value="❌ Error loading trajectory", visible=True),
799
+ gr.update(visible=False),
800
+ )
801
+
802
  def next_trajectory(dataset, current_idx, dataset_name):
803
  """Go to next trajectory."""
804
  if dataset is None:
805
  return 0, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
806
  next_idx = min(current_idx + 1, len(dataset) - 1)
807
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(
808
+ dataset, next_idx, dataset_name
809
+ )
810
+
811
  if video_path:
812
  # Build metadata text
813
  metadata_lines = []
 
815
  metadata_lines.append(f"**Quality Label:** {quality_label}")
816
  if partial_success is not None:
817
  metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
818
+
819
  metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
820
  return (
821
  next_idx,
822
  video_path,
823
  task,
824
  gr.update(value=metadata_text, visible=bool(metadata_text)),
825
+ gr.update(value=f"✅ Trajectory {next_idx}/{len(dataset) - 1}", visible=True),
826
  )
827
  else:
828
  return current_idx, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
829
+
830
  def prev_trajectory(dataset, current_idx, dataset_name):
831
  """Go to previous trajectory."""
832
  if dataset is None:
833
  return 0, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
834
  prev_idx = max(current_idx - 1, 0)
835
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(
836
+ dataset, prev_idx, dataset_name
837
+ )
838
+
839
  if video_path:
840
  # Build metadata text
841
  metadata_lines = []
 
843
  metadata_lines.append(f"**Quality Label:** {quality_label}")
844
  if partial_success is not None:
845
  metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
846
+
847
  metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
848
  return (
849
  prev_idx,
850
  video_path,
851
  task,
852
  gr.update(value=metadata_text, visible=bool(metadata_text)),
853
+ gr.update(value=f"✅ Trajectory {prev_idx}/{len(dataset) - 1}", visible=True),
854
  )
855
  else:
856
  return current_idx, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
857
+
858
  def update_trajectory_on_slider_change(dataset, index, dataset_name):
859
  """Update trajectory metadata when slider changes."""
860
  if dataset is None:
861
  return gr.update(visible=False), gr.update(visible=False)
862
+
863
  video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
864
  if video_path:
865
  # Build metadata text
 
868
  metadata_lines.append(f"**Quality Label:** {quality_label}")
869
  if partial_success is not None:
870
  metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
871
+
872
  metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
873
  return (
874
  gr.update(value=metadata_text, visible=bool(metadata_text)),
875
+ gr.update(value=f"Trajectory {index}/{len(dataset) - 1}", visible=True),
876
  )
877
  else:
878
  return gr.update(visible=False), gr.update(visible=False)
879
+
880
  # Dataset selection handlers
881
  dataset_name_single.change(
882
+ fn=update_config_choices_single, inputs=[dataset_name_single], outputs=[config_name_single]
 
 
883
  )
884
+
885
  refresh_configs_btn.click(
886
+ fn=update_config_choices_single, inputs=[dataset_name_single], outputs=[config_name_single]
 
 
887
  )
888
+
889
  load_dataset_btn.click(
890
  fn=load_dataset_single,
891
  inputs=[dataset_name_single, config_name_single],
892
+ outputs=[current_dataset_single, dataset_status_single, trajectory_slider],
893
  )
894
+
895
  use_dataset_video_btn.click(
896
  fn=use_dataset_video,
897
  inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
898
+ outputs=[single_video_input, task_text_input, dataset_status_single, trajectory_metadata],
899
  )
900
+
901
  # Navigation buttons
902
  next_traj_btn.click(
903
  fn=next_trajectory,
904
  inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
905
+ outputs=[
906
+ trajectory_slider,
907
+ single_video_input,
908
+ task_text_input,
909
+ trajectory_metadata,
910
+ dataset_status_single,
911
+ ],
912
  )
913
+
914
  prev_traj_btn.click(
915
  fn=prev_trajectory,
916
  inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
917
+ outputs=[
918
+ trajectory_slider,
919
+ single_video_input,
920
+ task_text_input,
921
+ trajectory_metadata,
922
+ dataset_status_single,
923
+ ],
924
  )
925
+
926
  # Update metadata when slider changes
927
  trajectory_slider.change(
928
  fn=update_trajectory_on_slider_change,
929
  inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
930
+ outputs=[trajectory_metadata, dataset_status_single],
931
  )
932
+
933
  analyze_single_btn.click(
934
  fn=process_single_video,
935
  inputs=[single_video_input, task_text_input, server_url_input, fps_input_single],
936
+ outputs=[progress_plot, info_output],
937
+ api_name="process_single_video",
938
  )
939
 
940
  with gr.Tab("Preference/Similarity Analysis"):
 
949
  value="Complete the task",
950
  )
951
  prediction_type = gr.Radio(
952
+ choices=["preference", "similarity", "progress"],
953
  value="preference",
954
  label="Prediction Type",
955
  )
 
971
  fn=process_dual_videos,
972
  inputs=[video_a_input, video_b_input, task_text_dual, prediction_type, server_url_input, fps_input_dual],
973
  outputs=[result_text, comparison_plot],
974
+ api_name="process_dual_videos",
975
  )
976
 
977
 
978
  def main():
979
  """Launch the Gradio app."""
980
  import sys
981
+
982
  # Check if reload mode is requested
983
  watch_files = os.getenv("GRADIO_WATCH", "0") == "1" or "--reload" in sys.argv
984
+
985
  demo.launch(
986
  server_name="0.0.0.0",
987
  server_port=7860,