rewardeval_ui / samplers /eval /reward_alignment.py
Anthony Liang
getting rid of dependencies
3e462dd
#!/usr/bin/env python3
"""
Data generator for reward alignment evaluation.
This generator creates subsequence samples from trajectories for progress prediction evaluation.
For each trajectory, it creates multiple subsequences (0:2, 0:4, 0:6, etc.) and formats them
as PreferenceSample objects that can be evaluated by the model.
"""
from typing import Dict, List, Any
import torch
from tqdm import tqdm
from rfm.data.dataset_types import ProgressSample, Trajectory
from rfm.data.samplers.base import RFMBaseSampler
from rfm.utils.distributed import rank_0_print
class RewardAlignmentSampler(RFMBaseSampler):
"""
Data generator that creates subsequence samples for reward alignment evaluation.
For each trajectory, creates subsequences of frames (0:2, 0:4, 0:6, etc.)
and formats them as PreferenceSample objects for evaluation.
"""
def __init__(
self,
max_trajectories: int | None = None,
frame_step: int = 1,
use_frame_steps: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.max_trajectories = max_trajectories
self.frame_step = frame_step
self.use_frame_steps = use_frame_steps
self.sample_indices = self._generate_all_sample_indices()
rank_0_print(
f"Generated {len(self.sample_indices)} reward alignment sample indices from {min(len(self.robot_trajectories), self.max_trajectories) if self.max_trajectories else len(self.robot_trajectories)} trajectories",
verbose=self.verbose,
)
def _generate_all_sample_indices(self) -> List[Dict[str, Any]]:
"""Generate all possible subsequence sample indices (not the actual samples)."""
sample_indices = []
# Limit number of trajectories if specified
trajectories_to_process = self.robot_trajectories
if self.max_trajectories is not None and self.max_trajectories < len(self.robot_trajectories):
trajectories_to_process = self._local_random.sample(self.robot_trajectories, self.max_trajectories)
rank_0_print(
f"Generating subsequence samples for {len(trajectories_to_process)} trajectories", verbose=self.verbose
)
for traj_idx in trajectories_to_process:
traj = self.dataset[traj_idx]
sample_indices.extend(self._generate_indices_for_trajectory(traj_idx, traj))
return sample_indices
def _generate_indices_for_trajectory(self, traj_idx: int, traj: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Generate sample indices for a single trajectory.
Args:
traj_idx: Index of the trajectory in the dataset
traj: Trajectory dictionary
Returns:
List of sample index dictionaries
"""
num_frames = traj["num_frames"]
indices = []
if self.use_frame_steps:
# Generate subsequence indices like reward_alignment: 0:frame_step, 0:2*frame_step, etc.
for end_idx in range(self.frame_step, num_frames + 1, self.frame_step):
frame_indices = list(range(end_idx))
indices.append({
"traj_idx": traj_idx,
"frame_indices": frame_indices,
"num_frames": num_frames,
"video_path": traj["frames"],
"id": traj["id"],
"use_frame_steps": True,
})
else:
# Generate one sample per trajectory (whole trajectory)
indices.append({
"traj_idx": traj_idx,
"video_path": traj["frames"],
"id": traj["id"],
"use_frame_steps": False,
})
return indices
def _generate_sample_from_indices(self, sample_idx_info: dict) -> ProgressSample:
"""Generate a single subsequence sample from stored indices."""
traj_idx = sample_idx_info["traj_idx"]
use_frame_steps = sample_idx_info.get("use_frame_steps", True)
traj = self.dataset[traj_idx]
if use_frame_steps:
# Frame steps mode: create subsequence like reward_alignment
frame_indices = sample_idx_info["frame_indices"]
num_frames = sample_idx_info["num_frames"]
metadata = {
"data_gen_strategy": "reward_alignment",
"id": traj["id"],
"video_path": sample_idx_info["video_path"],
"frame_step": frame_indices[-1] if frame_indices else 0,
"num_frames": num_frames,
}
trajectory = self._get_traj_from_data(
traj=traj,
frame_indices=frame_indices,
metadata=metadata,
)
else:
# Whole trajectory mode
metadata = {
"data_gen_strategy": "reward_alignment",
"id": traj["id"],
"video_path": sample_idx_info["video_path"],
}
trajectory = self._get_traj_from_data(
traj=traj,
metadata=metadata,
)
sample = ProgressSample(trajectory=trajectory, sample_type="progress")
return sample
def __len__(self):
return len(self.sample_indices)
def __getitem__(self, idx):
return self._generate_sample_from_indices(self.sample_indices[idx])