Alan0928's picture
Upload folder using huggingface_hub
08ff31f verified
Raw
History Blame Contribute Delete
20.2 kB
from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass
from fractions import Fraction
import numpy as np
MOTION_CLASS_NAMES = {
0: "still",
1: "translate",
2: "rotate",
3: "translate_rotate",
}
@dataclass(frozen=True)
class SpeedTransformConfig:
"""Speed-augmentation knobs for LIBERO 7D delta actions [dx,dy,dz,droll,dpitch,dyaw,gripper].
Gripper is never used for segmentation; it's copied discretely.
Two modes:
- chunk_aligned_observation=False (legacy): uniform linspace over the segment.
- chunk_aligned_observation=True: speed q/p splits each segment into chunks
of q source frames -> p output frames; the leftover < q frames pass
through 1:1 verbatim. Only chunk-starts (and phase=0 passthrough rows)
are mask=1 in the output. ``chunk_phase`` r in [0, q) shifts chunk
starts to r, r+q, r+2q, ... within each segment (used by the online
sliding sampler to randomize which source frames become chunk-starts).
"""
transl_eps: float = 1e-4
rot_eps: float = 1e-4
# OFF by default: LIBERO demos have no frame below 1e-4 in practice.
# Set 1e-4 explicitly for noisier datasets via the build CLI.
clean_transl_eps: float = 0.0
clean_rot_eps: float = 0.0
direction_cos_threshold: float = -0.25
min_segment_len: int = 1
keep_still_segments: bool = True
fps: int = 20
chunk_aligned_observation: bool = False
chunk_phase: int = 0
def _as_float32_2d(values: Iterable[np.ndarray], expected_dim: int) -> np.ndarray:
arr = np.asarray(list(values), dtype=np.float32)
if arr.ndim != 2 or arr.shape[1] != expected_dim:
raise ValueError(f"Expected shape (T, {expected_dim}), got {arr.shape}")
return arr
def clean_near_zero_actions(
actions: np.ndarray,
transl_eps: float = 1e-4,
rot_eps: float = 1e-4,
) -> tuple[np.ndarray, np.ndarray]:
"""Zero tiny translation/rotation noise; gripper untouched.
Returns (cleaned_actions, zeroed_mask of shape (T, 2)) where columns are
[translation_zeroed, rotation_zeroed].
"""
cleaned = np.asarray(actions, dtype=np.float32).copy()
if cleaned.ndim != 2 or cleaned.shape[1] < 7:
raise ValueError(f"Expected 7D actions, got {cleaned.shape}")
transl_norm = np.linalg.norm(cleaned[:, :3], axis=1)
rot_norm = np.linalg.norm(cleaned[:, 3:6], axis=1)
transl_zeroed = transl_norm < transl_eps
rot_zeroed = rot_norm < rot_eps
cleaned[transl_zeroed, :3] = 0.0
cleaned[rot_zeroed, 3:6] = 0.0
return cleaned, np.stack([transl_zeroed, rot_zeroed], axis=1)
def _motion_class(action: np.ndarray, transl_eps: float, rot_eps: float) -> int:
has_translation = float(np.linalg.norm(action[:3])) >= transl_eps
has_rotation = float(np.linalg.norm(action[3:6])) >= rot_eps
if has_translation and has_rotation:
return 3
if has_translation:
return 1
if has_rotation:
return 2
return 0
def _cosine(a: np.ndarray, b: np.ndarray) -> float:
denom = float(np.linalg.norm(a) * np.linalg.norm(b))
if denom <= 1e-12:
return 1.0
return float(np.dot(a, b) / denom)
def segment_actions(
actions: np.ndarray,
config: SpeedTransformConfig,
) -> list[tuple[int, int, int]]:
"""Cut the episode into homogeneous-motion segments.
A boundary is placed when (a) motion_class changes (still/translate/rotate/
both) or (b) translation or rotation direction reverses (cosine below
``direction_cos_threshold``). Gripper open/close is NEVER a cut criterion --
it's preserved later as an interval anchor in ``_segment_boundaries``.
Returns a list of (start, end_exclusive, motion_class).
"""
if len(actions) == 0:
return []
classes = np.asarray(
[_motion_class(a, config.transl_eps, config.rot_eps) for a in actions],
dtype=np.int32,
)
segments: list[tuple[int, int, int]] = []
start = 0
for i in range(1, len(actions)):
boundary = classes[i] != classes[i - 1]
if not boundary and classes[i] in (1, 3):
boundary = _cosine(actions[i - 1, :3], actions[i, :3]) < config.direction_cos_threshold
if not boundary and classes[i] in (2, 3):
boundary = _cosine(actions[i - 1, 3:6], actions[i, 3:6]) < config.direction_cos_threshold
if boundary:
segments.append((start, i, int(classes[i - 1])))
start = i
segments.append((start, len(actions), int(classes[-1])))
if config.min_segment_len <= 1:
return segments
merged: list[tuple[int, int, int]] = []
for seg in segments:
if not merged or (seg[1] - seg[0]) >= config.min_segment_len:
merged.append(seg)
continue
prev_start, _prev_end, prev_cls = merged[-1]
merged[-1] = (prev_start, seg[1], prev_cls)
return merged
def _interp_cumulative(cumulative: np.ndarray, x: float) -> np.ndarray:
x = float(np.clip(x, 0.0, cumulative.shape[0] - 1))
left = int(np.floor(x))
right = min(left + 1, cumulative.shape[0] - 1)
alpha = np.float32(x - left)
return (1.0 - alpha) * cumulative[left] + alpha * cumulative[right]
_CHUNK_DENOM_LIMIT = 32
def _speed_chunk_ratio(speed: float) -> tuple[int, int]:
"""Decompose speed as q/p with small denominator: q source frames -> p output frames.
All ablation speeds (0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 4.0) fit
within ``_CHUNK_DENOM_LIMIT=32``. Raises for irrational speeds.
"""
f = Fraction(speed).limit_denominator(_CHUNK_DENOM_LIMIT)
if abs(float(f) - speed) > 1e-6:
raise ValueError(
f"speed={speed} cannot be expressed as a rational with denominator <= "
f"{_CHUNK_DENOM_LIMIT}; chunk_aligned_observation requires a rational speed."
)
return f.numerator, f.denominator
def _segment_boundaries(
src_actions: np.ndarray,
speed: float,
*,
chunk_aligned: bool = False,
chunk_phase: int = 0,
) -> np.ndarray:
"""Build resampling boundaries in source-frame coordinates.
chunk_aligned=True layout (speed=q/p, phase=r):
[0..r] leading passthrough (1:1, only when r > 0)
[r, r+q] chunk 0 -> p sub-bins via linspace
[r+q, r+2q] chunk 1
...
[r+Nq, n_src] trailing passthrough (1:1)
Each full chunk integrates exactly q source frames into p outputs. The
leftover frames at both ends are emitted verbatim so total motion is
preserved end-to-end.
chunk_aligned=False: uniform linspace(0, n_src, n_out+1) over the segment.
Gripper open/close indices are added as anchor boundaries afterward so a
switch event never gets averaged inside a coarse bin (segmentation never
looks at the gripper channel).
"""
n_src = len(src_actions)
if chunk_aligned:
q, p = _speed_chunk_ratio(speed)
if chunk_phase < 0 or chunk_phase >= q:
raise ValueError(f"chunk_phase must be in [0, {q}), got {chunk_phase}")
boundaries: list[float] = []
# Leading passthrough [0..r]: integer 1:1 boundaries.
if chunk_phase > 0:
boundaries.extend(float(i) for i in range(0, min(chunk_phase, n_src) + 1))
# Full chunks at r, r+q, r+2q, ... while the chunk fits.
chunk_starts = list(range(chunk_phase, max(n_src - q + 1, chunk_phase), q))
for start in chunk_starts:
boundaries.extend(np.linspace(float(start), float(start + q), p + 1).tolist())
# Trailing passthrough: each remaining source frame -> one verbatim output.
tail_start = chunk_starts[-1] + q if chunk_starts else min(chunk_phase, n_src)
if tail_start < n_src:
boundaries.extend(float(i) for i in range(tail_start, n_src + 1))
else:
n_out = max(1, int(round(n_src / speed)))
boundaries = [float(x) for x in np.linspace(0.0, float(n_src), n_out + 1)]
switch_indices = np.flatnonzero(np.abs(np.diff(src_actions[:, 6])) > 0.5) + 1
boundaries.extend(float(i) for i in switch_indices)
deduped = sorted(set(boundaries))
return np.asarray(deduped, dtype=np.float32)
def _resample_segment(
actions: np.ndarray,
states: np.ndarray,
source_frame_indices: np.ndarray,
start: int,
end: int,
motion_class: int,
speed: float,
chunk_aligned_observation: bool = False,
chunk_phase: int = 0,
) -> dict[str, np.ndarray]:
"""Atomic chunk resampling: cumulative + linear interpolation.
Key invariant: per output bin [left_t, right_t],
action[:6] = cumulative(right_t) - cumulative(left_t)
where cumulative is piecewise-linear over source indices. So summing
output actions over a full chunk reproduces the source's integrated
motion EXACTLY (no drift from non-integer speeds).
state at output j is the source state at floor(left_t); at chunk-start
boundaries (integer left_t) this is exact. gripper at output j is the
source gripper at ceil(right_t)-1 -- gripper-switch anchors in
``_segment_boundaries`` ensure a switch is never averaged inside a bin.
"""
src_actions = actions[start:end]
src_states = states[start:end]
src_frames = source_frame_indices[start:end]
n_src = end - start
if n_src <= 0:
raise ValueError("Cannot resample an empty segment")
boundaries = _segment_boundaries(
src_actions,
speed,
chunk_aligned=chunk_aligned_observation,
chunk_phase=chunk_phase,
)
n_out = len(boundaries) - 1
cumulative = np.concatenate(
[np.zeros((1, 6), dtype=np.float32), np.cumsum(src_actions[:, :6], axis=0)],
axis=0,
)
out_actions = np.zeros((n_out, 7), dtype=np.float32)
out_states = np.zeros((n_out, states.shape[1]), dtype=np.float32)
out_source_frames = np.zeros(n_out, dtype=np.int64)
out_source_steps = np.zeros(n_out, dtype=np.int64)
observation_mask = np.ones(n_out, dtype=np.int8)
used_sources: set[int] = set()
for j in range(n_out):
left_t = float(boundaries[j])
right_t = float(boundaries[j + 1])
out_actions[j, :6] = _interp_cumulative(cumulative, right_t) - _interp_cumulative(cumulative, left_t)
source_local = min(int(np.floor(left_t)), n_src - 1)
grip_local = min(max(int(np.ceil(right_t) - 1), 0), n_src - 1)
out_actions[j, 6] = src_actions[grip_local, 6]
out_states[j] = src_states[source_local]
out_source_frames[j] = int(src_frames[source_local])
out_source_steps[j] = start + source_local
# Legacy (non-chunk-aligned) slow-speed mask: when the same source frame
# gets duplicated across consecutive bins, only the first copy is valid.
if speed < 1.0 and source_local in used_sources:
observation_mask[j] = 0
used_sources.add(source_local)
if chunk_aligned_observation:
# mask=1 only at chunk-aligned outputs (and -- for phase=0 -- trailing
# passthrough rows). NOTE the phase asymmetry:
# phase=0 : full chunk starts + ALL trailing-passthrough rows.
# phase>0 : full chunk starts ONLY (leading + trailing passthrough
# rows are mask=0). Random phase per access in the online
# sampler rotates which source frames become valid.
q, _p = _speed_chunk_ratio(speed)
if chunk_phase < 0 or chunk_phase >= q:
raise ValueError(f"chunk_phase must be in [0, {q}), got {chunk_phase}")
new_mask = np.zeros(n_out, dtype=np.int8)
n_full = n_src // q
tail_threshold = n_full * q # phase=0 trailing passthrough starts here
for j in range(n_out):
left_t = float(boundaries[j])
if chunk_phase == 0:
if left_t >= tail_threshold and n_full > 0:
new_mask[j] = 1
else:
chunk_idx_f = left_t / q
chunk_idx = round(chunk_idx_f)
if abs(chunk_idx_f - chunk_idx) < 1e-6 and chunk_idx * q < n_src:
new_mask[j] = 1
else:
chunk_idx_f = (left_t - chunk_phase) / q
chunk_idx = round(chunk_idx_f)
chunk_start = chunk_phase + chunk_idx * q
if abs(chunk_idx_f - chunk_idx) < 1e-6 and chunk_start + q <= n_src:
new_mask[j] = 1
observation_mask = new_mask
return {
"action": out_actions,
"state": out_states,
"source_frame_index": out_source_frames,
"source_step_index": out_source_steps,
"observation_mask": observation_mask,
"segment_id": np.full(n_out, -1, dtype=np.int32),
"motion_class": np.full(n_out, motion_class, dtype=np.int8),
}
def transform_episode(
actions: np.ndarray,
states: np.ndarray,
source_frame_indices: np.ndarray,
speed: float,
config: SpeedTransformConfig,
) -> tuple[dict[str, np.ndarray], dict[str, float | int]]:
"""End-to-end episode resampling: clean -> segment -> resample each segment.
Outputs are concatenated across segments into one continuous trajectory.
The same ``chunk_phase`` is applied to every segment (one phase per call).
"""
if speed <= 0:
raise ValueError(f"Speed must be positive, got {speed}")
if config.chunk_phase != 0 and not config.chunk_aligned_observation:
raise ValueError("chunk_phase is only supported when chunk_aligned_observation=True")
actions = _as_float32_2d(actions, 7)
states = np.asarray(states, dtype=np.float32)
if states.ndim != 2:
raise ValueError(f"Expected 2D states, got {states.shape}")
source_frame_indices = np.asarray(source_frame_indices, dtype=np.int64)
if len(actions) != len(states) or len(actions) != len(source_frame_indices):
raise ValueError("actions, states, and source_frame_indices must have equal length")
cleaned_actions, clean_mask = clean_near_zero_actions(actions, config.clean_transl_eps, config.clean_rot_eps)
segments = segment_actions(cleaned_actions, config)
if not config.keep_still_segments:
segments = [s for s in segments if s[2] != 0]
pieces = []
for segment_id, (start, end, motion_class) in enumerate(segments):
piece = _resample_segment(
cleaned_actions,
states,
source_frame_indices,
start,
end,
motion_class,
speed,
chunk_aligned_observation=config.chunk_aligned_observation,
chunk_phase=config.chunk_phase,
)
piece["segment_id"][:] = segment_id
pieces.append(piece)
if not pieces:
raise ValueError("Episode produced no segments")
merged = {key: np.concatenate([piece[key] for piece in pieces], axis=0) for key in pieces[0]}
source_steps = merged["source_step_index"]
merged["cleaned_translation"] = clean_mask[source_steps, 0].astype(np.int8)
merged["cleaned_rotation"] = clean_mask[source_steps, 1].astype(np.int8)
merged["speed"] = np.full(len(merged["action"]), float(speed), dtype=np.float32)
merged["action_mask"] = np.ones(len(merged["action"]), dtype=np.int8)
merged["is_padded"] = (1 - merged["observation_mask"]).astype(np.int8)
metrics = compute_replay_metrics(actions, merged["action"], speed)
transl_zeroed = clean_mask[:, 0]
rot_zeroed = clean_mask[:, 1]
any_zeroed = transl_zeroed | rot_zeroed
both_zeroed = transl_zeroed & rot_zeroed
n_source = int(len(actions))
denom = max(n_source, 1)
# Short segments => fewer full chunks => more 1:1 passthrough leftover at
# non-integer speeds. seg_len distribution is the key diagnostic.
seg_lens = np.asarray([end - start for (start, end, _) in segments], dtype=np.int64)
seg_classes = np.asarray([cls for (_, _, cls) in segments], dtype=np.int64)
metrics.update(
{
"segment_count": len(segments),
"segments": len(segments),
"source_frames": n_source,
"output_frames": int(len(merged["action"])),
"padded_frames": int(np.sum(merged["is_padded"])),
"padded_ratio": float(np.mean(merged["is_padded"])),
"cleaned_translation_frames": int(np.sum(transl_zeroed)),
"cleaned_rotation_frames": int(np.sum(rot_zeroed)),
"cleaned_any_frames": int(np.sum(any_zeroed)),
"cleaned_both_frames": int(np.sum(both_zeroed)),
"cleaned_translation_ratio": float(np.sum(transl_zeroed) / denom),
"cleaned_rotation_ratio": float(np.sum(rot_zeroed) / denom),
"cleaned_any_ratio": float(np.sum(any_zeroed) / denom),
"cleaned_both_ratio": float(np.sum(both_zeroed) / denom),
"segment_len_min": int(seg_lens.min()) if seg_lens.size else 0,
"segment_len_max": int(seg_lens.max()) if seg_lens.size else 0,
"segment_len_mean": float(seg_lens.mean()) if seg_lens.size else 0.0,
"segment_len_median": float(np.median(seg_lens)) if seg_lens.size else 0.0,
"segment_len_p10": float(np.percentile(seg_lens, 10)) if seg_lens.size else 0.0,
"segment_len_p90": float(np.percentile(seg_lens, 90)) if seg_lens.size else 0.0,
"motion_class_still_count": int(np.sum(seg_classes == 0)),
"motion_class_translate_count": int(np.sum(seg_classes == 1)),
"motion_class_rotate_count": int(np.sum(seg_classes == 2)),
"motion_class_translate_rotate_count": int(np.sum(seg_classes == 3)),
}
)
return merged, metrics
def compute_replay_metrics(
source_actions: np.ndarray,
replay_actions: np.ndarray,
target_speed: float | None = None,
) -> dict[str, float | int]:
"""Sanity-check that resampling preserved integrated 6D motion.
Reports per-axis path lengths, integrated translation/rotation L2 error,
and gripper-switch count delta. NOT a task-success metric.
"""
source_actions = _as_float32_2d(source_actions, 7)
replay_actions = _as_float32_2d(replay_actions, 7)
source_motion = source_actions[:, :6].sum(axis=0)
replay_motion = replay_actions[:, :6].sum(axis=0)
source_steps = len(source_actions)
replay_steps = len(replay_actions)
actual_speed = source_steps / max(replay_steps, 1)
translation_path_source = float(np.linalg.norm(source_actions[:, :3], axis=1).sum())
translation_path_replay = float(np.linalg.norm(replay_actions[:, :3], axis=1).sum())
rotation_path_source = float(np.linalg.norm(source_actions[:, 3:6], axis=1).sum())
rotation_path_replay = float(np.linalg.norm(replay_actions[:, 3:6], axis=1).sum())
gripper_switches_source = int(np.sum(np.abs(np.diff(source_actions[:, 6])) > 0.5))
gripper_switches_replay = int(np.sum(np.abs(np.diff(replay_actions[:, 6])) > 0.5))
out: dict[str, float | int] = {
"target_speed": float(target_speed) if target_speed is not None else 1.0,
"source_steps": int(source_steps),
"replay_steps": int(replay_steps),
"actual_speed": float(actual_speed),
"speed_error": float(abs(actual_speed - target_speed)) if target_speed is not None else 0.0,
"translation_path_source": translation_path_source,
"translation_path_replay": translation_path_replay,
"translation_path_ratio": translation_path_replay / max(translation_path_source, 1e-12),
"rotation_path_source": rotation_path_source,
"rotation_path_replay": rotation_path_replay,
"rotation_path_ratio": rotation_path_replay / max(rotation_path_source, 1e-12),
"integrated_translation_l2_error": float(np.linalg.norm(source_motion[:3] - replay_motion[:3])),
"integrated_rotation_l2_error": float(np.linalg.norm(source_motion[3:] - replay_motion[3:])),
"gripper_switches_source": gripper_switches_source,
"gripper_switches_replay": gripper_switches_replay,
"gripper_switch_delta": int(gripper_switches_replay - gripper_switches_source),
}
return out