Spaces:
Running
Running
File size: 2,647 Bytes
3e462dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | from typing import Dict, Any
import numpy as np
from rfm.data.dataset_types import PreferenceSample, Trajectory
from rfm.data.samplers.base import RFMBaseSampler
class BaseQualityPreferenceSampler(RFMBaseSampler):
"""Base class for quality preference samplers.
Subclasses should implement `_generate_all_sample_indices` to define how
trajectories are paired. This base class provides the common `_generate_sample_from_indices`
method that loads and processes the trajectories.
"""
def _generate_sample_from_indices(self, sample_idx_info: Dict[str, Any]) -> PreferenceSample:
"""Generate a single sample from stored indices."""
chosen_idx = sample_idx_info["chosen_traj_idx"]
rejected_idx = sample_idx_info["rejected_traj_idx"]
# Get the trajectories
chosen_traj = self.dataset[chosen_idx]
rejected_traj = self.dataset[rejected_idx]
chosen_metadata = {
"quality_label": chosen_traj["quality_label"],
"data_source": chosen_traj["data_source"],
"task": chosen_traj["task"],
"id": chosen_traj["id"],
"video_path": chosen_traj["frames"],
}
# Add partial_success if available
if chosen_traj.get("partial_success") is not None:
chosen_metadata["partial_success"] = chosen_traj.get("partial_success")
chosen_trajectory = self._get_traj_from_data(
traj=chosen_traj,
metadata=chosen_metadata,
)
rejected_metadata = {
"quality_label": rejected_traj["quality_label"],
"data_source": rejected_traj["data_source"],
"task": rejected_traj["task"],
"id": rejected_traj["id"],
"video_path": rejected_traj["frames"],
}
# Add partial_success if available
if rejected_traj.get("partial_success") is not None:
rejected_metadata["partial_success"] = rejected_traj.get("partial_success")
rejected_trajectory = self._get_traj_from_data(
traj=rejected_traj,
metadata=rejected_metadata,
)
data_gen_strategy = getattr(self, "data_gen_strategy", "quality_preference")
# Create preference sample
sample = PreferenceSample(
chosen_trajectory=chosen_trajectory,
rejected_trajectory=rejected_trajectory,
data_gen_strategy=data_gen_strategy,
)
return sample
def __len__(self):
return len(self.sample_indices)
def __getitem__(self, idx):
return self._generate_sample_from_indices(self.sample_indices[idx])
|