Spaces:
Running
Running
Anthony Liang commited on
Commit ·
3e462dd
1
Parent(s): e498336
getting rid of dependencies
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +4 -4
- dataset_types.py +78 -0
- eval_utils.py +400 -0
- eval_viz_utils.py +205 -0
- requirements.txt +1 -2
- samplers/README.md +182 -0
- samplers/__init__.py +23 -0
- samplers/__pycache__/__init__.cpython-310.pyc +0 -0
- samplers/__pycache__/__init__.cpython-311.pyc +0 -0
- samplers/__pycache__/base.cpython-310.pyc +0 -0
- samplers/__pycache__/base.cpython-311.pyc +0 -0
- samplers/__pycache__/confusion_matrix.cpython-310.pyc +0 -0
- samplers/__pycache__/confusion_matrix.cpython-311.pyc +0 -0
- samplers/__pycache__/pref.cpython-310.pyc +0 -0
- samplers/__pycache__/pref.cpython-311.pyc +0 -0
- samplers/__pycache__/progress.cpython-310.pyc +0 -0
- samplers/__pycache__/progress.cpython-311.pyc +0 -0
- samplers/__pycache__/progress_default.cpython-310.pyc +0 -0
- samplers/__pycache__/progress_default.cpython-311.pyc +0 -0
- samplers/__pycache__/quality_preference.cpython-310.pyc +0 -0
- samplers/__pycache__/quality_preference.cpython-311.pyc +0 -0
- samplers/__pycache__/reward_alignment.cpython-310.pyc +0 -0
- samplers/__pycache__/reward_alignment.cpython-311.pyc +0 -0
- samplers/__pycache__/roboarena.cpython-310.pyc +0 -0
- samplers/__pycache__/sim.cpython-310.pyc +0 -0
- samplers/__pycache__/sim.cpython-311.pyc +0 -0
- samplers/__pycache__/success_failure.cpython-310.pyc +0 -0
- samplers/__pycache__/success_failure.cpython-311.pyc +0 -0
- samplers/base.py +753 -0
- samplers/eval/__pycache__/base_pref.cpython-310.pyc +0 -0
- samplers/eval/__pycache__/base_pref.cpython-311.pyc +0 -0
- samplers/eval/__pycache__/confusion_matrix.cpython-310.pyc +0 -0
- samplers/eval/__pycache__/confusion_matrix.cpython-311.pyc +0 -0
- samplers/eval/__pycache__/progress_default.cpython-310.pyc +0 -0
- samplers/eval/__pycache__/progress_default.cpython-311.pyc +0 -0
- samplers/eval/__pycache__/progress_policy_ranking.cpython-310.pyc +0 -0
- samplers/eval/__pycache__/progress_policy_ranking.cpython-311.pyc +0 -0
- samplers/eval/__pycache__/quality_preference.cpython-310.pyc +0 -0
- samplers/eval/__pycache__/quality_preference.cpython-311.pyc +0 -0
- samplers/eval/__pycache__/reward_alignment.cpython-310.pyc +0 -0
- samplers/eval/__pycache__/reward_alignment.cpython-311.pyc +0 -0
- samplers/eval/__pycache__/roboarena_quality_preference.cpython-310.pyc +0 -0
- samplers/eval/__pycache__/roboarena_quality_preference.cpython-311.pyc +0 -0
- samplers/eval/__pycache__/similarity_score.cpython-310.pyc +0 -0
- samplers/eval/__pycache__/similarity_score.cpython-311.pyc +0 -0
- samplers/eval/base_pref.py +73 -0
- samplers/eval/confusion_matrix.py +299 -0
- samplers/eval/progress_policy_ranking.py +231 -0
- samplers/eval/quality_preference.py +219 -0
- samplers/eval/reward_alignment.py +147 -0
app.py
CHANGED
|
@@ -25,9 +25,9 @@ import numpy as np
|
|
| 25 |
import requests
|
| 26 |
from typing import Any, List, Optional, Tuple
|
| 27 |
|
| 28 |
-
from
|
| 29 |
-
from
|
| 30 |
-
from
|
| 31 |
from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
|
| 32 |
|
| 33 |
logger = logging.getLogger(__name__)
|
|
@@ -514,7 +514,7 @@ def process_two_videos(
|
|
| 514 |
|
| 515 |
elif prediction_type == "progress":
|
| 516 |
# Create ProgressSamples for both videos
|
| 517 |
-
from
|
| 518 |
|
| 519 |
progress_sample_a = ProgressSample(
|
| 520 |
trajectory=trajectory_a,
|
|
|
|
| 25 |
import requests
|
| 26 |
from typing import Any, List, Optional, Tuple
|
| 27 |
|
| 28 |
+
from dataset_types import Trajectory, ProgressSample, PreferenceSample, SimilaritySample
|
| 29 |
+
from eval_utils import build_payload, post_batch_npy
|
| 30 |
+
from eval_viz_utils import create_combined_progress_success_plot, extract_frames
|
| 31 |
from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
|
| 32 |
|
| 33 |
logger = logging.getLogger(__name__)
|
|
|
|
| 514 |
|
| 515 |
elif prediction_type == "progress":
|
| 516 |
# Create ProgressSamples for both videos
|
| 517 |
+
from dataset_types import ProgressSample
|
| 518 |
|
| 519 |
progress_sample_a = ProgressSample(
|
| 520 |
trajectory=trajectory_a,
|
dataset_types.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Dataclasses for RFM model dataset trajectory structures.
|
| 4 |
+
Defines the standard format for HuggingFace dataset trajectories.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Any, Union, List, Dict, Tuple, Optional
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from pydantic import BaseModel, ConfigDict
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Trajectory(BaseModel):
|
| 15 |
+
"""Trajectory structure containing frames, metadata, and progress information."""
|
| 16 |
+
|
| 17 |
+
# Core trajectory fields
|
| 18 |
+
frames: Union[List[str], np.ndarray, None] = None
|
| 19 |
+
frames_shape: Optional[Tuple] = None
|
| 20 |
+
|
| 21 |
+
# If embeddings are precomputed
|
| 22 |
+
embeddings_path: Optional[str] = None
|
| 23 |
+
video_embeddings: Union[torch.Tensor, np.ndarray, None] = None
|
| 24 |
+
text_embedding: Union[torch.Tensor, np.ndarray, None] = None
|
| 25 |
+
|
| 26 |
+
id: Optional[str] = None
|
| 27 |
+
task: Optional[str] = None
|
| 28 |
+
lang_vector: Union[np.ndarray, List[float], None] = None
|
| 29 |
+
data_source: Optional[str] = None
|
| 30 |
+
quality_label: Optional[str] = None
|
| 31 |
+
is_robot: Optional[bool] = None
|
| 32 |
+
|
| 33 |
+
# Progress and metadata
|
| 34 |
+
# Can be List[float] for continuous progress, np.ndarray, or List[np.ndarray] for C51 discrete distributions
|
| 35 |
+
target_progress: Optional[Union[List[float], List[torch.Tensor], torch.Tensor, None]] = None
|
| 36 |
+
partial_success: Optional[Union[float, torch.Tensor]] = None # float for continuous, Tensor for C51 discrete
|
| 37 |
+
success_label: Optional[List[float]] = None # Success labels for each frame (1.0 for success, 0.0 for failure)
|
| 38 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 39 |
+
data_gen_strategy: Optional[str] = None
|
| 40 |
+
|
| 41 |
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ProgressSample(BaseModel):
|
| 45 |
+
"""Sample structure for progress evaluation."""
|
| 46 |
+
|
| 47 |
+
trajectory: Trajectory
|
| 48 |
+
sample_type: str = "progress"
|
| 49 |
+
data_gen_strategy: Optional[str] = None
|
| 50 |
+
resample_attempts: int = 1
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class PreferenceSample(BaseModel):
|
| 54 |
+
"""Sample structure for preference prediction: chosen vs rejected where chosen is preferred."""
|
| 55 |
+
|
| 56 |
+
# Trajectories
|
| 57 |
+
chosen_trajectory: Trajectory
|
| 58 |
+
rejected_trajectory: Trajectory
|
| 59 |
+
|
| 60 |
+
sample_type: str = "preference"
|
| 61 |
+
data_gen_strategy: Optional[str] = None
|
| 62 |
+
resample_attempts: int = 1
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class SimilaritySample(BaseModel):
|
| 66 |
+
"""Sample structure for similarity scoring: traj_sim and traj_diff ranked against o^ref."""
|
| 67 |
+
|
| 68 |
+
# Trajectories
|
| 69 |
+
ref_trajectory: Trajectory # o^ref
|
| 70 |
+
sim_trajectory: Trajectory # Similar trajectory
|
| 71 |
+
diff_trajectory: Optional[Trajectory] = None # Different trajectory (optional in inference mode)
|
| 72 |
+
|
| 73 |
+
sample_type: str = "similarity"
|
| 74 |
+
data_gen_strategy: Optional[str] = None
|
| 75 |
+
resample_attempts: int = 1
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
SampleType = Union[PreferenceSample, SimilaritySample, ProgressSample]
|
eval_utils.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import re
|
| 5 |
+
import torch
|
| 6 |
+
import io
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Dict, List, Union, Optional, Tuple
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
|
| 13 |
+
import aiohttp
|
| 14 |
+
import numpy as np
|
| 15 |
+
import requests
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
from rfm.data.dataset_types import PreferenceSample, SimilaritySample, ProgressSample, Trajectory
|
| 19 |
+
from rfm.data.datasets.helpers import linspace_subsample_frames, pad_trajectory_to_max_frames_np
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def extract_answer_from_text(text: str) -> str:
|
| 23 |
+
"""Extract answer from text using <ans> tags."""
|
| 24 |
+
m = re.search(r"<ans>(.*?)</ans>", text, re.DOTALL)
|
| 25 |
+
ans = m.group(1).strip() if m else ""
|
| 26 |
+
return ans
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def raw_dict_to_sample(
|
| 30 |
+
raw_data: Union[Tuple[Dict[str, Any], Dict[str, Any]], Dict[str, Any]],
|
| 31 |
+
max_frames: int = 16,
|
| 32 |
+
sample_type: str = "progress",
|
| 33 |
+
) -> Union[ProgressSample, PreferenceSample]:
|
| 34 |
+
"""
|
| 35 |
+
Convert raw data dictionary to a ProgressSample or PreferenceSample.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
raw_data: Dict with 'frames', 'task', 'id', 'metadata', 'video_embeddings', 'text_embedding' or Tuple of (Dict[str, Any], Dict[str, Any])
|
| 39 |
+
max_frames: Maximum number of frames to use (default: 16)
|
| 40 |
+
sample_type: Either "progress" or "preference" (default: "progress")
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
ProgressSample or PreferenceSample
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def _build_trajectory(raw_data: Dict[str, Any], num_frames: int) -> Trajectory:
|
| 47 |
+
processed_item: Dict[str, Any] = {}
|
| 48 |
+
|
| 49 |
+
# Process frames
|
| 50 |
+
frames_array = raw_data["frames"]
|
| 51 |
+
|
| 52 |
+
# Ensure we have the correct shape: (T, H, W, C)
|
| 53 |
+
if len(frames_array.shape) != 4:
|
| 54 |
+
raise ValueError(f"Expected 4D array (T, H, W, C), got shape {frames_array.shape}")
|
| 55 |
+
|
| 56 |
+
# Convert from CxHxW to HxWxC if needed
|
| 57 |
+
if frames_array.shape[1] == 3:
|
| 58 |
+
frames_array = np.transpose(frames_array, (0, 2, 3, 1))
|
| 59 |
+
|
| 60 |
+
frames_array, _ = linspace_subsample_frames(frames_array, num_frames)
|
| 61 |
+
dummy_progress = [0.0] * len(frames_array)
|
| 62 |
+
frames_array, _ = pad_trajectory_to_max_frames_np(frames_array, dummy_progress, num_frames, pad_from="right")
|
| 63 |
+
|
| 64 |
+
if frames_array.size == 0:
|
| 65 |
+
raise ValueError("No frames processed for example")
|
| 66 |
+
|
| 67 |
+
processed_item["frames"] = frames_array
|
| 68 |
+
processed_item["frames_shape"] = frames_array.shape
|
| 69 |
+
processed_item["task"] = raw_data["task"]
|
| 70 |
+
processed_item["lang_vector"] = None
|
| 71 |
+
processed_item["metadata"] = raw_data.get("metadata", None)
|
| 72 |
+
|
| 73 |
+
# Process video embeddings using same helper functions
|
| 74 |
+
video_embeddings = raw_data.get("video_embeddings")
|
| 75 |
+
if video_embeddings is not None:
|
| 76 |
+
video_embeddings, _ = linspace_subsample_frames(video_embeddings, num_frames)
|
| 77 |
+
dummy_progress_emb = [0.0] * len(video_embeddings)
|
| 78 |
+
video_embeddings, _ = pad_trajectory_to_max_frames_np(
|
| 79 |
+
video_embeddings, dummy_progress_emb, num_frames, pad_from="right"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
text_embedding = raw_data.get("text_embedding")
|
| 83 |
+
|
| 84 |
+
# Convert to tensors if they are numpy arrays
|
| 85 |
+
if video_embeddings is not None and isinstance(video_embeddings, np.ndarray):
|
| 86 |
+
video_embeddings = torch.tensor(video_embeddings)
|
| 87 |
+
if text_embedding is not None and isinstance(text_embedding, np.ndarray):
|
| 88 |
+
text_embedding = torch.tensor(text_embedding)
|
| 89 |
+
|
| 90 |
+
processed_item["video_embeddings"] = video_embeddings
|
| 91 |
+
processed_item["text_embedding"] = text_embedding
|
| 92 |
+
processed_item["video_shape"] = video_embeddings.shape if video_embeddings is not None else None
|
| 93 |
+
processed_item["text_shape"] = text_embedding.shape if text_embedding is not None else None
|
| 94 |
+
|
| 95 |
+
trajectory = Trajectory(**processed_item)
|
| 96 |
+
return trajectory
|
| 97 |
+
|
| 98 |
+
if sample_type == "progress":
|
| 99 |
+
assert isinstance(raw_data, dict), "raw_data must be a dictionary"
|
| 100 |
+
trajectory = _build_trajectory(raw_data=raw_data, num_frames=max_frames)
|
| 101 |
+
return ProgressSample(trajectory=trajectory)
|
| 102 |
+
elif sample_type == "preference":
|
| 103 |
+
assert isinstance(raw_data, tuple), "raw_data must be a tuple"
|
| 104 |
+
assert len(raw_data) == 2, "raw_data must be a tuple of two dictionaries"
|
| 105 |
+
trajectories: List[Trajectory] = []
|
| 106 |
+
for trajectory_data in raw_data:
|
| 107 |
+
trajectory = _build_trajectory(raw_data=trajectory_data, num_frames=max_frames)
|
| 108 |
+
trajectories.append(trajectory)
|
| 109 |
+
return PreferenceSample(chosen_trajectory=trajectories[0], rejected_trajectory=trajectories[1])
|
| 110 |
+
else:
|
| 111 |
+
raise ValueError(f"Unsupported sample_type: {sample_type}")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def build_payload(
|
| 115 |
+
samples: list[PreferenceSample | SimilaritySample | ProgressSample],
|
| 116 |
+
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
| 117 |
+
"""Build a payload with numpy array handling.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
samples: List of samples to convert
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Tuple of (files, sample_data) where:
|
| 124 |
+
- files: Dict of numpy arrays converted to .npy format
|
| 125 |
+
- sample_data: List of sample dictionaries with numpy arrays replaced by file references
|
| 126 |
+
"""
|
| 127 |
+
files = {}
|
| 128 |
+
sample_data = []
|
| 129 |
+
|
| 130 |
+
for sample_idx, sample in enumerate(samples):
|
| 131 |
+
# Copy the original sample and handle numpy arrays
|
| 132 |
+
processed_sample = sample.model_dump().copy()
|
| 133 |
+
|
| 134 |
+
# Handle trajectory objects with numpy arrays
|
| 135 |
+
for key in [
|
| 136 |
+
"chosen_trajectory",
|
| 137 |
+
"rejected_trajectory",
|
| 138 |
+
"reference_trajectory",
|
| 139 |
+
"traj_sim_trajectory",
|
| 140 |
+
"traj_diff_trajectory",
|
| 141 |
+
"trajectory",
|
| 142 |
+
]:
|
| 143 |
+
if key in processed_sample and isinstance(processed_sample[key], dict):
|
| 144 |
+
trajectory = processed_sample[key]
|
| 145 |
+
|
| 146 |
+
# Convert numpy arrays to .npy files
|
| 147 |
+
numpy_fields = ["frames", "lang_vector", "video_embeddings", "text_embedding"]
|
| 148 |
+
for field_name in numpy_fields:
|
| 149 |
+
# if it is a tensor, first convert it to a numpy array
|
| 150 |
+
if field_name in trajectory and isinstance(trajectory[field_name], torch.Tensor):
|
| 151 |
+
trajectory[field_name] = trajectory[field_name].numpy()
|
| 152 |
+
|
| 153 |
+
if field_name in trajectory and isinstance(trajectory[field_name], np.ndarray):
|
| 154 |
+
# Convert numpy array to .npy file
|
| 155 |
+
buf = io.BytesIO()
|
| 156 |
+
np.save(buf, trajectory[field_name])
|
| 157 |
+
buf.seek(0)
|
| 158 |
+
file_key = f"sample_{sample_idx}_{key}_{field_name}"
|
| 159 |
+
files[file_key] = (
|
| 160 |
+
f"sample_{sample_idx}_{key}_{field_name}.npy",
|
| 161 |
+
buf,
|
| 162 |
+
"application/octet-stream",
|
| 163 |
+
)
|
| 164 |
+
trajectory[field_name] = {"__numpy_file__": file_key}
|
| 165 |
+
|
| 166 |
+
sample_data.append(processed_sample)
|
| 167 |
+
|
| 168 |
+
return files, sample_data
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def post_batch(url: str, payload: dict[str, Any], timeout_s: float = 120.0) -> dict[str, Any]:
|
| 172 |
+
"""POST a batch payload to the evaluation server and return parsed JSON."""
|
| 173 |
+
resp = requests.post(url.rstrip("/") + "/evaluate_batch", json=payload, timeout=timeout_s)
|
| 174 |
+
resp.raise_for_status()
|
| 175 |
+
return resp.json()
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def post_batch_npy(
|
| 179 |
+
url: str, files: dict[str, Any], sample_data: list[dict[str, Any]], timeout_s: float = 120.0
|
| 180 |
+
) -> dict[str, Any]:
|
| 181 |
+
"""POST batch using .npy format for numpy arrays."""
|
| 182 |
+
# Convert sample_data to form data
|
| 183 |
+
data = {f"sample_{i}": json.dumps(sample) for i, sample in enumerate(sample_data)}
|
| 184 |
+
|
| 185 |
+
# Send as multipart form data
|
| 186 |
+
resp = requests.post(url.rstrip("/") + "/evaluate_batch_npy", files=files, data=data, timeout=timeout_s)
|
| 187 |
+
resp.raise_for_status()
|
| 188 |
+
return resp.json()
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
async def post_batch_npy_async(
|
| 192 |
+
session: aiohttp.ClientSession,
|
| 193 |
+
url: str,
|
| 194 |
+
files: dict[str, Any],
|
| 195 |
+
sample_data: list[dict[str, Any]],
|
| 196 |
+
timeout_s: float = 120.0,
|
| 197 |
+
) -> dict[str, Any]:
|
| 198 |
+
"""Async version of post_batch_npy using aiohttp."""
|
| 199 |
+
# Create FormData for aiohttp
|
| 200 |
+
form_data = aiohttp.FormData()
|
| 201 |
+
|
| 202 |
+
# Add files
|
| 203 |
+
for key, (filename, file_obj, content_type) in files.items():
|
| 204 |
+
form_data.add_field(key, file_obj, filename=filename, content_type=content_type)
|
| 205 |
+
|
| 206 |
+
# Add sample data
|
| 207 |
+
for i, sample in enumerate(sample_data):
|
| 208 |
+
form_data.add_field(f"sample_{i}", json.dumps(sample))
|
| 209 |
+
|
| 210 |
+
headers = {"Connection": "close"}
|
| 211 |
+
# Send as multipart form data using aiohttp
|
| 212 |
+
timeout = aiohttp.ClientTimeout(total=timeout_s)
|
| 213 |
+
async with session.post(
|
| 214 |
+
url.rstrip("/") + "/evaluate_batch_npy", data=form_data, timeout=timeout, headers=headers
|
| 215 |
+
) as resp:
|
| 216 |
+
resp.raise_for_status()
|
| 217 |
+
return await resp.json()
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
async def parse_npy_form_data(form_data: Any) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
|
| 221 |
+
"""Parse multipart form data to extract numpy arrays and other data.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
form_data: FastAPI form data from request.form()
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
Tuple of (numpy_arrays dict, other_data dict)
|
| 228 |
+
"""
|
| 229 |
+
numpy_arrays = {}
|
| 230 |
+
other_data = {}
|
| 231 |
+
|
| 232 |
+
for key, value in form_data.items():
|
| 233 |
+
# Check if this is a file upload (UploadFile object)
|
| 234 |
+
if hasattr(value, "filename") and value.filename:
|
| 235 |
+
# This is a file upload
|
| 236 |
+
if value.filename.endswith(".npy"):
|
| 237 |
+
# Load .npy file (await async read)
|
| 238 |
+
content = await value.read()
|
| 239 |
+
buf = io.BytesIO(content)
|
| 240 |
+
array = np.load(buf)
|
| 241 |
+
numpy_arrays[key] = array
|
| 242 |
+
else:
|
| 243 |
+
# Non-.npy file, skip for now
|
| 244 |
+
continue
|
| 245 |
+
else:
|
| 246 |
+
# This is a string value (form field)
|
| 247 |
+
try:
|
| 248 |
+
# Try to parse as JSON
|
| 249 |
+
other_data[key] = json.loads(value)
|
| 250 |
+
except (json.JSONDecodeError, TypeError):
|
| 251 |
+
# Keep as string if not JSON
|
| 252 |
+
other_data[key] = value
|
| 253 |
+
|
| 254 |
+
return numpy_arrays, other_data
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def reconstruct_payload_from_npy(
|
| 258 |
+
numpy_arrays: Dict[str, np.ndarray],
|
| 259 |
+
other_data: Dict[str, Any],
|
| 260 |
+
trajectory_keys: Optional[List[str]] = None,
|
| 261 |
+
convert_embeddings_to_torch: bool = False,
|
| 262 |
+
) -> List[Dict[str, Any]]:
|
| 263 |
+
"""Reconstruct the original payload structure from .npy files and form data.
|
| 264 |
+
|
| 265 |
+
The client sends data in this format:
|
| 266 |
+
- Files: sample_0_chosen_trajectory_frames.npy, sample_0_trajectory_frames.npy, etc.
|
| 267 |
+
- Data: sample_0, sample_1, etc. (each containing the full sample JSON with numpy file references)
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
numpy_arrays: Dictionary of numpy arrays loaded from .npy files
|
| 271 |
+
other_data: Dictionary of other form data
|
| 272 |
+
trajectory_keys: List of trajectory keys to process (default: common keys)
|
| 273 |
+
convert_embeddings_to_torch: Whether to convert embeddings to torch tensors
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
List of reconstructed sample dictionaries
|
| 277 |
+
"""
|
| 278 |
+
if trajectory_keys is None:
|
| 279 |
+
trajectory_keys = [
|
| 280 |
+
"chosen_trajectory",
|
| 281 |
+
"rejected_trajectory",
|
| 282 |
+
"reference_trajectory",
|
| 283 |
+
"traj_sim_trajectory",
|
| 284 |
+
"traj_diff_trajectory",
|
| 285 |
+
"trajectory",
|
| 286 |
+
]
|
| 287 |
+
|
| 288 |
+
samples = []
|
| 289 |
+
|
| 290 |
+
# Process each sample
|
| 291 |
+
for i in range(len(other_data)):
|
| 292 |
+
sample_key = f"sample_{i}"
|
| 293 |
+
if sample_key in other_data:
|
| 294 |
+
# Get the sample data - might already be parsed or might be a string
|
| 295 |
+
sample_data = other_data[sample_key]
|
| 296 |
+
if isinstance(sample_data, str):
|
| 297 |
+
# Parse the sample JSON if it's a string
|
| 298 |
+
sample_data = json.loads(sample_data)
|
| 299 |
+
|
| 300 |
+
# Replace numpy file references with actual arrays
|
| 301 |
+
for key, value in sample_data.items():
|
| 302 |
+
if key in trajectory_keys:
|
| 303 |
+
if isinstance(value, dict):
|
| 304 |
+
for traj_key, traj_value in value.items():
|
| 305 |
+
if isinstance(traj_value, dict) and traj_value.get("__numpy_file__"):
|
| 306 |
+
# Replace with actual numpy array
|
| 307 |
+
file_key = traj_value["__numpy_file__"]
|
| 308 |
+
if file_key in numpy_arrays:
|
| 309 |
+
value[traj_key] = numpy_arrays[file_key]
|
| 310 |
+
|
| 311 |
+
# Convert embeddings to torch if requested
|
| 312 |
+
if convert_embeddings_to_torch and traj_key in ["video_embeddings", "text_embedding"]:
|
| 313 |
+
if traj_key in value and value[traj_key] is not None:
|
| 314 |
+
if isinstance(value[traj_key], np.ndarray):
|
| 315 |
+
value[traj_key] = torch.tensor(value[traj_key])
|
| 316 |
+
elif isinstance(value[traj_key], list):
|
| 317 |
+
value[traj_key] = torch.tensor(value[traj_key])
|
| 318 |
+
|
| 319 |
+
samples.append(sample_data)
|
| 320 |
+
|
| 321 |
+
return samples
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def find_video_files(directory: str) -> list[str]:
|
| 325 |
+
"""Find all video files in a directory.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
directory: Path to directory containing video files
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
List of paths to video files
|
| 332 |
+
"""
|
| 333 |
+
video_extensions = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".wmv", ".m4v"}
|
| 334 |
+
video_files = []
|
| 335 |
+
|
| 336 |
+
directory_path = Path(directory)
|
| 337 |
+
if not directory_path.is_dir():
|
| 338 |
+
return []
|
| 339 |
+
|
| 340 |
+
for file_path in directory_path.iterdir():
|
| 341 |
+
if file_path.is_file() and file_path.suffix.lower() in video_extensions:
|
| 342 |
+
video_files.append(str(file_path))
|
| 343 |
+
|
| 344 |
+
video_files.sort()
|
| 345 |
+
return video_files
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def infer_task_from_video_name(video_path: str) -> str:
|
| 349 |
+
"""Infer task name from video filename.
|
| 350 |
+
|
| 351 |
+
Task is everything before the comma (if comma exists), or everything before success/fail/failure.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
video_path: Path to video file
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
Inferred task name
|
| 358 |
+
"""
|
| 359 |
+
video_name = Path(video_path).stem # Get filename without extension
|
| 360 |
+
|
| 361 |
+
# If there's a comma, task is everything before the comma
|
| 362 |
+
if "," in video_name:
|
| 363 |
+
task_part = video_name.split(",")[0]
|
| 364 |
+
else:
|
| 365 |
+
# Otherwise, split by underscore and remove success/fail/failure suffixes
|
| 366 |
+
parts = video_name.split("_")
|
| 367 |
+
filtered_parts = []
|
| 368 |
+
for part in parts:
|
| 369 |
+
part_lower = part.lower()
|
| 370 |
+
if part_lower not in ["success", "fail", "failure"]:
|
| 371 |
+
filtered_parts.append(part)
|
| 372 |
+
|
| 373 |
+
if not filtered_parts:
|
| 374 |
+
return "Complete the task"
|
| 375 |
+
|
| 376 |
+
task_part = "_".join(filtered_parts)
|
| 377 |
+
|
| 378 |
+
# Split by underscore and join with spaces
|
| 379 |
+
task_words = task_part.split("_")
|
| 380 |
+
task = " ".join(task_words)
|
| 381 |
+
|
| 382 |
+
if task:
|
| 383 |
+
# Capitalize first letter of first word, keep rest as is
|
| 384 |
+
task = task[0].upper() + task[1:] if len(task) > 1 else task.upper()
|
| 385 |
+
else:
|
| 386 |
+
task = "Complete the task"
|
| 387 |
+
|
| 388 |
+
return task
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def setup_output_directory(output_dir: Optional[str], video_path: Optional[str] = None) -> str:
|
| 392 |
+
"""Create output directory and return path."""
|
| 393 |
+
if output_dir:
|
| 394 |
+
save_dir = output_dir
|
| 395 |
+
else:
|
| 396 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 397 |
+
save_dir = os.path.join(".", f"eval_outputs/{timestamp}")
|
| 398 |
+
|
| 399 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 400 |
+
return save_dir
|
eval_viz_utils.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Utility functions for visualization in RFM evaluations.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Optional
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
import tempfile
|
| 10 |
+
import numpy as np
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import decord
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def create_combined_progress_success_plot(
|
| 18 |
+
progress_pred: np.ndarray,
|
| 19 |
+
num_frames: int,
|
| 20 |
+
success_binary: Optional[np.ndarray] = None,
|
| 21 |
+
success_probs: Optional[np.ndarray] = None,
|
| 22 |
+
success_labels: Optional[np.ndarray] = None,
|
| 23 |
+
is_discrete_mode: bool = False,
|
| 24 |
+
title: Optional[str] = None,
|
| 25 |
+
loss: Optional[float] = None,
|
| 26 |
+
pearson: Optional[float] = None,
|
| 27 |
+
) -> plt.Figure:
|
| 28 |
+
"""Create a combined plot with progress, success binary, and success probabilities.
|
| 29 |
+
|
| 30 |
+
This function creates a unified plot with 1 subplot (progress only) or 3 subplots
|
| 31 |
+
(progress, success binary, success probs), similar to the one used in compile_results.py.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
progress_pred: Progress predictions array
|
| 35 |
+
num_frames: Number of frames
|
| 36 |
+
success_binary: Optional binary success predictions
|
| 37 |
+
success_probs: Optional success probability predictions
|
| 38 |
+
success_labels: Optional ground truth success labels
|
| 39 |
+
is_discrete_mode: Whether progress is in discrete mode (deprecated, kept for compatibility)
|
| 40 |
+
title: Optional title for the plot (if None, auto-generated from loss/pearson)
|
| 41 |
+
loss: Optional loss value to display in title
|
| 42 |
+
pearson: Optional pearson correlation to display in title
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
matplotlib Figure object
|
| 46 |
+
"""
|
| 47 |
+
# Determine if we should show success plots
|
| 48 |
+
has_success_binary = success_binary is not None and len(success_binary) == len(progress_pred)
|
| 49 |
+
|
| 50 |
+
if has_success_binary:
|
| 51 |
+
# Three subplots: progress, success (binary), success_probs
|
| 52 |
+
fig, axs = plt.subplots(1, 3, figsize=(15, 3.5))
|
| 53 |
+
ax = axs[0] # Progress subplot
|
| 54 |
+
ax2 = axs[1] # Success subplot (binary)
|
| 55 |
+
ax3 = axs[2] # Success probs subplot
|
| 56 |
+
else:
|
| 57 |
+
# Single subplot: progress only
|
| 58 |
+
fig, ax = plt.subplots(figsize=(6, 3.5))
|
| 59 |
+
ax2 = None
|
| 60 |
+
ax3 = None
|
| 61 |
+
|
| 62 |
+
# Plot progress
|
| 63 |
+
ax.plot(progress_pred, linewidth=2)
|
| 64 |
+
ax.set_ylabel("Progress")
|
| 65 |
+
|
| 66 |
+
# Build title
|
| 67 |
+
if title is None:
|
| 68 |
+
title_parts = ["Progress"]
|
| 69 |
+
if loss is not None:
|
| 70 |
+
title_parts.append(f"Loss: {loss:.3f}")
|
| 71 |
+
if pearson is not None:
|
| 72 |
+
title_parts.append(f"Pearson: {pearson:.2f}")
|
| 73 |
+
title = ", ".join(title_parts)
|
| 74 |
+
fig.suptitle(title)
|
| 75 |
+
|
| 76 |
+
# Set y-limits and ticks (always continuous since discrete is converted before this function)
|
| 77 |
+
ax.set_ylim(0, 1)
|
| 78 |
+
ax.spines["right"].set_visible(False)
|
| 79 |
+
ax.spines["top"].set_visible(False)
|
| 80 |
+
y_ticks = [0, 0.2, 0.4, 0.6, 0.8, 1.0]
|
| 81 |
+
ax.set_yticks(y_ticks)
|
| 82 |
+
|
| 83 |
+
# Setup success binary subplot
|
| 84 |
+
if ax2 is not None:
|
| 85 |
+
ax2.step(range(len(success_binary)), success_binary, where="post", linewidth=2, label="Predicted", color="blue")
|
| 86 |
+
# Add ground truth success labels as green line if available
|
| 87 |
+
if success_labels is not None and len(success_labels) == len(success_binary):
|
| 88 |
+
ax2.step(
|
| 89 |
+
range(len(success_labels)),
|
| 90 |
+
success_labels,
|
| 91 |
+
where="post",
|
| 92 |
+
linewidth=2,
|
| 93 |
+
label="Ground Truth",
|
| 94 |
+
color="green",
|
| 95 |
+
)
|
| 96 |
+
ax2.set_ylabel("Success (Binary)")
|
| 97 |
+
ax2.set_ylim(-0.05, 1.05)
|
| 98 |
+
ax2.spines["right"].set_visible(False)
|
| 99 |
+
ax2.spines["top"].set_visible(False)
|
| 100 |
+
ax2.set_yticks([0, 1])
|
| 101 |
+
ax2.legend()
|
| 102 |
+
|
| 103 |
+
# Setup success probs subplot if available
|
| 104 |
+
if ax3 is not None and success_probs is not None:
|
| 105 |
+
ax3.plot(range(len(success_probs)), success_probs, linewidth=2, label="Success Prob", color="purple")
|
| 106 |
+
# Add ground truth success labels as green line if available
|
| 107 |
+
if success_labels is not None and len(success_labels) == len(success_probs):
|
| 108 |
+
ax3.step(
|
| 109 |
+
range(len(success_labels)),
|
| 110 |
+
success_labels,
|
| 111 |
+
where="post",
|
| 112 |
+
linewidth=2,
|
| 113 |
+
label="Ground Truth",
|
| 114 |
+
color="green",
|
| 115 |
+
linestyle="--",
|
| 116 |
+
)
|
| 117 |
+
ax3.set_ylabel("Success Probability")
|
| 118 |
+
ax3.set_ylim(-0.05, 1.05)
|
| 119 |
+
ax3.spines["right"].set_visible(False)
|
| 120 |
+
ax3.spines["top"].set_visible(False)
|
| 121 |
+
ax3.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
|
| 122 |
+
ax3.legend()
|
| 123 |
+
|
| 124 |
+
plt.tight_layout()
|
| 125 |
+
return fig
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def extract_frames(video_path: str, fps: float = 1.0, max_frames: int = 64) -> np.ndarray:
|
| 129 |
+
"""Extract frames from video file as numpy array (T, H, W, C).
|
| 130 |
+
|
| 131 |
+
Supports both local file paths and URLs (e.g., HuggingFace Hub URLs).
|
| 132 |
+
Uses the provided ``fps`` to control how densely frames are sampled from
|
| 133 |
+
the underlying video, but caps the total number of frames at ``max_frames``
|
| 134 |
+
to prevent memory issues.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
video_path: Path to video file or URL
|
| 138 |
+
fps: Frames per second to extract (default: 1.0)
|
| 139 |
+
max_frames: Maximum number of frames to extract (default: 64). This prevents
|
| 140 |
+
memory issues with long videos or high FPS settings.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
numpy array of shape (T, H, W, C) containing extracted frames, or None if error
|
| 144 |
+
"""
|
| 145 |
+
if video_path is None:
|
| 146 |
+
return None
|
| 147 |
+
|
| 148 |
+
if isinstance(video_path, tuple):
|
| 149 |
+
video_path = video_path[0]
|
| 150 |
+
|
| 151 |
+
# Check if it's a URL or local file
|
| 152 |
+
is_url = video_path.startswith(("http://", "https://"))
|
| 153 |
+
is_local_file = os.path.exists(video_path) if not is_url else False
|
| 154 |
+
|
| 155 |
+
if not is_url and not is_local_file:
|
| 156 |
+
logger.warning(f"Video path does not exist: {video_path}")
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
# decord.VideoReader can handle both local files and URLs
|
| 161 |
+
vr = decord.VideoReader(video_path, num_threads=1)
|
| 162 |
+
total_frames = len(vr)
|
| 163 |
+
|
| 164 |
+
# Determine native FPS; fall back to a reasonable default if unavailable
|
| 165 |
+
try:
|
| 166 |
+
native_fps = float(vr.get_avg_fps())
|
| 167 |
+
except Exception:
|
| 168 |
+
native_fps = 1.0
|
| 169 |
+
|
| 170 |
+
# If user-specified fps is invalid or None, default to native fps
|
| 171 |
+
if fps is None or fps <= 0:
|
| 172 |
+
fps = native_fps
|
| 173 |
+
|
| 174 |
+
# Compute how many frames we want based on desired fps
|
| 175 |
+
# num_frames ≈ total_duration * fps = total_frames * (fps / native_fps)
|
| 176 |
+
if native_fps > 0:
|
| 177 |
+
desired_frames = int(round(total_frames * (fps / native_fps)))
|
| 178 |
+
else:
|
| 179 |
+
desired_frames = total_frames
|
| 180 |
+
|
| 181 |
+
# Clamp to [1, total_frames]
|
| 182 |
+
desired_frames = max(1, min(desired_frames, total_frames))
|
| 183 |
+
|
| 184 |
+
# IMPORTANT: Cap at max_frames to prevent memory issues
|
| 185 |
+
# This is critical when fps is high or videos are long
|
| 186 |
+
if desired_frames > max_frames:
|
| 187 |
+
logger.warning(
|
| 188 |
+
f"Requested {desired_frames} frames but capping at {max_frames} "
|
| 189 |
+
f"to prevent memory issues (video has {total_frames} frames at {native_fps:.2f} fps, "
|
| 190 |
+
f"requested extraction at {fps:.2f} fps)"
|
| 191 |
+
)
|
| 192 |
+
desired_frames = max_frames
|
| 193 |
+
|
| 194 |
+
# Evenly sample indices to match the desired number of frames
|
| 195 |
+
if desired_frames == total_frames:
|
| 196 |
+
frame_indices = list(range(total_frames))
|
| 197 |
+
else:
|
| 198 |
+
frame_indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
|
| 199 |
+
|
| 200 |
+
frames_array = vr.get_batch(frame_indices).asnumpy() # Shape: (T, H, W, C)
|
| 201 |
+
del vr
|
| 202 |
+
return frames_array
|
| 203 |
+
except Exception as e:
|
| 204 |
+
logger.error(f"Error extracting frames from {video_path}: {e}")
|
| 205 |
+
return None
|
requirements.txt
CHANGED
|
@@ -26,8 +26,7 @@ watchfiles # For file watching during development
|
|
| 26 |
|
| 27 |
# RFM package (installed from git repository)
|
| 28 |
# For local development, you can also install with: pip install -e ../ (from parent directory)
|
| 29 |
-
|
| 30 |
-
git+https://github.com/aliang8/reward_fm.git@anthony_working
|
| 31 |
|
| 32 |
# Make sure a newer version of gradio is installed
|
| 33 |
gradio==4.44.0
|
|
|
|
| 26 |
|
| 27 |
# RFM package (installed from git repository)
|
| 28 |
# For local development, you can also install with: pip install -e ../ (from parent directory)
|
| 29 |
+
# git+https://github.com/aliang8/reward_fm.git@anthony_working
|
|
|
|
| 30 |
|
| 31 |
# Make sure a newer version of gradio is installed
|
| 32 |
gradio==4.44.0
|
samplers/README.md
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Sampler Strategies Documentation
|
| 2 |
+
|
| 3 |
+
This document summarizes the data generation strategies used by each sampler type in the RFM data pipeline.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
The codebase contains three main sampler types:
|
| 8 |
+
- **SimSampler**: Generates similarity scoring samples
|
| 9 |
+
- **PrefSampler**: Generates preference prediction samples
|
| 10 |
+
- **ProgressSampler**: Generates progress prediction samples
|
| 11 |
+
|
| 12 |
+
Each sampler implements multiple strategies for generating training data, with automatic retry logic and strategy rebalancing on failure.
|
| 13 |
+
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
## SimSampler (Similarity Scoring)
|
| 17 |
+
|
| 18 |
+
The `SimSampler` creates similarity scoring samples where two trajectories (`o^1` and `o^2`) are ranked against a reference trajectory (`o^ref`). The goal is to learn that `o^1` should be ranked higher than `o^2`.
|
| 19 |
+
|
| 20 |
+
### Strategies
|
| 21 |
+
|
| 22 |
+
#### 1. **REWIND**
|
| 23 |
+
- **Description**: Creates a similarity sample where `o^1` is an optimal trajectory from the same task, and `o^2` is a rewound (subsampled) version of the reference trajectory.
|
| 24 |
+
- **Purpose**: Learn to distinguish between optimal and suboptimal trajectories from the same task.
|
| 25 |
+
- **Implementation**:
|
| 26 |
+
- `traj_sim`: Optimal trajectory from same task (via `_get_same_task_optimal`)
|
| 27 |
+
- `traj_diff`: Rewound trajectory from reference (via `subsample_rewind`)
|
| 28 |
+
|
| 29 |
+
#### 2. **SUBOPTIMAL**
|
| 30 |
+
- **Description**: Creates a similarity sample where `o^1` is an optimal trajectory from the same task, and `o^2` is a suboptimal/failure trajectory from the same task.
|
| 31 |
+
- **Purpose**: Learn to distinguish between optimal and suboptimal trajectories from the same task.
|
| 32 |
+
- **Conditions**: Only available when:
|
| 33 |
+
- Data source is in the failure category (`is_failure_ds`)
|
| 34 |
+
- Probability is boosted by 2x for failure category data sources
|
| 35 |
+
- **Implementation**:
|
| 36 |
+
- `traj_sim`: Optimal trajectory from same task (via `_get_same_task_optimal`)
|
| 37 |
+
- `traj_diff`: Suboptimal trajectory from same task (via `_get_same_task_suboptimal`)
|
| 38 |
+
|
| 39 |
+
#### 3. **PAIRED_HUMAN_ROBOT**
|
| 40 |
+
- **Description**: Creates a similarity sample where `o^1` is a paired human/robot trajectory (opposite type from reference, same task), and `o^2` is from a different task.
|
| 41 |
+
- **Purpose**: Learn to distinguish between same-task and different-task trajectories, leveraging paired human/robot demonstrations.
|
| 42 |
+
- **Conditions**: Only available when:
|
| 43 |
+
- Data source is in the paired category (`is_paired_ds`)
|
| 44 |
+
- Paired human/robot data exists for the task
|
| 45 |
+
- Probability is boosted by 2x for paired category data sources
|
| 46 |
+
- **Implementation**:
|
| 47 |
+
- `traj_sim`: Paired human/robot trajectory (via `_get_paired_human_robot_traj`)
|
| 48 |
+
- `traj_diff`: Trajectory from different task (via `_get_different_video_traj`)
|
| 49 |
+
|
| 50 |
+
### Strategy Selection
|
| 51 |
+
- Strategies are selected probabilistically based on `similarity_strategy_ratio` configuration
|
| 52 |
+
- Probabilities are rebalanced when strategies fail
|
| 53 |
+
- Strategies are removed after 4 consecutive failures
|
| 54 |
+
- Maximum 10 total attempts per sample generation
|
| 55 |
+
|
| 56 |
+
### Reference Trajectory Requirements
|
| 57 |
+
- For non-RoboArena: Must have `quality_label == "successful"`
|
| 58 |
+
- For RoboArena: Must have `partial_success` field present
|
| 59 |
+
|
| 60 |
+
---
|
| 61 |
+
|
| 62 |
+
## PrefSampler (Preference Prediction)
|
| 63 |
+
|
| 64 |
+
The `PrefSampler` creates preference prediction samples with a chosen (preferred) trajectory and a rejected (suboptimal) trajectory.
|
| 65 |
+
|
| 66 |
+
### Strategies
|
| 67 |
+
|
| 68 |
+
#### 1. **REWIND**
|
| 69 |
+
- **Description**: Uses the same optimal trajectory for both chosen and rejected, but applies rewind subsampling to the rejected trajectory.
|
| 70 |
+
- **Purpose**: Learn that full trajectories are preferred over truncated/rewound versions.
|
| 71 |
+
- **Implementation**:
|
| 72 |
+
- `chosen_trajectory`: Original optimal trajectory (forward subsampling)
|
| 73 |
+
- `rejected_trajectory`: Same trajectory with `subsample_rewind` strategy
|
| 74 |
+
|
| 75 |
+
#### 2. **SUBOPTIMAL**
|
| 76 |
+
- **Description**: Uses an optimal trajectory as chosen and a suboptimal/failure trajectory from the same task as rejected.
|
| 77 |
+
- **Purpose**: Learn to prefer optimal trajectories over suboptimal ones from the same task.
|
| 78 |
+
- **Conditions**: Only available when suboptimal trajectories exist for the task
|
| 79 |
+
- **Implementation**:
|
| 80 |
+
- `chosen_trajectory`: Optimal trajectory
|
| 81 |
+
- `rejected_trajectory`: Suboptimal trajectory from same task (via `_get_same_task_suboptimal`)
|
| 82 |
+
|
| 83 |
+
#### 3. **DIFFERENT_TASK**
|
| 84 |
+
- **Description**: Uses an optimal trajectory as chosen and a trajectory from a completely different task as rejected.
|
| 85 |
+
- **Purpose**: Learn that trajectories from the same task are preferred over trajectories from different tasks.
|
| 86 |
+
- **Implementation**:
|
| 87 |
+
- `chosen_trajectory`: Optimal trajectory
|
| 88 |
+
- `rejected_trajectory`: Trajectory from different task (via `_get_different_video_traj`)
|
| 89 |
+
- **Note**: Rejected trajectory's `target_progress` is set to `[0.0]` for all timesteps
|
| 90 |
+
|
| 91 |
+
#### 4. **REVERSE_PROGRESS**
|
| 92 |
+
- **Description**: Uses the same optimal trajectory for both chosen and rejected, but applies reverse uniform sampling to the rejected trajectory.
|
| 93 |
+
- **Purpose**: Learn that forward progress is preferred over reverse progress.
|
| 94 |
+
- **Implementation**:
|
| 95 |
+
- `chosen_trajectory`: Original optimal trajectory (forward subsampling)
|
| 96 |
+
- `rejected_trajectory`: Same trajectory with `subsample_reverse` strategy
|
| 97 |
+
|
| 98 |
+
#### 5. **ROBOARENA_PARTIAL_SUCCESS**
|
| 99 |
+
- **Description**: Uses two trajectories from the same task with different `partial_success` values. The trajectory with higher `partial_success` becomes chosen, and the one with lower `partial_success` becomes rejected.
|
| 100 |
+
- **Purpose**: Learn to prefer trajectories with higher partial success scores (RoboArena-specific).
|
| 101 |
+
- **Conditions**: Only available for RoboArena trajectories (has `partial_success` field and data_source contains "roboarena")
|
| 102 |
+
- **Implementation**:
|
| 103 |
+
- Finds a different trajectory from same task (via `_get_different_partial_success_traj`)
|
| 104 |
+
- Swaps trajectories if found trajectory has higher `partial_success`
|
| 105 |
+
- `chosen_trajectory`: Trajectory with higher `partial_success`
|
| 106 |
+
- `rejected_trajectory`: Trajectory with lower `partial_success`
|
| 107 |
+
|
| 108 |
+
### Special Handling
|
| 109 |
+
- **Non-successful trajectories**: If a trajectory has `quality_label != "successful"` (and is not RoboArena), it is automatically used as the rejected trajectory, with an optimal trajectory from the same task as the chosen trajectory.
|
| 110 |
+
|
| 111 |
+
### Strategy Selection
|
| 112 |
+
- Strategies are selected probabilistically based on `preference_strategy_ratio` configuration
|
| 113 |
+
- Probabilities are rebalanced when strategies fail
|
| 114 |
+
- Strategies are removed after 3 consecutive failures
|
| 115 |
+
- Maximum 10 total attempts per sample generation
|
| 116 |
+
|
| 117 |
+
---
|
| 118 |
+
|
| 119 |
+
## ProgressSampler (Progress Prediction)
|
| 120 |
+
|
| 121 |
+
The `ProgressSampler` creates progress prediction samples from a single trajectory, applying different subsampling strategies to create training data.
|
| 122 |
+
|
| 123 |
+
### Strategies
|
| 124 |
+
|
| 125 |
+
#### 1. **DIFFERENT_TASK_INSTRUCTION**
|
| 126 |
+
- **Description**: Uses a trajectory from a different task, but keeps the original task's embeddings and instruction.
|
| 127 |
+
- **Purpose**: Learn that progress should be 0.0 when the trajectory doesn't match the task instruction.
|
| 128 |
+
- **Implementation**:
|
| 129 |
+
- Gets trajectory from different task (via `_get_different_task_instruction`)
|
| 130 |
+
- Replaces embeddings with original task's embeddings
|
| 131 |
+
- Sets `target_progress = [0.0]` for all timesteps
|
| 132 |
+
- Uses forward subsampling
|
| 133 |
+
|
| 134 |
+
#### 2. **FORWARD_PROGRESS**
|
| 135 |
+
- **Description**: Samples the same trajectory with forward direction (start < middle < end).
|
| 136 |
+
- **Purpose**: Learn normal forward progress patterns.
|
| 137 |
+
- **Implementation**:
|
| 138 |
+
- Uses same trajectory with `subsample_forward` strategy
|
| 139 |
+
- Progress increases from start to end
|
| 140 |
+
|
| 141 |
+
#### 3. **REVERSE_PROGRESS**
|
| 142 |
+
- **Description**: Samples the same trajectory with reverse direction (end < middle < start).
|
| 143 |
+
- **Purpose**: Learn to handle reverse progress scenarios.
|
| 144 |
+
- **Implementation**:
|
| 145 |
+
- Uses same trajectory with `subsample_reverse` strategy
|
| 146 |
+
- Progress decreases from start to end
|
| 147 |
+
|
| 148 |
+
#### 4. **REWIND**
|
| 149 |
+
- **Description**: Samples the same trajectory with rewind direction (start < end < middle).
|
| 150 |
+
- **Purpose**: Learn to handle non-monotonic progress patterns.
|
| 151 |
+
- **Implementation**:
|
| 152 |
+
- Uses same trajectory with `subsample_rewind` strategy
|
| 153 |
+
- Progress pattern: increases, then decreases
|
| 154 |
+
|
| 155 |
+
### Strategy Selection
|
| 156 |
+
- Strategies are selected probabilistically based on `progress_strategy_ratio` configuration
|
| 157 |
+
- Probabilities are rebalanced when strategies fail
|
| 158 |
+
- Failed strategies are immediately removed (no retry count threshold)
|
| 159 |
+
- Maximum 10 total attempts per sample generation
|
| 160 |
+
|
| 161 |
+
---
|
| 162 |
+
|
| 163 |
+
## Common Features
|
| 164 |
+
|
| 165 |
+
### Retry Logic
|
| 166 |
+
All samplers implement retry logic with:
|
| 167 |
+
- Maximum attempt limits (typically 10 attempts)
|
| 168 |
+
- Strategy-specific retry counts (3-4 attempts per strategy)
|
| 169 |
+
- Automatic strategy removal after consecutive failures
|
| 170 |
+
- Probability rebalancing when strategies are removed
|
| 171 |
+
|
| 172 |
+
### Subsample Strategies
|
| 173 |
+
Common subsampling strategies used across samplers:
|
| 174 |
+
- `subsample_forward`: Normal forward sampling (start → end)
|
| 175 |
+
- `subsample_reverse`: Reverse sampling (end → start)
|
| 176 |
+
- `subsample_rewind`: Rewind sampling (start → end → start)
|
| 177 |
+
|
| 178 |
+
### Data Source Filtering
|
| 179 |
+
- Strategies may be filtered or boosted based on data source categories:
|
| 180 |
+
- **Failure category**: Boosts SUBOPTIMAL strategy probability by 2x
|
| 181 |
+
- **Paired category**: Boosts PAIRED_HUMAN_ROBOT strategy probability by 2x
|
| 182 |
+
- **RoboArena**: Special handling for `partial_success` field
|
samplers/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rfm.data.samplers.base import RFMBaseSampler
|
| 2 |
+
from rfm.data.samplers.pref import PrefSampler
|
| 3 |
+
from rfm.data.samplers.sim import SimSampler
|
| 4 |
+
from rfm.data.samplers.progress import ProgressSampler
|
| 5 |
+
from rfm.data.samplers.eval.confusion_matrix import ConfusionMatrixSampler
|
| 6 |
+
from rfm.data.samplers.eval.progress_policy_ranking import ProgressPolicyRankingSampler
|
| 7 |
+
from rfm.data.samplers.eval.reward_alignment import RewardAlignmentSampler
|
| 8 |
+
from rfm.data.samplers.eval.quality_preference import QualityPreferenceSampler
|
| 9 |
+
from rfm.data.samplers.eval.roboarena_quality_preference import RoboArenaQualityPreferenceSampler
|
| 10 |
+
from rfm.data.samplers.eval.similarity_score import SimilarityScoreSampler
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"RFMBaseSampler",
|
| 14 |
+
"PrefSampler",
|
| 15 |
+
"SimSampler",
|
| 16 |
+
"ProgressSampler",
|
| 17 |
+
"ConfusionMatrixSampler",
|
| 18 |
+
"ProgressPolicyRankingSampler",
|
| 19 |
+
"RewardAlignmentSampler",
|
| 20 |
+
"QualityPreferenceSampler",
|
| 21 |
+
"RoboArenaQualityPreferenceSampler",
|
| 22 |
+
"SimilarityScoreSampler",
|
| 23 |
+
]
|
samplers/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.01 kB). View file
|
|
|
samplers/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.2 kB). View file
|
|
|
samplers/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (19.8 kB). View file
|
|
|
samplers/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (33.1 kB). View file
|
|
|
samplers/__pycache__/confusion_matrix.cpython-310.pyc
ADDED
|
Binary file (4.58 kB). View file
|
|
|
samplers/__pycache__/confusion_matrix.cpython-311.pyc
ADDED
|
Binary file (7.71 kB). View file
|
|
|
samplers/__pycache__/pref.cpython-310.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
samplers/__pycache__/pref.cpython-311.pyc
ADDED
|
Binary file (16.9 kB). View file
|
|
|
samplers/__pycache__/progress.cpython-310.pyc
ADDED
|
Binary file (5.11 kB). View file
|
|
|
samplers/__pycache__/progress.cpython-311.pyc
ADDED
|
Binary file (8.45 kB). View file
|
|
|
samplers/__pycache__/progress_default.cpython-310.pyc
ADDED
|
Binary file (3.98 kB). View file
|
|
|
samplers/__pycache__/progress_default.cpython-311.pyc
ADDED
|
Binary file (6.48 kB). View file
|
|
|
samplers/__pycache__/quality_preference.cpython-310.pyc
ADDED
|
Binary file (5.49 kB). View file
|
|
|
samplers/__pycache__/quality_preference.cpython-311.pyc
ADDED
|
Binary file (9.6 kB). View file
|
|
|
samplers/__pycache__/reward_alignment.cpython-310.pyc
ADDED
|
Binary file (4.5 kB). View file
|
|
|
samplers/__pycache__/reward_alignment.cpython-311.pyc
ADDED
|
Binary file (7.15 kB). View file
|
|
|
samplers/__pycache__/roboarena.cpython-310.pyc
ADDED
|
Binary file (4.82 kB). View file
|
|
|
samplers/__pycache__/sim.cpython-310.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
samplers/__pycache__/sim.cpython-311.pyc
ADDED
|
Binary file (19.6 kB). View file
|
|
|
samplers/__pycache__/success_failure.cpython-310.pyc
ADDED
|
Binary file (4.27 kB). View file
|
|
|
samplers/__pycache__/success_failure.cpython-311.pyc
ADDED
|
Binary file (7.05 kB). View file
|
|
|
samplers/base.py
ADDED
|
@@ -0,0 +1,753 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from typing import Optional, Dict, Any, List, Set, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
import torch
|
| 7 |
+
from random import Random
|
| 8 |
+
from datasets import Dataset
|
| 9 |
+
|
| 10 |
+
from rfm.configs.experiment_configs import DataConfig
|
| 11 |
+
from rfm.data.datasets.helpers import (
|
| 12 |
+
load_frames_from_npz,
|
| 13 |
+
get_segment_indices_with_middle,
|
| 14 |
+
compute_progress_from_segment,
|
| 15 |
+
pad_trajectory_to_max_frames_torch,
|
| 16 |
+
pad_trajectory_to_max_frames_np,
|
| 17 |
+
compute_success_labels,
|
| 18 |
+
create_trajectory_from_dict,
|
| 19 |
+
load_embeddings_from_path,
|
| 20 |
+
linspace_subsample_frames,
|
| 21 |
+
convert_continuous_to_discrete_bins,
|
| 22 |
+
)
|
| 23 |
+
from rfm.data.dataset_types import Trajectory
|
| 24 |
+
from rfm.utils.logger import get_logger
|
| 25 |
+
|
| 26 |
+
logger = get_logger()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class RFMBaseSampler:
|
| 30 |
+
"""Base sampler class that provides trajectory retrieval functions for generating samples."""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
config: DataConfig,
|
| 35 |
+
dataset: Dataset,
|
| 36 |
+
combined_indices: Dict[str, Any],
|
| 37 |
+
dataset_success_cutoff_map: Optional[Dict[str, float]] = None,
|
| 38 |
+
verbose: bool = True,
|
| 39 |
+
random_seed: int = 42,
|
| 40 |
+
):
|
| 41 |
+
"""Initialize sampler with dataset and indices.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
config: Configuration object
|
| 45 |
+
dataset: The loaded dataset
|
| 46 |
+
combined_indices: Dictionary of combined indices from dataset loading
|
| 47 |
+
dataset_success_cutoff_map: Dictionary mapping dataset names to success cutoff percentages
|
| 48 |
+
verbose: Verbose flag
|
| 49 |
+
random_seed: Random seed for deterministic sampling. Creates a local Random instance to avoid affecting global random state.
|
| 50 |
+
"""
|
| 51 |
+
self.config = config
|
| 52 |
+
self.dataset = dataset
|
| 53 |
+
self.verbose = verbose
|
| 54 |
+
self.dataset_success_cutoff_map = dataset_success_cutoff_map or {}
|
| 55 |
+
self._local_random = Random(random_seed)
|
| 56 |
+
|
| 57 |
+
self._cached_ids = self.dataset["id"]
|
| 58 |
+
self._cached_is_robot = self.dataset["is_robot"]
|
| 59 |
+
|
| 60 |
+
# Build indices from combined_indices
|
| 61 |
+
self._build_indices(combined_indices)
|
| 62 |
+
|
| 63 |
+
def _build_indices(self, combined_indices):
|
| 64 |
+
"""Build all index mappings from combined_indices.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
combined_indices: Dictionary of combined indices from dataset loading
|
| 68 |
+
"""
|
| 69 |
+
# Initialize index mappings from the loaded indices
|
| 70 |
+
self.robot_trajectories = combined_indices["robot_trajectories"]
|
| 71 |
+
self.human_trajectories = combined_indices["human_trajectories"]
|
| 72 |
+
self.optimal_by_task = combined_indices["optimal_by_task"]
|
| 73 |
+
self.suboptimal_by_task = combined_indices["suboptimal_by_task"]
|
| 74 |
+
self.quality_indices = combined_indices["quality_indices"]
|
| 75 |
+
self.task_indices = combined_indices["task_indices"]
|
| 76 |
+
self.source_indices = combined_indices["source_indices"]
|
| 77 |
+
self.partial_success_indices = combined_indices["partial_success_indices"]
|
| 78 |
+
self.paired_human_robot_by_task = combined_indices["paired_human_robot_by_task"]
|
| 79 |
+
self.tasks_with_multiple_quality_labels = combined_indices["tasks_with_multiple_quality_labels"]
|
| 80 |
+
|
| 81 |
+
# Build mapping from data source -> available task instructions
|
| 82 |
+
self._build_tasks_by_data_source()
|
| 83 |
+
|
| 84 |
+
def _build_tasks_by_data_source(self):
|
| 85 |
+
"""Cache mapping from data_source to available task instructions."""
|
| 86 |
+
self.tasks_by_data_source: Dict[str, List[str]] = {}
|
| 87 |
+
|
| 88 |
+
all_tasks = self.dataset["task"]
|
| 89 |
+
all_sources = self.dataset["data_source"]
|
| 90 |
+
|
| 91 |
+
source_to_tasks: Dict[str, Set[str]] = {}
|
| 92 |
+
for task, source in zip(all_tasks, all_sources):
|
| 93 |
+
if task is None or source is None:
|
| 94 |
+
continue
|
| 95 |
+
if source not in source_to_tasks:
|
| 96 |
+
source_to_tasks[source] = set()
|
| 97 |
+
source_to_tasks[source].add(task)
|
| 98 |
+
|
| 99 |
+
self.tasks_by_data_source = {source: list(tasks) for source, tasks in source_to_tasks.items()}
|
| 100 |
+
|
| 101 |
+
def _generate_sample(self, item):
|
| 102 |
+
"""Generate a sample from an item.
|
| 103 |
+
|
| 104 |
+
This method should be overridden by subclasses to implement their specific
|
| 105 |
+
sample generation logic.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
item: An item from the dataset (typically a trajectory dict)
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
A sample object (e.g., PreferenceSample, SimilaritySample, ProgressSample)
|
| 112 |
+
"""
|
| 113 |
+
raise NotImplementedError("Subclasses must implement _generate_sample")
|
| 114 |
+
|
| 115 |
+
def _get_same_task_optimal(self, ref_traj: dict) -> dict | None:
|
| 116 |
+
"""Get optimal trajectory from same task (different from ref).
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
ref_traj: Reference trajectory
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Same task optimal trajectory dict or None if not available
|
| 123 |
+
"""
|
| 124 |
+
task_name = ref_traj["task"]
|
| 125 |
+
same_task_optimal_indices = self.optimal_by_task.get(task_name, [])
|
| 126 |
+
if not same_task_optimal_indices:
|
| 127 |
+
logger.trace(f"[BASE SAMPLER] _get_same_task_optimal: No optimal indices for task '{task_name}'")
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
# Use cached IDs to check without loading full trajectories
|
| 131 |
+
chosen_id = ref_traj["id"]
|
| 132 |
+
random_idx = random.choice(same_task_optimal_indices)
|
| 133 |
+
|
| 134 |
+
# Retry if the selected trajectory has the same ID as ref
|
| 135 |
+
max_retries = min(10, len(same_task_optimal_indices))
|
| 136 |
+
retries = 0
|
| 137 |
+
while self._cached_ids[random_idx] == chosen_id and retries < max_retries:
|
| 138 |
+
random_idx = random.choice(same_task_optimal_indices)
|
| 139 |
+
retries += 1
|
| 140 |
+
|
| 141 |
+
# If still matches after retries, fall back to filtering
|
| 142 |
+
if self._cached_ids[random_idx] == chosen_id:
|
| 143 |
+
filtered_indices = [idx for idx in same_task_optimal_indices if self._cached_ids[idx] != chosen_id]
|
| 144 |
+
if filtered_indices:
|
| 145 |
+
random_idx = random.choice(filtered_indices)
|
| 146 |
+
else:
|
| 147 |
+
# No other trajectories available
|
| 148 |
+
logger.trace(
|
| 149 |
+
f"[BASE SAMPLER] _get_same_task_optimal: All trajectories have same ID '{chosen_id}' for task '{task_name}'"
|
| 150 |
+
)
|
| 151 |
+
return None
|
| 152 |
+
|
| 153 |
+
result = self.dataset[random_idx]
|
| 154 |
+
logger.trace(
|
| 155 |
+
f"[BASE SAMPLER] _get_same_task_optimal: Found trajectory {result.get('id', 'unknown')} for task '{task_name}'"
|
| 156 |
+
)
|
| 157 |
+
return result
|
| 158 |
+
|
| 159 |
+
def _get_same_task_suboptimal(self, ref_traj: dict) -> dict | None:
|
| 160 |
+
"""Get suboptimal trajectory from same task.
|
| 161 |
+
|
| 162 |
+
For trajectories with partial_success, uses partial_success logic instead of quality_label logic.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
ref_traj: Reference trajectory
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
Suboptimal trajectory dict or None if not available
|
| 169 |
+
"""
|
| 170 |
+
# Check if this trajectory uses partial_success
|
| 171 |
+
use_partial_success = ref_traj.get("partial_success") is not None
|
| 172 |
+
|
| 173 |
+
if use_partial_success:
|
| 174 |
+
# For trajectories with partial_success, use partial_success logic
|
| 175 |
+
return self._get_different_partial_success_traj(ref_traj)
|
| 176 |
+
|
| 177 |
+
# For trajectories without partial_success, use the standard suboptimal logic
|
| 178 |
+
task_name = ref_traj["task"]
|
| 179 |
+
same_task_suboptimal_indices = self.suboptimal_by_task.get(task_name, [])
|
| 180 |
+
if not same_task_suboptimal_indices:
|
| 181 |
+
logger.trace(f"[BASE SAMPLER] _get_same_task_suboptimal: No suboptimal indices for task '{task_name}'")
|
| 182 |
+
return None
|
| 183 |
+
|
| 184 |
+
# Use cached IDs to check without loading full trajectories
|
| 185 |
+
chosen_id = ref_traj["id"]
|
| 186 |
+
random_idx = random.choice(same_task_suboptimal_indices)
|
| 187 |
+
|
| 188 |
+
# Retry if the selected trajectory has the same ID as ref
|
| 189 |
+
max_retries = min(10, len(same_task_suboptimal_indices))
|
| 190 |
+
retries = 0
|
| 191 |
+
while self._cached_ids[random_idx] == chosen_id and retries < max_retries:
|
| 192 |
+
random_idx = random.choice(same_task_suboptimal_indices)
|
| 193 |
+
retries += 1
|
| 194 |
+
|
| 195 |
+
# If still matches after retries, fall back to filtering
|
| 196 |
+
if self._cached_ids[random_idx] == chosen_id:
|
| 197 |
+
filtered_indices = [idx for idx in same_task_suboptimal_indices if self._cached_ids[idx] != chosen_id]
|
| 198 |
+
if filtered_indices:
|
| 199 |
+
random_idx = random.choice(filtered_indices)
|
| 200 |
+
else:
|
| 201 |
+
# No other trajectories available
|
| 202 |
+
logger.trace(
|
| 203 |
+
f"[BASE SAMPLER] _get_same_task_suboptimal: All trajectories have same ID '{chosen_id}' for task '{task_name}'"
|
| 204 |
+
)
|
| 205 |
+
return None
|
| 206 |
+
|
| 207 |
+
result = self.dataset[random_idx]
|
| 208 |
+
logger.trace(
|
| 209 |
+
f"[BASE SAMPLER] _get_same_task_suboptimal: Found trajectory {result.get('id', 'unknown')} for task '{task_name}'"
|
| 210 |
+
)
|
| 211 |
+
return result
|
| 212 |
+
|
| 213 |
+
def _get_different_video_traj(self, ref_traj: dict) -> dict | None:
|
| 214 |
+
"""Get trajectory from different task.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
ref_traj: Reference trajectory
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
Different task trajectory dict or None if not available
|
| 221 |
+
"""
|
| 222 |
+
same_source_prob = self.config.traj_same_source_prob
|
| 223 |
+
data_source = ref_traj.get("data_source")
|
| 224 |
+
other_tasks = []
|
| 225 |
+
|
| 226 |
+
if data_source and data_source in self.tasks_by_data_source and random.random() < same_source_prob:
|
| 227 |
+
other_tasks = [task for task in self.tasks_by_data_source[data_source] if task != ref_traj["task"]]
|
| 228 |
+
|
| 229 |
+
if not other_tasks:
|
| 230 |
+
other_tasks = [task for task in self.optimal_by_task.keys() if task != ref_traj["task"]]
|
| 231 |
+
|
| 232 |
+
if not other_tasks:
|
| 233 |
+
logger.trace(
|
| 234 |
+
f"[BASE SAMPLER] _get_different_video_traj: No other tasks available (ref task: '{ref_traj['task']}')"
|
| 235 |
+
)
|
| 236 |
+
return None
|
| 237 |
+
|
| 238 |
+
# Try up to 2 times to find a valid task
|
| 239 |
+
max_retries = 2
|
| 240 |
+
other_task_indices = None
|
| 241 |
+
other_task = None
|
| 242 |
+
|
| 243 |
+
for attempt in range(max_retries):
|
| 244 |
+
other_task = random.choice(other_tasks)
|
| 245 |
+
if other_task not in self.optimal_by_task:
|
| 246 |
+
logger.trace(
|
| 247 |
+
f"[BASE SAMPLER] _get_different_video_traj: Attempt {attempt + 1}/{max_retries}: Task '{other_task}' not found in optimal_by_task"
|
| 248 |
+
)
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
other_task_indices = self.optimal_by_task[other_task]
|
| 252 |
+
if not other_task_indices:
|
| 253 |
+
logger.trace(
|
| 254 |
+
f"[BASE SAMPLER] _get_different_video_traj: Attempt {attempt + 1}/{max_retries}: Task '{other_task}' has no optimal indices"
|
| 255 |
+
)
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
# Found a valid task with indices
|
| 259 |
+
break
|
| 260 |
+
|
| 261 |
+
if other_task_indices is None or not other_task_indices:
|
| 262 |
+
logger.trace(
|
| 263 |
+
f"[BASE SAMPLER] _get_different_video_traj: Failed to find valid task after {max_retries} attempts"
|
| 264 |
+
)
|
| 265 |
+
return None
|
| 266 |
+
|
| 267 |
+
other_idx = random.choice(other_task_indices)
|
| 268 |
+
result = self.dataset[other_idx]
|
| 269 |
+
logger.trace(
|
| 270 |
+
f"[BASE SAMPLER] _get_different_video_traj: Found trajectory {result.get('id', 'unknown')} from task '{other_task}'"
|
| 271 |
+
)
|
| 272 |
+
return result
|
| 273 |
+
|
| 274 |
+
def _get_different_task_instruction(self, ref_traj: dict) -> dict | None:
|
| 275 |
+
"""Get the same trajectory but with a different task instruction.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
ref_traj: Reference trajectory
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
Trajectory dict with different task instruction or None if not available
|
| 282 |
+
"""
|
| 283 |
+
same_source_prob = self.config.traj_same_source_prob
|
| 284 |
+
data_source = ref_traj.get("data_source")
|
| 285 |
+
candidate_tasks = []
|
| 286 |
+
|
| 287 |
+
if data_source and data_source in self.tasks_by_data_source and random.random() < same_source_prob:
|
| 288 |
+
candidate_tasks = [task for task in self.tasks_by_data_source[data_source] if task != ref_traj["task"]]
|
| 289 |
+
|
| 290 |
+
if not candidate_tasks:
|
| 291 |
+
candidate_tasks = [task for task in self.optimal_by_task.keys() if task != ref_traj["task"]]
|
| 292 |
+
|
| 293 |
+
if not candidate_tasks:
|
| 294 |
+
logger.trace(
|
| 295 |
+
f"[BASE SAMPLER] _get_different_task_instruction: No candidate tasks available (ref task: '{ref_traj['task']}')"
|
| 296 |
+
)
|
| 297 |
+
return None
|
| 298 |
+
|
| 299 |
+
other_task = random.choice(candidate_tasks)
|
| 300 |
+
|
| 301 |
+
# Get embeddings_path and lang_vector from a random trajectory with the other_task
|
| 302 |
+
other_task_indices = self.optimal_by_task.get(other_task, [])
|
| 303 |
+
if not other_task_indices:
|
| 304 |
+
logger.trace(f"[BASE SAMPLER] _get_different_task_instruction: Task '{other_task}' has no optimal indices")
|
| 305 |
+
return None
|
| 306 |
+
|
| 307 |
+
other_task_idx = random.choice(other_task_indices)
|
| 308 |
+
other_task_traj = self.dataset[other_task_idx]
|
| 309 |
+
|
| 310 |
+
# Create a copy of the trajectory with the task changed
|
| 311 |
+
# Use embeddings_path and lang_vector from the other_task trajectory
|
| 312 |
+
new_traj = ref_traj.copy()
|
| 313 |
+
new_traj["task"] = other_task
|
| 314 |
+
# Get embeddings_path and lang_vector from a random trajectory with the other_task
|
| 315 |
+
if "embeddings_path" in other_task_traj:
|
| 316 |
+
new_traj["embeddings_path"] = other_task_traj["embeddings_path"]
|
| 317 |
+
if "lang_vector" in other_task_traj:
|
| 318 |
+
new_traj["lang_vector"] = other_task_traj["lang_vector"]
|
| 319 |
+
return new_traj
|
| 320 |
+
|
| 321 |
+
def _get_paired_human_robot_traj(self, ref_traj: dict) -> dict | None:
|
| 322 |
+
"""Get paired human/robot trajectory for the same task.
|
| 323 |
+
|
| 324 |
+
Given a reference trajectory, if it's a robot trajectory, returns a human trajectory
|
| 325 |
+
from the same task. If it's a human trajectory, returns a robot trajectory from the
|
| 326 |
+
same task.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
ref_traj: Reference trajectory (can be robot or human)
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
Paired trajectory dict (opposite type) or None if not available
|
| 333 |
+
"""
|
| 334 |
+
task = ref_traj["task"]
|
| 335 |
+
is_robot = ref_traj.get("is_robot", True)
|
| 336 |
+
|
| 337 |
+
if task not in self.paired_human_robot_by_task:
|
| 338 |
+
logger.trace(
|
| 339 |
+
f"[BASE SAMPLER] _get_paired_human_robot_traj: Task '{task}' not in paired_human_robot_by_task"
|
| 340 |
+
)
|
| 341 |
+
return None
|
| 342 |
+
|
| 343 |
+
task_pairs = self.paired_human_robot_by_task[task]
|
| 344 |
+
|
| 345 |
+
# Get opposite type
|
| 346 |
+
opposite_key = "human" if is_robot else "robot"
|
| 347 |
+
opposite_indices = task_pairs.get(opposite_key, [])
|
| 348 |
+
|
| 349 |
+
if not opposite_indices:
|
| 350 |
+
logger.trace(f"[BASE SAMPLER] _get_paired_human_robot_traj: No {opposite_key} indices for task '{task}'")
|
| 351 |
+
return None
|
| 352 |
+
|
| 353 |
+
# Sample a paired trajectory and verify it's different from reference
|
| 354 |
+
chosen_id = ref_traj["id"]
|
| 355 |
+
available_indices = opposite_indices.copy()
|
| 356 |
+
paired_traj = None
|
| 357 |
+
|
| 358 |
+
# Add retry limit to prevent infinite loops
|
| 359 |
+
max_retries = min(len(available_indices), 10)
|
| 360 |
+
retries = 0
|
| 361 |
+
|
| 362 |
+
logger.trace(
|
| 363 |
+
f"[BASE SAMPLER] _get_paired_human_robot_traj: Looking for {opposite_key} trajectory (chosen_id: {chosen_id}, available: {len(available_indices)})"
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
while (paired_traj is None or paired_traj.get("id") == chosen_id) and retries < max_retries:
|
| 367 |
+
retries += 1
|
| 368 |
+
|
| 369 |
+
if not available_indices:
|
| 370 |
+
logger.trace(
|
| 371 |
+
f"[BASE SAMPLER] _get_paired_human_robot_traj: No more available indices after {retries} retries"
|
| 372 |
+
)
|
| 373 |
+
return None
|
| 374 |
+
|
| 375 |
+
paired_idx = random.choice(available_indices)
|
| 376 |
+
paired_traj = self.dataset[paired_idx]
|
| 377 |
+
|
| 378 |
+
# If it matches, remove this index and try again
|
| 379 |
+
if paired_traj.get("id") == chosen_id:
|
| 380 |
+
available_indices = [idx for idx in available_indices if idx != paired_idx]
|
| 381 |
+
paired_traj = None
|
| 382 |
+
continue
|
| 383 |
+
|
| 384 |
+
# If we exhausted retries without finding a valid trajectory, return None
|
| 385 |
+
if paired_traj is None or paired_traj.get("id") == chosen_id:
|
| 386 |
+
logger.trace(
|
| 387 |
+
f"[BASE SAMPLER] _get_paired_human_robot_traj: Failed to find valid paired trajectory after {max_retries} retries"
|
| 388 |
+
)
|
| 389 |
+
return None
|
| 390 |
+
|
| 391 |
+
logger.trace(
|
| 392 |
+
f"[BASE SAMPLER] _get_paired_human_robot_traj: Found paired trajectory {paired_traj.get('id', 'unknown')} on retry {retries}"
|
| 393 |
+
)
|
| 394 |
+
return paired_traj
|
| 395 |
+
|
| 396 |
+
def _get_different_partial_success_traj(self, ref_traj: dict) -> dict | None:
|
| 397 |
+
"""Get trajectory from same task with different partial_success.
|
| 398 |
+
|
| 399 |
+
Finds trajectories with either higher or lower partial_success than the reference,
|
| 400 |
+
using absolute difference for threshold checking.
|
| 401 |
+
|
| 402 |
+
Args:
|
| 403 |
+
ref_traj: Reference trajectory
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
Trajectory dict with different partial_success from same task or None if not available
|
| 407 |
+
"""
|
| 408 |
+
task_name = ref_traj["task"]
|
| 409 |
+
ref_partial_success = ref_traj.get("partial_success")
|
| 410 |
+
|
| 411 |
+
# Check if partial_success is available
|
| 412 |
+
if ref_partial_success is None:
|
| 413 |
+
logger.trace(
|
| 414 |
+
f"[BASE SAMPLER] _get_different_partial_success_traj: No partial_success for trajectory {ref_traj.get('id', 'unknown')}"
|
| 415 |
+
)
|
| 416 |
+
return None
|
| 417 |
+
|
| 418 |
+
# Get minimum threshold from config
|
| 419 |
+
min_threshold = getattr(self.config, "partial_success_threshold", 0.2)
|
| 420 |
+
|
| 421 |
+
# Get all trajectories from the same task
|
| 422 |
+
same_task_indices = self.task_indices.get(task_name, [])
|
| 423 |
+
if not same_task_indices:
|
| 424 |
+
logger.trace(
|
| 425 |
+
f"[BASE SAMPLER] _get_different_partial_success_traj: No trajectories found for task '{task_name}'"
|
| 426 |
+
)
|
| 427 |
+
return None
|
| 428 |
+
|
| 429 |
+
# Filter to trajectories with different partial_success that meet the threshold requirement
|
| 430 |
+
# Uses absolute difference to allow both higher and lower partial_success
|
| 431 |
+
chosen_id = ref_traj["id"]
|
| 432 |
+
candidate_indices = []
|
| 433 |
+
|
| 434 |
+
for idx in same_task_indices:
|
| 435 |
+
# Skip if same trajectory
|
| 436 |
+
if self._cached_ids[idx] == chosen_id:
|
| 437 |
+
continue
|
| 438 |
+
|
| 439 |
+
# Get partial_success for this trajectory
|
| 440 |
+
traj_dict = self.dataset[idx]
|
| 441 |
+
traj_partial_success = traj_dict.get("partial_success", None)
|
| 442 |
+
|
| 443 |
+
if traj_partial_success is None:
|
| 444 |
+
logger.trace(
|
| 445 |
+
f"[BASE SAMPLER] _get_different_partial_success_traj: No partial_success for trajectory {traj_dict.get('id', 'unknown')}, task '{task_name}'"
|
| 446 |
+
)
|
| 447 |
+
continue
|
| 448 |
+
|
| 449 |
+
# Include if partial_success differs from reference by at least the threshold (using abs)
|
| 450 |
+
partial_success_diff = abs(ref_partial_success - traj_partial_success)
|
| 451 |
+
if partial_success_diff >= min_threshold:
|
| 452 |
+
candidate_indices.append(idx)
|
| 453 |
+
|
| 454 |
+
if not candidate_indices:
|
| 455 |
+
logger.trace(
|
| 456 |
+
f"[BASE SAMPLER] _get_different_partial_success_traj: No trajectories with different partial_success (threshold: {min_threshold}) for task '{task_name}' (ref: {ref_partial_success})"
|
| 457 |
+
)
|
| 458 |
+
return None
|
| 459 |
+
|
| 460 |
+
# Randomly select from candidates
|
| 461 |
+
selected_idx = random.choice(candidate_indices)
|
| 462 |
+
result = self.dataset[selected_idx]
|
| 463 |
+
result_partial_success = result.get("partial_success")
|
| 464 |
+
# If ref_partial_success is 1.0, direction is always "lower" since 1.0 is the maximum
|
| 465 |
+
if ref_partial_success == 1.0:
|
| 466 |
+
direction = "lower"
|
| 467 |
+
else:
|
| 468 |
+
direction = "higher" if result_partial_success > ref_partial_success else "lower"
|
| 469 |
+
logger.trace(
|
| 470 |
+
f"[BASE SAMPLER] _get_different_partial_success_traj: Found trajectory {result.get('id', 'unknown')} with partial_success {result_partial_success} ({direction} than {ref_partial_success}, abs diff: {abs(ref_partial_success - result_partial_success):.3f}, threshold: {min_threshold})"
|
| 471 |
+
)
|
| 472 |
+
return result
|
| 473 |
+
|
| 474 |
+
def _get_subsample_indices(
|
| 475 |
+
self, data, direction: str = "bidirectional", max_frames: int = None
|
| 476 |
+
) -> Optional[Tuple[int, int, int]]:
|
| 477 |
+
"""Get start, middle, and end indices for subsample strategy.
|
| 478 |
+
|
| 479 |
+
Samples three random frames from the trajectory. The relationship between indices
|
| 480 |
+
follows three main scenarios:
|
| 481 |
+
1. start < middle < end: forward progress - normal forward progression through trajectory
|
| 482 |
+
2. start < end < middle: rewind progress - forward from start to end, then continues to middle (simulating rewind/backtrack)
|
| 483 |
+
3. end < middle < start: reverse progress - backward from start through middle to end (full backward traversal)
|
| 484 |
+
|
| 485 |
+
Args:
|
| 486 |
+
data: Trajectory data (frames or embeddings) to sample from
|
| 487 |
+
direction: Sampling direction - "forward" (start < middle < end),
|
| 488 |
+
"reverse" (end < middle < start),
|
| 489 |
+
"rewind" (start < end < middle),
|
| 490 |
+
or "bidirectional" (any of the 3 orderings)
|
| 491 |
+
max_frames: Maximum number of frames to subsample. If 1, returns only start. If 2, returns start and end.
|
| 492 |
+
|
| 493 |
+
Returns:
|
| 494 |
+
Tuple of (start_idx, middle_idx, end_idx), or None if insufficient frames
|
| 495 |
+
For max_frames == 1: returns (start_idx, None, None)
|
| 496 |
+
For max_frames == 2: returns (start_idx, None, end_idx)
|
| 497 |
+
"""
|
| 498 |
+
num_frames_total = len(data) if hasattr(data, "__len__") else data.shape[0]
|
| 499 |
+
|
| 500 |
+
# Handle edge cases for max_frames == 1 or 2
|
| 501 |
+
if max_frames == 1:
|
| 502 |
+
# Randomly sample 1 frame
|
| 503 |
+
random_idx = random.randint(0, num_frames_total - 1)
|
| 504 |
+
logger.trace(f"[BASE SAMPLER] _get_subsample_indices: max_frames=1, randomly sampled idx={random_idx}")
|
| 505 |
+
return (random_idx, None, None)
|
| 506 |
+
|
| 507 |
+
if max_frames == 2:
|
| 508 |
+
# Sample 2 frames: either forward (start < end) or reverse (end < start)
|
| 509 |
+
# No rewind possible with only 2 frames
|
| 510 |
+
if direction == "reverse":
|
| 511 |
+
# Reverse: sample end first, then start (end < start)
|
| 512 |
+
end_idx = random.randint(0, num_frames_total - 2)
|
| 513 |
+
start_idx = random.randint(end_idx + 1, num_frames_total - 1)
|
| 514 |
+
else:
|
| 515 |
+
# Forward: sample start first, then end (start < end)
|
| 516 |
+
start_idx = random.randint(0, num_frames_total - 2)
|
| 517 |
+
end_idx = random.randint(start_idx + 1, num_frames_total - 1)
|
| 518 |
+
logger.trace(
|
| 519 |
+
f"[BASE SAMPLER] _get_subsample_indices: max_frames=2, start_idx={start_idx}, end_idx={end_idx}, direction={direction}"
|
| 520 |
+
)
|
| 521 |
+
return (start_idx, None, end_idx)
|
| 522 |
+
|
| 523 |
+
if num_frames_total < 3:
|
| 524 |
+
logger.trace(f"[BASE SAMPLER] _get_subsample_indices: Not enough frames ({num_frames_total})")
|
| 525 |
+
return None
|
| 526 |
+
|
| 527 |
+
# Sample three random distinct frames
|
| 528 |
+
frame_indices = sorted(random.sample(range(num_frames_total), 3))
|
| 529 |
+
frame1_idx, frame2_idx, frame3_idx = frame_indices
|
| 530 |
+
|
| 531 |
+
# Determine start, middle, and end based on direction
|
| 532 |
+
# We only care about 3 cases:
|
| 533 |
+
# 1. start < middle < end: forward progress
|
| 534 |
+
# 2. start < end < middle: rewind progress
|
| 535 |
+
# 3. end < middle < start: reverse progress
|
| 536 |
+
|
| 537 |
+
if direction == "forward":
|
| 538 |
+
# Case 1: start < middle < end
|
| 539 |
+
start_idx = frame1_idx
|
| 540 |
+
middle_idx = frame2_idx
|
| 541 |
+
end_idx = frame3_idx
|
| 542 |
+
elif direction == "reverse":
|
| 543 |
+
# Case 3: end < middle < start
|
| 544 |
+
end_idx = frame1_idx
|
| 545 |
+
middle_idx = frame2_idx
|
| 546 |
+
start_idx = frame3_idx
|
| 547 |
+
elif direction == "rewind":
|
| 548 |
+
# Case 2: start < end < middle
|
| 549 |
+
start_idx = frame1_idx
|
| 550 |
+
end_idx = frame2_idx
|
| 551 |
+
middle_idx = frame3_idx
|
| 552 |
+
else: # bidirectional (default)
|
| 553 |
+
# Randomly choose from the 3 cases
|
| 554 |
+
pattern = random.choice([1, 2, 3])
|
| 555 |
+
if pattern == 1: # start < middle < end: forward progress
|
| 556 |
+
start_idx = frame1_idx
|
| 557 |
+
middle_idx = frame2_idx
|
| 558 |
+
end_idx = frame3_idx
|
| 559 |
+
elif pattern == 2: # start < end < middle: rewind progress
|
| 560 |
+
start_idx = frame1_idx
|
| 561 |
+
end_idx = frame2_idx
|
| 562 |
+
middle_idx = frame3_idx
|
| 563 |
+
else: # pattern == 3: end < middle < start: reverse progress
|
| 564 |
+
end_idx = frame1_idx
|
| 565 |
+
middle_idx = frame2_idx
|
| 566 |
+
start_idx = frame3_idx
|
| 567 |
+
|
| 568 |
+
logger.trace(
|
| 569 |
+
f"[BASE SAMPLER] _get_subsample_indices: Selected indices start={start_idx}, middle={middle_idx}, end={end_idx} "
|
| 570 |
+
f"from {num_frames_total} total frames (direction: {direction})"
|
| 571 |
+
)
|
| 572 |
+
return start_idx, middle_idx, end_idx
|
| 573 |
+
|
| 574 |
+
def _get_traj_from_data(
|
| 575 |
+
self,
|
| 576 |
+
traj: dict | Trajectory,
|
| 577 |
+
subsample_strategy: str | None = None,
|
| 578 |
+
frame_indices: List[int] | None = None,
|
| 579 |
+
metadata: Dict[str, Any] | None = None,
|
| 580 |
+
) -> Trajectory:
|
| 581 |
+
"""Load, subsample, and optionally pad trajectory data and create a Trajectory object.
|
| 582 |
+
|
| 583 |
+
Args:
|
| 584 |
+
traj: Trajectory dict or Trajectory object
|
| 585 |
+
subsample_strategy: Optional strategy for subsampling ("subsample_forward", "subsample_reverse", "subsample_rewind", or None for default/bidirectional). Ignored if frame_indices is provided.
|
| 586 |
+
frame_indices: Optional list of specific frame indices to use. If provided, subsample_strategy is ignored.
|
| 587 |
+
metadata: Optional metadata dict to merge into trajectory metadata.
|
| 588 |
+
|
| 589 |
+
Returns:
|
| 590 |
+
Trajectory object with loaded and subsampled data (padded)
|
| 591 |
+
"""
|
| 592 |
+
# Initialize variables
|
| 593 |
+
frames = None
|
| 594 |
+
video_embeddings = None
|
| 595 |
+
text_embedding = None
|
| 596 |
+
data = None
|
| 597 |
+
|
| 598 |
+
if isinstance(traj, Trajectory):
|
| 599 |
+
# If already a Trajectory, just return it
|
| 600 |
+
return traj
|
| 601 |
+
|
| 602 |
+
# Load from dict
|
| 603 |
+
# Check if text_embedding is already provided in the dict (for samplers that need to override it)
|
| 604 |
+
if "text_embedding" in traj and traj["text_embedding"] is not None:
|
| 605 |
+
text_embedding = traj["text_embedding"]
|
| 606 |
+
|
| 607 |
+
if self.config.load_embeddings and traj.get("embeddings_path"):
|
| 608 |
+
embeddings = load_embeddings_from_path(traj["embeddings_path"])
|
| 609 |
+
video_embeddings = embeddings["video_embeddings"]
|
| 610 |
+
# Only use loaded text_embedding if not already provided in dict
|
| 611 |
+
if text_embedding is None:
|
| 612 |
+
text_embedding = embeddings["text_embedding"]
|
| 613 |
+
data = video_embeddings
|
| 614 |
+
else:
|
| 615 |
+
if isinstance(traj["frames"], str):
|
| 616 |
+
frames = load_frames_from_npz(traj["frames"])
|
| 617 |
+
else:
|
| 618 |
+
frames = traj["frames"]
|
| 619 |
+
data = frames
|
| 620 |
+
|
| 621 |
+
# Get total frames for progress computation
|
| 622 |
+
if hasattr(data, "shape"):
|
| 623 |
+
num_frames_total = data.shape[0]
|
| 624 |
+
else:
|
| 625 |
+
num_frames_total = len(data)
|
| 626 |
+
|
| 627 |
+
ds_key = traj["data_source"]
|
| 628 |
+
success_cutoff = self.dataset_success_cutoff_map.get(ds_key, self.config.max_success)
|
| 629 |
+
|
| 630 |
+
# Determine which indices to use (construct indices first, then subsample uniformly)
|
| 631 |
+
if frame_indices is not None:
|
| 632 |
+
# Use provided frame indices directly
|
| 633 |
+
indices = frame_indices
|
| 634 |
+
elif subsample_strategy is not None:
|
| 635 |
+
# Use subsampling strategy
|
| 636 |
+
# Get subsample indices (handles edge cases for max_frames == 1 or 2)
|
| 637 |
+
if subsample_strategy == "subsample_forward":
|
| 638 |
+
strategy_indices = self._get_subsample_indices(
|
| 639 |
+
data, direction="forward", max_frames=self.config.max_frames
|
| 640 |
+
)
|
| 641 |
+
elif subsample_strategy == "subsample_reverse":
|
| 642 |
+
strategy_indices = self._get_subsample_indices(
|
| 643 |
+
data, direction="reverse", max_frames=self.config.max_frames
|
| 644 |
+
)
|
| 645 |
+
elif subsample_strategy == "subsample_rewind":
|
| 646 |
+
strategy_indices = self._get_subsample_indices(
|
| 647 |
+
data, direction="rewind", max_frames=self.config.max_frames
|
| 648 |
+
)
|
| 649 |
+
else:
|
| 650 |
+
strategy_indices = self._get_subsample_indices(
|
| 651 |
+
data, direction="bidirectional", max_frames=self.config.max_frames
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
if strategy_indices is None:
|
| 655 |
+
logger.trace("[BASE SAMPLER] _get_traj_from_data: Failed to get uniform sample indices")
|
| 656 |
+
return None
|
| 657 |
+
|
| 658 |
+
start_idx, middle_idx, end_idx = strategy_indices
|
| 659 |
+
|
| 660 |
+
logger.trace(
|
| 661 |
+
f"[BASE SAMPLER] _get_traj_from_data: Subsampling trajectory with strategy: {subsample_strategy}, start_idx: {start_idx}, middle_idx: {middle_idx}, end_idx: {end_idx}"
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
# Use middle_idx only for rewind strategy (requires at least 3 frames)
|
| 665 |
+
use_middle = subsample_strategy == "subsample_rewind" and middle_idx is not None and num_frames_total >= 3
|
| 666 |
+
|
| 667 |
+
# Use get_segment_indices_with_middle to construct indices
|
| 668 |
+
indices = get_segment_indices_with_middle(
|
| 669 |
+
num_frames_total=num_frames_total,
|
| 670 |
+
start_idx=start_idx,
|
| 671 |
+
end_idx=end_idx,
|
| 672 |
+
middle_idx=middle_idx if use_middle else None,
|
| 673 |
+
max_frames=self.config.max_frames,
|
| 674 |
+
)
|
| 675 |
+
else:
|
| 676 |
+
# No subsampling strategy or indices provided - use all frames
|
| 677 |
+
indices = list(range(num_frames_total))
|
| 678 |
+
|
| 679 |
+
# Extract data using indices
|
| 680 |
+
subsampled = data[indices]
|
| 681 |
+
|
| 682 |
+
# Get partial_success early to pass to compute_progress_from_segment
|
| 683 |
+
partial_success = traj.get("partial_success")
|
| 684 |
+
|
| 685 |
+
# Compute progress
|
| 686 |
+
target_progress = compute_progress_from_segment(
|
| 687 |
+
num_frames_total=num_frames_total,
|
| 688 |
+
frame_indices=indices,
|
| 689 |
+
progress_pred_type=self.config.progress_pred_type,
|
| 690 |
+
success_cutoff=success_cutoff,
|
| 691 |
+
partial_success=partial_success,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
# Subsample uniformly if needed (if we have more frames than max_frames)
|
| 695 |
+
current_frame_count = len(subsampled) if hasattr(subsampled, "__len__") else subsampled.shape[0]
|
| 696 |
+
if current_frame_count > self.config.max_frames:
|
| 697 |
+
subsampled, frame_indices_subsample = linspace_subsample_frames(subsampled, self.config.max_frames)
|
| 698 |
+
# Update indices and target_progress
|
| 699 |
+
if target_progress and len(target_progress) == current_frame_count:
|
| 700 |
+
target_progress = [target_progress[idx] for idx in frame_indices_subsample]
|
| 701 |
+
indices = [indices[idx] for idx in frame_indices_subsample] if isinstance(indices, list) else indices
|
| 702 |
+
|
| 703 |
+
# Pad if needed
|
| 704 |
+
if target_progress:
|
| 705 |
+
if self.config.load_embeddings:
|
| 706 |
+
subsampled, target_progress = pad_trajectory_to_max_frames_torch(
|
| 707 |
+
subsampled, target_progress, self.config.max_frames
|
| 708 |
+
)
|
| 709 |
+
else:
|
| 710 |
+
subsampled, target_progress = pad_trajectory_to_max_frames_np(
|
| 711 |
+
subsampled, target_progress, self.config.max_frames
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
# Update frames_shape
|
| 715 |
+
frames_shape = subsampled.shape if hasattr(subsampled, "shape") else tuple()
|
| 716 |
+
|
| 717 |
+
# Set frames or video_embeddings
|
| 718 |
+
if self.config.load_embeddings:
|
| 719 |
+
video_embeddings = subsampled
|
| 720 |
+
else:
|
| 721 |
+
frames = subsampled
|
| 722 |
+
|
| 723 |
+
# Compute success labels
|
| 724 |
+
success_label = compute_success_labels(
|
| 725 |
+
target_progress=target_progress,
|
| 726 |
+
data_source=traj["data_source"],
|
| 727 |
+
dataset_success_percent=self.dataset_success_cutoff_map,
|
| 728 |
+
max_success=self.config.max_success,
|
| 729 |
+
quality_label=traj.get("quality_label"),
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
# Convert partial_success and target_progress to discrete bins if in discrete mode
|
| 733 |
+
if self.config.progress_loss_type.lower() == "discrete":
|
| 734 |
+
if partial_success is not None:
|
| 735 |
+
partial_success = convert_continuous_to_discrete_bins(
|
| 736 |
+
[partial_success], self.config.progress_discrete_bins
|
| 737 |
+
)[0]
|
| 738 |
+
target_progress = convert_continuous_to_discrete_bins(target_progress, self.config.progress_discrete_bins)
|
| 739 |
+
|
| 740 |
+
trajectory = create_trajectory_from_dict(
|
| 741 |
+
traj,
|
| 742 |
+
overrides={
|
| 743 |
+
"frames": frames,
|
| 744 |
+
"frames_shape": frames_shape,
|
| 745 |
+
"video_embeddings": video_embeddings,
|
| 746 |
+
"text_embedding": text_embedding,
|
| 747 |
+
"target_progress": target_progress,
|
| 748 |
+
"success_label": success_label,
|
| 749 |
+
"partial_success": partial_success,
|
| 750 |
+
"metadata": metadata,
|
| 751 |
+
},
|
| 752 |
+
)
|
| 753 |
+
return trajectory
|
samplers/eval/__pycache__/base_pref.cpython-310.pyc
ADDED
|
Binary file (2.15 kB). View file
|
|
|
samplers/eval/__pycache__/base_pref.cpython-311.pyc
ADDED
|
Binary file (3.22 kB). View file
|
|
|
samplers/eval/__pycache__/confusion_matrix.cpython-310.pyc
ADDED
|
Binary file (8.56 kB). View file
|
|
|
samplers/eval/__pycache__/confusion_matrix.cpython-311.pyc
ADDED
|
Binary file (15.7 kB). View file
|
|
|
samplers/eval/__pycache__/progress_default.cpython-310.pyc
ADDED
|
Binary file (2.74 kB). View file
|
|
|
samplers/eval/__pycache__/progress_default.cpython-311.pyc
ADDED
|
Binary file (6.93 kB). View file
|
|
|
samplers/eval/__pycache__/progress_policy_ranking.cpython-310.pyc
ADDED
|
Binary file (6.34 kB). View file
|
|
|
samplers/eval/__pycache__/progress_policy_ranking.cpython-311.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
samplers/eval/__pycache__/quality_preference.cpython-310.pyc
ADDED
|
Binary file (4.38 kB). View file
|
|
|
samplers/eval/__pycache__/quality_preference.cpython-311.pyc
ADDED
|
Binary file (7.59 kB). View file
|
|
|
samplers/eval/__pycache__/reward_alignment.cpython-310.pyc
ADDED
|
Binary file (4.24 kB). View file
|
|
|
samplers/eval/__pycache__/reward_alignment.cpython-311.pyc
ADDED
|
Binary file (6.6 kB). View file
|
|
|
samplers/eval/__pycache__/roboarena_quality_preference.cpython-310.pyc
ADDED
|
Binary file (2.9 kB). View file
|
|
|
samplers/eval/__pycache__/roboarena_quality_preference.cpython-311.pyc
ADDED
|
Binary file (4.71 kB). View file
|
|
|
samplers/eval/__pycache__/similarity_score.cpython-310.pyc
ADDED
|
Binary file (4.35 kB). View file
|
|
|
samplers/eval/__pycache__/similarity_score.cpython-311.pyc
ADDED
|
Binary file (6.66 kB). View file
|
|
|
samplers/eval/base_pref.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from rfm.data.dataset_types import PreferenceSample, Trajectory
|
| 6 |
+
from rfm.data.samplers.base import RFMBaseSampler
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BaseQualityPreferenceSampler(RFMBaseSampler):
|
| 10 |
+
"""Base class for quality preference samplers.
|
| 11 |
+
|
| 12 |
+
Subclasses should implement `_generate_all_sample_indices` to define how
|
| 13 |
+
trajectories are paired. This base class provides the common `_generate_sample_from_indices`
|
| 14 |
+
method that loads and processes the trajectories.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def _generate_sample_from_indices(self, sample_idx_info: Dict[str, Any]) -> PreferenceSample:
|
| 18 |
+
"""Generate a single sample from stored indices."""
|
| 19 |
+
chosen_idx = sample_idx_info["chosen_traj_idx"]
|
| 20 |
+
rejected_idx = sample_idx_info["rejected_traj_idx"]
|
| 21 |
+
|
| 22 |
+
# Get the trajectories
|
| 23 |
+
chosen_traj = self.dataset[chosen_idx]
|
| 24 |
+
rejected_traj = self.dataset[rejected_idx]
|
| 25 |
+
|
| 26 |
+
chosen_metadata = {
|
| 27 |
+
"quality_label": chosen_traj["quality_label"],
|
| 28 |
+
"data_source": chosen_traj["data_source"],
|
| 29 |
+
"task": chosen_traj["task"],
|
| 30 |
+
"id": chosen_traj["id"],
|
| 31 |
+
"video_path": chosen_traj["frames"],
|
| 32 |
+
}
|
| 33 |
+
# Add partial_success if available
|
| 34 |
+
if chosen_traj.get("partial_success") is not None:
|
| 35 |
+
chosen_metadata["partial_success"] = chosen_traj.get("partial_success")
|
| 36 |
+
|
| 37 |
+
chosen_trajectory = self._get_traj_from_data(
|
| 38 |
+
traj=chosen_traj,
|
| 39 |
+
metadata=chosen_metadata,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
rejected_metadata = {
|
| 43 |
+
"quality_label": rejected_traj["quality_label"],
|
| 44 |
+
"data_source": rejected_traj["data_source"],
|
| 45 |
+
"task": rejected_traj["task"],
|
| 46 |
+
"id": rejected_traj["id"],
|
| 47 |
+
"video_path": rejected_traj["frames"],
|
| 48 |
+
}
|
| 49 |
+
# Add partial_success if available
|
| 50 |
+
if rejected_traj.get("partial_success") is not None:
|
| 51 |
+
rejected_metadata["partial_success"] = rejected_traj.get("partial_success")
|
| 52 |
+
|
| 53 |
+
rejected_trajectory = self._get_traj_from_data(
|
| 54 |
+
traj=rejected_traj,
|
| 55 |
+
metadata=rejected_metadata,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
data_gen_strategy = getattr(self, "data_gen_strategy", "quality_preference")
|
| 59 |
+
|
| 60 |
+
# Create preference sample
|
| 61 |
+
sample = PreferenceSample(
|
| 62 |
+
chosen_trajectory=chosen_trajectory,
|
| 63 |
+
rejected_trajectory=rejected_trajectory,
|
| 64 |
+
data_gen_strategy=data_gen_strategy,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
return sample
|
| 68 |
+
|
| 69 |
+
def __len__(self):
|
| 70 |
+
return len(self.sample_indices)
|
| 71 |
+
|
| 72 |
+
def __getitem__(self, idx):
|
| 73 |
+
return self._generate_sample_from_indices(self.sample_indices[idx])
|
samplers/eval/confusion_matrix.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Data generator for confusion matrix analysis.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
import torch
|
| 8 |
+
from collections import Counter, defaultdict
|
| 9 |
+
from typing import Tuple
|
| 10 |
+
|
| 11 |
+
from rfm.data.dataset_types import PreferenceSample, ProgressSample
|
| 12 |
+
from rfm.data.samplers.base import RFMBaseSampler
|
| 13 |
+
from rfm.utils.distributed import rank_0_print
|
| 14 |
+
from sentence_transformers import SentenceTransformer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ConfusionMatrixSampler(RFMBaseSampler):
|
| 18 |
+
"""
|
| 19 |
+
Data generator that creates task-trajectory pairs for confusion matrix analysis.
|
| 20 |
+
|
| 21 |
+
For each unique task, creates samples with each trajectory to analyze
|
| 22 |
+
how well the model can distinguish between different tasks.
|
| 23 |
+
|
| 24 |
+
If multiple data sources are present, samples N random trajectories from each data source
|
| 25 |
+
and prioritizes different language instructions by randomizing the pairing order.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, n_trajectories_per_source: int = 5, **kwargs):
|
| 29 |
+
"""Initialize confusion matrix sampler.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
n_trajectories_per_source: Number of trajectories to sample from each data source.
|
| 33 |
+
If None, uses all available trajectories.
|
| 34 |
+
**kwargs: Additional arguments passed to parent class.
|
| 35 |
+
"""
|
| 36 |
+
super().__init__(**kwargs)
|
| 37 |
+
self.n_trajectories_per_source = n_trajectories_per_source
|
| 38 |
+
|
| 39 |
+
# Load sentence transformer model and precompute embeddings for all unique tasks
|
| 40 |
+
self.sentence_model = SentenceTransformer("sentence-transformers/all-MiniLM-L12-v2")
|
| 41 |
+
self.sentence_model.eval()
|
| 42 |
+
|
| 43 |
+
# Precompute language embeddings for all unique tasks
|
| 44 |
+
unique_tasks = list(self.task_indices.keys())
|
| 45 |
+
rank_0_print(f"Precomputing language embeddings for {len(unique_tasks)} unique tasks", verbose=self.verbose)
|
| 46 |
+
self.task_embeddings = {}
|
| 47 |
+
for task in unique_tasks:
|
| 48 |
+
embedding = self.sentence_model.encode(task)
|
| 49 |
+
self.task_embeddings[task] = torch.tensor(embedding)
|
| 50 |
+
rank_0_print(f"Precomputed {len(self.task_embeddings)} language embeddings", verbose=self.verbose)
|
| 51 |
+
|
| 52 |
+
# Free up the model after precomputation (no longer needed)
|
| 53 |
+
del self.sentence_model
|
| 54 |
+
|
| 55 |
+
self.sample_indices = self._generate_all_sample_indices()
|
| 56 |
+
|
| 57 |
+
rank_0_print(
|
| 58 |
+
f"Generated {len(self.sample_indices)} confusion matrix sample indices from {len(self.robot_trajectories)} trajectories and {len(self.task_indices)} tasks"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def _generate_all_sample_indices(self) -> list[dict]:
|
| 62 |
+
"""Generate all possible task-trajectory pair sample indices.
|
| 63 |
+
|
| 64 |
+
If multiple data sources exist, samples N random trajectories from each data source.
|
| 65 |
+
Prioritizes different video tasks first, then prioritizes different language instructions
|
| 66 |
+
when creating pairs.
|
| 67 |
+
"""
|
| 68 |
+
sample_indices = []
|
| 69 |
+
|
| 70 |
+
# Get unique tasks (these will be the language instructions)
|
| 71 |
+
unique_lang_tasks = list(self.task_indices.keys())
|
| 72 |
+
rank_0_print(f"Found {len(unique_lang_tasks)} unique language tasks: {unique_lang_tasks}", verbose=self.verbose)
|
| 73 |
+
|
| 74 |
+
# Sample trajectories per data source (prioritizing different video tasks)
|
| 75 |
+
sampled_trajectories, stats = self._sample_trajectories_by_data_source()
|
| 76 |
+
|
| 77 |
+
rank_0_print(
|
| 78 |
+
f"Processing {len(sampled_trajectories)} trajectories for confusion matrix analysis",
|
| 79 |
+
verbose=self.verbose,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Print statistics about sampled trajectories
|
| 83 |
+
self._print_sampling_stats(stats)
|
| 84 |
+
|
| 85 |
+
# Shuffle language tasks once for round-robin pairing
|
| 86 |
+
shuffled_lang_tasks = unique_lang_tasks.copy()
|
| 87 |
+
self._local_random.shuffle(shuffled_lang_tasks)
|
| 88 |
+
|
| 89 |
+
# Create task-trajectory pairs with prioritized language instruction pairing
|
| 90 |
+
video_task_count = Counter()
|
| 91 |
+
|
| 92 |
+
for traj_idx in sampled_trajectories:
|
| 93 |
+
traj = self.dataset[traj_idx]
|
| 94 |
+
video_task = traj["task"]
|
| 95 |
+
|
| 96 |
+
# # Limit the number of video trajectories for each task to 5
|
| 97 |
+
# if video_task_count[video_task] >= 5:
|
| 98 |
+
# continue
|
| 99 |
+
|
| 100 |
+
video_task_count[video_task] += 1
|
| 101 |
+
|
| 102 |
+
# Pair this trajectory with all language tasks (shuffled for variety)
|
| 103 |
+
traj_id = traj.get("id", str(traj_idx))
|
| 104 |
+
for lang_task in shuffled_lang_tasks:
|
| 105 |
+
sample_indices.append({
|
| 106 |
+
"traj_idx": traj_idx,
|
| 107 |
+
"lang_task": lang_task,
|
| 108 |
+
"video_task": video_task,
|
| 109 |
+
"video_path": traj["frames"],
|
| 110 |
+
"id": traj_id,
|
| 111 |
+
})
|
| 112 |
+
|
| 113 |
+
# Shuffle final sample indices to further randomize the order
|
| 114 |
+
self._local_random.shuffle(sample_indices)
|
| 115 |
+
|
| 116 |
+
# Print statistics about pairs created
|
| 117 |
+
rank_0_print(f"Generated {len(sample_indices)} task-trajectory pairs", verbose=self.verbose)
|
| 118 |
+
rank_0_print(f" Video tasks sampled: {dict(video_task_count)}", verbose=self.verbose)
|
| 119 |
+
rank_0_print(f" Trajectories per video task: {dict(sorted(video_task_count.items()))}", verbose=self.verbose)
|
| 120 |
+
|
| 121 |
+
return sample_indices
|
| 122 |
+
|
| 123 |
+
def _sample_trajectories_by_data_source(self) -> Tuple[list[int], dict]:
|
| 124 |
+
"""Sample N random trajectories from each data source, prioritizing different video tasks.
|
| 125 |
+
|
| 126 |
+
When sampling N trajectories, first selects one trajectory from each unique video task,
|
| 127 |
+
then repeats in round-robin fashion until N trajectories are sampled.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Tuple of (list of sampled trajectory indices, stats dictionary)
|
| 131 |
+
"""
|
| 132 |
+
sampled_indices = []
|
| 133 |
+
stats = {
|
| 134 |
+
"by_source": {},
|
| 135 |
+
"by_task": Counter(),
|
| 136 |
+
"traj_to_task": {},
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
# Group robot trajectories by data source, then by video task
|
| 140 |
+
trajectories_by_source_and_task = defaultdict(lambda: defaultdict(list))
|
| 141 |
+
for traj_idx in self.robot_trajectories:
|
| 142 |
+
traj = self.dataset[traj_idx]
|
| 143 |
+
data_source = traj.get("data_source", "unknown")
|
| 144 |
+
video_task = traj.get("task", "unknown")
|
| 145 |
+
trajectories_by_source_and_task[data_source][video_task].append(traj_idx)
|
| 146 |
+
|
| 147 |
+
rank_0_print(
|
| 148 |
+
f"Found {len(trajectories_by_source_and_task)} data sources: {list(trajectories_by_source_and_task.keys())}",
|
| 149 |
+
verbose=self.verbose,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Sample N trajectories from each data source, prioritizing different tasks
|
| 153 |
+
for data_source, tasks_to_indices in trajectories_by_source_and_task.items():
|
| 154 |
+
# Shuffle trajectories within each task for randomization
|
| 155 |
+
for task in tasks_to_indices:
|
| 156 |
+
self._local_random.shuffle(tasks_to_indices[task])
|
| 157 |
+
|
| 158 |
+
# Get all unique tasks for this data source
|
| 159 |
+
all_tasks = list(tasks_to_indices.keys())
|
| 160 |
+
self._local_random.shuffle(all_tasks) # Randomize task order too
|
| 161 |
+
|
| 162 |
+
source_stats = {
|
| 163 |
+
"total_available": sum(len(indices) for indices in tasks_to_indices.values()),
|
| 164 |
+
"tasks_available": {task: len(indices) for task, indices in tasks_to_indices.items()},
|
| 165 |
+
"tasks_sampled": Counter(),
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
if self.n_trajectories_per_source is None:
|
| 169 |
+
# Use all available trajectories
|
| 170 |
+
sampled_from_source = []
|
| 171 |
+
for task, indices in tasks_to_indices.items():
|
| 172 |
+
sampled_from_source.extend(indices)
|
| 173 |
+
source_stats["tasks_sampled"][task] = len(indices)
|
| 174 |
+
stats["by_task"][task] += len(indices)
|
| 175 |
+
|
| 176 |
+
rank_0_print(
|
| 177 |
+
f" Data source '{data_source}': Using all {len(sampled_from_source)} trajectories",
|
| 178 |
+
verbose=self.verbose,
|
| 179 |
+
)
|
| 180 |
+
else:
|
| 181 |
+
# Sample N trajectories using round-robin to prioritize different tasks
|
| 182 |
+
n_to_sample = min(self.n_trajectories_per_source, source_stats["total_available"])
|
| 183 |
+
sampled_from_source = []
|
| 184 |
+
|
| 185 |
+
# Round-robin sampling: first get one from each task, then repeat
|
| 186 |
+
task_iterators = {task: iter(indices) for task, indices in tasks_to_indices.items()}
|
| 187 |
+
task_list = all_tasks.copy()
|
| 188 |
+
round_idx = 0
|
| 189 |
+
|
| 190 |
+
while len(sampled_from_source) < n_to_sample:
|
| 191 |
+
# If we've gone through all tasks once, reshuffle for next round
|
| 192 |
+
if round_idx >= len(task_list):
|
| 193 |
+
round_idx = 0
|
| 194 |
+
self._local_random.shuffle(task_list)
|
| 195 |
+
|
| 196 |
+
# Try to get one trajectory from current task
|
| 197 |
+
task = task_list[round_idx]
|
| 198 |
+
try:
|
| 199 |
+
traj_idx = next(task_iterators[task])
|
| 200 |
+
sampled_from_source.append(traj_idx)
|
| 201 |
+
source_stats["tasks_sampled"][task] += 1
|
| 202 |
+
stats["by_task"][task] += 1
|
| 203 |
+
except StopIteration:
|
| 204 |
+
# This task is exhausted, remove it from rotation
|
| 205 |
+
task_list.pop(round_idx)
|
| 206 |
+
if not task_list:
|
| 207 |
+
break # All tasks exhausted
|
| 208 |
+
continue
|
| 209 |
+
|
| 210 |
+
round_idx += 1
|
| 211 |
+
|
| 212 |
+
rank_0_print(
|
| 213 |
+
f" Data source '{data_source}': Sampled {len(sampled_from_source)} out of {source_stats['total_available']} trajectories",
|
| 214 |
+
verbose=self.verbose,
|
| 215 |
+
)
|
| 216 |
+
rank_0_print(
|
| 217 |
+
f" Tasks sampled: {dict(sorted(source_stats['tasks_sampled'].items()))}",
|
| 218 |
+
verbose=self.verbose,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Track trajectory to task mapping for stats
|
| 222 |
+
for traj_idx in sampled_from_source:
|
| 223 |
+
traj = self.dataset[traj_idx]
|
| 224 |
+
traj_id = traj.get("id", str(traj_idx))
|
| 225 |
+
stats["traj_to_task"][traj_id] = traj.get("task", "unknown")
|
| 226 |
+
|
| 227 |
+
sampled_indices.extend(sampled_from_source)
|
| 228 |
+
stats["by_source"][data_source] = source_stats
|
| 229 |
+
|
| 230 |
+
return sampled_indices, stats
|
| 231 |
+
|
| 232 |
+
def _print_sampling_stats(self, stats: dict):
|
| 233 |
+
"""Print detailed statistics about sampled trajectories.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
stats: Statistics dictionary from _sample_trajectories_by_data_source
|
| 237 |
+
"""
|
| 238 |
+
if not self.verbose:
|
| 239 |
+
return
|
| 240 |
+
|
| 241 |
+
rank_0_print("\n=== Confusion Matrix Sampling Statistics ===", verbose=self.verbose)
|
| 242 |
+
|
| 243 |
+
# Overall task statistics
|
| 244 |
+
rank_0_print(f"\nOverall trajectories per video task:", verbose=self.verbose)
|
| 245 |
+
for task, count in sorted(stats["by_task"].items()):
|
| 246 |
+
rank_0_print(f" {task}: {count} trajectories", verbose=self.verbose)
|
| 247 |
+
|
| 248 |
+
# Per data source statistics
|
| 249 |
+
rank_0_print(f"\nPer data source breakdown:", verbose=self.verbose)
|
| 250 |
+
for data_source, source_stats in stats["by_source"].items():
|
| 251 |
+
rank_0_print(f" Data source: {data_source}", verbose=self.verbose)
|
| 252 |
+
rank_0_print(f" Total available: {source_stats['total_available']}", verbose=self.verbose)
|
| 253 |
+
rank_0_print(f" Tasks available: {len(source_stats['tasks_available'])}", verbose=self.verbose)
|
| 254 |
+
for task, count in sorted(source_stats['tasks_available'].items()):
|
| 255 |
+
sampled_count = source_stats['tasks_sampled'].get(task, 0)
|
| 256 |
+
rank_0_print(
|
| 257 |
+
f" {task}: {sampled_count}/{count} trajectories sampled",
|
| 258 |
+
verbose=self.verbose,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
rank_0_print("=" * 50, verbose=self.verbose)
|
| 262 |
+
|
| 263 |
+
def _generate_sample_from_indices(self, sample_idx_info: dict) -> PreferenceSample:
|
| 264 |
+
"""Generate a single task-trajectory sample from stored indices."""
|
| 265 |
+
traj_idx = sample_idx_info["traj_idx"]
|
| 266 |
+
lang_task = sample_idx_info["lang_task"]
|
| 267 |
+
video_task = sample_idx_info["video_task"]
|
| 268 |
+
video_path = sample_idx_info["video_path"]
|
| 269 |
+
|
| 270 |
+
video_traj = self.dataset[traj_idx]
|
| 271 |
+
|
| 272 |
+
# Look up precomputed embedding instead of encoding
|
| 273 |
+
text_embedding = self.task_embeddings[lang_task]
|
| 274 |
+
|
| 275 |
+
metadata = {
|
| 276 |
+
"id": video_traj["id"],
|
| 277 |
+
"lang_task": lang_task,
|
| 278 |
+
"video_task": video_task,
|
| 279 |
+
"video_path": video_path,
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
# Override task and text_embedding in the trajectory dict
|
| 283 |
+
video_traj_with_task = video_traj.copy()
|
| 284 |
+
video_traj_with_task["task"] = lang_task
|
| 285 |
+
video_traj_with_task["text_embedding"] = text_embedding
|
| 286 |
+
|
| 287 |
+
sample_trajectory = self._get_traj_from_data(
|
| 288 |
+
traj=video_traj_with_task,
|
| 289 |
+
metadata=metadata,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
sample = ProgressSample(trajectory=sample_trajectory)
|
| 293 |
+
return sample
|
| 294 |
+
|
| 295 |
+
def __len__(self):
|
| 296 |
+
return len(self.sample_indices)
|
| 297 |
+
|
| 298 |
+
def __getitem__(self, idx):
|
| 299 |
+
return self._generate_sample_from_indices(self.sample_indices[idx])
|
samplers/eval/progress_policy_ranking.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Any, Optional
|
| 2 |
+
from itertools import cycle
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from rfm.data.dataset_types import ProgressSample
|
| 7 |
+
from rfm.data.samplers.base import RFMBaseSampler
|
| 8 |
+
from rfm.utils.logger import get_logger
|
| 9 |
+
|
| 10 |
+
logger = get_logger()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ProgressPolicyRankingSampler(RFMBaseSampler):
|
| 14 |
+
"""Dataset that generates progress samples for policy ranking by selecting N trajectories per quality label for tasks with multiple quality labels."""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
num_examples_per_quality_pr: int = 5,
|
| 19 |
+
num_partial_successes: Optional[int] = None,
|
| 20 |
+
frame_step: int = 1,
|
| 21 |
+
use_frame_steps: bool = True,
|
| 22 |
+
max_tasks: Optional[int] = None,
|
| 23 |
+
**kwargs,
|
| 24 |
+
):
|
| 25 |
+
super().__init__(**kwargs)
|
| 26 |
+
|
| 27 |
+
self.num_examples_per_quality_pr = num_examples_per_quality_pr
|
| 28 |
+
self.num_partial_successes = num_partial_successes
|
| 29 |
+
self.frame_step = frame_step
|
| 30 |
+
self.use_frame_steps = use_frame_steps
|
| 31 |
+
self.max_tasks = max_tasks
|
| 32 |
+
logger.info(f"ProgressPolicyRankingSampler initialized with {len(self.robot_trajectories)} trajectories")
|
| 33 |
+
|
| 34 |
+
self.sample_indices = self._generate_all_sample_indices()
|
| 35 |
+
|
| 36 |
+
logger.info(f"Generated {len(self.sample_indices)} sample indices")
|
| 37 |
+
|
| 38 |
+
def _generate_all_sample_indices(self) -> List[Dict[str, Any]]:
|
| 39 |
+
"""Generate sample indices by selecting tasks with multiple quality labels/partial_success values and sampling N trajectories per group.
|
| 40 |
+
|
| 41 |
+
For non-RoboArena: Groups by task and quality_label.
|
| 42 |
+
For RoboArena: Groups by task and partial_success values.
|
| 43 |
+
|
| 44 |
+
If use_frame_steps=True, generates subsequence samples like reward_alignment (0:frame_step, 0:2*frame_step, etc.).
|
| 45 |
+
If use_frame_steps=False, generates one sample per trajectory (whole trajectory).
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
# Check if this is RoboArena (has partial_success)
|
| 49 |
+
is_roboarena = False
|
| 50 |
+
if self.robot_trajectories:
|
| 51 |
+
first_traj = self.dataset[self.robot_trajectories[0]]
|
| 52 |
+
is_roboarena = first_traj.get("partial_success") is not None
|
| 53 |
+
|
| 54 |
+
# Group trajectories by task and grouping key (quality_label or partial_success)
|
| 55 |
+
task_to_key_to_trajs = defaultdict(lambda: defaultdict(list))
|
| 56 |
+
|
| 57 |
+
for traj_idx in self.robot_trajectories:
|
| 58 |
+
traj = self.dataset[traj_idx]
|
| 59 |
+
task = traj["task"]
|
| 60 |
+
|
| 61 |
+
if is_roboarena:
|
| 62 |
+
# RoboArena: Use rounded partial_success as key to group similar values
|
| 63 |
+
partial_success_val = traj.get("partial_success")
|
| 64 |
+
if partial_success_val is not None:
|
| 65 |
+
partial_success = round(float(partial_success_val), 2)
|
| 66 |
+
task_to_key_to_trajs[task][partial_success].append(traj_idx)
|
| 67 |
+
else:
|
| 68 |
+
# Non-RoboArena: Use quality_label
|
| 69 |
+
quality = traj["quality_label"]
|
| 70 |
+
task_to_key_to_trajs[task][quality].append(traj_idx)
|
| 71 |
+
|
| 72 |
+
# Filter to tasks that have multiple grouping values
|
| 73 |
+
tasks_with_multiple_values = {
|
| 74 |
+
task: key_to_trajs for task, key_to_trajs in task_to_key_to_trajs.items() if len(key_to_trajs) > 1
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
dataset_type_str = "partial_success values" if is_roboarena else "quality labels"
|
| 78 |
+
logger.info(f"Found {len(tasks_with_multiple_values)} tasks with multiple {dataset_type_str}")
|
| 79 |
+
|
| 80 |
+
# Limit number of tasks if max_tasks is specified
|
| 81 |
+
if self.max_tasks is not None and self.max_tasks > 0:
|
| 82 |
+
# Convert to list, shuffle, and take first max_tasks
|
| 83 |
+
# Sort tasks first to ensure deterministic ordering before shuffling
|
| 84 |
+
tasks_list = sorted(tasks_with_multiple_values.items())
|
| 85 |
+
self._local_random.shuffle(tasks_list)
|
| 86 |
+
tasks_with_multiple_values = dict(tasks_list[: self.max_tasks])
|
| 87 |
+
logger.info(f"Limited to {len(tasks_with_multiple_values)} tasks (max_tasks={self.max_tasks})")
|
| 88 |
+
|
| 89 |
+
# Sample trajectories for each task
|
| 90 |
+
sample_indices = []
|
| 91 |
+
all_sampled_traj_indices = []
|
| 92 |
+
# Sort tasks to ensure deterministic processing order
|
| 93 |
+
for task, key_to_trajs in sorted(tasks_with_multiple_values.items()):
|
| 94 |
+
if is_roboarena:
|
| 95 |
+
# RoboArena: Use num_partial_successes for circular sampling
|
| 96 |
+
num_to_sample_total = self.num_partial_successes
|
| 97 |
+
|
| 98 |
+
# Build lists of available indices per partial_success (sorted for deterministic sampling)
|
| 99 |
+
available_lists = []
|
| 100 |
+
for partial_success in sorted(key_to_trajs.keys()):
|
| 101 |
+
traj_indices = sorted(key_to_trajs[partial_success])
|
| 102 |
+
if traj_indices:
|
| 103 |
+
available_lists.append(traj_indices)
|
| 104 |
+
|
| 105 |
+
# Circular sampling: cycle through partial_success groups until we reach max
|
| 106 |
+
sampled_traj_indices = []
|
| 107 |
+
for available_indices in cycle(available_lists):
|
| 108 |
+
if len(sampled_traj_indices) >= num_to_sample_total:
|
| 109 |
+
break
|
| 110 |
+
if not available_indices:
|
| 111 |
+
# If all lists are empty, stop
|
| 112 |
+
if all(not lst for lst in available_lists):
|
| 113 |
+
break
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
# Sample one trajectory from this group
|
| 117 |
+
sampled_idx = self._local_random.choice(available_indices)
|
| 118 |
+
sampled_traj_indices.append(sampled_idx)
|
| 119 |
+
# Remove the sampled index from this list
|
| 120 |
+
available_indices.remove(sampled_idx)
|
| 121 |
+
|
| 122 |
+
# Generate samples for all sampled trajectories
|
| 123 |
+
for traj_idx in sampled_traj_indices:
|
| 124 |
+
traj = self.dataset[traj_idx]
|
| 125 |
+
sample_indices.extend(self._generate_indices_for_trajectory(traj_idx, traj))
|
| 126 |
+
all_sampled_traj_indices.append(traj_idx)
|
| 127 |
+
else:
|
| 128 |
+
# Non-RoboArena: Sample N trajectories per quality label
|
| 129 |
+
# Sort quality labels to ensure deterministic order
|
| 130 |
+
for quality in sorted(key_to_trajs.keys()):
|
| 131 |
+
traj_indices = key_to_trajs[quality]
|
| 132 |
+
# Sort trajectory indices to ensure deterministic sampling
|
| 133 |
+
traj_indices = sorted(traj_indices)
|
| 134 |
+
# Sample up to num_examples_per_quality_pr trajectories for this quality label
|
| 135 |
+
num_to_sample = min(self.num_examples_per_quality_pr, len(traj_indices))
|
| 136 |
+
sampled_traj_indices = self._local_random.sample(traj_indices, num_to_sample)
|
| 137 |
+
for traj_idx in sampled_traj_indices:
|
| 138 |
+
traj = self.dataset[traj_idx]
|
| 139 |
+
sample_indices.extend(self._generate_indices_for_trajectory(traj_idx, traj))
|
| 140 |
+
all_sampled_traj_indices.append(traj_idx)
|
| 141 |
+
|
| 142 |
+
logger.info(f"Sampled {len(sample_indices)} samples across {len(tasks_with_multiple_values)} tasks")
|
| 143 |
+
logger.info(f"Sampled trajectory indices: {all_sampled_traj_indices}")
|
| 144 |
+
|
| 145 |
+
return sample_indices
|
| 146 |
+
|
| 147 |
+
def _generate_indices_for_trajectory(self, traj_idx: int, traj: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 148 |
+
"""Generate sample indices for a single trajectory.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
traj_idx: Index of the trajectory in the dataset
|
| 152 |
+
traj: Trajectory dictionary
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
List of sample index dictionaries
|
| 156 |
+
"""
|
| 157 |
+
num_frames = traj["num_frames"]
|
| 158 |
+
indices = []
|
| 159 |
+
|
| 160 |
+
if self.use_frame_steps:
|
| 161 |
+
# Generate subsequence indices like reward_alignment: 0:frame_step, 0:2*frame_step, etc.
|
| 162 |
+
for end_idx in range(self.frame_step, num_frames + 1, self.frame_step):
|
| 163 |
+
frame_indices = list(range(end_idx))
|
| 164 |
+
indices.append({
|
| 165 |
+
"traj_idx": traj_idx,
|
| 166 |
+
"frame_indices": frame_indices,
|
| 167 |
+
"num_frames": num_frames,
|
| 168 |
+
"video_path": traj["frames"],
|
| 169 |
+
"id": traj["id"],
|
| 170 |
+
"use_frame_steps": True,
|
| 171 |
+
})
|
| 172 |
+
else:
|
| 173 |
+
# Generate one sample per trajectory (whole trajectory)
|
| 174 |
+
indices.append({
|
| 175 |
+
"traj_idx": traj_idx,
|
| 176 |
+
"video_path": traj["frames"],
|
| 177 |
+
"id": traj["id"],
|
| 178 |
+
"use_frame_steps": False,
|
| 179 |
+
})
|
| 180 |
+
|
| 181 |
+
return indices
|
| 182 |
+
|
| 183 |
+
def _generate_sample_from_indices(self, sample_idx_info: dict) -> ProgressSample:
|
| 184 |
+
"""Generate a single progress sample from trajectory index."""
|
| 185 |
+
traj_idx = sample_idx_info["traj_idx"]
|
| 186 |
+
use_frame_steps = sample_idx_info.get("use_frame_steps", True)
|
| 187 |
+
|
| 188 |
+
traj = self.dataset[traj_idx]
|
| 189 |
+
|
| 190 |
+
if use_frame_steps:
|
| 191 |
+
# Frame steps mode: create subsequence like reward_alignment
|
| 192 |
+
frame_indices = sample_idx_info["frame_indices"]
|
| 193 |
+
num_frames = sample_idx_info["num_frames"]
|
| 194 |
+
|
| 195 |
+
metadata = {
|
| 196 |
+
"quality_label": traj["quality_label"],
|
| 197 |
+
"data_source": traj["data_source"],
|
| 198 |
+
"task": traj["task"],
|
| 199 |
+
"id": traj["id"],
|
| 200 |
+
"video_path": sample_idx_info["video_path"],
|
| 201 |
+
"frame_step": frame_indices[-1] if frame_indices else 0,
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
trajectory = self._get_traj_from_data(
|
| 205 |
+
traj=traj,
|
| 206 |
+
frame_indices=frame_indices,
|
| 207 |
+
metadata=metadata,
|
| 208 |
+
)
|
| 209 |
+
else:
|
| 210 |
+
# Whole trajectory mode
|
| 211 |
+
metadata = {
|
| 212 |
+
"quality_label": traj["quality_label"],
|
| 213 |
+
"data_source": traj["data_source"],
|
| 214 |
+
"task": traj["task"],
|
| 215 |
+
"id": traj["id"],
|
| 216 |
+
"video_path": sample_idx_info["video_path"],
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
trajectory = self._get_traj_from_data(
|
| 220 |
+
traj=traj,
|
| 221 |
+
metadata=metadata,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
sample = ProgressSample(trajectory=trajectory)
|
| 225 |
+
return sample
|
| 226 |
+
|
| 227 |
+
def __len__(self):
|
| 228 |
+
return len(self.sample_indices)
|
| 229 |
+
|
| 230 |
+
def __getitem__(self, idx):
|
| 231 |
+
return self._generate_sample_from_indices(self.sample_indices[idx])
|
samplers/eval/quality_preference.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Any
|
| 2 |
+
|
| 3 |
+
from itertools import combinations
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
from rfm.data.samplers.eval.base_pref import BaseQualityPreferenceSampler
|
| 7 |
+
from rfm.utils.distributed import rank_0_print
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class QualityPreferenceSampler(BaseQualityPreferenceSampler):
|
| 11 |
+
"""Dataset that generates preference samples by pairing trajectories with different quality labels or partial_success values for the same task.
|
| 12 |
+
|
| 13 |
+
For non-RoboArena: Pairs trajectories with different quality labels (failure, suboptimal, successful).
|
| 14 |
+
For RoboArena: Pairs trajectories with different partial_success values (higher partial_success = chosen).
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
comparisons_per_task=None,
|
| 20 |
+
max_comparisons=None,
|
| 21 |
+
**kwargs,
|
| 22 |
+
):
|
| 23 |
+
super().__init__(**kwargs)
|
| 24 |
+
|
| 25 |
+
# Set data_gen_strategy for this sampler
|
| 26 |
+
self.data_gen_strategy = "quality_preference"
|
| 27 |
+
self.comparisons_per_task = comparisons_per_task
|
| 28 |
+
self.max_comparisons = max_comparisons
|
| 29 |
+
|
| 30 |
+
# Generate all possible sample indices upfront (not the actual samples)
|
| 31 |
+
self.sample_indices = self._generate_all_sample_indices()
|
| 32 |
+
rank_0_print(f"Generated {len(self.sample_indices)} quality preference sample indices", verbose=self.verbose)
|
| 33 |
+
|
| 34 |
+
def _generate_all_sample_indices(self) -> List[Dict[str, Any]]:
|
| 35 |
+
"""Generate all possible quality preference sample indices (not the actual samples).
|
| 36 |
+
|
| 37 |
+
For non-RoboArena: Groups by task and quality_label, pairs trajectories with different quality labels.
|
| 38 |
+
For RoboArena: Groups by task and partial_success values, pairs trajectories with different partial_success.
|
| 39 |
+
"""
|
| 40 |
+
sample_indices = []
|
| 41 |
+
|
| 42 |
+
# Check if this is RoboArena (has partial_success)
|
| 43 |
+
is_roboarena = False
|
| 44 |
+
if self.robot_trajectories:
|
| 45 |
+
first_traj = self.dataset[self.robot_trajectories[0]]
|
| 46 |
+
is_roboarena = first_traj.get("partial_success") is not None
|
| 47 |
+
|
| 48 |
+
rank_0_print(
|
| 49 |
+
f"Generating quality preference samples for {len(self.robot_trajectories)} trajectories "
|
| 50 |
+
f"({'RoboArena (partial_success)' if is_roboarena else 'non-RoboArena (quality_label)'})",
|
| 51 |
+
verbose=self.verbose,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
if is_roboarena:
|
| 55 |
+
# RoboArena: Group by task and partial_success (rounded to 2 decimals)
|
| 56 |
+
task_to_partial_trajs = {}
|
| 57 |
+
|
| 58 |
+
for traj_idx in self.robot_trajectories:
|
| 59 |
+
traj = self.dataset[traj_idx]
|
| 60 |
+
task = traj["task"]
|
| 61 |
+
partial_success_val = traj.get("partial_success")
|
| 62 |
+
|
| 63 |
+
if partial_success_val is None:
|
| 64 |
+
rank_0_print(
|
| 65 |
+
f"Warning: Trajectory {traj_idx} (task: {task}) missing partial_success, skipping",
|
| 66 |
+
verbose=self.verbose,
|
| 67 |
+
)
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
# Round partial_success to 2 decimals for grouping
|
| 71 |
+
partial_success = round(float(partial_success_val), 2)
|
| 72 |
+
|
| 73 |
+
if task not in task_to_partial_trajs:
|
| 74 |
+
task_to_partial_trajs[task] = {}
|
| 75 |
+
|
| 76 |
+
if partial_success not in task_to_partial_trajs[task]:
|
| 77 |
+
task_to_partial_trajs[task][partial_success] = []
|
| 78 |
+
|
| 79 |
+
task_to_partial_trajs[task][partial_success].append(traj_idx)
|
| 80 |
+
|
| 81 |
+
# Generate pairs for each task
|
| 82 |
+
for task in tqdm(task_to_partial_trajs, desc="Generating RoboArena quality preference samples"):
|
| 83 |
+
partial_groups = task_to_partial_trajs[task]
|
| 84 |
+
partial_values = list(partial_groups.keys())
|
| 85 |
+
|
| 86 |
+
# Only create pairs if we have at least 2 different partial_success values
|
| 87 |
+
if len(partial_values) < 2:
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
# Collect all pairs for this task
|
| 91 |
+
task_pairs = []
|
| 92 |
+
|
| 93 |
+
# Create pairs of different partial_success values
|
| 94 |
+
for partial1, partial2 in combinations(partial_values, 2):
|
| 95 |
+
trajs1 = partial_groups[partial1]
|
| 96 |
+
trajs2 = partial_groups[partial2]
|
| 97 |
+
|
| 98 |
+
if not trajs1 or not trajs2:
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
# Determine which partial_success is higher (chosen)
|
| 102 |
+
if partial1 > partial2:
|
| 103 |
+
chosen_partial = partial1
|
| 104 |
+
rejected_partial = partial2
|
| 105 |
+
chosen_indices = trajs1
|
| 106 |
+
rejected_indices = trajs2
|
| 107 |
+
elif partial2 > partial1:
|
| 108 |
+
chosen_partial = partial2
|
| 109 |
+
rejected_partial = partial1
|
| 110 |
+
chosen_indices = trajs2
|
| 111 |
+
rejected_indices = trajs1
|
| 112 |
+
else:
|
| 113 |
+
# Same value, skip this pair
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
# Create all possible pairs for this partial_success combination
|
| 117 |
+
for chosen_idx in chosen_indices:
|
| 118 |
+
for rejected_idx in rejected_indices:
|
| 119 |
+
task_pairs.append({
|
| 120 |
+
"chosen_traj_idx": chosen_idx,
|
| 121 |
+
"rejected_traj_idx": rejected_idx,
|
| 122 |
+
"task": task,
|
| 123 |
+
"chosen_partial_success": chosen_partial,
|
| 124 |
+
"rejected_partial_success": rejected_partial,
|
| 125 |
+
})
|
| 126 |
+
|
| 127 |
+
# Apply comparisons_per_task limit if set (sample uniformly across all pairs for this task)
|
| 128 |
+
if self.comparisons_per_task is not None and len(task_pairs) > self.comparisons_per_task:
|
| 129 |
+
# Uniformly sample comparisons for this task
|
| 130 |
+
task_pairs = self._local_random.sample(task_pairs, self.comparisons_per_task)
|
| 131 |
+
|
| 132 |
+
sample_indices.extend(task_pairs)
|
| 133 |
+
|
| 134 |
+
else:
|
| 135 |
+
# Non-RoboArena: Group by task and quality label
|
| 136 |
+
task_to_quality_trajs = {}
|
| 137 |
+
|
| 138 |
+
for traj_idx in self.robot_trajectories:
|
| 139 |
+
traj = self.dataset[traj_idx]
|
| 140 |
+
task = traj["task"]
|
| 141 |
+
quality_label = traj["quality_label"]
|
| 142 |
+
|
| 143 |
+
if task not in task_to_quality_trajs:
|
| 144 |
+
task_to_quality_trajs[task] = {}
|
| 145 |
+
|
| 146 |
+
if quality_label not in task_to_quality_trajs[task]:
|
| 147 |
+
task_to_quality_trajs[task][quality_label] = []
|
| 148 |
+
|
| 149 |
+
task_to_quality_trajs[task][quality_label].append(traj_idx)
|
| 150 |
+
|
| 151 |
+
# Generate pairs for each task
|
| 152 |
+
quality_order = {"failure": 1, "suboptimal": 2, "successful": 3}
|
| 153 |
+
|
| 154 |
+
for task in tqdm(task_to_quality_trajs, desc="Generating quality preference samples"):
|
| 155 |
+
quality_groups = task_to_quality_trajs[task]
|
| 156 |
+
quality_labels = list(quality_groups.keys())
|
| 157 |
+
|
| 158 |
+
# Only create pairs if we have at least 2 different quality labels
|
| 159 |
+
if len(quality_labels) < 2:
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
# Collect all pairs for this task
|
| 163 |
+
task_pairs = []
|
| 164 |
+
|
| 165 |
+
# Create pairs of different quality labels
|
| 166 |
+
for quality1, quality2 in combinations(quality_labels, 2):
|
| 167 |
+
trajs1 = quality_groups[quality1]
|
| 168 |
+
trajs2 = quality_groups[quality2]
|
| 169 |
+
|
| 170 |
+
if not trajs1 or not trajs2:
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
# Determine which quality is better (chosen)
|
| 174 |
+
order1 = quality_order.get(quality1, 0)
|
| 175 |
+
order2 = quality_order.get(quality2, 0)
|
| 176 |
+
|
| 177 |
+
# Only create pairs if one quality is strictly better than the other
|
| 178 |
+
if order1 > order2:
|
| 179 |
+
chosen_quality = quality1
|
| 180 |
+
rejected_quality = quality2
|
| 181 |
+
chosen_indices = trajs1
|
| 182 |
+
rejected_indices = trajs2
|
| 183 |
+
elif order2 > order1:
|
| 184 |
+
chosen_quality = quality2
|
| 185 |
+
rejected_quality = quality1
|
| 186 |
+
chosen_indices = trajs2
|
| 187 |
+
rejected_indices = trajs1
|
| 188 |
+
else:
|
| 189 |
+
# Same order, skip this pair as we can't reliably compare them
|
| 190 |
+
continue
|
| 191 |
+
|
| 192 |
+
# Create all possible pairs for this quality combination
|
| 193 |
+
for chosen_idx in chosen_indices:
|
| 194 |
+
for rejected_idx in rejected_indices:
|
| 195 |
+
task_pairs.append({
|
| 196 |
+
"chosen_traj_idx": chosen_idx,
|
| 197 |
+
"rejected_traj_idx": rejected_idx,
|
| 198 |
+
"task": task,
|
| 199 |
+
"chosen_quality": chosen_quality,
|
| 200 |
+
"rejected_quality": rejected_quality,
|
| 201 |
+
})
|
| 202 |
+
|
| 203 |
+
# Apply comparisons_per_task limit if set (sample uniformly across all pairs for this task)
|
| 204 |
+
if self.comparisons_per_task is not None and len(task_pairs) > self.comparisons_per_task:
|
| 205 |
+
# Uniformly sample comparisons for this task
|
| 206 |
+
task_pairs = self._local_random.sample(task_pairs, self.comparisons_per_task)
|
| 207 |
+
|
| 208 |
+
sample_indices.extend(task_pairs)
|
| 209 |
+
|
| 210 |
+
# Apply max_comparisons limit if set (sample uniformly across all comparisons)
|
| 211 |
+
original_count = len(sample_indices)
|
| 212 |
+
if self.max_comparisons is not None and original_count > self.max_comparisons:
|
| 213 |
+
sample_indices = self._local_random.sample(sample_indices, self.max_comparisons)
|
| 214 |
+
rank_0_print(
|
| 215 |
+
f"Limited total comparisons to {self.max_comparisons} (from {original_count} total comparisons)",
|
| 216 |
+
verbose=self.verbose,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
return sample_indices
|
samplers/eval/reward_alignment.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Data generator for reward alignment evaluation.
|
| 4 |
+
|
| 5 |
+
This generator creates subsequence samples from trajectories for progress prediction evaluation.
|
| 6 |
+
For each trajectory, it creates multiple subsequences (0:2, 0:4, 0:6, etc.) and formats them
|
| 7 |
+
as PreferenceSample objects that can be evaluated by the model.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Dict, List, Any
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from rfm.data.dataset_types import ProgressSample, Trajectory
|
| 16 |
+
from rfm.data.samplers.base import RFMBaseSampler
|
| 17 |
+
from rfm.utils.distributed import rank_0_print
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RewardAlignmentSampler(RFMBaseSampler):
|
| 21 |
+
"""
|
| 22 |
+
Data generator that creates subsequence samples for reward alignment evaluation.
|
| 23 |
+
|
| 24 |
+
For each trajectory, creates subsequences of frames (0:2, 0:4, 0:6, etc.)
|
| 25 |
+
and formats them as PreferenceSample objects for evaluation.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
max_trajectories: int | None = None,
|
| 31 |
+
frame_step: int = 1,
|
| 32 |
+
use_frame_steps: bool = True,
|
| 33 |
+
**kwargs,
|
| 34 |
+
):
|
| 35 |
+
super().__init__(**kwargs)
|
| 36 |
+
|
| 37 |
+
self.max_trajectories = max_trajectories
|
| 38 |
+
self.frame_step = frame_step
|
| 39 |
+
self.use_frame_steps = use_frame_steps
|
| 40 |
+
self.sample_indices = self._generate_all_sample_indices()
|
| 41 |
+
|
| 42 |
+
rank_0_print(
|
| 43 |
+
f"Generated {len(self.sample_indices)} reward alignment sample indices from {min(len(self.robot_trajectories), self.max_trajectories) if self.max_trajectories else len(self.robot_trajectories)} trajectories",
|
| 44 |
+
verbose=self.verbose,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def _generate_all_sample_indices(self) -> List[Dict[str, Any]]:
|
| 48 |
+
"""Generate all possible subsequence sample indices (not the actual samples)."""
|
| 49 |
+
sample_indices = []
|
| 50 |
+
|
| 51 |
+
# Limit number of trajectories if specified
|
| 52 |
+
trajectories_to_process = self.robot_trajectories
|
| 53 |
+
if self.max_trajectories is not None and self.max_trajectories < len(self.robot_trajectories):
|
| 54 |
+
trajectories_to_process = self._local_random.sample(self.robot_trajectories, self.max_trajectories)
|
| 55 |
+
|
| 56 |
+
rank_0_print(
|
| 57 |
+
f"Generating subsequence samples for {len(trajectories_to_process)} trajectories", verbose=self.verbose
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
for traj_idx in trajectories_to_process:
|
| 61 |
+
traj = self.dataset[traj_idx]
|
| 62 |
+
sample_indices.extend(self._generate_indices_for_trajectory(traj_idx, traj))
|
| 63 |
+
|
| 64 |
+
return sample_indices
|
| 65 |
+
|
| 66 |
+
def _generate_indices_for_trajectory(self, traj_idx: int, traj: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 67 |
+
"""Generate sample indices for a single trajectory.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
traj_idx: Index of the trajectory in the dataset
|
| 71 |
+
traj: Trajectory dictionary
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
List of sample index dictionaries
|
| 75 |
+
"""
|
| 76 |
+
num_frames = traj["num_frames"]
|
| 77 |
+
indices = []
|
| 78 |
+
|
| 79 |
+
if self.use_frame_steps:
|
| 80 |
+
# Generate subsequence indices like reward_alignment: 0:frame_step, 0:2*frame_step, etc.
|
| 81 |
+
for end_idx in range(self.frame_step, num_frames + 1, self.frame_step):
|
| 82 |
+
frame_indices = list(range(end_idx))
|
| 83 |
+
indices.append({
|
| 84 |
+
"traj_idx": traj_idx,
|
| 85 |
+
"frame_indices": frame_indices,
|
| 86 |
+
"num_frames": num_frames,
|
| 87 |
+
"video_path": traj["frames"],
|
| 88 |
+
"id": traj["id"],
|
| 89 |
+
"use_frame_steps": True,
|
| 90 |
+
})
|
| 91 |
+
else:
|
| 92 |
+
# Generate one sample per trajectory (whole trajectory)
|
| 93 |
+
indices.append({
|
| 94 |
+
"traj_idx": traj_idx,
|
| 95 |
+
"video_path": traj["frames"],
|
| 96 |
+
"id": traj["id"],
|
| 97 |
+
"use_frame_steps": False,
|
| 98 |
+
})
|
| 99 |
+
|
| 100 |
+
return indices
|
| 101 |
+
|
| 102 |
+
def _generate_sample_from_indices(self, sample_idx_info: dict) -> ProgressSample:
|
| 103 |
+
"""Generate a single subsequence sample from stored indices."""
|
| 104 |
+
traj_idx = sample_idx_info["traj_idx"]
|
| 105 |
+
use_frame_steps = sample_idx_info.get("use_frame_steps", True)
|
| 106 |
+
|
| 107 |
+
traj = self.dataset[traj_idx]
|
| 108 |
+
|
| 109 |
+
if use_frame_steps:
|
| 110 |
+
# Frame steps mode: create subsequence like reward_alignment
|
| 111 |
+
frame_indices = sample_idx_info["frame_indices"]
|
| 112 |
+
num_frames = sample_idx_info["num_frames"]
|
| 113 |
+
|
| 114 |
+
metadata = {
|
| 115 |
+
"data_gen_strategy": "reward_alignment",
|
| 116 |
+
"id": traj["id"],
|
| 117 |
+
"video_path": sample_idx_info["video_path"],
|
| 118 |
+
"frame_step": frame_indices[-1] if frame_indices else 0,
|
| 119 |
+
"num_frames": num_frames,
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
trajectory = self._get_traj_from_data(
|
| 123 |
+
traj=traj,
|
| 124 |
+
frame_indices=frame_indices,
|
| 125 |
+
metadata=metadata,
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
# Whole trajectory mode
|
| 129 |
+
metadata = {
|
| 130 |
+
"data_gen_strategy": "reward_alignment",
|
| 131 |
+
"id": traj["id"],
|
| 132 |
+
"video_path": sample_idx_info["video_path"],
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
trajectory = self._get_traj_from_data(
|
| 136 |
+
traj=traj,
|
| 137 |
+
metadata=metadata,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
sample = ProgressSample(trajectory=trajectory, sample_type="progress")
|
| 141 |
+
return sample
|
| 142 |
+
|
| 143 |
+
def __len__(self):
|
| 144 |
+
return len(self.sample_indices)
|
| 145 |
+
|
| 146 |
+
def __getitem__(self, idx):
|
| 147 |
+
return self._generate_sample_from_indices(self.sample_indices[idx])
|