Anthony Liang commited on
Commit
38f9df5
·
1 Parent(s): ad49410
Files changed (1) hide show
  1. app.py +458 -134
app.py CHANGED
@@ -23,13 +23,11 @@ matplotlib.use("Agg") # Use non-interactive backend
23
  import matplotlib.pyplot as plt
24
  import numpy as np
25
  import requests
26
- from PIL import Image
27
- import decord
28
  from typing import Any, Optional, Tuple
29
 
30
  from rfm.data.dataset_types import Trajectory, ProgressSample, PreferenceSample
31
  from rfm.evals.eval_utils import build_payload, post_batch_npy
32
- from rfm.evals.eval_viz_utils import create_combined_progress_success_plot
33
  from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
34
 
35
  logger = logging.getLogger(__name__)
@@ -266,66 +264,6 @@ def get_trajectory_video_path(dataset, index, dataset_name):
266
  return None, None, None, None
267
 
268
 
269
- def extract_frames(video_path: str, fps: float = 1.0) -> np.ndarray:
270
- """Extract frames from video file as numpy array (T, H, W, C).
271
-
272
- Supports both local file paths and URLs (e.g., HuggingFace Hub URLs).
273
- Uses the provided ``fps`` to control how densely frames are sampled from
274
- the underlying video; there is no additional hard cap on the number of frames.
275
- """
276
- if video_path is None:
277
- return None
278
-
279
- if isinstance(video_path, tuple):
280
- video_path = video_path[0]
281
-
282
- # Check if it's a URL or local file
283
- is_url = video_path.startswith(("http://", "https://"))
284
- is_local_file = os.path.exists(video_path) if not is_url else False
285
-
286
- if not is_url and not is_local_file:
287
- logger.warning(f"Video path does not exist: {video_path}")
288
- return None
289
-
290
- try:
291
- # decord.VideoReader can handle both local files and URLs
292
- vr = decord.VideoReader(video_path, num_threads=1)
293
- total_frames = len(vr)
294
-
295
- # Determine native FPS; fall back to a reasonable default if unavailable
296
- try:
297
- native_fps = float(vr.get_avg_fps())
298
- except Exception:
299
- native_fps = 1.0
300
-
301
- # If user-specified fps is invalid or None, default to native fps
302
- if fps is None or fps <= 0:
303
- fps = native_fps
304
-
305
- # Compute how many frames we want based on desired fps
306
- # num_frames ≈ total_duration * fps = total_frames * (fps / native_fps)
307
- if native_fps > 0:
308
- desired_frames = int(round(total_frames * (fps / native_fps)))
309
- else:
310
- desired_frames = total_frames
311
-
312
- # Clamp to [1, total_frames]
313
- desired_frames = max(1, min(desired_frames, total_frames))
314
-
315
- # Evenly sample indices to match the desired number of frames
316
- if desired_frames == total_frames:
317
- frame_indices = list(range(total_frames))
318
- else:
319
- frame_indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
320
-
321
- frames_array = vr.get_batch(frame_indices).asnumpy() # Shape: (T, H, W, C)
322
- del vr
323
- return frames_array
324
- except Exception as e:
325
- logger.error(f"Error extracting frames from {video_path}: {e}")
326
- return None
327
-
328
-
329
  def process_single_video(
330
  video_path: str,
331
  task_text: str = "Complete the task",
@@ -394,7 +332,7 @@ def process_single_video(
394
  success_array = None
395
  if success_probs and len(success_probs) > 0:
396
  success_array = np.array(success_probs[0])
397
-
398
  # Convert success_array to binary if available
399
  success_binary = None
400
  if success_array is not None:
@@ -408,10 +346,9 @@ def process_single_video(
408
  success_probs=success_array,
409
  success_labels=None, # No ground truth labels available
410
  is_discrete_mode=False,
411
- num_bins=10,
412
  title=f"Progress & Success - {task_text}",
413
  )
414
-
415
  # Save to temporary file
416
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
417
  fig.savefig(tmp_file.name, dpi=150, bbox_inches="tight")
@@ -438,25 +375,25 @@ def process_dual_videos(
438
  prediction_type: str = "preference",
439
  server_url: str = "",
440
  fps: float = 1.0,
441
- ) -> Tuple[Optional[str], Optional[str]]:
442
  """Process two videos for preference or similarity prediction using eval server."""
443
  if not server_url:
444
- return "Please provide a server URL and check connection first.", None
445
 
446
  if not _server_state.get("server_url"):
447
- return "Server not connected. Please check server connection first.", None
448
 
449
  if video_a_path is None or video_b_path is None:
450
- return "Please provide both videos.", None
451
 
452
  try:
453
  frames_array_a = extract_frames(video_a_path, fps=fps)
454
  frames_array_b = extract_frames(video_b_path, fps=fps)
455
 
456
  if frames_array_a is None or frames_array_a.size == 0:
457
- return "Could not extract frames from video A.", None
458
  if frames_array_b is None or frames_array_b.size == 0:
459
- return "Could not extract frames from video B.", None
460
 
461
  # Convert frames to uint8
462
  if frames_array_a.dtype != np.uint8:
@@ -563,81 +500,27 @@ def process_dual_videos(
563
  else: # similarity - not yet implemented in eval server response format
564
  result_text = "Similarity prediction not yet supported in eval server response format."
565
 
566
- # Create comparison plot
567
- frames_a_list = [Image.fromarray(frame) for frame in frames_array_a]
568
- frames_b_list = [Image.fromarray(frame) for frame in frames_array_b]
569
- comparison_plot = create_comparison_plot(frames_a_list, frames_b_list, prediction_type)
570
-
571
- return result_text, comparison_plot
572
 
573
  except Exception as e:
574
- return f"Error processing videos: {str(e)}", None
575
-
576
-
577
- def create_comparison_plot(frames_a: list, frames_b: list, prediction_type: str) -> str:
578
- """Create side-by-side comparison plot of two videos."""
579
- plt.rcParams["font.family"] = "DejaVu Sans"
580
- plt.rcParams["font.size"] = 16
581
-
582
- fig, axes = plt.subplots(2, min(8, max(len(frames_a), len(frames_b))), figsize=(16, 4))
583
-
584
- if len(axes.shape) == 1:
585
- axes = axes.reshape(2, -1)
586
-
587
- # Sample frames to display
588
- num_display = min(8, max(len(frames_a), len(frames_b)))
589
- indices_a = np.linspace(0, len(frames_a) - 1, num_display, dtype=int) if len(frames_a) > 1 else [0]
590
- indices_b = np.linspace(0, len(frames_b) - 1, num_display, dtype=int) if len(frames_b) > 1 else [0]
591
-
592
- # Display frames from video A (top row)
593
- for idx, frame_idx in enumerate(indices_a):
594
- if frame_idx < len(frames_a):
595
- axes[0, idx].imshow(frames_a[frame_idx])
596
- axes[0, idx].axis("off")
597
- axes[0, idx].set_title(f"Frame {frame_idx}", fontsize=12)
598
-
599
- # Display frames from video B (bottom row)
600
- for idx, frame_idx in enumerate(indices_b):
601
- if frame_idx < len(frames_b):
602
- axes[1, idx].imshow(frames_b[frame_idx])
603
- axes[1, idx].axis("off")
604
- axes[1, idx].set_title(f"Frame {frame_idx}", fontsize=12)
605
-
606
- # Add row labels
607
- fig.text(0.02, 0.75, "Video A", rotation=90, fontsize=18, fontweight="bold", va="center")
608
- fig.text(0.02, 0.25, "Video B", rotation=90, fontsize=18, fontweight="bold", va="center")
609
-
610
- title = f"{prediction_type.capitalize()} Comparison: Video A vs Video B"
611
- fig.suptitle(title, fontsize=20, fontweight="bold", y=0.98)
612
-
613
- plt.tight_layout()
614
 
615
- # Save to temporary file
616
- tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
617
- plt.savefig(tmp_file.name, dpi=150, bbox_inches="tight")
618
- plt.close()
619
 
620
- return tmp_file.name
621
 
622
 
623
  # Create Gradio interface
624
  try:
625
  # Try with theme (Gradio 4.0+)
626
- demo = gr.Blocks(title="RFM Inference Visualizer", theme=gr.themes.Soft())
627
  except TypeError:
628
  # Fallback for older Gradio versions without theme support
629
- demo = gr.Blocks(title="RFM Inference Visualizer")
630
 
631
  with demo:
632
  gr.Markdown(
633
  """
634
- # RFM (Reward Foundation Model) Inference Visualizer
635
-
636
- Visualize progress, success, preference, and similarity predictions from the Reward Foundation Model.
637
-
638
- **Features:**
639
- - **Single Video**: Get progress and success predictions
640
- - **Dual Videos**: Compare two videos with preference or similarity predictions
641
 
642
  **Note:** This app connects to an eval server. Please provide the server URL and check connection before use.
643
  """
@@ -941,6 +824,58 @@ with demo:
941
  gr.Markdown("### Preference & Similarity Prediction")
942
  with gr.Row():
943
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
944
  video_a_input = gr.Video(label="Video A", height=250)
945
  video_b_input = gr.Video(label="Video B", height=250)
946
  task_text_dual = gr.Textbox(
@@ -964,13 +899,402 @@ with demo:
964
  analyze_dual_btn = gr.Button("Compare Videos", variant="primary")
965
 
966
  with gr.Column():
 
 
 
 
 
 
967
  result_text = gr.Markdown("")
968
- comparison_plot = gr.Image(label="Video Comparison", height=500)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
969
 
970
  analyze_dual_btn.click(
971
  fn=process_dual_videos,
972
  inputs=[video_a_input, video_b_input, task_text_dual, prediction_type, server_url_input, fps_input_dual],
973
- outputs=[result_text, comparison_plot],
974
  api_name="process_dual_videos",
975
  )
976
 
 
23
  import matplotlib.pyplot as plt
24
  import numpy as np
25
  import requests
 
 
26
  from typing import Any, Optional, Tuple
27
 
28
  from rfm.data.dataset_types import Trajectory, ProgressSample, PreferenceSample
29
  from rfm.evals.eval_utils import build_payload, post_batch_npy
30
+ from rfm.evals.eval_viz_utils import create_combined_progress_success_plot, extract_frames
31
  from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
32
 
33
  logger = logging.getLogger(__name__)
 
264
  return None, None, None, None
265
 
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  def process_single_video(
268
  video_path: str,
269
  task_text: str = "Complete the task",
 
332
  success_array = None
333
  if success_probs and len(success_probs) > 0:
334
  success_array = np.array(success_probs[0])
335
+
336
  # Convert success_array to binary if available
337
  success_binary = None
338
  if success_array is not None:
 
346
  success_probs=success_array,
347
  success_labels=None, # No ground truth labels available
348
  is_discrete_mode=False,
 
349
  title=f"Progress & Success - {task_text}",
350
  )
351
+
352
  # Save to temporary file
353
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
354
  fig.savefig(tmp_file.name, dpi=150, bbox_inches="tight")
 
375
  prediction_type: str = "preference",
376
  server_url: str = "",
377
  fps: float = 1.0,
378
+ ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
379
  """Process two videos for preference or similarity prediction using eval server."""
380
  if not server_url:
381
+ return "Please provide a server URL and check connection first.", None, None
382
 
383
  if not _server_state.get("server_url"):
384
+ return "Server not connected. Please check server connection first.", None, None
385
 
386
  if video_a_path is None or video_b_path is None:
387
+ return "Please provide both videos.", None, None
388
 
389
  try:
390
  frames_array_a = extract_frames(video_a_path, fps=fps)
391
  frames_array_b = extract_frames(video_b_path, fps=fps)
392
 
393
  if frames_array_a is None or frames_array_a.size == 0:
394
+ return "Could not extract frames from video A.", None, None
395
  if frames_array_b is None or frames_array_b.size == 0:
396
+ return "Could not extract frames from video B.", None, None
397
 
398
  # Convert frames to uint8
399
  if frames_array_a.dtype != np.uint8:
 
500
  else: # similarity - not yet implemented in eval server response format
501
  result_text = "Similarity prediction not yet supported in eval server response format."
502
 
503
+ # Return result text and both video paths
504
+ return result_text, video_a_path, video_b_path
 
 
 
 
505
 
506
  except Exception as e:
507
+ return f"Error processing videos: {str(e)}", None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
 
 
 
 
 
509
 
 
510
 
511
 
512
  # Create Gradio interface
513
  try:
514
  # Try with theme (Gradio 4.0+)
515
+ demo = gr.Blocks(title="RFM Evaluation Server", theme=gr.themes.Soft())
516
  except TypeError:
517
  # Fallback for older Gradio versions without theme support
518
+ demo = gr.Blocks(title="RFM Evaluation Server")
519
 
520
  with demo:
521
  gr.Markdown(
522
  """
523
+ # RFM (Reward Foundation Model) Evaluation Server
 
 
 
 
 
 
524
 
525
  **Note:** This app connects to an eval server. Please provide the server URL and check connection before use.
526
  """
 
824
  gr.Markdown("### Preference & Similarity Prediction")
825
  with gr.Row():
826
  with gr.Column():
827
+ with gr.Accordion("📁 Video A - Select from Dataset", open=False):
828
+ dataset_name_a = gr.Dropdown(
829
+ choices=PREDEFINED_DATASETS,
830
+ value="jesbu1/oxe_rfm",
831
+ label="Dataset Name",
832
+ allow_custom_value=True,
833
+ )
834
+ config_name_a = gr.Dropdown(
835
+ choices=[], value="", label="Configuration Name", allow_custom_value=True
836
+ )
837
+ with gr.Row():
838
+ refresh_configs_btn_a = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm")
839
+ load_dataset_btn_a = gr.Button("Load Dataset", variant="secondary", size="sm")
840
+
841
+ dataset_status_a = gr.Markdown("", visible=False)
842
+ with gr.Row():
843
+ prev_traj_btn_a = gr.Button("⬅️ Prev", variant="secondary", size="sm")
844
+ trajectory_slider_a = gr.Slider(
845
+ minimum=0, maximum=0, step=1, value=0, label="Trajectory Index", interactive=True
846
+ )
847
+ next_traj_btn_a = gr.Button("Next ➡️", variant="secondary", size="sm")
848
+ trajectory_metadata_a = gr.Markdown("", visible=False)
849
+ use_dataset_video_btn_a = gr.Button("Use Selected Video for A", variant="secondary")
850
+
851
+ with gr.Accordion("📁 Video B - Select from Dataset", open=False):
852
+ dataset_name_b = gr.Dropdown(
853
+ choices=PREDEFINED_DATASETS,
854
+ value="jesbu1/oxe_rfm",
855
+ label="Dataset Name",
856
+ allow_custom_value=True,
857
+ )
858
+ config_name_b = gr.Dropdown(
859
+ choices=[], value="", label="Configuration Name", allow_custom_value=True
860
+ )
861
+ with gr.Row():
862
+ refresh_configs_btn_b = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm")
863
+ load_dataset_btn_b = gr.Button("Load Dataset", variant="secondary", size="sm")
864
+
865
+ dataset_status_b = gr.Markdown("", visible=False)
866
+ with gr.Row():
867
+ prev_traj_btn_b = gr.Button("⬅️ Prev", variant="secondary", size="sm")
868
+ trajectory_slider_b = gr.Slider(
869
+ minimum=0, maximum=0, step=1, value=0, label="Trajectory Index", interactive=True
870
+ )
871
+ next_traj_btn_b = gr.Button("Next ➡️", variant="secondary", size="sm")
872
+ trajectory_metadata_b = gr.Markdown("", visible=False)
873
+ use_dataset_video_btn_b = gr.Button("Use Selected Video for B", variant="secondary")
874
+
875
+ gr.Markdown("---")
876
+ gr.Markdown("**OR Upload Videos Directly**")
877
+ gr.Markdown("---")
878
+
879
  video_a_input = gr.Video(label="Video A", height=250)
880
  video_b_input = gr.Video(label="Video B", height=250)
881
  task_text_dual = gr.Textbox(
 
899
  analyze_dual_btn = gr.Button("Compare Videos", variant="primary")
900
 
901
  with gr.Column():
902
+ # Videos displayed side by side
903
+ with gr.Row():
904
+ video_a_display = gr.Video(label="Video A", height=400)
905
+ video_b_display = gr.Video(label="Video B", height=400)
906
+
907
+ # Result text at the bottom
908
  result_text = gr.Markdown("")
909
+
910
+ # State variables for datasets
911
+ current_dataset_a = gr.State(None)
912
+ current_dataset_b = gr.State(None)
913
+
914
+ # Helper functions for Video A
915
+ def update_config_choices_a(dataset_name):
916
+ """Update config choices for Video A when dataset changes."""
917
+ if not dataset_name:
918
+ return gr.update(choices=[], value="")
919
+ try:
920
+ configs = get_available_configs(dataset_name)
921
+ if configs:
922
+ return gr.update(choices=configs, value=configs[0])
923
+ else:
924
+ return gr.update(choices=[], value="")
925
+ except Exception as e:
926
+ logger.warning(f"Could not fetch configs: {e}")
927
+ return gr.update(choices=[], value="")
928
+
929
+ def load_dataset_a(dataset_name, config_name):
930
+ """Load dataset A and update slider."""
931
+ dataset, status = load_rfm_dataset(dataset_name, config_name)
932
+ if dataset is not None:
933
+ max_index = len(dataset) - 1
934
+ return (
935
+ dataset,
936
+ gr.update(value=status, visible=True),
937
+ gr.update(
938
+ maximum=max_index, value=0, interactive=True, label=f"Trajectory Index (0 to {max_index})"
939
+ ),
940
+ )
941
+ else:
942
+ return None, gr.update(value=status, visible=True), gr.update(maximum=0, value=0, interactive=False)
943
+
944
+ def use_dataset_video_a(dataset, index, dataset_name):
945
+ """Load video A from dataset and update input."""
946
+ if dataset is None:
947
+ return (
948
+ None,
949
+ gr.update(value="No dataset loaded", visible=True),
950
+ gr.update(visible=False),
951
+ )
952
+
953
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
954
+ if video_path:
955
+ # Build metadata text
956
+ metadata_lines = []
957
+ if quality_label:
958
+ metadata_lines.append(f"**Quality Label:** {quality_label}")
959
+ if partial_success is not None:
960
+ metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
961
+
962
+ metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
963
+ status_text = f"✅ Loaded trajectory {index} from dataset for Video A"
964
+ if metadata_text:
965
+ status_text += f"\n\n{metadata_text}"
966
+
967
+ return (
968
+ video_path,
969
+ gr.update(value=status_text, visible=True),
970
+ gr.update(value=metadata_text, visible=bool(metadata_text)),
971
+ )
972
+ else:
973
+ return (
974
+ None,
975
+ gr.update(value="❌ Error loading trajectory", visible=True),
976
+ gr.update(visible=False),
977
+ )
978
+
979
+ def next_trajectory_a(dataset, current_idx, dataset_name):
980
+ """Go to next trajectory for Video A."""
981
+ if dataset is None:
982
+ return 0, None, gr.update(visible=False), gr.update(visible=False)
983
+ next_idx = min(current_idx + 1, len(dataset) - 1)
984
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(
985
+ dataset, next_idx, dataset_name
986
+ )
987
+
988
+ if video_path:
989
+ # Build metadata text
990
+ metadata_lines = []
991
+ if quality_label:
992
+ metadata_lines.append(f"**Quality Label:** {quality_label}")
993
+ if partial_success is not None:
994
+ metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
995
+
996
+ metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
997
+ return (
998
+ next_idx,
999
+ video_path,
1000
+ gr.update(value=metadata_text, visible=bool(metadata_text)),
1001
+ gr.update(value=f"✅ Trajectory {next_idx}/{len(dataset) - 1}", visible=True),
1002
+ )
1003
+ else:
1004
+ return current_idx, None, gr.update(visible=False), gr.update(visible=False)
1005
+
1006
+ def prev_trajectory_a(dataset, current_idx, dataset_name):
1007
+ """Go to previous trajectory for Video A."""
1008
+ if dataset is None:
1009
+ return 0, None, gr.update(visible=False), gr.update(visible=False)
1010
+ prev_idx = max(current_idx - 1, 0)
1011
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(
1012
+ dataset, prev_idx, dataset_name
1013
+ )
1014
+
1015
+ if video_path:
1016
+ # Build metadata text
1017
+ metadata_lines = []
1018
+ if quality_label:
1019
+ metadata_lines.append(f"**Quality Label:** {quality_label}")
1020
+ if partial_success is not None:
1021
+ metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
1022
+
1023
+ metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
1024
+ return (
1025
+ prev_idx,
1026
+ video_path,
1027
+ gr.update(value=metadata_text, visible=bool(metadata_text)),
1028
+ gr.update(value=f"✅ Trajectory {prev_idx}/{len(dataset) - 1}", visible=True),
1029
+ )
1030
+ else:
1031
+ return current_idx, None, gr.update(visible=False), gr.update(visible=False)
1032
+
1033
+ def update_trajectory_on_slider_change_a(dataset, index, dataset_name):
1034
+ """Update trajectory metadata when slider changes for Video A."""
1035
+ if dataset is None:
1036
+ return gr.update(visible=False), gr.update(visible=False)
1037
+
1038
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
1039
+ if video_path:
1040
+ # Build metadata text
1041
+ metadata_lines = []
1042
+ if quality_label:
1043
+ metadata_lines.append(f"**Quality Label:** {quality_label}")
1044
+ if partial_success is not None:
1045
+ metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
1046
+
1047
+ metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
1048
+ return (
1049
+ gr.update(value=metadata_text, visible=bool(metadata_text)),
1050
+ gr.update(value=f"Trajectory {index}/{len(dataset) - 1}", visible=True),
1051
+ )
1052
+ else:
1053
+ return gr.update(visible=False), gr.update(visible=False)
1054
+
1055
+ # Helper functions for Video B (same as Video A)
1056
+ def update_config_choices_b(dataset_name):
1057
+ """Update config choices for Video B when dataset changes."""
1058
+ if not dataset_name:
1059
+ return gr.update(choices=[], value="")
1060
+ try:
1061
+ configs = get_available_configs(dataset_name)
1062
+ if configs:
1063
+ return gr.update(choices=configs, value=configs[0])
1064
+ else:
1065
+ return gr.update(choices=[], value="")
1066
+ except Exception as e:
1067
+ logger.warning(f"Could not fetch configs: {e}")
1068
+ return gr.update(choices=[], value="")
1069
+
1070
+ def load_dataset_b(dataset_name, config_name):
1071
+ """Load dataset B and update slider."""
1072
+ dataset, status = load_rfm_dataset(dataset_name, config_name)
1073
+ if dataset is not None:
1074
+ max_index = len(dataset) - 1
1075
+ return (
1076
+ dataset,
1077
+ gr.update(value=status, visible=True),
1078
+ gr.update(
1079
+ maximum=max_index, value=0, interactive=True, label=f"Trajectory Index (0 to {max_index})"
1080
+ ),
1081
+ )
1082
+ else:
1083
+ return None, gr.update(value=status, visible=True), gr.update(maximum=0, value=0, interactive=False)
1084
+
1085
+ def use_dataset_video_b(dataset, index, dataset_name):
1086
+ """Load video B from dataset and update input."""
1087
+ if dataset is None:
1088
+ return (
1089
+ None,
1090
+ gr.update(value="No dataset loaded", visible=True),
1091
+ gr.update(visible=False),
1092
+ )
1093
+
1094
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
1095
+ if video_path:
1096
+ # Build metadata text
1097
+ metadata_lines = []
1098
+ if quality_label:
1099
+ metadata_lines.append(f"**Quality Label:** {quality_label}")
1100
+ if partial_success is not None:
1101
+ metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
1102
+
1103
+ metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
1104
+ status_text = f"✅ Loaded trajectory {index} from dataset for Video B"
1105
+ if metadata_text:
1106
+ status_text += f"\n\n{metadata_text}"
1107
+
1108
+ return (
1109
+ video_path,
1110
+ gr.update(value=status_text, visible=True),
1111
+ gr.update(value=metadata_text, visible=bool(metadata_text)),
1112
+ )
1113
+ else:
1114
+ return (
1115
+ None,
1116
+ gr.update(value="❌ Error loading trajectory", visible=True),
1117
+ gr.update(visible=False),
1118
+ )
1119
+
1120
+ def next_trajectory_b(dataset, current_idx, dataset_name):
1121
+ """Go to next trajectory for Video B."""
1122
+ if dataset is None:
1123
+ return 0, None, gr.update(visible=False), gr.update(visible=False)
1124
+ next_idx = min(current_idx + 1, len(dataset) - 1)
1125
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(
1126
+ dataset, next_idx, dataset_name
1127
+ )
1128
+
1129
+ if video_path:
1130
+ # Build metadata text
1131
+ metadata_lines = []
1132
+ if quality_label:
1133
+ metadata_lines.append(f"**Quality Label:** {quality_label}")
1134
+ if partial_success is not None:
1135
+ metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
1136
+
1137
+ metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
1138
+ return (
1139
+ next_idx,
1140
+ video_path,
1141
+ gr.update(value=metadata_text, visible=bool(metadata_text)),
1142
+ gr.update(value=f"✅ Trajectory {next_idx}/{len(dataset) - 1}", visible=True),
1143
+ )
1144
+ else:
1145
+ return current_idx, None, gr.update(visible=False), gr.update(visible=False)
1146
+
1147
+ def prev_trajectory_b(dataset, current_idx, dataset_name):
1148
+ """Go to previous trajectory for Video B."""
1149
+ if dataset is None:
1150
+ return 0, None, gr.update(visible=False), gr.update(visible=False)
1151
+ prev_idx = max(current_idx - 1, 0)
1152
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(
1153
+ dataset, prev_idx, dataset_name
1154
+ )
1155
+
1156
+ if video_path:
1157
+ # Build metadata text
1158
+ metadata_lines = []
1159
+ if quality_label:
1160
+ metadata_lines.append(f"**Quality Label:** {quality_label}")
1161
+ if partial_success is not None:
1162
+ metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
1163
+
1164
+ metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
1165
+ return (
1166
+ prev_idx,
1167
+ video_path,
1168
+ gr.update(value=metadata_text, visible=bool(metadata_text)),
1169
+ gr.update(value=f"✅ Trajectory {prev_idx}/{len(dataset) - 1}", visible=True),
1170
+ )
1171
+ else:
1172
+ return current_idx, None, gr.update(visible=False), gr.update(visible=False)
1173
+
1174
+ def update_trajectory_on_slider_change_b(dataset, index, dataset_name):
1175
+ """Update trajectory metadata when slider changes for Video B."""
1176
+ if dataset is None:
1177
+ return gr.update(visible=False), gr.update(visible=False)
1178
+
1179
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
1180
+ if video_path:
1181
+ # Build metadata text
1182
+ metadata_lines = []
1183
+ if quality_label:
1184
+ metadata_lines.append(f"**Quality Label:** {quality_label}")
1185
+ if partial_success is not None:
1186
+ metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
1187
+
1188
+ metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
1189
+ return (
1190
+ gr.update(value=metadata_text, visible=bool(metadata_text)),
1191
+ gr.update(value=f"Trajectory {index}/{len(dataset) - 1}", visible=True),
1192
+ )
1193
+ else:
1194
+ return gr.update(visible=False), gr.update(visible=False)
1195
+
1196
+ # Video A dataset selection handlers
1197
+ dataset_name_a.change(
1198
+ fn=update_config_choices_a, inputs=[dataset_name_a], outputs=[config_name_a]
1199
+ )
1200
+
1201
+ refresh_configs_btn_a.click(
1202
+ fn=update_config_choices_a, inputs=[dataset_name_a], outputs=[config_name_a]
1203
+ )
1204
+
1205
+ load_dataset_btn_a.click(
1206
+ fn=load_dataset_a,
1207
+ inputs=[dataset_name_a, config_name_a],
1208
+ outputs=[current_dataset_a, dataset_status_a, trajectory_slider_a],
1209
+ )
1210
+
1211
+ use_dataset_video_btn_a.click(
1212
+ fn=use_dataset_video_a,
1213
+ inputs=[current_dataset_a, trajectory_slider_a, dataset_name_a],
1214
+ outputs=[video_a_input, dataset_status_a, trajectory_metadata_a],
1215
+ )
1216
+
1217
+ next_traj_btn_a.click(
1218
+ fn=next_trajectory_a,
1219
+ inputs=[current_dataset_a, trajectory_slider_a, dataset_name_a],
1220
+ outputs=[
1221
+ trajectory_slider_a,
1222
+ video_a_input,
1223
+ trajectory_metadata_a,
1224
+ dataset_status_a,
1225
+ ],
1226
+ )
1227
+
1228
+ prev_traj_btn_a.click(
1229
+ fn=prev_trajectory_a,
1230
+ inputs=[current_dataset_a, trajectory_slider_a, dataset_name_a],
1231
+ outputs=[
1232
+ trajectory_slider_a,
1233
+ video_a_input,
1234
+ trajectory_metadata_a,
1235
+ dataset_status_a,
1236
+ ],
1237
+ )
1238
+
1239
+ trajectory_slider_a.change(
1240
+ fn=update_trajectory_on_slider_change_a,
1241
+ inputs=[current_dataset_a, trajectory_slider_a, dataset_name_a],
1242
+ outputs=[trajectory_metadata_a, dataset_status_a],
1243
+ )
1244
+
1245
+ # Video B dataset selection handlers
1246
+ dataset_name_b.change(
1247
+ fn=update_config_choices_b, inputs=[dataset_name_b], outputs=[config_name_b]
1248
+ )
1249
+
1250
+ refresh_configs_btn_b.click(
1251
+ fn=update_config_choices_b, inputs=[dataset_name_b], outputs=[config_name_b]
1252
+ )
1253
+
1254
+ load_dataset_btn_b.click(
1255
+ fn=load_dataset_b,
1256
+ inputs=[dataset_name_b, config_name_b],
1257
+ outputs=[current_dataset_b, dataset_status_b, trajectory_slider_b],
1258
+ )
1259
+
1260
+ use_dataset_video_btn_b.click(
1261
+ fn=use_dataset_video_b,
1262
+ inputs=[current_dataset_b, trajectory_slider_b, dataset_name_b],
1263
+ outputs=[video_b_input, dataset_status_b, trajectory_metadata_b],
1264
+ )
1265
+
1266
+ next_traj_btn_b.click(
1267
+ fn=next_trajectory_b,
1268
+ inputs=[current_dataset_b, trajectory_slider_b, dataset_name_b],
1269
+ outputs=[
1270
+ trajectory_slider_b,
1271
+ video_b_input,
1272
+ trajectory_metadata_b,
1273
+ dataset_status_b,
1274
+ ],
1275
+ )
1276
+
1277
+ prev_traj_btn_b.click(
1278
+ fn=prev_trajectory_b,
1279
+ inputs=[current_dataset_b, trajectory_slider_b, dataset_name_b],
1280
+ outputs=[
1281
+ trajectory_slider_b,
1282
+ video_b_input,
1283
+ trajectory_metadata_b,
1284
+ dataset_status_b,
1285
+ ],
1286
+ )
1287
+
1288
+ trajectory_slider_b.change(
1289
+ fn=update_trajectory_on_slider_change_b,
1290
+ inputs=[current_dataset_b, trajectory_slider_b, dataset_name_b],
1291
+ outputs=[trajectory_metadata_b, dataset_status_b],
1292
+ )
1293
 
1294
  analyze_dual_btn.click(
1295
  fn=process_dual_videos,
1296
  inputs=[video_a_input, video_b_input, task_text_dual, prediction_type, server_url_input, fps_input_dual],
1297
+ outputs=[result_text, video_a_display, video_b_display],
1298
  api_name="process_dual_videos",
1299
  )
1300