Spaces:
Running
Running
File size: 7,947 Bytes
3e462dd 88e2e89 3e462dd 88e2e89 3e462dd 88e2e89 3e462dd 88e2e89 3e462dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | from typing import Dict, Any, Optional
import random
import torch
from robometer.data.dataset_types import ProgressSample, Trajectory
from robometer.data.samplers.base import RBMBaseSampler
from robometer.data.datasets.helpers import (
DataGenStrat,
load_embeddings_from_path,
convert_continuous_to_discrete_bins,
)
from robometer.utils.distributed import rank_0_print
from robometer.utils.logger import get_logger
logger = get_logger()
class ProgressSampler(RBMBaseSampler):
"""Data generator for progress samples."""
def __init__(self, is_evaluation=False, **kwargs):
super().__init__(**kwargs)
def _generate_sample(self, item: Dict[str, Any], preferred_strategy: Optional[DataGenStrat] = None):
return self._create_progress_sample(item, preferred_strategy=preferred_strategy)
def _execute_strategy(self, strategy: DataGenStrat, traj: Dict[str, Any]) -> tuple[Dict[str, Any], str] | None:
"""Execute a strategy to get processed trajectory.
Args:
strategy: The strategy to execute
traj: The trajectory to process
Returns:
Tuple of (processed_traj, subsample_strategy) or None if failed
"""
if strategy == DataGenStrat.FORWARD_PROGRESS:
return (traj, "subsample_forward")
elif strategy == DataGenStrat.REVERSE_PROGRESS:
return (traj, "subsample_reverse")
elif strategy == DataGenStrat.REWIND:
return (traj, "subsample_rewind")
elif strategy == DataGenStrat.DIFFERENT_TASK_INSTRUCTION:
processed_traj = self._get_different_task_instruction(traj)
if processed_traj is None:
return None
return (processed_traj, "subsample_forward")
else:
return None
def _create_progress_sample(self, traj: Dict[str, Any], preferred_strategy: Optional[DataGenStrat] = None):
"""Create a progress sample using normalized and rebalanced strategy selection.
Implements four strategies:
1. Different Task: Use trajectory from different task (progress set to 0.0)
2. Forward Progress: Sample with forward direction (start < middle < end)
3. Reverse Progress: Sample with reverse direction (end < middle < start)
4. Rewind: Sample with rewind direction (start < end < middle)
"""
# Initialize variables for strategy selection
processed_traj = None
strategy_used = None
subsample_strategy = None
# Strategy selection: use preferred_strategy if provided, otherwise select based on ratios
if preferred_strategy is not None:
# Use the preferred strategy directly
logger.trace(f"[PROGRESS SAMPLER] Using preferred strategy: {preferred_strategy.value}")
result = self._execute_strategy(preferred_strategy, traj)
if result is None:
logger.trace(f"[PROGRESS SAMPLER] Preferred strategy {preferred_strategy.value} failed, returning None")
return None
processed_traj, subsample_strategy = result
strategy_used = preferred_strategy
attempt = 1 # Set attempt for preferred strategy path
else:
# Strategy setup with rebalancing on failure
# [different_task_instruction, forward_progress, reverse_progress, rewind]
strategies = [
(
DataGenStrat.DIFFERENT_TASK_INSTRUCTION,
self.config.progress_strategy_ratio[0] if len(self.config.progress_strategy_ratio) > 0 else 0.0,
),
(
DataGenStrat.FORWARD_PROGRESS,
self.config.progress_strategy_ratio[1] if len(self.config.progress_strategy_ratio) > 1 else 0.0,
),
(
DataGenStrat.REVERSE_PROGRESS,
self.config.progress_strategy_ratio[2] if len(self.config.progress_strategy_ratio) > 2 else 0.0,
),
(
DataGenStrat.REWIND,
self.config.progress_strategy_ratio[3] if len(self.config.progress_strategy_ratio) > 3 else 0.0,
),
]
# Remove strategies with zero probability
strategies = [(strat, prob) for strat, prob in strategies if prob > 0]
max_attempts = 10 # Limit retry attempts to prevent infinite loops
attempt = 0
while processed_traj is None and attempt < max_attempts:
attempt += 1
# Check if we have any strategies left
if not strategies:
return None
# Rebalance probabilities based on remaining strategies
total_prob = sum(prob for _, prob in strategies)
if total_prob == 0:
return None
# Normalize probabilities
normalized_strategies = [(strat, prob / total_prob) for strat, prob in strategies]
# Select strategy based on rebalanced probabilities
prob = random.random()
cumulative_prob = 0.0
selected_strategy = None
for strat, normalized_prob in normalized_strategies:
cumulative_prob += normalized_prob
if prob <= cumulative_prob:
selected_strategy = strat
break
# Execute selected strategy
result = self._execute_strategy(selected_strategy, traj)
if result is not None:
processed_traj, subsample_strategy = result
strategy_used = selected_strategy
else:
# Remove failed strategy and try again
strategies = [(strat, prob) for strat, prob in strategies if strat != selected_strategy]
continue
# If we still don't have a sample after all attempts, return None
if processed_traj is None:
logger.trace(
f"[PROGRESS SAMPLER] Failed to generate progress sample after {max_attempts} attempts - all strategies exhausted"
)
return None
progress_traj = self._get_traj_from_data(processed_traj, subsample_strategy=subsample_strategy)
if progress_traj is None:
return None
# Handle special cases
if strategy_used in [DataGenStrat.DIFFERENT_TASK, DataGenStrat.DIFFERENT_TASK_INSTRUCTION]:
# We need to use the original task embeddings instead of the different task embeddings
if self.config.load_embeddings and traj.get("embeddings_path"):
progress_traj.text_embedding = load_embeddings_from_path(traj["embeddings_path"])["text_embedding"]
progress_traj.lang_vector = traj["lang_vector"]
progress_traj.task = traj["task"]
progress_traj.target_progress = [0.0] * len(progress_traj.target_progress)
if self.config.progress_loss_type.lower() == "discrete":
progress_traj.target_progress = convert_continuous_to_discrete_bins(
progress_traj.target_progress, self.config.progress_discrete_bins
)
# Also set success labels to 0.0 (predict 0 success for different task trajectories)
if progress_traj.success_label is not None:
progress_traj.success_label = [0.0] * len(progress_traj.success_label)
strategy_value = strategy_used.value if isinstance(strategy_used, DataGenStrat) else strategy_used
sample = ProgressSample(
trajectory=progress_traj,
sample_type="progress",
data_gen_strategy=strategy_value,
)
sample.resample_attempts = attempt
return sample
|