rewardeval_ui / samplers /progress.py
Anthony Liang
updates
88e2e89
from typing import Dict, Any, Optional
import random
import torch
from robometer.data.dataset_types import ProgressSample, Trajectory
from robometer.data.samplers.base import RBMBaseSampler
from robometer.data.datasets.helpers import (
DataGenStrat,
load_embeddings_from_path,
convert_continuous_to_discrete_bins,
)
from robometer.utils.distributed import rank_0_print
from robometer.utils.logger import get_logger
logger = get_logger()
class ProgressSampler(RBMBaseSampler):
"""Data generator for progress samples."""
def __init__(self, is_evaluation=False, **kwargs):
super().__init__(**kwargs)
def _generate_sample(self, item: Dict[str, Any], preferred_strategy: Optional[DataGenStrat] = None):
return self._create_progress_sample(item, preferred_strategy=preferred_strategy)
def _execute_strategy(self, strategy: DataGenStrat, traj: Dict[str, Any]) -> tuple[Dict[str, Any], str] | None:
"""Execute a strategy to get processed trajectory.
Args:
strategy: The strategy to execute
traj: The trajectory to process
Returns:
Tuple of (processed_traj, subsample_strategy) or None if failed
"""
if strategy == DataGenStrat.FORWARD_PROGRESS:
return (traj, "subsample_forward")
elif strategy == DataGenStrat.REVERSE_PROGRESS:
return (traj, "subsample_reverse")
elif strategy == DataGenStrat.REWIND:
return (traj, "subsample_rewind")
elif strategy == DataGenStrat.DIFFERENT_TASK_INSTRUCTION:
processed_traj = self._get_different_task_instruction(traj)
if processed_traj is None:
return None
return (processed_traj, "subsample_forward")
else:
return None
def _create_progress_sample(self, traj: Dict[str, Any], preferred_strategy: Optional[DataGenStrat] = None):
"""Create a progress sample using normalized and rebalanced strategy selection.
Implements four strategies:
1. Different Task: Use trajectory from different task (progress set to 0.0)
2. Forward Progress: Sample with forward direction (start < middle < end)
3. Reverse Progress: Sample with reverse direction (end < middle < start)
4. Rewind: Sample with rewind direction (start < end < middle)
"""
# Initialize variables for strategy selection
processed_traj = None
strategy_used = None
subsample_strategy = None
# Strategy selection: use preferred_strategy if provided, otherwise select based on ratios
if preferred_strategy is not None:
# Use the preferred strategy directly
logger.trace(f"[PROGRESS SAMPLER] Using preferred strategy: {preferred_strategy.value}")
result = self._execute_strategy(preferred_strategy, traj)
if result is None:
logger.trace(f"[PROGRESS SAMPLER] Preferred strategy {preferred_strategy.value} failed, returning None")
return None
processed_traj, subsample_strategy = result
strategy_used = preferred_strategy
attempt = 1 # Set attempt for preferred strategy path
else:
# Strategy setup with rebalancing on failure
# [different_task_instruction, forward_progress, reverse_progress, rewind]
strategies = [
(
DataGenStrat.DIFFERENT_TASK_INSTRUCTION,
self.config.progress_strategy_ratio[0] if len(self.config.progress_strategy_ratio) > 0 else 0.0,
),
(
DataGenStrat.FORWARD_PROGRESS,
self.config.progress_strategy_ratio[1] if len(self.config.progress_strategy_ratio) > 1 else 0.0,
),
(
DataGenStrat.REVERSE_PROGRESS,
self.config.progress_strategy_ratio[2] if len(self.config.progress_strategy_ratio) > 2 else 0.0,
),
(
DataGenStrat.REWIND,
self.config.progress_strategy_ratio[3] if len(self.config.progress_strategy_ratio) > 3 else 0.0,
),
]
# Remove strategies with zero probability
strategies = [(strat, prob) for strat, prob in strategies if prob > 0]
max_attempts = 10 # Limit retry attempts to prevent infinite loops
attempt = 0
while processed_traj is None and attempt < max_attempts:
attempt += 1
# Check if we have any strategies left
if not strategies:
return None
# Rebalance probabilities based on remaining strategies
total_prob = sum(prob for _, prob in strategies)
if total_prob == 0:
return None
# Normalize probabilities
normalized_strategies = [(strat, prob / total_prob) for strat, prob in strategies]
# Select strategy based on rebalanced probabilities
prob = random.random()
cumulative_prob = 0.0
selected_strategy = None
for strat, normalized_prob in normalized_strategies:
cumulative_prob += normalized_prob
if prob <= cumulative_prob:
selected_strategy = strat
break
# Execute selected strategy
result = self._execute_strategy(selected_strategy, traj)
if result is not None:
processed_traj, subsample_strategy = result
strategy_used = selected_strategy
else:
# Remove failed strategy and try again
strategies = [(strat, prob) for strat, prob in strategies if strat != selected_strategy]
continue
# If we still don't have a sample after all attempts, return None
if processed_traj is None:
logger.trace(
f"[PROGRESS SAMPLER] Failed to generate progress sample after {max_attempts} attempts - all strategies exhausted"
)
return None
progress_traj = self._get_traj_from_data(processed_traj, subsample_strategy=subsample_strategy)
if progress_traj is None:
return None
# Handle special cases
if strategy_used in [DataGenStrat.DIFFERENT_TASK, DataGenStrat.DIFFERENT_TASK_INSTRUCTION]:
# We need to use the original task embeddings instead of the different task embeddings
if self.config.load_embeddings and traj.get("embeddings_path"):
progress_traj.text_embedding = load_embeddings_from_path(traj["embeddings_path"])["text_embedding"]
progress_traj.lang_vector = traj["lang_vector"]
progress_traj.task = traj["task"]
progress_traj.target_progress = [0.0] * len(progress_traj.target_progress)
if self.config.progress_loss_type.lower() == "discrete":
progress_traj.target_progress = convert_continuous_to_discrete_bins(
progress_traj.target_progress, self.config.progress_discrete_bins
)
# Also set success labels to 0.0 (predict 0 success for different task trajectories)
if progress_traj.success_label is not None:
progress_traj.success_label = [0.0] * len(progress_traj.success_label)
strategy_value = strategy_used.value if isinstance(strategy_used, DataGenStrat) else strategy_used
sample = ProgressSample(
trajectory=progress_traj,
sample_type="progress",
data_gen_strategy=strategy_value,
)
sample.resample_attempts = attempt
return sample