File size: 7,947 Bytes
3e462dd
 
 
 
 
88e2e89
 
 
3e462dd
 
 
 
88e2e89
 
3e462dd
 
 
 
88e2e89
3e462dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88e2e89
 
 
3e462dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from typing import Dict, Any, Optional

import random
import torch

from robometer.data.dataset_types import ProgressSample, Trajectory
from robometer.data.samplers.base import RBMBaseSampler
from robometer.data.datasets.helpers import (
    DataGenStrat,
    load_embeddings_from_path,
    convert_continuous_to_discrete_bins,
)
from robometer.utils.distributed import rank_0_print
from robometer.utils.logger import get_logger

logger = get_logger()


class ProgressSampler(RBMBaseSampler):
    """Data generator for progress samples."""

    def __init__(self, is_evaluation=False, **kwargs):
        super().__init__(**kwargs)

    def _generate_sample(self, item: Dict[str, Any], preferred_strategy: Optional[DataGenStrat] = None):
        return self._create_progress_sample(item, preferred_strategy=preferred_strategy)

    def _execute_strategy(self, strategy: DataGenStrat, traj: Dict[str, Any]) -> tuple[Dict[str, Any], str] | None:
        """Execute a strategy to get processed trajectory.

        Args:
            strategy: The strategy to execute
            traj: The trajectory to process

        Returns:
            Tuple of (processed_traj, subsample_strategy) or None if failed
        """
        if strategy == DataGenStrat.FORWARD_PROGRESS:
            return (traj, "subsample_forward")
        elif strategy == DataGenStrat.REVERSE_PROGRESS:
            return (traj, "subsample_reverse")
        elif strategy == DataGenStrat.REWIND:
            return (traj, "subsample_rewind")
        elif strategy == DataGenStrat.DIFFERENT_TASK_INSTRUCTION:
            processed_traj = self._get_different_task_instruction(traj)
            if processed_traj is None:
                return None
            return (processed_traj, "subsample_forward")
        else:
            return None

    def _create_progress_sample(self, traj: Dict[str, Any], preferred_strategy: Optional[DataGenStrat] = None):
        """Create a progress sample using normalized and rebalanced strategy selection.

        Implements four strategies:
        1. Different Task: Use trajectory from different task (progress set to 0.0)
        2. Forward Progress: Sample with forward direction (start < middle < end)
        3. Reverse Progress: Sample with reverse direction (end < middle < start)
        4. Rewind: Sample with rewind direction (start < end < middle)
        """
        # Initialize variables for strategy selection
        processed_traj = None
        strategy_used = None
        subsample_strategy = None

        # Strategy selection: use preferred_strategy if provided, otherwise select based on ratios
        if preferred_strategy is not None:
            # Use the preferred strategy directly
            logger.trace(f"[PROGRESS SAMPLER] Using preferred strategy: {preferred_strategy.value}")
            result = self._execute_strategy(preferred_strategy, traj)
            if result is None:
                logger.trace(f"[PROGRESS SAMPLER] Preferred strategy {preferred_strategy.value} failed, returning None")
                return None
            processed_traj, subsample_strategy = result
            strategy_used = preferred_strategy
            attempt = 1  # Set attempt for preferred strategy path
        else:
            # Strategy setup with rebalancing on failure
            # [different_task_instruction, forward_progress, reverse_progress, rewind]
            strategies = [
                (
                    DataGenStrat.DIFFERENT_TASK_INSTRUCTION,
                    self.config.progress_strategy_ratio[0] if len(self.config.progress_strategy_ratio) > 0 else 0.0,
                ),
                (
                    DataGenStrat.FORWARD_PROGRESS,
                    self.config.progress_strategy_ratio[1] if len(self.config.progress_strategy_ratio) > 1 else 0.0,
                ),
                (
                    DataGenStrat.REVERSE_PROGRESS,
                    self.config.progress_strategy_ratio[2] if len(self.config.progress_strategy_ratio) > 2 else 0.0,
                ),
                (
                    DataGenStrat.REWIND,
                    self.config.progress_strategy_ratio[3] if len(self.config.progress_strategy_ratio) > 3 else 0.0,
                ),
            ]

            # Remove strategies with zero probability
            strategies = [(strat, prob) for strat, prob in strategies if prob > 0]

            max_attempts = 10  # Limit retry attempts to prevent infinite loops
            attempt = 0

            while processed_traj is None and attempt < max_attempts:
                attempt += 1

                # Check if we have any strategies left
                if not strategies:
                    return None

                # Rebalance probabilities based on remaining strategies
                total_prob = sum(prob for _, prob in strategies)
                if total_prob == 0:
                    return None

                # Normalize probabilities
                normalized_strategies = [(strat, prob / total_prob) for strat, prob in strategies]

                # Select strategy based on rebalanced probabilities
                prob = random.random()
                cumulative_prob = 0.0
                selected_strategy = None

                for strat, normalized_prob in normalized_strategies:
                    cumulative_prob += normalized_prob
                    if prob <= cumulative_prob:
                        selected_strategy = strat
                        break

                # Execute selected strategy
                result = self._execute_strategy(selected_strategy, traj)
                if result is not None:
                    processed_traj, subsample_strategy = result
                    strategy_used = selected_strategy
                else:
                    # Remove failed strategy and try again
                    strategies = [(strat, prob) for strat, prob in strategies if strat != selected_strategy]
                    continue

            # If we still don't have a sample after all attempts, return None
            if processed_traj is None:
                logger.trace(
                    f"[PROGRESS SAMPLER] Failed to generate progress sample after {max_attempts} attempts - all strategies exhausted"
                )
                return None

        progress_traj = self._get_traj_from_data(processed_traj, subsample_strategy=subsample_strategy)

        if progress_traj is None:
            return None

        # Handle special cases
        if strategy_used in [DataGenStrat.DIFFERENT_TASK, DataGenStrat.DIFFERENT_TASK_INSTRUCTION]:
            # We need to use the original task embeddings instead of the different task embeddings
            if self.config.load_embeddings and traj.get("embeddings_path"):
                progress_traj.text_embedding = load_embeddings_from_path(traj["embeddings_path"])["text_embedding"]
            progress_traj.lang_vector = traj["lang_vector"]
            progress_traj.task = traj["task"]
            progress_traj.target_progress = [0.0] * len(progress_traj.target_progress)
            if self.config.progress_loss_type.lower() == "discrete":
                progress_traj.target_progress = convert_continuous_to_discrete_bins(
                    progress_traj.target_progress, self.config.progress_discrete_bins
                )
            # Also set success labels to 0.0 (predict 0 success for different task trajectories)
            if progress_traj.success_label is not None:
                progress_traj.success_label = [0.0] * len(progress_traj.success_label)

        strategy_value = strategy_used.value if isinstance(strategy_used, DataGenStrat) else strategy_used
        sample = ProgressSample(
            trajectory=progress_traj,
            sample_type="progress",
            data_gen_strategy=strategy_value,
        )
        sample.resample_attempts = attempt
        return sample