rewardeval_ui / samplers /eval /progress_policy_ranking.py
Anthony Liang
getting rid of dependencies
3e462dd
from typing import Dict, List, Any, Optional
from itertools import cycle
import numpy as np
from collections import defaultdict
from rfm.data.dataset_types import ProgressSample
from rfm.data.samplers.base import RFMBaseSampler
from rfm.utils.logger import get_logger
logger = get_logger()
class ProgressPolicyRankingSampler(RFMBaseSampler):
"""Dataset that generates progress samples for policy ranking by selecting N trajectories per quality label for tasks with multiple quality labels."""
def __init__(
self,
num_examples_per_quality_pr: int = 5,
num_partial_successes: Optional[int] = None,
frame_step: int = 1,
use_frame_steps: bool = True,
max_tasks: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)
self.num_examples_per_quality_pr = num_examples_per_quality_pr
self.num_partial_successes = num_partial_successes
self.frame_step = frame_step
self.use_frame_steps = use_frame_steps
self.max_tasks = max_tasks
logger.info(f"ProgressPolicyRankingSampler initialized with {len(self.robot_trajectories)} trajectories")
self.sample_indices = self._generate_all_sample_indices()
logger.info(f"Generated {len(self.sample_indices)} sample indices")
def _generate_all_sample_indices(self) -> List[Dict[str, Any]]:
"""Generate sample indices by selecting tasks with multiple quality labels/partial_success values and sampling N trajectories per group.
For non-RoboArena: Groups by task and quality_label.
For RoboArena: Groups by task and partial_success values.
If use_frame_steps=True, generates subsequence samples like reward_alignment (0:frame_step, 0:2*frame_step, etc.).
If use_frame_steps=False, generates one sample per trajectory (whole trajectory).
"""
# Check if this is RoboArena (has partial_success)
is_roboarena = False
if self.robot_trajectories:
first_traj = self.dataset[self.robot_trajectories[0]]
is_roboarena = first_traj.get("partial_success") is not None
# Group trajectories by task and grouping key (quality_label or partial_success)
task_to_key_to_trajs = defaultdict(lambda: defaultdict(list))
for traj_idx in self.robot_trajectories:
traj = self.dataset[traj_idx]
task = traj["task"]
if is_roboarena:
# RoboArena: Use rounded partial_success as key to group similar values
partial_success_val = traj.get("partial_success")
if partial_success_val is not None:
partial_success = round(float(partial_success_val), 2)
task_to_key_to_trajs[task][partial_success].append(traj_idx)
else:
# Non-RoboArena: Use quality_label
quality = traj["quality_label"]
task_to_key_to_trajs[task][quality].append(traj_idx)
# Filter to tasks that have multiple grouping values
tasks_with_multiple_values = {
task: key_to_trajs for task, key_to_trajs in task_to_key_to_trajs.items() if len(key_to_trajs) > 1
}
dataset_type_str = "partial_success values" if is_roboarena else "quality labels"
logger.info(f"Found {len(tasks_with_multiple_values)} tasks with multiple {dataset_type_str}")
# Limit number of tasks if max_tasks is specified
if self.max_tasks is not None and self.max_tasks > 0:
# Convert to list, shuffle, and take first max_tasks
# Sort tasks first to ensure deterministic ordering before shuffling
tasks_list = sorted(tasks_with_multiple_values.items())
self._local_random.shuffle(tasks_list)
tasks_with_multiple_values = dict(tasks_list[: self.max_tasks])
logger.info(f"Limited to {len(tasks_with_multiple_values)} tasks (max_tasks={self.max_tasks})")
# Sample trajectories for each task
sample_indices = []
all_sampled_traj_indices = []
# Sort tasks to ensure deterministic processing order
for task, key_to_trajs in sorted(tasks_with_multiple_values.items()):
if is_roboarena:
# RoboArena: Use num_partial_successes for circular sampling
num_to_sample_total = self.num_partial_successes
# Build lists of available indices per partial_success (sorted for deterministic sampling)
available_lists = []
for partial_success in sorted(key_to_trajs.keys()):
traj_indices = sorted(key_to_trajs[partial_success])
if traj_indices:
available_lists.append(traj_indices)
# Circular sampling: cycle through partial_success groups until we reach max
sampled_traj_indices = []
for available_indices in cycle(available_lists):
if len(sampled_traj_indices) >= num_to_sample_total:
break
if not available_indices:
# If all lists are empty, stop
if all(not lst for lst in available_lists):
break
continue
# Sample one trajectory from this group
sampled_idx = self._local_random.choice(available_indices)
sampled_traj_indices.append(sampled_idx)
# Remove the sampled index from this list
available_indices.remove(sampled_idx)
# Generate samples for all sampled trajectories
for traj_idx in sampled_traj_indices:
traj = self.dataset[traj_idx]
sample_indices.extend(self._generate_indices_for_trajectory(traj_idx, traj))
all_sampled_traj_indices.append(traj_idx)
else:
# Non-RoboArena: Sample N trajectories per quality label
# Sort quality labels to ensure deterministic order
for quality in sorted(key_to_trajs.keys()):
traj_indices = key_to_trajs[quality]
# Sort trajectory indices to ensure deterministic sampling
traj_indices = sorted(traj_indices)
# Sample up to num_examples_per_quality_pr trajectories for this quality label
num_to_sample = min(self.num_examples_per_quality_pr, len(traj_indices))
sampled_traj_indices = self._local_random.sample(traj_indices, num_to_sample)
for traj_idx in sampled_traj_indices:
traj = self.dataset[traj_idx]
sample_indices.extend(self._generate_indices_for_trajectory(traj_idx, traj))
all_sampled_traj_indices.append(traj_idx)
logger.info(f"Sampled {len(sample_indices)} samples across {len(tasks_with_multiple_values)} tasks")
logger.info(f"Sampled trajectory indices: {all_sampled_traj_indices}")
return sample_indices
def _generate_indices_for_trajectory(self, traj_idx: int, traj: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Generate sample indices for a single trajectory.
Args:
traj_idx: Index of the trajectory in the dataset
traj: Trajectory dictionary
Returns:
List of sample index dictionaries
"""
num_frames = traj["num_frames"]
indices = []
if self.use_frame_steps:
# Generate subsequence indices like reward_alignment: 0:frame_step, 0:2*frame_step, etc.
for end_idx in range(self.frame_step, num_frames + 1, self.frame_step):
frame_indices = list(range(end_idx))
indices.append({
"traj_idx": traj_idx,
"frame_indices": frame_indices,
"num_frames": num_frames,
"video_path": traj["frames"],
"id": traj["id"],
"use_frame_steps": True,
})
else:
# Generate one sample per trajectory (whole trajectory)
indices.append({
"traj_idx": traj_idx,
"video_path": traj["frames"],
"id": traj["id"],
"use_frame_steps": False,
})
return indices
def _generate_sample_from_indices(self, sample_idx_info: dict) -> ProgressSample:
"""Generate a single progress sample from trajectory index."""
traj_idx = sample_idx_info["traj_idx"]
use_frame_steps = sample_idx_info.get("use_frame_steps", True)
traj = self.dataset[traj_idx]
if use_frame_steps:
# Frame steps mode: create subsequence like reward_alignment
frame_indices = sample_idx_info["frame_indices"]
num_frames = sample_idx_info["num_frames"]
metadata = {
"quality_label": traj["quality_label"],
"data_source": traj["data_source"],
"task": traj["task"],
"id": traj["id"],
"video_path": sample_idx_info["video_path"],
"frame_step": frame_indices[-1] if frame_indices else 0,
}
trajectory = self._get_traj_from_data(
traj=traj,
frame_indices=frame_indices,
metadata=metadata,
)
else:
# Whole trajectory mode
metadata = {
"quality_label": traj["quality_label"],
"data_source": traj["data_source"],
"task": traj["task"],
"id": traj["id"],
"video_path": sample_idx_info["video_path"],
}
trajectory = self._get_traj_from_data(
traj=traj,
metadata=metadata,
)
sample = ProgressSample(trajectory=trajectory)
return sample
def __len__(self):
return len(self.sample_indices)
def __getitem__(self, idx):
return self._generate_sample_from_indices(self.sample_indices[idx])