Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| from typing import Optional, Dict, Any, List, Set, Tuple, Union | |
| import numpy as np | |
| import random | |
| import torch | |
| from random import Random | |
| from datasets import Dataset | |
| from robometer.configs.experiment_configs import DataConfig | |
| from robometer.data.datasets.helpers import ( | |
| load_frames_from_npz, | |
| get_segment_indices_with_middle, | |
| compute_progress_from_segment, | |
| pad_trajectory_to_max_frames_torch, | |
| pad_trajectory_to_max_frames_np, | |
| compute_success_labels, | |
| create_trajectory_from_dict, | |
| load_embeddings_from_path, | |
| linspace_subsample_frames, | |
| convert_continuous_to_discrete_bins, | |
| ) | |
| from robometer.data.dataset_types import Trajectory | |
| from robometer.utils.logger import get_logger | |
| from robometer.data.dataset_category import is_preference_only_ds | |
| logger = get_logger() | |
| class RBMBaseSampler: | |
| """Base sampler class that provides trajectory retrieval functions for generating samples.""" | |
| def __init__( | |
| self, | |
| config: DataConfig, | |
| dataset: Dataset, | |
| combined_indices: Dict[str, Any], | |
| dataset_success_cutoff_map: Optional[Dict[str, float]] = None, | |
| verbose: bool = True, | |
| random_seed: int = 42, | |
| pad_frames: bool = True, | |
| ): | |
| """Initialize sampler with dataset and indices. | |
| Args: | |
| config: Configuration object | |
| dataset: The loaded dataset | |
| combined_indices: Dictionary of combined indices from dataset loading | |
| dataset_success_cutoff_map: Dictionary mapping dataset names to success cutoff percentages | |
| verbose: Verbose flag | |
| random_seed: Random seed for deterministic sampling. Creates a local Random instance to avoid affecting global random state. | |
| """ | |
| self.config = config | |
| self.dataset = dataset | |
| self.verbose = verbose | |
| self.dataset_success_cutoff_map = dataset_success_cutoff_map or {} | |
| self._local_random = Random(random_seed) | |
| self.pad_frames = pad_frames | |
| self._cached_ids = self.dataset["id"] | |
| self._cached_is_robot = self.dataset["is_robot"] | |
| # Build indices from combined_indices | |
| self._build_indices(combined_indices) | |
| def _build_indices(self, combined_indices): | |
| """Build all index mappings from combined_indices. | |
| Args: | |
| combined_indices: Dictionary of combined indices from dataset loading | |
| """ | |
| # Initialize index mappings from the loaded indices | |
| self.robot_trajectories = combined_indices["robot_trajectories"] | |
| self.human_trajectories = combined_indices["human_trajectories"] | |
| self.optimal_by_task = combined_indices["optimal_by_task"] | |
| self.suboptimal_by_task = combined_indices["suboptimal_by_task"] | |
| self.quality_indices = combined_indices["quality_indices"] | |
| self.task_indices = combined_indices["task_indices"] | |
| self.source_indices = combined_indices["source_indices"] | |
| self.partial_success_indices = combined_indices["partial_success_indices"] | |
| self.paired_human_robot_by_task = combined_indices["paired_human_robot_by_task"] | |
| self.tasks_with_multiple_quality_labels = combined_indices["tasks_with_multiple_quality_labels"] | |
| # Build mapping from data source -> available task instructions | |
| self._build_tasks_by_data_source() | |
| def _build_tasks_by_data_source(self): | |
| """Cache mapping from data_source to available task instructions.""" | |
| self.tasks_by_data_source: Dict[str, List[str]] = {} | |
| all_tasks = self.dataset["task"] | |
| all_sources = self.dataset["data_source"] | |
| source_to_tasks: Dict[str, Set[str]] = {} | |
| for task, source in zip(all_tasks, all_sources): | |
| if task is None or source is None: | |
| continue | |
| if source not in source_to_tasks: | |
| source_to_tasks[source] = set() | |
| source_to_tasks[source].add(task) | |
| self.tasks_by_data_source = {source: list(tasks) for source, tasks in source_to_tasks.items()} | |
| def _generate_sample(self, item): | |
| """Generate a sample from an item. | |
| This method should be overridden by subclasses to implement their specific | |
| sample generation logic. | |
| Args: | |
| item: An item from the dataset (typically a trajectory dict) | |
| Returns: | |
| A sample object (e.g., PreferenceSample, ProgressSample) | |
| """ | |
| raise NotImplementedError("Subclasses must implement _generate_sample") | |
| def _get_same_task_optimal(self, ref_traj: dict) -> dict | None: | |
| """Get optimal trajectory from same task (different from ref). | |
| Args: | |
| ref_traj: Reference trajectory | |
| Returns: | |
| Same task optimal trajectory dict or None if not available | |
| """ | |
| task_name = ref_traj["task"] | |
| same_task_optimal_indices = self.optimal_by_task.get(task_name, []) | |
| if not same_task_optimal_indices: | |
| logger.trace(f"[BASE SAMPLER] _get_same_task_optimal: No optimal indices for task '{task_name}'") | |
| return None | |
| # Use cached IDs to check without loading full trajectories | |
| chosen_id = ref_traj["id"] | |
| random_idx = random.choice(same_task_optimal_indices) | |
| # Retry if the selected trajectory has the same ID as ref | |
| max_retries = min(10, len(same_task_optimal_indices)) | |
| retries = 0 | |
| while self._cached_ids[random_idx] == chosen_id and retries < max_retries: | |
| random_idx = random.choice(same_task_optimal_indices) | |
| retries += 1 | |
| # If still matches after retries, fall back to filtering | |
| if self._cached_ids[random_idx] == chosen_id: | |
| filtered_indices = [idx for idx in same_task_optimal_indices if self._cached_ids[idx] != chosen_id] | |
| if filtered_indices: | |
| random_idx = random.choice(filtered_indices) | |
| else: | |
| # No other trajectories available | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_same_task_optimal: All trajectories have same ID '{chosen_id}' for task '{task_name}'" | |
| ) | |
| return None | |
| result = self.dataset[random_idx] | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_same_task_optimal: Found trajectory {result.get('id', 'unknown')} for task '{task_name}'" | |
| ) | |
| return result | |
| def _get_same_task_suboptimal(self, ref_traj: dict) -> dict | None: | |
| """Get suboptimal trajectory from same task. | |
| For trajectories with partial_success, uses partial_success logic instead of quality_label logic. | |
| Args: | |
| ref_traj: Reference trajectory | |
| Returns: | |
| Suboptimal trajectory dict or None if not available | |
| """ | |
| # Check if this trajectory uses partial_success | |
| use_partial_success = ref_traj.get("partial_success") is not None | |
| if use_partial_success: | |
| # For trajectories with partial_success, use partial_success logic | |
| return self._get_different_partial_success_traj(ref_traj) | |
| # For trajectories without partial_success, use the standard suboptimal logic | |
| task_name = ref_traj["task"] | |
| same_task_suboptimal_indices = self.suboptimal_by_task.get(task_name, []) | |
| if not same_task_suboptimal_indices: | |
| logger.trace(f"[BASE SAMPLER] _get_same_task_suboptimal: No suboptimal indices for task '{task_name}'") | |
| return None | |
| # Use cached IDs to check without loading full trajectories | |
| chosen_id = ref_traj["id"] | |
| random_idx = random.choice(same_task_suboptimal_indices) | |
| # Retry if the selected trajectory has the same ID as ref | |
| max_retries = min(10, len(same_task_suboptimal_indices)) | |
| retries = 0 | |
| while self._cached_ids[random_idx] == chosen_id and retries < max_retries: | |
| random_idx = random.choice(same_task_suboptimal_indices) | |
| retries += 1 | |
| # If still matches after retries, fall back to filtering | |
| if self._cached_ids[random_idx] == chosen_id: | |
| filtered_indices = [idx for idx in same_task_suboptimal_indices if self._cached_ids[idx] != chosen_id] | |
| if filtered_indices: | |
| random_idx = random.choice(filtered_indices) | |
| else: | |
| # No other trajectories available | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_same_task_suboptimal: All trajectories have same ID '{chosen_id}' for task '{task_name}'" | |
| ) | |
| return None | |
| result = self.dataset[random_idx] | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_same_task_suboptimal: Found trajectory {result.get('id', 'unknown')} for task '{task_name}'" | |
| ) | |
| return result | |
| def _get_different_video_traj(self, ref_traj: dict) -> dict | None: | |
| """Get trajectory from different task. | |
| Args: | |
| ref_traj: Reference trajectory | |
| Returns: | |
| Different task trajectory dict or None if not available | |
| """ | |
| same_source_prob = self.config.traj_same_source_prob | |
| data_source = ref_traj.get("data_source") | |
| other_tasks = [] | |
| if data_source and data_source in self.tasks_by_data_source and random.random() < same_source_prob: | |
| other_tasks = [task for task in self.tasks_by_data_source[data_source] if task != ref_traj["task"]] | |
| if not other_tasks: | |
| other_tasks = [task for task in self.optimal_by_task.keys() if task != ref_traj["task"]] | |
| if not other_tasks: | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_different_video_traj: No other tasks available (ref task: '{ref_traj['task']}')" | |
| ) | |
| return None | |
| # Try up to 2 times to find a valid task | |
| max_retries = 2 | |
| other_task_indices = None | |
| other_task = None | |
| for attempt in range(max_retries): | |
| other_task = random.choice(other_tasks) | |
| if other_task not in self.optimal_by_task: | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_different_video_traj: Attempt {attempt + 1}/{max_retries}: Task '{other_task}' not found in optimal_by_task" | |
| ) | |
| continue | |
| other_task_indices = self.optimal_by_task[other_task] | |
| if not other_task_indices: | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_different_video_traj: Attempt {attempt + 1}/{max_retries}: Task '{other_task}' has no optimal indices" | |
| ) | |
| continue | |
| # Found a valid task with indices | |
| break | |
| if other_task_indices is None or not other_task_indices: | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_different_video_traj: Failed to find valid task after {max_retries} attempts" | |
| ) | |
| return None | |
| other_idx = random.choice(other_task_indices) | |
| result = self.dataset[other_idx] | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_different_video_traj: Found trajectory {result.get('id', 'unknown')} from task '{other_task}'" | |
| ) | |
| return result | |
| def _get_different_task_instruction(self, ref_traj: dict) -> dict | None: | |
| """Get the same trajectory but with a different task instruction. | |
| Args: | |
| ref_traj: Reference trajectory | |
| Returns: | |
| Trajectory dict with different task instruction or None if not available | |
| """ | |
| same_source_prob = self.config.traj_same_source_prob | |
| data_source = ref_traj.get("data_source") | |
| candidate_tasks = [] | |
| if data_source and data_source in self.tasks_by_data_source and random.random() < same_source_prob: | |
| candidate_tasks = [task for task in self.tasks_by_data_source[data_source] if task != ref_traj["task"]] | |
| if not candidate_tasks: | |
| candidate_tasks = [task for task in self.optimal_by_task.keys() if task != ref_traj["task"]] | |
| if not candidate_tasks: | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_different_task_instruction: No candidate tasks available (ref task: '{ref_traj['task']}')" | |
| ) | |
| return None | |
| other_task = random.choice(candidate_tasks) | |
| # Get embeddings_path and lang_vector from a random trajectory with the other_task | |
| other_task_indices = self.optimal_by_task.get(other_task, []) | |
| if not other_task_indices: | |
| logger.trace(f"[BASE SAMPLER] _get_different_task_instruction: Task '{other_task}' has no optimal indices") | |
| return None | |
| other_task_idx = random.choice(other_task_indices) | |
| other_task_traj = self.dataset[other_task_idx] | |
| # Create a copy of the trajectory with the task changed | |
| # Use embeddings_path and lang_vector from the other_task trajectory | |
| new_traj = ref_traj.copy() | |
| new_traj["task"] = other_task | |
| # Get embeddings_path and lang_vector from a random trajectory with the other_task | |
| if "embeddings_path" in other_task_traj: | |
| new_traj["embeddings_path"] = other_task_traj["embeddings_path"] | |
| if "lang_vector" in other_task_traj: | |
| new_traj["lang_vector"] = other_task_traj["lang_vector"] | |
| return new_traj | |
| def _get_paired_human_robot_traj(self, ref_traj: dict) -> dict | None: | |
| """Get paired human/robot trajectory for the same task. | |
| Given a reference trajectory, if it's a robot trajectory, returns a human trajectory | |
| from the same task. If it's a human trajectory, returns a robot trajectory from the | |
| same task. | |
| Args: | |
| ref_traj: Reference trajectory (can be robot or human) | |
| Returns: | |
| Paired trajectory dict (opposite type) or None if not available | |
| """ | |
| task = ref_traj["task"] | |
| is_robot = ref_traj.get("is_robot", True) | |
| if task not in self.paired_human_robot_by_task: | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_paired_human_robot_traj: Task '{task}' not in paired_human_robot_by_task" | |
| ) | |
| return None | |
| task_pairs = self.paired_human_robot_by_task[task] | |
| # Get opposite type | |
| opposite_key = "human" if is_robot else "robot" | |
| opposite_indices = task_pairs.get(opposite_key, []) | |
| if not opposite_indices: | |
| logger.trace(f"[BASE SAMPLER] _get_paired_human_robot_traj: No {opposite_key} indices for task '{task}'") | |
| return None | |
| # Sample a paired trajectory and verify it's different from reference | |
| chosen_id = ref_traj["id"] | |
| available_indices = opposite_indices.copy() | |
| paired_traj = None | |
| # Add retry limit to prevent infinite loops | |
| max_retries = min(len(available_indices), 10) | |
| retries = 0 | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_paired_human_robot_traj: Looking for {opposite_key} trajectory (chosen_id: {chosen_id}, available: {len(available_indices)})" | |
| ) | |
| while (paired_traj is None or paired_traj.get("id") == chosen_id) and retries < max_retries: | |
| retries += 1 | |
| if not available_indices: | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_paired_human_robot_traj: No more available indices after {retries} retries" | |
| ) | |
| return None | |
| paired_idx = random.choice(available_indices) | |
| paired_traj = self.dataset[paired_idx] | |
| # If it matches, remove this index and try again | |
| if paired_traj.get("id") == chosen_id: | |
| available_indices = [idx for idx in available_indices if idx != paired_idx] | |
| paired_traj = None | |
| continue | |
| # If we exhausted retries without finding a valid trajectory, return None | |
| if paired_traj is None or paired_traj.get("id") == chosen_id: | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_paired_human_robot_traj: Failed to find valid paired trajectory after {max_retries} retries" | |
| ) | |
| return None | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_paired_human_robot_traj: Found paired trajectory {paired_traj.get('id', 'unknown')} on retry {retries}" | |
| ) | |
| return paired_traj | |
| def _get_different_partial_success_traj(self, ref_traj: dict) -> dict | None: | |
| """Get trajectory from same task with different partial_success. | |
| Finds trajectories with either higher or lower partial_success than the reference, | |
| using absolute difference for threshold checking. | |
| Args: | |
| ref_traj: Reference trajectory | |
| Returns: | |
| Trajectory dict with different partial_success from same task or None if not available | |
| """ | |
| task_name = ref_traj["task"] | |
| ref_partial_success = ref_traj.get("partial_success") | |
| # Check if partial_success is available | |
| if ref_partial_success is None: | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_different_partial_success_traj: No partial_success for trajectory {ref_traj.get('id', 'unknown')}" | |
| ) | |
| return None | |
| # Get minimum threshold from config | |
| min_threshold = getattr(self.config, "partial_success_threshold", 0.2) | |
| # Get all trajectories from the same task | |
| same_task_indices = self.task_indices.get(task_name, []) | |
| if not same_task_indices: | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_different_partial_success_traj: No trajectories found for task '{task_name}'" | |
| ) | |
| return None | |
| # Filter to trajectories with different partial_success that meet the threshold requirement | |
| # Uses absolute difference to allow both higher and lower partial_success | |
| chosen_id = ref_traj["id"] | |
| candidate_indices = [] | |
| for idx in same_task_indices: | |
| # Skip if same trajectory | |
| if self._cached_ids[idx] == chosen_id: | |
| continue | |
| # Get partial_success for this trajectory | |
| traj_dict = self.dataset[idx] | |
| traj_partial_success = traj_dict.get("partial_success", None) | |
| if traj_partial_success is None: | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_different_partial_success_traj: No partial_success for trajectory {traj_dict.get('id', 'unknown')}, task '{task_name}'" | |
| ) | |
| continue | |
| # Include if partial_success differs from reference by at least the threshold (using abs) | |
| partial_success_diff = abs(ref_partial_success - traj_partial_success) | |
| if partial_success_diff >= min_threshold: | |
| candidate_indices.append(idx) | |
| if not candidate_indices: | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_different_partial_success_traj: No trajectories with different partial_success (threshold: {min_threshold}) for task '{task_name}' (ref: {ref_partial_success})" | |
| ) | |
| return None | |
| # Randomly select from candidates | |
| selected_idx = random.choice(candidate_indices) | |
| result = self.dataset[selected_idx] | |
| result_partial_success = result.get("partial_success") | |
| # If ref_partial_success is 1.0, direction is always "lower" since 1.0 is the maximum | |
| if ref_partial_success == 1.0: | |
| direction = "lower" | |
| else: | |
| direction = "higher" if result_partial_success > ref_partial_success else "lower" | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_different_partial_success_traj: Found trajectory {result.get('id', 'unknown')} with partial_success {result_partial_success} ({direction} than {ref_partial_success}, abs diff: {abs(ref_partial_success - result_partial_success):.3f}, threshold: {min_threshold})" | |
| ) | |
| return result | |
| def _get_subsample_indices( | |
| self, data, direction: str = "bidirectional", max_frames: int = None | |
| ) -> Optional[Tuple[int, int, int]]: | |
| """Get start, middle, and end indices for subsample strategy. | |
| Samples three random frames from the trajectory. The relationship between indices | |
| follows three main scenarios: | |
| 1. start < middle < end: forward progress - normal forward progression through trajectory | |
| 2. start < end < middle: rewind progress - forward from start to end, then continues to middle (simulating rewind/backtrack) | |
| 3. end < middle < start: reverse progress - backward from start through middle to end (full backward traversal) | |
| Args: | |
| data: Trajectory data (frames or embeddings) to sample from | |
| direction: Sampling direction - "forward" (start < middle < end), | |
| "reverse" (end < middle < start), | |
| "rewind" (start < end < middle), | |
| or "bidirectional" (any of the 3 orderings) | |
| max_frames: Maximum number of frames to subsample. If 1, returns only start. If 2, returns start and end. | |
| Returns: | |
| Tuple of (start_idx, middle_idx, end_idx), or None if insufficient frames | |
| For max_frames == 1: returns (start_idx, None, None) | |
| For max_frames == 2: returns (start_idx, None, end_idx) | |
| """ | |
| num_frames_total = len(data) if hasattr(data, "__len__") else data.shape[0] | |
| # Handle edge cases for max_frames == 1 or 2 | |
| if max_frames == 1: | |
| # Randomly sample 1 frame | |
| random_idx = random.randint(0, num_frames_total - 1) | |
| logger.trace(f"[BASE SAMPLER] _get_subsample_indices: max_frames=1, randomly sampled idx={random_idx}") | |
| return (random_idx, None, None) | |
| if max_frames == 2: | |
| # Sample 2 frames: either forward (start < end) or reverse (end < start) | |
| # No rewind possible with only 2 frames | |
| if direction == "reverse": | |
| # Reverse: sample end first, then start (end < start) | |
| end_idx = random.randint(0, num_frames_total - 2) | |
| start_idx = random.randint(end_idx + 1, num_frames_total - 1) | |
| else: | |
| # Forward: sample start first, then end (start < end) | |
| start_idx = random.randint(0, num_frames_total - 2) | |
| end_idx = random.randint(start_idx + 1, num_frames_total - 1) | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_subsample_indices: max_frames=2, start_idx={start_idx}, end_idx={end_idx}, direction={direction}" | |
| ) | |
| return (start_idx, None, end_idx) | |
| if num_frames_total < 3: | |
| logger.trace(f"[BASE SAMPLER] _get_subsample_indices: Not enough frames ({num_frames_total})") | |
| return None | |
| # Sample three random distinct frames | |
| frame_indices = sorted(random.sample(range(num_frames_total), 3)) | |
| frame1_idx, frame2_idx, frame3_idx = frame_indices | |
| # Determine start, middle, and end based on direction | |
| # We only care about 3 cases: | |
| # 1. start < middle < end: forward progress | |
| # 2. start < end < middle: rewind progress | |
| # 3. end < middle < start: reverse progress | |
| if direction == "forward": | |
| # Case 1: start < middle < end | |
| start_idx = frame1_idx | |
| middle_idx = frame2_idx | |
| end_idx = frame3_idx | |
| elif direction == "reverse": | |
| # Case 3: end < middle < start | |
| end_idx = frame1_idx | |
| middle_idx = frame2_idx | |
| start_idx = frame3_idx | |
| elif direction == "rewind": | |
| # Case 2: start < end < middle | |
| start_idx = frame1_idx | |
| end_idx = frame2_idx | |
| middle_idx = frame3_idx | |
| else: # bidirectional (default) | |
| # Randomly choose from the 3 cases | |
| pattern = random.choice([1, 2, 3]) | |
| if pattern == 1: # start < middle < end: forward progress | |
| start_idx = frame1_idx | |
| middle_idx = frame2_idx | |
| end_idx = frame3_idx | |
| elif pattern == 2: # start < end < middle: rewind progress | |
| start_idx = frame1_idx | |
| end_idx = frame2_idx | |
| middle_idx = frame3_idx | |
| else: # pattern == 3: end < middle < start: reverse progress | |
| end_idx = frame1_idx | |
| middle_idx = frame2_idx | |
| start_idx = frame3_idx | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_subsample_indices: Selected indices start={start_idx}, middle={middle_idx}, end={end_idx} " | |
| f"from {num_frames_total} total frames (direction: {direction})" | |
| ) | |
| return start_idx, middle_idx, end_idx | |
| def _get_traj_from_data( | |
| self, | |
| traj: dict | Trajectory, | |
| subsample_strategy: str | None = None, | |
| frame_indices: List[int] | None = None, | |
| metadata: Dict[str, Any] | None = None, | |
| pad_frames: bool = True, | |
| ) -> Trajectory: | |
| """Load, subsample, and optionally pad trajectory data and create a Trajectory object. | |
| Args: | |
| traj: Trajectory dict or Trajectory object | |
| subsample_strategy: Optional strategy for subsampling ("subsample_forward", "subsample_reverse", "subsample_rewind", or None for default/bidirectional). Ignored if frame_indices is provided. | |
| frame_indices: Optional list of specific frame indices to use. If provided, subsample_strategy is ignored. | |
| metadata: Optional metadata dict to merge into trajectory metadata. | |
| pad_frames: Whether to pad the trajectory data to max_frames. | |
| Returns: | |
| Trajectory object with loaded and subsampled data (padded) | |
| """ | |
| # Initialize variables | |
| frames = None | |
| video_embeddings = None | |
| text_embedding = None | |
| data = None | |
| if isinstance(traj, Trajectory): | |
| # If already a Trajectory, just return it | |
| return traj | |
| # Load from dict | |
| # Check if text_embedding is already provided in the dict (for samplers that need to override it) | |
| if "text_embedding" in traj and traj["text_embedding"] is not None: | |
| text_embedding = traj["text_embedding"] | |
| if self.config.load_embeddings and traj.get("embeddings_path"): | |
| embeddings = load_embeddings_from_path(traj["embeddings_path"]) | |
| video_embeddings = embeddings["video_embeddings"] | |
| # Only use loaded text_embedding if not already provided in dict | |
| if text_embedding is None: | |
| text_embedding = embeddings["text_embedding"] | |
| data = video_embeddings | |
| else: | |
| if isinstance(traj["frames"], str): | |
| frames = load_frames_from_npz(traj["frames"]) | |
| else: | |
| frames = traj["frames"] | |
| data = frames | |
| # Get total frames for progress computation | |
| if hasattr(data, "shape"): | |
| num_frames_total = data.shape[0] | |
| else: | |
| num_frames_total = len(data) | |
| ds_key = traj["data_source"] | |
| success_cutoff = self.dataset_success_cutoff_map.get(ds_key, self.config.max_success) | |
| # Determine which indices to use (construct indices first, then subsample uniformly) | |
| if frame_indices is not None: | |
| # Use provided frame indices directly | |
| indices = frame_indices | |
| elif subsample_strategy is not None: | |
| # Use subsampling strategy | |
| # Get subsample indices (handles edge cases for max_frames == 1 or 2) | |
| if subsample_strategy == "subsample_forward": | |
| strategy_indices = self._get_subsample_indices( | |
| data, direction="forward", max_frames=self.config.max_frames | |
| ) | |
| elif subsample_strategy == "subsample_reverse": | |
| strategy_indices = self._get_subsample_indices( | |
| data, direction="reverse", max_frames=self.config.max_frames | |
| ) | |
| elif subsample_strategy == "subsample_rewind": | |
| strategy_indices = self._get_subsample_indices( | |
| data, direction="rewind", max_frames=self.config.max_frames | |
| ) | |
| else: | |
| strategy_indices = self._get_subsample_indices( | |
| data, direction="bidirectional", max_frames=self.config.max_frames | |
| ) | |
| if strategy_indices is None: | |
| logger.trace("[BASE SAMPLER] _get_traj_from_data: Failed to get uniform sample indices") | |
| return None | |
| start_idx, middle_idx, end_idx = strategy_indices | |
| logger.trace( | |
| f"[BASE SAMPLER] _get_traj_from_data: Subsampling trajectory with strategy: {subsample_strategy}, start_idx: {start_idx}, middle_idx: {middle_idx}, end_idx: {end_idx}" | |
| ) | |
| # Use middle_idx only for rewind strategy (requires at least 3 frames) | |
| use_middle = subsample_strategy == "subsample_rewind" and middle_idx is not None and num_frames_total >= 3 | |
| # Use get_segment_indices_with_middle to construct indices | |
| indices = get_segment_indices_with_middle( | |
| num_frames_total=num_frames_total, | |
| start_idx=start_idx, | |
| end_idx=end_idx, | |
| middle_idx=middle_idx if use_middle else None, | |
| max_frames=self.config.max_frames, | |
| ) | |
| else: | |
| # No subsampling strategy or indices provided - use all frames | |
| indices = list(range(num_frames_total)) | |
| # Extract data using indices | |
| subsampled = data[indices] | |
| # Get partial_success early to pass to compute_progress_from_segment | |
| partial_success = traj.get("partial_success") | |
| # Compute progress | |
| target_progress = compute_progress_from_segment( | |
| num_frames_total=num_frames_total, | |
| frame_indices=indices, | |
| progress_pred_type=self.config.progress_pred_type, | |
| success_cutoff=success_cutoff, | |
| partial_success=partial_success, | |
| ) | |
| # Subsample uniformly if needed (if we have more frames than max_frames) | |
| current_frame_count = len(subsampled) if hasattr(subsampled, "__len__") else subsampled.shape[0] | |
| if current_frame_count > self.config.max_frames: | |
| subsampled, frame_indices_subsample = linspace_subsample_frames(subsampled, self.config.max_frames) | |
| # Update indices and target_progress | |
| if target_progress and len(target_progress) == current_frame_count: | |
| target_progress = [target_progress[idx] for idx in frame_indices_subsample] | |
| indices = [indices[idx] for idx in frame_indices_subsample] if isinstance(indices, list) else indices | |
| # Pad if needed | |
| if target_progress and pad_frames: | |
| if self.config.load_embeddings: | |
| subsampled, target_progress = pad_trajectory_to_max_frames_torch( | |
| subsampled, target_progress, self.config.max_frames | |
| ) | |
| else: | |
| subsampled, target_progress = pad_trajectory_to_max_frames_np( | |
| subsampled, target_progress, self.config.max_frames | |
| ) | |
| # Create predict_last_frame_mask: mark the last frame if partial_success < 1.0 | |
| # If predict_last_frame_partial_progress is True and partial_success < 1.0 and the last original frame is in the subsampled indices, | |
| # mark all positions where it appears with 1.0, all others 0.0. Otherwise, all 1.0s. | |
| final_frame_count = len(subsampled) | |
| predict_last_frame_mask = [1.0] * final_frame_count # Default: all 1.0s (no masking) | |
| if self.config.predict_last_frame_partial_progress and partial_success is not None: | |
| if partial_success == 1.0 and not is_preference_only_ds(traj["data_source"]): | |
| pass | |
| else: | |
| last_original_frame_idx = num_frames_total - 1 | |
| if isinstance(indices, list) and last_original_frame_idx in indices: | |
| # Find all positions where the last frame index appears | |
| last_frame_positions = [ | |
| i for i, idx in enumerate(indices) if idx == last_original_frame_idx and i < final_frame_count | |
| ] | |
| if last_frame_positions: | |
| # Mark all positions where the last frame appears with 1.0, all others 0.0 | |
| predict_last_frame_mask = [0.0] * final_frame_count | |
| for pos in last_frame_positions: | |
| predict_last_frame_mask[pos] = 1.0 | |
| else: | |
| predict_last_frame_mask = [0.0] * final_frame_count | |
| # Update frames_shape | |
| frames_shape = subsampled.shape if hasattr(subsampled, "shape") else tuple() | |
| # Set frames or video_embeddings | |
| if self.config.load_embeddings: | |
| video_embeddings = subsampled | |
| else: | |
| frames = subsampled | |
| # Compute success labels | |
| success_label = compute_success_labels( | |
| target_progress=target_progress, | |
| data_source=traj["data_source"], | |
| dataset_success_percent=self.dataset_success_cutoff_map, | |
| max_success=self.config.max_success, | |
| quality_label=traj.get("quality_label"), | |
| ) | |
| # Convert partial_success and target_progress to discrete bins if in discrete mode | |
| if self.config.progress_loss_type.lower() == "discrete": | |
| if partial_success is not None: | |
| partial_success = convert_continuous_to_discrete_bins( | |
| [partial_success], self.config.progress_discrete_bins | |
| )[0] | |
| target_progress = convert_continuous_to_discrete_bins(target_progress, self.config.progress_discrete_bins) | |
| trajectory = create_trajectory_from_dict( | |
| traj, | |
| overrides={ | |
| "frames": frames, | |
| "frames_shape": frames_shape, | |
| "video_embeddings": video_embeddings, | |
| "text_embedding": text_embedding, | |
| "target_progress": target_progress, | |
| "success_label": success_label, | |
| "partial_success": partial_success, | |
| "predict_last_frame_mask": predict_last_frame_mask, | |
| "metadata": metadata, | |
| }, | |
| ) | |
| return trajectory | |