Anthony Liang commited on
Commit
cc5bab9
·
1 Parent(s): 3e462dd
Files changed (4) hide show
  1. app.py +70 -52
  2. eval_utils.py +110 -6
  3. eval_viz_utils.py +1 -1
  4. samplers/eval/confusion_matrix.py +27 -27
app.py CHANGED
@@ -75,15 +75,17 @@ _server_state = {
75
  }
76
 
77
 
78
- def discover_available_models(base_url: str = "http://40.119.56.66", port_range: tuple = (8000, 8010)) -> List[Tuple[str, str]]:
 
 
79
  """Discover available models by pinging ports in the specified range.
80
-
81
  Returns:
82
  List of tuples: [(server_url, model_name), ...]
83
  """
84
  available_models = []
85
  start_port, end_port = port_range
86
-
87
  for port in range(start_port, end_port + 1):
88
  server_url = f"{base_url.rstrip('/')}:{port}"
89
  try:
@@ -108,7 +110,7 @@ def discover_available_models(base_url: str = "http://40.119.56.66", port_range:
108
  except requests.exceptions.RequestException:
109
  # Port not available, continue
110
  continue
111
-
112
  return available_models
113
 
114
 
@@ -116,7 +118,7 @@ def get_model_info_for_url(server_url: str) -> Optional[str]:
116
  """Get formatted model info for a given server URL."""
117
  if not server_url:
118
  return None
119
-
120
  try:
121
  model_info_url = server_url.rstrip("/") + "/model_info"
122
  model_info_response = requests.get(model_info_url, timeout=5.0)
@@ -325,7 +327,7 @@ def process_single_video(
325
  # Get server URL from state if not provided
326
  if not server_url:
327
  server_url = _server_state.get("server_url")
328
-
329
  if not server_url:
330
  return None, "Please select a model from the dropdown above and ensure it's connected."
331
 
@@ -435,7 +437,7 @@ def process_two_videos(
435
  # Get server URL from state if not provided
436
  if not server_url:
437
  server_url = _server_state.get("server_url")
438
-
439
  if not server_url:
440
  return "Please select a model from the dropdown above and ensure it's connected.", None, None
441
 
@@ -560,7 +562,7 @@ def process_two_videos(
560
  # - Video A as reference trajectory
561
  # - Video B as similar trajectory
562
  # diff_trajectory is None in inference mode (only need similarity between ref and sim)
563
-
564
  # Create SimilaritySample with Video A as ref and Video B as sim
565
  similarity_sample = SimilaritySample(
566
  ref_trajectory=trajectory_a,
@@ -601,8 +603,6 @@ def process_two_videos(
601
  return f"Error processing videos: {str(e)}", None, None
602
 
603
 
604
-
605
-
606
  # Create Gradio interface
607
  try:
608
  # Try with theme (Gradio 4.0+)
@@ -633,10 +633,10 @@ with demo:
633
  None,
634
  {}, # Empty mapping
635
  )
636
-
637
  _server_state["base_url"] = base_url
638
  models = discover_available_models(base_url, port_range=(8000, 8010))
639
-
640
  if not models:
641
  return (
642
  gr.update(choices=[], value=None),
@@ -645,7 +645,7 @@ with demo:
645
  None,
646
  {}, # Empty mapping
647
  )
648
-
649
  # Format choices: show model_name in dropdown
650
  # Store mapping of model_name to URL in state
651
  choices = []
@@ -653,17 +653,17 @@ with demo:
653
  for url, name in models:
654
  choices.append(name)
655
  url_map[name] = url
656
-
657
  # Auto-select first model
658
  selected_choice = choices[0] if choices else None
659
  selected_url = url_map.get(selected_choice) if selected_choice else None
660
-
661
  # Get model info for selected model
662
  model_info_text = get_model_info_for_url(selected_url) if selected_url else ""
663
  status_text = f"✅ Found {len(models)} model(s). Auto-selected first model."
664
-
665
  _server_state["server_url"] = selected_url
666
-
667
  return (
668
  gr.update(choices=choices, value=selected_choice),
669
  gr.update(value=status_text, visible=True),
@@ -680,23 +680,25 @@ with demo:
680
  gr.update(value="", visible=True),
681
  None,
682
  )
683
-
684
  # Get URL from mapping
685
  server_url = url_mapping.get(model_choice) if url_mapping else None
686
-
687
  if not server_url:
688
  return (
689
- gr.update(value="Could not find server URL for selected model. Please rediscover models.", visible=True),
 
 
690
  gr.update(value="", visible=True),
691
  None,
692
  )
693
-
694
  # Get model info
695
  model_info_text = get_model_info_for_url(server_url) or ""
696
  status, health_data, _ = check_server_health(server_url)
697
-
698
  _server_state["server_url"] = server_url
699
-
700
  return (
701
  gr.update(value=status, visible=True),
702
  gr.update(value=model_info_text, visible=True),
@@ -706,16 +708,16 @@ with demo:
706
  # Use Gradio's built-in Sidebar component (collapsible by default)
707
  with gr.Sidebar():
708
  gr.Markdown("### 🔧 Model Configuration")
709
-
710
  base_url_input = gr.Textbox(
711
  label="Base Server URL",
712
  placeholder="http://40.119.56.66",
713
  value="http://40.119.56.66",
714
  interactive=True,
715
  )
716
-
717
  discover_btn = gr.Button("🔍 Discover Models", variant="primary", size="lg")
718
-
719
  model_dropdown = gr.Dropdown(
720
  label="Select Model",
721
  choices=[],
@@ -723,11 +725,9 @@ with demo:
723
  interactive=True,
724
  info="Models will be discovered on ports 8000-8010",
725
  )
726
-
727
- server_status = gr.Markdown(
728
- "Click 'Discover Models' to find available models"
729
- )
730
-
731
  gr.Markdown("---")
732
  gr.Markdown("### 📋 Model Information")
733
  model_info_display = gr.Markdown("")
@@ -848,7 +848,9 @@ with demo:
848
  gr.update(visible=False),
849
  )
850
 
851
- video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
 
 
852
  if video_path:
853
  # Build metadata text
854
  metadata_lines = []
@@ -937,7 +939,9 @@ with demo:
937
  if dataset is None:
938
  return gr.update(visible=False), gr.update(visible=False)
939
 
940
- video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
 
 
941
  if video_path:
942
  # Build metadata text
943
  metadata_lines = []
@@ -1009,7 +1013,13 @@ with demo:
1009
 
1010
  analyze_single_btn.click(
1011
  fn=process_single_video,
1012
- inputs=[single_video_input, task_text_input, server_url_state, fps_input_single, use_frame_steps_single],
 
 
 
 
 
 
1013
  outputs=[progress_plot, info_output],
1014
  api_name="process_single_video",
1015
  )
@@ -1103,7 +1113,7 @@ with demo:
1103
  with gr.Row():
1104
  video_a_display = gr.Video(label="Video A", height=400)
1105
  video_b_display = gr.Video(label="Video B", height=400)
1106
-
1107
  # Result text at the bottom
1108
  result_text = gr.Markdown("")
1109
 
@@ -1161,7 +1171,9 @@ with demo:
1161
  gr.update(visible=False),
1162
  )
1163
 
1164
- video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
 
 
1165
  if video_path:
1166
  # Build metadata text
1167
  metadata_lines = []
@@ -1246,7 +1258,9 @@ with demo:
1246
  if dataset is None:
1247
  return gr.update(visible=False), gr.update(visible=False)
1248
 
1249
- video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
 
 
1250
  if video_path:
1251
  # Build metadata text
1252
  metadata_lines = []
@@ -1302,7 +1316,9 @@ with demo:
1302
  gr.update(visible=False),
1303
  )
1304
 
1305
- video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
 
 
1306
  if video_path:
1307
  # Build metadata text
1308
  metadata_lines = []
@@ -1387,7 +1403,9 @@ with demo:
1387
  if dataset is None:
1388
  return gr.update(visible=False), gr.update(visible=False)
1389
 
1390
- video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
 
 
1391
  if video_path:
1392
  # Build metadata text
1393
  metadata_lines = []
@@ -1405,13 +1423,9 @@ with demo:
1405
  return gr.update(visible=False), gr.update(visible=False)
1406
 
1407
  # Video A dataset selection handlers
1408
- dataset_name_a.change(
1409
- fn=update_config_choices_a, inputs=[dataset_name_a], outputs=[config_name_a]
1410
- )
1411
 
1412
- refresh_configs_btn_a.click(
1413
- fn=update_config_choices_a, inputs=[dataset_name_a], outputs=[config_name_a]
1414
- )
1415
 
1416
  load_dataset_btn_a.click(
1417
  fn=load_dataset_a,
@@ -1454,13 +1468,9 @@ with demo:
1454
  )
1455
 
1456
  # Video B dataset selection handlers
1457
- dataset_name_b.change(
1458
- fn=update_config_choices_b, inputs=[dataset_name_b], outputs=[config_name_b]
1459
- )
1460
 
1461
- refresh_configs_btn_b.click(
1462
- fn=update_config_choices_b, inputs=[dataset_name_b], outputs=[config_name_b]
1463
- )
1464
 
1465
  load_dataset_btn_b.click(
1466
  fn=load_dataset_b,
@@ -1504,7 +1514,15 @@ with demo:
1504
 
1505
  analyze_dual_btn.click(
1506
  fn=process_two_videos,
1507
- inputs=[video_a_input, video_b_input, task_text_dual, prediction_type, server_url_state, fps_input_dual, use_frame_steps_dual],
 
 
 
 
 
 
 
 
1508
  outputs=[result_text, video_a_display, video_b_display],
1509
  api_name="process_two_videos",
1510
  )
 
75
  }
76
 
77
 
78
+ def discover_available_models(
79
+ base_url: str = "http://40.119.56.66", port_range: tuple = (8000, 8010)
80
+ ) -> List[Tuple[str, str]]:
81
  """Discover available models by pinging ports in the specified range.
82
+
83
  Returns:
84
  List of tuples: [(server_url, model_name), ...]
85
  """
86
  available_models = []
87
  start_port, end_port = port_range
88
+
89
  for port in range(start_port, end_port + 1):
90
  server_url = f"{base_url.rstrip('/')}:{port}"
91
  try:
 
110
  except requests.exceptions.RequestException:
111
  # Port not available, continue
112
  continue
113
+
114
  return available_models
115
 
116
 
 
118
  """Get formatted model info for a given server URL."""
119
  if not server_url:
120
  return None
121
+
122
  try:
123
  model_info_url = server_url.rstrip("/") + "/model_info"
124
  model_info_response = requests.get(model_info_url, timeout=5.0)
 
327
  # Get server URL from state if not provided
328
  if not server_url:
329
  server_url = _server_state.get("server_url")
330
+
331
  if not server_url:
332
  return None, "Please select a model from the dropdown above and ensure it's connected."
333
 
 
437
  # Get server URL from state if not provided
438
  if not server_url:
439
  server_url = _server_state.get("server_url")
440
+
441
  if not server_url:
442
  return "Please select a model from the dropdown above and ensure it's connected.", None, None
443
 
 
562
  # - Video A as reference trajectory
563
  # - Video B as similar trajectory
564
  # diff_trajectory is None in inference mode (only need similarity between ref and sim)
565
+
566
  # Create SimilaritySample with Video A as ref and Video B as sim
567
  similarity_sample = SimilaritySample(
568
  ref_trajectory=trajectory_a,
 
603
  return f"Error processing videos: {str(e)}", None, None
604
 
605
 
 
 
606
  # Create Gradio interface
607
  try:
608
  # Try with theme (Gradio 4.0+)
 
633
  None,
634
  {}, # Empty mapping
635
  )
636
+
637
  _server_state["base_url"] = base_url
638
  models = discover_available_models(base_url, port_range=(8000, 8010))
639
+
640
  if not models:
641
  return (
642
  gr.update(choices=[], value=None),
 
645
  None,
646
  {}, # Empty mapping
647
  )
648
+
649
  # Format choices: show model_name in dropdown
650
  # Store mapping of model_name to URL in state
651
  choices = []
 
653
  for url, name in models:
654
  choices.append(name)
655
  url_map[name] = url
656
+
657
  # Auto-select first model
658
  selected_choice = choices[0] if choices else None
659
  selected_url = url_map.get(selected_choice) if selected_choice else None
660
+
661
  # Get model info for selected model
662
  model_info_text = get_model_info_for_url(selected_url) if selected_url else ""
663
  status_text = f"✅ Found {len(models)} model(s). Auto-selected first model."
664
+
665
  _server_state["server_url"] = selected_url
666
+
667
  return (
668
  gr.update(choices=choices, value=selected_choice),
669
  gr.update(value=status_text, visible=True),
 
680
  gr.update(value="", visible=True),
681
  None,
682
  )
683
+
684
  # Get URL from mapping
685
  server_url = url_mapping.get(model_choice) if url_mapping else None
686
+
687
  if not server_url:
688
  return (
689
+ gr.update(
690
+ value="Could not find server URL for selected model. Please rediscover models.", visible=True
691
+ ),
692
  gr.update(value="", visible=True),
693
  None,
694
  )
695
+
696
  # Get model info
697
  model_info_text = get_model_info_for_url(server_url) or ""
698
  status, health_data, _ = check_server_health(server_url)
699
+
700
  _server_state["server_url"] = server_url
701
+
702
  return (
703
  gr.update(value=status, visible=True),
704
  gr.update(value=model_info_text, visible=True),
 
708
  # Use Gradio's built-in Sidebar component (collapsible by default)
709
  with gr.Sidebar():
710
  gr.Markdown("### 🔧 Model Configuration")
711
+
712
  base_url_input = gr.Textbox(
713
  label="Base Server URL",
714
  placeholder="http://40.119.56.66",
715
  value="http://40.119.56.66",
716
  interactive=True,
717
  )
718
+
719
  discover_btn = gr.Button("🔍 Discover Models", variant="primary", size="lg")
720
+
721
  model_dropdown = gr.Dropdown(
722
  label="Select Model",
723
  choices=[],
 
725
  interactive=True,
726
  info="Models will be discovered on ports 8000-8010",
727
  )
728
+
729
+ server_status = gr.Markdown("Click 'Discover Models' to find available models")
730
+
 
 
731
  gr.Markdown("---")
732
  gr.Markdown("### 📋 Model Information")
733
  model_info_display = gr.Markdown("")
 
848
  gr.update(visible=False),
849
  )
850
 
851
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(
852
+ dataset, index, dataset_name
853
+ )
854
  if video_path:
855
  # Build metadata text
856
  metadata_lines = []
 
939
  if dataset is None:
940
  return gr.update(visible=False), gr.update(visible=False)
941
 
942
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(
943
+ dataset, index, dataset_name
944
+ )
945
  if video_path:
946
  # Build metadata text
947
  metadata_lines = []
 
1013
 
1014
  analyze_single_btn.click(
1015
  fn=process_single_video,
1016
+ inputs=[
1017
+ single_video_input,
1018
+ task_text_input,
1019
+ server_url_state,
1020
+ fps_input_single,
1021
+ use_frame_steps_single,
1022
+ ],
1023
  outputs=[progress_plot, info_output],
1024
  api_name="process_single_video",
1025
  )
 
1113
  with gr.Row():
1114
  video_a_display = gr.Video(label="Video A", height=400)
1115
  video_b_display = gr.Video(label="Video B", height=400)
1116
+
1117
  # Result text at the bottom
1118
  result_text = gr.Markdown("")
1119
 
 
1171
  gr.update(visible=False),
1172
  )
1173
 
1174
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(
1175
+ dataset, index, dataset_name
1176
+ )
1177
  if video_path:
1178
  # Build metadata text
1179
  metadata_lines = []
 
1258
  if dataset is None:
1259
  return gr.update(visible=False), gr.update(visible=False)
1260
 
1261
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(
1262
+ dataset, index, dataset_name
1263
+ )
1264
  if video_path:
1265
  # Build metadata text
1266
  metadata_lines = []
 
1316
  gr.update(visible=False),
1317
  )
1318
 
1319
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(
1320
+ dataset, index, dataset_name
1321
+ )
1322
  if video_path:
1323
  # Build metadata text
1324
  metadata_lines = []
 
1403
  if dataset is None:
1404
  return gr.update(visible=False), gr.update(visible=False)
1405
 
1406
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(
1407
+ dataset, index, dataset_name
1408
+ )
1409
  if video_path:
1410
  # Build metadata text
1411
  metadata_lines = []
 
1423
  return gr.update(visible=False), gr.update(visible=False)
1424
 
1425
  # Video A dataset selection handlers
1426
+ dataset_name_a.change(fn=update_config_choices_a, inputs=[dataset_name_a], outputs=[config_name_a])
 
 
1427
 
1428
+ refresh_configs_btn_a.click(fn=update_config_choices_a, inputs=[dataset_name_a], outputs=[config_name_a])
 
 
1429
 
1430
  load_dataset_btn_a.click(
1431
  fn=load_dataset_a,
 
1468
  )
1469
 
1470
  # Video B dataset selection handlers
1471
+ dataset_name_b.change(fn=update_config_choices_b, inputs=[dataset_name_b], outputs=[config_name_b])
 
 
1472
 
1473
+ refresh_configs_btn_b.click(fn=update_config_choices_b, inputs=[dataset_name_b], outputs=[config_name_b])
 
 
1474
 
1475
  load_dataset_btn_b.click(
1476
  fn=load_dataset_b,
 
1514
 
1515
  analyze_dual_btn.click(
1516
  fn=process_two_videos,
1517
+ inputs=[
1518
+ video_a_input,
1519
+ video_b_input,
1520
+ task_text_dual,
1521
+ prediction_type,
1522
+ server_url_state,
1523
+ fps_input_dual,
1524
+ use_frame_steps_dual,
1525
+ ],
1526
  outputs=[result_text, video_a_display, video_b_display],
1527
  api_name="process_two_videos",
1528
  )
eval_utils.py CHANGED
@@ -15,8 +15,112 @@ import numpy as np
15
  import requests
16
  import torch
17
 
18
- from rfm.data.dataset_types import PreferenceSample, SimilaritySample, ProgressSample, Trajectory
19
- from rfm.data.datasets.helpers import linspace_subsample_frames, pad_trajectory_to_max_frames_np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  def extract_answer_from_text(text: str) -> str:
@@ -219,10 +323,10 @@ async def post_batch_npy_async(
219
 
220
  async def parse_npy_form_data(form_data: Any) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
221
  """Parse multipart form data to extract numpy arrays and other data.
222
-
223
  Args:
224
  form_data: FastAPI form data from request.form()
225
-
226
  Returns:
227
  Tuple of (numpy_arrays dict, other_data dict)
228
  """
@@ -271,7 +375,7 @@ def reconstruct_payload_from_npy(
271
  other_data: Dictionary of other form data
272
  trajectory_keys: List of trajectory keys to process (default: common keys)
273
  convert_embeddings_to_torch: Whether to convert embeddings to torch tensors
274
-
275
  Returns:
276
  List of reconstructed sample dictionaries
277
  """
@@ -284,7 +388,7 @@ def reconstruct_payload_from_npy(
284
  "traj_diff_trajectory",
285
  "trajectory",
286
  ]
287
-
288
  samples = []
289
 
290
  # Process each sample
 
15
  import requests
16
  import torch
17
 
18
+ from dataset_types import PreferenceSample, SimilaritySample, ProgressSample, Trajectory
19
+
20
+
21
+ def pad_trajectory_to_max_frames_np(
22
+ frames: np.ndarray, progress: List[float], max_frames: int, pad_from: str = "right"
23
+ ) -> Tuple[np.ndarray, List[float]]:
24
+ """Pad trajectory frames and progress to max_frames by repeating the first frame/progress if needed.
25
+
26
+ Args:
27
+ frames: Trajectory frames (numpy array)
28
+ progress: Progress values (list of floats)
29
+ max_frames: Target number of frames
30
+
31
+ Returns:
32
+ Tuple[np.ndarray, List[float]: (padded_frames, padded_progress)
33
+ """
34
+ current_frames = frames.shape[0]
35
+
36
+ if current_frames >= max_frames:
37
+ # No padding needed
38
+ return frames, progress
39
+
40
+ if pad_from == "left":
41
+ pad_frame = frames[0:1] # Keep the batch dimension
42
+ pad_progress = progress[0]
43
+ else:
44
+ pad_frame = frames[-1:]
45
+ pad_progress = progress[-1]
46
+
47
+ # Calculate how many frames to pad
48
+ frames_to_pad = max_frames - current_frames
49
+
50
+ # Pad frames by repeating the first frame
51
+ if pad_from == "left":
52
+ padded_frames = np.concatenate([np.repeat(pad_frame, frames_to_pad, axis=0), frames], axis=0)
53
+ padded_progress = [pad_progress] * frames_to_pad + progress
54
+ else:
55
+ padded_frames = np.concatenate([frames, np.repeat(pad_frame, frames_to_pad, axis=0)], axis=0)
56
+ padded_progress = progress + [pad_progress] * frames_to_pad
57
+
58
+ return padded_frames, padded_progress
59
+
60
+
61
+ def linspace_subsample_frames(
62
+ frames: np.ndarray, num_frames: int = 8, end_idx: Optional[int] = None
63
+ ) -> Tuple[np.ndarray, List[int]]:
64
+ """Uniformly subsample frames from a trajectory and return the indices.
65
+
66
+ This method takes the full trajectory (e.g., 64 frames) and uniformly subsamples
67
+ num_frames from it. The first and last frames are always included.
68
+
69
+ Args:
70
+ frames: Full trajectory frames (N frames)
71
+ num_frames: Number of frames to subsample (default: 8)
72
+ end_idx: Optional end index to subsample up to (if None, uses total_frames - 1)
73
+
74
+ Returns:
75
+ Tuple[np.ndarray, List[int]: (subsampled_frames, subsampled_indices)
76
+ """
77
+ if hasattr(frames, "shape"):
78
+ total_frames = frames.shape[0]
79
+ else:
80
+ total_frames = len(frames)
81
+
82
+ if total_frames <= 0:
83
+ return frames, []
84
+
85
+ # Use end_idx if provided, otherwise use full trajectory
86
+ if end_idx is not None:
87
+ end_idx = min(end_idx, total_frames - 1)
88
+ frames_to_subsample = frames[: end_idx + 1]
89
+ effective_total = end_idx + 1
90
+ else:
91
+ frames_to_subsample = frames
92
+ effective_total = total_frames
93
+
94
+ if effective_total <= num_frames:
95
+ # If we have fewer (or equal) frames than requested, return all frames
96
+ indices = list(range(effective_total))
97
+ return frames_to_subsample, indices
98
+
99
+ # Special case: if num_frames == 1, always take the last frame
100
+ if num_frames == 1:
101
+ indices = [effective_total - 1]
102
+ subsampled_frames = frames_to_subsample[indices]
103
+ return subsampled_frames, indices
104
+
105
+ # Evenly spaced indices from 0 to effective_total-1, inclusive
106
+ indices_np = np.linspace(0, effective_total - 1, num_frames)
107
+ indices = np.rint(indices_np).astype(int).tolist()
108
+
109
+ # Enforce first and last explicitly
110
+ indices[0] = 0
111
+ indices[-1] = effective_total - 1
112
+
113
+ # Ensure indices are strictly non-decreasing and within bounds
114
+ for k in range(1, len(indices)):
115
+ if indices[k] < indices[k - 1]:
116
+ indices[k] = indices[k - 1]
117
+ if indices[k] >= effective_total:
118
+ indices[k] = effective_total - 1
119
+
120
+ # Subsample frames
121
+ subsampled_frames = frames_to_subsample[indices]
122
+
123
+ return subsampled_frames, indices
124
 
125
 
126
  def extract_answer_from_text(text: str) -> str:
 
323
 
324
  async def parse_npy_form_data(form_data: Any) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
325
  """Parse multipart form data to extract numpy arrays and other data.
326
+
327
  Args:
328
  form_data: FastAPI form data from request.form()
329
+
330
  Returns:
331
  Tuple of (numpy_arrays dict, other_data dict)
332
  """
 
375
  other_data: Dictionary of other form data
376
  trajectory_keys: List of trajectory keys to process (default: common keys)
377
  convert_embeddings_to_torch: Whether to convert embeddings to torch tensors
378
+
379
  Returns:
380
  List of reconstructed sample dictionaries
381
  """
 
388
  "traj_diff_trajectory",
389
  "trajectory",
390
  ]
391
+
392
  samples = []
393
 
394
  # Process each sample
eval_viz_utils.py CHANGED
@@ -180,7 +180,7 @@ def extract_frames(video_path: str, fps: float = 1.0, max_frames: int = 64) -> n
180
 
181
  # Clamp to [1, total_frames]
182
  desired_frames = max(1, min(desired_frames, total_frames))
183
-
184
  # IMPORTANT: Cap at max_frames to prevent memory issues
185
  # This is critical when fps is high or videos are long
186
  if desired_frames > max_frames:
 
180
 
181
  # Clamp to [1, total_frames]
182
  desired_frames = max(1, min(desired_frames, total_frames))
183
+
184
  # IMPORTANT: Cap at max_frames to prevent memory issues
185
  # This is critical when fps is high or videos are long
186
  if desired_frames > max_frames:
samplers/eval/confusion_matrix.py CHANGED
@@ -60,7 +60,7 @@ class ConfusionMatrixSampler(RFMBaseSampler):
60
 
61
  def _generate_all_sample_indices(self) -> list[dict]:
62
  """Generate all possible task-trajectory pair sample indices.
63
-
64
  If multiple data sources exist, samples N random trajectories from each data source.
65
  Prioritizes different video tasks first, then prioritizes different language instructions
66
  when creating pairs.
@@ -73,7 +73,7 @@ class ConfusionMatrixSampler(RFMBaseSampler):
73
 
74
  # Sample trajectories per data source (prioritizing different video tasks)
75
  sampled_trajectories, stats = self._sample_trajectories_by_data_source()
76
-
77
  rank_0_print(
78
  f"Processing {len(sampled_trajectories)} trajectories for confusion matrix analysis",
79
  verbose=self.verbose,
@@ -88,7 +88,7 @@ class ConfusionMatrixSampler(RFMBaseSampler):
88
 
89
  # Create task-trajectory pairs with prioritized language instruction pairing
90
  video_task_count = Counter()
91
-
92
  for traj_idx in sampled_trajectories:
93
  traj = self.dataset[traj_idx]
94
  video_task = traj["task"]
@@ -98,7 +98,7 @@ class ConfusionMatrixSampler(RFMBaseSampler):
98
  # continue
99
 
100
  video_task_count[video_task] += 1
101
-
102
  # Pair this trajectory with all language tasks (shuffled for variety)
103
  traj_id = traj.get("id", str(traj_idx))
104
  for lang_task in shuffled_lang_tasks:
@@ -117,15 +117,15 @@ class ConfusionMatrixSampler(RFMBaseSampler):
117
  rank_0_print(f"Generated {len(sample_indices)} task-trajectory pairs", verbose=self.verbose)
118
  rank_0_print(f" Video tasks sampled: {dict(video_task_count)}", verbose=self.verbose)
119
  rank_0_print(f" Trajectories per video task: {dict(sorted(video_task_count.items()))}", verbose=self.verbose)
120
-
121
  return sample_indices
122
 
123
  def _sample_trajectories_by_data_source(self) -> Tuple[list[int], dict]:
124
  """Sample N random trajectories from each data source, prioritizing different video tasks.
125
-
126
  When sampling N trajectories, first selects one trajectory from each unique video task,
127
  then repeats in round-robin fashion until N trajectories are sampled.
128
-
129
  Returns:
130
  Tuple of (list of sampled trajectory indices, stats dictionary)
131
  """
@@ -135,7 +135,7 @@ class ConfusionMatrixSampler(RFMBaseSampler):
135
  "by_task": Counter(),
136
  "traj_to_task": {},
137
  }
138
-
139
  # Group robot trajectories by data source, then by video task
140
  trajectories_by_source_and_task = defaultdict(lambda: defaultdict(list))
141
  for traj_idx in self.robot_trajectories:
@@ -143,7 +143,7 @@ class ConfusionMatrixSampler(RFMBaseSampler):
143
  data_source = traj.get("data_source", "unknown")
144
  video_task = traj.get("task", "unknown")
145
  trajectories_by_source_and_task[data_source][video_task].append(traj_idx)
146
-
147
  rank_0_print(
148
  f"Found {len(trajectories_by_source_and_task)} data sources: {list(trajectories_by_source_and_task.keys())}",
149
  verbose=self.verbose,
@@ -154,17 +154,17 @@ class ConfusionMatrixSampler(RFMBaseSampler):
154
  # Shuffle trajectories within each task for randomization
155
  for task in tasks_to_indices:
156
  self._local_random.shuffle(tasks_to_indices[task])
157
-
158
  # Get all unique tasks for this data source
159
  all_tasks = list(tasks_to_indices.keys())
160
  self._local_random.shuffle(all_tasks) # Randomize task order too
161
-
162
  source_stats = {
163
  "total_available": sum(len(indices) for indices in tasks_to_indices.values()),
164
  "tasks_available": {task: len(indices) for task, indices in tasks_to_indices.items()},
165
  "tasks_sampled": Counter(),
166
  }
167
-
168
  if self.n_trajectories_per_source is None:
169
  # Use all available trajectories
170
  sampled_from_source = []
@@ -172,7 +172,7 @@ class ConfusionMatrixSampler(RFMBaseSampler):
172
  sampled_from_source.extend(indices)
173
  source_stats["tasks_sampled"][task] = len(indices)
174
  stats["by_task"][task] += len(indices)
175
-
176
  rank_0_print(
177
  f" Data source '{data_source}': Using all {len(sampled_from_source)} trajectories",
178
  verbose=self.verbose,
@@ -181,18 +181,18 @@ class ConfusionMatrixSampler(RFMBaseSampler):
181
  # Sample N trajectories using round-robin to prioritize different tasks
182
  n_to_sample = min(self.n_trajectories_per_source, source_stats["total_available"])
183
  sampled_from_source = []
184
-
185
  # Round-robin sampling: first get one from each task, then repeat
186
  task_iterators = {task: iter(indices) for task, indices in tasks_to_indices.items()}
187
  task_list = all_tasks.copy()
188
  round_idx = 0
189
-
190
  while len(sampled_from_source) < n_to_sample:
191
  # If we've gone through all tasks once, reshuffle for next round
192
  if round_idx >= len(task_list):
193
  round_idx = 0
194
  self._local_random.shuffle(task_list)
195
-
196
  # Try to get one trajectory from current task
197
  task = task_list[round_idx]
198
  try:
@@ -206,9 +206,9 @@ class ConfusionMatrixSampler(RFMBaseSampler):
206
  if not task_list:
207
  break # All tasks exhausted
208
  continue
209
-
210
  round_idx += 1
211
-
212
  rank_0_print(
213
  f" Data source '{data_source}': Sampled {len(sampled_from_source)} out of {source_stats['total_available']} trajectories",
214
  verbose=self.verbose,
@@ -217,13 +217,13 @@ class ConfusionMatrixSampler(RFMBaseSampler):
217
  f" Tasks sampled: {dict(sorted(source_stats['tasks_sampled'].items()))}",
218
  verbose=self.verbose,
219
  )
220
-
221
  # Track trajectory to task mapping for stats
222
  for traj_idx in sampled_from_source:
223
  traj = self.dataset[traj_idx]
224
  traj_id = traj.get("id", str(traj_idx))
225
  stats["traj_to_task"][traj_id] = traj.get("task", "unknown")
226
-
227
  sampled_indices.extend(sampled_from_source)
228
  stats["by_source"][data_source] = source_stats
229
 
@@ -231,33 +231,33 @@ class ConfusionMatrixSampler(RFMBaseSampler):
231
 
232
  def _print_sampling_stats(self, stats: dict):
233
  """Print detailed statistics about sampled trajectories.
234
-
235
  Args:
236
  stats: Statistics dictionary from _sample_trajectories_by_data_source
237
  """
238
  if not self.verbose:
239
  return
240
-
241
  rank_0_print("\n=== Confusion Matrix Sampling Statistics ===", verbose=self.verbose)
242
-
243
  # Overall task statistics
244
  rank_0_print(f"\nOverall trajectories per video task:", verbose=self.verbose)
245
  for task, count in sorted(stats["by_task"].items()):
246
  rank_0_print(f" {task}: {count} trajectories", verbose=self.verbose)
247
-
248
  # Per data source statistics
249
  rank_0_print(f"\nPer data source breakdown:", verbose=self.verbose)
250
  for data_source, source_stats in stats["by_source"].items():
251
  rank_0_print(f" Data source: {data_source}", verbose=self.verbose)
252
  rank_0_print(f" Total available: {source_stats['total_available']}", verbose=self.verbose)
253
  rank_0_print(f" Tasks available: {len(source_stats['tasks_available'])}", verbose=self.verbose)
254
- for task, count in sorted(source_stats['tasks_available'].items()):
255
- sampled_count = source_stats['tasks_sampled'].get(task, 0)
256
  rank_0_print(
257
  f" {task}: {sampled_count}/{count} trajectories sampled",
258
  verbose=self.verbose,
259
  )
260
-
261
  rank_0_print("=" * 50, verbose=self.verbose)
262
 
263
  def _generate_sample_from_indices(self, sample_idx_info: dict) -> PreferenceSample:
 
60
 
61
  def _generate_all_sample_indices(self) -> list[dict]:
62
  """Generate all possible task-trajectory pair sample indices.
63
+
64
  If multiple data sources exist, samples N random trajectories from each data source.
65
  Prioritizes different video tasks first, then prioritizes different language instructions
66
  when creating pairs.
 
73
 
74
  # Sample trajectories per data source (prioritizing different video tasks)
75
  sampled_trajectories, stats = self._sample_trajectories_by_data_source()
76
+
77
  rank_0_print(
78
  f"Processing {len(sampled_trajectories)} trajectories for confusion matrix analysis",
79
  verbose=self.verbose,
 
88
 
89
  # Create task-trajectory pairs with prioritized language instruction pairing
90
  video_task_count = Counter()
91
+
92
  for traj_idx in sampled_trajectories:
93
  traj = self.dataset[traj_idx]
94
  video_task = traj["task"]
 
98
  # continue
99
 
100
  video_task_count[video_task] += 1
101
+
102
  # Pair this trajectory with all language tasks (shuffled for variety)
103
  traj_id = traj.get("id", str(traj_idx))
104
  for lang_task in shuffled_lang_tasks:
 
117
  rank_0_print(f"Generated {len(sample_indices)} task-trajectory pairs", verbose=self.verbose)
118
  rank_0_print(f" Video tasks sampled: {dict(video_task_count)}", verbose=self.verbose)
119
  rank_0_print(f" Trajectories per video task: {dict(sorted(video_task_count.items()))}", verbose=self.verbose)
120
+
121
  return sample_indices
122
 
123
  def _sample_trajectories_by_data_source(self) -> Tuple[list[int], dict]:
124
  """Sample N random trajectories from each data source, prioritizing different video tasks.
125
+
126
  When sampling N trajectories, first selects one trajectory from each unique video task,
127
  then repeats in round-robin fashion until N trajectories are sampled.
128
+
129
  Returns:
130
  Tuple of (list of sampled trajectory indices, stats dictionary)
131
  """
 
135
  "by_task": Counter(),
136
  "traj_to_task": {},
137
  }
138
+
139
  # Group robot trajectories by data source, then by video task
140
  trajectories_by_source_and_task = defaultdict(lambda: defaultdict(list))
141
  for traj_idx in self.robot_trajectories:
 
143
  data_source = traj.get("data_source", "unknown")
144
  video_task = traj.get("task", "unknown")
145
  trajectories_by_source_and_task[data_source][video_task].append(traj_idx)
146
+
147
  rank_0_print(
148
  f"Found {len(trajectories_by_source_and_task)} data sources: {list(trajectories_by_source_and_task.keys())}",
149
  verbose=self.verbose,
 
154
  # Shuffle trajectories within each task for randomization
155
  for task in tasks_to_indices:
156
  self._local_random.shuffle(tasks_to_indices[task])
157
+
158
  # Get all unique tasks for this data source
159
  all_tasks = list(tasks_to_indices.keys())
160
  self._local_random.shuffle(all_tasks) # Randomize task order too
161
+
162
  source_stats = {
163
  "total_available": sum(len(indices) for indices in tasks_to_indices.values()),
164
  "tasks_available": {task: len(indices) for task, indices in tasks_to_indices.items()},
165
  "tasks_sampled": Counter(),
166
  }
167
+
168
  if self.n_trajectories_per_source is None:
169
  # Use all available trajectories
170
  sampled_from_source = []
 
172
  sampled_from_source.extend(indices)
173
  source_stats["tasks_sampled"][task] = len(indices)
174
  stats["by_task"][task] += len(indices)
175
+
176
  rank_0_print(
177
  f" Data source '{data_source}': Using all {len(sampled_from_source)} trajectories",
178
  verbose=self.verbose,
 
181
  # Sample N trajectories using round-robin to prioritize different tasks
182
  n_to_sample = min(self.n_trajectories_per_source, source_stats["total_available"])
183
  sampled_from_source = []
184
+
185
  # Round-robin sampling: first get one from each task, then repeat
186
  task_iterators = {task: iter(indices) for task, indices in tasks_to_indices.items()}
187
  task_list = all_tasks.copy()
188
  round_idx = 0
189
+
190
  while len(sampled_from_source) < n_to_sample:
191
  # If we've gone through all tasks once, reshuffle for next round
192
  if round_idx >= len(task_list):
193
  round_idx = 0
194
  self._local_random.shuffle(task_list)
195
+
196
  # Try to get one trajectory from current task
197
  task = task_list[round_idx]
198
  try:
 
206
  if not task_list:
207
  break # All tasks exhausted
208
  continue
209
+
210
  round_idx += 1
211
+
212
  rank_0_print(
213
  f" Data source '{data_source}': Sampled {len(sampled_from_source)} out of {source_stats['total_available']} trajectories",
214
  verbose=self.verbose,
 
217
  f" Tasks sampled: {dict(sorted(source_stats['tasks_sampled'].items()))}",
218
  verbose=self.verbose,
219
  )
220
+
221
  # Track trajectory to task mapping for stats
222
  for traj_idx in sampled_from_source:
223
  traj = self.dataset[traj_idx]
224
  traj_id = traj.get("id", str(traj_idx))
225
  stats["traj_to_task"][traj_id] = traj.get("task", "unknown")
226
+
227
  sampled_indices.extend(sampled_from_source)
228
  stats["by_source"][data_source] = source_stats
229
 
 
231
 
232
  def _print_sampling_stats(self, stats: dict):
233
  """Print detailed statistics about sampled trajectories.
234
+
235
  Args:
236
  stats: Statistics dictionary from _sample_trajectories_by_data_source
237
  """
238
  if not self.verbose:
239
  return
240
+
241
  rank_0_print("\n=== Confusion Matrix Sampling Statistics ===", verbose=self.verbose)
242
+
243
  # Overall task statistics
244
  rank_0_print(f"\nOverall trajectories per video task:", verbose=self.verbose)
245
  for task, count in sorted(stats["by_task"].items()):
246
  rank_0_print(f" {task}: {count} trajectories", verbose=self.verbose)
247
+
248
  # Per data source statistics
249
  rank_0_print(f"\nPer data source breakdown:", verbose=self.verbose)
250
  for data_source, source_stats in stats["by_source"].items():
251
  rank_0_print(f" Data source: {data_source}", verbose=self.verbose)
252
  rank_0_print(f" Total available: {source_stats['total_available']}", verbose=self.verbose)
253
  rank_0_print(f" Tasks available: {len(source_stats['tasks_available'])}", verbose=self.verbose)
254
+ for task, count in sorted(source_stats["tasks_available"].items()):
255
+ sampled_count = source_stats["tasks_sampled"].get(task, 0)
256
  rank_0_print(
257
  f" {task}: {sampled_count}/{count} trajectories sampled",
258
  verbose=self.verbose,
259
  )
260
+
261
  rank_0_print("=" * 50, verbose=self.verbose)
262
 
263
  def _generate_sample_from_indices(self, sample_idx_info: dict) -> PreferenceSample: