Spaces:
Running
Running
| from typing import Dict, List, Any | |
| from itertools import combinations | |
| from tqdm import tqdm | |
| from rfm.data.samplers.eval.base_pref import BaseQualityPreferenceSampler | |
| from rfm.utils.distributed import rank_0_print | |
| class QualityPreferenceSampler(BaseQualityPreferenceSampler): | |
| """Dataset that generates preference samples by pairing trajectories with different quality labels or partial_success values for the same task. | |
| For non-RoboArena: Pairs trajectories with different quality labels (failure, suboptimal, successful). | |
| For RoboArena: Pairs trajectories with different partial_success values (higher partial_success = chosen). | |
| """ | |
| def __init__( | |
| self, | |
| comparisons_per_task=None, | |
| max_comparisons=None, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| # Set data_gen_strategy for this sampler | |
| self.data_gen_strategy = "quality_preference" | |
| self.comparisons_per_task = comparisons_per_task | |
| self.max_comparisons = max_comparisons | |
| # Generate all possible sample indices upfront (not the actual samples) | |
| self.sample_indices = self._generate_all_sample_indices() | |
| rank_0_print(f"Generated {len(self.sample_indices)} quality preference sample indices", verbose=self.verbose) | |
| def _generate_all_sample_indices(self) -> List[Dict[str, Any]]: | |
| """Generate all possible quality preference sample indices (not the actual samples). | |
| For non-RoboArena: Groups by task and quality_label, pairs trajectories with different quality labels. | |
| For RoboArena: Groups by task and partial_success values, pairs trajectories with different partial_success. | |
| """ | |
| sample_indices = [] | |
| # Check if this is RoboArena (has partial_success) | |
| is_roboarena = False | |
| if self.robot_trajectories: | |
| first_traj = self.dataset[self.robot_trajectories[0]] | |
| is_roboarena = first_traj.get("partial_success") is not None | |
| rank_0_print( | |
| f"Generating quality preference samples for {len(self.robot_trajectories)} trajectories " | |
| f"({'RoboArena (partial_success)' if is_roboarena else 'non-RoboArena (quality_label)'})", | |
| verbose=self.verbose, | |
| ) | |
| if is_roboarena: | |
| # RoboArena: Group by task and partial_success (rounded to 2 decimals) | |
| task_to_partial_trajs = {} | |
| for traj_idx in self.robot_trajectories: | |
| traj = self.dataset[traj_idx] | |
| task = traj["task"] | |
| partial_success_val = traj.get("partial_success") | |
| if partial_success_val is None: | |
| rank_0_print( | |
| f"Warning: Trajectory {traj_idx} (task: {task}) missing partial_success, skipping", | |
| verbose=self.verbose, | |
| ) | |
| continue | |
| # Round partial_success to 2 decimals for grouping | |
| partial_success = round(float(partial_success_val), 2) | |
| if task not in task_to_partial_trajs: | |
| task_to_partial_trajs[task] = {} | |
| if partial_success not in task_to_partial_trajs[task]: | |
| task_to_partial_trajs[task][partial_success] = [] | |
| task_to_partial_trajs[task][partial_success].append(traj_idx) | |
| # Generate pairs for each task | |
| for task in tqdm(task_to_partial_trajs, desc="Generating RoboArena quality preference samples"): | |
| partial_groups = task_to_partial_trajs[task] | |
| partial_values = list(partial_groups.keys()) | |
| # Only create pairs if we have at least 2 different partial_success values | |
| if len(partial_values) < 2: | |
| continue | |
| # Collect all pairs for this task | |
| task_pairs = [] | |
| # Create pairs of different partial_success values | |
| for partial1, partial2 in combinations(partial_values, 2): | |
| trajs1 = partial_groups[partial1] | |
| trajs2 = partial_groups[partial2] | |
| if not trajs1 or not trajs2: | |
| continue | |
| # Determine which partial_success is higher (chosen) | |
| if partial1 > partial2: | |
| chosen_partial = partial1 | |
| rejected_partial = partial2 | |
| chosen_indices = trajs1 | |
| rejected_indices = trajs2 | |
| elif partial2 > partial1: | |
| chosen_partial = partial2 | |
| rejected_partial = partial1 | |
| chosen_indices = trajs2 | |
| rejected_indices = trajs1 | |
| else: | |
| # Same value, skip this pair | |
| continue | |
| # Create all possible pairs for this partial_success combination | |
| for chosen_idx in chosen_indices: | |
| for rejected_idx in rejected_indices: | |
| task_pairs.append({ | |
| "chosen_traj_idx": chosen_idx, | |
| "rejected_traj_idx": rejected_idx, | |
| "task": task, | |
| "chosen_partial_success": chosen_partial, | |
| "rejected_partial_success": rejected_partial, | |
| }) | |
| # Apply comparisons_per_task limit if set (sample uniformly across all pairs for this task) | |
| if self.comparisons_per_task is not None and len(task_pairs) > self.comparisons_per_task: | |
| # Uniformly sample comparisons for this task | |
| task_pairs = self._local_random.sample(task_pairs, self.comparisons_per_task) | |
| sample_indices.extend(task_pairs) | |
| else: | |
| # Non-RoboArena: Group by task and quality label | |
| task_to_quality_trajs = {} | |
| for traj_idx in self.robot_trajectories: | |
| traj = self.dataset[traj_idx] | |
| task = traj["task"] | |
| quality_label = traj["quality_label"] | |
| if task not in task_to_quality_trajs: | |
| task_to_quality_trajs[task] = {} | |
| if quality_label not in task_to_quality_trajs[task]: | |
| task_to_quality_trajs[task][quality_label] = [] | |
| task_to_quality_trajs[task][quality_label].append(traj_idx) | |
| # Generate pairs for each task | |
| quality_order = {"failure": 1, "suboptimal": 2, "successful": 3} | |
| for task in tqdm(task_to_quality_trajs, desc="Generating quality preference samples"): | |
| quality_groups = task_to_quality_trajs[task] | |
| quality_labels = list(quality_groups.keys()) | |
| # Only create pairs if we have at least 2 different quality labels | |
| if len(quality_labels) < 2: | |
| continue | |
| # Collect all pairs for this task | |
| task_pairs = [] | |
| # Create pairs of different quality labels | |
| for quality1, quality2 in combinations(quality_labels, 2): | |
| trajs1 = quality_groups[quality1] | |
| trajs2 = quality_groups[quality2] | |
| if not trajs1 or not trajs2: | |
| continue | |
| # Determine which quality is better (chosen) | |
| order1 = quality_order.get(quality1, 0) | |
| order2 = quality_order.get(quality2, 0) | |
| # Only create pairs if one quality is strictly better than the other | |
| if order1 > order2: | |
| chosen_quality = quality1 | |
| rejected_quality = quality2 | |
| chosen_indices = trajs1 | |
| rejected_indices = trajs2 | |
| elif order2 > order1: | |
| chosen_quality = quality2 | |
| rejected_quality = quality1 | |
| chosen_indices = trajs2 | |
| rejected_indices = trajs1 | |
| else: | |
| # Same order, skip this pair as we can't reliably compare them | |
| continue | |
| # Create all possible pairs for this quality combination | |
| for chosen_idx in chosen_indices: | |
| for rejected_idx in rejected_indices: | |
| task_pairs.append({ | |
| "chosen_traj_idx": chosen_idx, | |
| "rejected_traj_idx": rejected_idx, | |
| "task": task, | |
| "chosen_quality": chosen_quality, | |
| "rejected_quality": rejected_quality, | |
| }) | |
| # Apply comparisons_per_task limit if set (sample uniformly across all pairs for this task) | |
| if self.comparisons_per_task is not None and len(task_pairs) > self.comparisons_per_task: | |
| # Uniformly sample comparisons for this task | |
| task_pairs = self._local_random.sample(task_pairs, self.comparisons_per_task) | |
| sample_indices.extend(task_pairs) | |
| # Apply max_comparisons limit if set (sample uniformly across all comparisons) | |
| original_count = len(sample_indices) | |
| if self.max_comparisons is not None and original_count > self.max_comparisons: | |
| sample_indices = self._local_random.sample(sample_indices, self.max_comparisons) | |
| rank_0_print( | |
| f"Limited total comparisons to {self.max_comparisons} (from {original_count} total comparisons)", | |
| verbose=self.verbose, | |
| ) | |
| return sample_indices | |