from typing import Dict, Any, Optional import random import torch from robometer.data.dataset_types import ProgressSample, Trajectory from robometer.data.samplers.base import RBMBaseSampler from robometer.data.datasets.helpers import ( DataGenStrat, load_embeddings_from_path, convert_continuous_to_discrete_bins, ) from robometer.utils.distributed import rank_0_print from robometer.utils.logger import get_logger logger = get_logger() class ProgressSampler(RBMBaseSampler): """Data generator for progress samples.""" def __init__(self, is_evaluation=False, **kwargs): super().__init__(**kwargs) def _generate_sample(self, item: Dict[str, Any], preferred_strategy: Optional[DataGenStrat] = None): return self._create_progress_sample(item, preferred_strategy=preferred_strategy) def _execute_strategy(self, strategy: DataGenStrat, traj: Dict[str, Any]) -> tuple[Dict[str, Any], str] | None: """Execute a strategy to get processed trajectory. Args: strategy: The strategy to execute traj: The trajectory to process Returns: Tuple of (processed_traj, subsample_strategy) or None if failed """ if strategy == DataGenStrat.FORWARD_PROGRESS: return (traj, "subsample_forward") elif strategy == DataGenStrat.REVERSE_PROGRESS: return (traj, "subsample_reverse") elif strategy == DataGenStrat.REWIND: return (traj, "subsample_rewind") elif strategy == DataGenStrat.DIFFERENT_TASK_INSTRUCTION: processed_traj = self._get_different_task_instruction(traj) if processed_traj is None: return None return (processed_traj, "subsample_forward") else: return None def _create_progress_sample(self, traj: Dict[str, Any], preferred_strategy: Optional[DataGenStrat] = None): """Create a progress sample using normalized and rebalanced strategy selection. Implements four strategies: 1. Different Task: Use trajectory from different task (progress set to 0.0) 2. Forward Progress: Sample with forward direction (start < middle < end) 3. Reverse Progress: Sample with reverse direction (end < middle < start) 4. Rewind: Sample with rewind direction (start < end < middle) """ # Initialize variables for strategy selection processed_traj = None strategy_used = None subsample_strategy = None # Strategy selection: use preferred_strategy if provided, otherwise select based on ratios if preferred_strategy is not None: # Use the preferred strategy directly logger.trace(f"[PROGRESS SAMPLER] Using preferred strategy: {preferred_strategy.value}") result = self._execute_strategy(preferred_strategy, traj) if result is None: logger.trace(f"[PROGRESS SAMPLER] Preferred strategy {preferred_strategy.value} failed, returning None") return None processed_traj, subsample_strategy = result strategy_used = preferred_strategy attempt = 1 # Set attempt for preferred strategy path else: # Strategy setup with rebalancing on failure # [different_task_instruction, forward_progress, reverse_progress, rewind] strategies = [ ( DataGenStrat.DIFFERENT_TASK_INSTRUCTION, self.config.progress_strategy_ratio[0] if len(self.config.progress_strategy_ratio) > 0 else 0.0, ), ( DataGenStrat.FORWARD_PROGRESS, self.config.progress_strategy_ratio[1] if len(self.config.progress_strategy_ratio) > 1 else 0.0, ), ( DataGenStrat.REVERSE_PROGRESS, self.config.progress_strategy_ratio[2] if len(self.config.progress_strategy_ratio) > 2 else 0.0, ), ( DataGenStrat.REWIND, self.config.progress_strategy_ratio[3] if len(self.config.progress_strategy_ratio) > 3 else 0.0, ), ] # Remove strategies with zero probability strategies = [(strat, prob) for strat, prob in strategies if prob > 0] max_attempts = 10 # Limit retry attempts to prevent infinite loops attempt = 0 while processed_traj is None and attempt < max_attempts: attempt += 1 # Check if we have any strategies left if not strategies: return None # Rebalance probabilities based on remaining strategies total_prob = sum(prob for _, prob in strategies) if total_prob == 0: return None # Normalize probabilities normalized_strategies = [(strat, prob / total_prob) for strat, prob in strategies] # Select strategy based on rebalanced probabilities prob = random.random() cumulative_prob = 0.0 selected_strategy = None for strat, normalized_prob in normalized_strategies: cumulative_prob += normalized_prob if prob <= cumulative_prob: selected_strategy = strat break # Execute selected strategy result = self._execute_strategy(selected_strategy, traj) if result is not None: processed_traj, subsample_strategy = result strategy_used = selected_strategy else: # Remove failed strategy and try again strategies = [(strat, prob) for strat, prob in strategies if strat != selected_strategy] continue # If we still don't have a sample after all attempts, return None if processed_traj is None: logger.trace( f"[PROGRESS SAMPLER] Failed to generate progress sample after {max_attempts} attempts - all strategies exhausted" ) return None progress_traj = self._get_traj_from_data(processed_traj, subsample_strategy=subsample_strategy) if progress_traj is None: return None # Handle special cases if strategy_used in [DataGenStrat.DIFFERENT_TASK, DataGenStrat.DIFFERENT_TASK_INSTRUCTION]: # We need to use the original task embeddings instead of the different task embeddings if self.config.load_embeddings and traj.get("embeddings_path"): progress_traj.text_embedding = load_embeddings_from_path(traj["embeddings_path"])["text_embedding"] progress_traj.lang_vector = traj["lang_vector"] progress_traj.task = traj["task"] progress_traj.target_progress = [0.0] * len(progress_traj.target_progress) if self.config.progress_loss_type.lower() == "discrete": progress_traj.target_progress = convert_continuous_to_discrete_bins( progress_traj.target_progress, self.config.progress_discrete_bins ) # Also set success labels to 0.0 (predict 0 success for different task trajectories) if progress_traj.success_label is not None: progress_traj.success_label = [0.0] * len(progress_traj.success_label) strategy_value = strategy_used.value if isinstance(strategy_used, DataGenStrat) else strategy_used sample = ProgressSample( trajectory=progress_traj, sample_type="progress", data_gen_strategy=strategy_value, ) sample.resample_attempts = attempt return sample