Spaces:
Running
Running
Anthony Liang commited on
Commit ·
cc5bab9
1
Parent(s): 3e462dd
update
Browse files- app.py +70 -52
- eval_utils.py +110 -6
- eval_viz_utils.py +1 -1
- samplers/eval/confusion_matrix.py +27 -27
app.py
CHANGED
|
@@ -75,15 +75,17 @@ _server_state = {
|
|
| 75 |
}
|
| 76 |
|
| 77 |
|
| 78 |
-
def discover_available_models(
|
|
|
|
|
|
|
| 79 |
"""Discover available models by pinging ports in the specified range.
|
| 80 |
-
|
| 81 |
Returns:
|
| 82 |
List of tuples: [(server_url, model_name), ...]
|
| 83 |
"""
|
| 84 |
available_models = []
|
| 85 |
start_port, end_port = port_range
|
| 86 |
-
|
| 87 |
for port in range(start_port, end_port + 1):
|
| 88 |
server_url = f"{base_url.rstrip('/')}:{port}"
|
| 89 |
try:
|
|
@@ -108,7 +110,7 @@ def discover_available_models(base_url: str = "http://40.119.56.66", port_range:
|
|
| 108 |
except requests.exceptions.RequestException:
|
| 109 |
# Port not available, continue
|
| 110 |
continue
|
| 111 |
-
|
| 112 |
return available_models
|
| 113 |
|
| 114 |
|
|
@@ -116,7 +118,7 @@ def get_model_info_for_url(server_url: str) -> Optional[str]:
|
|
| 116 |
"""Get formatted model info for a given server URL."""
|
| 117 |
if not server_url:
|
| 118 |
return None
|
| 119 |
-
|
| 120 |
try:
|
| 121 |
model_info_url = server_url.rstrip("/") + "/model_info"
|
| 122 |
model_info_response = requests.get(model_info_url, timeout=5.0)
|
|
@@ -325,7 +327,7 @@ def process_single_video(
|
|
| 325 |
# Get server URL from state if not provided
|
| 326 |
if not server_url:
|
| 327 |
server_url = _server_state.get("server_url")
|
| 328 |
-
|
| 329 |
if not server_url:
|
| 330 |
return None, "Please select a model from the dropdown above and ensure it's connected."
|
| 331 |
|
|
@@ -435,7 +437,7 @@ def process_two_videos(
|
|
| 435 |
# Get server URL from state if not provided
|
| 436 |
if not server_url:
|
| 437 |
server_url = _server_state.get("server_url")
|
| 438 |
-
|
| 439 |
if not server_url:
|
| 440 |
return "Please select a model from the dropdown above and ensure it's connected.", None, None
|
| 441 |
|
|
@@ -560,7 +562,7 @@ def process_two_videos(
|
|
| 560 |
# - Video A as reference trajectory
|
| 561 |
# - Video B as similar trajectory
|
| 562 |
# diff_trajectory is None in inference mode (only need similarity between ref and sim)
|
| 563 |
-
|
| 564 |
# Create SimilaritySample with Video A as ref and Video B as sim
|
| 565 |
similarity_sample = SimilaritySample(
|
| 566 |
ref_trajectory=trajectory_a,
|
|
@@ -601,8 +603,6 @@ def process_two_videos(
|
|
| 601 |
return f"Error processing videos: {str(e)}", None, None
|
| 602 |
|
| 603 |
|
| 604 |
-
|
| 605 |
-
|
| 606 |
# Create Gradio interface
|
| 607 |
try:
|
| 608 |
# Try with theme (Gradio 4.0+)
|
|
@@ -633,10 +633,10 @@ with demo:
|
|
| 633 |
None,
|
| 634 |
{}, # Empty mapping
|
| 635 |
)
|
| 636 |
-
|
| 637 |
_server_state["base_url"] = base_url
|
| 638 |
models = discover_available_models(base_url, port_range=(8000, 8010))
|
| 639 |
-
|
| 640 |
if not models:
|
| 641 |
return (
|
| 642 |
gr.update(choices=[], value=None),
|
|
@@ -645,7 +645,7 @@ with demo:
|
|
| 645 |
None,
|
| 646 |
{}, # Empty mapping
|
| 647 |
)
|
| 648 |
-
|
| 649 |
# Format choices: show model_name in dropdown
|
| 650 |
# Store mapping of model_name to URL in state
|
| 651 |
choices = []
|
|
@@ -653,17 +653,17 @@ with demo:
|
|
| 653 |
for url, name in models:
|
| 654 |
choices.append(name)
|
| 655 |
url_map[name] = url
|
| 656 |
-
|
| 657 |
# Auto-select first model
|
| 658 |
selected_choice = choices[0] if choices else None
|
| 659 |
selected_url = url_map.get(selected_choice) if selected_choice else None
|
| 660 |
-
|
| 661 |
# Get model info for selected model
|
| 662 |
model_info_text = get_model_info_for_url(selected_url) if selected_url else ""
|
| 663 |
status_text = f"✅ Found {len(models)} model(s). Auto-selected first model."
|
| 664 |
-
|
| 665 |
_server_state["server_url"] = selected_url
|
| 666 |
-
|
| 667 |
return (
|
| 668 |
gr.update(choices=choices, value=selected_choice),
|
| 669 |
gr.update(value=status_text, visible=True),
|
|
@@ -680,23 +680,25 @@ with demo:
|
|
| 680 |
gr.update(value="", visible=True),
|
| 681 |
None,
|
| 682 |
)
|
| 683 |
-
|
| 684 |
# Get URL from mapping
|
| 685 |
server_url = url_mapping.get(model_choice) if url_mapping else None
|
| 686 |
-
|
| 687 |
if not server_url:
|
| 688 |
return (
|
| 689 |
-
gr.update(
|
|
|
|
|
|
|
| 690 |
gr.update(value="", visible=True),
|
| 691 |
None,
|
| 692 |
)
|
| 693 |
-
|
| 694 |
# Get model info
|
| 695 |
model_info_text = get_model_info_for_url(server_url) or ""
|
| 696 |
status, health_data, _ = check_server_health(server_url)
|
| 697 |
-
|
| 698 |
_server_state["server_url"] = server_url
|
| 699 |
-
|
| 700 |
return (
|
| 701 |
gr.update(value=status, visible=True),
|
| 702 |
gr.update(value=model_info_text, visible=True),
|
|
@@ -706,16 +708,16 @@ with demo:
|
|
| 706 |
# Use Gradio's built-in Sidebar component (collapsible by default)
|
| 707 |
with gr.Sidebar():
|
| 708 |
gr.Markdown("### 🔧 Model Configuration")
|
| 709 |
-
|
| 710 |
base_url_input = gr.Textbox(
|
| 711 |
label="Base Server URL",
|
| 712 |
placeholder="http://40.119.56.66",
|
| 713 |
value="http://40.119.56.66",
|
| 714 |
interactive=True,
|
| 715 |
)
|
| 716 |
-
|
| 717 |
discover_btn = gr.Button("🔍 Discover Models", variant="primary", size="lg")
|
| 718 |
-
|
| 719 |
model_dropdown = gr.Dropdown(
|
| 720 |
label="Select Model",
|
| 721 |
choices=[],
|
|
@@ -723,11 +725,9 @@ with demo:
|
|
| 723 |
interactive=True,
|
| 724 |
info="Models will be discovered on ports 8000-8010",
|
| 725 |
)
|
| 726 |
-
|
| 727 |
-
server_status = gr.Markdown(
|
| 728 |
-
|
| 729 |
-
)
|
| 730 |
-
|
| 731 |
gr.Markdown("---")
|
| 732 |
gr.Markdown("### 📋 Model Information")
|
| 733 |
model_info_display = gr.Markdown("")
|
|
@@ -848,7 +848,9 @@ with demo:
|
|
| 848 |
gr.update(visible=False),
|
| 849 |
)
|
| 850 |
|
| 851 |
-
video_path, task, quality_label, partial_success = get_trajectory_video_path(
|
|
|
|
|
|
|
| 852 |
if video_path:
|
| 853 |
# Build metadata text
|
| 854 |
metadata_lines = []
|
|
@@ -937,7 +939,9 @@ with demo:
|
|
| 937 |
if dataset is None:
|
| 938 |
return gr.update(visible=False), gr.update(visible=False)
|
| 939 |
|
| 940 |
-
video_path, task, quality_label, partial_success = get_trajectory_video_path(
|
|
|
|
|
|
|
| 941 |
if video_path:
|
| 942 |
# Build metadata text
|
| 943 |
metadata_lines = []
|
|
@@ -1009,7 +1013,13 @@ with demo:
|
|
| 1009 |
|
| 1010 |
analyze_single_btn.click(
|
| 1011 |
fn=process_single_video,
|
| 1012 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1013 |
outputs=[progress_plot, info_output],
|
| 1014 |
api_name="process_single_video",
|
| 1015 |
)
|
|
@@ -1103,7 +1113,7 @@ with demo:
|
|
| 1103 |
with gr.Row():
|
| 1104 |
video_a_display = gr.Video(label="Video A", height=400)
|
| 1105 |
video_b_display = gr.Video(label="Video B", height=400)
|
| 1106 |
-
|
| 1107 |
# Result text at the bottom
|
| 1108 |
result_text = gr.Markdown("")
|
| 1109 |
|
|
@@ -1161,7 +1171,9 @@ with demo:
|
|
| 1161 |
gr.update(visible=False),
|
| 1162 |
)
|
| 1163 |
|
| 1164 |
-
video_path, task, quality_label, partial_success = get_trajectory_video_path(
|
|
|
|
|
|
|
| 1165 |
if video_path:
|
| 1166 |
# Build metadata text
|
| 1167 |
metadata_lines = []
|
|
@@ -1246,7 +1258,9 @@ with demo:
|
|
| 1246 |
if dataset is None:
|
| 1247 |
return gr.update(visible=False), gr.update(visible=False)
|
| 1248 |
|
| 1249 |
-
video_path, task, quality_label, partial_success = get_trajectory_video_path(
|
|
|
|
|
|
|
| 1250 |
if video_path:
|
| 1251 |
# Build metadata text
|
| 1252 |
metadata_lines = []
|
|
@@ -1302,7 +1316,9 @@ with demo:
|
|
| 1302 |
gr.update(visible=False),
|
| 1303 |
)
|
| 1304 |
|
| 1305 |
-
video_path, task, quality_label, partial_success = get_trajectory_video_path(
|
|
|
|
|
|
|
| 1306 |
if video_path:
|
| 1307 |
# Build metadata text
|
| 1308 |
metadata_lines = []
|
|
@@ -1387,7 +1403,9 @@ with demo:
|
|
| 1387 |
if dataset is None:
|
| 1388 |
return gr.update(visible=False), gr.update(visible=False)
|
| 1389 |
|
| 1390 |
-
video_path, task, quality_label, partial_success = get_trajectory_video_path(
|
|
|
|
|
|
|
| 1391 |
if video_path:
|
| 1392 |
# Build metadata text
|
| 1393 |
metadata_lines = []
|
|
@@ -1405,13 +1423,9 @@ with demo:
|
|
| 1405 |
return gr.update(visible=False), gr.update(visible=False)
|
| 1406 |
|
| 1407 |
# Video A dataset selection handlers
|
| 1408 |
-
dataset_name_a.change(
|
| 1409 |
-
fn=update_config_choices_a, inputs=[dataset_name_a], outputs=[config_name_a]
|
| 1410 |
-
)
|
| 1411 |
|
| 1412 |
-
refresh_configs_btn_a.click(
|
| 1413 |
-
fn=update_config_choices_a, inputs=[dataset_name_a], outputs=[config_name_a]
|
| 1414 |
-
)
|
| 1415 |
|
| 1416 |
load_dataset_btn_a.click(
|
| 1417 |
fn=load_dataset_a,
|
|
@@ -1454,13 +1468,9 @@ with demo:
|
|
| 1454 |
)
|
| 1455 |
|
| 1456 |
# Video B dataset selection handlers
|
| 1457 |
-
dataset_name_b.change(
|
| 1458 |
-
fn=update_config_choices_b, inputs=[dataset_name_b], outputs=[config_name_b]
|
| 1459 |
-
)
|
| 1460 |
|
| 1461 |
-
refresh_configs_btn_b.click(
|
| 1462 |
-
fn=update_config_choices_b, inputs=[dataset_name_b], outputs=[config_name_b]
|
| 1463 |
-
)
|
| 1464 |
|
| 1465 |
load_dataset_btn_b.click(
|
| 1466 |
fn=load_dataset_b,
|
|
@@ -1504,7 +1514,15 @@ with demo:
|
|
| 1504 |
|
| 1505 |
analyze_dual_btn.click(
|
| 1506 |
fn=process_two_videos,
|
| 1507 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1508 |
outputs=[result_text, video_a_display, video_b_display],
|
| 1509 |
api_name="process_two_videos",
|
| 1510 |
)
|
|
|
|
| 75 |
}
|
| 76 |
|
| 77 |
|
| 78 |
+
def discover_available_models(
|
| 79 |
+
base_url: str = "http://40.119.56.66", port_range: tuple = (8000, 8010)
|
| 80 |
+
) -> List[Tuple[str, str]]:
|
| 81 |
"""Discover available models by pinging ports in the specified range.
|
| 82 |
+
|
| 83 |
Returns:
|
| 84 |
List of tuples: [(server_url, model_name), ...]
|
| 85 |
"""
|
| 86 |
available_models = []
|
| 87 |
start_port, end_port = port_range
|
| 88 |
+
|
| 89 |
for port in range(start_port, end_port + 1):
|
| 90 |
server_url = f"{base_url.rstrip('/')}:{port}"
|
| 91 |
try:
|
|
|
|
| 110 |
except requests.exceptions.RequestException:
|
| 111 |
# Port not available, continue
|
| 112 |
continue
|
| 113 |
+
|
| 114 |
return available_models
|
| 115 |
|
| 116 |
|
|
|
|
| 118 |
"""Get formatted model info for a given server URL."""
|
| 119 |
if not server_url:
|
| 120 |
return None
|
| 121 |
+
|
| 122 |
try:
|
| 123 |
model_info_url = server_url.rstrip("/") + "/model_info"
|
| 124 |
model_info_response = requests.get(model_info_url, timeout=5.0)
|
|
|
|
| 327 |
# Get server URL from state if not provided
|
| 328 |
if not server_url:
|
| 329 |
server_url = _server_state.get("server_url")
|
| 330 |
+
|
| 331 |
if not server_url:
|
| 332 |
return None, "Please select a model from the dropdown above and ensure it's connected."
|
| 333 |
|
|
|
|
| 437 |
# Get server URL from state if not provided
|
| 438 |
if not server_url:
|
| 439 |
server_url = _server_state.get("server_url")
|
| 440 |
+
|
| 441 |
if not server_url:
|
| 442 |
return "Please select a model from the dropdown above and ensure it's connected.", None, None
|
| 443 |
|
|
|
|
| 562 |
# - Video A as reference trajectory
|
| 563 |
# - Video B as similar trajectory
|
| 564 |
# diff_trajectory is None in inference mode (only need similarity between ref and sim)
|
| 565 |
+
|
| 566 |
# Create SimilaritySample with Video A as ref and Video B as sim
|
| 567 |
similarity_sample = SimilaritySample(
|
| 568 |
ref_trajectory=trajectory_a,
|
|
|
|
| 603 |
return f"Error processing videos: {str(e)}", None, None
|
| 604 |
|
| 605 |
|
|
|
|
|
|
|
| 606 |
# Create Gradio interface
|
| 607 |
try:
|
| 608 |
# Try with theme (Gradio 4.0+)
|
|
|
|
| 633 |
None,
|
| 634 |
{}, # Empty mapping
|
| 635 |
)
|
| 636 |
+
|
| 637 |
_server_state["base_url"] = base_url
|
| 638 |
models = discover_available_models(base_url, port_range=(8000, 8010))
|
| 639 |
+
|
| 640 |
if not models:
|
| 641 |
return (
|
| 642 |
gr.update(choices=[], value=None),
|
|
|
|
| 645 |
None,
|
| 646 |
{}, # Empty mapping
|
| 647 |
)
|
| 648 |
+
|
| 649 |
# Format choices: show model_name in dropdown
|
| 650 |
# Store mapping of model_name to URL in state
|
| 651 |
choices = []
|
|
|
|
| 653 |
for url, name in models:
|
| 654 |
choices.append(name)
|
| 655 |
url_map[name] = url
|
| 656 |
+
|
| 657 |
# Auto-select first model
|
| 658 |
selected_choice = choices[0] if choices else None
|
| 659 |
selected_url = url_map.get(selected_choice) if selected_choice else None
|
| 660 |
+
|
| 661 |
# Get model info for selected model
|
| 662 |
model_info_text = get_model_info_for_url(selected_url) if selected_url else ""
|
| 663 |
status_text = f"✅ Found {len(models)} model(s). Auto-selected first model."
|
| 664 |
+
|
| 665 |
_server_state["server_url"] = selected_url
|
| 666 |
+
|
| 667 |
return (
|
| 668 |
gr.update(choices=choices, value=selected_choice),
|
| 669 |
gr.update(value=status_text, visible=True),
|
|
|
|
| 680 |
gr.update(value="", visible=True),
|
| 681 |
None,
|
| 682 |
)
|
| 683 |
+
|
| 684 |
# Get URL from mapping
|
| 685 |
server_url = url_mapping.get(model_choice) if url_mapping else None
|
| 686 |
+
|
| 687 |
if not server_url:
|
| 688 |
return (
|
| 689 |
+
gr.update(
|
| 690 |
+
value="Could not find server URL for selected model. Please rediscover models.", visible=True
|
| 691 |
+
),
|
| 692 |
gr.update(value="", visible=True),
|
| 693 |
None,
|
| 694 |
)
|
| 695 |
+
|
| 696 |
# Get model info
|
| 697 |
model_info_text = get_model_info_for_url(server_url) or ""
|
| 698 |
status, health_data, _ = check_server_health(server_url)
|
| 699 |
+
|
| 700 |
_server_state["server_url"] = server_url
|
| 701 |
+
|
| 702 |
return (
|
| 703 |
gr.update(value=status, visible=True),
|
| 704 |
gr.update(value=model_info_text, visible=True),
|
|
|
|
| 708 |
# Use Gradio's built-in Sidebar component (collapsible by default)
|
| 709 |
with gr.Sidebar():
|
| 710 |
gr.Markdown("### 🔧 Model Configuration")
|
| 711 |
+
|
| 712 |
base_url_input = gr.Textbox(
|
| 713 |
label="Base Server URL",
|
| 714 |
placeholder="http://40.119.56.66",
|
| 715 |
value="http://40.119.56.66",
|
| 716 |
interactive=True,
|
| 717 |
)
|
| 718 |
+
|
| 719 |
discover_btn = gr.Button("🔍 Discover Models", variant="primary", size="lg")
|
| 720 |
+
|
| 721 |
model_dropdown = gr.Dropdown(
|
| 722 |
label="Select Model",
|
| 723 |
choices=[],
|
|
|
|
| 725 |
interactive=True,
|
| 726 |
info="Models will be discovered on ports 8000-8010",
|
| 727 |
)
|
| 728 |
+
|
| 729 |
+
server_status = gr.Markdown("Click 'Discover Models' to find available models")
|
| 730 |
+
|
|
|
|
|
|
|
| 731 |
gr.Markdown("---")
|
| 732 |
gr.Markdown("### 📋 Model Information")
|
| 733 |
model_info_display = gr.Markdown("")
|
|
|
|
| 848 |
gr.update(visible=False),
|
| 849 |
)
|
| 850 |
|
| 851 |
+
video_path, task, quality_label, partial_success = get_trajectory_video_path(
|
| 852 |
+
dataset, index, dataset_name
|
| 853 |
+
)
|
| 854 |
if video_path:
|
| 855 |
# Build metadata text
|
| 856 |
metadata_lines = []
|
|
|
|
| 939 |
if dataset is None:
|
| 940 |
return gr.update(visible=False), gr.update(visible=False)
|
| 941 |
|
| 942 |
+
video_path, task, quality_label, partial_success = get_trajectory_video_path(
|
| 943 |
+
dataset, index, dataset_name
|
| 944 |
+
)
|
| 945 |
if video_path:
|
| 946 |
# Build metadata text
|
| 947 |
metadata_lines = []
|
|
|
|
| 1013 |
|
| 1014 |
analyze_single_btn.click(
|
| 1015 |
fn=process_single_video,
|
| 1016 |
+
inputs=[
|
| 1017 |
+
single_video_input,
|
| 1018 |
+
task_text_input,
|
| 1019 |
+
server_url_state,
|
| 1020 |
+
fps_input_single,
|
| 1021 |
+
use_frame_steps_single,
|
| 1022 |
+
],
|
| 1023 |
outputs=[progress_plot, info_output],
|
| 1024 |
api_name="process_single_video",
|
| 1025 |
)
|
|
|
|
| 1113 |
with gr.Row():
|
| 1114 |
video_a_display = gr.Video(label="Video A", height=400)
|
| 1115 |
video_b_display = gr.Video(label="Video B", height=400)
|
| 1116 |
+
|
| 1117 |
# Result text at the bottom
|
| 1118 |
result_text = gr.Markdown("")
|
| 1119 |
|
|
|
|
| 1171 |
gr.update(visible=False),
|
| 1172 |
)
|
| 1173 |
|
| 1174 |
+
video_path, task, quality_label, partial_success = get_trajectory_video_path(
|
| 1175 |
+
dataset, index, dataset_name
|
| 1176 |
+
)
|
| 1177 |
if video_path:
|
| 1178 |
# Build metadata text
|
| 1179 |
metadata_lines = []
|
|
|
|
| 1258 |
if dataset is None:
|
| 1259 |
return gr.update(visible=False), gr.update(visible=False)
|
| 1260 |
|
| 1261 |
+
video_path, task, quality_label, partial_success = get_trajectory_video_path(
|
| 1262 |
+
dataset, index, dataset_name
|
| 1263 |
+
)
|
| 1264 |
if video_path:
|
| 1265 |
# Build metadata text
|
| 1266 |
metadata_lines = []
|
|
|
|
| 1316 |
gr.update(visible=False),
|
| 1317 |
)
|
| 1318 |
|
| 1319 |
+
video_path, task, quality_label, partial_success = get_trajectory_video_path(
|
| 1320 |
+
dataset, index, dataset_name
|
| 1321 |
+
)
|
| 1322 |
if video_path:
|
| 1323 |
# Build metadata text
|
| 1324 |
metadata_lines = []
|
|
|
|
| 1403 |
if dataset is None:
|
| 1404 |
return gr.update(visible=False), gr.update(visible=False)
|
| 1405 |
|
| 1406 |
+
video_path, task, quality_label, partial_success = get_trajectory_video_path(
|
| 1407 |
+
dataset, index, dataset_name
|
| 1408 |
+
)
|
| 1409 |
if video_path:
|
| 1410 |
# Build metadata text
|
| 1411 |
metadata_lines = []
|
|
|
|
| 1423 |
return gr.update(visible=False), gr.update(visible=False)
|
| 1424 |
|
| 1425 |
# Video A dataset selection handlers
|
| 1426 |
+
dataset_name_a.change(fn=update_config_choices_a, inputs=[dataset_name_a], outputs=[config_name_a])
|
|
|
|
|
|
|
| 1427 |
|
| 1428 |
+
refresh_configs_btn_a.click(fn=update_config_choices_a, inputs=[dataset_name_a], outputs=[config_name_a])
|
|
|
|
|
|
|
| 1429 |
|
| 1430 |
load_dataset_btn_a.click(
|
| 1431 |
fn=load_dataset_a,
|
|
|
|
| 1468 |
)
|
| 1469 |
|
| 1470 |
# Video B dataset selection handlers
|
| 1471 |
+
dataset_name_b.change(fn=update_config_choices_b, inputs=[dataset_name_b], outputs=[config_name_b])
|
|
|
|
|
|
|
| 1472 |
|
| 1473 |
+
refresh_configs_btn_b.click(fn=update_config_choices_b, inputs=[dataset_name_b], outputs=[config_name_b])
|
|
|
|
|
|
|
| 1474 |
|
| 1475 |
load_dataset_btn_b.click(
|
| 1476 |
fn=load_dataset_b,
|
|
|
|
| 1514 |
|
| 1515 |
analyze_dual_btn.click(
|
| 1516 |
fn=process_two_videos,
|
| 1517 |
+
inputs=[
|
| 1518 |
+
video_a_input,
|
| 1519 |
+
video_b_input,
|
| 1520 |
+
task_text_dual,
|
| 1521 |
+
prediction_type,
|
| 1522 |
+
server_url_state,
|
| 1523 |
+
fps_input_dual,
|
| 1524 |
+
use_frame_steps_dual,
|
| 1525 |
+
],
|
| 1526 |
outputs=[result_text, video_a_display, video_b_display],
|
| 1527 |
api_name="process_two_videos",
|
| 1528 |
)
|
eval_utils.py
CHANGED
|
@@ -15,8 +15,112 @@ import numpy as np
|
|
| 15 |
import requests
|
| 16 |
import torch
|
| 17 |
|
| 18 |
-
from
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
def extract_answer_from_text(text: str) -> str:
|
|
@@ -219,10 +323,10 @@ async def post_batch_npy_async(
|
|
| 219 |
|
| 220 |
async def parse_npy_form_data(form_data: Any) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
|
| 221 |
"""Parse multipart form data to extract numpy arrays and other data.
|
| 222 |
-
|
| 223 |
Args:
|
| 224 |
form_data: FastAPI form data from request.form()
|
| 225 |
-
|
| 226 |
Returns:
|
| 227 |
Tuple of (numpy_arrays dict, other_data dict)
|
| 228 |
"""
|
|
@@ -271,7 +375,7 @@ def reconstruct_payload_from_npy(
|
|
| 271 |
other_data: Dictionary of other form data
|
| 272 |
trajectory_keys: List of trajectory keys to process (default: common keys)
|
| 273 |
convert_embeddings_to_torch: Whether to convert embeddings to torch tensors
|
| 274 |
-
|
| 275 |
Returns:
|
| 276 |
List of reconstructed sample dictionaries
|
| 277 |
"""
|
|
@@ -284,7 +388,7 @@ def reconstruct_payload_from_npy(
|
|
| 284 |
"traj_diff_trajectory",
|
| 285 |
"trajectory",
|
| 286 |
]
|
| 287 |
-
|
| 288 |
samples = []
|
| 289 |
|
| 290 |
# Process each sample
|
|
|
|
| 15 |
import requests
|
| 16 |
import torch
|
| 17 |
|
| 18 |
+
from dataset_types import PreferenceSample, SimilaritySample, ProgressSample, Trajectory
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def pad_trajectory_to_max_frames_np(
|
| 22 |
+
frames: np.ndarray, progress: List[float], max_frames: int, pad_from: str = "right"
|
| 23 |
+
) -> Tuple[np.ndarray, List[float]]:
|
| 24 |
+
"""Pad trajectory frames and progress to max_frames by repeating the first frame/progress if needed.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
frames: Trajectory frames (numpy array)
|
| 28 |
+
progress: Progress values (list of floats)
|
| 29 |
+
max_frames: Target number of frames
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Tuple[np.ndarray, List[float]: (padded_frames, padded_progress)
|
| 33 |
+
"""
|
| 34 |
+
current_frames = frames.shape[0]
|
| 35 |
+
|
| 36 |
+
if current_frames >= max_frames:
|
| 37 |
+
# No padding needed
|
| 38 |
+
return frames, progress
|
| 39 |
+
|
| 40 |
+
if pad_from == "left":
|
| 41 |
+
pad_frame = frames[0:1] # Keep the batch dimension
|
| 42 |
+
pad_progress = progress[0]
|
| 43 |
+
else:
|
| 44 |
+
pad_frame = frames[-1:]
|
| 45 |
+
pad_progress = progress[-1]
|
| 46 |
+
|
| 47 |
+
# Calculate how many frames to pad
|
| 48 |
+
frames_to_pad = max_frames - current_frames
|
| 49 |
+
|
| 50 |
+
# Pad frames by repeating the first frame
|
| 51 |
+
if pad_from == "left":
|
| 52 |
+
padded_frames = np.concatenate([np.repeat(pad_frame, frames_to_pad, axis=0), frames], axis=0)
|
| 53 |
+
padded_progress = [pad_progress] * frames_to_pad + progress
|
| 54 |
+
else:
|
| 55 |
+
padded_frames = np.concatenate([frames, np.repeat(pad_frame, frames_to_pad, axis=0)], axis=0)
|
| 56 |
+
padded_progress = progress + [pad_progress] * frames_to_pad
|
| 57 |
+
|
| 58 |
+
return padded_frames, padded_progress
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def linspace_subsample_frames(
|
| 62 |
+
frames: np.ndarray, num_frames: int = 8, end_idx: Optional[int] = None
|
| 63 |
+
) -> Tuple[np.ndarray, List[int]]:
|
| 64 |
+
"""Uniformly subsample frames from a trajectory and return the indices.
|
| 65 |
+
|
| 66 |
+
This method takes the full trajectory (e.g., 64 frames) and uniformly subsamples
|
| 67 |
+
num_frames from it. The first and last frames are always included.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
frames: Full trajectory frames (N frames)
|
| 71 |
+
num_frames: Number of frames to subsample (default: 8)
|
| 72 |
+
end_idx: Optional end index to subsample up to (if None, uses total_frames - 1)
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Tuple[np.ndarray, List[int]: (subsampled_frames, subsampled_indices)
|
| 76 |
+
"""
|
| 77 |
+
if hasattr(frames, "shape"):
|
| 78 |
+
total_frames = frames.shape[0]
|
| 79 |
+
else:
|
| 80 |
+
total_frames = len(frames)
|
| 81 |
+
|
| 82 |
+
if total_frames <= 0:
|
| 83 |
+
return frames, []
|
| 84 |
+
|
| 85 |
+
# Use end_idx if provided, otherwise use full trajectory
|
| 86 |
+
if end_idx is not None:
|
| 87 |
+
end_idx = min(end_idx, total_frames - 1)
|
| 88 |
+
frames_to_subsample = frames[: end_idx + 1]
|
| 89 |
+
effective_total = end_idx + 1
|
| 90 |
+
else:
|
| 91 |
+
frames_to_subsample = frames
|
| 92 |
+
effective_total = total_frames
|
| 93 |
+
|
| 94 |
+
if effective_total <= num_frames:
|
| 95 |
+
# If we have fewer (or equal) frames than requested, return all frames
|
| 96 |
+
indices = list(range(effective_total))
|
| 97 |
+
return frames_to_subsample, indices
|
| 98 |
+
|
| 99 |
+
# Special case: if num_frames == 1, always take the last frame
|
| 100 |
+
if num_frames == 1:
|
| 101 |
+
indices = [effective_total - 1]
|
| 102 |
+
subsampled_frames = frames_to_subsample[indices]
|
| 103 |
+
return subsampled_frames, indices
|
| 104 |
+
|
| 105 |
+
# Evenly spaced indices from 0 to effective_total-1, inclusive
|
| 106 |
+
indices_np = np.linspace(0, effective_total - 1, num_frames)
|
| 107 |
+
indices = np.rint(indices_np).astype(int).tolist()
|
| 108 |
+
|
| 109 |
+
# Enforce first and last explicitly
|
| 110 |
+
indices[0] = 0
|
| 111 |
+
indices[-1] = effective_total - 1
|
| 112 |
+
|
| 113 |
+
# Ensure indices are strictly non-decreasing and within bounds
|
| 114 |
+
for k in range(1, len(indices)):
|
| 115 |
+
if indices[k] < indices[k - 1]:
|
| 116 |
+
indices[k] = indices[k - 1]
|
| 117 |
+
if indices[k] >= effective_total:
|
| 118 |
+
indices[k] = effective_total - 1
|
| 119 |
+
|
| 120 |
+
# Subsample frames
|
| 121 |
+
subsampled_frames = frames_to_subsample[indices]
|
| 122 |
+
|
| 123 |
+
return subsampled_frames, indices
|
| 124 |
|
| 125 |
|
| 126 |
def extract_answer_from_text(text: str) -> str:
|
|
|
|
| 323 |
|
| 324 |
async def parse_npy_form_data(form_data: Any) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
|
| 325 |
"""Parse multipart form data to extract numpy arrays and other data.
|
| 326 |
+
|
| 327 |
Args:
|
| 328 |
form_data: FastAPI form data from request.form()
|
| 329 |
+
|
| 330 |
Returns:
|
| 331 |
Tuple of (numpy_arrays dict, other_data dict)
|
| 332 |
"""
|
|
|
|
| 375 |
other_data: Dictionary of other form data
|
| 376 |
trajectory_keys: List of trajectory keys to process (default: common keys)
|
| 377 |
convert_embeddings_to_torch: Whether to convert embeddings to torch tensors
|
| 378 |
+
|
| 379 |
Returns:
|
| 380 |
List of reconstructed sample dictionaries
|
| 381 |
"""
|
|
|
|
| 388 |
"traj_diff_trajectory",
|
| 389 |
"trajectory",
|
| 390 |
]
|
| 391 |
+
|
| 392 |
samples = []
|
| 393 |
|
| 394 |
# Process each sample
|
eval_viz_utils.py
CHANGED
|
@@ -180,7 +180,7 @@ def extract_frames(video_path: str, fps: float = 1.0, max_frames: int = 64) -> n
|
|
| 180 |
|
| 181 |
# Clamp to [1, total_frames]
|
| 182 |
desired_frames = max(1, min(desired_frames, total_frames))
|
| 183 |
-
|
| 184 |
# IMPORTANT: Cap at max_frames to prevent memory issues
|
| 185 |
# This is critical when fps is high or videos are long
|
| 186 |
if desired_frames > max_frames:
|
|
|
|
| 180 |
|
| 181 |
# Clamp to [1, total_frames]
|
| 182 |
desired_frames = max(1, min(desired_frames, total_frames))
|
| 183 |
+
|
| 184 |
# IMPORTANT: Cap at max_frames to prevent memory issues
|
| 185 |
# This is critical when fps is high or videos are long
|
| 186 |
if desired_frames > max_frames:
|
samplers/eval/confusion_matrix.py
CHANGED
|
@@ -60,7 +60,7 @@ class ConfusionMatrixSampler(RFMBaseSampler):
|
|
| 60 |
|
| 61 |
def _generate_all_sample_indices(self) -> list[dict]:
|
| 62 |
"""Generate all possible task-trajectory pair sample indices.
|
| 63 |
-
|
| 64 |
If multiple data sources exist, samples N random trajectories from each data source.
|
| 65 |
Prioritizes different video tasks first, then prioritizes different language instructions
|
| 66 |
when creating pairs.
|
|
@@ -73,7 +73,7 @@ class ConfusionMatrixSampler(RFMBaseSampler):
|
|
| 73 |
|
| 74 |
# Sample trajectories per data source (prioritizing different video tasks)
|
| 75 |
sampled_trajectories, stats = self._sample_trajectories_by_data_source()
|
| 76 |
-
|
| 77 |
rank_0_print(
|
| 78 |
f"Processing {len(sampled_trajectories)} trajectories for confusion matrix analysis",
|
| 79 |
verbose=self.verbose,
|
|
@@ -88,7 +88,7 @@ class ConfusionMatrixSampler(RFMBaseSampler):
|
|
| 88 |
|
| 89 |
# Create task-trajectory pairs with prioritized language instruction pairing
|
| 90 |
video_task_count = Counter()
|
| 91 |
-
|
| 92 |
for traj_idx in sampled_trajectories:
|
| 93 |
traj = self.dataset[traj_idx]
|
| 94 |
video_task = traj["task"]
|
|
@@ -98,7 +98,7 @@ class ConfusionMatrixSampler(RFMBaseSampler):
|
|
| 98 |
# continue
|
| 99 |
|
| 100 |
video_task_count[video_task] += 1
|
| 101 |
-
|
| 102 |
# Pair this trajectory with all language tasks (shuffled for variety)
|
| 103 |
traj_id = traj.get("id", str(traj_idx))
|
| 104 |
for lang_task in shuffled_lang_tasks:
|
|
@@ -117,15 +117,15 @@ class ConfusionMatrixSampler(RFMBaseSampler):
|
|
| 117 |
rank_0_print(f"Generated {len(sample_indices)} task-trajectory pairs", verbose=self.verbose)
|
| 118 |
rank_0_print(f" Video tasks sampled: {dict(video_task_count)}", verbose=self.verbose)
|
| 119 |
rank_0_print(f" Trajectories per video task: {dict(sorted(video_task_count.items()))}", verbose=self.verbose)
|
| 120 |
-
|
| 121 |
return sample_indices
|
| 122 |
|
| 123 |
def _sample_trajectories_by_data_source(self) -> Tuple[list[int], dict]:
|
| 124 |
"""Sample N random trajectories from each data source, prioritizing different video tasks.
|
| 125 |
-
|
| 126 |
When sampling N trajectories, first selects one trajectory from each unique video task,
|
| 127 |
then repeats in round-robin fashion until N trajectories are sampled.
|
| 128 |
-
|
| 129 |
Returns:
|
| 130 |
Tuple of (list of sampled trajectory indices, stats dictionary)
|
| 131 |
"""
|
|
@@ -135,7 +135,7 @@ class ConfusionMatrixSampler(RFMBaseSampler):
|
|
| 135 |
"by_task": Counter(),
|
| 136 |
"traj_to_task": {},
|
| 137 |
}
|
| 138 |
-
|
| 139 |
# Group robot trajectories by data source, then by video task
|
| 140 |
trajectories_by_source_and_task = defaultdict(lambda: defaultdict(list))
|
| 141 |
for traj_idx in self.robot_trajectories:
|
|
@@ -143,7 +143,7 @@ class ConfusionMatrixSampler(RFMBaseSampler):
|
|
| 143 |
data_source = traj.get("data_source", "unknown")
|
| 144 |
video_task = traj.get("task", "unknown")
|
| 145 |
trajectories_by_source_and_task[data_source][video_task].append(traj_idx)
|
| 146 |
-
|
| 147 |
rank_0_print(
|
| 148 |
f"Found {len(trajectories_by_source_and_task)} data sources: {list(trajectories_by_source_and_task.keys())}",
|
| 149 |
verbose=self.verbose,
|
|
@@ -154,17 +154,17 @@ class ConfusionMatrixSampler(RFMBaseSampler):
|
|
| 154 |
# Shuffle trajectories within each task for randomization
|
| 155 |
for task in tasks_to_indices:
|
| 156 |
self._local_random.shuffle(tasks_to_indices[task])
|
| 157 |
-
|
| 158 |
# Get all unique tasks for this data source
|
| 159 |
all_tasks = list(tasks_to_indices.keys())
|
| 160 |
self._local_random.shuffle(all_tasks) # Randomize task order too
|
| 161 |
-
|
| 162 |
source_stats = {
|
| 163 |
"total_available": sum(len(indices) for indices in tasks_to_indices.values()),
|
| 164 |
"tasks_available": {task: len(indices) for task, indices in tasks_to_indices.items()},
|
| 165 |
"tasks_sampled": Counter(),
|
| 166 |
}
|
| 167 |
-
|
| 168 |
if self.n_trajectories_per_source is None:
|
| 169 |
# Use all available trajectories
|
| 170 |
sampled_from_source = []
|
|
@@ -172,7 +172,7 @@ class ConfusionMatrixSampler(RFMBaseSampler):
|
|
| 172 |
sampled_from_source.extend(indices)
|
| 173 |
source_stats["tasks_sampled"][task] = len(indices)
|
| 174 |
stats["by_task"][task] += len(indices)
|
| 175 |
-
|
| 176 |
rank_0_print(
|
| 177 |
f" Data source '{data_source}': Using all {len(sampled_from_source)} trajectories",
|
| 178 |
verbose=self.verbose,
|
|
@@ -181,18 +181,18 @@ class ConfusionMatrixSampler(RFMBaseSampler):
|
|
| 181 |
# Sample N trajectories using round-robin to prioritize different tasks
|
| 182 |
n_to_sample = min(self.n_trajectories_per_source, source_stats["total_available"])
|
| 183 |
sampled_from_source = []
|
| 184 |
-
|
| 185 |
# Round-robin sampling: first get one from each task, then repeat
|
| 186 |
task_iterators = {task: iter(indices) for task, indices in tasks_to_indices.items()}
|
| 187 |
task_list = all_tasks.copy()
|
| 188 |
round_idx = 0
|
| 189 |
-
|
| 190 |
while len(sampled_from_source) < n_to_sample:
|
| 191 |
# If we've gone through all tasks once, reshuffle for next round
|
| 192 |
if round_idx >= len(task_list):
|
| 193 |
round_idx = 0
|
| 194 |
self._local_random.shuffle(task_list)
|
| 195 |
-
|
| 196 |
# Try to get one trajectory from current task
|
| 197 |
task = task_list[round_idx]
|
| 198 |
try:
|
|
@@ -206,9 +206,9 @@ class ConfusionMatrixSampler(RFMBaseSampler):
|
|
| 206 |
if not task_list:
|
| 207 |
break # All tasks exhausted
|
| 208 |
continue
|
| 209 |
-
|
| 210 |
round_idx += 1
|
| 211 |
-
|
| 212 |
rank_0_print(
|
| 213 |
f" Data source '{data_source}': Sampled {len(sampled_from_source)} out of {source_stats['total_available']} trajectories",
|
| 214 |
verbose=self.verbose,
|
|
@@ -217,13 +217,13 @@ class ConfusionMatrixSampler(RFMBaseSampler):
|
|
| 217 |
f" Tasks sampled: {dict(sorted(source_stats['tasks_sampled'].items()))}",
|
| 218 |
verbose=self.verbose,
|
| 219 |
)
|
| 220 |
-
|
| 221 |
# Track trajectory to task mapping for stats
|
| 222 |
for traj_idx in sampled_from_source:
|
| 223 |
traj = self.dataset[traj_idx]
|
| 224 |
traj_id = traj.get("id", str(traj_idx))
|
| 225 |
stats["traj_to_task"][traj_id] = traj.get("task", "unknown")
|
| 226 |
-
|
| 227 |
sampled_indices.extend(sampled_from_source)
|
| 228 |
stats["by_source"][data_source] = source_stats
|
| 229 |
|
|
@@ -231,33 +231,33 @@ class ConfusionMatrixSampler(RFMBaseSampler):
|
|
| 231 |
|
| 232 |
def _print_sampling_stats(self, stats: dict):
|
| 233 |
"""Print detailed statistics about sampled trajectories.
|
| 234 |
-
|
| 235 |
Args:
|
| 236 |
stats: Statistics dictionary from _sample_trajectories_by_data_source
|
| 237 |
"""
|
| 238 |
if not self.verbose:
|
| 239 |
return
|
| 240 |
-
|
| 241 |
rank_0_print("\n=== Confusion Matrix Sampling Statistics ===", verbose=self.verbose)
|
| 242 |
-
|
| 243 |
# Overall task statistics
|
| 244 |
rank_0_print(f"\nOverall trajectories per video task:", verbose=self.verbose)
|
| 245 |
for task, count in sorted(stats["by_task"].items()):
|
| 246 |
rank_0_print(f" {task}: {count} trajectories", verbose=self.verbose)
|
| 247 |
-
|
| 248 |
# Per data source statistics
|
| 249 |
rank_0_print(f"\nPer data source breakdown:", verbose=self.verbose)
|
| 250 |
for data_source, source_stats in stats["by_source"].items():
|
| 251 |
rank_0_print(f" Data source: {data_source}", verbose=self.verbose)
|
| 252 |
rank_0_print(f" Total available: {source_stats['total_available']}", verbose=self.verbose)
|
| 253 |
rank_0_print(f" Tasks available: {len(source_stats['tasks_available'])}", verbose=self.verbose)
|
| 254 |
-
for task, count in sorted(source_stats[
|
| 255 |
-
sampled_count = source_stats[
|
| 256 |
rank_0_print(
|
| 257 |
f" {task}: {sampled_count}/{count} trajectories sampled",
|
| 258 |
verbose=self.verbose,
|
| 259 |
)
|
| 260 |
-
|
| 261 |
rank_0_print("=" * 50, verbose=self.verbose)
|
| 262 |
|
| 263 |
def _generate_sample_from_indices(self, sample_idx_info: dict) -> PreferenceSample:
|
|
|
|
| 60 |
|
| 61 |
def _generate_all_sample_indices(self) -> list[dict]:
|
| 62 |
"""Generate all possible task-trajectory pair sample indices.
|
| 63 |
+
|
| 64 |
If multiple data sources exist, samples N random trajectories from each data source.
|
| 65 |
Prioritizes different video tasks first, then prioritizes different language instructions
|
| 66 |
when creating pairs.
|
|
|
|
| 73 |
|
| 74 |
# Sample trajectories per data source (prioritizing different video tasks)
|
| 75 |
sampled_trajectories, stats = self._sample_trajectories_by_data_source()
|
| 76 |
+
|
| 77 |
rank_0_print(
|
| 78 |
f"Processing {len(sampled_trajectories)} trajectories for confusion matrix analysis",
|
| 79 |
verbose=self.verbose,
|
|
|
|
| 88 |
|
| 89 |
# Create task-trajectory pairs with prioritized language instruction pairing
|
| 90 |
video_task_count = Counter()
|
| 91 |
+
|
| 92 |
for traj_idx in sampled_trajectories:
|
| 93 |
traj = self.dataset[traj_idx]
|
| 94 |
video_task = traj["task"]
|
|
|
|
| 98 |
# continue
|
| 99 |
|
| 100 |
video_task_count[video_task] += 1
|
| 101 |
+
|
| 102 |
# Pair this trajectory with all language tasks (shuffled for variety)
|
| 103 |
traj_id = traj.get("id", str(traj_idx))
|
| 104 |
for lang_task in shuffled_lang_tasks:
|
|
|
|
| 117 |
rank_0_print(f"Generated {len(sample_indices)} task-trajectory pairs", verbose=self.verbose)
|
| 118 |
rank_0_print(f" Video tasks sampled: {dict(video_task_count)}", verbose=self.verbose)
|
| 119 |
rank_0_print(f" Trajectories per video task: {dict(sorted(video_task_count.items()))}", verbose=self.verbose)
|
| 120 |
+
|
| 121 |
return sample_indices
|
| 122 |
|
| 123 |
def _sample_trajectories_by_data_source(self) -> Tuple[list[int], dict]:
|
| 124 |
"""Sample N random trajectories from each data source, prioritizing different video tasks.
|
| 125 |
+
|
| 126 |
When sampling N trajectories, first selects one trajectory from each unique video task,
|
| 127 |
then repeats in round-robin fashion until N trajectories are sampled.
|
| 128 |
+
|
| 129 |
Returns:
|
| 130 |
Tuple of (list of sampled trajectory indices, stats dictionary)
|
| 131 |
"""
|
|
|
|
| 135 |
"by_task": Counter(),
|
| 136 |
"traj_to_task": {},
|
| 137 |
}
|
| 138 |
+
|
| 139 |
# Group robot trajectories by data source, then by video task
|
| 140 |
trajectories_by_source_and_task = defaultdict(lambda: defaultdict(list))
|
| 141 |
for traj_idx in self.robot_trajectories:
|
|
|
|
| 143 |
data_source = traj.get("data_source", "unknown")
|
| 144 |
video_task = traj.get("task", "unknown")
|
| 145 |
trajectories_by_source_and_task[data_source][video_task].append(traj_idx)
|
| 146 |
+
|
| 147 |
rank_0_print(
|
| 148 |
f"Found {len(trajectories_by_source_and_task)} data sources: {list(trajectories_by_source_and_task.keys())}",
|
| 149 |
verbose=self.verbose,
|
|
|
|
| 154 |
# Shuffle trajectories within each task for randomization
|
| 155 |
for task in tasks_to_indices:
|
| 156 |
self._local_random.shuffle(tasks_to_indices[task])
|
| 157 |
+
|
| 158 |
# Get all unique tasks for this data source
|
| 159 |
all_tasks = list(tasks_to_indices.keys())
|
| 160 |
self._local_random.shuffle(all_tasks) # Randomize task order too
|
| 161 |
+
|
| 162 |
source_stats = {
|
| 163 |
"total_available": sum(len(indices) for indices in tasks_to_indices.values()),
|
| 164 |
"tasks_available": {task: len(indices) for task, indices in tasks_to_indices.items()},
|
| 165 |
"tasks_sampled": Counter(),
|
| 166 |
}
|
| 167 |
+
|
| 168 |
if self.n_trajectories_per_source is None:
|
| 169 |
# Use all available trajectories
|
| 170 |
sampled_from_source = []
|
|
|
|
| 172 |
sampled_from_source.extend(indices)
|
| 173 |
source_stats["tasks_sampled"][task] = len(indices)
|
| 174 |
stats["by_task"][task] += len(indices)
|
| 175 |
+
|
| 176 |
rank_0_print(
|
| 177 |
f" Data source '{data_source}': Using all {len(sampled_from_source)} trajectories",
|
| 178 |
verbose=self.verbose,
|
|
|
|
| 181 |
# Sample N trajectories using round-robin to prioritize different tasks
|
| 182 |
n_to_sample = min(self.n_trajectories_per_source, source_stats["total_available"])
|
| 183 |
sampled_from_source = []
|
| 184 |
+
|
| 185 |
# Round-robin sampling: first get one from each task, then repeat
|
| 186 |
task_iterators = {task: iter(indices) for task, indices in tasks_to_indices.items()}
|
| 187 |
task_list = all_tasks.copy()
|
| 188 |
round_idx = 0
|
| 189 |
+
|
| 190 |
while len(sampled_from_source) < n_to_sample:
|
| 191 |
# If we've gone through all tasks once, reshuffle for next round
|
| 192 |
if round_idx >= len(task_list):
|
| 193 |
round_idx = 0
|
| 194 |
self._local_random.shuffle(task_list)
|
| 195 |
+
|
| 196 |
# Try to get one trajectory from current task
|
| 197 |
task = task_list[round_idx]
|
| 198 |
try:
|
|
|
|
| 206 |
if not task_list:
|
| 207 |
break # All tasks exhausted
|
| 208 |
continue
|
| 209 |
+
|
| 210 |
round_idx += 1
|
| 211 |
+
|
| 212 |
rank_0_print(
|
| 213 |
f" Data source '{data_source}': Sampled {len(sampled_from_source)} out of {source_stats['total_available']} trajectories",
|
| 214 |
verbose=self.verbose,
|
|
|
|
| 217 |
f" Tasks sampled: {dict(sorted(source_stats['tasks_sampled'].items()))}",
|
| 218 |
verbose=self.verbose,
|
| 219 |
)
|
| 220 |
+
|
| 221 |
# Track trajectory to task mapping for stats
|
| 222 |
for traj_idx in sampled_from_source:
|
| 223 |
traj = self.dataset[traj_idx]
|
| 224 |
traj_id = traj.get("id", str(traj_idx))
|
| 225 |
stats["traj_to_task"][traj_id] = traj.get("task", "unknown")
|
| 226 |
+
|
| 227 |
sampled_indices.extend(sampled_from_source)
|
| 228 |
stats["by_source"][data_source] = source_stats
|
| 229 |
|
|
|
|
| 231 |
|
| 232 |
def _print_sampling_stats(self, stats: dict):
|
| 233 |
"""Print detailed statistics about sampled trajectories.
|
| 234 |
+
|
| 235 |
Args:
|
| 236 |
stats: Statistics dictionary from _sample_trajectories_by_data_source
|
| 237 |
"""
|
| 238 |
if not self.verbose:
|
| 239 |
return
|
| 240 |
+
|
| 241 |
rank_0_print("\n=== Confusion Matrix Sampling Statistics ===", verbose=self.verbose)
|
| 242 |
+
|
| 243 |
# Overall task statistics
|
| 244 |
rank_0_print(f"\nOverall trajectories per video task:", verbose=self.verbose)
|
| 245 |
for task, count in sorted(stats["by_task"].items()):
|
| 246 |
rank_0_print(f" {task}: {count} trajectories", verbose=self.verbose)
|
| 247 |
+
|
| 248 |
# Per data source statistics
|
| 249 |
rank_0_print(f"\nPer data source breakdown:", verbose=self.verbose)
|
| 250 |
for data_source, source_stats in stats["by_source"].items():
|
| 251 |
rank_0_print(f" Data source: {data_source}", verbose=self.verbose)
|
| 252 |
rank_0_print(f" Total available: {source_stats['total_available']}", verbose=self.verbose)
|
| 253 |
rank_0_print(f" Tasks available: {len(source_stats['tasks_available'])}", verbose=self.verbose)
|
| 254 |
+
for task, count in sorted(source_stats["tasks_available"].items()):
|
| 255 |
+
sampled_count = source_stats["tasks_sampled"].get(task, 0)
|
| 256 |
rank_0_print(
|
| 257 |
f" {task}: {sampled_count}/{count} trajectories sampled",
|
| 258 |
verbose=self.verbose,
|
| 259 |
)
|
| 260 |
+
|
| 261 |
rank_0_print("=" * 50, verbose=self.verbose)
|
| 262 |
|
| 263 |
def _generate_sample_from_indices(self, sample_idx_info: dict) -> PreferenceSample:
|