File size: 5,413 Bytes
3e462dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python3
"""
Data generator for reward alignment evaluation.

This generator creates subsequence samples from trajectories for progress prediction evaluation.
For each trajectory, it creates multiple subsequences (0:2, 0:4, 0:6, etc.) and formats them
as PreferenceSample objects that can be evaluated by the model.
"""

from typing import Dict, List, Any

import torch
from tqdm import tqdm

from rfm.data.dataset_types import ProgressSample, Trajectory
from rfm.data.samplers.base import RFMBaseSampler
from rfm.utils.distributed import rank_0_print


class RewardAlignmentSampler(RFMBaseSampler):
    """
    Data generator that creates subsequence samples for reward alignment evaluation.

    For each trajectory, creates subsequences of frames (0:2, 0:4, 0:6, etc.)
    and formats them as PreferenceSample objects for evaluation.
    """

    def __init__(
        self,
        max_trajectories: int | None = None,
        frame_step: int = 1,
        use_frame_steps: bool = True,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.max_trajectories = max_trajectories
        self.frame_step = frame_step
        self.use_frame_steps = use_frame_steps
        self.sample_indices = self._generate_all_sample_indices()

        rank_0_print(
            f"Generated {len(self.sample_indices)} reward alignment sample indices from {min(len(self.robot_trajectories), self.max_trajectories) if self.max_trajectories else len(self.robot_trajectories)} trajectories",
            verbose=self.verbose,
        )

    def _generate_all_sample_indices(self) -> List[Dict[str, Any]]:
        """Generate all possible subsequence sample indices (not the actual samples)."""
        sample_indices = []

        # Limit number of trajectories if specified
        trajectories_to_process = self.robot_trajectories
        if self.max_trajectories is not None and self.max_trajectories < len(self.robot_trajectories):
            trajectories_to_process = self._local_random.sample(self.robot_trajectories, self.max_trajectories)

        rank_0_print(
            f"Generating subsequence samples for {len(trajectories_to_process)} trajectories", verbose=self.verbose
        )

        for traj_idx in trajectories_to_process:
            traj = self.dataset[traj_idx]
            sample_indices.extend(self._generate_indices_for_trajectory(traj_idx, traj))

        return sample_indices

    def _generate_indices_for_trajectory(self, traj_idx: int, traj: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Generate sample indices for a single trajectory.

        Args:
            traj_idx: Index of the trajectory in the dataset
            traj: Trajectory dictionary

        Returns:
            List of sample index dictionaries
        """
        num_frames = traj["num_frames"]
        indices = []

        if self.use_frame_steps:
            # Generate subsequence indices like reward_alignment: 0:frame_step, 0:2*frame_step, etc.
            for end_idx in range(self.frame_step, num_frames + 1, self.frame_step):
                frame_indices = list(range(end_idx))
                indices.append({
                    "traj_idx": traj_idx,
                    "frame_indices": frame_indices,
                    "num_frames": num_frames,
                    "video_path": traj["frames"],
                    "id": traj["id"],
                    "use_frame_steps": True,
                })
        else:
            # Generate one sample per trajectory (whole trajectory)
            indices.append({
                "traj_idx": traj_idx,
                "video_path": traj["frames"],
                "id": traj["id"],
                "use_frame_steps": False,
            })

        return indices

    def _generate_sample_from_indices(self, sample_idx_info: dict) -> ProgressSample:
        """Generate a single subsequence sample from stored indices."""
        traj_idx = sample_idx_info["traj_idx"]
        use_frame_steps = sample_idx_info.get("use_frame_steps", True)

        traj = self.dataset[traj_idx]

        if use_frame_steps:
            # Frame steps mode: create subsequence like reward_alignment
            frame_indices = sample_idx_info["frame_indices"]
            num_frames = sample_idx_info["num_frames"]

            metadata = {
                "data_gen_strategy": "reward_alignment",
                "id": traj["id"],
                "video_path": sample_idx_info["video_path"],
                "frame_step": frame_indices[-1] if frame_indices else 0,
                "num_frames": num_frames,
            }

            trajectory = self._get_traj_from_data(
                traj=traj,
                frame_indices=frame_indices,
                metadata=metadata,
            )
        else:
            # Whole trajectory mode
            metadata = {
                "data_gen_strategy": "reward_alignment",
                "id": traj["id"],
                "video_path": sample_idx_info["video_path"],
            }

            trajectory = self._get_traj_from_data(
                traj=traj,
                metadata=metadata,
            )

        sample = ProgressSample(trajectory=trajectory, sample_type="progress")
        return sample

    def __len__(self):
        return len(self.sample_indices)

    def __getitem__(self, idx):
        return self._generate_sample_from_indices(self.sample_indices[idx])