XinKongCosmos commited on
Commit
381b35a
·
verified ·
1 Parent(s): 3264a6d

Deep trim viewer-only release

Browse files
cosmos-framework/cosmos_framework/data/imaginaire/__init__.py DELETED
File without changes
cosmos-framework/cosmos_framework/data/imaginaire/webdataset/__init__.py DELETED
File without changes
cosmos-framework/cosmos_framework/data/vfm/action/action_spec.py DELETED
@@ -1,235 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: OpenMDW-1.1
3
-
4
- """Action-vector specification: per-dim type label + idle thresholds.
5
-
6
- Single concept: every column of an action vector has a :class:`DimType` label.
7
- Idle detection iterates by type and applies the matching algorithm:
8
-
9
- POS → ‖action[pos_idx]‖ per arm < eps_t
10
- ROT → distance(rot, identity) per group < eps_r
11
- GRIPPER → max |Δgripper| < eps_g (frame 0 idle by convention)
12
- JOINT → max |Δjoint| < joint_threshold (frame 0 idle)
13
- RESERVED → ignored
14
-
15
- An :class:`ActionSpec` is just ``names`` + ``types`` + ``rotation_format``.
16
- Build one declaratively via :func:`build_action_spec` from DSL components::
17
-
18
- build_action_spec(Pos(), Rot("rot6d"), Gripper()) # 10D single arm
19
- build_action_spec(Pos(), Rot("rot6d")) # 9D no gripper
20
- build_action_spec(Joint(n=14, label="arm"), # 30D joint-space
21
- Joint(n=14, label="end"),
22
- Joint(n=2, label="gripper"))
23
- build_action_spec(Pos(prefix="left"), Rot("rot6d", "left"), Gripper(prefix="left"),
24
- Pos(prefix="right"), Rot("rot6d", "right"), Gripper(prefix="right"))
25
-
26
- Naming convention:
27
- Default ``pos_x``, ``rot_0``, ``gripper``, ``arm_0`` ...
28
- With ``prefix="left"`` (idempotent on trailing ``_``): ``left_pos_x`` ...
29
- """
30
-
31
- from __future__ import annotations
32
-
33
- from dataclasses import dataclass
34
- from enum import Enum
35
- from typing import ClassVar
36
-
37
- from cosmos_framework.data.vfm.action.pose_utils import (
38
- RotationConvention,
39
- _identity_rotation_vector,
40
- )
41
-
42
-
43
- class DimType(str, Enum):
44
- """Per-column action-dim category (drives idle detection)."""
45
-
46
- POS = "pos"
47
- ROT = "rot"
48
- GRIPPER = "gripper"
49
- JOINT = "joint"
50
- RESERVED = "reserved"
51
-
52
-
53
- @dataclass(frozen=True, slots=True)
54
- class ActionSpec:
55
- """Structural description of an action vector: names + per-dim types.
56
-
57
- All ROT dims share a single ``rotation_format``; mixed formats in one spec
58
- are not supported (raise at build time).
59
-
60
- This struct contains no detection thresholds — those are passed at call
61
- time to :func:`compute_idle_frames` so each dataset can tune them
62
- independently of layout.
63
- """
64
-
65
- names: list[str]
66
- types: list[DimType]
67
- rotation_format: RotationConvention = "rot6d"
68
-
69
- @property
70
- def dim(self) -> int:
71
- return len(self.names)
72
-
73
-
74
- # ---------------------------------------------------------------------------
75
- # DSL components
76
- # ---------------------------------------------------------------------------
77
-
78
-
79
- def _join_prefix(prefix: str, name: str) -> str:
80
- """Join ``prefix`` and ``name`` with a single ``_``; idempotent on trailing ``_``."""
81
- return name if not prefix else f"{prefix.rstrip('_')}_{name}"
82
-
83
-
84
- @dataclass(frozen=True)
85
- class Pos:
86
- """Translation block.
87
-
88
- Default 3D (``pos_x``, ``pos_y``, ``pos_z``). For planar tasks (e.g. PushT)
89
- use ``Pos(dim=2)`` → ``pos_x``, ``pos_y``. ``dim >= 4`` falls back to
90
- indexed names ``pos_0``, ``pos_1``, ...
91
- """
92
-
93
- dim: int = 3
94
- prefix: str = ""
95
- type: ClassVar[DimType] = DimType.POS
96
-
97
- def names(self) -> list[str]:
98
- if self.dim <= 3:
99
- return [_join_prefix(self.prefix, f"pos_{c}") for c in "xyz"[: self.dim]]
100
- return [_join_prefix(self.prefix, f"pos_{i}") for i in range(self.dim)]
101
-
102
-
103
- @dataclass(frozen=True)
104
- class Rot:
105
- """Rotation block; ``format`` selects the encoding.
106
-
107
- Supported formats and per-dim names:
108
-
109
- - ``rot6d`` → 6 dims, ``rot_0`` ... ``rot_5`` (identity ``[1,0,0,0,1,0]``)
110
- - ``rot9d`` → 9 dims, ``rot_0`` ... ``rot_8`` (identity ``[1,0,0,0,1,0,0,0,1]``)
111
- - ``euler_xyz`` → 3 dims, ``roll``, ``pitch``, ``yaw`` (identity ``[0,0,0]``)
112
- - ``axisangle`` → 3 dims, ``axang_x/y/z`` (identity ``[0,0,0]``)
113
- - ``quat_xyzw`` / ``quat_wxyz`` → 4 dims, ``quat_x/y/z/w`` in declared order
114
- """
115
-
116
- format: RotationConvention = "rot6d"
117
- prefix: str = ""
118
- type: ClassVar[DimType] = DimType.ROT
119
-
120
- @property
121
- def rotation_format(self) -> RotationConvention:
122
- return self.format
123
-
124
- @property
125
- def dim(self) -> int:
126
- return _identity_rotation_vector(self.format).shape[0]
127
-
128
- def names(self) -> list[str]:
129
- if self.format == "euler_xyz":
130
- return [_join_prefix(self.prefix, c) for c in ("roll", "pitch", "yaw")]
131
- if self.format == "axisangle":
132
- return [_join_prefix(self.prefix, f"axang_{c}") for c in "xyz"]
133
- if self.format.startswith("quat_"):
134
- order = self.format.split("_", 1)[1] # "xyzw" or "wxyz"
135
- return [_join_prefix(self.prefix, f"quat_{c}") for c in order]
136
- return [_join_prefix(self.prefix, f"rot_{i}") for i in range(self.dim)]
137
-
138
-
139
- @dataclass(frozen=True)
140
- class Gripper:
141
- """1D gripper command (binary 0/1 or continuous). Detected by frame-diff."""
142
-
143
- prefix: str = ""
144
- type: ClassVar[DimType] = DimType.GRIPPER
145
-
146
- @property
147
- def dim(self) -> int:
148
- return 1
149
-
150
- def names(self) -> list[str]:
151
- return [_join_prefix(self.prefix, "gripper")]
152
-
153
-
154
- @dataclass(frozen=True)
155
- class Joint:
156
- """``n`` joint commands. Detected by frame-diff against ``joint_threshold``."""
157
-
158
- n: int = 0
159
- label: str = "joint"
160
- prefix: str = ""
161
- type: ClassVar[DimType] = DimType.JOINT
162
-
163
- @property
164
- def dim(self) -> int:
165
- return self.n
166
-
167
- def names(self) -> list[str]:
168
- return [_join_prefix(self.prefix, f"{self.label}_{i}") for i in range(self.n)]
169
-
170
-
171
- @dataclass(frozen=True)
172
- class Reserved:
173
- """``n`` dims counted in ``action_dim`` but ignored by idle detection."""
174
-
175
- n: int = 0
176
- label: str = "reserved"
177
- prefix: str = ""
178
- type: ClassVar[DimType] = DimType.RESERVED
179
-
180
- @property
181
- def dim(self) -> int:
182
- return self.n
183
-
184
- def names(self) -> list[str]:
185
- return [_join_prefix(self.prefix, f"{self.label}_{i}") for i in range(self.n)]
186
-
187
-
188
- # ---------------------------------------------------------------------------
189
- # Builder
190
- # ---------------------------------------------------------------------------
191
-
192
-
193
- # Type alias for any DSL component. Not a runtime check — only annotation hint.
194
- Component = Pos | Rot | Gripper | Joint | Reserved
195
-
196
-
197
- def build_action_spec(*components: Component) -> ActionSpec:
198
- """Compose ``components`` into an :class:`ActionSpec`.
199
-
200
- Each component contributes its ``names()`` and replicates its ``type`` for
201
- every column it occupies. The first ROT component's ``rotation_format``
202
- is captured for the whole spec; mixing formats raises ``ValueError``.
203
- """
204
- names: list[str] = []
205
- types: list[DimType] = []
206
- rotation_format: RotationConvention | None = None
207
-
208
- for c in components:
209
- names.extend(c.names())
210
- types.extend([c.type] * c.dim)
211
- if c.type == DimType.ROT:
212
- fmt = c.rotation_format # type: ignore[union-attr]
213
- if rotation_format is None:
214
- rotation_format = fmt
215
- elif rotation_format != fmt:
216
- raise ValueError(f"Mixed rotation_format in one ActionSpec: {rotation_format!r} vs {fmt!r}")
217
-
218
- return ActionSpec(
219
- names=names,
220
- types=types,
221
- rotation_format=rotation_format or "rot6d",
222
- )
223
-
224
-
225
- __all__ = [
226
- "ActionSpec",
227
- "Component",
228
- "DimType",
229
- "Gripper",
230
- "Joint",
231
- "Pos",
232
- "Reserved",
233
- "Rot",
234
- "build_action_spec",
235
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosmos-framework/cosmos_framework/data/vfm/action/av_dataset.py CHANGED
@@ -37,8 +37,6 @@ from torch.utils.data import IterableDataset
37
  # torch.multiprocessing.set_sharing_strategy("file_system")
38
  from cosmos_framework.utils import log
39
  from cosmos_framework.utils.easy_io import easy_io
40
- from cosmos_framework.data.vfm.action.camera_dataset import get_target_size_and_crop
41
- from cosmos_framework.data.vfm.action.domain_utils import get_domain_id
42
  from cosmos_framework.data.vfm.action.pose_utils import (
43
  RotationConvention,
44
  build_abs_pose_from_components,
@@ -46,6 +44,24 @@ from cosmos_framework.data.vfm.action.pose_utils import (
46
  )
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def decode_video_bytes(
50
  video_bytes: bytes,
51
  resolution: str | None = None,
@@ -509,7 +525,6 @@ class AVDataset(IterableDataset):
509
  resolution: str | None = None,
510
  fps: int = 10,
511
  mode: str = "policy",
512
- embodiment_type: str = "av",
513
  split: str = "train",
514
  seed: int = 0,
515
  shuffle: bool = True,
@@ -529,11 +544,6 @@ class AVDataset(IterableDataset):
529
  rotation_scale: float = 1.0,
530
  max_action_translation_norm: float | None = None,
531
  align_opencv_pose: bool = False,
532
- # When True, use a separate domain ID for inverse dynamics / policy modes
533
- # so that DomainAwareLinear learns different projections for anchored (conditioning)
534
- # vs framewise (generation) action representations.
535
- mode_aware_domain: bool = False,
536
- inv_embodiment_type: str = "av_inv",
537
  ):
538
  """Initialize AVDataset.
539
 
@@ -543,7 +553,6 @@ class AVDataset(IterableDataset):
543
  resolution: Target resolution for video frames (e.g. "256", "480"). If None, keeps original resolution.
544
  fps: Target frames per second for video and actions.
545
  mode: Training mode ('policy', 'forward_dynamics', 'inverse_dynamics', 'image2video', 'joint').
546
- embodiment_type: Embodiment type for domain ID.
547
  split: Dataset split ('train', 'val', or 'full').
548
  seed: Random seed for shuffling.
549
  shuffle: Whether to shuffle tar files during iteration (for training).
@@ -570,8 +579,6 @@ class AVDataset(IterableDataset):
570
  align_opencv_pose: If True, transform pose rotations from car body-frame
571
  convention (x=forward, y=left, z=up) to OpenCV camera convention
572
  (x=right, y=down, z=forward) before computing relative actions.
573
- mode_aware_domain: When True, inverse_dynamics/policy modes use a separate domain ID.
574
- inv_embodiment_type: Embodiment type string for the inverse domain ID.
575
  """
576
  super().__init__()
577
 
@@ -602,11 +609,6 @@ class AVDataset(IterableDataset):
602
  self.rotation_scale = rotation_scale
603
  self.max_action_translation_norm = max_action_translation_norm
604
  self.align_opencv_pose = align_opencv_pose
605
- # Get domain ID for this embodiment
606
- self.domain_id = get_domain_id(embodiment_type)
607
- self.mode_aware_domain = mode_aware_domain
608
- self.domain_id_inv = get_domain_id(inv_embodiment_type) if mode_aware_domain else self.domain_id
609
-
610
  # Validate mode
611
  valid_modes = ["joint", "forward_dynamics", "inverse_dynamics", "policy", "image2video"]
612
  if mode not in valid_modes:
@@ -864,11 +866,6 @@ class AVDataset(IterableDataset):
864
  )
865
  # prompt += f"Predict the future {future_duration:.1f}s action trajectory at {self.fps}Hz."
866
 
867
- # Select domain ID: use inverse domain for generation modes when mode_aware_domain is on
868
- if self.mode_aware_domain and mode in ["inverse_dynamics", "policy"]:
869
- domain_id = self.domain_id_inv
870
- else:
871
- domain_id = self.domain_id
872
 
873
  sample = {
874
  "video": video,
@@ -881,7 +878,6 @@ class AVDataset(IterableDataset):
881
  "ai_caption": prompt,
882
  "mode": mode,
883
  "__key__": key_tensor,
884
- "domain_id": torch.tensor(domain_id, dtype=torch.long),
885
  "history_length": actual_history_length,
886
  "future_length": actual_future_length,
887
  "viewpoint": "ego_view",
@@ -1001,7 +997,6 @@ if __name__ == "__main__":
1001
  print(f"{'future_length':<25}: {data['future_length']}")
1002
  print(f"{'conditioning_fps':<25}: {data['conditioning_fps'].item()}")
1003
  print(f"{'mode':<25}: {data['mode']}")
1004
- print(f"{'domain_id':<25}: {data['domain_id'].item()}")
1005
  print(f"{'prompt':<25}: {data['prompt']}")
1006
 
1007
  # save video
 
37
  # torch.multiprocessing.set_sharing_strategy("file_system")
38
  from cosmos_framework.utils import log
39
  from cosmos_framework.utils.easy_io import easy_io
 
 
40
  from cosmos_framework.data.vfm.action.pose_utils import (
41
  RotationConvention,
42
  build_abs_pose_from_components,
 
44
  )
45
 
46
 
47
+ VIDEO_RES_SIZE_INFO: dict[str, dict[str, tuple[int, int]]] = {
48
+ "256": {"1,1": (256, 256), "4,3": (320, 256), "3,4": (256, 320), "16,9": (320, 192), "9,16": (192, 320)},
49
+ "480": {"1,1": (640, 640), "4,3": (736, 544), "3,4": (544, 736), "16,9": (832, 480), "9,16": (480, 832)},
50
+ }
51
+
52
+
53
+ def get_target_size_and_crop(resolution: str, current_H: int, current_W: int) -> tuple[int, int, int, int]:
54
+ target_resolutions = VIDEO_RES_SIZE_INFO[resolution]
55
+ current_ar = current_W / current_H
56
+ best_key = min(
57
+ target_resolutions,
58
+ key=lambda key: abs((int(key.split(",")[0]) / int(key.split(",")[1])) - current_ar),
59
+ )
60
+ target_canvas_W, target_canvas_H = target_resolutions[best_key]
61
+ scaling_ratio = max(target_canvas_W / current_W, target_canvas_H / current_H)
62
+ return int(scaling_ratio * current_H + 0.5), int(scaling_ratio * current_W + 0.5), target_canvas_H, target_canvas_W
63
+
64
+
65
  def decode_video_bytes(
66
  video_bytes: bytes,
67
  resolution: str | None = None,
 
525
  resolution: str | None = None,
526
  fps: int = 10,
527
  mode: str = "policy",
 
528
  split: str = "train",
529
  seed: int = 0,
530
  shuffle: bool = True,
 
544
  rotation_scale: float = 1.0,
545
  max_action_translation_norm: float | None = None,
546
  align_opencv_pose: bool = False,
 
 
 
 
 
547
  ):
548
  """Initialize AVDataset.
549
 
 
553
  resolution: Target resolution for video frames (e.g. "256", "480"). If None, keeps original resolution.
554
  fps: Target frames per second for video and actions.
555
  mode: Training mode ('policy', 'forward_dynamics', 'inverse_dynamics', 'image2video', 'joint').
 
556
  split: Dataset split ('train', 'val', or 'full').
557
  seed: Random seed for shuffling.
558
  shuffle: Whether to shuffle tar files during iteration (for training).
 
579
  align_opencv_pose: If True, transform pose rotations from car body-frame
580
  convention (x=forward, y=left, z=up) to OpenCV camera convention
581
  (x=right, y=down, z=forward) before computing relative actions.
 
 
582
  """
583
  super().__init__()
584
 
 
609
  self.rotation_scale = rotation_scale
610
  self.max_action_translation_norm = max_action_translation_norm
611
  self.align_opencv_pose = align_opencv_pose
 
 
 
 
 
612
  # Validate mode
613
  valid_modes = ["joint", "forward_dynamics", "inverse_dynamics", "policy", "image2video"]
614
  if mode not in valid_modes:
 
866
  )
867
  # prompt += f"Predict the future {future_duration:.1f}s action trajectory at {self.fps}Hz."
868
 
 
 
 
 
 
869
 
870
  sample = {
871
  "video": video,
 
878
  "ai_caption": prompt,
879
  "mode": mode,
880
  "__key__": key_tensor,
 
881
  "history_length": actual_history_length,
882
  "future_length": actual_future_length,
883
  "viewpoint": "ego_view",
 
997
  print(f"{'future_length':<25}: {data['future_length']}")
998
  print(f"{'conditioning_fps':<25}: {data['conditioning_fps'].item()}")
999
  print(f"{'mode':<25}: {data['mode']}")
 
1000
  print(f"{'prompt':<25}: {data['prompt']}")
1001
 
1002
  # save video
cosmos-framework/cosmos_framework/data/vfm/action/bridge_orig_lerobot_dataset.py CHANGED
@@ -18,12 +18,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
18
 
19
  from cosmos_framework.utils import log
20
  from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import (
21
- ActionSpec,
22
  BaseActionLeRobotDataset,
23
- Gripper,
24
- Pos,
25
- Rot,
26
- build_action_spec,
27
  )
28
  from cosmos_framework.data.vfm.action.pose_utils import (
29
  PoseConvention,
@@ -111,7 +106,6 @@ class BridgeOrigLeRobotDataset(BaseActionLeRobotDataset):
111
  split_val_ratio=split_val_ratio,
112
  split=split,
113
  mode=mode,
114
- embodiment_type="bridge_orig_lerobot",
115
  viewpoint=viewpoint,
116
  pose_convention=pose_convention,
117
  rotation_format="rot6d",
@@ -240,9 +234,6 @@ class BridgeOrigLeRobotDataset(BaseActionLeRobotDataset):
240
  # __getitem__
241
  # ------------------------------------------------------------------
242
 
243
- def _build_action_spec(self) -> ActionSpec:
244
- """Bridge: 10D = ``[Pos, Rot6d, Gripper]``."""
245
- return build_action_spec(Pos(), Rot("rot6d"), Gripper())
246
 
247
  def __getitem__(self, idx: int) -> dict[str, Any]:
248
  """ """
 
18
 
19
  from cosmos_framework.utils import log
20
  from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import (
 
21
  BaseActionLeRobotDataset,
 
 
 
 
22
  )
23
  from cosmos_framework.data.vfm.action.pose_utils import (
24
  PoseConvention,
 
106
  split_val_ratio=split_val_ratio,
107
  split=split,
108
  mode=mode,
 
109
  viewpoint=viewpoint,
110
  pose_convention=pose_convention,
111
  rotation_format="rot6d",
 
234
  # __getitem__
235
  # ------------------------------------------------------------------
236
 
 
 
 
237
 
238
  def __getitem__(self, idx: int) -> dict[str, Any]:
239
  """ """
cosmos-framework/cosmos_framework/data/vfm/action/camera_dataset.py DELETED
@@ -1,15 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: OpenMDW-1.1
3
-
4
- VIDEO_RES_SIZE_INFO: dict[str, dict[str, tuple[int, int]]] = {
5
- "256": {"1,1": (256, 256), "4,3": (320, 256), "3,4": (256, 320), "16,9": (320, 192), "9,16": (192, 320)},
6
- "480": {"1,1": (640, 640), "4,3": (736, 544), "3,4": (544, 736), "16,9": (832, 480), "9,16": (480, 832)},
7
- }
8
-
9
- def get_target_size_and_crop(resolution: str, current_H: int, current_W: int) -> tuple[int, int, int, int]:
10
- target_resolutions = VIDEO_RES_SIZE_INFO[resolution]
11
- current_ar = current_W / current_H
12
- best_key = min(target_resolutions, key=lambda key: abs((int(key.split(',')[0]) / int(key.split(',')[1])) - current_ar))
13
- target_canvas_W, target_canvas_H = target_resolutions[best_key]
14
- scaling_ratio = max(target_canvas_W / current_W, target_canvas_H / current_H)
15
- return int(scaling_ratio * current_H + 0.5), int(scaling_ratio * current_W + 0.5), target_canvas_H, target_canvas_W
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosmos-framework/cosmos_framework/data/vfm/action/cosmos3_action_lerobot.py CHANGED
@@ -14,7 +14,6 @@ from __future__ import annotations
14
 
15
  import importlib
16
  import logging as _logging
17
- import math
18
  import os as _os
19
  import random
20
  from bisect import bisect_right
@@ -53,23 +52,7 @@ def _ensure_hf_hub_offline() -> None:
53
  _hf_offline_applied = True
54
 
55
 
56
- from functools import cached_property
57
-
58
  from cosmos_framework.utils import log
59
- # Re-export the action_spec DSL from this module so that subclass datasets
60
- # only need a single import block (alongside ``BaseActionLeRobotDataset``).
61
- from cosmos_framework.data.vfm.action.action_spec import ( # noqa: F401 (re-export)
62
- ActionSpec,
63
- DimType,
64
- Gripper,
65
- Joint,
66
- Pos,
67
- Reserved,
68
- Rot,
69
- build_action_spec,
70
- )
71
- from cosmos_framework.data.vfm.action.domain_utils import get_domain_id
72
- from cosmos_framework.data.vfm.action.pose_utils import compute_idle_frames
73
  from cosmos_framework.data.vfm.action.viewpoint_utils import Viewpoint
74
 
75
  # ---------------------------------------------------------------------------
@@ -278,7 +261,6 @@ class BaseActionLeRobotDataset(Dataset):
278
  split_val_ratio: float,
279
  split: str,
280
  mode: str,
281
- embodiment_type: str,
282
  viewpoint: Viewpoint,
283
  pose_convention: str | None = None,
284
  rotation_format: str | None = None,
@@ -301,7 +283,6 @@ class BaseActionLeRobotDataset(Dataset):
301
  self._split_val_ratio = split_val_ratio
302
  self._split = _normalize_split(split)
303
  self._mode = mode
304
- self._embodiment_type = embodiment_type
305
  self._viewpoint: Viewpoint = viewpoint
306
  self._pose_convention = pose_convention
307
  self._rotation_format = rotation_format
@@ -331,7 +312,6 @@ class BaseActionLeRobotDataset(Dataset):
331
  self._episode_records: list[tuple[int, int, int, int]] = []
332
  self._episode_cum_ends: list[int] = []
333
  self._num_valid_indices = 0
334
- self._domain_id = get_domain_id(self._embodiment_type)
335
  self._all_shard_roots: list[str] = []
336
 
337
  # -- public properties ---------------------------------------------------
@@ -356,9 +336,6 @@ class BaseActionLeRobotDataset(Dataset):
356
  def mode(self, value: str) -> None:
357
  self._mode = value
358
 
359
- @property
360
- def domain_id(self) -> int:
361
- return self._domain_id
362
 
363
  # -- source registration -------------------------------------------------
364
 
@@ -679,138 +656,6 @@ class BaseActionLeRobotDataset(Dataset):
679
 
680
  # -- result building -----------------------------------------------------
681
 
682
- def _build_action_spec(self) -> ActionSpec | None:
683
- """Subclass override: declare this dataset's action layout.
684
-
685
- Called once per instance — the result is cached by ``self.action_spec``.
686
- Return ``None`` to skip spec-driven idle detection; in that case
687
- ``_compute_idle_frames`` will log a one-time warning and return
688
- ``None`` for every sample.
689
- """
690
- return None
691
-
692
- @cached_property
693
- def action_spec(self) -> ActionSpec | None:
694
- """Cached :class:`ActionSpec` from ``_build_action_spec``.
695
-
696
- Returns ``None`` when the subclass did not declare one; idle detection
697
- is then skipped (with a one-time warning) until the subclass overrides
698
- ``_build_action_spec``.
699
- """
700
- return self._build_action_spec()
701
-
702
- @cached_property
703
- def action_names(self) -> list[str] | None:
704
- spec = self.action_spec
705
- return spec.names if spec is not None else None
706
-
707
- # Idle-detection thresholds. Defined as **velocities** (per second) so the
708
- # same numeric value means the same physical motion across datasets with
709
- # different sampling rates; converted to per-frame at call time using
710
- # ``self._fps`` via :meth:`_resolve_idle_thresholds`.
711
- #
712
- # Defaults:
713
- # - ``idle_eps_t_per_sec`` = 5 mm/s (≈ 1 mm/frame at 5 Hz)
714
- # - ``idle_eps_r_per_sec`` = 1.5°/s (geodesic, rotation-format aware)
715
- # - ``idle_eps_g`` = 1e-2 unit gripper Δ (no fps)
716
- # - ``idle_joint_threshold_per_sec`` = 5e-3 rad/s
717
- # - ``idle_min_streak`` = 3 require ≥ 3 consecutive
718
- #
719
- # Subclasses can either override the ``*_per_sec`` attributes (preferred —
720
- # keeps the velocity semantics) or set the corresponding ``idle_eps_*`` /
721
- # ``idle_joint_threshold`` attribute to a non-``None`` value to bypass the
722
- # per-fps conversion entirely (raw per-frame override).
723
- idle_eps_t_per_sec: float = 5e-3
724
- idle_eps_r_per_sec: float = math.radians(1.5)
725
- idle_eps_g: float = 1e-2
726
- idle_joint_threshold_per_sec: float = 5e-3
727
- idle_min_streak: int = 3
728
-
729
- # Optional per-frame overrides. ``None`` (default) → use the ``*_per_sec``
730
- # attribute / fps conversion above.
731
- idle_eps_t: float | None = None
732
- idle_eps_r: float | None = None
733
- idle_joint_threshold: float | None = None
734
-
735
- def _resolve_idle_thresholds(self) -> tuple[float, float, float, float]:
736
- """Resolve per-frame idle thresholds for this dataset instance.
737
-
738
- Returns ``(eps_t, eps_r, eps_g, joint_threshold)`` in raw per-frame
739
- units. Honours direct per-frame overrides if the subclass sets the
740
- non-``_per_sec`` attribute; otherwise scales the ``_per_sec`` values
741
- by ``self._fps``.
742
- """
743
- fps = float(self._fps) if self._fps else 1.0
744
- eps_t = self.idle_eps_t if self.idle_eps_t is not None else self.idle_eps_t_per_sec / fps
745
- eps_r = self.idle_eps_r if self.idle_eps_r is not None else self.idle_eps_r_per_sec / fps
746
- joint_thr = (
747
- self.idle_joint_threshold
748
- if self.idle_joint_threshold is not None
749
- else self.idle_joint_threshold_per_sec / fps
750
- )
751
- return float(eps_t), float(eps_r), float(self.idle_eps_g), float(joint_thr)
752
-
753
- def _compute_idle_frames(self, raw_action: torch.Tensor) -> torch.Tensor | None:
754
- """Count idle frames in the *raw* (un-normalized) action chunk.
755
-
756
- Requires ``self.action_spec`` to be declared via ``_build_action_spec``.
757
- Returns ``None`` when:
758
- - ``pose_convention`` is not ``"backward_framewise"`` (TODO: extend),
759
- - the subclass has not declared an ``ActionSpec`` (logs a one-time warning),
760
- - the action layout does not match the declared spec.
761
-
762
- Detection thresholds come from the ``idle_eps_*`` class attributes
763
- (overridable per dataset). Subclasses can also override this method
764
- outright, or pass an explicit ``idle_frames`` integer via
765
- ``**extras`` to :meth:`_build_result`.
766
- """
767
-
768
- # conventions (anchored / absolute) need different idle semantics.
769
- if self._pose_convention != "backward_framewise":
770
- if not getattr(self, "_warned_pose_convention", False):
771
- log.warning(
772
- f"Dataset {self.__class__.__name__}: pose_convention="
773
- f"{self._pose_convention!r} is not 'backward_framewise'; "
774
- "skipping idle-frames detection. Centralize the dataset "
775
- "to backward_framewise to enable IdleFrames captioning."
776
- )
777
- self._warned_pose_convention = True
778
- return None
779
-
780
- spec = self.action_spec
781
- if spec is None:
782
- if not getattr(self, "_warned_no_action_spec", False):
783
- log.warning(
784
- f"Dataset {self.__class__.__name__} has no action spec defined; "
785
- "skipping idle-frames detection. Override _build_action_spec() to enable it."
786
- )
787
- self._warned_no_action_spec = True
788
- return None
789
-
790
- eps_t, eps_r, eps_g, joint_thr = self._resolve_idle_thresholds()
791
- try:
792
- n = compute_idle_frames(
793
- raw_action,
794
- spec,
795
- eps_t=eps_t,
796
- eps_r=eps_r,
797
- eps_g=eps_g,
798
- joint_threshold=joint_thr,
799
- min_streak=self.idle_min_streak,
800
- )
801
- except (ValueError, TypeError) as e:
802
- if not getattr(self, "_warned_action_layout", False):
803
- log.warning(
804
- f"Dataset {self.__class__.__name__}: action layout does "
805
- f"not match the declared ActionSpec "
806
- f"(action_dim={int(raw_action.shape[-1])}, "
807
- f"spec.dim={spec.dim}); skipping idle-frames detection. "
808
- f"Underlying error: {e}"
809
- )
810
- self._warned_action_layout = True
811
- return None
812
- return torch.tensor(n, dtype=torch.long)
813
-
814
  def _build_result(
815
  self,
816
  *,
@@ -823,25 +668,12 @@ class BaseActionLeRobotDataset(Dataset):
823
  """Assemble the common return dict for ``__getitem__``.
824
 
825
  ``video`` is expected in raw LeRobot layout before final formatting.
826
- Subclasses may pass extra keys (e.g. ``initial_pose``) via ``**extras``.
827
- ``idle_frames`` is auto-computed from the raw (un-normalized) ``action``
828
- whenever the dataset's pose/rotation conventions allow it; subclasses
829
- can override by passing ``idle_frames`` (int or scalar tensor) via
830
  ``**extras``.
831
  """
832
- # Compute idle_frames from the raw action before normalization, unless
833
- # the subclass has provided one explicitly via ``**extras``.
834
- if "idle_frames" not in extras:
835
- idle_frames = self._compute_idle_frames(action)
836
- if idle_frames is not None:
837
- extras = {"idle_frames": idle_frames, **extras}
838
-
839
  raw_action = action # [T,D]
840
  if self._skip_video_loading:
841
- result: dict[str, Any] = {"action": raw_action}
842
- if "idle_frames" in extras:
843
- result["idle_frames"] = extras["idle_frames"]
844
- return result
845
  formatted_video = self._convert_video(video) # [C,T,H,W] | None
846
  return {
847
  "ai_caption": ai_caption,
@@ -849,7 +681,6 @@ class BaseActionLeRobotDataset(Dataset):
849
  "action": raw_action,
850
  "conditioning_fps": torch.tensor(self._fps, dtype=torch.long),
851
  "mode": mode,
852
- "domain_id": torch.tensor(self._domain_id, dtype=torch.long),
853
  "viewpoint": self._viewpoint,
854
  **extras,
855
  }
 
14
 
15
  import importlib
16
  import logging as _logging
 
17
  import os as _os
18
  import random
19
  from bisect import bisect_right
 
52
  _hf_offline_applied = True
53
 
54
 
 
 
55
  from cosmos_framework.utils import log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  from cosmos_framework.data.vfm.action.viewpoint_utils import Viewpoint
57
 
58
  # ---------------------------------------------------------------------------
 
261
  split_val_ratio: float,
262
  split: str,
263
  mode: str,
 
264
  viewpoint: Viewpoint,
265
  pose_convention: str | None = None,
266
  rotation_format: str | None = None,
 
283
  self._split_val_ratio = split_val_ratio
284
  self._split = _normalize_split(split)
285
  self._mode = mode
 
286
  self._viewpoint: Viewpoint = viewpoint
287
  self._pose_convention = pose_convention
288
  self._rotation_format = rotation_format
 
312
  self._episode_records: list[tuple[int, int, int, int]] = []
313
  self._episode_cum_ends: list[int] = []
314
  self._num_valid_indices = 0
 
315
  self._all_shard_roots: list[str] = []
316
 
317
  # -- public properties ---------------------------------------------------
 
336
  def mode(self, value: str) -> None:
337
  self._mode = value
338
 
 
 
 
339
 
340
  # -- source registration -------------------------------------------------
341
 
 
656
 
657
  # -- result building -----------------------------------------------------
658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
  def _build_result(
660
  self,
661
  *,
 
668
  """Assemble the common return dict for ``__getitem__``.
669
 
670
  ``video`` is expected in raw LeRobot layout before final formatting.
671
+ Subclasses may pass extra viewer metadata (e.g. ``initial_pose``) via
 
 
 
672
  ``**extras``.
673
  """
 
 
 
 
 
 
 
674
  raw_action = action # [T,D]
675
  if self._skip_video_loading:
676
+ return {"action": raw_action}
 
 
 
677
  formatted_video = self._convert_video(video) # [C,T,H,W] | None
678
  return {
679
  "ai_caption": ai_caption,
 
681
  "action": raw_action,
682
  "conditioning_fps": torch.tensor(self._fps, dtype=torch.long),
683
  "mode": mode,
 
684
  "viewpoint": self._viewpoint,
685
  **extras,
686
  }
cosmos-framework/cosmos_framework/data/vfm/action/domain_utils.py DELETED
@@ -1,29 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: OpenMDW-1.1
3
-
4
- """Domain ID helpers for cross-embodiment action datasets."""
5
-
6
- EMBODIMENT_TO_DOMAIN_ID: dict[str, int] = {
7
- "no_action": 0,
8
- "av": 1,
9
- "camera_pose": 2,
10
- "pusht": 4,
11
- "umi": 6,
12
- "bridge_orig_lerobot": 7,
13
- "droid_lerobot": 8,
14
- "robomind-franka": 8, # Both Droid and RoboMIND-Franka are using robotiq and franka
15
- "embodiment_b": 9,
16
- "robomind-franka-dual": 12,
17
- "fractal": 20,
18
- }
19
-
20
-
21
- def get_domain_id(embodiment_type: str) -> int:
22
- """Get the domain ID for a given embodiment type."""
23
- key = embodiment_type.lower().strip()
24
- if key not in EMBODIMENT_TO_DOMAIN_ID:
25
- raise KeyError(
26
- f"Unknown embodiment type: {embodiment_type!r}. "
27
- f"Available embodiments: {sorted(EMBODIMENT_TO_DOMAIN_ID.keys())}"
28
- )
29
- return EMBODIMENT_TO_DOMAIN_ID[key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosmos-framework/cosmos_framework/data/vfm/action/droid_lerobot_dataset.py CHANGED
@@ -11,13 +11,7 @@ from scipy.spatial.transform import Rotation as R
11
 
12
  from cosmos_framework.utils import log
13
  from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import (
14
- ActionSpec,
15
  BaseActionLeRobotDataset,
16
- Gripper,
17
- Joint,
18
- Pos,
19
- Rot,
20
- build_action_spec,
21
  build_episode_spans,
22
  split_episode_ids,
23
  )
@@ -87,7 +81,6 @@ class DROIDLeRobotDataset(BaseActionLeRobotDataset):
87
  split_val_ratio=split_val_ratio,
88
  split=split,
89
  mode=mode,
90
- embodiment_type="droid_lerobot",
91
  viewpoint=viewpoint,
92
  pose_convention=pose_convention,
93
  rotation_format="rot6d",
@@ -307,13 +300,6 @@ class DROIDLeRobotDataset(BaseActionLeRobotDataset):
307
  composite = torch.cat([wrist, bottom], dim=-2) # [T,C,3H/2,W]
308
  return composite # [T,C,3H/2,W]
309
 
310
- def _build_action_spec(self) -> ActionSpec:
311
- """DROID: 10D ``[Pos, Rot6d, Gripper]`` for ``ee_pose``,
312
- 8D ``[Joint(7), Gripper]`` for ``joint_pos``.
313
- """
314
- if self._action_space == "joint_pos":
315
- return build_action_spec(Joint(n=7, label="joint"), Gripper())
316
- return build_action_spec(Pos(), Rot("rot6d"), Gripper())
317
 
318
  def __getitem__(self, idx: int) -> dict[str, Any]:
319
  """ """
 
11
 
12
  from cosmos_framework.utils import log
13
  from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import (
 
14
  BaseActionLeRobotDataset,
 
 
 
 
 
15
  build_episode_spans,
16
  split_episode_ids,
17
  )
 
81
  split_val_ratio=split_val_ratio,
82
  split=split,
83
  mode=mode,
 
84
  viewpoint=viewpoint,
85
  pose_convention=pose_convention,
86
  rotation_format="rot6d",
 
300
  composite = torch.cat([wrist, bottom], dim=-2) # [T,C,3H/2,W]
301
  return composite # [T,C,3H/2,W]
302
 
 
 
 
 
 
 
 
303
 
304
  def __getitem__(self, idx: int) -> dict[str, Any]:
305
  """ """
cosmos-framework/cosmos_framework/data/vfm/action/fractal.py CHANGED
@@ -15,12 +15,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
15
 
16
  from cosmos_framework.utils import log
17
  from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import (
18
- ActionSpec,
19
  BaseActionLeRobotDataset,
20
- Gripper,
21
- Pos,
22
- Rot,
23
- build_action_spec,
24
  )
25
  from cosmos_framework.data.vfm.action.pose_utils import (
26
  PoseConvention,
@@ -112,7 +107,6 @@ class FractalLeRobotDataset(BaseActionLeRobotDataset):
112
  split_val_ratio=split_val_ratio,
113
  split=split,
114
  mode=mode,
115
- embodiment_type="fractal",
116
  viewpoint=viewpoint,
117
  pose_convention=pose_convention,
118
  rotation_format="rot6d",
@@ -141,9 +135,6 @@ class FractalLeRobotDataset(BaseActionLeRobotDataset):
141
  )
142
  return kept
143
 
144
- def _build_action_spec(self) -> ActionSpec:
145
- """Fractal: 10D = ``[Pos(3), Rot6d(6), Gripper(1)]``."""
146
- return build_action_spec(Pos(dim=3), Rot("rot6d"), Gripper())
147
 
148
  def __getitem__(self, idx: int) -> dict[str, Any]:
149
  """Return a single training sample."""
 
15
 
16
  from cosmos_framework.utils import log
17
  from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import (
 
18
  BaseActionLeRobotDataset,
 
 
 
 
19
  )
20
  from cosmos_framework.data.vfm.action.pose_utils import (
21
  PoseConvention,
 
107
  split_val_ratio=split_val_ratio,
108
  split=split,
109
  mode=mode,
 
110
  viewpoint=viewpoint,
111
  pose_convention=pose_convention,
112
  rotation_format="rot6d",
 
135
  )
136
  return kept
137
 
 
 
 
138
 
139
  def __getitem__(self, idx: int) -> dict[str, Any]:
140
  """Return a single training sample."""
cosmos-framework/cosmos_framework/data/vfm/action/pose_utils.py CHANGED
@@ -19,7 +19,6 @@ dataset stack:
19
  canonical public entrypoint for representation conversion.
20
  """
21
 
22
- import math
23
  from typing import Literal
24
 
25
  import numpy as np
@@ -540,208 +539,3 @@ def pose_rel_to_abs(
540
  current_pose = next_pose
541
 
542
  return np.stack(poses_abs) # [T,4,4]
543
-
544
-
545
- # -----------------------------------------------------------------------------
546
- # Idle-frame detection
547
- # -----------------------------------------------------------------------------
548
-
549
-
550
- def _identity_rotation_vector(rotation_format: RotationConvention) -> np.ndarray:
551
- """Return the identity-rotation vector for a given rotation convention.
552
-
553
- Used by :func:`compute_idle_frames` to test whether a rotation block is
554
- close to "no rotation" in its current encoding.
555
- """
556
- if rotation_format in ("matrix", "rot9d"):
557
- return np.array([1, 0, 0, 0, 1, 0, 0, 0, 1], dtype=np.float32)
558
- if rotation_format == "rot6d":
559
- return np.array([1, 0, 0, 0, 1, 0], dtype=np.float32)
560
- if rotation_format == "quat_xyzw":
561
- return np.array([0, 0, 0, 1], dtype=np.float32)
562
- if rotation_format == "quat_wxyz":
563
- return np.array([1, 0, 0, 0], dtype=np.float32)
564
- if rotation_format in ("euler_xyz", "axisangle"):
565
- return np.array([0, 0, 0], dtype=np.float32)
566
- raise ValueError(f"Unsupported rotation_format={rotation_format!r}")
567
-
568
-
569
- def _rotation_angle_per_arm(rotations: np.ndarray, rotation_format: str) -> np.ndarray:
570
- """Geodesic angle (rad) from identity for each arm at each frame.
571
-
572
- ``rotations`` has shape ``(T, n_arms, n_per_arm)``; the returned array has
573
- shape ``(T, n_arms)``. The angle is rotation-format aware so a fixed
574
- ``eps_r`` threshold has consistent geometric meaning across formats:
575
-
576
- - ``rot6d`` → reconstruct ``trace(R)`` in closed form from the two stored
577
- columns ``a, b`` (already unit-orthogonal as they came from a valid
578
- rotation matrix). The third column is ``a × b``, so
579
- ``trace(R) = a[0] + b[1] + a[0]·b[1] - a[1]·b[0]``.
580
- ``angle = arccos(clip((trace - 1) / 2, -1, 1))``.
581
- - ``rot9d`` → reshape to ``(..., 3, 3)`` and use
582
- ``trace(R) = R[0,0] + R[1,1] + R[2,2]``.
583
- - ``quat_xyzw`` / ``quat_wxyz`` → ``angle = 2 · arccos(|q_w|)``; the
584
- absolute value handles the double cover (``q`` and ``-q`` represent the
585
- same rotation).
586
- - ``axisangle`` → the magnitude of the axis-angle vector *is* the angle.
587
- - ``euler_xyz`` → no closed-form angle; use ``‖euler‖`` as a conservative
588
- upper bound (exact for single-axis rotations, an overestimate for
589
- composed ones — fine for idle detection where small angles are the
590
- regime of interest).
591
- """
592
- if rotation_format == "rot6d":
593
- a = rotations[..., :3]
594
- b = rotations[..., 3:6]
595
- trace = a[..., 0] + b[..., 1] + a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0]
596
- return np.arccos(np.clip((trace - 1.0) / 2.0, -1.0, 1.0))
597
- if rotation_format == "rot9d":
598
- mat = rotations.reshape(*rotations.shape[:-1], 3, 3)
599
- trace = mat[..., 0, 0] + mat[..., 1, 1] + mat[..., 2, 2]
600
- return np.arccos(np.clip((trace - 1.0) / 2.0, -1.0, 1.0))
601
- if rotation_format in ("quat_xyzw", "quat_wxyz"):
602
- qw = rotations[..., 3] if rotation_format == "quat_xyzw" else rotations[..., 0]
603
- return 2.0 * np.arccos(np.clip(np.abs(qw), 0.0, 1.0))
604
- if rotation_format == "axisangle":
605
- return np.linalg.norm(rotations, axis=-1)
606
- if rotation_format == "euler_xyz":
607
- # Exact for single-axis rotations, overestimate for composed ones —
608
- # safe for idle thresholds since overestimation can only mark a frame
609
- # as non-idle, never spuriously idle.
610
- return np.linalg.norm(rotations, axis=-1)
611
- raise ValueError(f"Unsupported rotation_format={rotation_format!r}")
612
-
613
-
614
- def _consecutive_streaks(idle: np.ndarray, min_streak: int) -> np.ndarray:
615
- """Zero out idle bits not belonging to a run of ``>= min_streak`` Trues.
616
-
617
- Pure-numpy two-pointer scan. ``min_streak <= 1`` is a no-op (returns the
618
- input mask unchanged).
619
- """
620
- if min_streak <= 1:
621
- return idle
622
- out = np.zeros_like(idle)
623
- n = len(idle)
624
- i = 0
625
- while i < n:
626
- if not idle[i]:
627
- i += 1
628
- continue
629
- j = i
630
- while j < n and idle[j]:
631
- j += 1
632
- if j - i >= min_streak:
633
- out[i:j] = True
634
- i = j
635
- return out
636
-
637
-
638
- def compute_idle_frames(
639
- action_raw: torch.Tensor | np.ndarray,
640
- spec: "ActionSpec", # noqa: F821 — forward ref, real import is in action_spec.py
641
- *,
642
- eps_t: float = 1e-3,
643
- eps_r: float = math.radians(5.0),
644
- eps_g: float = 1e-2,
645
- joint_threshold: float = 5e-4,
646
- min_streak: int = 3,
647
- ) -> int:
648
- """Count idle frames in a raw (un-normalized) action chunk.
649
-
650
- Idle detection runs per-DimType (driven by ``spec.types``); a frame is
651
- *raw-idle* iff every relevant type group is idle on that frame, and
652
- counts toward the final tally only if it belongs to a run of at least
653
- ``min_streak`` consecutive raw-idle frames. The streak filter rejects
654
- isolated low-motion frames (instantaneous slowdowns) which carry weak
655
- physical meaning and add noise to the IdleFrames training signal.
656
-
657
- DimType branches:
658
-
659
- - ``POS`` → combined ``‖action[pos_idx]‖`` (L2 across all POS dims)
660
- < ``eps_t``. For single-arm specs (3 dims) this is the standard ``‖t‖``
661
- check; for multi-arm specs the combined norm is slightly stricter than
662
- a per-arm check.
663
- - ``ROT`` → per-arm geodesic rotation angle (rad) from identity
664
- < ``eps_r``. The angle is computed in a rotation-format aware way (see
665
- :func:`_rotation_angle_per_arm`) so the threshold has consistent
666
- geometric meaning regardless of the encoding.
667
- - ``GRIPPER`` → ``max |action[t] - action[t-1]| < eps_g``. ``np.diff``
668
- with ``prepend=action[0]`` makes step 0 ``|0|`` (treated as "no change");
669
- with the streak filter this can no longer create a spurious single-frame
670
- idle event.
671
- - ``JOINT`` → same frame-diff scheme as gripper with
672
- ``joint_threshold`` (rad / step).
673
- - ``RESERVED`` → ignored.
674
-
675
- Defaults (in the units of the un-normalized action):
676
-
677
- - ``eps_t = 1e-3`` → 1 mm per-frame translation
678
- - ``eps_r = 5°`` → 5° per-frame rotation (geodesic angle)
679
- - ``eps_g = 1e-2`` → 1 % gripper command change
680
- - ``joint_threshold = 5e-4`` → ~0.03° / step joint angle change
681
- - ``min_streak = 3`` → require a run of >= 3 consecutive idle frames
682
-
683
- The input must be **un-normalized** so the identity transform sits at
684
- known coordinates (translation ≈ 0, rotation ≈ identity). The action
685
- vector is also assumed to be encoded in a per-step / framewise convention
686
- (e.g. ``backward_framewise``); anchored conventions (``backward_anchored``)
687
- accumulate over the chunk and would silently break the POS/ROT idle
688
- checks. Callers (e.g. the LeRobot base class) gate on pose convention
689
- before calling this function.
690
- """
691
- if isinstance(action_raw, torch.Tensor):
692
- action = action_raw.detach().cpu().numpy().astype(np.float32, copy=False)
693
- else:
694
- action = np.asarray(action_raw, dtype=np.float32)
695
-
696
- if action.ndim != 2:
697
- raise ValueError(f"action_raw must be 2-D (T, D); got shape {action.shape}")
698
- num_frames, action_dim = action.shape
699
- if num_frames == 0:
700
- return 0
701
- if action_dim != len(spec.types):
702
- raise ValueError(f"action_dim={action_dim} does not match spec.dim={len(spec.types)}")
703
-
704
- # Import locally to avoid a circular import at module load time
705
- # (action_spec.py imports RotationConvention from this file).
706
- from cosmos_framework.data.vfm.action.action_spec import DimType
707
-
708
- pos_idx = [i for i, t in enumerate(spec.types) if t == DimType.POS]
709
- rot_idx = [i for i, t in enumerate(spec.types) if t == DimType.ROT]
710
- grip_idx = [i for i, t in enumerate(spec.types) if t == DimType.GRIPPER]
711
- joint_idx = [i for i, t in enumerate(spec.types) if t == DimType.JOINT]
712
-
713
- idle = np.ones(num_frames, dtype=bool)
714
-
715
- # POS: combined L2 norm across all translation dims.
716
- if pos_idx:
717
- idle &= np.linalg.norm(action[:, pos_idx], axis=1) < eps_t
718
-
719
- # ROT: per-arm geodesic angle (rad).
720
- if rot_idx:
721
- rot_id = _identity_rotation_vector(spec.rotation_format)
722
- n_per_arm = rot_id.shape[0]
723
- if len(rot_idx) % n_per_arm != 0:
724
- raise ValueError(
725
- f"ROT dims ({len(rot_idx)}) not a multiple of "
726
- f"rotation_format={spec.rotation_format!r} dim ({n_per_arm})"
727
- )
728
- rotations = action[:, rot_idx].reshape(num_frames, -1, n_per_arm)
729
- angles = _rotation_angle_per_arm(rotations, spec.rotation_format) # (T, n_arms)
730
- idle &= angles.max(axis=1) < eps_r
731
-
732
- # GRIPPER: max |Δgripper| across all gripper dims; step 0's diff is 0.
733
- if grip_idx:
734
- gripper = action[:, grip_idx]
735
- diff = np.abs(np.diff(gripper, axis=0, prepend=gripper[:1]))
736
- idle &= diff.max(axis=1) < eps_g
737
-
738
- # JOINT: same frame-diff scheme with joint_threshold.
739
- if joint_idx:
740
- joints = action[:, joint_idx]
741
- diff = np.abs(np.diff(joints, axis=0, prepend=joints[:1]))
742
- idle &= diff.max(axis=1) < joint_threshold
743
-
744
- if min_streak > 1:
745
- idle = _consecutive_streaks(idle, min_streak)
746
-
747
- return int(idle.sum())
 
19
  canonical public entrypoint for representation conversion.
20
  """
21
 
 
22
  from typing import Literal
23
 
24
  import numpy as np
 
539
  current_pose = next_pose
540
 
541
  return np.stack(poses_abs) # [T,4,4]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosmos-framework/cosmos_framework/data/vfm/action/robomind_franka_dataset.py CHANGED
@@ -16,7 +16,6 @@
16
 
17
  from __future__ import annotations
18
 
19
- import math
20
  import os
21
  from typing import Any, cast
22
 
@@ -25,12 +24,7 @@ import torch
25
  import torch.nn.functional as F
26
 
27
  from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import (
28
- ActionSpec,
29
  BaseActionLeRobotDataset,
30
- Gripper,
31
- Pos,
32
- Rot,
33
- build_action_spec,
34
  )
35
  from cosmos_framework.data.vfm.action.pose_utils import (
36
  PoseConvention,
@@ -77,14 +71,6 @@ class RoboMINDFrankaDataset(BaseActionLeRobotDataset):
77
  # 1.5°/s) so a single arm doing a slow approach (~1mm/f at 10 Hz) is no
78
  # longer classified as idle.
79
  #
80
- # Class defaults below match single-arm. Dual-arm overrides at instance
81
- # construction (see ``__init__``).
82
- _IDLE_EPS_T_SINGLE: float = 22e-3
83
- _IDLE_EPS_R_SINGLE: float = math.radians(3.0)
84
- _IDLE_EPS_T_DUAL: float = 5e-3 # = base default; tight enough
85
- _IDLE_EPS_R_DUAL: float = math.radians(1.5) # for "single-arm-slow" cases
86
- idle_eps_t_per_sec: float = _IDLE_EPS_T_SINGLE
87
- idle_eps_r_per_sec: float = _IDLE_EPS_R_SINGLE
88
 
89
  def __init__(
90
  self,
@@ -113,7 +99,6 @@ class RoboMINDFrankaDataset(BaseActionLeRobotDataset):
113
  split_val_ratio=split_val_ratio,
114
  split=split,
115
  mode=mode,
116
- embodiment_type=embodiment_type,
117
  viewpoint=viewpoint,
118
  pose_convention=pose_convention,
119
  rotation_format="rot6d",
@@ -121,15 +106,10 @@ class RoboMINDFrankaDataset(BaseActionLeRobotDataset):
121
  enable_fast_init=enable_fast_init,
122
  )
123
 
 
124
  self._to_opencv: np.ndarray = _ROBOMIND_FRANKA_TO_OPENCV[:3, :3]
125
  self._is_concat_view: bool = viewpoint == "concat_view"
126
 
127
- # Per-embodiment idle thresholds (instance-level override of the
128
- # class default which matches single-arm). Dual-arm tightens both
129
- # eps_t and eps_r to reflect its smaller per-frame motion tail.
130
- if embodiment_type == "robomind-franka-dual":
131
- self.idle_eps_t_per_sec = self._IDLE_EPS_T_DUAL
132
- self.idle_eps_r_per_sec = self._IDLE_EPS_R_DUAL
133
 
134
  embodiment_key = embodiment_type.removeprefix("robomind-")
135
  lerobot_roots = LEROBOT_ROOTS[embodiment_key]
@@ -220,26 +200,6 @@ class RoboMINDFrankaDataset(BaseActionLeRobotDataset):
220
  composite = torch.cat([top_or_front, bottom], dim=-2) # [T,C,3H/2,W]
221
  return composite # [T,C,3H/2,W]
222
 
223
- def _build_action_spec(self) -> ActionSpec:
224
- """RoboMIND Franka: 10D single-arm or 20D dual-arm.
225
-
226
- Single (``robomind-franka``):
227
- ``[Pos, Rot6d, Gripper]`` (10D)
228
-
229
- Dual (``robomind-franka-dual``):
230
- ``[L_Pos, L_Rot6d, L_Gripper, R_Pos, R_Rot6d, R_Gripper]`` (20D)
231
- """
232
- if self._embodiment_type == "robomind-franka":
233
- return build_action_spec(Pos(), Rot("rot6d"), Gripper())
234
- # dual arm
235
- return build_action_spec(
236
- Pos(prefix="left"),
237
- Rot("rot6d", prefix="left"),
238
- Gripper(prefix="left"),
239
- Pos(prefix="right"),
240
- Rot("rot6d", prefix="right"),
241
- Gripper(prefix="right"),
242
- )
243
 
244
  def __getitem__(self, idx: int) -> dict[str, Any]:
245
  mode, _, _, sample = self._fetch_sample(idx)
 
16
 
17
  from __future__ import annotations
18
 
 
19
  import os
20
  from typing import Any, cast
21
 
 
24
  import torch.nn.functional as F
25
 
26
  from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import (
 
27
  BaseActionLeRobotDataset,
 
 
 
 
28
  )
29
  from cosmos_framework.data.vfm.action.pose_utils import (
30
  PoseConvention,
 
71
  # 1.5°/s) so a single arm doing a slow approach (~1mm/f at 10 Hz) is no
72
  # longer classified as idle.
73
  #
 
 
 
 
 
 
 
 
74
 
75
  def __init__(
76
  self,
 
99
  split_val_ratio=split_val_ratio,
100
  split=split,
101
  mode=mode,
 
102
  viewpoint=viewpoint,
103
  pose_convention=pose_convention,
104
  rotation_format="rot6d",
 
106
  enable_fast_init=enable_fast_init,
107
  )
108
 
109
+ self._embodiment_type = embodiment_type
110
  self._to_opencv: np.ndarray = _ROBOMIND_FRANKA_TO_OPENCV[:3, :3]
111
  self._is_concat_view: bool = viewpoint == "concat_view"
112
 
 
 
 
 
 
 
113
 
114
  embodiment_key = embodiment_type.removeprefix("robomind-")
115
  lerobot_roots = LEROBOT_ROOTS[embodiment_key]
 
200
  composite = torch.cat([top_or_front, bottom], dim=-2) # [T,C,3H/2,W]
201
  return composite # [T,C,3H/2,W]
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  def __getitem__(self, idx: int) -> dict[str, Any]:
205
  mode, _, _, sample = self._fetch_sample(idx)
cosmos-framework/cosmos_framework/data/vfm/action/umi_lerobot_dataset.py CHANGED
@@ -11,7 +11,6 @@ import numpy as np
11
  import torch
12
  from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
13
 
14
- from cosmos_framework.data.vfm.action.action_spec import ActionSpec, Gripper, Pos, Rot, build_action_spec
15
  from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import BaseActionLeRobotDataset
16
  from cosmos_framework.data.vfm.action.pose_utils import PoseConvention, build_abs_pose_from_components, pose_abs_to_rel
17
  from cosmos_framework.data.vfm.action.viewpoint_utils import Viewpoint
@@ -45,7 +44,6 @@ class UMIFastLeRobotDataset(BaseActionLeRobotDataset):
45
  split_val_ratio=split_val_ratio,
46
  split=split,
47
  mode=mode,
48
- embodiment_type="umi",
49
  viewpoint=viewpoint,
50
  pose_convention=pose_convention,
51
  rotation_format="rot6d",
@@ -62,8 +60,6 @@ class UMIFastLeRobotDataset(BaseActionLeRobotDataset):
62
  _GRIPPER_FEATURE: observation_ts,
63
  }
64
 
65
- def _build_action_spec(self) -> ActionSpec:
66
- return build_action_spec(Pos(), Rot("rot6d"), Gripper())
67
 
68
  def _register_sources(self, shard_indices: list[int] | None = None) -> None:
69
  roots = self._all_shard_roots if shard_indices is None else [self._all_shard_roots[i] for i in shard_indices]
 
11
  import torch
12
  from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
13
 
 
14
  from cosmos_framework.data.vfm.action.cosmos3_action_lerobot import BaseActionLeRobotDataset
15
  from cosmos_framework.data.vfm.action.pose_utils import PoseConvention, build_abs_pose_from_components, pose_abs_to_rel
16
  from cosmos_framework.data.vfm.action.viewpoint_utils import Viewpoint
 
44
  split_val_ratio=split_val_ratio,
45
  split=split,
46
  mode=mode,
 
47
  viewpoint=viewpoint,
48
  pose_convention=pose_convention,
49
  rotation_format="rot6d",
 
60
  _GRIPPER_FEATURE: observation_ts,
61
  }
62
 
 
 
63
 
64
  def _register_sources(self, shard_indices: list[int] | None = None) -> None:
65
  roots = self._all_shard_roots if shard_indices is None else [self._all_shard_roots[i] for i in shard_indices]