rewardeval_ui / samplers /eval /quality_preference.py
Anthony Liang
getting rid of dependencies
3e462dd
from typing import Dict, List, Any
from itertools import combinations
from tqdm import tqdm
from rfm.data.samplers.eval.base_pref import BaseQualityPreferenceSampler
from rfm.utils.distributed import rank_0_print
class QualityPreferenceSampler(BaseQualityPreferenceSampler):
"""Dataset that generates preference samples by pairing trajectories with different quality labels or partial_success values for the same task.
For non-RoboArena: Pairs trajectories with different quality labels (failure, suboptimal, successful).
For RoboArena: Pairs trajectories with different partial_success values (higher partial_success = chosen).
"""
def __init__(
self,
comparisons_per_task=None,
max_comparisons=None,
**kwargs,
):
super().__init__(**kwargs)
# Set data_gen_strategy for this sampler
self.data_gen_strategy = "quality_preference"
self.comparisons_per_task = comparisons_per_task
self.max_comparisons = max_comparisons
# Generate all possible sample indices upfront (not the actual samples)
self.sample_indices = self._generate_all_sample_indices()
rank_0_print(f"Generated {len(self.sample_indices)} quality preference sample indices", verbose=self.verbose)
def _generate_all_sample_indices(self) -> List[Dict[str, Any]]:
"""Generate all possible quality preference sample indices (not the actual samples).
For non-RoboArena: Groups by task and quality_label, pairs trajectories with different quality labels.
For RoboArena: Groups by task and partial_success values, pairs trajectories with different partial_success.
"""
sample_indices = []
# Check if this is RoboArena (has partial_success)
is_roboarena = False
if self.robot_trajectories:
first_traj = self.dataset[self.robot_trajectories[0]]
is_roboarena = first_traj.get("partial_success") is not None
rank_0_print(
f"Generating quality preference samples for {len(self.robot_trajectories)} trajectories "
f"({'RoboArena (partial_success)' if is_roboarena else 'non-RoboArena (quality_label)'})",
verbose=self.verbose,
)
if is_roboarena:
# RoboArena: Group by task and partial_success (rounded to 2 decimals)
task_to_partial_trajs = {}
for traj_idx in self.robot_trajectories:
traj = self.dataset[traj_idx]
task = traj["task"]
partial_success_val = traj.get("partial_success")
if partial_success_val is None:
rank_0_print(
f"Warning: Trajectory {traj_idx} (task: {task}) missing partial_success, skipping",
verbose=self.verbose,
)
continue
# Round partial_success to 2 decimals for grouping
partial_success = round(float(partial_success_val), 2)
if task not in task_to_partial_trajs:
task_to_partial_trajs[task] = {}
if partial_success not in task_to_partial_trajs[task]:
task_to_partial_trajs[task][partial_success] = []
task_to_partial_trajs[task][partial_success].append(traj_idx)
# Generate pairs for each task
for task in tqdm(task_to_partial_trajs, desc="Generating RoboArena quality preference samples"):
partial_groups = task_to_partial_trajs[task]
partial_values = list(partial_groups.keys())
# Only create pairs if we have at least 2 different partial_success values
if len(partial_values) < 2:
continue
# Collect all pairs for this task
task_pairs = []
# Create pairs of different partial_success values
for partial1, partial2 in combinations(partial_values, 2):
trajs1 = partial_groups[partial1]
trajs2 = partial_groups[partial2]
if not trajs1 or not trajs2:
continue
# Determine which partial_success is higher (chosen)
if partial1 > partial2:
chosen_partial = partial1
rejected_partial = partial2
chosen_indices = trajs1
rejected_indices = trajs2
elif partial2 > partial1:
chosen_partial = partial2
rejected_partial = partial1
chosen_indices = trajs2
rejected_indices = trajs1
else:
# Same value, skip this pair
continue
# Create all possible pairs for this partial_success combination
for chosen_idx in chosen_indices:
for rejected_idx in rejected_indices:
task_pairs.append({
"chosen_traj_idx": chosen_idx,
"rejected_traj_idx": rejected_idx,
"task": task,
"chosen_partial_success": chosen_partial,
"rejected_partial_success": rejected_partial,
})
# Apply comparisons_per_task limit if set (sample uniformly across all pairs for this task)
if self.comparisons_per_task is not None and len(task_pairs) > self.comparisons_per_task:
# Uniformly sample comparisons for this task
task_pairs = self._local_random.sample(task_pairs, self.comparisons_per_task)
sample_indices.extend(task_pairs)
else:
# Non-RoboArena: Group by task and quality label
task_to_quality_trajs = {}
for traj_idx in self.robot_trajectories:
traj = self.dataset[traj_idx]
task = traj["task"]
quality_label = traj["quality_label"]
if task not in task_to_quality_trajs:
task_to_quality_trajs[task] = {}
if quality_label not in task_to_quality_trajs[task]:
task_to_quality_trajs[task][quality_label] = []
task_to_quality_trajs[task][quality_label].append(traj_idx)
# Generate pairs for each task
quality_order = {"failure": 1, "suboptimal": 2, "successful": 3}
for task in tqdm(task_to_quality_trajs, desc="Generating quality preference samples"):
quality_groups = task_to_quality_trajs[task]
quality_labels = list(quality_groups.keys())
# Only create pairs if we have at least 2 different quality labels
if len(quality_labels) < 2:
continue
# Collect all pairs for this task
task_pairs = []
# Create pairs of different quality labels
for quality1, quality2 in combinations(quality_labels, 2):
trajs1 = quality_groups[quality1]
trajs2 = quality_groups[quality2]
if not trajs1 or not trajs2:
continue
# Determine which quality is better (chosen)
order1 = quality_order.get(quality1, 0)
order2 = quality_order.get(quality2, 0)
# Only create pairs if one quality is strictly better than the other
if order1 > order2:
chosen_quality = quality1
rejected_quality = quality2
chosen_indices = trajs1
rejected_indices = trajs2
elif order2 > order1:
chosen_quality = quality2
rejected_quality = quality1
chosen_indices = trajs2
rejected_indices = trajs1
else:
# Same order, skip this pair as we can't reliably compare them
continue
# Create all possible pairs for this quality combination
for chosen_idx in chosen_indices:
for rejected_idx in rejected_indices:
task_pairs.append({
"chosen_traj_idx": chosen_idx,
"rejected_traj_idx": rejected_idx,
"task": task,
"chosen_quality": chosen_quality,
"rejected_quality": rejected_quality,
})
# Apply comparisons_per_task limit if set (sample uniformly across all pairs for this task)
if self.comparisons_per_task is not None and len(task_pairs) > self.comparisons_per_task:
# Uniformly sample comparisons for this task
task_pairs = self._local_random.sample(task_pairs, self.comparisons_per_task)
sample_indices.extend(task_pairs)
# Apply max_comparisons limit if set (sample uniformly across all comparisons)
original_count = len(sample_indices)
if self.max_comparisons is not None and original_count > self.max_comparisons:
sample_indices = self._local_random.sample(sample_indices, self.max_comparisons)
rank_0_print(
f"Limited total comparisons to {self.max_comparisons} (from {original_count} total comparisons)",
verbose=self.verbose,
)
return sample_indices