File size: 2,647 Bytes
3e462dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import Dict, Any

import numpy as np

from rfm.data.dataset_types import PreferenceSample, Trajectory
from rfm.data.samplers.base import RFMBaseSampler


class BaseQualityPreferenceSampler(RFMBaseSampler):
    """Base class for quality preference samplers.

    Subclasses should implement `_generate_all_sample_indices` to define how
    trajectories are paired. This base class provides the common `_generate_sample_from_indices`
    method that loads and processes the trajectories.
    """

    def _generate_sample_from_indices(self, sample_idx_info: Dict[str, Any]) -> PreferenceSample:
        """Generate a single sample from stored indices."""
        chosen_idx = sample_idx_info["chosen_traj_idx"]
        rejected_idx = sample_idx_info["rejected_traj_idx"]

        # Get the trajectories
        chosen_traj = self.dataset[chosen_idx]
        rejected_traj = self.dataset[rejected_idx]

        chosen_metadata = {
            "quality_label": chosen_traj["quality_label"],
            "data_source": chosen_traj["data_source"],
            "task": chosen_traj["task"],
            "id": chosen_traj["id"],
            "video_path": chosen_traj["frames"],
        }
        # Add partial_success if available
        if chosen_traj.get("partial_success") is not None:
            chosen_metadata["partial_success"] = chosen_traj.get("partial_success")

        chosen_trajectory = self._get_traj_from_data(
            traj=chosen_traj,
            metadata=chosen_metadata,
        )

        rejected_metadata = {
            "quality_label": rejected_traj["quality_label"],
            "data_source": rejected_traj["data_source"],
            "task": rejected_traj["task"],
            "id": rejected_traj["id"],
            "video_path": rejected_traj["frames"],
        }
        # Add partial_success if available
        if rejected_traj.get("partial_success") is not None:
            rejected_metadata["partial_success"] = rejected_traj.get("partial_success")

        rejected_trajectory = self._get_traj_from_data(
            traj=rejected_traj,
            metadata=rejected_metadata,
        )

        data_gen_strategy = getattr(self, "data_gen_strategy", "quality_preference")

        # Create preference sample
        sample = PreferenceSample(
            chosen_trajectory=chosen_trajectory,
            rejected_trajectory=rejected_trajectory,
            data_gen_strategy=data_gen_strategy,
        )

        return sample

    def __len__(self):
        return len(self.sample_indices)

    def __getitem__(self, idx):
        return self._generate_sample_from_indices(self.sample_indices[idx])