Anthony Liang
updates
88e2e89
#!/usr/bin/env python3
from typing import Optional, Dict, Any, List, Set, Tuple, Union
import numpy as np
import random
import torch
from random import Random
from datasets import Dataset
from robometer.configs.experiment_configs import DataConfig
from robometer.data.datasets.helpers import (
load_frames_from_npz,
get_segment_indices_with_middle,
compute_progress_from_segment,
pad_trajectory_to_max_frames_torch,
pad_trajectory_to_max_frames_np,
compute_success_labels,
create_trajectory_from_dict,
load_embeddings_from_path,
linspace_subsample_frames,
convert_continuous_to_discrete_bins,
)
from robometer.data.dataset_types import Trajectory
from robometer.utils.logger import get_logger
from robometer.data.dataset_category import is_preference_only_ds
logger = get_logger()
class RBMBaseSampler:
"""Base sampler class that provides trajectory retrieval functions for generating samples."""
def __init__(
self,
config: DataConfig,
dataset: Dataset,
combined_indices: Dict[str, Any],
dataset_success_cutoff_map: Optional[Dict[str, float]] = None,
verbose: bool = True,
random_seed: int = 42,
pad_frames: bool = True,
):
"""Initialize sampler with dataset and indices.
Args:
config: Configuration object
dataset: The loaded dataset
combined_indices: Dictionary of combined indices from dataset loading
dataset_success_cutoff_map: Dictionary mapping dataset names to success cutoff percentages
verbose: Verbose flag
random_seed: Random seed for deterministic sampling. Creates a local Random instance to avoid affecting global random state.
"""
self.config = config
self.dataset = dataset
self.verbose = verbose
self.dataset_success_cutoff_map = dataset_success_cutoff_map or {}
self._local_random = Random(random_seed)
self.pad_frames = pad_frames
self._cached_ids = self.dataset["id"]
self._cached_is_robot = self.dataset["is_robot"]
# Build indices from combined_indices
self._build_indices(combined_indices)
def _build_indices(self, combined_indices):
"""Build all index mappings from combined_indices.
Args:
combined_indices: Dictionary of combined indices from dataset loading
"""
# Initialize index mappings from the loaded indices
self.robot_trajectories = combined_indices["robot_trajectories"]
self.human_trajectories = combined_indices["human_trajectories"]
self.optimal_by_task = combined_indices["optimal_by_task"]
self.suboptimal_by_task = combined_indices["suboptimal_by_task"]
self.quality_indices = combined_indices["quality_indices"]
self.task_indices = combined_indices["task_indices"]
self.source_indices = combined_indices["source_indices"]
self.partial_success_indices = combined_indices["partial_success_indices"]
self.paired_human_robot_by_task = combined_indices["paired_human_robot_by_task"]
self.tasks_with_multiple_quality_labels = combined_indices["tasks_with_multiple_quality_labels"]
# Build mapping from data source -> available task instructions
self._build_tasks_by_data_source()
def _build_tasks_by_data_source(self):
"""Cache mapping from data_source to available task instructions."""
self.tasks_by_data_source: Dict[str, List[str]] = {}
all_tasks = self.dataset["task"]
all_sources = self.dataset["data_source"]
source_to_tasks: Dict[str, Set[str]] = {}
for task, source in zip(all_tasks, all_sources):
if task is None or source is None:
continue
if source not in source_to_tasks:
source_to_tasks[source] = set()
source_to_tasks[source].add(task)
self.tasks_by_data_source = {source: list(tasks) for source, tasks in source_to_tasks.items()}
def _generate_sample(self, item):
"""Generate a sample from an item.
This method should be overridden by subclasses to implement their specific
sample generation logic.
Args:
item: An item from the dataset (typically a trajectory dict)
Returns:
A sample object (e.g., PreferenceSample, ProgressSample)
"""
raise NotImplementedError("Subclasses must implement _generate_sample")
def _get_same_task_optimal(self, ref_traj: dict) -> dict | None:
"""Get optimal trajectory from same task (different from ref).
Args:
ref_traj: Reference trajectory
Returns:
Same task optimal trajectory dict or None if not available
"""
task_name = ref_traj["task"]
same_task_optimal_indices = self.optimal_by_task.get(task_name, [])
if not same_task_optimal_indices:
logger.trace(f"[BASE SAMPLER] _get_same_task_optimal: No optimal indices for task '{task_name}'")
return None
# Use cached IDs to check without loading full trajectories
chosen_id = ref_traj["id"]
random_idx = random.choice(same_task_optimal_indices)
# Retry if the selected trajectory has the same ID as ref
max_retries = min(10, len(same_task_optimal_indices))
retries = 0
while self._cached_ids[random_idx] == chosen_id and retries < max_retries:
random_idx = random.choice(same_task_optimal_indices)
retries += 1
# If still matches after retries, fall back to filtering
if self._cached_ids[random_idx] == chosen_id:
filtered_indices = [idx for idx in same_task_optimal_indices if self._cached_ids[idx] != chosen_id]
if filtered_indices:
random_idx = random.choice(filtered_indices)
else:
# No other trajectories available
logger.trace(
f"[BASE SAMPLER] _get_same_task_optimal: All trajectories have same ID '{chosen_id}' for task '{task_name}'"
)
return None
result = self.dataset[random_idx]
logger.trace(
f"[BASE SAMPLER] _get_same_task_optimal: Found trajectory {result.get('id', 'unknown')} for task '{task_name}'"
)
return result
def _get_same_task_suboptimal(self, ref_traj: dict) -> dict | None:
"""Get suboptimal trajectory from same task.
For trajectories with partial_success, uses partial_success logic instead of quality_label logic.
Args:
ref_traj: Reference trajectory
Returns:
Suboptimal trajectory dict or None if not available
"""
# Check if this trajectory uses partial_success
use_partial_success = ref_traj.get("partial_success") is not None
if use_partial_success:
# For trajectories with partial_success, use partial_success logic
return self._get_different_partial_success_traj(ref_traj)
# For trajectories without partial_success, use the standard suboptimal logic
task_name = ref_traj["task"]
same_task_suboptimal_indices = self.suboptimal_by_task.get(task_name, [])
if not same_task_suboptimal_indices:
logger.trace(f"[BASE SAMPLER] _get_same_task_suboptimal: No suboptimal indices for task '{task_name}'")
return None
# Use cached IDs to check without loading full trajectories
chosen_id = ref_traj["id"]
random_idx = random.choice(same_task_suboptimal_indices)
# Retry if the selected trajectory has the same ID as ref
max_retries = min(10, len(same_task_suboptimal_indices))
retries = 0
while self._cached_ids[random_idx] == chosen_id and retries < max_retries:
random_idx = random.choice(same_task_suboptimal_indices)
retries += 1
# If still matches after retries, fall back to filtering
if self._cached_ids[random_idx] == chosen_id:
filtered_indices = [idx for idx in same_task_suboptimal_indices if self._cached_ids[idx] != chosen_id]
if filtered_indices:
random_idx = random.choice(filtered_indices)
else:
# No other trajectories available
logger.trace(
f"[BASE SAMPLER] _get_same_task_suboptimal: All trajectories have same ID '{chosen_id}' for task '{task_name}'"
)
return None
result = self.dataset[random_idx]
logger.trace(
f"[BASE SAMPLER] _get_same_task_suboptimal: Found trajectory {result.get('id', 'unknown')} for task '{task_name}'"
)
return result
def _get_different_video_traj(self, ref_traj: dict) -> dict | None:
"""Get trajectory from different task.
Args:
ref_traj: Reference trajectory
Returns:
Different task trajectory dict or None if not available
"""
same_source_prob = self.config.traj_same_source_prob
data_source = ref_traj.get("data_source")
other_tasks = []
if data_source and data_source in self.tasks_by_data_source and random.random() < same_source_prob:
other_tasks = [task for task in self.tasks_by_data_source[data_source] if task != ref_traj["task"]]
if not other_tasks:
other_tasks = [task for task in self.optimal_by_task.keys() if task != ref_traj["task"]]
if not other_tasks:
logger.trace(
f"[BASE SAMPLER] _get_different_video_traj: No other tasks available (ref task: '{ref_traj['task']}')"
)
return None
# Try up to 2 times to find a valid task
max_retries = 2
other_task_indices = None
other_task = None
for attempt in range(max_retries):
other_task = random.choice(other_tasks)
if other_task not in self.optimal_by_task:
logger.trace(
f"[BASE SAMPLER] _get_different_video_traj: Attempt {attempt + 1}/{max_retries}: Task '{other_task}' not found in optimal_by_task"
)
continue
other_task_indices = self.optimal_by_task[other_task]
if not other_task_indices:
logger.trace(
f"[BASE SAMPLER] _get_different_video_traj: Attempt {attempt + 1}/{max_retries}: Task '{other_task}' has no optimal indices"
)
continue
# Found a valid task with indices
break
if other_task_indices is None or not other_task_indices:
logger.trace(
f"[BASE SAMPLER] _get_different_video_traj: Failed to find valid task after {max_retries} attempts"
)
return None
other_idx = random.choice(other_task_indices)
result = self.dataset[other_idx]
logger.trace(
f"[BASE SAMPLER] _get_different_video_traj: Found trajectory {result.get('id', 'unknown')} from task '{other_task}'"
)
return result
def _get_different_task_instruction(self, ref_traj: dict) -> dict | None:
"""Get the same trajectory but with a different task instruction.
Args:
ref_traj: Reference trajectory
Returns:
Trajectory dict with different task instruction or None if not available
"""
same_source_prob = self.config.traj_same_source_prob
data_source = ref_traj.get("data_source")
candidate_tasks = []
if data_source and data_source in self.tasks_by_data_source and random.random() < same_source_prob:
candidate_tasks = [task for task in self.tasks_by_data_source[data_source] if task != ref_traj["task"]]
if not candidate_tasks:
candidate_tasks = [task for task in self.optimal_by_task.keys() if task != ref_traj["task"]]
if not candidate_tasks:
logger.trace(
f"[BASE SAMPLER] _get_different_task_instruction: No candidate tasks available (ref task: '{ref_traj['task']}')"
)
return None
other_task = random.choice(candidate_tasks)
# Get embeddings_path and lang_vector from a random trajectory with the other_task
other_task_indices = self.optimal_by_task.get(other_task, [])
if not other_task_indices:
logger.trace(f"[BASE SAMPLER] _get_different_task_instruction: Task '{other_task}' has no optimal indices")
return None
other_task_idx = random.choice(other_task_indices)
other_task_traj = self.dataset[other_task_idx]
# Create a copy of the trajectory with the task changed
# Use embeddings_path and lang_vector from the other_task trajectory
new_traj = ref_traj.copy()
new_traj["task"] = other_task
# Get embeddings_path and lang_vector from a random trajectory with the other_task
if "embeddings_path" in other_task_traj:
new_traj["embeddings_path"] = other_task_traj["embeddings_path"]
if "lang_vector" in other_task_traj:
new_traj["lang_vector"] = other_task_traj["lang_vector"]
return new_traj
def _get_paired_human_robot_traj(self, ref_traj: dict) -> dict | None:
"""Get paired human/robot trajectory for the same task.
Given a reference trajectory, if it's a robot trajectory, returns a human trajectory
from the same task. If it's a human trajectory, returns a robot trajectory from the
same task.
Args:
ref_traj: Reference trajectory (can be robot or human)
Returns:
Paired trajectory dict (opposite type) or None if not available
"""
task = ref_traj["task"]
is_robot = ref_traj.get("is_robot", True)
if task not in self.paired_human_robot_by_task:
logger.trace(
f"[BASE SAMPLER] _get_paired_human_robot_traj: Task '{task}' not in paired_human_robot_by_task"
)
return None
task_pairs = self.paired_human_robot_by_task[task]
# Get opposite type
opposite_key = "human" if is_robot else "robot"
opposite_indices = task_pairs.get(opposite_key, [])
if not opposite_indices:
logger.trace(f"[BASE SAMPLER] _get_paired_human_robot_traj: No {opposite_key} indices for task '{task}'")
return None
# Sample a paired trajectory and verify it's different from reference
chosen_id = ref_traj["id"]
available_indices = opposite_indices.copy()
paired_traj = None
# Add retry limit to prevent infinite loops
max_retries = min(len(available_indices), 10)
retries = 0
logger.trace(
f"[BASE SAMPLER] _get_paired_human_robot_traj: Looking for {opposite_key} trajectory (chosen_id: {chosen_id}, available: {len(available_indices)})"
)
while (paired_traj is None or paired_traj.get("id") == chosen_id) and retries < max_retries:
retries += 1
if not available_indices:
logger.trace(
f"[BASE SAMPLER] _get_paired_human_robot_traj: No more available indices after {retries} retries"
)
return None
paired_idx = random.choice(available_indices)
paired_traj = self.dataset[paired_idx]
# If it matches, remove this index and try again
if paired_traj.get("id") == chosen_id:
available_indices = [idx for idx in available_indices if idx != paired_idx]
paired_traj = None
continue
# If we exhausted retries without finding a valid trajectory, return None
if paired_traj is None or paired_traj.get("id") == chosen_id:
logger.trace(
f"[BASE SAMPLER] _get_paired_human_robot_traj: Failed to find valid paired trajectory after {max_retries} retries"
)
return None
logger.trace(
f"[BASE SAMPLER] _get_paired_human_robot_traj: Found paired trajectory {paired_traj.get('id', 'unknown')} on retry {retries}"
)
return paired_traj
def _get_different_partial_success_traj(self, ref_traj: dict) -> dict | None:
"""Get trajectory from same task with different partial_success.
Finds trajectories with either higher or lower partial_success than the reference,
using absolute difference for threshold checking.
Args:
ref_traj: Reference trajectory
Returns:
Trajectory dict with different partial_success from same task or None if not available
"""
task_name = ref_traj["task"]
ref_partial_success = ref_traj.get("partial_success")
# Check if partial_success is available
if ref_partial_success is None:
logger.trace(
f"[BASE SAMPLER] _get_different_partial_success_traj: No partial_success for trajectory {ref_traj.get('id', 'unknown')}"
)
return None
# Get minimum threshold from config
min_threshold = getattr(self.config, "partial_success_threshold", 0.2)
# Get all trajectories from the same task
same_task_indices = self.task_indices.get(task_name, [])
if not same_task_indices:
logger.trace(
f"[BASE SAMPLER] _get_different_partial_success_traj: No trajectories found for task '{task_name}'"
)
return None
# Filter to trajectories with different partial_success that meet the threshold requirement
# Uses absolute difference to allow both higher and lower partial_success
chosen_id = ref_traj["id"]
candidate_indices = []
for idx in same_task_indices:
# Skip if same trajectory
if self._cached_ids[idx] == chosen_id:
continue
# Get partial_success for this trajectory
traj_dict = self.dataset[idx]
traj_partial_success = traj_dict.get("partial_success", None)
if traj_partial_success is None:
logger.trace(
f"[BASE SAMPLER] _get_different_partial_success_traj: No partial_success for trajectory {traj_dict.get('id', 'unknown')}, task '{task_name}'"
)
continue
# Include if partial_success differs from reference by at least the threshold (using abs)
partial_success_diff = abs(ref_partial_success - traj_partial_success)
if partial_success_diff >= min_threshold:
candidate_indices.append(idx)
if not candidate_indices:
logger.trace(
f"[BASE SAMPLER] _get_different_partial_success_traj: No trajectories with different partial_success (threshold: {min_threshold}) for task '{task_name}' (ref: {ref_partial_success})"
)
return None
# Randomly select from candidates
selected_idx = random.choice(candidate_indices)
result = self.dataset[selected_idx]
result_partial_success = result.get("partial_success")
# If ref_partial_success is 1.0, direction is always "lower" since 1.0 is the maximum
if ref_partial_success == 1.0:
direction = "lower"
else:
direction = "higher" if result_partial_success > ref_partial_success else "lower"
logger.trace(
f"[BASE SAMPLER] _get_different_partial_success_traj: Found trajectory {result.get('id', 'unknown')} with partial_success {result_partial_success} ({direction} than {ref_partial_success}, abs diff: {abs(ref_partial_success - result_partial_success):.3f}, threshold: {min_threshold})"
)
return result
def _get_subsample_indices(
self, data, direction: str = "bidirectional", max_frames: int = None
) -> Optional[Tuple[int, int, int]]:
"""Get start, middle, and end indices for subsample strategy.
Samples three random frames from the trajectory. The relationship between indices
follows three main scenarios:
1. start < middle < end: forward progress - normal forward progression through trajectory
2. start < end < middle: rewind progress - forward from start to end, then continues to middle (simulating rewind/backtrack)
3. end < middle < start: reverse progress - backward from start through middle to end (full backward traversal)
Args:
data: Trajectory data (frames or embeddings) to sample from
direction: Sampling direction - "forward" (start < middle < end),
"reverse" (end < middle < start),
"rewind" (start < end < middle),
or "bidirectional" (any of the 3 orderings)
max_frames: Maximum number of frames to subsample. If 1, returns only start. If 2, returns start and end.
Returns:
Tuple of (start_idx, middle_idx, end_idx), or None if insufficient frames
For max_frames == 1: returns (start_idx, None, None)
For max_frames == 2: returns (start_idx, None, end_idx)
"""
num_frames_total = len(data) if hasattr(data, "__len__") else data.shape[0]
# Handle edge cases for max_frames == 1 or 2
if max_frames == 1:
# Randomly sample 1 frame
random_idx = random.randint(0, num_frames_total - 1)
logger.trace(f"[BASE SAMPLER] _get_subsample_indices: max_frames=1, randomly sampled idx={random_idx}")
return (random_idx, None, None)
if max_frames == 2:
# Sample 2 frames: either forward (start < end) or reverse (end < start)
# No rewind possible with only 2 frames
if direction == "reverse":
# Reverse: sample end first, then start (end < start)
end_idx = random.randint(0, num_frames_total - 2)
start_idx = random.randint(end_idx + 1, num_frames_total - 1)
else:
# Forward: sample start first, then end (start < end)
start_idx = random.randint(0, num_frames_total - 2)
end_idx = random.randint(start_idx + 1, num_frames_total - 1)
logger.trace(
f"[BASE SAMPLER] _get_subsample_indices: max_frames=2, start_idx={start_idx}, end_idx={end_idx}, direction={direction}"
)
return (start_idx, None, end_idx)
if num_frames_total < 3:
logger.trace(f"[BASE SAMPLER] _get_subsample_indices: Not enough frames ({num_frames_total})")
return None
# Sample three random distinct frames
frame_indices = sorted(random.sample(range(num_frames_total), 3))
frame1_idx, frame2_idx, frame3_idx = frame_indices
# Determine start, middle, and end based on direction
# We only care about 3 cases:
# 1. start < middle < end: forward progress
# 2. start < end < middle: rewind progress
# 3. end < middle < start: reverse progress
if direction == "forward":
# Case 1: start < middle < end
start_idx = frame1_idx
middle_idx = frame2_idx
end_idx = frame3_idx
elif direction == "reverse":
# Case 3: end < middle < start
end_idx = frame1_idx
middle_idx = frame2_idx
start_idx = frame3_idx
elif direction == "rewind":
# Case 2: start < end < middle
start_idx = frame1_idx
end_idx = frame2_idx
middle_idx = frame3_idx
else: # bidirectional (default)
# Randomly choose from the 3 cases
pattern = random.choice([1, 2, 3])
if pattern == 1: # start < middle < end: forward progress
start_idx = frame1_idx
middle_idx = frame2_idx
end_idx = frame3_idx
elif pattern == 2: # start < end < middle: rewind progress
start_idx = frame1_idx
end_idx = frame2_idx
middle_idx = frame3_idx
else: # pattern == 3: end < middle < start: reverse progress
end_idx = frame1_idx
middle_idx = frame2_idx
start_idx = frame3_idx
logger.trace(
f"[BASE SAMPLER] _get_subsample_indices: Selected indices start={start_idx}, middle={middle_idx}, end={end_idx} "
f"from {num_frames_total} total frames (direction: {direction})"
)
return start_idx, middle_idx, end_idx
def _get_traj_from_data(
self,
traj: dict | Trajectory,
subsample_strategy: str | None = None,
frame_indices: List[int] | None = None,
metadata: Dict[str, Any] | None = None,
pad_frames: bool = True,
) -> Trajectory:
"""Load, subsample, and optionally pad trajectory data and create a Trajectory object.
Args:
traj: Trajectory dict or Trajectory object
subsample_strategy: Optional strategy for subsampling ("subsample_forward", "subsample_reverse", "subsample_rewind", or None for default/bidirectional). Ignored if frame_indices is provided.
frame_indices: Optional list of specific frame indices to use. If provided, subsample_strategy is ignored.
metadata: Optional metadata dict to merge into trajectory metadata.
pad_frames: Whether to pad the trajectory data to max_frames.
Returns:
Trajectory object with loaded and subsampled data (padded)
"""
# Initialize variables
frames = None
video_embeddings = None
text_embedding = None
data = None
if isinstance(traj, Trajectory):
# If already a Trajectory, just return it
return traj
# Load from dict
# Check if text_embedding is already provided in the dict (for samplers that need to override it)
if "text_embedding" in traj and traj["text_embedding"] is not None:
text_embedding = traj["text_embedding"]
if self.config.load_embeddings and traj.get("embeddings_path"):
embeddings = load_embeddings_from_path(traj["embeddings_path"])
video_embeddings = embeddings["video_embeddings"]
# Only use loaded text_embedding if not already provided in dict
if text_embedding is None:
text_embedding = embeddings["text_embedding"]
data = video_embeddings
else:
if isinstance(traj["frames"], str):
frames = load_frames_from_npz(traj["frames"])
else:
frames = traj["frames"]
data = frames
# Get total frames for progress computation
if hasattr(data, "shape"):
num_frames_total = data.shape[0]
else:
num_frames_total = len(data)
ds_key = traj["data_source"]
success_cutoff = self.dataset_success_cutoff_map.get(ds_key, self.config.max_success)
# Determine which indices to use (construct indices first, then subsample uniformly)
if frame_indices is not None:
# Use provided frame indices directly
indices = frame_indices
elif subsample_strategy is not None:
# Use subsampling strategy
# Get subsample indices (handles edge cases for max_frames == 1 or 2)
if subsample_strategy == "subsample_forward":
strategy_indices = self._get_subsample_indices(
data, direction="forward", max_frames=self.config.max_frames
)
elif subsample_strategy == "subsample_reverse":
strategy_indices = self._get_subsample_indices(
data, direction="reverse", max_frames=self.config.max_frames
)
elif subsample_strategy == "subsample_rewind":
strategy_indices = self._get_subsample_indices(
data, direction="rewind", max_frames=self.config.max_frames
)
else:
strategy_indices = self._get_subsample_indices(
data, direction="bidirectional", max_frames=self.config.max_frames
)
if strategy_indices is None:
logger.trace("[BASE SAMPLER] _get_traj_from_data: Failed to get uniform sample indices")
return None
start_idx, middle_idx, end_idx = strategy_indices
logger.trace(
f"[BASE SAMPLER] _get_traj_from_data: Subsampling trajectory with strategy: {subsample_strategy}, start_idx: {start_idx}, middle_idx: {middle_idx}, end_idx: {end_idx}"
)
# Use middle_idx only for rewind strategy (requires at least 3 frames)
use_middle = subsample_strategy == "subsample_rewind" and middle_idx is not None and num_frames_total >= 3
# Use get_segment_indices_with_middle to construct indices
indices = get_segment_indices_with_middle(
num_frames_total=num_frames_total,
start_idx=start_idx,
end_idx=end_idx,
middle_idx=middle_idx if use_middle else None,
max_frames=self.config.max_frames,
)
else:
# No subsampling strategy or indices provided - use all frames
indices = list(range(num_frames_total))
# Extract data using indices
subsampled = data[indices]
# Get partial_success early to pass to compute_progress_from_segment
partial_success = traj.get("partial_success")
# Compute progress
target_progress = compute_progress_from_segment(
num_frames_total=num_frames_total,
frame_indices=indices,
progress_pred_type=self.config.progress_pred_type,
success_cutoff=success_cutoff,
partial_success=partial_success,
)
# Subsample uniformly if needed (if we have more frames than max_frames)
current_frame_count = len(subsampled) if hasattr(subsampled, "__len__") else subsampled.shape[0]
if current_frame_count > self.config.max_frames:
subsampled, frame_indices_subsample = linspace_subsample_frames(subsampled, self.config.max_frames)
# Update indices and target_progress
if target_progress and len(target_progress) == current_frame_count:
target_progress = [target_progress[idx] for idx in frame_indices_subsample]
indices = [indices[idx] for idx in frame_indices_subsample] if isinstance(indices, list) else indices
# Pad if needed
if target_progress and pad_frames:
if self.config.load_embeddings:
subsampled, target_progress = pad_trajectory_to_max_frames_torch(
subsampled, target_progress, self.config.max_frames
)
else:
subsampled, target_progress = pad_trajectory_to_max_frames_np(
subsampled, target_progress, self.config.max_frames
)
# Create predict_last_frame_mask: mark the last frame if partial_success < 1.0
# If predict_last_frame_partial_progress is True and partial_success < 1.0 and the last original frame is in the subsampled indices,
# mark all positions where it appears with 1.0, all others 0.0. Otherwise, all 1.0s.
final_frame_count = len(subsampled)
predict_last_frame_mask = [1.0] * final_frame_count # Default: all 1.0s (no masking)
if self.config.predict_last_frame_partial_progress and partial_success is not None:
if partial_success == 1.0 and not is_preference_only_ds(traj["data_source"]):
pass
else:
last_original_frame_idx = num_frames_total - 1
if isinstance(indices, list) and last_original_frame_idx in indices:
# Find all positions where the last frame index appears
last_frame_positions = [
i for i, idx in enumerate(indices) if idx == last_original_frame_idx and i < final_frame_count
]
if last_frame_positions:
# Mark all positions where the last frame appears with 1.0, all others 0.0
predict_last_frame_mask = [0.0] * final_frame_count
for pos in last_frame_positions:
predict_last_frame_mask[pos] = 1.0
else:
predict_last_frame_mask = [0.0] * final_frame_count
# Update frames_shape
frames_shape = subsampled.shape if hasattr(subsampled, "shape") else tuple()
# Set frames or video_embeddings
if self.config.load_embeddings:
video_embeddings = subsampled
else:
frames = subsampled
# Compute success labels
success_label = compute_success_labels(
target_progress=target_progress,
data_source=traj["data_source"],
dataset_success_percent=self.dataset_success_cutoff_map,
max_success=self.config.max_success,
quality_label=traj.get("quality_label"),
)
# Convert partial_success and target_progress to discrete bins if in discrete mode
if self.config.progress_loss_type.lower() == "discrete":
if partial_success is not None:
partial_success = convert_continuous_to_discrete_bins(
[partial_success], self.config.progress_discrete_bins
)[0]
target_progress = convert_continuous_to_discrete_bins(target_progress, self.config.progress_discrete_bins)
trajectory = create_trajectory_from_dict(
traj,
overrides={
"frames": frames,
"frames_shape": frames_shape,
"video_embeddings": video_embeddings,
"text_embedding": text_embedding,
"target_progress": target_progress,
"success_label": success_label,
"partial_success": partial_success,
"predict_last_frame_mask": predict_last_frame_mask,
"metadata": metadata,
},
)
return trajectory