File size: 10,283 Bytes
3e462dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from typing import Dict, List, Any, Optional
from itertools import cycle

import numpy as np
from collections import defaultdict
from rfm.data.dataset_types import ProgressSample
from rfm.data.samplers.base import RFMBaseSampler
from rfm.utils.logger import get_logger

logger = get_logger()


class ProgressPolicyRankingSampler(RFMBaseSampler):
    """Dataset that generates progress samples for policy ranking by selecting N trajectories per quality label for tasks with multiple quality labels."""

    def __init__(
        self,
        num_examples_per_quality_pr: int = 5,
        num_partial_successes: Optional[int] = None,
        frame_step: int = 1,
        use_frame_steps: bool = True,
        max_tasks: Optional[int] = None,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.num_examples_per_quality_pr = num_examples_per_quality_pr
        self.num_partial_successes = num_partial_successes
        self.frame_step = frame_step
        self.use_frame_steps = use_frame_steps
        self.max_tasks = max_tasks
        logger.info(f"ProgressPolicyRankingSampler initialized with {len(self.robot_trajectories)} trajectories")

        self.sample_indices = self._generate_all_sample_indices()

        logger.info(f"Generated {len(self.sample_indices)} sample indices")

    def _generate_all_sample_indices(self) -> List[Dict[str, Any]]:
        """Generate sample indices by selecting tasks with multiple quality labels/partial_success values and sampling N trajectories per group.

        For non-RoboArena: Groups by task and quality_label.
        For RoboArena: Groups by task and partial_success values.

        If use_frame_steps=True, generates subsequence samples like reward_alignment (0:frame_step, 0:2*frame_step, etc.).
        If use_frame_steps=False, generates one sample per trajectory (whole trajectory).
        """

        # Check if this is RoboArena (has partial_success)
        is_roboarena = False
        if self.robot_trajectories:
            first_traj = self.dataset[self.robot_trajectories[0]]
            is_roboarena = first_traj.get("partial_success") is not None

        # Group trajectories by task and grouping key (quality_label or partial_success)
        task_to_key_to_trajs = defaultdict(lambda: defaultdict(list))

        for traj_idx in self.robot_trajectories:
            traj = self.dataset[traj_idx]
            task = traj["task"]

            if is_roboarena:
                # RoboArena: Use rounded partial_success as key to group similar values
                partial_success_val = traj.get("partial_success")
                if partial_success_val is not None:
                    partial_success = round(float(partial_success_val), 2)
                    task_to_key_to_trajs[task][partial_success].append(traj_idx)
            else:
                # Non-RoboArena: Use quality_label
                quality = traj["quality_label"]
                task_to_key_to_trajs[task][quality].append(traj_idx)

        # Filter to tasks that have multiple grouping values
        tasks_with_multiple_values = {
            task: key_to_trajs for task, key_to_trajs in task_to_key_to_trajs.items() if len(key_to_trajs) > 1
        }

        dataset_type_str = "partial_success values" if is_roboarena else "quality labels"
        logger.info(f"Found {len(tasks_with_multiple_values)} tasks with multiple {dataset_type_str}")

        # Limit number of tasks if max_tasks is specified
        if self.max_tasks is not None and self.max_tasks > 0:
            # Convert to list, shuffle, and take first max_tasks
            # Sort tasks first to ensure deterministic ordering before shuffling
            tasks_list = sorted(tasks_with_multiple_values.items())
            self._local_random.shuffle(tasks_list)
            tasks_with_multiple_values = dict(tasks_list[: self.max_tasks])
            logger.info(f"Limited to {len(tasks_with_multiple_values)} tasks (max_tasks={self.max_tasks})")

        # Sample trajectories for each task
        sample_indices = []
        all_sampled_traj_indices = []
        # Sort tasks to ensure deterministic processing order
        for task, key_to_trajs in sorted(tasks_with_multiple_values.items()):
            if is_roboarena:
                # RoboArena: Use num_partial_successes for circular sampling
                num_to_sample_total = self.num_partial_successes

                # Build lists of available indices per partial_success (sorted for deterministic sampling)
                available_lists = []
                for partial_success in sorted(key_to_trajs.keys()):
                    traj_indices = sorted(key_to_trajs[partial_success])
                    if traj_indices:
                        available_lists.append(traj_indices)

                # Circular sampling: cycle through partial_success groups until we reach max
                sampled_traj_indices = []
                for available_indices in cycle(available_lists):
                    if len(sampled_traj_indices) >= num_to_sample_total:
                        break
                    if not available_indices:
                        # If all lists are empty, stop
                        if all(not lst for lst in available_lists):
                            break
                        continue

                    # Sample one trajectory from this group
                    sampled_idx = self._local_random.choice(available_indices)
                    sampled_traj_indices.append(sampled_idx)
                    # Remove the sampled index from this list
                    available_indices.remove(sampled_idx)

                # Generate samples for all sampled trajectories
                for traj_idx in sampled_traj_indices:
                    traj = self.dataset[traj_idx]
                    sample_indices.extend(self._generate_indices_for_trajectory(traj_idx, traj))
                    all_sampled_traj_indices.append(traj_idx)
            else:
                # Non-RoboArena: Sample N trajectories per quality label
                # Sort quality labels to ensure deterministic order
                for quality in sorted(key_to_trajs.keys()):
                    traj_indices = key_to_trajs[quality]
                    # Sort trajectory indices to ensure deterministic sampling
                    traj_indices = sorted(traj_indices)
                    # Sample up to num_examples_per_quality_pr trajectories for this quality label
                    num_to_sample = min(self.num_examples_per_quality_pr, len(traj_indices))
                    sampled_traj_indices = self._local_random.sample(traj_indices, num_to_sample)
                    for traj_idx in sampled_traj_indices:
                        traj = self.dataset[traj_idx]
                        sample_indices.extend(self._generate_indices_for_trajectory(traj_idx, traj))
                        all_sampled_traj_indices.append(traj_idx)

        logger.info(f"Sampled {len(sample_indices)} samples across {len(tasks_with_multiple_values)} tasks")
        logger.info(f"Sampled trajectory indices: {all_sampled_traj_indices}")

        return sample_indices

    def _generate_indices_for_trajectory(self, traj_idx: int, traj: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Generate sample indices for a single trajectory.

        Args:
            traj_idx: Index of the trajectory in the dataset
            traj: Trajectory dictionary

        Returns:
            List of sample index dictionaries
        """
        num_frames = traj["num_frames"]
        indices = []

        if self.use_frame_steps:
            # Generate subsequence indices like reward_alignment: 0:frame_step, 0:2*frame_step, etc.
            for end_idx in range(self.frame_step, num_frames + 1, self.frame_step):
                frame_indices = list(range(end_idx))
                indices.append({
                    "traj_idx": traj_idx,
                    "frame_indices": frame_indices,
                    "num_frames": num_frames,
                    "video_path": traj["frames"],
                    "id": traj["id"],
                    "use_frame_steps": True,
                })
        else:
            # Generate one sample per trajectory (whole trajectory)
            indices.append({
                "traj_idx": traj_idx,
                "video_path": traj["frames"],
                "id": traj["id"],
                "use_frame_steps": False,
            })

        return indices

    def _generate_sample_from_indices(self, sample_idx_info: dict) -> ProgressSample:
        """Generate a single progress sample from trajectory index."""
        traj_idx = sample_idx_info["traj_idx"]
        use_frame_steps = sample_idx_info.get("use_frame_steps", True)

        traj = self.dataset[traj_idx]

        if use_frame_steps:
            # Frame steps mode: create subsequence like reward_alignment
            frame_indices = sample_idx_info["frame_indices"]
            num_frames = sample_idx_info["num_frames"]

            metadata = {
                "quality_label": traj["quality_label"],
                "data_source": traj["data_source"],
                "task": traj["task"],
                "id": traj["id"],
                "video_path": sample_idx_info["video_path"],
                "frame_step": frame_indices[-1] if frame_indices else 0,
            }

            trajectory = self._get_traj_from_data(
                traj=traj,
                frame_indices=frame_indices,
                metadata=metadata,
            )
        else:
            # Whole trajectory mode
            metadata = {
                "quality_label": traj["quality_label"],
                "data_source": traj["data_source"],
                "task": traj["task"],
                "id": traj["id"],
                "video_path": sample_idx_info["video_path"],
            }

            trajectory = self._get_traj_from_data(
                traj=traj,
                metadata=metadata,
            )

        sample = ProgressSample(trajectory=trajectory)
        return sample

    def __len__(self):
        return len(self.sample_indices)

    def __getitem__(self, idx):
        return self._generate_sample_from_indices(self.sample_indices[idx])