Spaces:
Running
Running
Anthony Liang commited on
Commit ·
1b2bc24
1
Parent(s): 1fe73ab
update
Browse files- app.py +43 -7
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -25,7 +25,7 @@ import numpy as np
|
|
| 25 |
import requests
|
| 26 |
from typing import Any, Optional, Tuple
|
| 27 |
|
| 28 |
-
from rfm.data.dataset_types import Trajectory, ProgressSample, PreferenceSample
|
| 29 |
from rfm.evals.eval_utils import build_payload, post_batch_npy
|
| 30 |
from rfm.evals.eval_viz_utils import create_combined_progress_success_plot, extract_frames
|
| 31 |
from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
|
|
@@ -368,7 +368,7 @@ def process_single_video(
|
|
| 368 |
return None, f"Error processing video: {str(e)}"
|
| 369 |
|
| 370 |
|
| 371 |
-
def
|
| 372 |
video_a_path: str,
|
| 373 |
video_b_path: str,
|
| 374 |
task_text: str = "Complete the task",
|
|
@@ -376,7 +376,7 @@ def process_dual_videos(
|
|
| 376 |
server_url: str = "",
|
| 377 |
fps: float = 1.0,
|
| 378 |
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
| 379 |
-
"""Process two videos for preference or
|
| 380 |
if not server_url:
|
| 381 |
return "Please provide a server URL and check connection first.", None, None
|
| 382 |
|
|
@@ -497,8 +497,44 @@ def process_dual_videos(
|
|
| 497 |
else:
|
| 498 |
result_text += "Could not extract progress predictions from server response.\n"
|
| 499 |
|
| 500 |
-
|
| 501 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
|
| 503 |
# Return result text and both video paths
|
| 504 |
return result_text, video_a_path, video_b_path
|
|
@@ -1292,10 +1328,10 @@ with demo:
|
|
| 1292 |
)
|
| 1293 |
|
| 1294 |
analyze_dual_btn.click(
|
| 1295 |
-
fn=
|
| 1296 |
inputs=[video_a_input, video_b_input, task_text_dual, prediction_type, server_url_input, fps_input_dual],
|
| 1297 |
outputs=[result_text, video_a_display, video_b_display],
|
| 1298 |
-
api_name="
|
| 1299 |
)
|
| 1300 |
|
| 1301 |
|
|
|
|
| 25 |
import requests
|
| 26 |
from typing import Any, Optional, Tuple
|
| 27 |
|
| 28 |
+
from rfm.data.dataset_types import Trajectory, ProgressSample, PreferenceSample, SimilaritySample
|
| 29 |
from rfm.evals.eval_utils import build_payload, post_batch_npy
|
| 30 |
from rfm.evals.eval_viz_utils import create_combined_progress_success_plot, extract_frames
|
| 31 |
from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
|
|
|
|
| 368 |
return None, f"Error processing video: {str(e)}"
|
| 369 |
|
| 370 |
|
| 371 |
+
def process_two_videos(
|
| 372 |
video_a_path: str,
|
| 373 |
video_b_path: str,
|
| 374 |
task_text: str = "Complete the task",
|
|
|
|
| 376 |
server_url: str = "",
|
| 377 |
fps: float = 1.0,
|
| 378 |
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
| 379 |
+
"""Process two videos for preference, similarity, or progress prediction using eval server."""
|
| 380 |
if not server_url:
|
| 381 |
return "Please provide a server URL and check connection first.", None, None
|
| 382 |
|
|
|
|
| 497 |
else:
|
| 498 |
result_text += "Could not extract progress predictions from server response.\n"
|
| 499 |
|
| 500 |
+
elif prediction_type == "similarity":
|
| 501 |
+
# For similarity inference, we have two videos:
|
| 502 |
+
# - Video A as reference trajectory
|
| 503 |
+
# - Video B as similar trajectory
|
| 504 |
+
# diff_trajectory is None in inference mode (only need similarity between ref and sim)
|
| 505 |
+
|
| 506 |
+
# Create SimilaritySample with Video A as ref and Video B as sim
|
| 507 |
+
similarity_sample = SimilaritySample(
|
| 508 |
+
ref_trajectory=trajectory_a,
|
| 509 |
+
sim_trajectory=trajectory_b,
|
| 510 |
+
diff_trajectory=None, # None in inference mode
|
| 511 |
+
data_gen_strategy="demo",
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# Build payload and send to server
|
| 515 |
+
files, sample_data = build_payload([similarity_sample])
|
| 516 |
+
response = post_batch_npy(server_url, files, sample_data, timeout_s=120.0)
|
| 517 |
+
|
| 518 |
+
# Process response - we only care about sim_score_ref_sim (similarity between Video A and Video B)
|
| 519 |
+
outputs_similarity = response.get("outputs_similarity", {})
|
| 520 |
+
sim_score_ref_sim = outputs_similarity.get("sim_score_ref_sim", [])
|
| 521 |
+
|
| 522 |
+
result_text = f"**Similarity Prediction:**\n"
|
| 523 |
+
if sim_score_ref_sim and len(sim_score_ref_sim) > 0:
|
| 524 |
+
sim_score = sim_score_ref_sim[0]
|
| 525 |
+
if sim_score is not None:
|
| 526 |
+
result_text += f"- Similarity score (Video A vs Video B): {sim_score:.3f}\n"
|
| 527 |
+
# Interpret similarity score (higher = more similar)
|
| 528 |
+
if sim_score > 0.7:
|
| 529 |
+
result_text += f"- Interpretation: High similarity - videos are very similar\n"
|
| 530 |
+
elif sim_score > 0.4:
|
| 531 |
+
result_text += f"- Interpretation: Moderate similarity - videos share some similarities\n"
|
| 532 |
+
else:
|
| 533 |
+
result_text += f"- Interpretation: Low similarity - videos are quite different\n"
|
| 534 |
+
else:
|
| 535 |
+
result_text += "Could not extract similarity score from server response.\n"
|
| 536 |
+
else:
|
| 537 |
+
result_text += "Could not extract similarity prediction from server response.\n"
|
| 538 |
|
| 539 |
# Return result text and both video paths
|
| 540 |
return result_text, video_a_path, video_b_path
|
|
|
|
| 1328 |
)
|
| 1329 |
|
| 1330 |
analyze_dual_btn.click(
|
| 1331 |
+
fn=process_two_videos,
|
| 1332 |
inputs=[video_a_input, video_b_input, task_text_dual, prediction_type, server_url_input, fps_input_dual],
|
| 1333 |
outputs=[result_text, video_a_display, video_b_display],
|
| 1334 |
+
api_name="process_two_videos",
|
| 1335 |
)
|
| 1336 |
|
| 1337 |
|
requirements.txt
CHANGED
|
@@ -26,7 +26,7 @@ watchfiles # For file watching during development
|
|
| 26 |
|
| 27 |
# RFM package (installed from git repository)
|
| 28 |
# For local development, you can also install with: pip install -e ../ (from parent directory)
|
| 29 |
-
git+https://github.com/aliang8/reward_fm.git@
|
| 30 |
|
| 31 |
# Make sure a newer version of gradio is installed
|
| 32 |
gradio==4.44.0
|
|
|
|
| 26 |
|
| 27 |
# RFM package (installed from git repository)
|
| 28 |
# For local development, you can also install with: pip install -e ../ (from parent directory)
|
| 29 |
+
git+https://github.com/aliang8/reward_fm.git@7fd45f3854d45aa297a9873c84e4dc663ef5519e
|
| 30 |
|
| 31 |
# Make sure a newer version of gradio is installed
|
| 32 |
gradio==4.44.0
|