XinKongCosmos commited on
Commit
3264a6d
·
verified ·
1 Parent(s): 80ca707

Trim unused viewer support code

Browse files
cosmos-framework/cosmos_framework/data/imaginaire/webdataset/augmentors/__init__.py DELETED
File without changes
cosmos-framework/cosmos_framework/data/imaginaire/webdataset/augmentors/augmentor.py DELETED
@@ -1,52 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: OpenMDW-1.1
3
-
4
- from collections.abc import Iterable
5
- from typing import Any, Generator, Optional
6
-
7
-
8
- class Augmentor:
9
- def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None:
10
- r"""Base augmentor class
11
-
12
- Args:
13
- input_keys (list): List of input keys
14
- output_keys (list): List of output keys
15
- args (dict): Arguments associated with the augmentation
16
- """
17
- self.input_keys = input_keys
18
- self.output_keys = output_keys
19
- self.args = args
20
-
21
- def __call__(self, *args: Any, **kwds: Any) -> Any:
22
- raise ValueError("Augmentor not implemented")
23
-
24
-
25
- class IterableAugmentor:
26
- def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None:
27
- r"""Base augmentor class
28
-
29
- Args:
30
- input_keys (list): List of input keys
31
- output_keys (list): List of output keys
32
- args (dict): Arguments associated with the augmentation
33
- """
34
- self.input_keys = input_keys
35
- self.output_keys = output_keys
36
- self.args = args
37
- self.is_generator = True
38
-
39
- def __call__(self, data: Iterable) -> Generator:
40
- r"""Example usage:
41
-
42
- for data_dict in data:
43
- # Do something to data_dict
44
- data_dict["input"] = data_dict["raw_sequence"][:, :-1]
45
- data_dict["target"] = data_dict["raw_sequence"][:, 1:]
46
- # Skip sample if needed
47
- if data_dict["input"].shape[1] < 64:
48
- continue
49
- # Construct a generator
50
- yield data_dict
51
- """
52
- raise ValueError("Augmentor not implemented")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosmos-framework/cosmos_framework/data/vfm/action/cosmos3_action_lerobot.py CHANGED
@@ -71,19 +71,6 @@ from cosmos_framework.data.vfm.action.action_spec import ( # noqa: F401 (re-ex
71
  from cosmos_framework.data.vfm.action.domain_utils import get_domain_id
72
  from cosmos_framework.data.vfm.action.pose_utils import compute_idle_frames
73
  from cosmos_framework.data.vfm.action.viewpoint_utils import Viewpoint
74
- from cosmos_framework.data.vfm.action_scripts.memprofile import (
75
- deep_size as _deep_size,
76
- )
77
- from cosmos_framework.data.vfm.action_scripts.memprofile import (
78
- fmt_mb as _fmt_mb,
79
- )
80
- from cosmos_framework.data.vfm.action_scripts.memprofile import (
81
- log_worker_memory_breakdown,
82
- rss_tracker,
83
- )
84
- from cosmos_framework.data.vfm.action_scripts.memprofile import (
85
- memprofile_enabled as _memprofile_enabled,
86
- )
87
 
88
  # ---------------------------------------------------------------------------
89
  # LRU-capped VideoDecoderCache
@@ -305,69 +292,47 @@ class BaseActionLeRobotDataset(Dataset):
305
  super().__init__()
306
  _ensure_hf_hub_offline()
307
  _patch_decoder_cache()
308
- self._memprofile = _memprofile_enabled()
309
-
310
  assert sample_stride >= 1, f"sample_stride must be >= 1, got {sample_stride}"
311
  assert fast_init_max_workers >= 1, f"fast_init_max_workers must be >= 1, got {fast_init_max_workers}"
312
- with rss_tracker(f"{self.__class__.__name__}.__init__", enabled=self._memprofile):
313
- self._fps = fps
314
- self._dt = 1.0 / fps
315
- self._chunk_length = chunk_length
316
- self._split_seed = split_seed
317
- self._split_val_ratio = split_val_ratio
318
- self._split = _normalize_split(split)
319
- self._mode = mode
320
- self._embodiment_type = embodiment_type
321
- self._viewpoint: Viewpoint = viewpoint
322
- self._pose_convention = pose_convention
323
- self._rotation_format = rotation_format
324
- self._tolerance_s = tolerance_s
325
- self._max_loaded_datasets = max_loaded_datasets
326
- self._skip_video_loading = skip_video_loading
327
- self._sample_stride = sample_stride
328
- self._enable_fast_init = enable_fast_init
329
- self._fast_init_max_workers = fast_init_max_workers
330
- self._delta_timestamps: dict[str, list[float]] = {}
331
- self._to_opencv: np.ndarray | dict[str, np.ndarray] = np.eye(3, dtype=np.float32)
332
-
333
- if pose_convention is None:
334
- log.warning(
335
- f"{self.__class__.__name__}: pose_convention is not set. "
336
- "Consider specifying 'backward_framewise' or 'backward_anchored'."
337
- )
338
 
339
- self._datasets: list[LeRobotDataset | None] = []
340
- self._dataset_build_args: list[dict[str, Any] | None] = []
341
- self._loaded_lru: OrderedDict[int, None] = OrderedDict()
342
-
343
- # -- Flat index structures (populated by _append_index_records) --
344
- # Together these two lists form a searchable map from a flat
345
- # global index to (dataset, row, episode, frame). One entry per
346
- # episode span across *all* registered sources.
347
- #
348
- # _episode_records[i] = (ds_idx, sample_start, valid_len, episode_id)
349
- # ds_idx – which source dataset (index into _datasets)
350
- # sample_start first row of this span in that dataset's table
351
- # valid_len – number of usable frames in this span
352
- # episode_id – the episode this span belongs to
353
- #
354
- # _episode_cum_ends[i] = running total of valid_len through span i
355
- # Used for O(log N) lookup via bisect_right in _resolve_index.
356
- self._episode_records: list[tuple[int, int, int, int]] = []
357
- self._episode_cum_ends: list[int] = []
358
- self._num_valid_indices = 0
359
- self._domain_id = get_domain_id(self._embodiment_type)
360
-
361
- # Deferred-init shard roots — a list of root paths.
362
- # Subclasses populate this in __init__; _register_sources()
363
- # reads _delta_timestamps and _tolerance_s from self (both
364
- # initialised above, with _delta_timestamps overridden by
365
- # each subclass).
366
- # ActionUnifiedIterableDataset.assign_worker uses len() for
367
- # round-robin shard distribution and _register_sources(indices)
368
- # for deferred loading. When empty, shard distribution is
369
- # skipped (every worker iterates the full dataset).
370
- self._all_shard_roots: list[str] = []
371
 
372
  # -- public properties ---------------------------------------------------
373
 
@@ -428,42 +393,30 @@ class BaseActionLeRobotDataset(Dataset):
428
  if repo_id == "local" and revision is None:
429
  revision = "local"
430
 
431
- with rss_tracker(f"{cls}{label_str} metadata load", enabled=self._memprofile):
432
- if prefetched_meta is not None:
433
- meta = prefetched_meta
434
- else:
435
- meta = LeRobotDatasetMetadata(
436
- repo_id=repo_id,
437
- root=root,
438
- revision=revision,
439
- force_cache_sync=force_cache_sync,
440
- )
441
- ds_idx = len(self._datasets)
442
- self._datasets.append(None)
443
- self._dataset_build_args.append(
444
- {
445
- "repo_id": repo_id,
446
- "root": root,
447
- "delta_timestamps": delta_timestamps,
448
- "tolerance_s": tolerance_s,
449
- "force_cache_sync": force_cache_sync,
450
- "download_videos": download_videos,
451
- "video_backend": video_backend,
452
- "revision": revision,
453
- }
454
  )
455
-
456
- with rss_tracker(
457
- f"{cls}{label_str} — index records",
458
- enabled=self._memprofile,
459
- extras_fn=lambda: [
460
- f"episode_records so far: {len(self._episode_records)} entries, "
461
- f"~{_fmt_mb(_deep_size(self._episode_records) / (1024 * 1024))}",
462
- f"episode_cum_ends so far: {len(self._episode_cum_ends)} entries, "
463
- f"~{_fmt_mb(_deep_size(self._episode_cum_ends) / (1024 * 1024))}",
464
- ],
465
- ):
466
- self._append_index_records(meta=meta, ds_idx=ds_idx, dataset_label=dataset_label)
 
 
 
467
 
468
  return meta
469
 
@@ -584,35 +537,30 @@ class BaseActionLeRobotDataset(Dataset):
584
  evict_idx, _ = self._loaded_lru.popitem(last=False)
585
  self._datasets[evict_idx] = None
586
 
587
- with rss_tracker(
588
- f"[WORKER {_os.getpid()}] Lazy-loaded ds[{ds_idx}]",
589
- enabled=self._memprofile,
590
- extras_fn=lambda: [f"total loaded={len(self._loaded_lru)}/{len(self._datasets)}"],
591
- ):
592
- delta_ts = build_args["delta_timestamps"]
593
- if self._skip_video_loading:
594
- # Covers both LeRobot v2 (``observation.images.<name>``) and
595
- # v3 (``observation.image.<name>``) video-column conventions.
596
- delta_ts = {k: v for k, v in delta_ts.items() if not k.startswith("observation.image")}
597
-
598
- log.info(f"Loading shard root={build_args['root']}")
599
- ds = LeRobotDataset(
600
- repo_id=build_args["repo_id"],
601
- root=build_args["root"],
602
- delta_timestamps=delta_ts,
603
- tolerance_s=build_args["tolerance_s"],
604
- force_cache_sync=build_args["force_cache_sync"],
605
- download_videos=build_args["download_videos"],
606
- video_backend=build_args["video_backend"],
607
- revision=build_args["revision"],
608
- episodes=None,
609
- )
610
- if self._skip_video_loading:
611
- ds.meta.info["features"] = {
612
- k: v for k, v in ds.meta.info["features"].items() if v.get("dtype") != "video"
613
- }
614
- self._datasets[ds_idx] = ds
615
- self._loaded_lru[ds_idx] = None
616
 
617
  return ds
618
 
@@ -688,15 +636,7 @@ class BaseActionLeRobotDataset(Dataset):
688
  mode = self._choose_mode()
689
  dataset_idx, row_idx, _, _ = self._resolve_index(idx)
690
 
691
- self._getitem_count = getattr(self, "_getitem_count", 0) + 1
692
- profile = self._memprofile and self._getitem_count % 50 == 1
693
-
694
- with rss_tracker(
695
- f"[WORKER {_os.getpid()}] __getitem__ transient (dataset_idx={dataset_idx})",
696
- enabled=profile,
697
- after_fn=lambda: log_worker_memory_breakdown(self),
698
- ):
699
- sample = self._get_dataset(dataset_idx)[row_idx]
700
 
701
  if self._skip_video_loading:
702
  sample = defaultdict(lambda: None, sample)
 
71
  from cosmos_framework.data.vfm.action.domain_utils import get_domain_id
72
  from cosmos_framework.data.vfm.action.pose_utils import compute_idle_frames
73
  from cosmos_framework.data.vfm.action.viewpoint_utils import Viewpoint
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # ---------------------------------------------------------------------------
76
  # LRU-capped VideoDecoderCache
 
292
  super().__init__()
293
  _ensure_hf_hub_offline()
294
  _patch_decoder_cache()
 
 
295
  assert sample_stride >= 1, f"sample_stride must be >= 1, got {sample_stride}"
296
  assert fast_init_max_workers >= 1, f"fast_init_max_workers must be >= 1, got {fast_init_max_workers}"
297
+ self._fps = fps
298
+ self._dt = 1.0 / fps
299
+ self._chunk_length = chunk_length
300
+ self._split_seed = split_seed
301
+ self._split_val_ratio = split_val_ratio
302
+ self._split = _normalize_split(split)
303
+ self._mode = mode
304
+ self._embodiment_type = embodiment_type
305
+ self._viewpoint: Viewpoint = viewpoint
306
+ self._pose_convention = pose_convention
307
+ self._rotation_format = rotation_format
308
+ self._tolerance_s = tolerance_s
309
+ self._max_loaded_datasets = max_loaded_datasets
310
+ self._skip_video_loading = skip_video_loading
311
+ self._sample_stride = sample_stride
312
+ self._enable_fast_init = enable_fast_init
313
+ self._fast_init_max_workers = fast_init_max_workers
314
+ self._delta_timestamps: dict[str, list[float]] = {}
315
+ self._to_opencv: np.ndarray | dict[str, np.ndarray] = np.eye(3, dtype=np.float32)
316
+
317
+ if pose_convention is None:
318
+ log.warning(
319
+ f"{self.__class__.__name__}: pose_convention is not set. "
320
+ "Consider specifying 'backward_framewise' or 'backward_anchored'."
321
+ )
 
322
 
323
+ self._datasets: list[LeRobotDataset | None] = []
324
+ self._dataset_build_args: list[dict[str, Any] | None] = []
325
+ self._loaded_lru: OrderedDict[int, None] = OrderedDict()
326
+
327
+ # -- Flat index structures (populated by _append_index_records) --
328
+ # Together these two lists form a searchable map from a flat
329
+ # global index to (dataset, row, episode, frame). One entry per
330
+ # episode span across all registered sources.
331
+ self._episode_records: list[tuple[int, int, int, int]] = []
332
+ self._episode_cum_ends: list[int] = []
333
+ self._num_valid_indices = 0
334
+ self._domain_id = get_domain_id(self._embodiment_type)
335
+ self._all_shard_roots: list[str] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
  # -- public properties ---------------------------------------------------
338
 
 
393
  if repo_id == "local" and revision is None:
394
  revision = "local"
395
 
396
+ if prefetched_meta is not None:
397
+ meta = prefetched_meta
398
+ else:
399
+ meta = LeRobotDatasetMetadata(
400
+ repo_id=repo_id,
401
+ root=root,
402
+ revision=revision,
403
+ force_cache_sync=force_cache_sync,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  )
405
+ ds_idx = len(self._datasets)
406
+ self._datasets.append(None)
407
+ self._dataset_build_args.append(
408
+ {
409
+ "repo_id": repo_id,
410
+ "root": root,
411
+ "delta_timestamps": delta_timestamps,
412
+ "tolerance_s": tolerance_s,
413
+ "force_cache_sync": force_cache_sync,
414
+ "download_videos": download_videos,
415
+ "video_backend": video_backend,
416
+ "revision": revision,
417
+ }
418
+ )
419
+ self._append_index_records(meta=meta, ds_idx=ds_idx, dataset_label=dataset_label)
420
 
421
  return meta
422
 
 
537
  evict_idx, _ = self._loaded_lru.popitem(last=False)
538
  self._datasets[evict_idx] = None
539
 
540
+ delta_ts = build_args["delta_timestamps"]
541
+ if self._skip_video_loading:
542
+ # Covers both LeRobot v2 (``observation.images.<name>``) and
543
+ # v3 (``observation.image.<name>``) video-column conventions.
544
+ delta_ts = {k: v for k, v in delta_ts.items() if not k.startswith("observation.image")}
545
+
546
+ log.info(f"Loading shard root={build_args['root']}")
547
+ ds = LeRobotDataset(
548
+ repo_id=build_args["repo_id"],
549
+ root=build_args["root"],
550
+ delta_timestamps=delta_ts,
551
+ tolerance_s=build_args["tolerance_s"],
552
+ force_cache_sync=build_args["force_cache_sync"],
553
+ download_videos=build_args["download_videos"],
554
+ video_backend=build_args["video_backend"],
555
+ revision=build_args["revision"],
556
+ episodes=None,
557
+ )
558
+ if self._skip_video_loading:
559
+ ds.meta.info["features"] = {
560
+ k: v for k, v in ds.meta.info["features"].items() if v.get("dtype") != "video"
561
+ }
562
+ self._datasets[ds_idx] = ds
563
+ self._loaded_lru[ds_idx] = None
 
 
 
 
 
564
 
565
  return ds
566
 
 
636
  mode = self._choose_mode()
637
  dataset_idx, row_idx, _, _ = self._resolve_index(idx)
638
 
639
+ sample = self._get_dataset(dataset_idx)[row_idx]
 
 
 
 
 
 
 
 
640
 
641
  if self._skip_video_loading:
642
  sample = defaultdict(lambda: None, sample)
cosmos-framework/cosmos_framework/data/vfm/action/urdf_visualizer/action_datasets.py CHANGED
@@ -16,9 +16,6 @@ from cosmos_framework.data.vfm.action.fractal import FractalLeRobotDataset
16
  from cosmos_framework.data.vfm.action.robomind_franka_dataset import RoboMINDFrankaDataset
17
  from cosmos_framework.data.vfm.action.umi_lerobot_dataset import UMIFastLeRobotDataset
18
 
19
- _DEFAULT_LUSTRE_DATASET_ROOT = "/lustre/fsw/portfolios/cosmos/projects/cosmos_base_training/cosmos3_action_datasets"
20
-
21
-
22
  @dataclass
23
  class LazyCall:
24
  """Tiny LazyCall replacement sufficient for the standalone viewer."""
 
16
  from cosmos_framework.data.vfm.action.robomind_franka_dataset import RoboMINDFrankaDataset
17
  from cosmos_framework.data.vfm.action.umi_lerobot_dataset import UMIFastLeRobotDataset
18
 
 
 
 
19
  @dataclass
20
  class LazyCall:
21
  """Tiny LazyCall replacement sufficient for the standalone viewer."""
cosmos-framework/cosmos_framework/data/vfm/action/urdf_visualizer/viewer.py CHANGED
@@ -21,12 +21,12 @@ from __future__ import annotations
21
 
22
  import argparse
23
  import importlib
 
24
  import os
25
  import random
26
  import sys
27
  import time as _time
28
  from dataclasses import dataclass, field
29
- from functools import lru_cache
30
  from pathlib import Path
31
  from typing import Any, cast
32
 
@@ -279,92 +279,6 @@ def _format_sample_text(value: Any, max_chars: int | None = None) -> str:
279
  return text[:max_chars]
280
 
281
 
282
- def _build_viewer_idle_action_spec(action_format: ActionFormat) -> Any:
283
- """Build a fallback idle-frame spec from the viewer-declared action format."""
284
-
285
- from cosmos_framework.data.vfm.action.action_spec import Gripper, Pos, Rot, build_action_spec
286
-
287
- if action_format is ActionFormat.EGO_9D:
288
- return build_action_spec(Pos(prefix="ego"), Rot("rot6d", prefix="ego"))
289
- if action_format is ActionFormat.SINGLE_ARM_10D:
290
- return build_action_spec(Pos(), Rot("rot6d"), Gripper())
291
- if action_format is ActionFormat.DUAL_ARM_20D:
292
- return build_action_spec(
293
- Pos(prefix="left"),
294
- Rot("rot6d", prefix="left"),
295
- Gripper(prefix="left"),
296
- Pos(prefix="right"),
297
- Rot("rot6d", prefix="right"),
298
- Gripper(prefix="right"),
299
- )
300
- raise ValueError(f"Unsupported action format for idle-frame detection: {action_format}")
301
-
302
-
303
- def _compute_viewer_idle_frames(
304
- action: Any,
305
- dataset: Any,
306
- action_format: ActionFormat,
307
- ) -> torch.Tensor | None:
308
- """Compute idle frames for a viewer sample when the dataset did not provide them."""
309
-
310
- action_spec = getattr(dataset, "action_spec", None)
311
- compute_idle_frames_method = getattr(dataset, "_compute_idle_frames", None)
312
- if action_spec is not None and compute_idle_frames_method is not None:
313
- return compute_idle_frames_method(action)
314
-
315
- from cosmos_framework.data.vfm.action.pose_utils import compute_idle_frames
316
-
317
- spec = _build_viewer_idle_action_spec(action_format)
318
- try:
319
- idle_frames = compute_idle_frames(action, spec)
320
- except (TypeError, ValueError) as error:
321
- log.warning(f"Viewer idle-frame detection skipped for {action_format.value}: {error}")
322
- return None
323
- return torch.tensor(idle_frames, dtype=torch.long) # []
324
-
325
-
326
- @lru_cache(maxsize=1)
327
- def _get_viewer_idle_frames_augmentor() -> Any:
328
- """Return the caption augmentor used by the viewer idle-frame path."""
329
-
330
- from cosmos_framework.data.vfm.augmentors.idle_frames_text_info import IdleFramesTextInfo
331
-
332
- return IdleFramesTextInfo(
333
- input_keys=["ai_caption", "idle_frames", "action"],
334
- output_keys=["ai_caption"],
335
- args={
336
- "caption_key": "ai_caption",
337
- "idle_frames_key": "idle_frames",
338
- "action_key": "action",
339
- "dropout_rate": 0.0,
340
- "enabled": True,
341
- },
342
- )
343
-
344
-
345
- def _enable_viewer_idle_frames(sample: dict[str, Any], dataset: Any, action_format: ActionFormat) -> dict[str, Any]:
346
- """Populate idle-frame metadata and append text in the direct viewer data path."""
347
-
348
- updated_sample = sample
349
- idle_frames = updated_sample.get("idle_frames")
350
- action = updated_sample.get("action")
351
- if idle_frames is None and action is not None:
352
- idle_frames = _compute_viewer_idle_frames(action, dataset, action_format)
353
- if idle_frames is not None:
354
- updated_sample = dict(updated_sample)
355
- updated_sample["idle_frames"] = idle_frames
356
-
357
- if idle_frames is None:
358
- return updated_sample
359
-
360
- updated_sample = dict(updated_sample)
361
- caption = updated_sample.get("ai_caption")
362
- if isinstance(caption, dict):
363
- updated_sample["ai_caption"] = dict(caption)
364
-
365
- augmented_sample = _get_viewer_idle_frames_augmentor()(updated_sample)
366
- return updated_sample if augmented_sample is None else augmented_sample
367
-
368
 
369
  class _IterableToMapDataset:
370
  """Wraps an IterableDataset into a random-access dataset with lazy loading."""
@@ -687,7 +601,7 @@ def launch_viewer(
687
  ep_idx = n_total - 1
688
  ep_input.value = ep_idx
689
 
690
- sample: Any = _enable_viewer_idle_frames(dataset[ep_idx], dataset, effective_action_format)
691
 
692
  action_tensor = sample["action"]
693
  action_raw = (
 
21
 
22
  import argparse
23
  import importlib
24
+ from functools import lru_cache
25
  import os
26
  import random
27
  import sys
28
  import time as _time
29
  from dataclasses import dataclass, field
 
30
  from pathlib import Path
31
  from typing import Any, cast
32
 
 
279
  return text[:max_chars]
280
 
281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  class _IterableToMapDataset:
284
  """Wraps an IterableDataset into a random-access dataset with lazy loading."""
 
601
  ep_idx = n_total - 1
602
  ep_input.value = ep_idx
603
 
604
+ sample: Any = dataset[ep_idx]
605
 
606
  action_tensor = sample["action"]
607
  action_raw = (
cosmos-framework/cosmos_framework/data/vfm/action/viewpoint_utils.py CHANGED
@@ -1,114 +1,10 @@
1
  # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
  # SPDX-License-Identifier: OpenMDW-1.1
3
 
4
- """Viewpoint type definitions and caption augmentor for Action datasets.
5
-
6
- Provides a ``Viewpoint`` type alias for camera perspective labels and a
7
- ``ViewpointTextInfo`` augmentor that appends a human-readable viewpoint
8
- description to the caption string.
9
- """
10
 
11
  from __future__ import annotations
12
 
13
  from typing import Literal
14
 
15
- from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor
16
- from cosmos_framework.utils import log
17
-
18
  Viewpoint = Literal["ego_view", "third_person_view", "wrist_view", "concat_view"]
19
-
20
- DEFAULT_VIEWPOINT_TEMPLATES: dict[str, str] = {
21
- "ego_view": "This video is captured from a first-person perspective looking at the scene.",
22
- "third_person_view": "This video is captured from a third-person perspective looking towards the agent from the front.",
23
- "wrist_view": "This video is captured from a wrist-mounted camera.",
24
- "concat_view": "This video contains concatenated views from multiple camera perspectives.",
25
- }
26
-
27
-
28
- class ViewpointTextInfo(Augmentor):
29
- """Augmentor that appends viewpoint type description to captions.
30
-
31
- Reads a viewpoint label from ``data_dict[viewpoint_key]`` and appends
32
- the corresponding template sentence to the caption. Designed to run
33
- after the raw ``ai_caption`` is set but before duration/FPS metadata
34
- is appended.
35
-
36
- Args:
37
- input_keys: Input keys (kept for API compatibility).
38
- output_keys: Output keys (kept for API compatibility).
39
- args: Configuration arguments:
40
- - caption_key (str): Key for caption in data_dict. Default: ``"ai_caption"``
41
- - viewpoint_key (str): Key for viewpoint label. Default: ``"viewpoint"``
42
- - templates (dict): Override mapping from viewpoint to sentence.
43
- Default: :data:`DEFAULT_VIEWPOINT_TEMPLATES`
44
- - separator (str): Separator between caption and metadata. Default: ``". "``
45
- - enabled (bool): Whether augmentation is enabled. Default: ``True``
46
- """
47
-
48
- def __init__(
49
- self,
50
- input_keys: list | None = None,
51
- output_keys: list | None = None,
52
- args: dict | None = None,
53
- ) -> None:
54
- super().__init__(input_keys or [], output_keys or [], args)
55
-
56
- self.caption_key: str = args.get("caption_key", "ai_caption") if args else "ai_caption"
57
- self.viewpoint_key: str = args.get("viewpoint_key", "viewpoint") if args else "viewpoint"
58
- self.templates: dict[str, str] = (
59
- args.get("templates", DEFAULT_VIEWPOINT_TEMPLATES) if args else DEFAULT_VIEWPOINT_TEMPLATES
60
- )
61
- self.default_separator: str = args.get("separator", ". ") if args else ". "
62
- self.enabled: bool = args.get("enabled", True) if args else True
63
-
64
- def __call__(self, data_dict: dict) -> dict | None:
65
- """Append viewpoint description to the caption.
66
-
67
- If the sample provides an ``"additional_view_description"`` key (a
68
- free-form string describing the concatenated camera layout), it is
69
- appended after the generic ``concat_view`` template. This allows each
70
- dataset to supply its own description of which cameras are tiled and
71
- how.
72
-
73
- Args:
74
- data_dict: Sample dictionary containing caption and viewpoint.
75
-
76
- Returns:
77
- The mutated *data_dict*, or the original unchanged if the
78
- viewpoint key is missing or unrecognized.
79
- """
80
- if not self.enabled:
81
- return data_dict
82
-
83
- viewpoint = data_dict.get(self.viewpoint_key)
84
- if viewpoint is None:
85
- raise ValueError(
86
- f"ViewpointTextInfo: missing key {self.viewpoint_key!r} in data_dict. "
87
- f"All action datasets must provide a viewpoint label."
88
- )
89
-
90
- # Append dataset-specific concat_view details after the base template.
91
- additional_view_description = data_dict.pop("additional_view_description", None)
92
- template = self.templates.get(viewpoint)
93
-
94
- if template is None:
95
- log.warning(
96
- f"ViewpointTextInfo: unrecognized viewpoint {viewpoint!r}. "
97
- f"Known viewpoints: {sorted(self.templates.keys())}. Skipping.",
98
- rank0_only=False,
99
- )
100
- return data_dict
101
-
102
- if additional_view_description:
103
- separator = " " if template.endswith(".") else self.default_separator
104
- template = template + separator + additional_view_description.rstrip()
105
-
106
- caption = data_dict.get(self.caption_key)
107
- if not isinstance(caption, str) or caption == "":
108
- return data_dict
109
-
110
- caption = caption.rstrip()
111
- separator = " " if caption.endswith(".") else self.default_separator
112
- data_dict[self.caption_key] = caption + separator + template
113
-
114
- return data_dict
 
1
  # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
  # SPDX-License-Identifier: OpenMDW-1.1
3
 
4
+ """Viewpoint type definitions for release action datasets."""
 
 
 
 
 
5
 
6
  from __future__ import annotations
7
 
8
  from typing import Literal
9
 
 
 
 
10
  Viewpoint = Literal["ego_view", "third_person_view", "wrist_view", "concat_view"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosmos-framework/cosmos_framework/data/vfm/action_scripts/__init__.py DELETED
File without changes
cosmos-framework/cosmos_framework/data/vfm/action_scripts/memprofile.py DELETED
@@ -1,254 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: OpenMDW-1.1
3
-
4
- """Lightweight CPU memory-profiling helpers.
5
-
6
- Only depends on ``os``, ``sys``, and ``psutil`` so it can be imported safely
7
- from dataset modules without pulling in heavy dependencies.
8
-
9
- Enable per-stage logging by setting the ``MEMORY_PROFILE`` env var::
10
-
11
- MEMORY_PROFILE=1 torchrun ...
12
- """
13
-
14
- import contextlib
15
- import gc
16
- import logging
17
- import os
18
- import sys
19
- from collections.abc import Callable, Iterator
20
-
21
- import psutil
22
-
23
- _log = logging.getLogger(__name__)
24
-
25
-
26
- def memprofile_enabled() -> bool:
27
- """Return ``True`` when the ``MEMORY_PROFILE`` env var is truthy."""
28
- return os.environ.get("MEMORY_PROFILE", "").strip() not in ("", "0", "false")
29
-
30
-
31
- def fmt_mb(mb: float) -> str:
32
- """Format a MiB value as a human-readable string (MiB or GiB)."""
33
- if mb >= 1024:
34
- return f"{mb / 1024:.2f} GiB"
35
- return f"{mb:.1f} MiB"
36
-
37
-
38
- @contextlib.contextmanager
39
- def rss_tracker(
40
- label: str,
41
- *,
42
- enabled: bool | None = None,
43
- extras_fn: Callable[[], list[str]] | None = None,
44
- after_fn: Callable[[], None] | None = None,
45
- ) -> Iterator[None]:
46
- """Track RSS delta across a block. No-op when profiling is disabled.
47
-
48
- When *enabled* is ``False`` (or ``None`` and ``MEMORY_PROFILE`` is unset)
49
- the context manager yields immediately with zero overhead -- no
50
- ``gc.collect()`` and no ``psutil`` calls.
51
-
52
- Args:
53
- label: Human-readable description included in the log line.
54
- enabled: Explicit toggle. When ``None``, falls back to
55
- ``memprofile_enabled()`` (i.e. the ``MEMORY_PROFILE`` env var).
56
- extras_fn: Optional callback invoked *after* the measured block.
57
- Each returned string is logged as a supplementary detail line.
58
- after_fn: Optional side-effect callback invoked after logging.
59
- Use for actions that should only run when profiling is active
60
- (e.g. detailed worker memory breakdowns).
61
- """
62
- if enabled is None:
63
- enabled = memprofile_enabled()
64
- if not enabled:
65
- yield
66
- return
67
- gc.collect()
68
- rss_before = get_rss_mb()
69
- yield
70
- gc.collect()
71
- rss_after = get_rss_mb()
72
- _log.debug(
73
- "[MEMPROFILE] %s | RSS: %s (delta: +%s)",
74
- label,
75
- fmt_mb(rss_after),
76
- fmt_mb(rss_after - rss_before),
77
- )
78
- if extras_fn is not None:
79
- for line in extras_fn():
80
- _log.debug("[MEMPROFILE] %s", line)
81
- if after_fn is not None:
82
- after_fn()
83
-
84
-
85
- def get_rss_mb() -> float:
86
- """Return the current process RSS in MiB."""
87
- return psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
88
-
89
-
90
- def get_process_tree_rss_mb() -> float:
91
- """Return RSS of the current process + all children in MiB."""
92
- proc = psutil.Process(os.getpid())
93
- total = proc.memory_info().rss
94
- for child in proc.children(recursive=True):
95
- try:
96
- total += child.memory_info().rss
97
- except (psutil.NoSuchProcess, psutil.AccessDenied):
98
- pass
99
- return total / (1024 * 1024)
100
-
101
-
102
- def get_worker_memory_breakdown() -> list[tuple[int, float]]:
103
- """Return a list of ``(pid, rss_mib)`` for each child process."""
104
- proc = psutil.Process(os.getpid())
105
- result: list[tuple[int, float]] = []
106
- for child in proc.children(recursive=True):
107
- try:
108
- rss_mb = child.memory_info().rss / (1024 * 1024)
109
- result.append((child.pid, rss_mb))
110
- except (psutil.NoSuchProcess, psutil.AccessDenied):
111
- pass
112
- return result
113
-
114
-
115
- def get_worker_memory_detailed() -> list[dict[str, float]]:
116
- """Return RSS, USS (Unique Set Size), and PSS for each child process.
117
-
118
- USS is the memory *unique* to a process -- not shared with any other.
119
- It directly measures CoW-duplicated pages plus worker-only allocations.
120
-
121
- PSS counts shared pages proportionally (shared_page / num_sharers).
122
-
123
- Returns list of dicts with keys: ``pid``, ``rss``, ``uss``, ``pss`` (all in MiB).
124
- Falls back to RSS-only if ``memory_full_info()`` is unavailable.
125
- """
126
- proc = psutil.Process(os.getpid())
127
- result: list[dict[str, float]] = []
128
- for child in proc.children(recursive=True):
129
- try:
130
- full = child.memory_full_info()
131
- result.append(
132
- {
133
- "pid": float(child.pid),
134
- "rss": full.rss / (1024 * 1024),
135
- "uss": full.uss / (1024 * 1024),
136
- "pss": full.pss / (1024 * 1024),
137
- }
138
- )
139
- except (psutil.NoSuchProcess, psutil.AccessDenied, AttributeError):
140
- try:
141
- rss_mb = child.memory_info().rss / (1024 * 1024)
142
- result.append(
143
- {
144
- "pid": float(child.pid),
145
- "rss": rss_mb,
146
- "uss": -1.0,
147
- "pss": -1.0,
148
- }
149
- )
150
- except (psutil.NoSuchProcess, psutil.AccessDenied):
151
- pass
152
- return result
153
-
154
-
155
- def get_uss_mb() -> float:
156
- """Return USS (Unique Set Size) of the current process in MiB.
157
-
158
- Falls back to RSS if ``memory_full_info()`` is unavailable.
159
- """
160
- proc = psutil.Process(os.getpid())
161
- try:
162
- return proc.memory_full_info().uss / (1024 * 1024)
163
- except (AttributeError, psutil.AccessDenied):
164
- return proc.memory_info().rss / (1024 * 1024)
165
-
166
-
167
- def log_worker_memory_breakdown(dataset: object) -> None:
168
- """Log a detailed memory breakdown from inside a dataloader worker.
169
-
170
- Designed to be called periodically from ``__getitem__`` when
171
- ``MEMORY_PROFILE=1``. Inspects the dataset's internal state to
172
- report how many ``LeRobotDataset`` instances are loaded, HuggingFace
173
- Arrow table sizes, and the LeRobot ``VideoDecoderCache`` size.
174
-
175
- Args:
176
- dataset: A ``BaseActionLeRobotDataset`` instance (or compatible).
177
- """
178
- import gc
179
- import logging
180
-
181
- pid = os.getpid()
182
- rss = get_rss_mb()
183
- uss = get_uss_mb()
184
- logger = logging.getLogger(f"memprofile.worker.{pid}")
185
-
186
- logger.warning(f"[WORKER {pid}] RSS={fmt_mb(rss)} USS={fmt_mb(uss)}")
187
-
188
- # --- LeRobotDataset instances ---
189
- datasets_list = getattr(dataset, "_datasets", [])
190
- loaded_count = sum(1 for ds in datasets_list if ds is not None)
191
- total_count = len(datasets_list)
192
- logger.warning(f"[WORKER {pid}] LeRobotDataset: {loaded_count}/{total_count} loaded")
193
-
194
- total_arrow_bytes = 0
195
- total_hf_rows = 0
196
- for i, ds in enumerate(datasets_list):
197
- if ds is None:
198
- continue
199
- hf_ds = getattr(ds, "hf_dataset", None)
200
- if hf_ds is None:
201
- logger.warning(f"[WORKER {pid}] ds[{i}]: hf_dataset not yet loaded")
202
- continue
203
-
204
- num_rows = len(hf_ds)
205
- total_hf_rows += num_rows
206
-
207
- arrow_bytes = 0
208
- data_table = getattr(hf_ds, "_data", None)
209
- if data_table is not None and hasattr(data_table, "nbytes"):
210
- arrow_bytes = data_table.nbytes
211
- total_arrow_bytes += arrow_bytes
212
-
213
- logger.warning(f"[WORKER {pid}] ds[{i}]: rows={num_rows}, arrow={fmt_mb(arrow_bytes / (1024 * 1024))}")
214
-
215
- if loaded_count > 0:
216
- logger.warning(
217
- f"[WORKER {pid}] Total HF rows={total_hf_rows}, total arrow={fmt_mb(total_arrow_bytes / (1024 * 1024))}"
218
- )
219
-
220
- # --- VideoDecoderCache ---
221
- try:
222
- from lerobot.datasets.video_utils import _default_decoder_cache
223
-
224
- cache_size = _default_decoder_cache.size()
225
- logger.warning(f"[WORKER {pid}] VideoDecoderCache entries: {cache_size}")
226
- except Exception:
227
- pass
228
-
229
- # --- GC stats ---
230
- gc_counts = gc.get_count()
231
- all_objects = len(gc.get_objects())
232
- logger.warning(f"[WORKER {pid}] GC counts={gc_counts}, tracked objects={all_objects}")
233
-
234
-
235
- def deep_size(obj: object, seen: set | None = None) -> int:
236
- """Approximate deep memory size in bytes for nested Python containers.
237
-
238
- Recursively walks ``dict``, ``list``, ``tuple``, ``set``, and ``frozenset``.
239
- Does **not** follow arbitrary object attributes.
240
- """
241
- if seen is None:
242
- seen = set()
243
- obj_id = id(obj)
244
- if obj_id in seen:
245
- return 0
246
- seen.add(obj_id)
247
- size = sys.getsizeof(obj)
248
- if isinstance(obj, dict):
249
- for k, v in obj.items():
250
- size += deep_size(k, seen) + deep_size(v, seen)
251
- elif isinstance(obj, (list, tuple, set, frozenset)):
252
- for item in obj:
253
- size += deep_size(item, seen)
254
- return size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cosmos-framework/cosmos_framework/data/vfm/augmentors/__init__.py DELETED
File without changes
cosmos-framework/cosmos_framework/data/vfm/augmentors/idle_frames_text_info.py DELETED
@@ -1,10 +0,0 @@
1
- class IdleFramesTextInfo:
2
- """Minimal standalone replacement for viewer caption augmentation."""
3
-
4
- def __init__(self, input_keys=None, output_keys=None, args=None):
5
- self.input_keys = input_keys or []
6
- self.output_keys = output_keys or []
7
- self.args = args or {}
8
-
9
- def __call__(self, sample, *args, **kwargs):
10
- return sample