Spaces:
Running
Running
Deep trim viewer-only release
Browse files- cosmos-framework/cosmos_framework/data/imaginaire/__init__.py +0 -0
- cosmos-framework/cosmos_framework/data/imaginaire/webdataset/__init__.py +0 -0
- cosmos-framework/cosmos_framework/data/vfm/action/action_spec.py +0 -235
- cosmos-framework/cosmos_framework/data/vfm/action/av_dataset.py +18 -23
- cosmos-framework/cosmos_framework/data/vfm/action/bridge_orig_lerobot_dataset.py +0 -9
- cosmos-framework/cosmos_framework/data/vfm/action/camera_dataset.py +0 -15
- cosmos-framework/cosmos_framework/data/vfm/action/cosmos3_action_lerobot.py +2 -171
- cosmos-framework/cosmos_framework/data/vfm/action/domain_utils.py +0 -29
- cosmos-framework/cosmos_framework/data/vfm/action/droid_lerobot_dataset.py +0 -14
- cosmos-framework/cosmos_framework/data/vfm/action/fractal.py +0 -9
- cosmos-framework/cosmos_framework/data/vfm/action/pose_utils.py +0 -206
- cosmos-framework/cosmos_framework/data/vfm/action/robomind_franka_dataset.py +1 -41
- cosmos-framework/cosmos_framework/data/vfm/action/umi_lerobot_dataset.py +0 -4
cosmos-framework/cosmos_framework/data/imaginaire/__init__.py
DELETED
|
File without changes
|
cosmos-framework/cosmos_framework/data/imaginaire/webdataset/__init__.py
DELETED
|
File without changes
|
cosmos-framework/cosmos_framework/data/vfm/action/action_spec.py
DELETED
|
@@ -1,235 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: OpenMDW-1.1
|
| 3 |
-
|
| 4 |
-
"""Action-vector specification: per-dim type label + idle thresholds.
|
| 5 |
-
|
| 6 |
-
Single concept: every column of an action vector has a :class:`DimType` label.
|
| 7 |
-
Idle detection iterates by type and applies the matching algorithm:
|
| 8 |
-
|
| 9 |
-
POS → ‖action[pos_idx]‖ per arm < eps_t
|
| 10 |
-
ROT → distance(rot, identity) per group < eps_r
|
| 11 |
-
GRIPPER → max |Δgripper| < eps_g (frame 0 idle by convention)
|
| 12 |
-
JOINT → max |Δjoint| < joint_threshold (frame 0 idle)
|
| 13 |
-
RESERVED → ignored
|
| 14 |
-
|
| 15 |
-
An :class:`ActionSpec` is just ``names`` + ``types`` + ``rotation_format``.
|
| 16 |
-
Build one declaratively via :func:`build_action_spec` from DSL components::
|
| 17 |
-
|
| 18 |
-
build_action_spec(Pos(), Rot("rot6d"), Gripper()) # 10D single arm
|
| 19 |
-
build_action_spec(Pos(), Rot("rot6d")) # 9D no gripper
|
| 20 |
-
build_action_spec(Joint(n=14, label="arm"), # 30D joint-space
|
| 21 |
-
Joint(n=14, label="end"),
|
| 22 |
-
Joint(n=2, label="gripper"))
|
| 23 |
-
build_action_spec(Pos(prefix="left"), Rot("rot6d", "left"), Gripper(prefix="left"),
|
| 24 |
-
Pos(prefix="right"), Rot("rot6d", "right"), Gripper(prefix="right"))
|
| 25 |
-
|
| 26 |
-
Naming convention:
|
| 27 |
-
Default ``pos_x``, ``rot_0``, ``gripper``, ``arm_0`` ...
|
| 28 |
-
With ``prefix="left"`` (idempotent on trailing ``_``): ``left_pos_x`` ...
|
| 29 |
-
"""
|
| 30 |
-
|
| 31 |
-
from __future__ import annotations
|
| 32 |
-
|
| 33 |
-
from dataclasses import dataclass
|
| 34 |
-
from enum import Enum
|
| 35 |
-
from typing import ClassVar
|
| 36 |
-
|
| 37 |
-
from cosmos_framework.data.vfm.action.pose_utils import (
|
| 38 |
-
RotationConvention,
|
| 39 |
-
_identity_rotation_vector,
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
class DimType(str, Enum):
|
| 44 |
-
"""Per-column action-dim category (drives idle detection)."""
|
| 45 |
-
|
| 46 |
-
POS = "pos"
|
| 47 |
-
ROT = "rot"
|
| 48 |
-
GRIPPER = "gripper"
|
| 49 |
-
JOINT = "joint"
|
| 50 |
-
RESERVED = "reserved"
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
@dataclass(frozen=True, slots=True)
|
| 54 |
-
class ActionSpec:
|
| 55 |
-
"""Structural description of an action vector: names + per-dim types.
|
| 56 |
-
|
| 57 |
-
All ROT dims share a single ``rotation_format``; mixed formats in one spec
|
| 58 |
-
are not supported (raise at build time).
|
| 59 |
-
|
| 60 |
-
This struct contains no detection thresholds — those are passed at call
|
| 61 |
-
time to :func:`compute_idle_frames` so each dataset can tune them
|
| 62 |
-
independently of layout.
|
| 63 |
-
"""
|
| 64 |
-
|
| 65 |
-
names: list[str]
|
| 66 |
-
types: list[DimType]
|
| 67 |
-
rotation_format: RotationConvention = "rot6d"
|
| 68 |
-
|
| 69 |
-
@property
|
| 70 |
-
def dim(self) -> int:
|
| 71 |
-
return len(self.names)
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
# ---------------------------------------------------------------------------
|
| 75 |
-
# DSL components
|
| 76 |
-
# ---------------------------------------------------------------------------
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def _join_prefix(prefix: str, name: str) -> str:
|
| 80 |
-
"""Join ``prefix`` and ``name`` with a single ``_``; idempotent on trailing ``_``."""
|
| 81 |
-
return name if not prefix else f"{prefix.rstrip('_')}_{name}"
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
@dataclass(frozen=True)
|
| 85 |
-
class Pos:
|
| 86 |
-
"""Translation block.
|
| 87 |
-
|
| 88 |
-
Default 3D (``pos_x``, ``pos_y``, ``pos_z``). For planar tasks (e.g. PushT)
|
| 89 |
-
use ``Pos(dim=2)`` → ``pos_x``, ``pos_y``. ``dim >= 4`` falls back to
|
| 90 |
-
indexed names ``pos_0``, ``pos_1``, ...
|
| 91 |
-
"""
|
| 92 |
-
|
| 93 |
-
dim: int = 3
|
| 94 |
-
prefix: str = ""
|
| 95 |
-
type: ClassVar[DimType] = DimType.POS
|
| 96 |
-
|
| 97 |
-
def names(self) -> list[str]:
|
| 98 |
-
if self.dim <= 3:
|
| 99 |
-
return [_join_prefix(self.prefix, f"pos_{c}") for c in "xyz"[: self.dim]]
|
| 100 |
-
return [_join_prefix(self.prefix, f"pos_{i}") for i in range(self.dim)]
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
@dataclass(frozen=True)
|
| 104 |
-
class Rot:
|
| 105 |
-
"""Rotation block; ``format`` selects the encoding.
|
| 106 |
-
|
| 107 |
-
Supported formats and per-dim names:
|
| 108 |
-
|
| 109 |
-
- ``rot6d`` → 6 dims, ``rot_0`` ... ``rot_5`` (identity ``[1,0,0,0,1,0]``)
|
| 110 |
-
- ``rot9d`` → 9 dims, ``rot_0`` ... ``rot_8`` (identity ``[1,0,0,0,1,0,0,0,1]``)
|
| 111 |
-
- ``euler_xyz`` → 3 dims, ``roll``, ``pitch``, ``yaw`` (identity ``[0,0,0]``)
|
| 112 |
-
- ``axisangle`` → 3 dims, ``axang_x/y/z`` (identity ``[0,0,0]``)
|
| 113 |
-
- ``quat_xyzw`` / ``quat_wxyz`` → 4 dims, ``quat_x/y/z/w`` in declared order
|
| 114 |
-
"""
|
| 115 |
-
|
| 116 |
-
format: RotationConvention = "rot6d"
|
| 117 |
-
prefix: str = ""
|
| 118 |
-
type: ClassVar[DimType] = DimType.ROT
|
| 119 |
-
|
| 120 |
-
@property
|
| 121 |
-
def rotation_format(self) -> RotationConvention:
|
| 122 |
-
return self.format
|
| 123 |
-
|
| 124 |
-
@property
|
| 125 |
-
def dim(self) -> int:
|
| 126 |
-
return _identity_rotation_vector(self.format).shape[0]
|
| 127 |
-
|
| 128 |
-
def names(self) -> list[str]:
|
| 129 |
-
if self.format == "euler_xyz":
|
| 130 |
-
return [_join_prefix(self.prefix, c) for c in ("roll", "pitch", "yaw")]
|
| 131 |
-
if self.format == "axisangle":
|
| 132 |
-
return [_join_prefix(self.prefix, f"axang_{c}") for c in "xyz"]
|
| 133 |
-
if self.format.startswith("quat_"):
|
| 134 |
-
order = self.format.split("_", 1)[1] # "xyzw" or "wxyz"
|
| 135 |
-
return [_join_prefix(self.prefix, f"quat_{c}") for c in order]
|
| 136 |
-
return [_join_prefix(self.prefix, f"rot_{i}") for i in range(self.dim)]
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
@dataclass(frozen=True)
|
| 140 |
-
class Gripper:
|
| 141 |
-
"""1D gripper command (binary 0/1 or continuous). Detected by frame-diff."""
|
| 142 |
-
|
| 143 |
-
prefix: str = ""
|
| 144 |
-
type: ClassVar[DimType] = DimType.GRIPPER
|
| 145 |
-
|
| 146 |
-
@property
|
| 147 |
-
def dim(self) -> int:
|
| 148 |
-
return 1
|
| 149 |
-
|
| 150 |
-
def names(self) -> list[str]:
|
| 151 |
-
return [_join_prefix(self.prefix, "gripper")]
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
@dataclass(frozen=True)
|
| 155 |
-
class Joint:
|
| 156 |
-
"""``n`` joint commands. Detected by frame-diff against ``joint_threshold``."""
|
| 157 |
-
|
| 158 |
-
n: int = 0
|
| 159 |
-
label: str = "joint"
|
| 160 |
-
prefix: str = ""
|
| 161 |
-
type: ClassVar[DimType] = DimType.JOINT
|
| 162 |
-
|
| 163 |
-
@property
|
| 164 |
-
def dim(self) -> int:
|
| 165 |
-
return self.n
|
| 166 |
-
|
| 167 |
-
def names(self) -> list[str]:
|
| 168 |
-
return [_join_prefix(self.prefix, f"{self.label}_{i}") for i in range(self.n)]
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
@dataclass(frozen=True)
|
| 172 |
-
class Reserved:
|
| 173 |
-
"""``n`` dims counted in ``action_dim`` but ignored by idle detection."""
|
| 174 |
-
|
| 175 |
-
n: int = 0
|
| 176 |
-
label: str = "reserved"
|
| 177 |
-
prefix: str = ""
|
| 178 |
-
type: ClassVar[DimType] = DimType.RESERVED
|
| 179 |
-
|
| 180 |
-
@property
|
| 181 |
-
def dim(self) -> int:
|
| 182 |
-
return self.n
|
| 183 |
-
|
| 184 |
-
def names(self) -> list[str]:
|
| 185 |
-
return [_join_prefix(self.prefix, f"{self.label}_{i}") for i in range(self.n)]
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
# ---------------------------------------------------------------------------
|
| 189 |
-
# Builder
|
| 190 |
-
# ---------------------------------------------------------------------------
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
# Type alias for any DSL component. Not a runtime check — only annotation hint.
|
| 194 |
-
Component = Pos | Rot | Gripper | Joint | Reserved
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
def build_action_spec(*components: Component) -> ActionSpec:
|
| 198 |
-
"""Compose ``components`` into an :class:`ActionSpec`.
|
| 199 |
-
|
| 200 |
-
Each component contributes its ``names()`` and replicates its ``type`` for
|
| 201 |
-
every column it occupies. The first ROT component's ``rotation_format``
|
| 202 |
-
is captured for the whole spec; mixing formats raises ``ValueError``.
|
| 203 |
-
"""
|
| 204 |
-
names: list[str] = []
|
| 205 |
-
types: list[DimType] = []
|
| 206 |
-
rotation_format: RotationConvention | None = None
|
| 207 |
-
|
| 208 |
-
for c in components:
|
| 209 |
-
names.extend(c.names())
|
| 210 |
-
types.extend([c.type] * c.dim)
|
| 211 |
-
if c.type == DimType.ROT:
|
| 212 |
-
fmt = c.rotation_format # type: ignore[union-attr]
|
| 213 |
-
if rotation_format is None:
|
| 214 |
-
rotation_format = fmt
|
| 215 |
-
elif rotation_format != fmt:
|
| 216 |
-
raise ValueError(f"Mixed rotation_format in one ActionSpec: {rotation_format!r} vs {fmt!r}")
|
| 217 |
-
|
| 218 |
-
return ActionSpec(
|
| 219 |
-
names=names,
|
| 220 |
-
types=types,
|
| 221 |
-
rotation_format=rotation_format or "rot6d",
|
| 222 |
-
)
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
__all__ = [
|
| 226 |
-
"ActionSpec",
|
| 227 |
-
"Component",
|
| 228 |
-
"DimType",
|
| 229 |
-
"Gripper",
|
| 230 |
-
"Joint",
|
| 231 |
-
"Pos",
|
| 232 |
-
"Reserved",
|
| 233 |
-
"Rot",
|
| 234 |
-
"build_action_spec",
|
| 235 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cosmos-framework/cosmos_framework/data/vfm/action/av_dataset.py
CHANGED
|
@@ -37,8 +37,6 @@ from torch.utils.data import IterableDataset
|
|
| 37 |
# torch.multiprocessing.set_sharing_strategy("file_system")
|
| 38 |
from cosmos_framework.utils import log
|
| 39 |
from cosmos_framework.utils.easy_io import easy_io
|
| 40 |
-
from cosmos_framework.data.vfm.action.camera_dataset import get_target_size_and_crop
|
| 41 |
-
from cosmos_framework.data.vfm.action.domain_utils import get_domain_id
|
| 42 |
from cosmos_framework.data.vfm.action.pose_utils import (
|
| 43 |
RotationConvention,
|
| 44 |
build_abs_pose_from_components,
|
|
@@ -46,6 +44,24 @@ from cosmos_framework.data.vfm.action.pose_utils import (
|
|
| 46 |
)
|
| 47 |
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def decode_video_bytes(
|
| 50 |
video_bytes: bytes,
|
| 51 |
resolution: str | None = None,
|
|
@@ -509,7 +525,6 @@ class AVDataset(IterableDataset):
|
|
| 509 |
resolution: str | None = None,
|
| 510 |
fps: int = 10,
|
| 511 |
mode: str = "policy",
|
| 512 |
-
embodiment_type: str = "av",
|
| 513 |
split: str = "train",
|
| 514 |
seed: int = 0,
|
| 515 |
shuffle: bool = True,
|
|
@@ -529,11 +544,6 @@ class AVDataset(IterableDataset):
|
|
| 529 |
rotation_scale: float = 1.0,
|
| 530 |
max_action_translation_norm: float | None = None,
|
| 531 |
align_opencv_pose: bool = False,
|
| 532 |
-
# When True, use a separate domain ID for inverse dynamics / policy modes
|
| 533 |
-
# so that DomainAwareLinear learns different projections for anchored (conditioning)
|
| 534 |
-
# vs framewise (generation) action representations.
|
| 535 |
-
mode_aware_domain: bool = False,
|
| 536 |
-
inv_embodiment_type: str = "av_inv",
|
| 537 |
):
|
| 538 |
"""Initialize AVDataset.
|
| 539 |
|
|
@@ -543,7 +553,6 @@ class AVDataset(IterableDataset):
|
|
| 543 |
resolution: Target resolution for video frames (e.g. "256", "480"). If None, keeps original resolution.
|
| 544 |
fps: Target frames per second for video and actions.
|
| 545 |
mode: Training mode ('policy', 'forward_dynamics', 'inverse_dynamics', 'image2video', 'joint').
|
| 546 |
-
embodiment_type: Embodiment type for domain ID.
|
| 547 |
split: Dataset split ('train', 'val', or 'full').
|
| 548 |
seed: Random seed for shuffling.
|
| 549 |
shuffle: Whether to shuffle tar files during iteration (for training).
|
|
@@ -570,8 +579,6 @@ class AVDataset(IterableDataset):
|
|
| 570 |
align_opencv_pose: If True, transform pose rotations from car body-frame
|
| 571 |
convention (x=forward, y=left, z=up) to OpenCV camera convention
|
| 572 |
(x=right, y=down, z=forward) before computing relative actions.
|
| 573 |
-
mode_aware_domain: When True, inverse_dynamics/policy modes use a separate domain ID.
|
| 574 |
-
inv_embodiment_type: Embodiment type string for the inverse domain ID.
|
| 575 |
"""
|
| 576 |
super().__init__()
|
| 577 |
|
|
@@ -602,11 +609,6 @@ class AVDataset(IterableDataset):
|
|
| 602 |
self.rotation_scale = rotation_scale
|
| 603 |
self.max_action_translation_norm = max_action_translation_norm
|
| 604 |
self.align_opencv_pose = align_opencv_pose
|
| 605 |
-
# Get domain ID for this embodiment
|
| 606 |
-
self.domain_id = get_domain_id(embodiment_type)
|
| 607 |
-
self.mode_aware_domain = mode_aware_domain
|
| 608 |
-
self.domain_id_inv = get_domain_id(inv_embodiment_type) if mode_aware_domain else self.domain_id
|
| 609 |
-
|
| 610 |
# Validate mode
|
| 611 |
valid_modes = ["joint", "forward_dynamics", "inverse_dynamics", "policy", "image2video"]
|
| 612 |
if mode not in valid_modes:
|
|
@@ -864,11 +866,6 @@ class AVDataset(IterableDataset):
|
|
| 864 |
)
|
| 865 |
# prompt += f"Predict the future {future_duration:.1f}s action trajectory at {self.fps}Hz."
|
| 866 |
|
| 867 |
-
# Select domain ID: use inverse domain for generation modes when mode_aware_domain is on
|
| 868 |
-
if self.mode_aware_domain and mode in ["inverse_dynamics", "policy"]:
|
| 869 |
-
domain_id = self.domain_id_inv
|
| 870 |
-
else:
|
| 871 |
-
domain_id = self.domain_id
|
| 872 |
|
| 873 |
sample = {
|
| 874 |
"video": video,
|
|
@@ -881,7 +878,6 @@ class AVDataset(IterableDataset):
|
|
| 881 |
"ai_caption": prompt,
|
| 882 |
"mode": mode,
|
| 883 |
"__key__": key_tensor,
|
| 884 |
-
"domain_id": torch.tensor(domain_id, dtype=torch.long),
|
| 885 |
"history_length": actual_history_length,
|
| 886 |
"future_length": actual_future_length,
|
| 887 |
"viewpoint": "ego_view",
|
|
@@ -1001,7 +997,6 @@ if __name__ == "__main__":
|
|
| 1001 |
print(f"{'future_length':<25}: {data['future_length']}")
|
| 1002 |
print(f"{'conditioning_fps':<25}: {data['conditioning_fps'].item()}")
|
| 1003 |
print(f"{'mode':<25}: {data['mode']}")
|
| 1004 |
-
print(f"{'domain_id':<25}: {data['domain_id'].item()}")
|
| 1005 |
print(f"{'prompt':<25}: {data['prompt']}")
|
| 1006 |
|
| 1007 |
# save video
|
|
|
|
| 37 |
# torch.multiprocessing.set_sharing_strategy("file_system")
|
| 38 |
from cosmos_framework.utils import log
|
| 39 |
from cosmos_framework.utils.easy_io import easy_io
|
|
|
|
|
|
|
| 40 |
from cosmos_framework.data.vfm.action.pose_utils import (
|
| 41 |
RotationConvention,
|
| 42 |
build_abs_pose_from_components,
|
|
|
|
| 44 |
)
|
| 45 |
|
| 46 |
|
| 47 |
+
VIDEO_RES_SIZE_INFO: dict[str, dict[str, tuple[int, int]]] = {
|
| 48 |
+
"256": {"1,1": (256, 256), "4,3": (320, 256), "3,4": (256, 320), "16,9": (320, 192), "9,16": (192, 320)},
|
| 49 |
+
"480": {"1,1": (640, 640), "4,3": (736, 544), "3,4": (544, 736), "16,9": (832, 480), "9,16": (480, 832)},
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_target_size_and_crop(resolution: str, current_H: int, current_W: int) -> tuple[int, int, int, int]:
|
| 54 |
+
target_resolutions = VIDEO_RES_SIZE_INFO[resolution]
|
| 55 |
+
current_ar = current_W / current_H
|
| 56 |
+
best_key = min(
|
| 57 |
+
target_resolutions,
|
| 58 |
+
key=lambda key: abs((int(key.split(",")[0]) / int(key.split(",")[1])) - current_ar),
|
| 59 |
+
)
|
| 60 |
+
target_canvas_W, target_canvas_H = target_resolutions[best_key]
|
| 61 |
+
scaling_ratio = max(target_canvas_W / current_W, target_canvas_H / current_H)
|
| 62 |
+
return int(scaling_ratio * current_H + 0.5), int(scaling_ratio * current_W + 0.5), target_canvas_H, target_canvas_W
|
| 63 |
+
|
| 64 |
+
|
| 65 |
def decode_video_bytes(
|
| 66 |
video_bytes: bytes,
|
| 67 |
resolution: str | None = None,
|
|
|
|
| 525 |
resolution: str | None = None,
|
| 526 |
fps: int = 10,
|
| 527 |
mode: str = "policy",
|
|
|
|
| 528 |
split: str = "train",
|
| 529 |
seed: int = 0,
|
| 530 |
shuffle: bool = True,
|
|
|
|
| 544 |
rotation_scale: float = 1.0,
|
| 545 |
max_action_translation_norm: float | None = None,
|
| 546 |
align_opencv_pose: bool = False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
):
|
| 548 |
"""Initialize AVDataset.
|
| 549 |
|
|
|
|
| 553 |
resolution: Target resolution for video frames (e.g. "256", "480"). If None, keeps original resolution.
|
| 554 |
fps: Target frames per second for video and actions.
|
| 555 |
mode: Training mode ('policy', 'forward_dynamics', 'inverse_dynamics', 'image2video', 'joint').
|
|
|
|
| 556 |
split: Dataset split ('train', 'val', or 'full').
|
| 557 |
seed: Random seed for shuffling.
|
| 558 |
shuffle: Whether to shuffle tar files during iteration (for training).
|
|
|
|
| 579 |
align_opencv_pose: If True, transform pose rotations from car body-frame
|
| 580 |
convention (x=forward, y=left, z=up) to OpenCV camera convention
|
| 581 |
(x=right, y=down, z=forward) before computing relative actions.
|
|
|
|
|
|
|
| 582 |
"""
|
| 583 |
super().__init__()
|
| 584 |
|
|
|
|
| 609 |
self.rotation_scale = rotation_scale
|
| 610 |
self.max_action_translation_norm = max_action_translation_norm
|
| 611 |
self.align_opencv_pose = align_opencv_pose
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
# Validate mode
|
| 613 |
valid_modes = ["joint", "forward_dynamics", "inverse_dynamics", "policy", "image2video"]
|
| 614 |
if mode not in valid_modes:
|
|
|
|
| 866 |
)
|
| 867 |
# prompt += f"Predict the future {future_duration:.1f}s action trajectory at {self.fps}Hz."
|
| 868 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 869 |
|
| 870 |
sample = {
|
| 871 |
"video": video,
|
|
|
|
| 878 |
"ai_caption": prompt,
|
| 879 |
"mode": mode,
|
| 880 |
"__key__": key_tensor,
|
|
|
|
| 881 |
"history_length": actual_history_length,
|
| 882 |
"future_length": actual_future_length,
|
| 883 |
"viewpoint": "ego_view",
|
|
|
|
| 997 |
print(f"{'future_length':<25}: {data['future_length']}")
|
| 998 |
print(f"{'conditioning_fps':<25}: {data['conditioning_fps'].item()}")
|
| 999 |
print(f"{'mode':<25}: {data['mode']}")
|
|
|
|
| 1000 |
print(f"{'prompt':<25}: {data['prompt']}")
|
| 1001 |
|
| 1002 |
# save video
|
cosmos-framework/cosmos_framework/data/vfm/action/bridge_orig_lerobot_dataset.py
CHANGED
|
@@ -18,12 +18,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
|
| 18 |
|
| 19 |
from cosmos_framework.utils import log
|
| 20 |
from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import (
|
| 21 |
-
ActionSpec,
|
| 22 |
BaseActionLeRobotDataset,
|
| 23 |
-
Gripper,
|
| 24 |
-
Pos,
|
| 25 |
-
Rot,
|
| 26 |
-
build_action_spec,
|
| 27 |
)
|
| 28 |
from cosmos_framework.data.vfm.action.pose_utils import (
|
| 29 |
PoseConvention,
|
|
@@ -111,7 +106,6 @@ class BridgeOrigLeRobotDataset(BaseActionLeRobotDataset):
|
|
| 111 |
split_val_ratio=split_val_ratio,
|
| 112 |
split=split,
|
| 113 |
mode=mode,
|
| 114 |
-
embodiment_type="bridge_orig_lerobot",
|
| 115 |
viewpoint=viewpoint,
|
| 116 |
pose_convention=pose_convention,
|
| 117 |
rotation_format="rot6d",
|
|
@@ -240,9 +234,6 @@ class BridgeOrigLeRobotDataset(BaseActionLeRobotDataset):
|
|
| 240 |
# __getitem__
|
| 241 |
# ------------------------------------------------------------------
|
| 242 |
|
| 243 |
-
def _build_action_spec(self) -> ActionSpec:
|
| 244 |
-
"""Bridge: 10D = ``[Pos, Rot6d, Gripper]``."""
|
| 245 |
-
return build_action_spec(Pos(), Rot("rot6d"), Gripper())
|
| 246 |
|
| 247 |
def __getitem__(self, idx: int) -> dict[str, Any]:
|
| 248 |
""" """
|
|
|
|
| 18 |
|
| 19 |
from cosmos_framework.utils import log
|
| 20 |
from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import (
|
|
|
|
| 21 |
BaseActionLeRobotDataset,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
)
|
| 23 |
from cosmos_framework.data.vfm.action.pose_utils import (
|
| 24 |
PoseConvention,
|
|
|
|
| 106 |
split_val_ratio=split_val_ratio,
|
| 107 |
split=split,
|
| 108 |
mode=mode,
|
|
|
|
| 109 |
viewpoint=viewpoint,
|
| 110 |
pose_convention=pose_convention,
|
| 111 |
rotation_format="rot6d",
|
|
|
|
| 234 |
# __getitem__
|
| 235 |
# ------------------------------------------------------------------
|
| 236 |
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
def __getitem__(self, idx: int) -> dict[str, Any]:
|
| 239 |
""" """
|
cosmos-framework/cosmos_framework/data/vfm/action/camera_dataset.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: OpenMDW-1.1
|
| 3 |
-
|
| 4 |
-
VIDEO_RES_SIZE_INFO: dict[str, dict[str, tuple[int, int]]] = {
|
| 5 |
-
"256": {"1,1": (256, 256), "4,3": (320, 256), "3,4": (256, 320), "16,9": (320, 192), "9,16": (192, 320)},
|
| 6 |
-
"480": {"1,1": (640, 640), "4,3": (736, 544), "3,4": (544, 736), "16,9": (832, 480), "9,16": (480, 832)},
|
| 7 |
-
}
|
| 8 |
-
|
| 9 |
-
def get_target_size_and_crop(resolution: str, current_H: int, current_W: int) -> tuple[int, int, int, int]:
|
| 10 |
-
target_resolutions = VIDEO_RES_SIZE_INFO[resolution]
|
| 11 |
-
current_ar = current_W / current_H
|
| 12 |
-
best_key = min(target_resolutions, key=lambda key: abs((int(key.split(',')[0]) / int(key.split(',')[1])) - current_ar))
|
| 13 |
-
target_canvas_W, target_canvas_H = target_resolutions[best_key]
|
| 14 |
-
scaling_ratio = max(target_canvas_W / current_W, target_canvas_H / current_H)
|
| 15 |
-
return int(scaling_ratio * current_H + 0.5), int(scaling_ratio * current_W + 0.5), target_canvas_H, target_canvas_W
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cosmos-framework/cosmos_framework/data/vfm/action/cosmos3_action_lerobot.py
CHANGED
|
@@ -14,7 +14,6 @@ from __future__ import annotations
|
|
| 14 |
|
| 15 |
import importlib
|
| 16 |
import logging as _logging
|
| 17 |
-
import math
|
| 18 |
import os as _os
|
| 19 |
import random
|
| 20 |
from bisect import bisect_right
|
|
@@ -53,23 +52,7 @@ def _ensure_hf_hub_offline() -> None:
|
|
| 53 |
_hf_offline_applied = True
|
| 54 |
|
| 55 |
|
| 56 |
-
from functools import cached_property
|
| 57 |
-
|
| 58 |
from cosmos_framework.utils import log
|
| 59 |
-
# Re-export the action_spec DSL from this module so that subclass datasets
|
| 60 |
-
# only need a single import block (alongside ``BaseActionLeRobotDataset``).
|
| 61 |
-
from cosmos_framework.data.vfm.action.action_spec import ( # noqa: F401 (re-export)
|
| 62 |
-
ActionSpec,
|
| 63 |
-
DimType,
|
| 64 |
-
Gripper,
|
| 65 |
-
Joint,
|
| 66 |
-
Pos,
|
| 67 |
-
Reserved,
|
| 68 |
-
Rot,
|
| 69 |
-
build_action_spec,
|
| 70 |
-
)
|
| 71 |
-
from cosmos_framework.data.vfm.action.domain_utils import get_domain_id
|
| 72 |
-
from cosmos_framework.data.vfm.action.pose_utils import compute_idle_frames
|
| 73 |
from cosmos_framework.data.vfm.action.viewpoint_utils import Viewpoint
|
| 74 |
|
| 75 |
# ---------------------------------------------------------------------------
|
|
@@ -278,7 +261,6 @@ class BaseActionLeRobotDataset(Dataset):
|
|
| 278 |
split_val_ratio: float,
|
| 279 |
split: str,
|
| 280 |
mode: str,
|
| 281 |
-
embodiment_type: str,
|
| 282 |
viewpoint: Viewpoint,
|
| 283 |
pose_convention: str | None = None,
|
| 284 |
rotation_format: str | None = None,
|
|
@@ -301,7 +283,6 @@ class BaseActionLeRobotDataset(Dataset):
|
|
| 301 |
self._split_val_ratio = split_val_ratio
|
| 302 |
self._split = _normalize_split(split)
|
| 303 |
self._mode = mode
|
| 304 |
-
self._embodiment_type = embodiment_type
|
| 305 |
self._viewpoint: Viewpoint = viewpoint
|
| 306 |
self._pose_convention = pose_convention
|
| 307 |
self._rotation_format = rotation_format
|
|
@@ -331,7 +312,6 @@ class BaseActionLeRobotDataset(Dataset):
|
|
| 331 |
self._episode_records: list[tuple[int, int, int, int]] = []
|
| 332 |
self._episode_cum_ends: list[int] = []
|
| 333 |
self._num_valid_indices = 0
|
| 334 |
-
self._domain_id = get_domain_id(self._embodiment_type)
|
| 335 |
self._all_shard_roots: list[str] = []
|
| 336 |
|
| 337 |
# -- public properties ---------------------------------------------------
|
|
@@ -356,9 +336,6 @@ class BaseActionLeRobotDataset(Dataset):
|
|
| 356 |
def mode(self, value: str) -> None:
|
| 357 |
self._mode = value
|
| 358 |
|
| 359 |
-
@property
|
| 360 |
-
def domain_id(self) -> int:
|
| 361 |
-
return self._domain_id
|
| 362 |
|
| 363 |
# -- source registration -------------------------------------------------
|
| 364 |
|
|
@@ -679,138 +656,6 @@ class BaseActionLeRobotDataset(Dataset):
|
|
| 679 |
|
| 680 |
# -- result building -----------------------------------------------------
|
| 681 |
|
| 682 |
-
def _build_action_spec(self) -> ActionSpec | None:
|
| 683 |
-
"""Subclass override: declare this dataset's action layout.
|
| 684 |
-
|
| 685 |
-
Called once per instance — the result is cached by ``self.action_spec``.
|
| 686 |
-
Return ``None`` to skip spec-driven idle detection; in that case
|
| 687 |
-
``_compute_idle_frames`` will log a one-time warning and return
|
| 688 |
-
``None`` for every sample.
|
| 689 |
-
"""
|
| 690 |
-
return None
|
| 691 |
-
|
| 692 |
-
@cached_property
|
| 693 |
-
def action_spec(self) -> ActionSpec | None:
|
| 694 |
-
"""Cached :class:`ActionSpec` from ``_build_action_spec``.
|
| 695 |
-
|
| 696 |
-
Returns ``None`` when the subclass did not declare one; idle detection
|
| 697 |
-
is then skipped (with a one-time warning) until the subclass overrides
|
| 698 |
-
``_build_action_spec``.
|
| 699 |
-
"""
|
| 700 |
-
return self._build_action_spec()
|
| 701 |
-
|
| 702 |
-
@cached_property
|
| 703 |
-
def action_names(self) -> list[str] | None:
|
| 704 |
-
spec = self.action_spec
|
| 705 |
-
return spec.names if spec is not None else None
|
| 706 |
-
|
| 707 |
-
# Idle-detection thresholds. Defined as **velocities** (per second) so the
|
| 708 |
-
# same numeric value means the same physical motion across datasets with
|
| 709 |
-
# different sampling rates; converted to per-frame at call time using
|
| 710 |
-
# ``self._fps`` via :meth:`_resolve_idle_thresholds`.
|
| 711 |
-
#
|
| 712 |
-
# Defaults:
|
| 713 |
-
# - ``idle_eps_t_per_sec`` = 5 mm/s (≈ 1 mm/frame at 5 Hz)
|
| 714 |
-
# - ``idle_eps_r_per_sec`` = 1.5°/s (geodesic, rotation-format aware)
|
| 715 |
-
# - ``idle_eps_g`` = 1e-2 unit gripper Δ (no fps)
|
| 716 |
-
# - ``idle_joint_threshold_per_sec`` = 5e-3 rad/s
|
| 717 |
-
# - ``idle_min_streak`` = 3 require ≥ 3 consecutive
|
| 718 |
-
#
|
| 719 |
-
# Subclasses can either override the ``*_per_sec`` attributes (preferred —
|
| 720 |
-
# keeps the velocity semantics) or set the corresponding ``idle_eps_*`` /
|
| 721 |
-
# ``idle_joint_threshold`` attribute to a non-``None`` value to bypass the
|
| 722 |
-
# per-fps conversion entirely (raw per-frame override).
|
| 723 |
-
idle_eps_t_per_sec: float = 5e-3
|
| 724 |
-
idle_eps_r_per_sec: float = math.radians(1.5)
|
| 725 |
-
idle_eps_g: float = 1e-2
|
| 726 |
-
idle_joint_threshold_per_sec: float = 5e-3
|
| 727 |
-
idle_min_streak: int = 3
|
| 728 |
-
|
| 729 |
-
# Optional per-frame overrides. ``None`` (default) → use the ``*_per_sec``
|
| 730 |
-
# attribute / fps conversion above.
|
| 731 |
-
idle_eps_t: float | None = None
|
| 732 |
-
idle_eps_r: float | None = None
|
| 733 |
-
idle_joint_threshold: float | None = None
|
| 734 |
-
|
| 735 |
-
def _resolve_idle_thresholds(self) -> tuple[float, float, float, float]:
|
| 736 |
-
"""Resolve per-frame idle thresholds for this dataset instance.
|
| 737 |
-
|
| 738 |
-
Returns ``(eps_t, eps_r, eps_g, joint_threshold)`` in raw per-frame
|
| 739 |
-
units. Honours direct per-frame overrides if the subclass sets the
|
| 740 |
-
non-``_per_sec`` attribute; otherwise scales the ``_per_sec`` values
|
| 741 |
-
by ``self._fps``.
|
| 742 |
-
"""
|
| 743 |
-
fps = float(self._fps) if self._fps else 1.0
|
| 744 |
-
eps_t = self.idle_eps_t if self.idle_eps_t is not None else self.idle_eps_t_per_sec / fps
|
| 745 |
-
eps_r = self.idle_eps_r if self.idle_eps_r is not None else self.idle_eps_r_per_sec / fps
|
| 746 |
-
joint_thr = (
|
| 747 |
-
self.idle_joint_threshold
|
| 748 |
-
if self.idle_joint_threshold is not None
|
| 749 |
-
else self.idle_joint_threshold_per_sec / fps
|
| 750 |
-
)
|
| 751 |
-
return float(eps_t), float(eps_r), float(self.idle_eps_g), float(joint_thr)
|
| 752 |
-
|
| 753 |
-
def _compute_idle_frames(self, raw_action: torch.Tensor) -> torch.Tensor | None:
|
| 754 |
-
"""Count idle frames in the *raw* (un-normalized) action chunk.
|
| 755 |
-
|
| 756 |
-
Requires ``self.action_spec`` to be declared via ``_build_action_spec``.
|
| 757 |
-
Returns ``None`` when:
|
| 758 |
-
- ``pose_convention`` is not ``"backward_framewise"`` (TODO: extend),
|
| 759 |
-
- the subclass has not declared an ``ActionSpec`` (logs a one-time warning),
|
| 760 |
-
- the action layout does not match the declared spec.
|
| 761 |
-
|
| 762 |
-
Detection thresholds come from the ``idle_eps_*`` class attributes
|
| 763 |
-
(overridable per dataset). Subclasses can also override this method
|
| 764 |
-
outright, or pass an explicit ``idle_frames`` integer via
|
| 765 |
-
``**extras`` to :meth:`_build_result`.
|
| 766 |
-
"""
|
| 767 |
-
|
| 768 |
-
# conventions (anchored / absolute) need different idle semantics.
|
| 769 |
-
if self._pose_convention != "backward_framewise":
|
| 770 |
-
if not getattr(self, "_warned_pose_convention", False):
|
| 771 |
-
log.warning(
|
| 772 |
-
f"Dataset {self.__class__.__name__}: pose_convention="
|
| 773 |
-
f"{self._pose_convention!r} is not 'backward_framewise'; "
|
| 774 |
-
"skipping idle-frames detection. Centralize the dataset "
|
| 775 |
-
"to backward_framewise to enable IdleFrames captioning."
|
| 776 |
-
)
|
| 777 |
-
self._warned_pose_convention = True
|
| 778 |
-
return None
|
| 779 |
-
|
| 780 |
-
spec = self.action_spec
|
| 781 |
-
if spec is None:
|
| 782 |
-
if not getattr(self, "_warned_no_action_spec", False):
|
| 783 |
-
log.warning(
|
| 784 |
-
f"Dataset {self.__class__.__name__} has no action spec defined; "
|
| 785 |
-
"skipping idle-frames detection. Override _build_action_spec() to enable it."
|
| 786 |
-
)
|
| 787 |
-
self._warned_no_action_spec = True
|
| 788 |
-
return None
|
| 789 |
-
|
| 790 |
-
eps_t, eps_r, eps_g, joint_thr = self._resolve_idle_thresholds()
|
| 791 |
-
try:
|
| 792 |
-
n = compute_idle_frames(
|
| 793 |
-
raw_action,
|
| 794 |
-
spec,
|
| 795 |
-
eps_t=eps_t,
|
| 796 |
-
eps_r=eps_r,
|
| 797 |
-
eps_g=eps_g,
|
| 798 |
-
joint_threshold=joint_thr,
|
| 799 |
-
min_streak=self.idle_min_streak,
|
| 800 |
-
)
|
| 801 |
-
except (ValueError, TypeError) as e:
|
| 802 |
-
if not getattr(self, "_warned_action_layout", False):
|
| 803 |
-
log.warning(
|
| 804 |
-
f"Dataset {self.__class__.__name__}: action layout does "
|
| 805 |
-
f"not match the declared ActionSpec "
|
| 806 |
-
f"(action_dim={int(raw_action.shape[-1])}, "
|
| 807 |
-
f"spec.dim={spec.dim}); skipping idle-frames detection. "
|
| 808 |
-
f"Underlying error: {e}"
|
| 809 |
-
)
|
| 810 |
-
self._warned_action_layout = True
|
| 811 |
-
return None
|
| 812 |
-
return torch.tensor(n, dtype=torch.long)
|
| 813 |
-
|
| 814 |
def _build_result(
|
| 815 |
self,
|
| 816 |
*,
|
|
@@ -823,25 +668,12 @@ class BaseActionLeRobotDataset(Dataset):
|
|
| 823 |
"""Assemble the common return dict for ``__getitem__``.
|
| 824 |
|
| 825 |
``video`` is expected in raw LeRobot layout before final formatting.
|
| 826 |
-
Subclasses may pass extra
|
| 827 |
-
``idle_frames`` is auto-computed from the raw (un-normalized) ``action``
|
| 828 |
-
whenever the dataset's pose/rotation conventions allow it; subclasses
|
| 829 |
-
can override by passing ``idle_frames`` (int or scalar tensor) via
|
| 830 |
``**extras``.
|
| 831 |
"""
|
| 832 |
-
# Compute idle_frames from the raw action before normalization, unless
|
| 833 |
-
# the subclass has provided one explicitly via ``**extras``.
|
| 834 |
-
if "idle_frames" not in extras:
|
| 835 |
-
idle_frames = self._compute_idle_frames(action)
|
| 836 |
-
if idle_frames is not None:
|
| 837 |
-
extras = {"idle_frames": idle_frames, **extras}
|
| 838 |
-
|
| 839 |
raw_action = action # [T,D]
|
| 840 |
if self._skip_video_loading:
|
| 841 |
-
|
| 842 |
-
if "idle_frames" in extras:
|
| 843 |
-
result["idle_frames"] = extras["idle_frames"]
|
| 844 |
-
return result
|
| 845 |
formatted_video = self._convert_video(video) # [C,T,H,W] | None
|
| 846 |
return {
|
| 847 |
"ai_caption": ai_caption,
|
|
@@ -849,7 +681,6 @@ class BaseActionLeRobotDataset(Dataset):
|
|
| 849 |
"action": raw_action,
|
| 850 |
"conditioning_fps": torch.tensor(self._fps, dtype=torch.long),
|
| 851 |
"mode": mode,
|
| 852 |
-
"domain_id": torch.tensor(self._domain_id, dtype=torch.long),
|
| 853 |
"viewpoint": self._viewpoint,
|
| 854 |
**extras,
|
| 855 |
}
|
|
|
|
| 14 |
|
| 15 |
import importlib
|
| 16 |
import logging as _logging
|
|
|
|
| 17 |
import os as _os
|
| 18 |
import random
|
| 19 |
from bisect import bisect_right
|
|
|
|
| 52 |
_hf_offline_applied = True
|
| 53 |
|
| 54 |
|
|
|
|
|
|
|
| 55 |
from cosmos_framework.utils import log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
from cosmos_framework.data.vfm.action.viewpoint_utils import Viewpoint
|
| 57 |
|
| 58 |
# ---------------------------------------------------------------------------
|
|
|
|
| 261 |
split_val_ratio: float,
|
| 262 |
split: str,
|
| 263 |
mode: str,
|
|
|
|
| 264 |
viewpoint: Viewpoint,
|
| 265 |
pose_convention: str | None = None,
|
| 266 |
rotation_format: str | None = None,
|
|
|
|
| 283 |
self._split_val_ratio = split_val_ratio
|
| 284 |
self._split = _normalize_split(split)
|
| 285 |
self._mode = mode
|
|
|
|
| 286 |
self._viewpoint: Viewpoint = viewpoint
|
| 287 |
self._pose_convention = pose_convention
|
| 288 |
self._rotation_format = rotation_format
|
|
|
|
| 312 |
self._episode_records: list[tuple[int, int, int, int]] = []
|
| 313 |
self._episode_cum_ends: list[int] = []
|
| 314 |
self._num_valid_indices = 0
|
|
|
|
| 315 |
self._all_shard_roots: list[str] = []
|
| 316 |
|
| 317 |
# -- public properties ---------------------------------------------------
|
|
|
|
| 336 |
def mode(self, value: str) -> None:
|
| 337 |
self._mode = value
|
| 338 |
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
# -- source registration -------------------------------------------------
|
| 341 |
|
|
|
|
| 656 |
|
| 657 |
# -- result building -----------------------------------------------------
|
| 658 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
def _build_result(
|
| 660 |
self,
|
| 661 |
*,
|
|
|
|
| 668 |
"""Assemble the common return dict for ``__getitem__``.
|
| 669 |
|
| 670 |
``video`` is expected in raw LeRobot layout before final formatting.
|
| 671 |
+
Subclasses may pass extra viewer metadata (e.g. ``initial_pose``) via
|
|
|
|
|
|
|
|
|
|
| 672 |
``**extras``.
|
| 673 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 674 |
raw_action = action # [T,D]
|
| 675 |
if self._skip_video_loading:
|
| 676 |
+
return {"action": raw_action}
|
|
|
|
|
|
|
|
|
|
| 677 |
formatted_video = self._convert_video(video) # [C,T,H,W] | None
|
| 678 |
return {
|
| 679 |
"ai_caption": ai_caption,
|
|
|
|
| 681 |
"action": raw_action,
|
| 682 |
"conditioning_fps": torch.tensor(self._fps, dtype=torch.long),
|
| 683 |
"mode": mode,
|
|
|
|
| 684 |
"viewpoint": self._viewpoint,
|
| 685 |
**extras,
|
| 686 |
}
|
cosmos-framework/cosmos_framework/data/vfm/action/domain_utils.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: OpenMDW-1.1
|
| 3 |
-
|
| 4 |
-
"""Domain ID helpers for cross-embodiment action datasets."""
|
| 5 |
-
|
| 6 |
-
EMBODIMENT_TO_DOMAIN_ID: dict[str, int] = {
|
| 7 |
-
"no_action": 0,
|
| 8 |
-
"av": 1,
|
| 9 |
-
"camera_pose": 2,
|
| 10 |
-
"pusht": 4,
|
| 11 |
-
"umi": 6,
|
| 12 |
-
"bridge_orig_lerobot": 7,
|
| 13 |
-
"droid_lerobot": 8,
|
| 14 |
-
"robomind-franka": 8, # Both Droid and RoboMIND-Franka are using robotiq and franka
|
| 15 |
-
"embodiment_b": 9,
|
| 16 |
-
"robomind-franka-dual": 12,
|
| 17 |
-
"fractal": 20,
|
| 18 |
-
}
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def get_domain_id(embodiment_type: str) -> int:
|
| 22 |
-
"""Get the domain ID for a given embodiment type."""
|
| 23 |
-
key = embodiment_type.lower().strip()
|
| 24 |
-
if key not in EMBODIMENT_TO_DOMAIN_ID:
|
| 25 |
-
raise KeyError(
|
| 26 |
-
f"Unknown embodiment type: {embodiment_type!r}. "
|
| 27 |
-
f"Available embodiments: {sorted(EMBODIMENT_TO_DOMAIN_ID.keys())}"
|
| 28 |
-
)
|
| 29 |
-
return EMBODIMENT_TO_DOMAIN_ID[key]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cosmos-framework/cosmos_framework/data/vfm/action/droid_lerobot_dataset.py
CHANGED
|
@@ -11,13 +11,7 @@ from scipy.spatial.transform import Rotation as R
|
|
| 11 |
|
| 12 |
from cosmos_framework.utils import log
|
| 13 |
from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import (
|
| 14 |
-
ActionSpec,
|
| 15 |
BaseActionLeRobotDataset,
|
| 16 |
-
Gripper,
|
| 17 |
-
Joint,
|
| 18 |
-
Pos,
|
| 19 |
-
Rot,
|
| 20 |
-
build_action_spec,
|
| 21 |
build_episode_spans,
|
| 22 |
split_episode_ids,
|
| 23 |
)
|
|
@@ -87,7 +81,6 @@ class DROIDLeRobotDataset(BaseActionLeRobotDataset):
|
|
| 87 |
split_val_ratio=split_val_ratio,
|
| 88 |
split=split,
|
| 89 |
mode=mode,
|
| 90 |
-
embodiment_type="droid_lerobot",
|
| 91 |
viewpoint=viewpoint,
|
| 92 |
pose_convention=pose_convention,
|
| 93 |
rotation_format="rot6d",
|
|
@@ -307,13 +300,6 @@ class DROIDLeRobotDataset(BaseActionLeRobotDataset):
|
|
| 307 |
composite = torch.cat([wrist, bottom], dim=-2) # [T,C,3H/2,W]
|
| 308 |
return composite # [T,C,3H/2,W]
|
| 309 |
|
| 310 |
-
def _build_action_spec(self) -> ActionSpec:
|
| 311 |
-
"""DROID: 10D ``[Pos, Rot6d, Gripper]`` for ``ee_pose``,
|
| 312 |
-
8D ``[Joint(7), Gripper]`` for ``joint_pos``.
|
| 313 |
-
"""
|
| 314 |
-
if self._action_space == "joint_pos":
|
| 315 |
-
return build_action_spec(Joint(n=7, label="joint"), Gripper())
|
| 316 |
-
return build_action_spec(Pos(), Rot("rot6d"), Gripper())
|
| 317 |
|
| 318 |
def __getitem__(self, idx: int) -> dict[str, Any]:
|
| 319 |
""" """
|
|
|
|
| 11 |
|
| 12 |
from cosmos_framework.utils import log
|
| 13 |
from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import (
|
|
|
|
| 14 |
BaseActionLeRobotDataset,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
build_episode_spans,
|
| 16 |
split_episode_ids,
|
| 17 |
)
|
|
|
|
| 81 |
split_val_ratio=split_val_ratio,
|
| 82 |
split=split,
|
| 83 |
mode=mode,
|
|
|
|
| 84 |
viewpoint=viewpoint,
|
| 85 |
pose_convention=pose_convention,
|
| 86 |
rotation_format="rot6d",
|
|
|
|
| 300 |
composite = torch.cat([wrist, bottom], dim=-2) # [T,C,3H/2,W]
|
| 301 |
return composite # [T,C,3H/2,W]
|
| 302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
def __getitem__(self, idx: int) -> dict[str, Any]:
|
| 305 |
""" """
|
cosmos-framework/cosmos_framework/data/vfm/action/fractal.py
CHANGED
|
@@ -15,12 +15,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
|
| 15 |
|
| 16 |
from cosmos_framework.utils import log
|
| 17 |
from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import (
|
| 18 |
-
ActionSpec,
|
| 19 |
BaseActionLeRobotDataset,
|
| 20 |
-
Gripper,
|
| 21 |
-
Pos,
|
| 22 |
-
Rot,
|
| 23 |
-
build_action_spec,
|
| 24 |
)
|
| 25 |
from cosmos_framework.data.vfm.action.pose_utils import (
|
| 26 |
PoseConvention,
|
|
@@ -112,7 +107,6 @@ class FractalLeRobotDataset(BaseActionLeRobotDataset):
|
|
| 112 |
split_val_ratio=split_val_ratio,
|
| 113 |
split=split,
|
| 114 |
mode=mode,
|
| 115 |
-
embodiment_type="fractal",
|
| 116 |
viewpoint=viewpoint,
|
| 117 |
pose_convention=pose_convention,
|
| 118 |
rotation_format="rot6d",
|
|
@@ -141,9 +135,6 @@ class FractalLeRobotDataset(BaseActionLeRobotDataset):
|
|
| 141 |
)
|
| 142 |
return kept
|
| 143 |
|
| 144 |
-
def _build_action_spec(self) -> ActionSpec:
|
| 145 |
-
"""Fractal: 10D = ``[Pos(3), Rot6d(6), Gripper(1)]``."""
|
| 146 |
-
return build_action_spec(Pos(dim=3), Rot("rot6d"), Gripper())
|
| 147 |
|
| 148 |
def __getitem__(self, idx: int) -> dict[str, Any]:
|
| 149 |
"""Return a single training sample."""
|
|
|
|
| 15 |
|
| 16 |
from cosmos_framework.utils import log
|
| 17 |
from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import (
|
|
|
|
| 18 |
BaseActionLeRobotDataset,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
)
|
| 20 |
from cosmos_framework.data.vfm.action.pose_utils import (
|
| 21 |
PoseConvention,
|
|
|
|
| 107 |
split_val_ratio=split_val_ratio,
|
| 108 |
split=split,
|
| 109 |
mode=mode,
|
|
|
|
| 110 |
viewpoint=viewpoint,
|
| 111 |
pose_convention=pose_convention,
|
| 112 |
rotation_format="rot6d",
|
|
|
|
| 135 |
)
|
| 136 |
return kept
|
| 137 |
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
def __getitem__(self, idx: int) -> dict[str, Any]:
|
| 140 |
"""Return a single training sample."""
|
cosmos-framework/cosmos_framework/data/vfm/action/pose_utils.py
CHANGED
|
@@ -19,7 +19,6 @@ dataset stack:
|
|
| 19 |
canonical public entrypoint for representation conversion.
|
| 20 |
"""
|
| 21 |
|
| 22 |
-
import math
|
| 23 |
from typing import Literal
|
| 24 |
|
| 25 |
import numpy as np
|
|
@@ -540,208 +539,3 @@ def pose_rel_to_abs(
|
|
| 540 |
current_pose = next_pose
|
| 541 |
|
| 542 |
return np.stack(poses_abs) # [T,4,4]
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
# -----------------------------------------------------------------------------
|
| 546 |
-
# Idle-frame detection
|
| 547 |
-
# -----------------------------------------------------------------------------
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
def _identity_rotation_vector(rotation_format: RotationConvention) -> np.ndarray:
|
| 551 |
-
"""Return the identity-rotation vector for a given rotation convention.
|
| 552 |
-
|
| 553 |
-
Used by :func:`compute_idle_frames` to test whether a rotation block is
|
| 554 |
-
close to "no rotation" in its current encoding.
|
| 555 |
-
"""
|
| 556 |
-
if rotation_format in ("matrix", "rot9d"):
|
| 557 |
-
return np.array([1, 0, 0, 0, 1, 0, 0, 0, 1], dtype=np.float32)
|
| 558 |
-
if rotation_format == "rot6d":
|
| 559 |
-
return np.array([1, 0, 0, 0, 1, 0], dtype=np.float32)
|
| 560 |
-
if rotation_format == "quat_xyzw":
|
| 561 |
-
return np.array([0, 0, 0, 1], dtype=np.float32)
|
| 562 |
-
if rotation_format == "quat_wxyz":
|
| 563 |
-
return np.array([1, 0, 0, 0], dtype=np.float32)
|
| 564 |
-
if rotation_format in ("euler_xyz", "axisangle"):
|
| 565 |
-
return np.array([0, 0, 0], dtype=np.float32)
|
| 566 |
-
raise ValueError(f"Unsupported rotation_format={rotation_format!r}")
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
def _rotation_angle_per_arm(rotations: np.ndarray, rotation_format: str) -> np.ndarray:
|
| 570 |
-
"""Geodesic angle (rad) from identity for each arm at each frame.
|
| 571 |
-
|
| 572 |
-
``rotations`` has shape ``(T, n_arms, n_per_arm)``; the returned array has
|
| 573 |
-
shape ``(T, n_arms)``. The angle is rotation-format aware so a fixed
|
| 574 |
-
``eps_r`` threshold has consistent geometric meaning across formats:
|
| 575 |
-
|
| 576 |
-
- ``rot6d`` → reconstruct ``trace(R)`` in closed form from the two stored
|
| 577 |
-
columns ``a, b`` (already unit-orthogonal as they came from a valid
|
| 578 |
-
rotation matrix). The third column is ``a × b``, so
|
| 579 |
-
``trace(R) = a[0] + b[1] + a[0]·b[1] - a[1]·b[0]``.
|
| 580 |
-
``angle = arccos(clip((trace - 1) / 2, -1, 1))``.
|
| 581 |
-
- ``rot9d`` → reshape to ``(..., 3, 3)`` and use
|
| 582 |
-
``trace(R) = R[0,0] + R[1,1] + R[2,2]``.
|
| 583 |
-
- ``quat_xyzw`` / ``quat_wxyz`` → ``angle = 2 · arccos(|q_w|)``; the
|
| 584 |
-
absolute value handles the double cover (``q`` and ``-q`` represent the
|
| 585 |
-
same rotation).
|
| 586 |
-
- ``axisangle`` → the magnitude of the axis-angle vector *is* the angle.
|
| 587 |
-
- ``euler_xyz`` → no closed-form angle; use ``‖euler‖`` as a conservative
|
| 588 |
-
upper bound (exact for single-axis rotations, an overestimate for
|
| 589 |
-
composed ones — fine for idle detection where small angles are the
|
| 590 |
-
regime of interest).
|
| 591 |
-
"""
|
| 592 |
-
if rotation_format == "rot6d":
|
| 593 |
-
a = rotations[..., :3]
|
| 594 |
-
b = rotations[..., 3:6]
|
| 595 |
-
trace = a[..., 0] + b[..., 1] + a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0]
|
| 596 |
-
return np.arccos(np.clip((trace - 1.0) / 2.0, -1.0, 1.0))
|
| 597 |
-
if rotation_format == "rot9d":
|
| 598 |
-
mat = rotations.reshape(*rotations.shape[:-1], 3, 3)
|
| 599 |
-
trace = mat[..., 0, 0] + mat[..., 1, 1] + mat[..., 2, 2]
|
| 600 |
-
return np.arccos(np.clip((trace - 1.0) / 2.0, -1.0, 1.0))
|
| 601 |
-
if rotation_format in ("quat_xyzw", "quat_wxyz"):
|
| 602 |
-
qw = rotations[..., 3] if rotation_format == "quat_xyzw" else rotations[..., 0]
|
| 603 |
-
return 2.0 * np.arccos(np.clip(np.abs(qw), 0.0, 1.0))
|
| 604 |
-
if rotation_format == "axisangle":
|
| 605 |
-
return np.linalg.norm(rotations, axis=-1)
|
| 606 |
-
if rotation_format == "euler_xyz":
|
| 607 |
-
# Exact for single-axis rotations, overestimate for composed ones —
|
| 608 |
-
# safe for idle thresholds since overestimation can only mark a frame
|
| 609 |
-
# as non-idle, never spuriously idle.
|
| 610 |
-
return np.linalg.norm(rotations, axis=-1)
|
| 611 |
-
raise ValueError(f"Unsupported rotation_format={rotation_format!r}")
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
def _consecutive_streaks(idle: np.ndarray, min_streak: int) -> np.ndarray:
|
| 615 |
-
"""Zero out idle bits not belonging to a run of ``>= min_streak`` Trues.
|
| 616 |
-
|
| 617 |
-
Pure-numpy two-pointer scan. ``min_streak <= 1`` is a no-op (returns the
|
| 618 |
-
input mask unchanged).
|
| 619 |
-
"""
|
| 620 |
-
if min_streak <= 1:
|
| 621 |
-
return idle
|
| 622 |
-
out = np.zeros_like(idle)
|
| 623 |
-
n = len(idle)
|
| 624 |
-
i = 0
|
| 625 |
-
while i < n:
|
| 626 |
-
if not idle[i]:
|
| 627 |
-
i += 1
|
| 628 |
-
continue
|
| 629 |
-
j = i
|
| 630 |
-
while j < n and idle[j]:
|
| 631 |
-
j += 1
|
| 632 |
-
if j - i >= min_streak:
|
| 633 |
-
out[i:j] = True
|
| 634 |
-
i = j
|
| 635 |
-
return out
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
def compute_idle_frames(
|
| 639 |
-
action_raw: torch.Tensor | np.ndarray,
|
| 640 |
-
spec: "ActionSpec", # noqa: F821 — forward ref, real import is in action_spec.py
|
| 641 |
-
*,
|
| 642 |
-
eps_t: float = 1e-3,
|
| 643 |
-
eps_r: float = math.radians(5.0),
|
| 644 |
-
eps_g: float = 1e-2,
|
| 645 |
-
joint_threshold: float = 5e-4,
|
| 646 |
-
min_streak: int = 3,
|
| 647 |
-
) -> int:
|
| 648 |
-
"""Count idle frames in a raw (un-normalized) action chunk.
|
| 649 |
-
|
| 650 |
-
Idle detection runs per-DimType (driven by ``spec.types``); a frame is
|
| 651 |
-
*raw-idle* iff every relevant type group is idle on that frame, and
|
| 652 |
-
counts toward the final tally only if it belongs to a run of at least
|
| 653 |
-
``min_streak`` consecutive raw-idle frames. The streak filter rejects
|
| 654 |
-
isolated low-motion frames (instantaneous slowdowns) which carry weak
|
| 655 |
-
physical meaning and add noise to the IdleFrames training signal.
|
| 656 |
-
|
| 657 |
-
DimType branches:
|
| 658 |
-
|
| 659 |
-
- ``POS`` → combined ``‖action[pos_idx]‖`` (L2 across all POS dims)
|
| 660 |
-
< ``eps_t``. For single-arm specs (3 dims) this is the standard ``‖t‖``
|
| 661 |
-
check; for multi-arm specs the combined norm is slightly stricter than
|
| 662 |
-
a per-arm check.
|
| 663 |
-
- ``ROT`` → per-arm geodesic rotation angle (rad) from identity
|
| 664 |
-
< ``eps_r``. The angle is computed in a rotation-format aware way (see
|
| 665 |
-
:func:`_rotation_angle_per_arm`) so the threshold has consistent
|
| 666 |
-
geometric meaning regardless of the encoding.
|
| 667 |
-
- ``GRIPPER`` → ``max |action[t] - action[t-1]| < eps_g``. ``np.diff``
|
| 668 |
-
with ``prepend=action[0]`` makes step 0 ``|0|`` (treated as "no change");
|
| 669 |
-
with the streak filter this can no longer create a spurious single-frame
|
| 670 |
-
idle event.
|
| 671 |
-
- ``JOINT`` → same frame-diff scheme as gripper with
|
| 672 |
-
``joint_threshold`` (rad / step).
|
| 673 |
-
- ``RESERVED`` → ignored.
|
| 674 |
-
|
| 675 |
-
Defaults (in the units of the un-normalized action):
|
| 676 |
-
|
| 677 |
-
- ``eps_t = 1e-3`` → 1 mm per-frame translation
|
| 678 |
-
- ``eps_r = 5°`` → 5° per-frame rotation (geodesic angle)
|
| 679 |
-
- ``eps_g = 1e-2`` → 1 % gripper command change
|
| 680 |
-
- ``joint_threshold = 5e-4`` → ~0.03° / step joint angle change
|
| 681 |
-
- ``min_streak = 3`` → require a run of >= 3 consecutive idle frames
|
| 682 |
-
|
| 683 |
-
The input must be **un-normalized** so the identity transform sits at
|
| 684 |
-
known coordinates (translation ≈ 0, rotation ≈ identity). The action
|
| 685 |
-
vector is also assumed to be encoded in a per-step / framewise convention
|
| 686 |
-
(e.g. ``backward_framewise``); anchored conventions (``backward_anchored``)
|
| 687 |
-
accumulate over the chunk and would silently break the POS/ROT idle
|
| 688 |
-
checks. Callers (e.g. the LeRobot base class) gate on pose convention
|
| 689 |
-
before calling this function.
|
| 690 |
-
"""
|
| 691 |
-
if isinstance(action_raw, torch.Tensor):
|
| 692 |
-
action = action_raw.detach().cpu().numpy().astype(np.float32, copy=False)
|
| 693 |
-
else:
|
| 694 |
-
action = np.asarray(action_raw, dtype=np.float32)
|
| 695 |
-
|
| 696 |
-
if action.ndim != 2:
|
| 697 |
-
raise ValueError(f"action_raw must be 2-D (T, D); got shape {action.shape}")
|
| 698 |
-
num_frames, action_dim = action.shape
|
| 699 |
-
if num_frames == 0:
|
| 700 |
-
return 0
|
| 701 |
-
if action_dim != len(spec.types):
|
| 702 |
-
raise ValueError(f"action_dim={action_dim} does not match spec.dim={len(spec.types)}")
|
| 703 |
-
|
| 704 |
-
# Import locally to avoid a circular import at module load time
|
| 705 |
-
# (action_spec.py imports RotationConvention from this file).
|
| 706 |
-
from cosmos_framework.data.vfm.action.action_spec import DimType
|
| 707 |
-
|
| 708 |
-
pos_idx = [i for i, t in enumerate(spec.types) if t == DimType.POS]
|
| 709 |
-
rot_idx = [i for i, t in enumerate(spec.types) if t == DimType.ROT]
|
| 710 |
-
grip_idx = [i for i, t in enumerate(spec.types) if t == DimType.GRIPPER]
|
| 711 |
-
joint_idx = [i for i, t in enumerate(spec.types) if t == DimType.JOINT]
|
| 712 |
-
|
| 713 |
-
idle = np.ones(num_frames, dtype=bool)
|
| 714 |
-
|
| 715 |
-
# POS: combined L2 norm across all translation dims.
|
| 716 |
-
if pos_idx:
|
| 717 |
-
idle &= np.linalg.norm(action[:, pos_idx], axis=1) < eps_t
|
| 718 |
-
|
| 719 |
-
# ROT: per-arm geodesic angle (rad).
|
| 720 |
-
if rot_idx:
|
| 721 |
-
rot_id = _identity_rotation_vector(spec.rotation_format)
|
| 722 |
-
n_per_arm = rot_id.shape[0]
|
| 723 |
-
if len(rot_idx) % n_per_arm != 0:
|
| 724 |
-
raise ValueError(
|
| 725 |
-
f"ROT dims ({len(rot_idx)}) not a multiple of "
|
| 726 |
-
f"rotation_format={spec.rotation_format!r} dim ({n_per_arm})"
|
| 727 |
-
)
|
| 728 |
-
rotations = action[:, rot_idx].reshape(num_frames, -1, n_per_arm)
|
| 729 |
-
angles = _rotation_angle_per_arm(rotations, spec.rotation_format) # (T, n_arms)
|
| 730 |
-
idle &= angles.max(axis=1) < eps_r
|
| 731 |
-
|
| 732 |
-
# GRIPPER: max |Δgripper| across all gripper dims; step 0's diff is 0.
|
| 733 |
-
if grip_idx:
|
| 734 |
-
gripper = action[:, grip_idx]
|
| 735 |
-
diff = np.abs(np.diff(gripper, axis=0, prepend=gripper[:1]))
|
| 736 |
-
idle &= diff.max(axis=1) < eps_g
|
| 737 |
-
|
| 738 |
-
# JOINT: same frame-diff scheme with joint_threshold.
|
| 739 |
-
if joint_idx:
|
| 740 |
-
joints = action[:, joint_idx]
|
| 741 |
-
diff = np.abs(np.diff(joints, axis=0, prepend=joints[:1]))
|
| 742 |
-
idle &= diff.max(axis=1) < joint_threshold
|
| 743 |
-
|
| 744 |
-
if min_streak > 1:
|
| 745 |
-
idle = _consecutive_streaks(idle, min_streak)
|
| 746 |
-
|
| 747 |
-
return int(idle.sum())
|
|
|
|
| 19 |
canonical public entrypoint for representation conversion.
|
| 20 |
"""
|
| 21 |
|
|
|
|
| 22 |
from typing import Literal
|
| 23 |
|
| 24 |
import numpy as np
|
|
|
|
| 539 |
current_pose = next_pose
|
| 540 |
|
| 541 |
return np.stack(poses_abs) # [T,4,4]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cosmos-framework/cosmos_framework/data/vfm/action/robomind_franka_dataset.py
CHANGED
|
@@ -16,7 +16,6 @@
|
|
| 16 |
|
| 17 |
from __future__ import annotations
|
| 18 |
|
| 19 |
-
import math
|
| 20 |
import os
|
| 21 |
from typing import Any, cast
|
| 22 |
|
|
@@ -25,12 +24,7 @@ import torch
|
|
| 25 |
import torch.nn.functional as F
|
| 26 |
|
| 27 |
from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import (
|
| 28 |
-
ActionSpec,
|
| 29 |
BaseActionLeRobotDataset,
|
| 30 |
-
Gripper,
|
| 31 |
-
Pos,
|
| 32 |
-
Rot,
|
| 33 |
-
build_action_spec,
|
| 34 |
)
|
| 35 |
from cosmos_framework.data.vfm.action.pose_utils import (
|
| 36 |
PoseConvention,
|
|
@@ -77,14 +71,6 @@ class RoboMINDFrankaDataset(BaseActionLeRobotDataset):
|
|
| 77 |
# 1.5°/s) so a single arm doing a slow approach (~1mm/f at 10 Hz) is no
|
| 78 |
# longer classified as idle.
|
| 79 |
#
|
| 80 |
-
# Class defaults below match single-arm. Dual-arm overrides at instance
|
| 81 |
-
# construction (see ``__init__``).
|
| 82 |
-
_IDLE_EPS_T_SINGLE: float = 22e-3
|
| 83 |
-
_IDLE_EPS_R_SINGLE: float = math.radians(3.0)
|
| 84 |
-
_IDLE_EPS_T_DUAL: float = 5e-3 # = base default; tight enough
|
| 85 |
-
_IDLE_EPS_R_DUAL: float = math.radians(1.5) # for "single-arm-slow" cases
|
| 86 |
-
idle_eps_t_per_sec: float = _IDLE_EPS_T_SINGLE
|
| 87 |
-
idle_eps_r_per_sec: float = _IDLE_EPS_R_SINGLE
|
| 88 |
|
| 89 |
def __init__(
|
| 90 |
self,
|
|
@@ -113,7 +99,6 @@ class RoboMINDFrankaDataset(BaseActionLeRobotDataset):
|
|
| 113 |
split_val_ratio=split_val_ratio,
|
| 114 |
split=split,
|
| 115 |
mode=mode,
|
| 116 |
-
embodiment_type=embodiment_type,
|
| 117 |
viewpoint=viewpoint,
|
| 118 |
pose_convention=pose_convention,
|
| 119 |
rotation_format="rot6d",
|
|
@@ -121,15 +106,10 @@ class RoboMINDFrankaDataset(BaseActionLeRobotDataset):
|
|
| 121 |
enable_fast_init=enable_fast_init,
|
| 122 |
)
|
| 123 |
|
|
|
|
| 124 |
self._to_opencv: np.ndarray = _ROBOMIND_FRANKA_TO_OPENCV[:3, :3]
|
| 125 |
self._is_concat_view: bool = viewpoint == "concat_view"
|
| 126 |
|
| 127 |
-
# Per-embodiment idle thresholds (instance-level override of the
|
| 128 |
-
# class default which matches single-arm). Dual-arm tightens both
|
| 129 |
-
# eps_t and eps_r to reflect its smaller per-frame motion tail.
|
| 130 |
-
if embodiment_type == "robomind-franka-dual":
|
| 131 |
-
self.idle_eps_t_per_sec = self._IDLE_EPS_T_DUAL
|
| 132 |
-
self.idle_eps_r_per_sec = self._IDLE_EPS_R_DUAL
|
| 133 |
|
| 134 |
embodiment_key = embodiment_type.removeprefix("robomind-")
|
| 135 |
lerobot_roots = LEROBOT_ROOTS[embodiment_key]
|
|
@@ -220,26 +200,6 @@ class RoboMINDFrankaDataset(BaseActionLeRobotDataset):
|
|
| 220 |
composite = torch.cat([top_or_front, bottom], dim=-2) # [T,C,3H/2,W]
|
| 221 |
return composite # [T,C,3H/2,W]
|
| 222 |
|
| 223 |
-
def _build_action_spec(self) -> ActionSpec:
|
| 224 |
-
"""RoboMIND Franka: 10D single-arm or 20D dual-arm.
|
| 225 |
-
|
| 226 |
-
Single (``robomind-franka``):
|
| 227 |
-
``[Pos, Rot6d, Gripper]`` (10D)
|
| 228 |
-
|
| 229 |
-
Dual (``robomind-franka-dual``):
|
| 230 |
-
``[L_Pos, L_Rot6d, L_Gripper, R_Pos, R_Rot6d, R_Gripper]`` (20D)
|
| 231 |
-
"""
|
| 232 |
-
if self._embodiment_type == "robomind-franka":
|
| 233 |
-
return build_action_spec(Pos(), Rot("rot6d"), Gripper())
|
| 234 |
-
# dual arm
|
| 235 |
-
return build_action_spec(
|
| 236 |
-
Pos(prefix="left"),
|
| 237 |
-
Rot("rot6d", prefix="left"),
|
| 238 |
-
Gripper(prefix="left"),
|
| 239 |
-
Pos(prefix="right"),
|
| 240 |
-
Rot("rot6d", prefix="right"),
|
| 241 |
-
Gripper(prefix="right"),
|
| 242 |
-
)
|
| 243 |
|
| 244 |
def __getitem__(self, idx: int) -> dict[str, Any]:
|
| 245 |
mode, _, _, sample = self._fetch_sample(idx)
|
|
|
|
| 16 |
|
| 17 |
from __future__ import annotations
|
| 18 |
|
|
|
|
| 19 |
import os
|
| 20 |
from typing import Any, cast
|
| 21 |
|
|
|
|
| 24 |
import torch.nn.functional as F
|
| 25 |
|
| 26 |
from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import (
|
|
|
|
| 27 |
BaseActionLeRobotDataset,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
)
|
| 29 |
from cosmos_framework.data.vfm.action.pose_utils import (
|
| 30 |
PoseConvention,
|
|
|
|
| 71 |
# 1.5°/s) so a single arm doing a slow approach (~1mm/f at 10 Hz) is no
|
| 72 |
# longer classified as idle.
|
| 73 |
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
def __init__(
|
| 76 |
self,
|
|
|
|
| 99 |
split_val_ratio=split_val_ratio,
|
| 100 |
split=split,
|
| 101 |
mode=mode,
|
|
|
|
| 102 |
viewpoint=viewpoint,
|
| 103 |
pose_convention=pose_convention,
|
| 104 |
rotation_format="rot6d",
|
|
|
|
| 106 |
enable_fast_init=enable_fast_init,
|
| 107 |
)
|
| 108 |
|
| 109 |
+
self._embodiment_type = embodiment_type
|
| 110 |
self._to_opencv: np.ndarray = _ROBOMIND_FRANKA_TO_OPENCV[:3, :3]
|
| 111 |
self._is_concat_view: bool = viewpoint == "concat_view"
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
embodiment_key = embodiment_type.removeprefix("robomind-")
|
| 115 |
lerobot_roots = LEROBOT_ROOTS[embodiment_key]
|
|
|
|
| 200 |
composite = torch.cat([top_or_front, bottom], dim=-2) # [T,C,3H/2,W]
|
| 201 |
return composite # [T,C,3H/2,W]
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
def __getitem__(self, idx: int) -> dict[str, Any]:
|
| 205 |
mode, _, _, sample = self._fetch_sample(idx)
|
cosmos-framework/cosmos_framework/data/vfm/action/umi_lerobot_dataset.py
CHANGED
|
@@ -11,7 +11,6 @@ import numpy as np
|
|
| 11 |
import torch
|
| 12 |
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
| 13 |
|
| 14 |
-
from cosmos_framework.data.vfm.action.action_spec import ActionSpec, Gripper, Pos, Rot, build_action_spec
|
| 15 |
from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import BaseActionLeRobotDataset
|
| 16 |
from cosmos_framework.data.vfm.action.pose_utils import PoseConvention, build_abs_pose_from_components, pose_abs_to_rel
|
| 17 |
from cosmos_framework.data.vfm.action.viewpoint_utils import Viewpoint
|
|
@@ -45,7 +44,6 @@ class UMIFastLeRobotDataset(BaseActionLeRobotDataset):
|
|
| 45 |
split_val_ratio=split_val_ratio,
|
| 46 |
split=split,
|
| 47 |
mode=mode,
|
| 48 |
-
embodiment_type="umi",
|
| 49 |
viewpoint=viewpoint,
|
| 50 |
pose_convention=pose_convention,
|
| 51 |
rotation_format="rot6d",
|
|
@@ -62,8 +60,6 @@ class UMIFastLeRobotDataset(BaseActionLeRobotDataset):
|
|
| 62 |
_GRIPPER_FEATURE: observation_ts,
|
| 63 |
}
|
| 64 |
|
| 65 |
-
def _build_action_spec(self) -> ActionSpec:
|
| 66 |
-
return build_action_spec(Pos(), Rot("rot6d"), Gripper())
|
| 67 |
|
| 68 |
def _register_sources(self, shard_indices: list[int] | None = None) -> None:
|
| 69 |
roots = self._all_shard_roots if shard_indices is None else [self._all_shard_roots[i] for i in shard_indices]
|
|
|
|
| 11 |
import torch
|
| 12 |
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
| 13 |
|
|
|
|
| 14 |
from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import BaseActionLeRobotDataset
|
| 15 |
from cosmos_framework.data.vfm.action.pose_utils import PoseConvention, build_abs_pose_from_components, pose_abs_to_rel
|
| 16 |
from cosmos_framework.data.vfm.action.viewpoint_utils import Viewpoint
|
|
|
|
| 44 |
split_val_ratio=split_val_ratio,
|
| 45 |
split=split,
|
| 46 |
mode=mode,
|
|
|
|
| 47 |
viewpoint=viewpoint,
|
| 48 |
pose_convention=pose_convention,
|
| 49 |
rotation_format="rot6d",
|
|
|
|
| 60 |
_GRIPPER_FEATURE: observation_ts,
|
| 61 |
}
|
| 62 |
|
|
|
|
|
|
|
| 63 |
|
| 64 |
def _register_sources(self, shard_indices: list[int] | None = None) -> None:
|
| 65 |
roots = self._all_shard_roots if shard_indices is None else [self._all_shard_roots[i] for i in shard_indices]
|