Anthony Liang commited on
Commit
1b2bc24
·
1 Parent(s): 1fe73ab
Files changed (2) hide show
  1. app.py +43 -7
  2. requirements.txt +1 -1
app.py CHANGED
@@ -25,7 +25,7 @@ import numpy as np
25
  import requests
26
  from typing import Any, Optional, Tuple
27
 
28
- from rfm.data.dataset_types import Trajectory, ProgressSample, PreferenceSample
29
  from rfm.evals.eval_utils import build_payload, post_batch_npy
30
  from rfm.evals.eval_viz_utils import create_combined_progress_success_plot, extract_frames
31
  from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
@@ -368,7 +368,7 @@ def process_single_video(
368
  return None, f"Error processing video: {str(e)}"
369
 
370
 
371
- def process_dual_videos(
372
  video_a_path: str,
373
  video_b_path: str,
374
  task_text: str = "Complete the task",
@@ -376,7 +376,7 @@ def process_dual_videos(
376
  server_url: str = "",
377
  fps: float = 1.0,
378
  ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
379
- """Process two videos for preference or similarity prediction using eval server."""
380
  if not server_url:
381
  return "Please provide a server URL and check connection first.", None, None
382
 
@@ -497,8 +497,44 @@ def process_dual_videos(
497
  else:
498
  result_text += "Could not extract progress predictions from server response.\n"
499
 
500
- else: # similarity - not yet implemented in eval server response format
501
- result_text = "Similarity prediction not yet supported in eval server response format."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
  # Return result text and both video paths
504
  return result_text, video_a_path, video_b_path
@@ -1292,10 +1328,10 @@ with demo:
1292
  )
1293
 
1294
  analyze_dual_btn.click(
1295
- fn=process_dual_videos,
1296
  inputs=[video_a_input, video_b_input, task_text_dual, prediction_type, server_url_input, fps_input_dual],
1297
  outputs=[result_text, video_a_display, video_b_display],
1298
- api_name="process_dual_videos",
1299
  )
1300
 
1301
 
 
25
  import requests
26
  from typing import Any, Optional, Tuple
27
 
28
+ from rfm.data.dataset_types import Trajectory, ProgressSample, PreferenceSample, SimilaritySample
29
  from rfm.evals.eval_utils import build_payload, post_batch_npy
30
  from rfm.evals.eval_viz_utils import create_combined_progress_success_plot, extract_frames
31
  from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
 
368
  return None, f"Error processing video: {str(e)}"
369
 
370
 
371
+ def process_two_videos(
372
  video_a_path: str,
373
  video_b_path: str,
374
  task_text: str = "Complete the task",
 
376
  server_url: str = "",
377
  fps: float = 1.0,
378
  ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
379
+ """Process two videos for preference, similarity, or progress prediction using eval server."""
380
  if not server_url:
381
  return "Please provide a server URL and check connection first.", None, None
382
 
 
497
  else:
498
  result_text += "Could not extract progress predictions from server response.\n"
499
 
500
+ elif prediction_type == "similarity":
501
+ # For similarity inference, we have two videos:
502
+ # - Video A as reference trajectory
503
+ # - Video B as similar trajectory
504
+ # diff_trajectory is None in inference mode (only need similarity between ref and sim)
505
+
506
+ # Create SimilaritySample with Video A as ref and Video B as sim
507
+ similarity_sample = SimilaritySample(
508
+ ref_trajectory=trajectory_a,
509
+ sim_trajectory=trajectory_b,
510
+ diff_trajectory=None, # None in inference mode
511
+ data_gen_strategy="demo",
512
+ )
513
+
514
+ # Build payload and send to server
515
+ files, sample_data = build_payload([similarity_sample])
516
+ response = post_batch_npy(server_url, files, sample_data, timeout_s=120.0)
517
+
518
+ # Process response - we only care about sim_score_ref_sim (similarity between Video A and Video B)
519
+ outputs_similarity = response.get("outputs_similarity", {})
520
+ sim_score_ref_sim = outputs_similarity.get("sim_score_ref_sim", [])
521
+
522
+ result_text = f"**Similarity Prediction:**\n"
523
+ if sim_score_ref_sim and len(sim_score_ref_sim) > 0:
524
+ sim_score = sim_score_ref_sim[0]
525
+ if sim_score is not None:
526
+ result_text += f"- Similarity score (Video A vs Video B): {sim_score:.3f}\n"
527
+ # Interpret similarity score (higher = more similar)
528
+ if sim_score > 0.7:
529
+ result_text += f"- Interpretation: High similarity - videos are very similar\n"
530
+ elif sim_score > 0.4:
531
+ result_text += f"- Interpretation: Moderate similarity - videos share some similarities\n"
532
+ else:
533
+ result_text += f"- Interpretation: Low similarity - videos are quite different\n"
534
+ else:
535
+ result_text += "Could not extract similarity score from server response.\n"
536
+ else:
537
+ result_text += "Could not extract similarity prediction from server response.\n"
538
 
539
  # Return result text and both video paths
540
  return result_text, video_a_path, video_b_path
 
1328
  )
1329
 
1330
  analyze_dual_btn.click(
1331
+ fn=process_two_videos,
1332
  inputs=[video_a_input, video_b_input, task_text_dual, prediction_type, server_url_input, fps_input_dual],
1333
  outputs=[result_text, video_a_display, video_b_display],
1334
+ api_name="process_two_videos",
1335
  )
1336
 
1337
 
requirements.txt CHANGED
@@ -26,7 +26,7 @@ watchfiles # For file watching during development
26
 
27
  # RFM package (installed from git repository)
28
  # For local development, you can also install with: pip install -e ../ (from parent directory)
29
- git+https://github.com/aliang8/reward_fm.git@d0bfb225f0a8002ef301ea36a6eeadb7becc62d9
30
 
31
  # Make sure a newer version of gradio is installed
32
  gradio==4.44.0
 
26
 
27
  # RFM package (installed from git repository)
28
  # For local development, you can also install with: pip install -e ../ (from parent directory)
29
+ git+https://github.com/aliang8/reward_fm.git@7fd45f3854d45aa297a9873c84e4dc663ef5519e
30
 
31
  # Make sure a newer version of gradio is installed
32
  gradio==4.44.0