Anthony Liang commited on
Commit
679ca41
·
1 Parent(s): a326636

remove similarity

Browse files
Files changed (3) hide show
  1. app.py +16 -63
  2. dataset_types.py +2 -15
  3. eval_utils.py +4 -18
app.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/env python3
2
  """
3
- Gradio app for RFM (Reward Foundation Model) inference visualization.
4
- Supports single video (progress/success) and dual video (preference/similarity) predictions.
5
  Uses eval server for inference instead of loading models locally.
6
  """
7
 
@@ -25,7 +25,7 @@ import numpy as np
25
  import requests
26
  from typing import Any, List, Optional, Tuple
27
 
28
- from dataset_types import Trajectory, ProgressSample, PreferenceSample, SimilaritySample
29
  from eval_utils import build_payload, post_batch_npy
30
  from eval_viz_utils import create_combined_progress_success_plot, extract_frames
31
  from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
@@ -225,7 +225,6 @@ def format_model_info(model_info: dict) -> str:
225
  lines.append(f"- **Model Type:** `{model_cfg.get('model_type', 'N/A')}`\n")
226
  lines.append(f"- **Train Progress Head:** {model_cfg.get('train_progress_head', False)}\n")
227
  lines.append(f"- **Train Preference Head:** {model_cfg.get('train_preference_head', False)}\n")
228
- lines.append(f"- **Train Similarity Head:** {model_cfg.get('train_similarity_head', False)}\n")
229
  lines.append(f"- **Train Success Head:** {model_cfg.get('train_success_head', False)}\n")
230
  lines.append(f"- **Use PEFT:** {model_cfg.get('use_peft', False)}\n")
231
  lines.append(f"- **Use Unsloth:** {model_cfg.get('use_unsloth', False)}\n")
@@ -259,8 +258,8 @@ def format_model_info(model_info: dict) -> str:
259
  return "".join(lines)
260
 
261
 
262
- def load_rfm_dataset(dataset_name, config_name):
263
- """Load the RFM dataset from HuggingFace Hub."""
264
  try:
265
  if not dataset_name or not config_name:
266
  return None, "Please provide both dataset name and configuration"
@@ -302,7 +301,7 @@ def get_trajectory_video_path(dataset, index, dataset_name):
302
  if dataset_name:
303
  video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}"
304
  else:
305
- video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}"
306
 
307
  task = item.get("task", "Complete the task")
308
  quality_label = item.get("quality_label", None)
@@ -432,7 +431,7 @@ def process_two_videos(
432
  server_url: str = "",
433
  fps: float = 1.0,
434
  ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
435
- """Process two videos for preference, similarity, or progress prediction using eval server."""
436
  # Get server URL from state if not provided
437
  if not server_url:
438
  server_url = _server_state.get("server_url")
@@ -515,8 +514,6 @@ def process_two_videos(
515
 
516
  elif prediction_type == "progress":
517
  # Create ProgressSamples for both videos
518
- from dataset_types import ProgressSample
519
-
520
  progress_sample_a = ProgressSample(
521
  trajectory=trajectory_a,
522
  data_gen_strategy="demo",
@@ -554,45 +551,6 @@ def process_two_videos(
554
  else:
555
  result_text += "Could not extract progress predictions from server response.\n"
556
 
557
- elif prediction_type == "similarity":
558
- # For similarity inference, we have two videos:
559
- # - Video A as reference trajectory
560
- # - Video B as similar trajectory
561
- # diff_trajectory is None in inference mode (only need similarity between ref and sim)
562
-
563
- # Create SimilaritySample with Video A as ref and Video B as sim
564
- similarity_sample = SimilaritySample(
565
- ref_trajectory=trajectory_a,
566
- sim_trajectory=trajectory_b,
567
- diff_trajectory=None, # None in inference mode
568
- data_gen_strategy="demo",
569
- )
570
-
571
- # Build payload and send to server
572
- files, sample_data = build_payload([similarity_sample])
573
- response = post_batch_npy(server_url, files, sample_data, timeout_s=120.0)
574
-
575
- # Process response - we only care about sim_score_ref_sim (similarity between Video A and Video B)
576
- outputs_similarity = response.get("outputs_similarity", {})
577
- sim_score_ref_sim = outputs_similarity.get("sim_score_ref_sim", [])
578
-
579
- result_text = f"**Similarity Prediction:**\n"
580
- if sim_score_ref_sim and len(sim_score_ref_sim) > 0:
581
- sim_score = sim_score_ref_sim[0]
582
- if sim_score is not None:
583
- result_text += f"- Similarity score (Video A vs Video B): {sim_score:.3f}\n"
584
- # Interpret similarity score (higher = more similar)
585
- if sim_score > 0.7:
586
- result_text += f"- Interpretation: High similarity - videos are very similar\n"
587
- elif sim_score > 0.4:
588
- result_text += f"- Interpretation: Moderate similarity - videos share some similarities\n"
589
- else:
590
- result_text += f"- Interpretation: Low similarity - videos are quite different\n"
591
- else:
592
- result_text += "Could not extract similarity score from server response.\n"
593
- else:
594
- result_text += "Could not extract similarity prediction from server response.\n"
595
-
596
  # Return result text and both video paths
597
  return result_text, video_a_path, video_b_path
598
 
@@ -603,15 +561,15 @@ def process_two_videos(
603
  # Create Gradio interface
604
  try:
605
  # Try with theme (Gradio 4.0+)
606
- demo = gr.Blocks(title="RFM Evaluation Server", theme=gr.themes.Soft())
607
  except TypeError:
608
  # Fallback for older Gradio versions without theme support
609
- demo = gr.Blocks(title="RFM Evaluation Server")
610
 
611
  with demo:
612
  gr.Markdown(
613
  """
614
- # RFM (Reward Foundation Model) Evaluation Server
615
  """
616
  )
617
 
@@ -822,7 +780,7 @@ with demo:
822
 
823
  def load_dataset_single(dataset_name, config_name):
824
  """Load dataset and update slider."""
825
- dataset, status = load_rfm_dataset(dataset_name, config_name)
826
  if dataset is not None:
827
  max_index = len(dataset) - 1
828
  return (
@@ -1021,8 +979,8 @@ with demo:
1021
  api_name="process_single_video",
1022
  )
1023
 
1024
- with gr.Tab("Preference/Similarity Analysis"):
1025
- gr.Markdown("### Preference & Similarity Prediction")
1026
  with gr.Row():
1027
  with gr.Column():
1028
  video_a_input = gr.Video(label="Video A", height=250)
@@ -1033,7 +991,7 @@ with demo:
1033
  value="Complete the task",
1034
  )
1035
  prediction_type = gr.Radio(
1036
- choices=["preference", "similarity", "progress"],
1037
  value="preference",
1038
  label="Prediction Type",
1039
  )
@@ -1129,7 +1087,7 @@ with demo:
1129
 
1130
  def load_dataset_a(dataset_name, config_name):
1131
  """Load dataset A and update slider."""
1132
- dataset, status = load_rfm_dataset(dataset_name, config_name)
1133
  if dataset is not None:
1134
  max_index = len(dataset) - 1
1135
  return (
@@ -1274,7 +1232,7 @@ with demo:
1274
 
1275
  def load_dataset_b(dataset_name, config_name):
1276
  """Load dataset B and update slider."""
1277
- dataset, status = load_rfm_dataset(dataset_name, config_name)
1278
  if dataset is not None:
1279
  max_index = len(dataset) - 1
1280
  return (
@@ -1509,11 +1467,6 @@ with demo:
1509
 
1510
  def main():
1511
  """Launch the Gradio app."""
1512
- import sys
1513
-
1514
- # Check if reload mode is requested
1515
- watch_files = os.getenv("GRADIO_WATCH", "0") == "1" or "--reload" in sys.argv
1516
-
1517
  demo.launch(
1518
  server_name="0.0.0.0",
1519
  server_port=7860,
 
1
  #!/usr/bin/env python3
2
  """
3
+ Gradio app for RBM (Reward Foundation Model) inference visualization.
4
+ Supports single video (progress/success) and dual video (preference/progress) predictions.
5
  Uses eval server for inference instead of loading models locally.
6
  """
7
 
 
25
  import requests
26
  from typing import Any, List, Optional, Tuple
27
 
28
+ from dataset_types import Trajectory, ProgressSample, PreferenceSample
29
  from eval_utils import build_payload, post_batch_npy
30
  from eval_viz_utils import create_combined_progress_success_plot, extract_frames
31
  from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
 
225
  lines.append(f"- **Model Type:** `{model_cfg.get('model_type', 'N/A')}`\n")
226
  lines.append(f"- **Train Progress Head:** {model_cfg.get('train_progress_head', False)}\n")
227
  lines.append(f"- **Train Preference Head:** {model_cfg.get('train_preference_head', False)}\n")
 
228
  lines.append(f"- **Train Success Head:** {model_cfg.get('train_success_head', False)}\n")
229
  lines.append(f"- **Use PEFT:** {model_cfg.get('use_peft', False)}\n")
230
  lines.append(f"- **Use Unsloth:** {model_cfg.get('use_unsloth', False)}\n")
 
258
  return "".join(lines)
259
 
260
 
261
+ def load_rbm_dataset(dataset_name, config_name):
262
+ """Load an RBM-format dataset from HuggingFace Hub."""
263
  try:
264
  if not dataset_name or not config_name:
265
  return None, "Please provide both dataset name and configuration"
 
301
  if dataset_name:
302
  video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}"
303
  else:
304
+ video_path = f"https://huggingface.co/datasets/rewardfm/rbm-1m/resolve/main/{frames_data}"
305
 
306
  task = item.get("task", "Complete the task")
307
  quality_label = item.get("quality_label", None)
 
431
  server_url: str = "",
432
  fps: float = 1.0,
433
  ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
434
+ """Process two videos for preference or progress prediction using eval server."""
435
  # Get server URL from state if not provided
436
  if not server_url:
437
  server_url = _server_state.get("server_url")
 
514
 
515
  elif prediction_type == "progress":
516
  # Create ProgressSamples for both videos
 
 
517
  progress_sample_a = ProgressSample(
518
  trajectory=trajectory_a,
519
  data_gen_strategy="demo",
 
551
  else:
552
  result_text += "Could not extract progress predictions from server response.\n"
553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
  # Return result text and both video paths
555
  return result_text, video_a_path, video_b_path
556
 
 
561
  # Create Gradio interface
562
  try:
563
  # Try with theme (Gradio 4.0+)
564
+ demo = gr.Blocks(title="RBM Evaluation Server", theme=gr.themes.Soft())
565
  except TypeError:
566
  # Fallback for older Gradio versions without theme support
567
+ demo = gr.Blocks(title="RBM Evaluation Server")
568
 
569
  with demo:
570
  gr.Markdown(
571
  """
572
+ # RBM (Reward Foundation Model) Evaluation Server
573
  """
574
  )
575
 
 
780
 
781
  def load_dataset_single(dataset_name, config_name):
782
  """Load dataset and update slider."""
783
+ dataset, status = load_rbm_dataset(dataset_name, config_name)
784
  if dataset is not None:
785
  max_index = len(dataset) - 1
786
  return (
 
979
  api_name="process_single_video",
980
  )
981
 
982
+ with gr.Tab("Preference Analysis"):
983
+ gr.Markdown("### Preference & Progress Prediction")
984
  with gr.Row():
985
  with gr.Column():
986
  video_a_input = gr.Video(label="Video A", height=250)
 
991
  value="Complete the task",
992
  )
993
  prediction_type = gr.Radio(
994
+ choices=["preference", "progress"],
995
  value="preference",
996
  label="Prediction Type",
997
  )
 
1087
 
1088
  def load_dataset_a(dataset_name, config_name):
1089
  """Load dataset A and update slider."""
1090
+ dataset, status = load_rbm_dataset(dataset_name, config_name)
1091
  if dataset is not None:
1092
  max_index = len(dataset) - 1
1093
  return (
 
1232
 
1233
  def load_dataset_b(dataset_name, config_name):
1234
  """Load dataset B and update slider."""
1235
+ dataset, status = load_rbm_dataset(dataset_name, config_name)
1236
  if dataset is not None:
1237
  max_index = len(dataset) - 1
1238
  return (
 
1467
 
1468
  def main():
1469
  """Launch the Gradio app."""
 
 
 
 
 
1470
  demo.launch(
1471
  server_name="0.0.0.0",
1472
  server_port=7860,
dataset_types.py CHANGED
@@ -1,6 +1,6 @@
1
  #!/usr/bin/env python3
2
  """
3
- Dataclasses for RFM model dataset trajectory structures.
4
  Defines the standard format for HuggingFace dataset trajectories.
5
  """
6
 
@@ -62,17 +62,4 @@ class PreferenceSample(BaseModel):
62
  resample_attempts: int = 1
63
 
64
 
65
- class SimilaritySample(BaseModel):
66
- """Sample structure for similarity scoring: traj_sim and traj_diff ranked against o^ref."""
67
-
68
- # Trajectories
69
- ref_trajectory: Trajectory # o^ref
70
- sim_trajectory: Trajectory # Similar trajectory
71
- diff_trajectory: Optional[Trajectory] = None # Different trajectory (optional in inference mode)
72
-
73
- sample_type: str = "similarity"
74
- data_gen_strategy: Optional[str] = None
75
- resample_attempts: int = 1
76
-
77
-
78
- SampleType = Union[PreferenceSample, SimilaritySample, ProgressSample]
 
1
  #!/usr/bin/env python3
2
  """
3
+ Dataclasses for RBM model dataset trajectory structures.
4
  Defines the standard format for HuggingFace dataset trajectories.
5
  """
6
 
 
62
  resample_attempts: int = 1
63
 
64
 
65
+ SampleType = Union[PreferenceSample, ProgressSample]
 
 
 
 
 
 
 
 
 
 
 
 
 
eval_utils.py CHANGED
@@ -1,7 +1,6 @@
1
  #!/usr/bin/env python3
2
  from __future__ import annotations
3
 
4
- import re
5
  import torch
6
  import io
7
  import json
@@ -15,7 +14,7 @@ import numpy as np
15
  import requests
16
  import torch
17
 
18
- from dataset_types import PreferenceSample, SimilaritySample, ProgressSample, Trajectory
19
 
20
 
21
  def pad_trajectory_to_max_frames_np(
@@ -123,13 +122,6 @@ def linspace_subsample_frames(
123
  return subsampled_frames, indices
124
 
125
 
126
- def extract_answer_from_text(text: str) -> str:
127
- """Extract answer from text using <ans> tags."""
128
- m = re.search(r"<ans>(.*?)</ans>", text, re.DOTALL)
129
- ans = m.group(1).strip() if m else ""
130
- return ans
131
-
132
-
133
  def raw_dict_to_sample(
134
  raw_data: Union[Tuple[Dict[str, Any], Dict[str, Any]], Dict[str, Any]],
135
  max_frames: int = 16,
@@ -216,7 +208,7 @@ def raw_dict_to_sample(
216
 
217
 
218
  def build_payload(
219
- samples: list[PreferenceSample | SimilaritySample | ProgressSample],
220
  ) -> tuple[dict[str, Any], list[dict[str, Any]]]:
221
  """Build a payload with numpy array handling.
222
 
@@ -239,9 +231,6 @@ def build_payload(
239
  for key in [
240
  "chosen_trajectory",
241
  "rejected_trajectory",
242
- "reference_trajectory",
243
- "traj_sim_trajectory",
244
- "traj_diff_trajectory",
245
  "trajectory",
246
  ]:
247
  if key in processed_sample and isinstance(processed_sample[key], dict):
@@ -287,7 +276,7 @@ def post_batch_npy(
287
  extra_form_data: Optional[dict[str, Any]] = None,
288
  ) -> dict[str, Any]:
289
  """POST batch using .npy format for numpy arrays.
290
-
291
  Args:
292
  url: Server URL
293
  files: Dict of numpy arrays converted to .npy format
@@ -297,7 +286,7 @@ def post_batch_npy(
297
  """
298
  # Convert sample_data to form data
299
  data = {f"sample_{i}": json.dumps(sample) for i, sample in enumerate(sample_data)}
300
-
301
  # Add extra form data if provided
302
  if extra_form_data:
303
  for key, value in extra_form_data.items():
@@ -400,9 +389,6 @@ def reconstruct_payload_from_npy(
400
  trajectory_keys = [
401
  "chosen_trajectory",
402
  "rejected_trajectory",
403
- "reference_trajectory",
404
- "traj_sim_trajectory",
405
- "traj_diff_trajectory",
406
  "trajectory",
407
  ]
408
 
 
1
  #!/usr/bin/env python3
2
  from __future__ import annotations
3
 
 
4
  import torch
5
  import io
6
  import json
 
14
  import requests
15
  import torch
16
 
17
+ from dataset_types import PreferenceSample, ProgressSample, Trajectory
18
 
19
 
20
  def pad_trajectory_to_max_frames_np(
 
122
  return subsampled_frames, indices
123
 
124
 
 
 
 
 
 
 
 
125
  def raw_dict_to_sample(
126
  raw_data: Union[Tuple[Dict[str, Any], Dict[str, Any]], Dict[str, Any]],
127
  max_frames: int = 16,
 
208
 
209
 
210
  def build_payload(
211
+ samples: list[PreferenceSample | ProgressSample],
212
  ) -> tuple[dict[str, Any], list[dict[str, Any]]]:
213
  """Build a payload with numpy array handling.
214
 
 
231
  for key in [
232
  "chosen_trajectory",
233
  "rejected_trajectory",
 
 
 
234
  "trajectory",
235
  ]:
236
  if key in processed_sample and isinstance(processed_sample[key], dict):
 
276
  extra_form_data: Optional[dict[str, Any]] = None,
277
  ) -> dict[str, Any]:
278
  """POST batch using .npy format for numpy arrays.
279
+
280
  Args:
281
  url: Server URL
282
  files: Dict of numpy arrays converted to .npy format
 
286
  """
287
  # Convert sample_data to form data
288
  data = {f"sample_{i}": json.dumps(sample) for i, sample in enumerate(sample_data)}
289
+
290
  # Add extra form data if provided
291
  if extra_form_data:
292
  for key, value in extra_form_data.items():
 
389
  trajectory_keys = [
390
  "chosen_trajectory",
391
  "rejected_trajectory",
 
 
 
392
  "trajectory",
393
  ]
394