prithivMLmods commited on
Commit
24cca24
·
verified ·
1 Parent(s): edffd89

Delete lyra_2

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. lyra_2/__init__.py +0 -14
  2. lyra_2/_ext/__init__.py +0 -4
  3. lyra_2/_ext/imaginaire/__init__.py +0 -14
  4. lyra_2/_ext/imaginaire/checkpointer/__init__.py +0 -14
  5. lyra_2/_ext/imaginaire/checkpointer/base.py +0 -177
  6. lyra_2/_ext/imaginaire/checkpointer/dcp.py +0 -1003
  7. lyra_2/_ext/imaginaire/checkpointer/dummy.py +0 -47
  8. lyra_2/_ext/imaginaire/checkpointer/s3_filesystem.py +0 -323
  9. lyra_2/_ext/imaginaire/config.py +0 -445
  10. lyra_2/_ext/imaginaire/flags.py +0 -52
  11. lyra_2/_ext/imaginaire/functional/batch_ops.py +0 -61
  12. lyra_2/_ext/imaginaire/functional/lr_scheduler.py +0 -178
  13. lyra_2/_ext/imaginaire/lazy_config/__init__.py +0 -60
  14. lyra_2/_ext/imaginaire/lazy_config/file_io.py +0 -24
  15. lyra_2/_ext/imaginaire/lazy_config/instantiate.py +0 -120
  16. lyra_2/_ext/imaginaire/lazy_config/lazy.py +0 -430
  17. lyra_2/_ext/imaginaire/lazy_config/omegaconf_patch.py +0 -65
  18. lyra_2/_ext/imaginaire/lazy_config/registry.py +0 -74
  19. lyra_2/_ext/imaginaire/model.py +0 -129
  20. lyra_2/_ext/imaginaire/trainer.py +0 -379
  21. lyra_2/_ext/imaginaire/types/denoise_prediction.py +0 -28
  22. lyra_2/_ext/imaginaire/utils/__init__.py +0 -14
  23. lyra_2/_ext/imaginaire/utils/callback.py +0 -524
  24. lyra_2/_ext/imaginaire/utils/checkpointer.py +0 -372
  25. lyra_2/_ext/imaginaire/utils/config_helper.py +0 -214
  26. lyra_2/_ext/imaginaire/utils/context_managers.py +0 -55
  27. lyra_2/_ext/imaginaire/utils/count_params.py +0 -29
  28. lyra_2/_ext/imaginaire/utils/device.py +0 -114
  29. lyra_2/_ext/imaginaire/utils/distributed.py +0 -443
  30. lyra_2/_ext/imaginaire/utils/easy_io/__init__.py +0 -14
  31. lyra_2/_ext/imaginaire/utils/easy_io/backends/__init__.py +0 -34
  32. lyra_2/_ext/imaginaire/utils/easy_io/backends/auto_auth.py +0 -64
  33. lyra_2/_ext/imaginaire/utils/easy_io/backends/base_backend.py +0 -60
  34. lyra_2/_ext/imaginaire/utils/easy_io/backends/boto3_backend.py +0 -841
  35. lyra_2/_ext/imaginaire/utils/easy_io/backends/boto3_client.py +0 -565
  36. lyra_2/_ext/imaginaire/utils/easy_io/backends/http_backend.py +0 -91
  37. lyra_2/_ext/imaginaire/utils/easy_io/backends/local_backend.py +0 -550
  38. lyra_2/_ext/imaginaire/utils/easy_io/backends/registry_utils.py +0 -130
  39. lyra_2/_ext/imaginaire/utils/easy_io/easy_io.py +0 -1085
  40. lyra_2/_ext/imaginaire/utils/easy_io/file_client.py +0 -458
  41. lyra_2/_ext/imaginaire/utils/easy_io/handlers/__init__.py +0 -29
  42. lyra_2/_ext/imaginaire/utils/easy_io/handlers/base.py +0 -44
  43. lyra_2/_ext/imaginaire/utils/easy_io/handlers/byte_handler.py +0 -39
  44. lyra_2/_ext/imaginaire/utils/easy_io/handlers/csv_handler.py +0 -42
  45. lyra_2/_ext/imaginaire/utils/easy_io/handlers/gzip_handler.py +0 -33
  46. lyra_2/_ext/imaginaire/utils/easy_io/handlers/imageio_video_handler.py +0 -168
  47. lyra_2/_ext/imaginaire/utils/easy_io/handlers/json_handler.py +0 -49
  48. lyra_2/_ext/imaginaire/utils/easy_io/handlers/jsonl_handler.py +0 -80
  49. lyra_2/_ext/imaginaire/utils/easy_io/handlers/np_handler.py +0 -89
  50. lyra_2/_ext/imaginaire/utils/easy_io/handlers/pandas_handler.py +0 -31
lyra_2/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- """
2
- Copied from cosmos repository, with necessary modifications.
3
- Original source: https://github.com/nvidia-cosmos/cosmos-predict2.5/
4
- """
 
 
 
 
 
lyra_2/_ext/imaginaire/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/checkpointer/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/checkpointer/base.py DELETED
@@ -1,177 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import os
17
- from abc import ABC, abstractmethod
18
- from typing import Optional
19
-
20
- import torch
21
-
22
- from lyra_2._ext.imaginaire.config import CheckpointConfig, JobConfig
23
- from lyra_2._ext.imaginaire.model import ImaginaireModel
24
- from lyra_2._ext.imaginaire.utils import callback
25
- from lyra_2._ext.imaginaire.utils.easy_io import easy_io
26
-
27
-
28
- class AbstractCheckpointer(ABC):
29
- """The checkpointer class. Supports checkpoint saving/loading to both local disk or object store."""
30
-
31
- def __init__(
32
- self,
33
- config_checkpoint: CheckpointConfig,
34
- config_job: JobConfig,
35
- callbacks: Optional[callback.CallBackGroup] = None,
36
- ):
37
- """Constructor of the checkpointer.
38
-
39
- Args:
40
- config_checkpoint (CheckpointConfig): The config object for the checkpointer.
41
- """
42
- self.config_checkpoint = config_checkpoint
43
- # Set the callback functions.
44
- self.callbacks = callbacks
45
- self.save_to_object_store = config_checkpoint.save_to_object_store.enabled
46
- self.load_from_object_store = config_checkpoint.load_from_object_store.enabled
47
-
48
- # Set checkpoint directories for local and object store paths
49
- self._local_dirname = os.path.join(config_job.path_local, "checkpoints")
50
- self._object_store_dirname = os.path.join(config_job.path, "checkpoints")
51
-
52
- self.strict_resume = config_checkpoint.strict_resume
53
- self.load_path = config_checkpoint.load_path or None
54
- self.load_training_state = config_checkpoint.load_training_state
55
- self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state
56
- self.save_thread = None
57
- self.verbose = config_checkpoint.verbose
58
- self.keys_not_to_resume = config_checkpoint.keys_not_to_resume
59
- self.broadcast_via_filesystem = config_checkpoint.broadcast_via_filesystem
60
- # Create the object store client interface.
61
- if config_checkpoint.load_from_object_store.enabled:
62
- self.load_s3_backend_key = "_ckpt_s3_loader"
63
- easy_io.set_s3_backend(
64
- key="_ckpt_s3_loader",
65
- backend_args={
66
- "backend": "s3",
67
- "path_mapping": {
68
- "s3://ckpt/": f"s3://{config_checkpoint.load_from_object_store.bucket}/",
69
- },
70
- "s3_credential_path": config_checkpoint.load_from_object_store.credentials,
71
- },
72
- )
73
- else:
74
- self.load_s3_backend_key = None
75
-
76
- if config_checkpoint.save_to_object_store.enabled:
77
- self.save_s3_backend_key = "_ckpt_s3_saver"
78
- easy_io.set_s3_backend(
79
- key="_ckpt_s3_saver",
80
- backend_args={
81
- "backend": "s3",
82
- "path_mapping": {
83
- "s3://ckpt/": f"s3://{config_checkpoint.save_to_object_store.bucket}/",
84
- },
85
- "s3_credential_path": config_checkpoint.save_to_object_store.credentials,
86
- },
87
- )
88
- else:
89
- self.save_s3_backend_key = None
90
-
91
- @abstractmethod
92
- def save(
93
- self,
94
- model: ImaginaireModel,
95
- optimizer: torch.optim.Optimizer,
96
- scheduler: torch.optim.lr_scheduler.LRScheduler,
97
- grad_scaler: torch.amp.GradScaler,
98
- iteration: int,
99
- ) -> None:
100
- pass
101
-
102
- @abstractmethod
103
- def load(
104
- self,
105
- model: ImaginaireModel,
106
- optimizer: Optional[torch.optim.Optimizer] = None,
107
- scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
108
- grad_scaler: Optional[torch.amp.GradScaler] = None,
109
- ) -> int:
110
- pass
111
-
112
- @property
113
- def save_bucket(self):
114
- """Get the bucket name for saving checkpoints."""
115
- return self.config_checkpoint.save_to_object_store.bucket if self.save_to_object_store else None
116
-
117
- @property
118
- def load_bucket(self):
119
- """Get the bucket name for loading checkpoints."""
120
- return self.config_checkpoint.load_from_object_store.bucket if self.load_from_object_store else None
121
-
122
- @property
123
- def save_dirname(self):
124
- return (
125
- f"s3://{self.save_bucket}/{self._object_store_dirname}"
126
- if self.save_to_object_store
127
- else self._local_dirname
128
- )
129
-
130
- @property
131
- def load_dirname(self):
132
- return (
133
- f"s3://{self.load_bucket}/{self._object_store_dirname}"
134
- if self.load_from_object_store
135
- else self._local_dirname
136
- )
137
-
138
- def finalize(self) -> None:
139
- """Finalize the checkpointer."""
140
- if self.save_thread:
141
- self.save_thread.join()
142
-
143
- def _read_latest_checkpoint_file(self) -> str | None:
144
- """Get the file name of the latest saved checkpoint. If it doesn't exist, return None.
145
-
146
- Returns:
147
- checkpoint_file (str | None): file name of the latest saved checkpoint.
148
- """
149
- checkpoint_file = None
150
- checkpoint_path = os.path.join(self.load_dirname, "latest_checkpoint.txt")
151
- if easy_io.exists(f"{checkpoint_path}", backend_key=self.load_s3_backend_key):
152
- checkpoint_file = easy_io.load(f"{checkpoint_path}", backend_key=self.load_s3_backend_key).strip()
153
-
154
- return checkpoint_file
155
-
156
- def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None:
157
- """Track the file name of the latest saved checkpoint.
158
-
159
- Args:
160
- checkpoint_file (str): file name of the latest saved checkpoint.
161
- """
162
- content = f"{checkpoint_file}\n"
163
- checkpoint_path = os.path.join(self.save_dirname, "latest_checkpoint.txt")
164
- easy_io.dump(
165
- content,
166
- checkpoint_path,
167
- backend_key=self.save_s3_backend_key,
168
- )
169
-
170
- def _check_checkpoint_exists(self, checkpoint_path: str) -> None:
171
- """If the file checkpoint_path does not exist, raise an error.
172
-
173
- Args:
174
- checkpoint_path (str): full path to the checkpoint.
175
- """
176
- if not easy_io.exists(f"{checkpoint_path}", backend_key=self.load_s3_backend_key):
177
- raise FileNotFoundError(f"File not found (object store): {checkpoint_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/checkpointer/dcp.py DELETED
@@ -1,1003 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
-
17
- """
18
- Distributed checkpoint (DCP) directory structure and storage backends.
19
-
20
- The checkpointer saves model state in a sharded format across multiple processes:
21
-
22
- self.save_dirname/
23
- ├── iter_000000005/ # Checkpoint at iteration 5
24
- │ ├── model/ # Model state shards
25
- │ │ ├── __0_0.distcp # Shard 0 from rank 0
26
- │ │ └── __1_0.distcp # Shard 1 from rank 1
27
- │ ├── optim/ # Optimizer state shards
28
- │ │ ├── __0_0.distcp # Shard 0 from rank 0
29
- │ │ └── __1_0.distcp # Shard 1 from rank 1
30
- │ ├── scheduler/ # Learning rate scheduler state
31
- │ │ ├── __0_0.distcp # Shard 0 from rank 0
32
- │ │ └── __1_0.distcp # Shard 1 from rank 1
33
- │ └── trainer/ # Additional training state
34
- │ ├── __0_0.distcp # Shard 0 from rank 0
35
- │ └── __1_0.distcp # Shard 1 from rank 1
36
- └── latest_checkpoint.txt # Points to most recent checkpoint folder, e.g. iter_000000005
37
-
38
- Storage path format:
39
- self.save_dirname = "{config_job.path_local}/checkpoints"
40
-
41
- The sharded format enables efficient distributed saving/loading by:
42
- 1. Parallelizing I/O across processes
43
- 2. Reducing memory usage per process
44
- 3. Supporting both local and cloud storage backends
45
- """
46
-
47
- import enum
48
- import functools
49
- import multiprocessing
50
- import os
51
- import queue
52
- import re
53
- import time
54
- import warnings
55
- from collections import namedtuple
56
- from multiprocessing import get_context
57
- from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
58
-
59
- import torch
60
- import torch.distributed
61
- import torch.distributed.checkpoint as dcp
62
- from torch import nn
63
- from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
64
- from torch.distributed.checkpoint._storage_utils import _storage_setup
65
- from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
66
- from torch.distributed.checkpoint.logger import _dcp_method_logger
67
- from torch.distributed.checkpoint.state_dict import (
68
- StateDictOptions,
69
- get_model_state_dict,
70
- get_optimizer_state_dict,
71
- set_model_state_dict,
72
- set_optimizer_state_dict,
73
- )
74
- from torch.distributed.checkpoint.stateful import Stateful
75
- from torch.distributed.checkpoint.storage import StorageReader
76
- from torch.distributed.checkpoint.utils import _api_bc_check, _DistWrapper, _profile
77
- from torch.distributed.tensor import distribute_tensor
78
-
79
- from lyra_2._ext.imaginaire.checkpointer.base import AbstractCheckpointer
80
- from lyra_2._ext.imaginaire.checkpointer.s3_filesystem import S3StorageReader, S3StorageWriter
81
- from lyra_2._ext.imaginaire.config import CheckpointConfig, JobConfig
82
- from lyra_2._ext.imaginaire.model import ImaginaireModel
83
- from lyra_2._ext.imaginaire.utils import callback, distributed, log, misc
84
- from lyra_2._ext.imaginaire.utils.easy_io import easy_io
85
- from lyra_2._src.models.wan_t2v_model import WANDiffusionModel as DiffusionModel
86
-
87
- try:
88
- """
89
- We override the default load function to skip loadding _extra_state keys created by transformer-engine.
90
- """
91
- from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner as _DefaultLoadPlanner
92
- from torch.distributed.checkpoint.default_planner import (
93
- DTensor,
94
- LoadPlan,
95
- _create_read_items,
96
- _version,
97
- flatten_state_dict,
98
- )
99
- from torch.distributed.checkpoint.metadata import Metadata, TensorStorageMetadata
100
-
101
- def create_default_local_load_plan(
102
- state_dict: dict[str, Any], metadata: Metadata, strict: bool = True, dcp_allow_mismatched_size: bool = False
103
- ) -> LoadPlan:
104
- requests = []
105
- """
106
- Create the ``LoadPlan`` used by DefaultLoadPlanner.
107
-
108
- It produces one read item per value in ``state_dict`` using the metadata in ``metadata``.
109
-
110
- The default behavior is to match key exactly between state_dict and metadata.
111
- It handles resharding by issuing multiple read requests against storage in order to match
112
- load requirements.
113
- """
114
-
115
- for fqn, obj in state_dict.items():
116
- if fqn.endswith("._extra_state"): # dirty TE attention package!
117
- continue
118
- # ignore state_dict keys which do not exist in `state_dict` if strict=False
119
- if fqn not in metadata.state_dict_metadata:
120
- if strict:
121
- raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
122
- else:
123
- log.warning(f"Local load plan: Missing key in checkpoint state_dict: {fqn}.")
124
- continue
125
-
126
- md = metadata.state_dict_metadata[fqn]
127
-
128
- if not dcp_allow_mismatched_size:
129
- if (
130
- isinstance(md, TensorStorageMetadata)
131
- and getattr(obj, "size", None) is not None
132
- and md.size != obj.size()
133
- ):
134
- if not strict:
135
- log.critical(f"Size mismatch between saved {md.size} and current: {obj.size()} for {fqn}")
136
- continue
137
- else:
138
- raise ValueError(
139
- f"Size mismatch between saved {md.size} and current: {obj.size()} for {fqn}",
140
- )
141
- # Since DTensor supports submesh, adding extra check to ensure _create_read_items()
142
- # gets called only when the current rank is part of the mesh for the corresponding DTensor.
143
- if isinstance(obj, DTensor):
144
- if obj.device_mesh.get_coordinate() is not None:
145
- requests += _create_read_items(fqn, md, obj)
146
- else:
147
- requests += _create_read_items(fqn, md, obj)
148
-
149
- return LoadPlan(requests)
150
-
151
- class DefaultLoadPlanner(_DefaultLoadPlanner):
152
- def set_partial_channel_weight(self, dcp_allow_mismatched_size: bool):
153
- self.dcp_allow_mismatched_size = dcp_allow_mismatched_size
154
-
155
- def create_local_plan(self) -> LoadPlan:
156
- assert self.metadata is not None
157
- if self.flatten_state_dict:
158
- # To support checkpoints that are saved before v2.4, we have to
159
- # differentiate if the missing keys are due to old checkpoints.
160
- # The contracts are:
161
- # 1. There are 3 cases when we found a missing key.
162
- # 1.1 Actual missing key, but allow_partial_load is False
163
- # 1.2 Actual missing key, but allow_partial load is True
164
- # 1.3 Old checkpoint, but allow_partial_load is False
165
- # 1.4 Old checkpoint, but allow_partial_load is True
166
- # 2. If we found a missing key, we first convert the keys back to
167
- # the key format of v2.3
168
- # 3. If the previous missing keys are in the v2.3 keys, we assume
169
- # this is a old checkpoint.
170
- # 4. Pass the state_dict to `create_default_local_load_plan()`,
171
- # which has the logic to check missing for allow_partial_load.
172
- # So for 1.2 and 1.4 cases, we delegate allow_partial_load check to
173
- # `create_default_local_load_plan()`. The logic here is to determine
174
- # whether the checkpoint belong to 2.3 (or before) or 2.4 (or after).
175
- current_keys = set(self.state_dict.keys())
176
- load_keys = set(self.metadata.state_dict_metadata.keys())
177
- missing_keys = load_keys - current_keys
178
- if missing_keys:
179
- _version._derived_version = "2_3"
180
- old_state_dict, old_mappings = flatten_state_dict(self.original_state_dict)
181
- old_keys = set(old_state_dict.keys())
182
- if old_keys & missing_keys:
183
- self.state_dict, self.mappings = old_state_dict, old_mappings
184
- # _derived_version is only used by flatten_state_dict now.
185
- # Set it back to None so that later we can save to a new version.
186
- _version._derived_version = None
187
-
188
- return create_default_local_load_plan(
189
- self.state_dict,
190
- self.metadata,
191
- not self.allow_partial_load,
192
- getattr(self, "dcp_allow_mismatched_size", False),
193
- )
194
-
195
- log.info("for the back comptiable pytorch! New DefaultLoadPlanner class is created.")
196
-
197
- @_dcp_method_logger(log_exceptions=True)
198
- @_api_bc_check
199
- def load(
200
- state_dict: dict[str, Any],
201
- *,
202
- checkpoint_id: Union[str, os.PathLike, None] = None,
203
- storage_reader: Optional[StorageReader] = None,
204
- planner: Optional[DefaultLoadPlanner] = None,
205
- process_group: Optional[torch.distributed.ProcessGroup] = None,
206
- no_dist: bool = False,
207
- ) -> None:
208
- """
209
- We override the default load function to perform a load plan check for mismatched/missing keys.
210
- ==========================Original Doc string=====================================
211
- Load a checkpoint into a distributed state dict in SPMD style.
212
-
213
- Each rank must have the same keys in their ``state_dict`` provided to this
214
- API. Mismatched keys may result in hangs or errors. If unsure, you can use
215
- the ``utils._assert_same_keys`` API to check (but may incur communication
216
- costs).
217
-
218
- Each rank will try to read the least amount of data necessary
219
- to fullfill the requested `state_dict`. When loading :class:`ShardedTensor`
220
- or :class:`DTensor` instances, each rank only reads data for their local shards.
221
-
222
- For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``),
223
- load will first call ``state_dict`` before attempting deserialization, followed by
224
- ``load_state_dict`` once the deserialization is complete.
225
- For each non-``Stateful`` object, load will deserailize the object, and then replace
226
- it in the ``state_dict`` with the deserialized object.
227
-
228
- .. warning::
229
- All tensors in ``state_dict`` must be allocated on their
230
- destination device *prior to* calling this function.
231
-
232
- All non-tensor data is loaded using `torch.load()` and modified in place
233
- on state_dict.
234
-
235
- .. warning::
236
- Users must call `load_state_dict` on the root module to ensure load
237
- pos-processing and non-tensor data properly propagates.
238
-
239
- .. note:
240
- If no process group is initialized, this function will assume the intent
241
- is to load a checkpoint into the local process. This can be useful in the
242
- case of local inference, and when using regular Tensors (as opposed to DTensor
243
- or ShardedTensor)
244
-
245
- .. note:
246
- Rank 0 is assumed to be the coordinator rank.
247
-
248
- Args:
249
- state_dict (Dict[str, Any]): The state_dict to load the checkpoint into.
250
- checkpoint_id (Union[str, os.PathLike, None]):
251
- The ID of this checkpoint instance. The meaning of the checkpoint_id
252
- depends on the storage. It can be a path to a folder or to a file.
253
- It can also be a key if the storage is a key-value store.
254
- (Default: ``None``)
255
- storage_reader (Optional[StorageReader]):
256
- Instance of StorageWriter used to perform reads. If this is not
257
- specified, DCP will automatically infer the reader based on the
258
- checkpoint_id. If checkpoint_id is also None, an exception will
259
- be raised. (Default: ``None``)
260
- planner (Optional[LoadPlanner]):
261
- Instance of LoadPlanner. If this is not specificed, the default
262
- planner will be used. (Default: ``None``)
263
- process_group (Optional[ProcessGroup]):
264
- ProcessGroup to be used for cross-rank synchronization.
265
- (Default: ``None``)
266
- no_dist (bool): If ``True``, this function will assume the intent is to load
267
- a checkpoint without using cross-rank synchronization. (Default: ``False``)
268
- Returns:
269
- None.
270
-
271
- Examples
272
- >>> # xdoctest: +SKIP
273
- >>> my_model = MyModule()
274
- >>> optimizer = Adagrad(my_model.parameters())
275
- >>> model_state_dict = my_model.state_dict()
276
- >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(
277
- ... "/checkpoint/1"
278
- ... )
279
-
280
- >>> torch.distributed.checkpoint.load_state_dict(
281
- >>> state_dict=model_state_dict,
282
- >>> storage_reader=fs_storage_reader,
283
- >>> )
284
-
285
- >>> # module.load_state_dict() function might have customized steps
286
- >>> # to flush the state_dict, must call it to
287
- >>> # ensure correct behavior.
288
- >>> my_model.load_state_dict(model_state_dict)
289
-
290
- .. note::
291
- load_state_dict uses collectives to coordinate reads across ranks.
292
- For NCCL-based process groups, internal tensor representations of
293
- objects must be moved to the GPU device before communication takes place.
294
- In this case, the device used is given by ``torch.cuda.current_device()``
295
- and it is the user's responsibility to ensure that this is set so that each
296
- rank has an individual GPU, via ``torch.cuda.set_device()``.
297
- """
298
-
299
- no_dist = no_dist or (not torch.distributed.is_available()) or (not torch.distributed.is_initialized())
300
- if no_dist:
301
- warnings.warn(
302
- "torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to load in a single process."
303
- )
304
-
305
- with _profile():
306
- storage_reader = cast(StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True))
307
-
308
- # All ranks must have the same keys in their `state_dict` provided to
309
- # this API. See documentation for more details.
310
- # Here we simply sort the keys to ensure that all ranks load values in
311
- # the same order.
312
- keys = sorted(state_dict.keys())
313
-
314
- statetful_sd = {}
315
- for key in keys:
316
- if key not in state_dict:
317
- continue
318
- elem = state_dict[key]
319
- statetful_sd[key] = elem.state_dict() if isinstance(elem, Stateful) else elem
320
-
321
- _load_state_dict(
322
- state_dict=statetful_sd,
323
- storage_reader=storage_reader,
324
- process_group=process_group,
325
- no_dist=no_dist,
326
- planner=planner,
327
- )
328
- for key in keys:
329
- if key not in state_dict:
330
- continue
331
- elem = state_dict[key]
332
- if isinstance(elem, Stateful):
333
- # If the state_dict is a Stateful object,
334
- # DCP does an in-place load in the original state dict.
335
- elem.load_state_dict(statetful_sd[key])
336
- else:
337
- # Otherwise, replace the state_dict with the loaded state_dict.
338
- state_dict[key] = statetful_sd[key]
339
-
340
- def _load_state_dict(
341
- state_dict: dict[str, Any],
342
- storage_reader: StorageReader,
343
- process_group: Optional[torch.distributed.ProcessGroup] = None,
344
- coordinator_rank: int = 0,
345
- no_dist: bool = False,
346
- planner: Optional[DefaultLoadPlanner] = None,
347
- ) -> None:
348
- torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict")
349
-
350
- distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
351
- if planner is None:
352
- planner = DefaultLoadPlanner()
353
-
354
- ckpt_kwargs = {}
355
- if (ckpt_id := getattr(storage_reader, "checkpoint_id", None)) is not None:
356
- ckpt_kwargs["checkpoint_id"] = ckpt_id
357
- ckpt_kwargs["process_group"] = distW.group
358
-
359
- @_dcp_method_logger(**ckpt_kwargs)
360
- def local_step():
361
- assert planner is not None
362
- metadata = storage_reader.read_metadata()
363
- planner.set_up_planner(state_dict, metadata, distW.is_coordinator)
364
- storage_reader.set_up_storage_reader(metadata, distW.is_coordinator)
365
-
366
- local_plan = planner.create_local_plan()
367
- local_plan = storage_reader.prepare_local_plan(local_plan)
368
- return local_plan
369
-
370
- @_dcp_method_logger(**ckpt_kwargs)
371
- def global_step(all_local_plans):
372
- assert planner is not None
373
- all_local_plans = planner.create_global_plan(all_local_plans)
374
- all_local_plans = storage_reader.prepare_global_plan(all_local_plans)
375
- return all_local_plans
376
-
377
- central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step)
378
- if distW.is_coordinator:
379
- # Compare central_plan items with storage_reader.storage_data keys
380
- dest_fqns = set()
381
- storage_fqns = set()
382
-
383
- # Extract FQNs from central_plan items
384
- for item in central_plan.items:
385
- if hasattr(item, "dest_index") and hasattr(item.dest_index, "fqn"):
386
- dest_fqns.add(item.dest_index.fqn)
387
- if hasattr(item, "storage_index") and hasattr(item.storage_index, "fqn"):
388
- storage_fqns.add(item.storage_index.fqn)
389
-
390
- # Get storage data keys
391
- storage_data_keys = set()
392
- if hasattr(storage_reader, "storage_data") and storage_reader.storage_data is not None:
393
- storage_data_keys = set(item[0].fqn for item in storage_reader.storage_data.items())
394
- state_dict_keys = set(state_dict.keys())
395
- # Compare sets and log differences
396
- # Remove any item that has "_extra_state" as substring in the sets
397
- state_dict_keys = {fqn for fqn in state_dict_keys if "_extra_state" not in fqn}
398
- dest_fqns = {fqn for fqn in dest_fqns if "_extra_state" not in fqn}
399
- storage_fqns = {fqn for fqn in storage_fqns if "_extra_state" not in fqn}
400
- storage_data_keys = {fqn for fqn in storage_data_keys if "_extra_state" not in fqn}
401
-
402
- log.info("=== Load Plan FQN Analysis ===")
403
- log.info(f"State Dict FQNs count: {len(state_dict_keys)}")
404
- log.info(f"Destination FQNs count (without _extra_state): {len(dest_fqns)}")
405
- log.info(f"Loaded FQNs count (without _extra_state): {len(storage_fqns)}")
406
- log.info(f"In Storage keys count (without _extra_state): {len(storage_data_keys)}")
407
-
408
- # Find missing keys in each direction
409
- state_dict_missing_from_dest = state_dict_keys - dest_fqns
410
- storage_data_missing_from_storage_fqns = storage_data_keys - storage_fqns
411
-
412
- if state_dict_missing_from_dest:
413
- log.info(
414
- f"State Dict FQNs missing from load plan ({len(state_dict_missing_from_dest)} items): {sorted(state_dict_missing_from_dest)}"
415
- )
416
- else:
417
- log.info("✓ All State Dict FQNs found in storage_data")
418
-
419
- if storage_data_missing_from_storage_fqns:
420
- # If there are more than 100 "net_ema" keys in storage_data_missing_from_storage_fqns, summarize them
421
- net_ema_keys = {k for k in storage_data_missing_from_storage_fqns if "net_ema" in k}
422
- if len(net_ema_keys) > 100:
423
- storage_data_missing_from_storage_fqns = storage_data_missing_from_storage_fqns - net_ema_keys
424
- storage_data_missing_from_storage_fqns = set(
425
- storage_data_missing_from_storage_fqns
426
- ) # ensure set type
427
- storage_data_missing_from_storage_fqns.add("net_ema")
428
- log.info(
429
- f"Summarized {len(net_ema_keys)} 'net_ema' keys as 'net_ema' in missing storage data keys."
430
- )
431
- log.info(
432
- f"Storage data keys not loaded by load plan ({len(storage_data_missing_from_storage_fqns)} items): {sorted(storage_data_missing_from_storage_fqns)}"
433
- )
434
- else:
435
- log.info("✓ All storage data keys found in Loaded FQNs")
436
-
437
- log.info("=== End Load Plan FQN Analysis ===")
438
-
439
- @_dcp_method_logger(**ckpt_kwargs)
440
- def read_data():
441
- assert planner is not None
442
- final_local_plan = planner.finish_plan(central_plan)
443
- all_reads = storage_reader.read_data(final_local_plan, planner)
444
-
445
- all_reads.wait()
446
- return None
447
-
448
- _ = distW.all_gather("read", read_data)
449
-
450
- except ImportError as e:
451
- from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
452
-
453
- log.critical(f"{e}, using default planner")
454
-
455
- StateDictItemPath = namedtuple("StateDictItemPath", ["state_dict", "save_path"])
456
-
457
-
458
- class ModelWrapper(Stateful):
459
- """Wrapper for model state dict handling"""
460
-
461
- def __init__(self, model: Union[nn.Module, List[nn.Module]], load_ema_to_reg: bool = False):
462
- self.model = [model] if isinstance(model, nn.Module) else model
463
- self.load_ema_to_reg = load_ema_to_reg
464
- if self.load_ema_to_reg:
465
- assert isinstance(model, DiffusionModel)
466
-
467
- def state_dict(self) -> Dict[str, Any]:
468
- _state_dict = {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}
469
- if self.load_ema_to_reg:
470
- if not self.model[0].config.ema.enabled:
471
- all_keys = list(_state_dict.keys())
472
- assert all(k.startswith("net.") for k in all_keys), "All keys must start with net."
473
- for k in all_keys:
474
- _state_dict[k.replace("net.", "net_ema.")] = _state_dict.pop(k)
475
- else:
476
- log.warning("EMA is enabled, will only load EMA weights from checkpoint file.")
477
- all_keys = list(_state_dict.keys())
478
- for k in all_keys:
479
- if k.startswith("net_ema."):
480
- break
481
- else:
482
- raise ValueError("No EMA keys found in state_dict")
483
- # do not load .net keys, since we do not need them anyway.
484
- _state_dict = {k: _state_dict[k] for k in all_keys if not k.startswith("net.")}
485
-
486
- if hasattr(self.model[0].config, "lora_config") and self.model[0].config.lora_config.enabled:
487
- """
488
- When using LoRA, `inject_adapter_in_model` modifies the target modules in place.
489
- For example, `blocks[0].attn.q_proj.weight` will be modified to `blocks[0].attn.q_proj.base_layer.weight`.
490
- This means that the model will have the key `blocks[0].attn.q_proj.base_layer.weight`,
491
- but the checkpoint will have the key `blocks[0].attn.q_proj.weight`.
492
- We need to map the model key to the checkpoint key.
493
- """
494
- self.checkpoint_to_model_key = {}
495
- mapping_keys = {"base_layer.": "", "base_model.model.": ""}
496
- keys_to_update = []
497
- for k in _state_dict.keys():
498
- new_key = k
499
- for from_key, to_key in mapping_keys.items():
500
- new_key = new_key.replace(from_key, to_key)
501
- if new_key != k:
502
- keys_to_update.append((k, new_key))
503
- self.checkpoint_to_model_key[new_key] = k
504
- for k, new_key in keys_to_update:
505
- _state_dict[new_key] = _state_dict.pop(k)
506
-
507
- return _state_dict
508
-
509
- def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
510
- if hasattr(self.model[0].config, "lora_config") and self.model[0].config.lora_config.enabled:
511
- if hasattr(self, "checkpoint_to_model_key"):
512
- for checkpoint_key, model_key in self.checkpoint_to_model_key.items():
513
- state_dict[model_key] = state_dict.pop(checkpoint_key)
514
- else:
515
- raise ValueError("checkpoint_to_model_key is not set by `state_dict`")
516
-
517
- if self.load_ema_to_reg:
518
- if not self.model[0].config.ema.enabled:
519
- all_keys = list(state_dict.keys())
520
- assert all(k.startswith("net_ema.") for k in all_keys), "All keys must start with net_ema."
521
- for k in all_keys:
522
- state_dict[k.replace("net_ema.", "net.")] = state_dict.pop(k)
523
- else:
524
- log.warning("EMA is enabled, will load EMA weights to regular model weights")
525
- all_keys = list(state_dict.keys())
526
- assert all(not k.startswith("net.") for k in all_keys), "No .net keys should be in state_dict"
527
- for k in all_keys:
528
- if k.startswith("net_ema."):
529
- state_dict[k.replace("net_ema.", "net.")] = torch.clone(state_dict[k])
530
- func = functools.partial(
531
- set_model_state_dict,
532
- model_state_dict=state_dict,
533
- options=StateDictOptions(strict=False),
534
- )
535
- list(map(func, self.model))
536
-
537
-
538
- class OptimizerWrapper(Stateful):
539
- def __init__(
540
- self,
541
- model: Union[nn.Module, List[nn.Module]],
542
- optim: Union[torch.optim.Optimizer, List[torch.optim.Optimizer]],
543
- ) -> None:
544
- self.model = [model] if isinstance(model, nn.Module) else model
545
- self.optim = [optim] if isinstance(optim, torch.optim.Optimizer) else optim
546
-
547
- def state_dict(self) -> Dict[str, Any]:
548
- func = functools.partial(
549
- get_optimizer_state_dict,
550
- options=StateDictOptions(flatten_optimizer_state_dict=True),
551
- )
552
- return {k: v for sd in map(func, self.model, self.optim) for k, v in sd.items()}
553
-
554
- def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
555
- func = functools.partial(
556
- set_optimizer_state_dict,
557
- optim_state_dict=state_dict,
558
- options=StateDictOptions(flatten_optimizer_state_dict=True),
559
- )
560
- list(map(func, self.model, self.optim))
561
-
562
-
563
- class AsyncMode(str, enum.Enum):
564
- DISABLED = "disabled"
565
- ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem"
566
-
567
-
568
- class Terminate:
569
- pass
570
-
571
-
572
- class SaveDone:
573
- def __init__(self, iteration: int, elapsed_time: float, succeeded: bool):
574
- self.iteration = iteration
575
- self.elapsed_time = elapsed_time
576
- self.succeeded = succeeded
577
-
578
- def __str__(self):
579
- return f"SaveDone(iteration={self.iteration}, elapsed_time={self.elapsed_time}, succeeded={self.succeeded})"
580
-
581
-
582
- def save_checkpoint_in_background(
583
- receiver_queue: multiprocessing.Queue,
584
- sender_queue: multiprocessing.Queue,
585
- checkpoint_config: CheckpointConfig,
586
- job_config: JobConfig,
587
- ) -> None:
588
- """
589
- Handles model checkpoint saving in a separate background process using PyTorch's distributed functionality.
590
- This function runs in a dedicated process to avoid blocking the main training loop.
591
-
592
- Args:
593
- receiver_queue: Queue to receive state dictionaries and commands from the main process
594
- sender_queue: Queue to send completion signals back to the main process
595
- checkpoint_config: Configuration settings for checkpoint saving behavior
596
- job_config: Configuration settings for the training job
597
-
598
- Flow:
599
- 1. Initializes distributed processing environment
600
- 2. Continuously waits for state dictionaries to save
601
- 3. Saves checkpoints asynchronously
602
- 4. Signals completion back to main process
603
- 5. Terminates when receiving a Terminate signal
604
-
605
- Raises:
606
- AssertionError: If received object is neither Terminate signal nor valid state dict tuple
607
-
608
- Note:
609
- - Uses a different port than the main process to avoid conflicts
610
- - Disables TorchElastic agent store for checkpoint operations
611
- - Automatically cleans up distributed process group on exit
612
- """
613
- # Configure distributed environment
614
- os.environ["MASTER_PORT"] = str(int(os.environ["MASTER_PORT"]) + 2)
615
- os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False"
616
-
617
- # Set up GPU device and distributed processing
618
- torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
619
- distributed.init()
620
-
621
- # Initialize checkpointing mechanism
622
- checkpoint_handler = DistributedCheckpointer(checkpoint_config, job_config, None, disable_async=True)
623
-
624
- try:
625
- while True:
626
- log.debug("Checkpoint background process is ready for next task, waiting for new state_dict")
627
- received_data = receiver_queue.get()
628
- log.debug("Received new state_dict")
629
-
630
- if isinstance(received_data, Terminate):
631
- log.info("Received termination signal in checkpoint background process, closing sender queue")
632
- sender_queue.put(Terminate())
633
- sender_queue.close()
634
- return
635
-
636
- assert isinstance(received_data, tuple), "Received data must be a tuple of (state_dict, checkpoint_path)"
637
- state_dict, checkpoint_path = received_data
638
-
639
- # Save checkpoint and measure time taken
640
- start_time = time.monotonic()
641
- iteration = state_dict["trainer"][0]["iteration"]
642
- elapsed_time = 0
643
- succeeded = False
644
- try:
645
- checkpoint_handler.save_state_dict_worker(state_dict, checkpoint_path)
646
- elapsed_time = time.monotonic() - start_time
647
- log.info(
648
- f"Checkpoint saved successfully in background process. Time taken: {elapsed_time:.2f} seconds, iteration: {iteration}"
649
- )
650
- succeeded = True
651
- except Exception as e:
652
- log.error(f"Error saving checkpoint to {checkpoint_path}: {e}")
653
- # continue because if the thread exits, the main thread keeps on adding to the queue
654
- finally:
655
- if elapsed_time == 0:
656
- elapsed_time = time.monotonic() - start_time
657
- sender_queue.put(SaveDone(iteration, elapsed_time, succeeded))
658
-
659
- finally:
660
- log.info("Cleaning up: destroying distributed process group")
661
- torch.distributed.destroy_process_group()
662
-
663
-
664
- class DistributedCheckpointer(AbstractCheckpointer):
665
- KEYS_TO_SAVE = ["model", "optim", "scheduler", "trainer"]
666
-
667
- def __init__(
668
- self,
669
- config_checkpoint: CheckpointConfig,
670
- config_job: JobConfig,
671
- callbacks: Optional[callback.CallBackGroup] = None,
672
- disable_async: bool = False,
673
- ):
674
- super().__init__(config_checkpoint, config_job, callbacks)
675
- self.config_checkpoint = config_checkpoint
676
- if config_checkpoint.dcp_async_mode_enabled:
677
- self.async_mode = AsyncMode.ASYNC_WITH_PINNED_MEM
678
- else:
679
- self.async_mode = AsyncMode.DISABLED
680
-
681
- if disable_async:
682
- self.async_mode = AsyncMode.DISABLED
683
-
684
- if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
685
- ctx = get_context("spawn")
686
- self.mp_queue_send = ctx.Queue()
687
- self.mp_queue_recv = ctx.Queue()
688
- self.mp = ctx.Process(
689
- target=save_checkpoint_in_background,
690
- args=(
691
- self.mp_queue_send,
692
- self.mp_queue_recv,
693
- config_checkpoint,
694
- config_job,
695
- ),
696
- daemon=True,
697
- )
698
- self.mp.start()
699
- self.cpu_offload_state_dict = None
700
- self.staging = False
701
- self.staging_ckpt_file = None
702
- self.staging_stream = torch.cuda.Stream()
703
-
704
- def keys_to_resume_during_load(self) -> Tuple[Set, Union[str, None]]:
705
- latest_checkpoint_file = self._read_latest_checkpoint_file()
706
-
707
- resume_keys = []
708
-
709
- if latest_checkpoint_file is not None:
710
- # 1. Resume training from latest_checkpoint.txt under the same name.
711
- checkpoint_path = os.path.join(self.load_dirname, latest_checkpoint_file)
712
- resume_keys.extend(self.KEYS_TO_SAVE)
713
- else:
714
- if self.load_path and not self.load_path.endswith(".pth"):
715
- # 2. Load the module weights specified by config_checkpoint.path.
716
- checkpoint_path = self.load_path
717
- if self.load_s3_backend_key:
718
- checkpoint_path = f"s3://{self.config_checkpoint.load_from_object_store.bucket}/{checkpoint_path}"
719
- if not re.search(r"/checkpoints/iter_\d{9}/?$", checkpoint_path):
720
- old_ckpt_path = checkpoint_path
721
- # If path doesn't end with specific checkpoint, read latest checkpoint file
722
- latest_ckpt_path = os.path.join(checkpoint_path, "checkpoints/latest_checkpoint.txt")
723
- if easy_io.exists(latest_ckpt_path, backend_key=self.load_s3_backend_key):
724
- checkpoint_file = easy_io.load(
725
- latest_ckpt_path, backend_key=self.load_s3_backend_key
726
- ).strip()
727
- checkpoint_path = f"{checkpoint_path}/checkpoints/{checkpoint_file}"
728
- else:
729
- log.warning(
730
- f"Latest checkpoint file {latest_ckpt_path} not found, load from {old_ckpt_path}"
731
- )
732
- checkpoint_path = old_ckpt_path
733
- if self.load_training_state:
734
- resume_keys.extend(self.KEYS_TO_SAVE)
735
- else:
736
- resume_keys.append("model")
737
- if self.only_load_scheduler_state:
738
- resume_keys.append("scheduler")
739
- elif self.load_path and self.load_path.endswith(".pth"):
740
- checkpoint_path = self.load_path
741
- else:
742
- checkpoint_path = None
743
- if len(self.keys_not_to_resume) > 0:
744
- for key in self.keys_not_to_resume:
745
- assert key in self.KEYS_TO_SAVE, f"Invalid key to resume: {key} not in {self.KEYS_TO_SAVE}"
746
- resume_keys = [key for key in resume_keys if key not in self.keys_not_to_resume]
747
- return set(resume_keys), checkpoint_path
748
-
749
- @misc.timer("checkpoint loading")
750
- def load(
751
- self,
752
- model: ImaginaireModel,
753
- optimizer: torch.optim.Optimizer | None = None,
754
- scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
755
- grad_scaler: torch.amp.GradScaler | None = None,
756
- ) -> int:
757
- if self.callbacks is not None:
758
- self.callbacks.on_load_checkpoint_start(model)
759
-
760
- resume_keys, checkpoint_path = self.keys_to_resume_during_load()
761
- resume_keys = sorted(resume_keys)
762
- log.info(f"Resuming ckpt {checkpoint_path} with keys: {resume_keys}")
763
-
764
- iteration = 0
765
-
766
- if checkpoint_path is not None and not checkpoint_path.endswith(".pth"):
767
- self._check_checkpoint_exists(checkpoint_path)
768
- for key in resume_keys:
769
- load_planner = DefaultLoadPlanner(allow_partial_load=True)
770
- if hasattr(load_planner, "set_partial_channel_weight"):
771
- log.info(f"set_partial_channel_weight: {self.config_checkpoint.dcp_allow_mismatched_size}")
772
- load_planner.set_partial_channel_weight(self.config_checkpoint.dcp_allow_mismatched_size)
773
- cur_key_ckpt_full_path = os.path.join(checkpoint_path, key)
774
- log.info(f"Start loading checkpoint from {checkpoint_path}")
775
- storage_reader = self.get_storage_reader(cur_key_ckpt_full_path)
776
- torch.distributed.barrier()
777
- log.info(f"starting {cur_key_ckpt_full_path}", rank0_only=False)
778
- if key == "model":
779
- log.info("- Loading the model...")
780
- _model_wrapper = ModelWrapper(model)
781
- _state_dict = _model_wrapper.state_dict()
782
- load(_state_dict, storage_reader=storage_reader, planner=load_planner)
783
- _model_wrapper.load_state_dict(_state_dict)
784
- elif key == "optim":
785
- log.info("- Loading the optimizer...")
786
- _optim_wrapper = OptimizerWrapper(model, optimizer)
787
- _state_dict = _optim_wrapper.state_dict()
788
- dcp.load(
789
- _state_dict,
790
- storage_reader=storage_reader,
791
- planner=load_planner,
792
- )
793
- _optim_wrapper.load_state_dict(_state_dict)
794
- elif key == "scheduler":
795
- log.info("- Loading the scheduler...")
796
- _state_dict = scheduler.state_dict()
797
- dcp.load(
798
- _state_dict,
799
- storage_reader=storage_reader,
800
- planner=load_planner,
801
- )
802
- scheduler.load_state_dict(_state_dict)
803
- elif key == "trainer":
804
- log.info("- Loading the trainer...")
805
- _state_dict = {
806
- "grad_scaler": grad_scaler.state_dict(),
807
- "iteration": iteration,
808
- }
809
- dcp.load(
810
- _state_dict,
811
- storage_reader=storage_reader,
812
- planner=load_planner,
813
- )
814
- grad_scaler.load_state_dict(_state_dict["grad_scaler"])
815
- iteration = _state_dict["iteration"]
816
- else:
817
- raise ValueError(f"Invalid key: {key}. not support to resume.")
818
- if self.callbacks is not None:
819
- self.callbacks.on_load_checkpoint(model, state_dict=_state_dict)
820
- log.info(f"Loaded checkpoint from {checkpoint_path} in iteration {iteration}")
821
- elif checkpoint_path is not None and checkpoint_path.endswith(".pth"):
822
- state = easy_io.load(checkpoint_path)
823
- model_state = model.net.state_dict()
824
-
825
- for k, v in list(state.items()):
826
- tgt = model_state.get(k, None)
827
- if tgt is None:
828
- continue
829
- # If target param/buffer is a DTensor and checkpoint value is not, distribute it
830
- if isinstance(tgt, DTensor) and not isinstance(v, DTensor):
831
- # Match device, dtype, and placements from the target DTensor
832
- v = v.to(tgt.device, dtype=tgt.dtype, copy=False)
833
- v = distribute_tensor(v, tgt.device_mesh, tgt.placements)
834
- state[k] = v
835
- # If target is a plain Tensor but checkpoint is a DTensor, bring it local
836
- if not isinstance(tgt, DTensor) and isinstance(v, DTensor):
837
- state[k] = v.to_local().to(tgt.device, dtype=tgt.dtype, copy=False)
838
-
839
- model.load_state_dict(state, strict=False, pretrain_copy=True)
840
- # Clear unused reserved memory from fp32
841
- torch.cuda.empty_cache()
842
- log.critical(f"Loaded checkpoint from {checkpoint_path}. **** This only happen at iteration 0 **** ")
843
- else:
844
- log.info("Training from scratch.")
845
- torch.cuda.empty_cache()
846
-
847
- if self.callbacks is not None:
848
- self.callbacks.on_load_checkpoint_end(model, iteration=iteration, checkpoint_path=checkpoint_path)
849
- return iteration
850
-
851
- def _async_with_pinned_memory(self, checkpoint_file: str, state_dict: Dict[str, Tuple[Any, str]]) -> None:
852
- try:
853
- from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
854
- except ImportError as e:
855
- raise ImportError(
856
- "Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
857
- ) from e
858
- if self.cpu_offload_state_dict is None:
859
- log.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
860
- self.cpu_offload_state_dict = _create_cpu_state_dict(state_dict, pin_memory=True, share_memory=True)
861
-
862
- log.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
863
- with torch.cuda.stream(self.staging_stream):
864
- self.cpu_offload_state_dict = _copy_state_dict(
865
- state_dict,
866
- self.cpu_offload_state_dict,
867
- non_blocking=True,
868
- )
869
- self.staging = True
870
- self.staging_ckpt_file = checkpoint_file
871
-
872
- self.maybe_wait_for_staging()
873
-
874
- def maybe_wait_for_staging(self) -> None:
875
- if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM and self.staging:
876
- if not self.staging_stream.query():
877
- self.staging_stream.synchronize()
878
-
879
- def sync_func():
880
- self.mp_queue_send.put_nowait((self.cpu_offload_state_dict, self.staging_ckpt_file))
881
-
882
- sync_func()
883
- self.staging = False
884
-
885
- def get_previous_checkpoint_results(self, wait_for: int = 0) -> None:
886
- """Get the results of previously submitted checkpoints and pass them to callbacks if checkpoint succeeded"""
887
- if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
888
- try:
889
- start_time = time.monotonic()
890
- while not self.mp_queue_recv.empty() or wait_for > 0:
891
- try:
892
- ret = self.mp_queue_recv.get(timeout=1)
893
- if isinstance(ret, Terminate):
894
- log.info("Received termination event from checkpoint background process")
895
- break
896
- save_done: SaveDone = ret
897
- log.logger.info(f"Received checkpoint save result: {save_done}")
898
- if self.callbacks is not None and save_done.succeeded:
899
- self.callbacks.on_save_checkpoint_success(
900
- iteration=save_done.iteration, elapsed_time=save_done.elapsed_time
901
- )
902
- except queue.Empty:
903
- elapsed_time = time.monotonic() - start_time
904
- if elapsed_time > wait_for:
905
- break
906
- except (EOFError, BrokenPipeError):
907
- log.info("Queue was closed by checkpoint background process")
908
-
909
- def get_storage_writer(self, checkpoint_path: str) -> Union[S3StorageWriter, FileSystemWriter]:
910
- if self.save_to_object_store:
911
- return S3StorageWriter(
912
- credential_path=self.config_checkpoint.save_to_object_store.credentials,
913
- path=checkpoint_path,
914
- )
915
- return FileSystemWriter(path=checkpoint_path)
916
-
917
- def get_storage_reader(self, checkpoint_path: str) -> Union[S3StorageReader, FileSystemReader]:
918
- if self.load_from_object_store:
919
- return S3StorageReader(
920
- credential_path=self.config_checkpoint.load_from_object_store.credentials,
921
- path=checkpoint_path,
922
- )
923
- return FileSystemReader(checkpoint_path)
924
-
925
- def save_state_dict_worker(self, to_save_dict: Dict[str, Tuple[Any, str]], checkpoint_file: str) -> None:
926
- for k, (v, full_checkpoint_path) in to_save_dict.items():
927
- storage_writer = self.get_storage_writer(full_checkpoint_path)
928
- dcp.save(
929
- v,
930
- storage_writer=storage_writer,
931
- planner=DefaultSavePlanner(dedup_save_to_lowest_rank=True),
932
- )
933
-
934
- if distributed.is_rank0():
935
- print(f"Saving last checkpoint file {checkpoint_file}")
936
- self._write_latest_checkpoint_file(checkpoint_file)
937
-
938
- log.critical(f"Saved checkpoint to {os.path.join(self.save_dirname, checkpoint_file)}", rank0_only=True)
939
-
940
- def save(
941
- self,
942
- model: ImaginaireModel,
943
- optimizer: torch.optim.Optimizer,
944
- scheduler: torch.optim.lr_scheduler.LRScheduler,
945
- grad_scaler: torch.amp.GradScaler,
946
- iteration: int,
947
- ) -> None:
948
- """Save network weights, optimizer parameters, scheduler parameters to a checkpoint.
949
-
950
- Args:
951
- model (ImaginaireModel): The PyTorch model.
952
- optimizer (torch.optim.Optimizer): The model optimizer.
953
- scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
954
- grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
955
- iteration (int): Current iteration number.
956
- """
957
- if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
958
- self.get_previous_checkpoint_results(wait_for=0)
959
-
960
- if self.callbacks is not None:
961
- self.callbacks.on_save_checkpoint_start(model, iteration)
962
-
963
- checkpoint_file = f"iter_{iteration:09}"
964
- to_save_dict = {
965
- "model": ModelWrapper(model).state_dict(),
966
- "optim": OptimizerWrapper(model, optimizer).state_dict(),
967
- "scheduler": scheduler.state_dict(),
968
- "trainer": {
969
- "grad_scaler": grad_scaler.state_dict(),
970
- "iteration": iteration,
971
- },
972
- }
973
- for k in to_save_dict.keys():
974
- output_dirname = os.path.join(self.save_dirname, f"iter_{iteration:09}/{k}")
975
- to_save_dict[k] = (to_save_dict[k], output_dirname)
976
-
977
- if self.callbacks is not None:
978
- self.callbacks.on_save_checkpoint(model, state_dict=to_save_dict)
979
-
980
- if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
981
- self._async_with_pinned_memory(checkpoint_file, to_save_dict)
982
- else:
983
- start_time = time.monotonic()
984
- try:
985
- self.save_state_dict_worker(to_save_dict, checkpoint_file)
986
- finally:
987
- if self.callbacks is not None:
988
- self.callbacks.on_save_checkpoint_success(
989
- iteration=iteration, elapsed_time=time.monotonic() - start_time
990
- )
991
-
992
- # This measures exposed (synchronous) checkpoint time, on_save_checkpoint_success()
993
- # is instead called to measure the entire duration for asynchronous checkpoint for the async case too.
994
- if self.callbacks is not None:
995
- self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration)
996
-
997
- def finalize(self) -> None:
998
- super().finalize()
999
- if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
1000
- if self.mp and self.mp.is_alive():
1001
- self.mp_queue_send.put(Terminate())
1002
- self.get_previous_checkpoint_results(wait_for=60)
1003
- self.mp.join()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/checkpointer/dummy.py DELETED
@@ -1,47 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Optional
17
-
18
- import torch
19
- import torch.distributed
20
-
21
- from lyra_2._ext.imaginaire.checkpointer.base import AbstractCheckpointer
22
- from lyra_2._ext.imaginaire.model import ImaginaireModel
23
-
24
-
25
- class Checkpointer(AbstractCheckpointer):
26
- """
27
- A dummy checkpointer that does not save or load anything. This is useful for debugging jobs or share workload with collobrators.
28
- """
29
-
30
- def save(
31
- self,
32
- model: ImaginaireModel,
33
- optimizer: torch.optim.Optimizer,
34
- scheduler: torch.optim.lr_scheduler.LRScheduler,
35
- grad_scaler: torch.amp.GradScaler,
36
- iteration: int,
37
- ) -> None:
38
- pass
39
-
40
- def load(
41
- self,
42
- model: ImaginaireModel,
43
- optimizer: Optional[torch.optim.Optimizer] = None,
44
- scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
45
- grad_scaler: Optional[torch.amp.GradScaler] = None,
46
- ) -> int:
47
- return 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/checkpointer/s3_filesystem.py DELETED
@@ -1,323 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import io
17
- import json
18
- import os
19
- import time
20
- from contextlib import contextmanager
21
- from typing import Generator, Union
22
- from urllib.parse import urlparse
23
-
24
- import boto3
25
- from botocore.config import Config
26
- from botocore.exceptions import ClientError
27
- from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
28
- from torch.distributed.checkpoint.filesystem import FileSystemBase
29
-
30
- from lyra_2._ext.imaginaire.utils import log
31
-
32
-
33
- class S3Stream(io.BytesIO):
34
- """
35
- Workaround for PyTorch manually closing the stream before we can upload it to S3. We override the close() as noop
36
- and instead call our own _true_close() method to close the stream after we are done using it.
37
- The commit at fault is https://github.com/pytorch/pytorch/commit/9c909bf3bb122db2cce95e2eb7459bbe50dfa15a
38
- """
39
-
40
- def close(self):
41
- self.flush()
42
- # No close
43
-
44
- def _true_close(self):
45
- super().close()
46
-
47
-
48
- class S3FileSystem(FileSystemBase):
49
- """Implementation of FileSystemBase for AWS S3 storage."""
50
-
51
- def __init__(
52
- self,
53
- credential_path: str,
54
- max_attempts: int = 20,
55
- initial_backoff: float = 1.0,
56
- max_backoff: float = 30.0,
57
- backoff_factor: float = 2.0,
58
- ) -> None:
59
- """
60
- Initialize S3FileSystem with retry configuration.
61
-
62
- Args:
63
- credential_path: Path to AWS credentials JSON file
64
- max_attempts: Maximum number of retry attempts
65
- initial_backoff: Initial backoff time in seconds
66
- max_backoff: Maximum backoff time in seconds
67
- backoff_factor: Multiplicative factor for backoff time
68
- """
69
- with open(credential_path, "r") as f:
70
- conf = json.load(f)
71
-
72
- # Configure boto3 with retry settings
73
- config = Config(
74
- retries=dict(max_attempts=max_attempts, mode="adaptive"), # Adaptive mode automatically handles throttling
75
- connect_timeout=60,
76
- read_timeout=60,
77
- request_checksum_calculation="when_required", # Data integrity check for uploads and downloads
78
- response_checksum_validation="when_required", # Data integrity check for uploads and downloads
79
- )
80
-
81
- self.s3_client = boto3.client("s3", config=config, **conf)
82
- self.max_attempts = max_attempts
83
- self.initial_backoff = initial_backoff
84
- self.max_backoff = max_backoff
85
- self.backoff_factor = backoff_factor
86
-
87
- def _retry_with_backoff(self, operation_func, *args, **kwargs):
88
- """
89
- Execute an operation with exponential backoff retry logic.
90
-
91
- Args:
92
- operation_func: Function to execute
93
- *args: Positional arguments for the function
94
- **kwargs: Keyword arguments for the function
95
-
96
- Returns:
97
- Result of the operation function
98
-
99
- Raises:
100
- Exception: If all retry attempts fail
101
- """
102
- last_exception = None
103
- backoff = self.initial_backoff
104
-
105
- for attempt in range(self.max_attempts):
106
- try:
107
- return operation_func(*args, **kwargs)
108
- except ClientError as e:
109
- error_code = e.response.get("Error", {}).get("Code", "")
110
- log.info(f"S3 Filesystem: Received ClientError: {error_code}", rank0_only=False)
111
-
112
- # Handle specific error cases
113
- if error_code in ["SlowDown", "ThrottlingException", "RequestLimitExceeded", "InternalError"]:
114
- last_exception = e
115
- if attempt < self.max_attempts - 1: # Don't sleep on last attempt
116
- current_backoff = min(backoff, self.max_backoff)
117
- log.info(f"S3 Filesystem: Retrying in {current_backoff} seconds", rank0_only=False)
118
- time.sleep(current_backoff)
119
- backoff *= self.backoff_factor
120
- continue
121
- # For other client errors, raise immediately
122
- raise
123
- except Exception as e:
124
- log.info(f"S3 Filesystem: Received Exception: {str(e)}", rank0_only=False)
125
- last_exception = e
126
- if attempt < self.max_attempts - 1:
127
- current_backoff = min(backoff, self.max_backoff)
128
- log.info(f"S3 Filesystem: Retrying in {current_backoff} seconds", rank0_only=False)
129
- time.sleep(current_backoff)
130
- backoff *= self.backoff_factor
131
- continue
132
-
133
- raise last_exception
134
-
135
- @contextmanager
136
- def create_stream(self, path: Union[str, os.PathLike], mode: str) -> Generator[io.IOBase, None, None]:
137
- """Create a stream for reading from or writing to S3 with retry logic."""
138
- path_str = str(path)
139
- bucket, key = self._parse_s3_uri(path_str)
140
- log.info(f"S3 Filesystem: Creating stream for {key} in bucket {bucket}", rank0_only=False)
141
-
142
- if mode == "rb":
143
- stream = io.BytesIO()
144
- try:
145
-
146
- def download_operation():
147
- self.s3_client.download_fileobj(bucket, key, stream)
148
- stream.seek(0)
149
-
150
- log.info(f"S3 Filesystem: Downloading {key} from bucket {bucket}", rank0_only=False)
151
- self._retry_with_backoff(download_operation)
152
- log.info("S3 Filesystem: Download complete", rank0_only=False)
153
- yield stream
154
- finally:
155
- stream.close()
156
- elif mode == "wb":
157
- stream = S3Stream()
158
- try:
159
- yield stream
160
-
161
- def upload_operation():
162
- stream.seek(0)
163
- self.s3_client.upload_fileobj(stream, bucket, key)
164
-
165
- log.info(f"S3 Filesystem: Uploading {key} to bucket {bucket}", rank0_only=False)
166
- self._retry_with_backoff(upload_operation)
167
- log.info("S3 Filesystem: Upload complete", rank0_only=False)
168
- finally:
169
- stream._true_close()
170
- else:
171
- raise ValueError(f"Unsupported mode: {mode}")
172
-
173
- def concat_path(self, path: Union[str, os.PathLike], suffix: str) -> Union[str, os.PathLike]:
174
- """Concatenate S3 path with suffix."""
175
- path_str = str(path)
176
- if path_str.endswith("/"):
177
- return f"{path_str}{suffix}"
178
- return f"{path_str}/{suffix}"
179
-
180
- def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
181
- """Initialize and validate S3 path."""
182
- path_str = str(path)
183
- if not path_str.startswith("s3://"):
184
- raise ValueError(f"Invalid S3 URI: {path_str}. Must start with 's3://'")
185
- return path_str
186
-
187
- def rename(self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]) -> None:
188
- """Rename (move) an object in S3 with retry logic."""
189
- src_bucket, src_key = self._parse_s3_uri(str(path))
190
- dst_bucket, dst_key = self._parse_s3_uri(str(new_path))
191
-
192
- def copy_operation():
193
- copy_source = {"Bucket": src_bucket, "Key": src_key}
194
- self.s3_client.copy(copy_source, dst_bucket, dst_key)
195
-
196
- self._retry_with_backoff(copy_operation)
197
-
198
- def delete_operation():
199
- self.s3_client.delete_object(Bucket=src_bucket, Key=src_key)
200
-
201
- self._retry_with_backoff(delete_operation)
202
-
203
- def mkdir(self, path: Union[str, os.PathLike]) -> None:
204
- """
205
- Create a "directory" in S3.
206
-
207
- Note: S3 doesn't have real directories, but we can create an empty object
208
- with a trailing slash to simulate a directory.
209
- """
210
- # Creating same buckets from different ranks can cause rate limit issues in GCP.
211
- # In object store, we don't need to create a directory.
212
- pass
213
-
214
- @classmethod
215
- def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
216
- """Validate if the checkpoint_id is a valid S3 URI."""
217
- checkpoint_id_str = str(checkpoint_id)
218
- try:
219
- if not checkpoint_id_str.startswith("s3://"):
220
- return False
221
- parsed = urlparse(checkpoint_id_str)
222
- return bool(parsed.netloc and parsed.path) # Must have bucket and key
223
- except Exception:
224
- return False
225
-
226
- def exists(self, path: Union[str, os.PathLike]) -> bool:
227
- """Check if an object exists in S3 with retry logic."""
228
- bucket, key = self._parse_s3_uri(str(path))
229
- try:
230
-
231
- def head_operation():
232
- self.s3_client.head_object(Bucket=bucket, Key=key)
233
-
234
- self._retry_with_backoff(head_operation)
235
- return True
236
- except ClientError as e:
237
- if e.response.get("Error", {}).get("Code", "") == "404":
238
- return False
239
- raise
240
-
241
- def rm_file(self, path: Union[str, os.PathLike]) -> None:
242
- """Remove a file from S3 with retry logic."""
243
- bucket, key = self._parse_s3_uri(str(path))
244
-
245
- def delete_operation():
246
- self.s3_client.delete_object(Bucket=bucket, Key=key)
247
-
248
- self._retry_with_backoff(delete_operation)
249
-
250
- def _parse_s3_uri(self, uri: str) -> tuple[str, str]:
251
- """
252
- Parse an S3 URI into bucket and key.
253
-
254
- Args:
255
- uri: S3 URI in the format s3://bucket-name/key
256
-
257
- Returns:
258
- Tuple of (bucket_name, key)
259
-
260
- Raises:
261
- ValueError: If the URI is invalid
262
- """
263
- uri = uri if isinstance(uri, str) else str(uri)
264
- if not uri.startswith("s3://"):
265
- raise ValueError(f"Invalid S3 URI: {uri}. Must start with 's3://'")
266
-
267
- parsed = urlparse(uri)
268
- bucket = parsed.netloc
269
-
270
- # Remove leading slash from key
271
- key = parsed.path.lstrip("/")
272
-
273
- if not bucket:
274
- raise ValueError(f"Invalid S3 URI: {uri}. No bucket specified")
275
-
276
- return bucket, key
277
-
278
-
279
- class S3StorageWriter(FileSystemWriter):
280
- def __init__(
281
- self,
282
- credential_path: str,
283
- path: str,
284
- **kwargs,
285
- ) -> None:
286
- """
287
- Initialize an S3 writer for distributed checkpointing.
288
-
289
- Args:
290
- region (str): The AWS region for S3.
291
- path (str): The S3 URI to write checkpoints to.
292
- kwargs (dict): Keyword arguments to pass to the parent :class:`FileSystemWriter`.
293
- """
294
- super().__init__(
295
- path=path,
296
- sync_files=False,
297
- **kwargs,
298
- )
299
- self.fs = S3FileSystem(credential_path) # type: ignore
300
- self.path = self.fs.init_path(path)
301
-
302
- @classmethod
303
- def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
304
- return S3FileSystem.validate_checkpoint_id(checkpoint_id)
305
-
306
-
307
- class S3StorageReader(FileSystemReader):
308
- def __init__(self, credential_path: str, path: Union[str, os.PathLike]) -> None:
309
- """
310
- Initialize an S3 reader for distributed checkpointing.
311
-
312
- Args:
313
- region (str): The AWS region for S3.
314
- path (Union[str, os.PathLike]): The S3 path to read checkpoints from.
315
- """
316
- super().__init__(path)
317
- self.fs = S3FileSystem(credential_path) # type: ignore
318
- self.path = self.fs.init_path(path)
319
- self.sync_files = False
320
-
321
- @classmethod
322
- def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
323
- return S3FileSystem.validate_checkpoint_id(checkpoint_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/config.py DELETED
@@ -1,445 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- """Training config system for Imaginare4"""
17
-
18
- from __future__ import annotations
19
-
20
- import os
21
- from typing import Any, Dict, Optional, Type, TypeVar, Union
22
-
23
- import attrs
24
- import torch
25
- import torch.utils.data
26
- import torch.utils.data.distributed
27
-
28
- try:
29
- from megatron.core import ModelParallelConfig
30
-
31
- USE_MEGATRON = True
32
- except ImportError:
33
- USE_MEGATRON = False
34
- print("Megatron-core is not installed.")
35
-
36
- from lyra_2._ext.imaginaire.lazy_config import LazyCall as L
37
- from lyra_2._ext.imaginaire.lazy_config import LazyDict
38
- from lyra_2._ext.imaginaire.utils import callback, distributed
39
- from lyra_2._ext.imaginaire.utils.misc import Color
40
-
41
- T = TypeVar("T")
42
-
43
-
44
- def _is_attrs_instance(obj: object) -> bool:
45
- """
46
- Helper function to check if an object is an instance of an attrs-defined class.
47
-
48
- Args:
49
- obj: The object to check.
50
-
51
- Returns:
52
- bool: True if the object is an instance of an attrs-defined class, False otherwise.
53
- """
54
- return hasattr(obj, "__attrs_attrs__")
55
-
56
-
57
- def make_freezable(cls: T) -> T:
58
- """
59
- A decorator that adds the capability to freeze instances of an attrs-defined class.
60
-
61
- NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need
62
- to hack on a "_is_frozen" attribute.
63
-
64
- This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime.
65
- Once an instance is frozen, its attributes cannot be changed. It also recursively freezes
66
- any attrs-defined objects that are attributes of the class.
67
-
68
- Usage:
69
- @make_freezable
70
- @attrs.define(slots=False)
71
- class MyClass:
72
- attribute1: int
73
- attribute2: str
74
-
75
- obj = MyClass(1, 'a')
76
- obj.freeze() # Freeze the instance
77
- obj.attribute1 = 2 # Raises AttributeError
78
-
79
- Args:
80
- cls: The class to be decorated.
81
-
82
- Returns:
83
- The decorated class with added freezing capability.
84
- """
85
-
86
- if not hasattr(cls, "__dict__"):
87
- raise TypeError(
88
- "make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped "
89
- "class was defined with `@attrs.define(slots=False)`"
90
- )
91
-
92
- original_setattr = cls.__setattr__
93
-
94
- def setattr_override(self, key, value) -> None: # noqa: ANN001
95
- """
96
- Override __setattr__ to allow modifications during initialization
97
- and prevent modifications once the instance is frozen.
98
- """
99
- if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen":
100
- raise AttributeError("Cannot modify frozen instance")
101
- original_setattr(self, key, value) # type: ignore
102
-
103
- cls.__setattr__ = setattr_override # type: ignore
104
-
105
- def freeze(self: object) -> None:
106
- """
107
- Freeze the instance and all its attrs-defined attributes.
108
- """
109
- for _, value in attrs.asdict(self, recurse=False).items():
110
- if _is_attrs_instance(value) and hasattr(value, "freeze"):
111
- value.freeze()
112
- self._is_frozen = True # type: ignore
113
-
114
- cls.freeze = freeze # type: ignore
115
-
116
- return cls
117
-
118
-
119
- def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str:
120
- """
121
- Recursively pretty prints attrs objects with color.
122
- """
123
-
124
- assert attrs.has(obj.__class__)
125
-
126
- lines: list[str] = []
127
- for attribute in attrs.fields(obj.__class__):
128
- value = getattr(obj, attribute.name)
129
- if attrs.has(value.__class__):
130
- if use_color:
131
- lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":")
132
- else:
133
- lines.append(" " * indent + "* " + attribute.name + ":")
134
- lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color))
135
- else:
136
- if use_color:
137
- lines.append(
138
- " " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value)
139
- )
140
- else:
141
- lines.append(" " * indent + "* " + attribute.name + ": " + str(value))
142
- return "\n".join(lines)
143
-
144
-
145
- def pretty_print_overrides(overrides: Optional[list[str]] = None, use_color: bool = False) -> str:
146
- """
147
- Pretty prints overrides.
148
- """
149
-
150
- lines: list[str] = []
151
- lines.append(Color.cyan("* ") + Color.green("overrides") + ": ")
152
- for override in overrides:
153
- if override == "--":
154
- continue
155
- if override.startswith("~"):
156
- attribute_name = override[1:]
157
- attribute_value = None
158
- else:
159
- attribute_name, attribute_value = override.split("=")
160
- if use_color:
161
- lines.append(" " + Color.cyan("* ") + Color.green(attribute_name) + ": " + Color.yellow(attribute_value))
162
- else:
163
- lines.append(" " + "* " + attribute_name + ": " + str(attribute_value))
164
-
165
- return "\n".join(lines)
166
-
167
-
168
- @make_freezable
169
- @attrs.define(slots=False) # slots=False is required for make_freezable. See the make_freezable notes for more info.
170
- class ObjectStoreConfig:
171
- # Whether the file I/O is from object store instead of local disk.
172
- enabled: bool = False
173
- # Path to the object store credentials file.
174
- credentials: str = ""
175
- # Object store bucket to read from / write to the objects.
176
- bucket: str = ""
177
-
178
-
179
- @make_freezable
180
- @attrs.define(slots=False)
181
- class JobConfig:
182
- # Project name.
183
- project: str = ""
184
- # Experiment name.
185
- group: str = ""
186
- # Run/job name.
187
- name: str = ""
188
- @property
189
- def path(self) -> str:
190
- return f"{self.project}/{self.group}/{self.name}"
191
-
192
- @property
193
- def path_local(self) -> str:
194
- local_root = os.environ.get("IMAGINAIRE_OUTPUT_ROOT", "/tmp/imaginaire4-output")
195
- return f"{local_root}/{self.path}"
196
-
197
-
198
- @make_freezable
199
- @attrs.define(slots=False)
200
- class EMAConfig:
201
- # Enable tracking a set of exponential moving average (EMA) weights.
202
- enabled: bool = False
203
- # EMA decay rate.
204
- beta: float = 0.9999
205
- # Enable removing "_orig_mod-" from buffer names that is added by torch.compile
206
- torch_compile_buffer_renaming: bool = False
207
-
208
-
209
- @make_freezable
210
- @attrs.define(slots=False)
211
- class PowerEMAConfig:
212
- # Enable tracking a set of exponential moving average (EMA) weights.
213
- enabled: bool = False
214
- # EDM2 paper EMA decay rate.
215
- s: float = 0.1
216
- # Enable removing "_orig_mod-" from buffer names that is added by torch.compile
217
- torch_compile_buffer_renaming: bool = False
218
-
219
-
220
- @make_freezable
221
- @attrs.define(slots=False)
222
- class DDPConfig:
223
- # Traverse the computation graph to find parameters that don't receive gradients.
224
- find_unused_parameters: bool = False
225
- # Set to True if the computation graph does not change during the whole training loop.
226
- static_graph: bool = True
227
- # Set to True if we want to synchronize buffers. Set to False if the sync is going to be handled elsewhere.
228
- broadcast_buffers: bool = True
229
-
230
-
231
- @make_freezable
232
- @attrs.define(slots=False)
233
- class CuDNNConfig:
234
- # Set to True for better reproducibility of the results (only using deterministic cudnn functions).
235
- deterministic: bool = False
236
- # If set to True, cudnn will benchmark several algorithms and pick the fastest one.
237
- benchmark: bool = True
238
-
239
-
240
- @make_freezable
241
- @attrs.define(slots=False)
242
- class JITConfig:
243
- # Enable exporting a JIT compiled model.
244
- enabled: bool = False
245
- # Input tensor shape, for example input.
246
- input_shape: Union[list[int], None] = None
247
- # Device to compile onto.
248
- device: str = "cuda"
249
- # # Data type to compile onto.
250
- dtype: str = "bfloat16"
251
- # Strict mode for PyTorch JIT.
252
- strict: bool = True
253
-
254
-
255
- @make_freezable
256
- @attrs.define(slots=False)
257
- class CheckpointConfig:
258
- # possible checkpoint class
259
- type: Optional[Dict] = None
260
- # for dcp, whether to use async mode
261
- dcp_async_mode_enabled: bool = False
262
- # Configs for saving the checkpoints to object store.
263
- save_to_object_store: ObjectStoreConfig = attrs.field(factory=ObjectStoreConfig)
264
- # Save the checkpoint every N iterations.
265
- save_iter: int = 999999999
266
- # Configs for loading the checkpoints from object store.
267
- load_from_object_store: ObjectStoreConfig = attrs.field(factory=ObjectStoreConfig)
268
- # Path of model weights to resume the checkpoint from.
269
- load_path: str = ""
270
- # Whether to load the training states (optimizer/scheduler/grad-scaler) from the checkpoint path.
271
- load_training_state: bool = False
272
- # Whether to load the scheduler state only from the checkpoint path. If load_training_state is True, this will be ignored.
273
- only_load_scheduler_state: bool = False
274
- # Load state_dict to the models in strict mode.
275
- strict_resume: bool = True
276
- # Configs for JIT compiling EMA model.
277
- jit: JITConfig = attrs.field(factory=JITConfig)
278
- # Print detailed information during checkpoint saving/loading.
279
- verbose: bool = True
280
- # keys not to resume from the checkpoint, choices: ["model", "optim", "scheduler", "trainer"]
281
- keys_not_to_resume: list[str] = []
282
- # Whether to use the local filesystem for broadcasting checkpoint data (used for Tensor Parallel Checkpointer).
283
- broadcast_via_filesystem: bool = False
284
- load_ema_to_reg: bool = False # used in inference, load EMA weights to regular model
285
- # In dcp planner, skip the weight shape check, load weights into the model even weight shape is different
286
- dcp_allow_mismatched_size: bool = False
287
-
288
-
289
- @make_freezable
290
- @attrs.define(slots=False)
291
- class NVTXConfig:
292
- """Config for NVTX ranges used in the main training loop.
293
-
294
- See tutorials/nanogpt for more details on how to integrate profiling into your model."""
295
-
296
- # Enable the NVTX ranges.
297
- enabled: bool = False
298
- # Synchronize everything in each NVTX range.
299
- cuda_synchronize: bool = False
300
-
301
-
302
- @make_freezable
303
- @attrs.define(slots=False)
304
- class StragglerDetectionConfig:
305
- """Config for Straggler detection tool."""
306
-
307
- # Enable the Straggler Detection.
308
- enabled: bool = False
309
- # How frequently should the Straggler reports be generated.
310
- report_freq: int = 100
311
- # How frequently iterations should be profiled
312
- profile_freq: int = 1
313
- # What is the maximum relative difference between GPUs after they are considered stragglers
314
- max_diff: float = 2.0
315
- # Should the error be raised when straggler is detected
316
- raise_error: bool = True
317
- # Analyze kernels in the forward pass.
318
- analyze_forward: bool = True
319
- # Analyze kernels in the backward pass.
320
- analyze_backward: bool = True
321
- # Analyze kernels in the optimizer.
322
- analyze_optimizer: bool = True
323
- # Analyze dataloading time.
324
- analyze_dataloading: bool = True
325
-
326
-
327
- @make_freezable
328
- @attrs.define(slots=False)
329
- class Profiling:
330
- enable_profiling: bool = False
331
- enable_memory_snapshot: bool = False
332
- save_s3: bool = False
333
- profile_freq: int = 1
334
- # Target ranks for profiling, each entry must be >=0 and < world_size.
335
- target_ranks: list[int] = list(range(8))
336
- # Set `record_shape` and `profile_memory` to False to reduce profile size.
337
- record_shape: bool = False
338
- profile_memory: bool = False
339
- with_stack: bool = True
340
- with_modules: bool = True
341
-
342
-
343
- @make_freezable
344
- @attrs.define(slots=False)
345
- class TrainerConfig:
346
- from lyra_2._ext.imaginaire.trainer import ImaginaireTrainer
347
-
348
- type: Type[ImaginaireTrainer] = ImaginaireTrainer
349
- # Set the callback class.
350
- # Defaults to the callbacks below.
351
- callbacks: LazyDict = LazyDict(
352
- dict(
353
- ema=L(callback.EMAModelCallback)(),
354
- progress_bar=L(callback.ProgressBarCallback)(),
355
- )
356
- )
357
- # distributed parallelism strategy
358
- distributed_parallelism: str = "ddp"
359
- # Distributed data parallel configs.
360
- ddp: DDPConfig = attrs.field(factory=DDPConfig)
361
- # cuDNN configs.
362
- cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig)
363
- # Set the random seed.
364
- seed: int = 0
365
- # Set the random seed based on current timestamp
366
- timestamp_seed: bool = False
367
- # Gradient scaler arguments (for torch.amp.GradScaler).
368
- grad_scaler_args: dict = attrs.field(factory=lambda: dict(enabled=False))
369
- # Maximum number of iterations to train the model.
370
- max_iter: int = 999999999
371
- # Maximum number of iterations to validate the model. If None, validate on the entire dataset.
372
- max_val_iter: int | None = None
373
- # How often we log the training stats.
374
- logging_iter: int = 100
375
- # Whether we want to run the validation routines.
376
- run_validation: bool = True
377
- # How often we evaluate on the validation set.
378
- validation_iter: int = 999999999
379
- # Kill the process after N seconds since the last iteration (usually means dead job).
380
- timeout_period: int = 999999999
381
- # Tensor memory organization format.
382
- memory_format: torch.memory_format = torch.preserve_format
383
- # Gradient accumulation (update step every N iteration).
384
- grad_accum_iter: int = 1
385
- # Straggler Detection config
386
- straggler_detection: StragglerDetectionConfig = attrs.field(factory=StragglerDetectionConfig)
387
- # Profiling config
388
- profiling: Profiling = attrs.field(factory=Profiling)
389
-
390
-
391
- @make_freezable
392
- @attrs.define(slots=False)
393
- class Config:
394
- """Config for an imaginaire4 job.
395
-
396
- See /README.md/Configuration System for more info.
397
- """
398
-
399
- # Model configs.
400
- model: LazyDict
401
- # Optimizer configs.
402
- optimizer: LazyDict
403
- # Scheduler configs.
404
- scheduler: LazyDict
405
- # Training data configs.
406
- dataloader_train: LazyDict
407
- # Validation data configs.
408
- dataloader_val: LazyDict
409
-
410
- # Training job configs.
411
- job: JobConfig = attrs.field(factory=JobConfig)
412
-
413
- # Trainer configs.
414
- trainer: TrainerConfig = attrs.field(factory=TrainerConfig)
415
-
416
- if USE_MEGATRON:
417
- # Megatron-Core configs
418
- model_parallel: ModelParallelConfig = attrs.field(factory=ModelParallelConfig)
419
- else:
420
- model_parallel: None = None
421
-
422
- # Checkpointer configs.
423
- checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig)
424
-
425
- # enable upload reproducible setup to s3
426
- upload_reproducible_setup: bool = False
427
-
428
- def pretty_print(self, use_color: bool = False) -> str:
429
- return _pretty_print_attrs_instance(self, 0, use_color)
430
-
431
- def to_dict(self) -> dict[str, Any]:
432
- return attrs.asdict(self)
433
-
434
- def validate(self) -> None:
435
- """Validate that the config has all required fields."""
436
-
437
- # broadcast job.name across all ranks to make sure it is consistent
438
- # otherwise, unaligned job names leads unaligned path to save checkpoints
439
- job_name_tensor = torch.ByteTensor(bytearray(self.job.name, "utf-8")).cuda()
440
- distributed.broadcast(job_name_tensor, 0)
441
- self.job.name = job_name_tensor.cpu().numpy().tobytes().decode("utf-8")
442
-
443
- assert self.job.project != ""
444
- assert self.job.group != ""
445
- assert self.job.name != ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/flags.py DELETED
@@ -1,52 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- """Feature flags."""
17
-
18
- import os
19
- from dataclasses import dataclass
20
-
21
-
22
- def _parse_bool(value: str) -> bool:
23
- """Parse string to a boolean."""
24
- return value.lower() in ["true", "1", "yes", "y"]
25
-
26
-
27
- INTERNAL = _parse_bool(os.environ.get("COSMOS_INTERNAL", "0"))
28
- """Whether to enable internal features."""
29
-
30
- SMOKE = _parse_bool(os.environ.get("COSMOS_SMOKE", "0"))
31
- """Whether to enable smoke test.
32
-
33
- Disables expensive operations such as checkpoint loading.
34
- """
35
-
36
- VERBOSE = _parse_bool(os.environ.get("COSMOS_VERBOSE", "0"))
37
- """Whether to enable verbose output."""
38
-
39
- EXPERIMENTAL_CHECKPOINTS = _parse_bool(os.environ.get("COSMOS_EXPERIMENTAL_CHECKPOINTS", "0"))
40
- """Whether to enable experimental checkpoints."""
41
-
42
-
43
- @dataclass
44
- class Flags:
45
- internal: bool = INTERNAL
46
- smoke: bool = SMOKE
47
- verbose: bool = VERBOSE
48
- experimental_checkpoints: bool = EXPERIMENTAL_CHECKPOINTS
49
-
50
-
51
- FLAGS = Flags()
52
- """Convenience object for accessing flags."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/functional/batch_ops.py DELETED
@@ -1,61 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- # Functions for performing operations with broadcasting to the right axis
17
- #
18
- # Example
19
- # input1: tensor of size (N1, N2)
20
- # input2: tensor of size (N1, N2, N3, N4)
21
- # batch_mul(input1, input2) = input1[:, :, None, None] * input2
22
- #
23
- # If the common dimensions don't match, we raise an assertion error.
24
-
25
- from torch import Tensor
26
-
27
-
28
- def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
29
- ndims1 = x.ndim
30
- ndims2 = y.ndim
31
-
32
- common_ndims = min(ndims1, ndims2)
33
- for axis in range(common_ndims):
34
- assert x.shape[axis] == y.shape[axis], "Dimensions not equal at axis {}".format(axis)
35
-
36
- if ndims1 < ndims2:
37
- x = x.reshape(x.shape + (1,) * (ndims2 - ndims1))
38
- elif ndims2 < ndims1:
39
- y = y.reshape(y.shape + (1,) * (ndims1 - ndims2))
40
-
41
- return x, y
42
-
43
-
44
- def batch_add(x: Tensor, y: Tensor) -> Tensor:
45
- x, y = common_broadcast(x, y)
46
- return x + y
47
-
48
-
49
- def batch_mul(x: Tensor, y: Tensor) -> Tensor:
50
- x, y = common_broadcast(x, y)
51
- return x * y
52
-
53
-
54
- def batch_sub(x: Tensor, y: Tensor) -> Tensor:
55
- x, y = common_broadcast(x, y)
56
- return x - y
57
-
58
-
59
- def batch_div(x: Tensor, y: Tensor) -> Tensor:
60
- x, y = common_broadcast(x, y)
61
- return x / y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/functional/lr_scheduler.py DELETED
@@ -1,178 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Optional
17
-
18
- import numpy as np
19
-
20
- from lyra_2._ext.imaginaire.utils import distributed, log
21
-
22
-
23
- class TeroPolyScheduler:
24
- def __init__(
25
- self,
26
- total_Mimg: int,
27
- batch_size: int,
28
- ref_Mimg: Optional[int] = None,
29
- ref_batches: float = 70e3 / 1024,
30
- max_lr_ratio: Optional[float] = 1.0,
31
- min_lr_ratio: Optional[float] = None,
32
- rampup_Mimg: float = 0,
33
- rampdown_Mimg: int = 0,
34
- verbosity_interval: int = 0,
35
- formula: str = "poly",
36
- poly_exp: float = 0.5,
37
- ):
38
- self.total_Mimg = total_Mimg
39
- self.batch_size = batch_size * distributed.get_world_size()
40
- self.ref_Mimg = ref_Mimg or ref_batches * batch_size / 1e6
41
- self.ref_batches = ref_batches
42
- self.max_lr_ratio = max_lr_ratio
43
- self.min_lr_ratio = min_lr_ratio
44
- self.rampup_Mimg = rampup_Mimg
45
- self.rampdown_Mimg = rampdown_Mimg
46
- self.verbosity_interval = verbosity_interval
47
- self.formula = formula
48
- self.poly_exp = poly_exp
49
-
50
- self._model = None
51
-
52
- @property
53
- def model(self):
54
- return self._model
55
-
56
- @model.setter
57
- def model(self, model):
58
- self._model = model
59
-
60
- def schedule(self, n, **kwargs):
61
- cur_Mimg = getattr(self.model, "sample_counter", 0) / 1e6
62
-
63
- if self.formula == "constant":
64
- lr = 1.0
65
- elif self.formula == "poly":
66
- lr = max(cur_Mimg / self.ref_Mimg, 1e-8) ** -self.poly_exp
67
- else:
68
- raise ValueError(f'Invalid learning rate formula "{self.formula}"')
69
-
70
- if self.max_lr_ratio is not None:
71
- lr = min(lr, self.max_lr_ratio)
72
- if self.min_lr_ratio is not None:
73
- lr = max(lr, self.min_lr_ratio)
74
-
75
- if self.rampup_Mimg > 0 and cur_Mimg < self.rampup_Mimg:
76
- lr *= cur_Mimg / self.rampup_Mimg
77
- if self.rampdown_Mimg > 0 and cur_Mimg > self.total_Mimg - self.rampdown_Mimg:
78
- lr *= (self.total_Mimg - cur_Mimg) / self.rampdown_Mimg
79
-
80
- return lr
81
-
82
- def __call__(self, n, **kwargs):
83
- return self.schedule(n, **kwargs)
84
-
85
-
86
- class LambdaWarmUpCosineScheduler:
87
- """
88
- A learning rate scheduler that combines warm-up with a cosine decay schedule for multiple cycles.
89
- It supports different configurations for each cycle, including the number of warm-up steps, minimum
90
- and maximum scaling factors for the learning rate.
91
-
92
- The scheduler is intended to be used with a base learning rate of 1.0, where the actual learning
93
- rate at any step is the base learning rate multiplied by the scaling factor computed by the scheduler.
94
-
95
- Parameters:
96
- warm_up_steps (list[int]): List of integers where each element represents the number of warm-up
97
- steps for the corresponding cycle.
98
- f_min (list[float]): List of the minimum scaling factors for each cycle after warm-up.
99
- f_max (list[float]): List of the maximum scaling factors at the start and end of each cosine cycle.
100
- f_start (list[float]): List of starting scaling factors for each warm-up phase.
101
- cycle_lengths (list[int]): List of the total lengths of each cycle, including warm-up steps.
102
- verbosity_interval (int, optional): Interval of training steps at which to print current step and
103
- scaling factor information. Set to 0 by default to disable verbosity.
104
-
105
- Examples:
106
- >>> scheduler = LambdaWarmUpCosineScheduler2(
107
- warm_up_steps=[10, 10],
108
- f_min=[0.1, 0.1],
109
- f_max=[1.0, 1.0],
110
- f_start=[0.01, 0.01],
111
- cycle_lengths=[50, 50],
112
- verbosity_interval=10)
113
- >>> for step in range(100):
114
- >>> lr_multiplier = scheduler(step)
115
- >>> print(f"Step {step}: LR Multiplier = {lr_multiplier}")
116
- """
117
-
118
- def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
119
- assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
120
- self.lr_warm_up_steps = warm_up_steps
121
- self.f_start = f_start
122
- self.f_min = f_min
123
- self.f_max = f_max
124
- self.cycle_lengths = cycle_lengths
125
- self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
126
- self.last_f = 0.0
127
- self.verbosity_interval = verbosity_interval
128
-
129
- def find_in_interval(self, n):
130
- interval = 0
131
- for cl in self.cum_cycles[1:]:
132
- if n <= cl:
133
- return interval
134
- interval += 1
135
-
136
- def schedule(self, n, **kwargs):
137
- cycle = self.find_in_interval(n)
138
- n = n - self.cum_cycles[cycle]
139
- if self.verbosity_interval > 0:
140
- if n % self.verbosity_interval == 0:
141
- log.info(f"current step: {n}, recent lr-multiplier: {self.last_f}, current cycle {cycle}")
142
- if n < self.lr_warm_up_steps[cycle]:
143
- f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
144
- self.last_f = f
145
- return f
146
- else:
147
- t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
148
- t = min(t, 1.0)
149
- f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi))
150
- self.last_f = f
151
- return f
152
-
153
- def __call__(self, n, **kwargs):
154
- return self.schedule(n, **kwargs)
155
-
156
-
157
- class LambdaLinearScheduler(LambdaWarmUpCosineScheduler):
158
- """
159
- Linear instead of cosine decay for the main part of the cycle.
160
- """
161
-
162
- def schedule(self, n, **kwargs):
163
- cycle = self.find_in_interval(n)
164
- n = n - self.cum_cycles[cycle]
165
- if self.verbosity_interval > 0:
166
- if n % self.verbosity_interval == 0:
167
- log.info(f"current step: {n}, recent lr-multiplier: {self.last_f}, current cycle {cycle}")
168
-
169
- if n < self.lr_warm_up_steps[cycle]:
170
- f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
171
- self.last_f = f
172
- return f
173
- else:
174
- f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (
175
- self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
176
- )
177
- self.last_f = f
178
- return f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/lazy_config/__init__.py DELETED
@@ -1,60 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- import os
3
-
4
- from omegaconf import DictConfig, OmegaConf
5
-
6
- from lyra_2._ext.imaginaire.lazy_config.instantiate import instantiate
7
- from lyra_2._ext.imaginaire.lazy_config.lazy import LazyCall, LazyConfig
8
- from lyra_2._ext.imaginaire.lazy_config.omegaconf_patch import to_object
9
-
10
- OmegaConf.to_object = to_object
11
-
12
- PLACEHOLDER = None
13
- LazyDict = DictConfig
14
-
15
- __all__ = ["instantiate", "LazyCall", "LazyConfig", "PLACEHOLDER", "LazyDict"]
16
-
17
-
18
- DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py
19
-
20
-
21
- def fixup_module_metadata(module_name, namespace, keys=None):
22
- """
23
- Fix the __qualname__ of module members to be their exported api name, so
24
- when they are referenced in docs, sphinx can find them. Reference:
25
- https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241
26
- """
27
- if not DOC_BUILDING:
28
- return
29
- seen_ids = set()
30
-
31
- def fix_one(qualname, name, obj):
32
- # avoid infinite recursion (relevant when using
33
- # typing.Generic, for example)
34
- if id(obj) in seen_ids:
35
- return
36
- seen_ids.add(id(obj))
37
-
38
- mod = getattr(obj, "__module__", None)
39
- if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")):
40
- obj.__module__ = module_name
41
- # Modules, unlike everything else in Python, put fully-qualitied
42
- # names into their __name__ attribute. We check for "." to avoid
43
- # rewriting these.
44
- if hasattr(obj, "__name__") and "." not in obj.__name__:
45
- obj.__name__ = name
46
- obj.__qualname__ = qualname
47
- if isinstance(obj, type):
48
- for attr_name, attr_value in obj.__dict__.items():
49
- fix_one(objname + "." + attr_name, attr_name, attr_value)
50
-
51
- if keys is None:
52
- keys = namespace.keys()
53
- for objname in keys:
54
- if not objname.startswith("_"):
55
- obj = namespace[objname]
56
- fix_one(objname, objname, obj)
57
-
58
-
59
- fixup_module_metadata(__name__, globals(), __all__)
60
- del fixup_module_metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/lazy_config/file_io.py DELETED
@@ -1,24 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler
17
- from iopath.common.file_io import PathManager as PathManagerBase
18
-
19
- __all__ = ["PathManager", "PathHandler"]
20
-
21
-
22
- PathManager = PathManagerBase()
23
- PathManager.register_handler(HTTPURLHandler())
24
- PathManager.register_handler(OneDrivePathHandler())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/lazy_config/instantiate.py DELETED
@@ -1,120 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import collections.abc as abc
17
- import dataclasses
18
- import logging
19
- from typing import Any
20
-
21
- import attrs
22
-
23
- from lyra_2._ext.imaginaire.lazy_config.registry import _convert_target_to_string, locate
24
-
25
- __all__ = ["dump_dataclass", "instantiate"]
26
-
27
-
28
- def is_dataclass_or_attrs(target):
29
- return dataclasses.is_dataclass(target) or attrs.has(target)
30
-
31
-
32
- def dump_dataclass(obj: Any):
33
- """
34
- Dump a dataclass recursively into a dict that can be later instantiated.
35
-
36
- Args:
37
- obj: a dataclass object
38
-
39
- Returns:
40
- dict
41
- """
42
- assert dataclasses.is_dataclass(obj) and not isinstance(obj, type), (
43
- "dump_dataclass() requires an instance of a dataclass."
44
- )
45
- ret = {"_target_": _convert_target_to_string(type(obj))}
46
- for f in dataclasses.fields(obj):
47
- v = getattr(obj, f.name)
48
- if dataclasses.is_dataclass(v):
49
- v = dump_dataclass(v)
50
- if isinstance(v, (list, tuple)):
51
- v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v]
52
- ret[f.name] = v
53
- return ret
54
-
55
-
56
- def instantiate(cfg, *args, **kwargs):
57
- """
58
- Recursively instantiate objects defined in dictionaries by
59
- "_target_" and arguments.
60
-
61
- Args:
62
- cfg: a dict-like object with "_target_" that defines the caller, and
63
- other keys that define the arguments
64
- args: Optional positional parameters pass-through.
65
- kwargs: Optional named parameters pass-through.
66
-
67
- Returns:
68
- object instantiated by cfg
69
- """
70
- from omegaconf import DictConfig, ListConfig, OmegaConf
71
-
72
- if isinstance(cfg, ListConfig):
73
- lst = [instantiate(x) for x in cfg]
74
- return ListConfig(lst, flags={"allow_objects": True})
75
- if isinstance(cfg, list):
76
- # Specialize for list, because many classes take
77
- # list[objects] as arguments, such as ResNet, DatasetMapper
78
- return [instantiate(x) for x in cfg]
79
-
80
- # If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config),
81
- # instantiate it to the actual dataclass.
82
- if isinstance(cfg, DictConfig) and is_dataclass_or_attrs(cfg._metadata.object_type):
83
- return OmegaConf.to_object(cfg)
84
-
85
- if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
86
- # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
87
- # but faster: https://github.com/facebookresearch/hydra/issues/1200
88
- is_recursive = getattr(cfg, "_recursive_", True)
89
- if is_recursive:
90
- cfg = {k: instantiate(v) for k, v in cfg.items()}
91
- else:
92
- cfg = {k: v for k, v in cfg.items()}
93
- # pop the _recursive_ key to avoid passing it as a parameter
94
- if "_recursive_" in cfg:
95
- cfg.pop("_recursive_")
96
- cls = cfg.pop("_target_")
97
- cls = instantiate(cls)
98
-
99
- if isinstance(cls, str):
100
- cls_name = cls
101
- cls = locate(cls_name)
102
- assert cls is not None, cls_name
103
- else:
104
- try:
105
- cls_name = cls.__module__ + "." + cls.__qualname__
106
- except Exception:
107
- # target could be anything, so the above could fail
108
- cls_name = str(cls)
109
- assert callable(cls), f"_target_ {cls} does not define a callable object"
110
- try:
111
- # override config with kwargs
112
- instantiate_kwargs = {}
113
- instantiate_kwargs.update(cfg)
114
- instantiate_kwargs.update(kwargs)
115
- return cls(*args, **instantiate_kwargs)
116
- except TypeError:
117
- logger = logging.getLogger(__name__)
118
- logger.error(f"Error when instantiating {cls_name}!")
119
- raise
120
- return cfg # return as-is if don't know what to do
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/lazy_config/lazy.py DELETED
@@ -1,430 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import ast
17
- import builtins
18
- import collections.abc as abc
19
- import importlib
20
- import inspect
21
- import logging
22
- import os
23
- import pickle
24
- import uuid
25
- from collections import OrderedDict
26
- from contextlib import contextmanager
27
- from copy import deepcopy
28
- from dataclasses import is_dataclass
29
- from typing import Any, Dict, List, Tuple, Union
30
-
31
- import attrs
32
- import yaml
33
- from omegaconf import DictConfig, ListConfig, OmegaConf
34
-
35
- try:
36
- import dill as dill_pickle
37
- except ImportError:
38
- dill_pickle = None
39
-
40
- try:
41
- import cloudpickle
42
- except ImportError:
43
- cloudpickle = None
44
-
45
- from lyra_2._ext.imaginaire.lazy_config.file_io import PathManager
46
- from lyra_2._ext.imaginaire.lazy_config.registry import _convert_target_to_string
47
-
48
- __all__ = ["LazyCall", "LazyConfig"]
49
-
50
-
51
- def sort_dict(d: Dict[str, Any]) -> OrderedDict[str, Any]:
52
- return OrderedDict(sorted(d.items(), key=lambda x: x[0]))
53
-
54
-
55
- def dict_representer(dumper: yaml.Dumper, data: OrderedDict[str, Any]) -> yaml.nodes.MappingNode:
56
- return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
57
-
58
-
59
- def sort_recursive(obj: Union[Dict[str, Any], List[Any], Any]) -> Union[OrderedDict[str, Any], List[Any], Any]:
60
- if isinstance(obj, dict):
61
- return sort_dict({k: sort_recursive(v) for k, v in obj.items()})
62
- elif isinstance(obj, list):
63
- return [sort_recursive(item) for item in obj]
64
- return obj
65
-
66
-
67
- yaml.add_representer(OrderedDict, dict_representer)
68
-
69
- OmegaConf.register_new_resolver("add", lambda *vals: sum(vals))
70
- OmegaConf.register_new_resolver("subtract", lambda *vals: vals[0] - sum(vals[1:]))
71
-
72
-
73
- def get_default_params(cls_or_func):
74
- if callable(cls_or_func):
75
- # inspect signature for function
76
- signature = inspect.signature(cls_or_func)
77
- else:
78
- # inspect signature for class
79
- signature = inspect.signature(cls_or_func.__init__)
80
- params = signature.parameters
81
- default_params = {
82
- name: param.default for name, param in params.items() if param.default is not inspect.Parameter.empty
83
- }
84
- return default_params
85
-
86
-
87
- class LazyCall:
88
- """
89
- Wrap a callable so that when it's called, the call will not be executed,
90
- but returns a dict that describes the call.
91
-
92
- LazyCall object has to be called with only keyword arguments. Positional
93
- arguments are not yet supported.
94
-
95
- Examples:
96
- ::
97
- from detectron2.config import instantiate, LazyCall
98
-
99
- layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
100
- layer_cfg.out_channels = 64 # can edit it afterwards
101
- layer = instantiate(layer_cfg)
102
- """
103
-
104
- def __init__(self, target):
105
- if not (callable(target) or isinstance(target, (str, abc.Mapping))):
106
- raise TypeError(f"target of LazyCall must be a callable or defines a callable! Got {target}")
107
- self._target = target
108
-
109
- def __call__(self, **kwargs):
110
- if is_dataclass(self._target) or attrs.has(self._target):
111
- # omegaconf object cannot hold dataclass type
112
- # https://github.com/omry/omegaconf/issues/784
113
- target = _convert_target_to_string(self._target)
114
- else:
115
- target = self._target
116
- kwargs["_target_"] = target
117
-
118
- _final_params = get_default_params(self._target)
119
- _final_params.update(kwargs)
120
-
121
- return DictConfig(content=_final_params, flags={"allow_objects": True})
122
-
123
-
124
- def _visit_dict_config(cfg, func):
125
- """
126
- Apply func recursively to all DictConfig in cfg.
127
- """
128
- if isinstance(cfg, DictConfig):
129
- func(cfg)
130
- for v in cfg.values():
131
- _visit_dict_config(v, func)
132
- elif isinstance(cfg, ListConfig):
133
- for v in cfg:
134
- _visit_dict_config(v, func)
135
-
136
-
137
- def _validate_py_syntax(filename):
138
- # see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
139
- with PathManager.open(filename, "r") as f:
140
- content = f.read()
141
- try:
142
- ast.parse(content)
143
- except SyntaxError as e:
144
- raise SyntaxError(f"Config file {filename} has syntax error!") from e
145
-
146
-
147
- def _cast_to_config(obj):
148
- # if given a dict, return DictConfig instead
149
- if isinstance(obj, dict):
150
- return DictConfig(obj, flags={"allow_objects": True})
151
- return obj
152
-
153
-
154
- _CFG_PACKAGE_NAME = "detectron2._cfg_loader"
155
- """
156
- A namespace to put all imported config into.
157
- """
158
-
159
-
160
- def _random_package_name(filename):
161
- # generate a random package name when loading config files
162
- return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename)
163
-
164
-
165
- @contextmanager
166
- def _patch_import():
167
- """
168
- Enhance relative import statements in config files, so that they:
169
- 1. locate files purely based on relative location, regardless of packages.
170
- e.g. you can import file without having __init__
171
- 2. do not cache modules globally; modifications of module states has no side effect
172
- 3. support other storage system through PathManager, so config files can be in the cloud
173
- 4. imported dict are turned into omegaconf.DictConfig automatically
174
- """
175
- old_import = builtins.__import__
176
-
177
- def find_relative_file(original_file, relative_import_path, level):
178
- # NOTE: "from . import x" is not handled. Because then it's unclear
179
- # if such import should produce `x` as a python module or DictConfig.
180
- # This can be discussed further if needed.
181
- relative_import_err = """
182
- Relative import of directories is not allowed within config files.
183
- Within a config file, relative import can only import other config files.
184
- """.replace("\n", " ")
185
- if not len(relative_import_path):
186
- raise ImportError(relative_import_err)
187
-
188
- cur_file = os.path.dirname(original_file)
189
- for _ in range(level - 1):
190
- cur_file = os.path.dirname(cur_file)
191
- cur_name = relative_import_path.lstrip(".")
192
- for part in cur_name.split("."):
193
- cur_file = os.path.join(cur_file, part)
194
- if not cur_file.endswith(".py"):
195
- cur_file += ".py"
196
- if not PathManager.isfile(cur_file):
197
- cur_file_no_suffix = cur_file[: -len(".py")]
198
- if PathManager.isdir(cur_file_no_suffix):
199
- raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err)
200
- else:
201
- raise ImportError(
202
- f"Cannot import name {relative_import_path} from {original_file}: {cur_file} does not exist."
203
- )
204
- return cur_file
205
-
206
- def new_import(name, globals=None, locals=None, fromlist=(), level=0):
207
- if (
208
- # Only deal with relative imports inside config files
209
- level != 0 and globals is not None and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME)
210
- ):
211
- cur_file = find_relative_file(globals["__file__"], name, level)
212
- _validate_py_syntax(cur_file)
213
- spec = importlib.machinery.ModuleSpec(_random_package_name(cur_file), None, origin=cur_file)
214
- module = importlib.util.module_from_spec(spec)
215
- module.__file__ = cur_file
216
- with PathManager.open(cur_file) as f:
217
- content = f.read()
218
- exec(compile(content, cur_file, "exec"), module.__dict__)
219
- for name in fromlist: # turn imported dict into DictConfig automatically
220
- val = _cast_to_config(module.__dict__[name])
221
- module.__dict__[name] = val
222
- return module
223
- return old_import(name, globals, locals, fromlist=fromlist, level=level)
224
-
225
- builtins.__import__ = new_import
226
- yield new_import
227
- builtins.__import__ = old_import
228
-
229
-
230
- class LazyConfig:
231
- """
232
- Provide methods to save, load, and overrides an omegaconf config object
233
- which may contain definition of lazily-constructed objects.
234
- """
235
-
236
- @staticmethod
237
- def load_rel(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
238
- """
239
- Similar to :meth:`load()`, but load path relative to the caller's
240
- source file.
241
-
242
- This has the same functionality as a relative import, except that this method
243
- accepts filename as a string, so more characters are allowed in the filename.
244
- """
245
- caller_frame = inspect.stack()[1]
246
- caller_fname = caller_frame[0].f_code.co_filename
247
- assert caller_fname != "<string>", "load_rel Unable to find caller"
248
- caller_dir = os.path.dirname(caller_fname)
249
- filename = os.path.join(caller_dir, filename)
250
- return LazyConfig.load(filename, keys)
251
-
252
- @staticmethod
253
- def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
254
- """
255
- Load a config file.
256
-
257
- Args:
258
- filename: absolute path or relative path w.r.t. the current working directory
259
- keys: keys to load and return. If not given, return all keys
260
- (whose values are config objects) in a dict.
261
- """
262
- has_keys = keys is not None
263
- filename = filename.replace("/./", "/") # redundant
264
- if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]:
265
- raise ValueError(f"Config file {filename} has to be a python or yaml file.")
266
- if filename.endswith(".py"):
267
- _validate_py_syntax(filename)
268
-
269
- with _patch_import():
270
- # Record the filename
271
- module_namespace = {
272
- "__file__": filename,
273
- "__package__": _random_package_name(filename),
274
- }
275
- with PathManager.open(filename) as f:
276
- content = f.read()
277
- # Compile first with filename to:
278
- # 1. make filename appears in stacktrace
279
- # 2. make load_rel able to find its parent's (possibly remote) location
280
- exec(compile(content, filename, "exec"), module_namespace)
281
-
282
- ret = module_namespace
283
- else:
284
- with PathManager.open(filename) as f:
285
- obj = yaml.unsafe_load(f)
286
- ret = OmegaConf.create(obj, flags={"allow_objects": True})
287
-
288
- if has_keys:
289
- if isinstance(keys, str):
290
- return _cast_to_config(ret[keys])
291
- else:
292
- return tuple(_cast_to_config(ret[a]) for a in keys)
293
- else:
294
- if filename.endswith(".py"):
295
- # when not specified, only load those that are config objects
296
- ret = DictConfig(
297
- {
298
- name: _cast_to_config(value)
299
- for name, value in ret.items()
300
- if isinstance(value, (DictConfig, ListConfig, dict)) and not name.startswith("_")
301
- },
302
- flags={"allow_objects": True},
303
- )
304
- return ret
305
-
306
- @staticmethod
307
- def save_pkl(cfg, filename: str) -> str:
308
- """
309
- Saves a Config object to a file using pickle serialization. This method is typically used
310
- when the configuration object contains complex objects, such as lambdas, that are not supported by
311
- simpler serialization methods like YAML. The function attempts to create a deep copy of the configuration
312
- object before serialization to ensure that the original object remains unmodified.
313
-
314
- Args:
315
- cfg: A Config object to be serialized and saved.
316
- filename: The path and name of the file where the configuration should be saved. The function
317
- assumes the file extension indicates a pickle format (e.g., .pkl).
318
-
319
- Returns:
320
- str: The filename to which the configuration was saved. This can be used to verify the file location
321
- or log the outcome.
322
-
323
- Notes:
324
- - The function logs a warning if the configuration is successfully saved using pickle.
325
- - If saving fails, an error is logged with the exception details.
326
- """
327
- logger = logging.getLogger(__name__)
328
- try:
329
- cfg = deepcopy(cfg)
330
- except Exception:
331
- pass
332
-
333
- try:
334
- with PathManager.open(filename, "wb") as f:
335
- pickle.dump(cfg, f)
336
- logger.warning(f"Config is saved using pickle at {filename}.")
337
- except Exception as e:
338
- logger.error(f"Failed to save config to {filename}: {e}. Trying dill or cloudpickle instead")
339
- if dill_pickle:
340
- try:
341
- with PathManager.open(filename, "wb") as f:
342
- pickle.dump(dill_pickle.dumps(cfg, recurse=True), f)
343
- logger.warning(f"Config is saved using dill at {filename}.")
344
- except Exception as e:
345
- logger.error(f"Failed to save config to {filename}: {e}.")
346
- if cloudpickle:
347
- try:
348
- with PathManager.open(filename, "wb") as f:
349
- pickle.dump(cloudpickle.dumps(cfg), f)
350
- logger.warning(f"Config is saved using cloudpickle at {filename}.")
351
- except Exception as e:
352
- logger.error(f"Failed to save config to {filename}: {e}.")
353
- else:
354
- logger.error("cloudpickle is not available. Cannot save the config.")
355
- raise e
356
-
357
- return filename
358
-
359
- @staticmethod
360
- def save_yaml(cfg, filename: str) -> str:
361
- """
362
- Saves a Config object to a file using YAML serialization. This method is beneficial when the configuration object's content needs to be human-readable and easily editable. YAML is suitable for configurations that do not contain complex types like lambdas, which must be handled differently. The function converts unserializable items to strings before saving to ensure compatibility with YAML serialization.
363
-
364
- Args:
365
- cfg: A Config object to be serialized and saved. It handles both DictConfig and ListConfig types.
366
- filename: The path and name of the file where the configuration should be saved. The function does not require a specific file extension but typically uses '.yaml'.
367
-
368
- Returns:
369
- str: The filename to which the configuration was saved. This can be used to verify the file location or log the outcome.
370
-
371
- Notes:
372
- - The function logs a warning if the configuration is successfully saved using YAML.
373
- - If saving fails, an error is logged with the exception details.
374
- """
375
- logger = logging.getLogger(__name__)
376
- try:
377
- cfg = deepcopy(cfg)
378
- except Exception:
379
- pass
380
-
381
- # Define a function to check if an item is serializable to YAML
382
- def is_serializable(item):
383
- try:
384
- OmegaConf.to_yaml(item)
385
- return True
386
- except Exception as e:
387
- return False
388
-
389
- # Function to convert unserializable items to strings
390
- def serialize_config(config):
391
- if isinstance(config, DictConfig):
392
- for key, value in config.items():
393
- if isinstance(value, (DictConfig, ListConfig)):
394
- try:
395
- if "_target_" in value:
396
- default_params = get_default_params(value["_target_"])
397
- for default_key, default_v in default_params.items():
398
- if default_key not in value:
399
- value[default_key] = default_v
400
- except Exception as e:
401
- logger.error(f"Failed to add default argument values: {e}")
402
-
403
- serialize_config(value)
404
- else:
405
- if not is_serializable(value) and value is not None:
406
- config[key] = str(value)
407
- elif isinstance(config, ListConfig):
408
- for i, item in enumerate(config):
409
- if isinstance(item, (DictConfig, ListConfig)):
410
- serialize_config(item)
411
- else:
412
- if not is_serializable(item) and item is not None:
413
- config[i] = str(item)
414
- else:
415
- raise NotImplementedError("Input config must be a DictConfig or ListConfig.")
416
- return config
417
-
418
- # Convert Config object to a DictConfig object.
419
- config_dict = attrs.asdict(cfg)
420
- config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
421
-
422
- # Serialize the DictConfig object by converting non-serializable objects to strings.
423
- config_omegaconf = serialize_config(config_omegaconf)
424
-
425
- config_dict: Dict[str, Any] = OmegaConf.to_container(config_omegaconf, resolve=True)
426
- sorted_config: OrderedDict[str, Any] = sort_recursive(config_dict)
427
- with open(filename, "w") as f:
428
- yaml.dump(sorted_config, f, default_flow_style=False)
429
- logger.warning(f"Config is saved using omegaconf at {filename}.")
430
- return filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/lazy_config/omegaconf_patch.py DELETED
@@ -1,65 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Any, Dict, List, Union
17
-
18
- from omegaconf import OmegaConf
19
- from omegaconf.base import DictKeyType, SCMode
20
- from omegaconf.dictconfig import DictConfig # pragma: no cover
21
-
22
-
23
- def to_object(cfg: Any) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]:
24
- """
25
- Converts an OmegaConf configuration object to a native Python container (dict or list), unless
26
- the configuration is specifically created by LazyCall, in which case the original configuration
27
- is returned directly.
28
-
29
- This function serves as a modification of the original `to_object` method from OmegaConf,
30
- preventing DictConfig objects created by LazyCall from being automatically converted to Python
31
- dictionaries. This ensures that configurations meant to be lazily evaluated retain their intended
32
- structure and behavior.
33
-
34
- Differences from OmegaConf's original `to_object`:
35
- - Adds a check at the beginning to return the configuration unchanged if it is created by LazyCall.
36
-
37
- Reference:
38
- - Original OmegaConf `to_object` method: https://github.com/omry/omegaconf/blob/master/omegaconf/omegaconf.py#L595
39
-
40
- Args:
41
- cfg (Any): The OmegaConf configuration object to convert.
42
-
43
- Returns:
44
- Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: The converted Python container if
45
- `cfg` is not a LazyCall created configuration, otherwise the unchanged `cfg`.
46
-
47
- Examples:
48
- >>> cfg = DictConfig({"key": "value", "_target_": "Model"})
49
- >>> to_object(cfg)
50
- DictConfig({"key": "value", "_target_": "Model"})
51
-
52
- >>> cfg = DictConfig({"list": [1, 2, 3]})
53
- >>> to_object(cfg)
54
- {'list': [1, 2, 3]}
55
- """
56
- if isinstance(cfg, DictConfig) and "_target_" in cfg.keys():
57
- return cfg
58
-
59
- return OmegaConf.to_container(
60
- cfg=cfg,
61
- resolve=True,
62
- throw_on_missing=True,
63
- enum_to_str=False,
64
- structured_config_mode=SCMode.INSTANTIATE,
65
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/lazy_config/registry.py DELETED
@@ -1,74 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import pydoc
17
- from typing import Any
18
-
19
- from fvcore.common.registry import Registry # for backward compatibility.
20
-
21
- """
22
- ``Registry`` and `locate` provide ways to map a string (typically found
23
- in config files) to callable objects.
24
- """
25
-
26
- __all__ = ["Registry", "locate"]
27
-
28
-
29
- def _convert_target_to_string(t: Any) -> str:
30
- """
31
- Inverse of ``locate()``.
32
-
33
- Args:
34
- t: any object with ``__module__`` and ``__qualname__``
35
- """
36
- module, qualname = t.__module__, t.__qualname__
37
-
38
- # Compress the path to this object, e.g. ``module.submodule._impl.class``
39
- # may become ``module.submodule.class``, if the later also resolves to the same
40
- # object. This simplifies the string, and also is less affected by moving the
41
- # class implementation.
42
- module_parts = module.split(".")
43
- for k in range(1, len(module_parts)):
44
- prefix = ".".join(module_parts[:k])
45
- candidate = f"{prefix}.{qualname}"
46
- try:
47
- if locate(candidate) is t:
48
- return candidate
49
- except ImportError:
50
- pass
51
- return f"{module}.{qualname}"
52
-
53
-
54
- def locate(name: str) -> Any:
55
- """
56
- Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``,
57
- such as "module.submodule.class_name".
58
-
59
- Raise Exception if it cannot be found.
60
- """
61
- obj = pydoc.locate(name)
62
-
63
- # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly
64
- # by pydoc.locate. Try a private function from hydra.
65
- if obj is None:
66
- try:
67
- # from hydra.utils import get_method - will print many errors
68
- from hydra.utils import _locate
69
- except ImportError as e:
70
- raise ImportError(f"Cannot dynamically locate object {name}!") from e
71
- else:
72
- obj = _locate(name) # it raises if fails
73
-
74
- return obj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/model.py DELETED
@@ -1,129 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Any
17
-
18
- import torch
19
-
20
- from lyra_2._ext.imaginaire.lazy_config import LazyDict, instantiate
21
-
22
-
23
- class ImaginaireModel(torch.nn.Module):
24
- """The base model class of Imaginaire. It is inherited from torch.nn.Module.
25
-
26
- All models in Imaginaire should inherit ImaginaireModel. It should include the implementions for all the
27
- computation graphs. All inheriting child classes should implement the following methods:
28
- - training_step(): The training step of the model, including the loss computation.
29
- - validation_step(): The validation step of the model, including the loss computation.
30
- - forward(): The computation graph for model inference.
31
- The following methods have default implementations in ImaginaireModel:
32
- - init_optimizer_scheduler(): Creates the optimizer and scheduler for the model.
33
- """
34
-
35
- def __init__(self) -> None:
36
- super().__init__()
37
-
38
- def init_optimizer_scheduler(
39
- self, optimizer_config: LazyDict, scheduler_config: LazyDict
40
- ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
41
- """Creates the optimizer and scheduler for the model.
42
-
43
- Args:
44
- config_model (ModelConfig): The config object for the model.
45
-
46
- Returns:
47
- optimizer (torch.optim.Optimizer): The model optimizer.
48
- scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
49
- """
50
- optimizer_config.params = self.parameters()
51
- optimizer = instantiate(optimizer_config)
52
- scheduler_config.optimizer = optimizer
53
- scheduler = instantiate(scheduler_config)
54
- return optimizer, scheduler
55
-
56
- def training_step(
57
- self, data_batch: dict[str, torch.Tensor], iteration: int
58
- ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
59
- """The training step of the model, including the loss computation.
60
-
61
- Args:
62
- data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
63
- iteration (int): Current iteration number.
64
-
65
- Returns:
66
- output_batch (dict[str, torch.Tensor]): Auxiliary model output from the training batch.
67
- loss (torch.Tensor): The total loss for backprop (weighted sum of various losses).
68
- """
69
- raise NotImplementedError
70
-
71
- @torch.no_grad()
72
- def validation_step(
73
- self, data_batch: dict[str, torch.Tensor], iteration: int
74
- ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
75
- """The validation step of the model, including the loss computation.
76
-
77
- Args:
78
- data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
79
- iteration (int): Current iteration number.
80
-
81
- Returns:
82
- output_batch (dict[str, torch.Tensor]): Auxiliary model output from the validation batch.
83
- loss (torch.Tensor): The total loss (weighted sum of various losses).
84
- """
85
- raise NotImplementedError
86
-
87
- @torch.inference_mode()
88
- def forward(self, *args: Any, **kwargs: Any) -> Any:
89
- """The computation graph for model inference.
90
-
91
- Args:
92
- *args: Whatever you decide to pass into the forward method.
93
- **kwargs: Keyword arguments are also possible.
94
-
95
- Return:
96
- Your model's output.
97
- """
98
- raise NotImplementedError
99
-
100
- def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None:
101
- """The model preparation before the training is launched
102
-
103
- Args:
104
- memory_format (torch.memory_format): Memory format of the model.
105
- """
106
- pass
107
-
108
- def on_before_zero_grad(
109
- self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int
110
- ) -> None:
111
- """Hook before zero_grad() is called.
112
-
113
- Args:
114
- optimizer (torch.optim.Optimizer): The model optimizer.
115
- scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
116
- iteration (int): Current iteration number.
117
- """
118
- pass
119
-
120
- def on_after_backward(self, iteration: int = 0) -> None:
121
- """Hook after loss.backward() is called.
122
-
123
- This method is called immediately after the backward pass, allowing for custom operations
124
- or modifications to be performed on the gradients before the optimizer step.
125
-
126
- Args:
127
- iteration (int): Current iteration number.
128
- """
129
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/trainer.py DELETED
@@ -1,379 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import functools
17
- import inspect
18
- import os
19
- import signal
20
-
21
- import torch
22
- import torch.distributed as dist
23
- import torch.utils.data
24
-
25
- from lyra_2._ext.imaginaire.utils.context_managers import distributed_init
26
- from lyra_2._ext.imaginaire.utils.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
27
-
28
- try:
29
- from megatron.core import parallel_state
30
-
31
- USE_MEGATRON = True
32
- except ImportError:
33
- USE_MEGATRON = False
34
- print("Megatron-core is not installed.")
35
-
36
-
37
- from lyra_2._ext.imaginaire.lazy_config import LazyConfig, instantiate
38
- from lyra_2._ext.imaginaire.model import ImaginaireModel
39
- from lyra_2._ext.imaginaire.utils import callback, distributed, ema, log, misc
40
- from lyra_2._ext.imaginaire.utils.checkpointer import Checkpointer
41
- from lyra_2._ext.imaginaire.utils.misc import StragglerDetectorV2
42
-
43
- try:
44
- from lyra_2._src.callbacks.smart_stop import TimeoutException
45
- except ImportError:
46
- # Define a dummy exception if smart_stop is not available
47
- class TimeoutException(Exception):
48
- pass
49
-
50
-
51
- class ImaginaireTrainer:
52
- """The base trainer class of Imaginaire.
53
-
54
- All trainers in Imaginaire should inherit ImaginaireTrainer. It contains the basic functionality for model training
55
- (particularly suited for large-scale training), including data parallel (DDP/FSDP), model weight average (EMA),
56
- mixed-precision training (fp16/bf16).
57
-
58
- Attributes:
59
- checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states.
60
- training_timer (misc.Timer): Timer object to time code blocks and functions.
61
- """
62
-
63
- def __init__(self, config):
64
- """Constructor of the trainer.
65
-
66
- Args:
67
- config (Config): The config object for the Imaginaire codebase.
68
- """
69
- super().__init__()
70
- self.config = config
71
- # Set up the distributed computing environment.
72
- with distributed_init():
73
- distributed.init()
74
- # Set up parallel states.
75
- if hasattr(config.model, "context_parallel_size"):
76
- if config.model_parallel.context_parallel_size > 1:
77
- raise ValueError(
78
- "Both config.model.context_parallel_size and config.model_parallel.context_parallel_size are set. "
79
- "config.model.context_parallel_size is deprecated. Please only set config.model_parallel.context_parallel_size."
80
- )
81
- else:
82
- log.critical(
83
- "Using deprecated config.model.context_parallel_size. Please use config.model_parallel.context_parallel_size instead."
84
- )
85
- config.model_parallel.context_parallel_size = config.model.context_parallel_size
86
- if USE_MEGATRON:
87
- if (
88
- "create_gloo_process_groups"
89
- in inspect.signature(parallel_state.initialize_model_parallel).parameters
90
- ):
91
- parallel_state.initialize_model_parallel(
92
- pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size,
93
- tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size,
94
- context_parallel_size=config.model_parallel.context_parallel_size,
95
- create_gloo_process_groups=False,
96
- )
97
- else:
98
- parallel_state.initialize_model_parallel(
99
- pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size,
100
- tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size,
101
- context_parallel_size=config.model_parallel.context_parallel_size,
102
- )
103
- # `config.model_parallel.sequence_parallel` is a bool that indicates whether to use sequence parallelism.
104
- # It is not part of the original `parallel_state` API, so we need to set it manually.
105
- parallel_state.sequence_parallel = config.model_parallel.sequence_parallel
106
- if parallel_state.sequence_parallel:
107
- os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
108
-
109
- # Create the local job directory, save the config file, and pipe to a local log.
110
- if distributed.is_rank0():
111
- os.makedirs(config.job.path_local, exist_ok=True)
112
- # Save the config as .pkl for reproducibility.
113
- LazyConfig.save_pkl(config, f"{config.job.path_local}/config.pkl")
114
- # Save the config as .yaml for reading or parsing experiment hyperparameters.
115
- LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml")
116
- dist.barrier()
117
- log.init_loguru_file(f"{config.job.path_local}/stdout.log")
118
- if distributed.is_rank0():
119
- # Print important environment variables and the effective config.
120
- log.info("Config:\n" + config.pretty_print(use_color=True))
121
- misc.print_environ_variables(["HF_HOME", "IMAGINAIRE_OUTPUT_ROOT"])
122
- # Set the random seed. If multi-GPU, different ranks are set with different seeds.
123
- misc.set_random_seed(seed=config.trainer.seed, by_rank=True)
124
- # Initialize cuDNN.
125
- torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic
126
- torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark
127
- # Floating-point precision settings.
128
- torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True
129
- # Initialize the callback functions.
130
- self.callbacks = callback.CallBackGroup(config=config, trainer=self)
131
- # Initialize the model checkpointer.
132
- if config.checkpoint.type is None:
133
- self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks)
134
- else:
135
- self.checkpointer: Checkpointer = instantiate(
136
- config.checkpoint.type, config.checkpoint, config.job, callbacks=self.callbacks
137
- )
138
- # Initialize the timer for speed benchmarking.
139
- self.training_timer = misc.TrainingTimer()
140
- # Initialize Straggler Detection
141
- self.straggler_detector = StragglerDetectorV2(
142
- enabled=self.config.trainer.straggler_detection.enabled,
143
- report_freq=self.config.trainer.straggler_detection.report_freq,
144
- profile_freq=self.config.trainer.straggler_detection.profile_freq,
145
- max_diff=self.config.trainer.straggler_detection.max_diff,
146
- raise_error=self.config.trainer.straggler_detection.raise_error,
147
- )
148
- self.straggler_detector.initialize()
149
- # Send a TimeoutError if a training step takes over timeout_period seconds.
150
- signal.signal(signal.SIGALRM, functools.partial(misc.timeout_handler, config.trainer.timeout_period)) # type: ignore
151
-
152
- def train(
153
- self,
154
- model: ImaginaireModel,
155
- dataloader_train: torch.utils.data.DataLoader,
156
- dataloader_val: torch.utils.data.DataLoader,
157
- ) -> None:
158
- """The training function.
159
-
160
- Args:
161
- model (ImaginaireModel): The PyTorch model.
162
- dataloader_train (torch.utils.data.DataLoader): The training data loader.
163
- dataloader_val (torch.utils.data.DataLoader): The validation data loader.
164
- """
165
- # Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models.
166
- model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore
167
- model.on_train_start(self.config.trainer.memory_format)
168
-
169
- # Initialize the optimizer, scheduler, and grad_scaler.
170
- self.callbacks.on_optimizer_init_start()
171
- optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler)
172
- grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args)
173
- self.callbacks.on_optimizer_init_end()
174
- # Load the model checkpoint and get the starting iteration number.
175
- iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler)
176
- grad_accum_iter = 0
177
- log.info(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}")
178
- if self.config.trainer.distributed_parallelism == "ddp":
179
- # Create a DDP model wrapper.
180
- model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model)
181
- elif self.config.trainer.distributed_parallelism == "fsdp":
182
- model_ddp = model
183
- else:
184
- raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}")
185
-
186
- log.info("Starting training...")
187
- self.callbacks.on_train_start(model, iteration=iteration)
188
- # Initial validation.
189
- if self.config.trainer.run_validation and iteration == 0:
190
- self.validate(model, dataloader_val, iteration=iteration)
191
- _end_training = False
192
- _smart_stop_triggered = False
193
- _oom_triggered = False
194
- with (
195
- maybe_enable_profiling(self.config, global_step=iteration) as torch_profiler,
196
- maybe_enable_memory_snapshot(self.config, global_step=iteration) as memory_profiler,
197
- ):
198
- while True:
199
- dataloader_train_iter = iter(dataloader_train)
200
- while True:
201
- self.callbacks.on_before_dataloading(iteration)
202
- try:
203
- with (
204
- self.training_timer("dataloader_train"),
205
- self.straggler_detector.profile_section(
206
- "dataloading",
207
- self.config.trainer.straggler_detection.analyze_dataloading,
208
- profile_cuda=False,
209
- ),
210
- ):
211
- data_batch = next(dataloader_train_iter)
212
- except StopIteration:
213
- break
214
- finally:
215
- self.callbacks.on_after_dataloading(iteration)
216
- # If max_iter is reached, exit the training loop.
217
- if iteration >= self.config.trainer.max_iter:
218
- _end_training = True
219
- break
220
- # Move all tensors in the data batch to GPU device.
221
- data_batch = misc.to(data_batch, device="cuda")
222
- # The actual training step.
223
- self.callbacks.on_training_step_start(model, data_batch, iteration=iteration)
224
- self.callbacks.on_training_step_batch_start(model, data_batch, iteration=iteration)
225
- if not model.training:
226
- model_ddp.train()
227
- assert model_ddp.training, "model_ddp is not in training mode."
228
- assert model.training, "model is not in training mode."
229
- try:
230
- output_batch, loss, grad_accum_iter = self.training_step(
231
- model_ddp,
232
- optimizer,
233
- scheduler,
234
- grad_scaler,
235
- data_batch,
236
- iteration=iteration,
237
- grad_accum_iter=grad_accum_iter,
238
- )
239
- except torch.OutOfMemoryError as e:
240
- # CUDA OOM error - save checkpoint and exit gracefully
241
- _oom_triggered = True
242
- log.error(f"CUDA Out of Memory error caught: {e}")
243
- log.info("Saving checkpoint due to CUDA OOM error...")
244
- self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration)
245
- log.success("Checkpoint saved successfully. Exiting training gracefully due to OOM.")
246
- _end_training = True
247
- break
248
- self.callbacks.on_training_step_batch_end(
249
- model, data_batch, output_batch, loss, iteration=iteration
250
- )
251
- # If the gradients are still being accumulated, continue to load the next training batch.
252
- if grad_accum_iter != 0:
253
- continue
254
- # Do the following when an actual optimizer (update) step has been made.
255
- iteration += 1
256
- # Save checkpoint.
257
- if iteration % self.config.checkpoint.save_iter == 0:
258
- self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration)
259
- # Call on_training_step_end with SmartStop exception handling
260
- try:
261
- self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration)
262
- except TimeoutException as e:
263
- # Smart stop triggered - save checkpoint and exit gracefully
264
- _smart_stop_triggered = True
265
- log.warning(f"SmartStop exception caught: {e}")
266
- log.info("Saving checkpoint due to time limit...")
267
- self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration)
268
- log.success("Checkpoint saved successfully. Exiting training gracefully.")
269
- _end_training = True
270
- break
271
- # Validation.
272
- if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0:
273
- self.validate(model, dataloader_val, iteration=iteration)
274
- # This iteration is successful; reset the timeout signal.
275
- signal.alarm(self.config.trainer.timeout_period)
276
- self.straggler_detector.generate_report(iteration)
277
- if torch_profiler:
278
- torch_profiler.step()
279
- if memory_profiler:
280
- memory_profiler.step()
281
- if _end_training:
282
- break
283
- log.success("Done with training.")
284
- if iteration % self.config.checkpoint.save_iter != 0 and not _smart_stop_triggered and not _oom_triggered:
285
- self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration)
286
- self.callbacks.on_train_end(model, iteration=iteration)
287
- self.checkpointer.finalize()
288
- distributed.barrier()
289
- self.callbacks.on_app_end()
290
-
291
- def training_step(
292
- self,
293
- model_ddp: torch.nn.Module | distributed.DistributedDataParallel,
294
- optimizer: torch.optim.Optimizer,
295
- scheduler: torch.optim.lr_scheduler.LRScheduler,
296
- grad_scaler: torch.amp.GradScaler,
297
- data: dict[str, torch.Tensor],
298
- iteration: int = 0,
299
- grad_accum_iter: int = 0,
300
- ) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]:
301
- """The training step.
302
-
303
- Args:
304
- model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare
305
- module, depending on whether distributed training is enabled or not.
306
- optimizer (torch.optim.Optimizer): The model optimizer.
307
- scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
308
- grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
309
- data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
310
- iteration (int): Current iteration number.
311
- grad_accum_iter (int): Number of gradient accumulation iterations.
312
-
313
- Returns:
314
- output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors).
315
- loss (torch.Tensor): The total loss of the training data batch.
316
- """
317
- # Only let DDP sync gradient at the last iteration of the gradient accumulation window
318
- with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1):
319
- self.callbacks.on_before_forward(iteration=iteration)
320
- with self.training_timer("forward"):
321
- with self.straggler_detector.profile_section(
322
- "fwd", self.config.trainer.straggler_detection.analyze_forward
323
- ):
324
- output_batch, loss = model_ddp.training_step(data, iteration)
325
- self.callbacks.on_after_forward(iteration=iteration)
326
- self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration)
327
- with self.training_timer("backward"):
328
- with self.straggler_detector.profile_section(
329
- "bwd", self.config.trainer.straggler_detection.analyze_backward
330
- ):
331
- loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter)
332
- loss_scaled.backward()
333
- if self.config.trainer.distributed_parallelism == "ddp":
334
- model_ddp.module.on_after_backward()
335
- else:
336
- model_ddp.on_after_backward()
337
- self.callbacks.on_after_backward(model_ddp, iteration=iteration)
338
- grad_accum_iter += 1
339
- if grad_accum_iter == self.config.trainer.grad_accum_iter:
340
- with self.training_timer("optimizer_step"):
341
- with self.straggler_detector.profile_section(
342
- "opt", self.config.trainer.straggler_detection.analyze_optimizer
343
- ):
344
- self.callbacks.on_before_optimizer_step(
345
- model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration
346
- )
347
- grad_scaler.step(optimizer)
348
- grad_scaler.update()
349
- scheduler.step()
350
- self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration)
351
- if self.config.trainer.distributed_parallelism == "ddp":
352
- model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration)
353
- else:
354
- model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration)
355
- optimizer.zero_grad(set_to_none=True)
356
- grad_accum_iter = 0
357
- return output_batch, loss, grad_accum_iter
358
-
359
- @torch.no_grad()
360
- def validate(self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None:
361
- """Validate on the full validation dataset.
362
-
363
- Args:
364
- model (ImaginaireModel): The PyTorch model.
365
- dataloader_val (torch.utils.data.DataLoader): The validation data loader.
366
- iteration (int): Current iteration number.
367
- """
368
- self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration)
369
- model.eval()
370
- # Evaluate on the full validation set.
371
- with ema.ema_scope(model, enabled=model.config.ema.enabled):
372
- for val_iter, data_batch in enumerate(dataloader_val):
373
- if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter:
374
- break
375
- data_batch = misc.to(data_batch, device="cuda")
376
- self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration)
377
- output_batch, loss = model.validation_step(data_batch, iteration)
378
- self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration)
379
- self.callbacks.on_validation_end(model, iteration=iteration)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/types/denoise_prediction.py DELETED
@@ -1,28 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from __future__ import annotations
17
-
18
- from dataclasses import dataclass
19
- from typing import Optional
20
-
21
- import torch
22
-
23
-
24
- @dataclass
25
- class DenoisePrediction:
26
- x0: torch.Tensor # clean data prediction
27
- eps: Optional[torch.Tensor] = None # noise prediction
28
- logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/callback.py DELETED
@@ -1,524 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from __future__ import annotations
17
-
18
- import sys
19
- import time
20
- import warnings
21
- from typing import TYPE_CHECKING, Any, Callable, Optional
22
-
23
- import omegaconf
24
- import torch
25
- import torch.distributed as dist
26
- import torch.utils.data
27
- import tqdm
28
- from lyra_2._ext.imaginaire.lazy_config import instantiate
29
- from lyra_2._ext.imaginaire.utils import distributed, log, misc
30
- from lyra_2._ext.imaginaire.utils.misc import get_local_tensor_if_DTensor
31
-
32
- try:
33
- from megatron.core import parallel_state
34
- except ImportError:
35
- parallel_state = None
36
- print("Megatron-core is not installed.")
37
-
38
-
39
- if TYPE_CHECKING:
40
- from lyra_2._ext.imaginaire.config import Config
41
- from lyra_2._ext.imaginaire.model import ImaginaireModel
42
- from lyra_2._ext.imaginaire.trainer import ImaginaireTrainer
43
-
44
-
45
- class CallBackGroup:
46
- """A class for hosting a collection of callback objects.
47
-
48
- It is used to execute callback functions of multiple callback objects with the same method name.
49
- When callbackgroup.func(args) is executed, internally it loops through the objects in self._callbacks and runs
50
- self._callbacks[0].func(args), self._callbacks[1].func(args), etc. The method name and arguments should match.
51
-
52
- Attributes:
53
- _callbacks (list[Callback]): List of callback objects.
54
- """
55
-
56
- def __init__(self, config: Config, trainer: ImaginaireTrainer) -> None:
57
- """Initializes the list of callback objects.
58
-
59
- Args:
60
- config (Config): The config object for the Imaginaire codebase.
61
- trainer (ImaginaireTrainer): The main trainer.
62
- """
63
- self._callbacks = []
64
- callback_configs = config.trainer.callbacks
65
- if callback_configs:
66
- if isinstance(callback_configs, list) or isinstance(callback_configs, omegaconf.listconfig.ListConfig):
67
- warnings.warn(
68
- "The 'config.trainer.callbacks' parameter should be a dict instead of a list. "
69
- "Please update your code",
70
- DeprecationWarning,
71
- stacklevel=2,
72
- )
73
- callback_configs = {f"callback_{i}": v for i, v in enumerate(callback_configs)}
74
- for callback_name, current_callback_cfg in callback_configs.items():
75
- if "_target_" not in current_callback_cfg:
76
- log.critical(
77
- f"Callback {callback_name} is missing the '_target_' field. \n SKip {current_callback_cfg}"
78
- )
79
- continue
80
- log.info(f"Instantiating callback {callback_name}: {current_callback_cfg}")
81
- _callback = instantiate(current_callback_cfg)
82
- assert isinstance(_callback, Callback), f"{current_callback_cfg} is not a valid callback."
83
- _callback.config = config
84
- _callback.trainer = trainer
85
- self._callbacks.append(_callback)
86
-
87
- def __getattr__(self, method_name: str) -> Callable:
88
- """Loops through the callback objects to call the corresponding callback function.
89
-
90
- Args:
91
- method_name (str): Callback method name.
92
- """
93
-
94
- def multi_callback_wrapper(*args, **kwargs) -> None:
95
- for callback in self._callbacks:
96
- assert hasattr(callback, method_name)
97
- method = getattr(callback, method_name)
98
- assert callable(method)
99
- _ = method(*args, **kwargs)
100
-
101
- return multi_callback_wrapper
102
-
103
-
104
- class Callback:
105
- """The base class for all callbacks.
106
-
107
- All callbacks should inherit from this class and adhere to the established method names and signatures.
108
- """
109
-
110
- def __init__(self, config: Optional["Config"] = None, trainer: Optional["ImaginaireTrainer"] = None):
111
- """Initializes a Callback object.
112
-
113
- Args:
114
- config (Optional[Config]): The configuration object for the Imaginaire codebase, if available.
115
- trainer (Optional[ImaginaireTrainer]): The main trainer handling the training loop, if available.
116
-
117
- Notes:
118
- The config and trainer parameters are optional to maintain backward compatibility.
119
- In future releases, these parameters will be removed. Upon using these parameters, a deprecation
120
- warning will be issued.
121
-
122
- """
123
- if config is not None or trainer is not None:
124
- warnings.warn(
125
- "The 'config' and 'trainer' parameters are deprecated and will be removed in a future release. "
126
- "Please update your code to create Callback instances without these parameters.",
127
- DeprecationWarning,
128
- stacklevel=2,
129
- )
130
- del config, trainer
131
-
132
- def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
133
- pass
134
-
135
- def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
136
- """
137
- Called before the training step, for each batch. This is paired with on_training_step_end() but note that
138
- when using gradient accumulation, while on_training_step_end() is only called when the optimizer is updated,
139
- this function is called for every batch.
140
- Use on_training_step_batch_start and on_training_step_batch_end if you need callbacks that are called
141
- for every batch, albeit with the same iteration number.
142
- FIXME - should this either be deprecated, or called only when a new training step is started after having updated
143
- the optimizer?
144
- """
145
- pass
146
-
147
- def on_training_step_batch_start(
148
- self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
149
- ) -> None:
150
- """
151
- Called before the training step, for each batch, similarly to on_training_step_start(). This function is paired with
152
- on_training_step_batch_end(), and both functions are called for every batch even when using gradient accumulation.
153
- Note that the iteration is only updated when the optimizer is updated, and therefore it may be the same for multiple invocations.
154
- """
155
- pass
156
-
157
- def on_before_forward(self, iteration: int = 0) -> None:
158
- pass
159
-
160
- def on_after_forward(self, iteration: int = 0) -> None:
161
- pass
162
-
163
- def on_before_backward(
164
- self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0
165
- ) -> None:
166
- pass
167
-
168
- def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None:
169
- pass
170
-
171
- def on_before_dataloading(self, iteration: int = 0) -> None:
172
- pass
173
-
174
- def on_after_dataloading(self, iteration: int = 0) -> None:
175
- pass
176
-
177
- def on_optimizer_init_start(self) -> None:
178
- pass
179
-
180
- def on_optimizer_init_end(self) -> None:
181
- pass
182
-
183
- def on_before_optimizer_step(
184
- self,
185
- model_ddp: distributed.DistributedDataParallel,
186
- optimizer: torch.optim.Optimizer,
187
- scheduler: torch.optim.lr_scheduler.LRScheduler,
188
- grad_scaler: torch.amp.GradScaler,
189
- iteration: int = 0,
190
- ) -> None:
191
- pass
192
-
193
- def on_before_zero_grad(
194
- self,
195
- model_ddp: distributed.DistributedDataParallel,
196
- optimizer: torch.optim.Optimizer,
197
- scheduler: torch.optim.lr_scheduler.LRScheduler,
198
- iteration: int = 0,
199
- ) -> None:
200
- pass
201
-
202
- def on_training_step_batch_end(
203
- self,
204
- model: ImaginaireModel,
205
- data_batch: dict[str, torch.Tensor],
206
- output_batch: dict[str, torch.Tensor],
207
- loss: torch.Tensor,
208
- iteration: int = 0,
209
- ) -> None:
210
- """
211
- Called at the end of a training step for every batch even when using gradient accumulation.
212
- This is paired with on_training_step_batch_start(). Note that the iteration is only updated when the optimizer is updated,
213
- and therefore it may be the same for multiple batches.
214
- """
215
- pass
216
-
217
- def on_training_step_end(
218
- self,
219
- model: ImaginaireModel,
220
- data_batch: dict[str, torch.Tensor],
221
- output_batch: dict[str, torch.Tensor],
222
- loss: torch.Tensor,
223
- iteration: int = 0,
224
- ) -> None:
225
- """
226
- Called at the end of a training step, but note that when using gradient accumulation, this is only called
227
- when the optimizer is updated, and the iteration incremented, whereas on_training_step_start is called every time.
228
- Use on_training_step_batch_start and on_training_step_batch_end if you need callbacks that are called
229
- for every batch.
230
- """
231
- pass
232
-
233
- def on_validation_start(
234
- self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0
235
- ) -> None:
236
- pass
237
-
238
- def on_validation_step_start(
239
- self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
240
- ) -> None:
241
- pass
242
-
243
- def on_validation_step_end(
244
- self,
245
- model: ImaginaireModel,
246
- data_batch: dict[str, torch.Tensor],
247
- output_batch: dict[str, torch.Tensor],
248
- loss: torch.Tensor,
249
- iteration: int = 0,
250
- ) -> None:
251
- pass
252
-
253
- def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
254
- pass
255
-
256
- def on_load_checkpoint_start(self, model: ImaginaireModel) -> None:
257
- pass
258
-
259
- def on_load_checkpoint_end(
260
- self, model: ImaginaireModel, iteration: int = 0, checkpoint_path: Optional[str] = None
261
- ) -> None:
262
- pass
263
-
264
- def on_load_checkpoint(self, model: ImaginaireModel, state_dict: dict[Any]) -> None:
265
- """
266
- Called when checkpoint loading is about to start, but after on_save_checkpoint_start().
267
- FIXME - why do we need this callback, can't we just use on_save_checkpoint_start()?
268
- """
269
- pass
270
-
271
- def on_save_checkpoint_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
272
- """
273
- Called when checkpoint saving is about to start.
274
- """
275
- pass
276
-
277
- def on_save_checkpoint_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
278
- """
279
- Called when the synchronous part of checkpointing is finished, this function can be used
280
- along with on_save_checkpoint_start() to measure the exposed (synchronous) checkpoint time.
281
- Note that for asynchronous checkpoint, the checkpoint may still be ongoing, so this function
282
- does not mean the checkpoint is finished for the asynchronous case, use on_save_checkpoint_success()
283
- for that.
284
- """
285
- pass
286
-
287
- def on_save_checkpoint_success(self, iteration: int = 0, elapsed_time: float = 0) -> None:
288
- """
289
- Called when checkpoint saving is fully finished, and succeeded. Not called if checkpoint failed.
290
- For synchronous checkpoint, it is called at the same time as on_save_checkpoint_end(), but for asynchronous
291
- checkpoint, it is called after the asynchronous part has also finished. For checkpointers with out-of-process
292
- checkpointing, this function is called as soon as the notification is received from the checkpointer process,
293
- which may not be immediately after the checkpoint has completed but later on. Therefore, if you need to measure
294
- the full checkpoint duration for the asynchronous part, use the elapsed_time parameter, do not measure it directly
295
- as this would be a significant overestimate.
296
- """
297
- pass
298
-
299
- def on_save_checkpoint(self, model: ImaginaireModel, state_dict: dict[Any]) -> None:
300
- pass
301
-
302
- def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
303
- pass
304
-
305
- def on_app_end(self) -> None:
306
- pass
307
-
308
-
309
- class EMAModelCallback(Callback):
310
- """The callback class for tracking EMA model weights."""
311
-
312
- def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
313
- # Set up the EMA model weight tracker.
314
- if model.config.ema.enabled:
315
- assert hasattr(model, "ema"), "EMA should be initialized from ImaginaireModel"
316
- # EMA model must be kept in FP32 precision.
317
- model.ema = model.ema.to(dtype=torch.float32)
318
- else:
319
- assert not hasattr(model, "ema"), "There should be no EMA initialized."
320
-
321
- def on_training_step_end(
322
- self,
323
- model: ImaginaireModel,
324
- data_batch: dict[str, torch.Tensor],
325
- output_batch: dict[str, torch.Tensor],
326
- loss: torch.Tensor,
327
- iteration: int = 0,
328
- ) -> None:
329
- # Update the EMA model with the new regular weights.
330
- if model.config.ema.enabled:
331
- model.ema.update_average(model, iteration)
332
-
333
-
334
- class ProgressBarCallback(Callback):
335
- """The callback class for visualizing the training/validation progress bar in the console."""
336
-
337
- @distributed.rank0_only
338
- def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
339
- self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training")
340
-
341
- @distributed.rank0_only
342
- def on_training_step_end(
343
- self,
344
- model: ImaginaireModel,
345
- data_batch: dict[str, torch.Tensor],
346
- output_batch: dict[str, torch.Tensor],
347
- loss: torch.Tensor,
348
- iteration: int = 0,
349
- ) -> None:
350
- self.train_pbar.update()
351
-
352
- @distributed.rank0_only
353
- def on_validation_start(
354
- self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0
355
- ) -> None:
356
- if self.config.trainer.max_val_iter is not None:
357
- num_iter = self.config.trainer.max_val_iter
358
- else:
359
- num_iter = len(dataloader_val)
360
- assert num_iter is not None and num_iter > 0, f"Invalid number of validation iterations: {num_iter}"
361
- self.val_pbar = tqdm.trange(num_iter, desc="Validating", position=1, leave=False)
362
-
363
- @distributed.rank0_only
364
- def on_validation_step_end(
365
- self,
366
- model: ImaginaireModel,
367
- data_batch: dict[str, torch.Tensor],
368
- output_batch: dict[str, torch.Tensor],
369
- loss: torch.Tensor,
370
- iteration: int = 0,
371
- ) -> None:
372
- self.val_pbar.update()
373
-
374
- @distributed.rank0_only
375
- def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
376
- self.val_pbar.close()
377
-
378
- @distributed.rank0_only
379
- def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
380
- self.trainer.checkpointer.finalize()
381
- self.train_pbar.close()
382
-
383
-
384
- class IterationLoggerCallback(Callback):
385
- """The callback class for visualizing the training/validation progress bar in the console."""
386
-
387
- @distributed.rank0_only
388
- def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
389
- # self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training")
390
- self.start_iteration_time = time.time()
391
- self.elapsed_iteration_time = 0
392
-
393
- @distributed.rank0_only
394
- def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
395
- self.start_iteration_time = time.time()
396
-
397
- @distributed.rank0_only
398
- def on_training_step_end(
399
- self,
400
- model: ImaginaireModel,
401
- data_batch: dict[str, torch.Tensor],
402
- output_batch: dict[str, torch.Tensor],
403
- loss: torch.Tensor,
404
- iteration: int = 0,
405
- ) -> None:
406
- # but this is only called when the optimizer is updated, so it's only the time for the last batch.
407
- self.elapsed_iteration_time += time.time() - self.start_iteration_time
408
-
409
- if iteration % self.config.trainer.logging_iter == 0:
410
- avg_time = self.elapsed_iteration_time / self.config.trainer.logging_iter
411
- log.info(f"Iteration: {iteration}, average iter time: {avg_time:2f}, total loss {loss.item():4f}")
412
-
413
- self.elapsed_iteration_time = 0
414
-
415
-
416
- class LowPrecisionCallback(Callback):
417
- """The callback class handling low precision training"""
418
-
419
- def __init__(self, config: Config, trainer: ImaginaireTrainer, update_iter: int):
420
- self.update_iter = update_iter
421
-
422
- def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
423
- if model.precision == torch.float32:
424
- log.critical("Using fp32. We should disable master weights update.")
425
- self.update_iter = sys.maxsize
426
- else:
427
- assert model.precision in [
428
- torch.bfloat16,
429
- torch.float16,
430
- torch.half,
431
- ], "LowPrecisionCallback must use a low precision dtype."
432
- self.precision_type = model.precision
433
-
434
- def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
435
- for k, v in data.items():
436
- if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]):
437
- data[k] = v.to(dtype=self.precision_type)
438
-
439
- def on_validation_step_start(
440
- self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
441
- ) -> None:
442
- for k, v in data.items():
443
- if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]):
444
- data[k] = v.to(dtype=self.precision_type)
445
-
446
- def on_before_zero_grad(
447
- self,
448
- model_ddp: distributed.DistributedDataParallel,
449
- optimizer: torch.optim.Optimizer,
450
- scheduler: torch.optim.lr_scheduler.LRScheduler,
451
- iteration: int = 0,
452
- ) -> None:
453
- if iteration % self.update_iter == 0:
454
- if getattr(optimizer, "master_weights", False):
455
- params, master_params = [], []
456
- for group, group_master in zip(optimizer.param_groups, optimizer.param_groups_master):
457
- for p, p_master in zip(group["params"], group_master["params"]):
458
- params.append(get_local_tensor_if_DTensor(p.data))
459
- master_params.append(p_master.data)
460
- torch._foreach_copy_(params, master_params)
461
-
462
-
463
- class NVTXCallback(Callback):
464
- """The callback for creating NVTX ranges"""
465
-
466
- def __init__(
467
- self,
468
- synchronize: bool = False,
469
- config: Optional["Config"] = None,
470
- trainer: Optional["ImaginaireTrainer"] = None,
471
- ):
472
- super().__init__(config, trainer)
473
- self.synchronize = synchronize
474
-
475
- def on_before_forward(self, iteration: int = 0) -> None:
476
- if self.synchronize:
477
- torch.cuda.synchronize()
478
- torch.cuda.nvtx.range_push("forward")
479
-
480
- def on_after_forward(self, iteration: int = 0) -> None:
481
- if self.synchronize:
482
- torch.cuda.synchronize()
483
- torch.cuda.nvtx.range_pop()
484
-
485
- def on_before_backward(
486
- self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0
487
- ) -> None:
488
- if self.synchronize:
489
- torch.cuda.synchronize()
490
- torch.cuda.nvtx.range_push("backward")
491
-
492
- def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None:
493
- if self.synchronize:
494
- torch.cuda.synchronize()
495
- torch.cuda.nvtx.range_pop()
496
-
497
- def on_before_optimizer_step(
498
- self,
499
- model_ddp: distributed.DistributedDataParallel,
500
- optimizer: torch.optim.Optimizer,
501
- scheduler: torch.optim.lr_scheduler.LRScheduler,
502
- grad_scaler: torch.amp.GradScaler,
503
- iteration: int = 0,
504
- ) -> None:
505
- if self.synchronize:
506
- torch.cuda.synchronize()
507
- torch.cuda.nvtx.range_push("optimizer_step")
508
-
509
- def on_before_zero_grad(
510
- self,
511
- model_ddp: distributed.DistributedDataParallel,
512
- optimizer: torch.optim.Optimizer,
513
- scheduler: torch.optim.lr_scheduler.LRScheduler,
514
- iteration: int = 0,
515
- ) -> None:
516
- if self.synchronize:
517
- torch.cuda.synchronize()
518
- torch.cuda.nvtx.range_pop()
519
-
520
- def on_before_dataloading(self, iteration: int = 0) -> None:
521
- torch.cuda.nvtx.range_push("dataloading")
522
-
523
- def on_after_dataloading(self, iteration: int = 0) -> None:
524
- torch.cuda.nvtx.range_pop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/checkpointer.py DELETED
@@ -1,372 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from __future__ import annotations
17
-
18
- import os
19
- import threading
20
- from typing import TYPE_CHECKING, List, NamedTuple, Tuple
21
-
22
- import torch
23
-
24
- from lyra_2._ext.imaginaire.model import ImaginaireModel
25
- from lyra_2._ext.imaginaire.utils import callback, distributed, log, misc, object_store
26
-
27
- if TYPE_CHECKING:
28
- from lyra_2._ext.imaginaire.config import CheckpointConfig, JobConfig
29
-
30
- TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
31
- if TORCH_VERSION >= (1, 11):
32
- from torch.ao import quantization
33
- from torch.ao.quantization import FakeQuantizeBase, ObserverBase
34
- elif (
35
- TORCH_VERSION >= (1, 8)
36
- and hasattr(torch.quantization, "FakeQuantizeBase")
37
- and hasattr(torch.quantization, "ObserverBase")
38
- ):
39
- from torch import quantization
40
- from torch.quantization import FakeQuantizeBase, ObserverBase
41
-
42
-
43
- class _IncompatibleKeys(
44
- NamedTuple(
45
- "IncompatibleKeys",
46
- [
47
- ("missing_keys", List[str]),
48
- ("unexpected_keys", List[str]),
49
- ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]),
50
- ],
51
- )
52
- ):
53
- pass
54
-
55
-
56
- class Checkpointer:
57
- """The checkpointer class. Supports checkpoint saving/loading to both local disk or object store."""
58
-
59
- def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup):
60
- """Constructor of the checkpointer.
61
-
62
- Args:
63
- config_checkpoint (CheckpointConfig): The config object for the checkpointer.
64
- """
65
- # Set the callback functions.
66
- self.callbacks = callbacks
67
-
68
- self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints"
69
- self.checkpoint_dir_object_store = f"{config_job.path}/checkpoints"
70
- self.save_to_object_store = config_checkpoint.save_to_object_store.enabled
71
- self.load_from_object_store = config_checkpoint.load_from_object_store.enabled
72
- self.strict_resume = config_checkpoint.strict_resume
73
- self.load_path = config_checkpoint.load_path or None
74
- self.load_training_state = config_checkpoint.load_training_state
75
- self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state
76
- self.save_thread = None
77
- # Create the object store client interface.
78
- if self.save_to_object_store:
79
- self.object_store_saver = object_store.ObjectStore(config_checkpoint.save_to_object_store)
80
- if self.load_from_object_store:
81
- self.object_store_loader = object_store.ObjectStore(config_checkpoint.load_from_object_store)
82
-
83
- def save(
84
- self,
85
- model: ImaginaireModel,
86
- optimizer: torch.optim.Optimizer,
87
- scheduler: torch.optim.lr_scheduler.LRScheduler,
88
- grad_scaler: torch.amp.GradScaler,
89
- iteration: int,
90
- ) -> None:
91
- """Save network weights, optimizer parameters, scheduler parameters to a checkpoint.
92
-
93
- Args:
94
- model (ImaginaireModel): The PyTorch model.
95
- optimizer (torch.optim.Optimizer): The model optimizer.
96
- scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
97
- grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
98
- iteration (int): Current iteration number.
99
- """
100
- self.callbacks.on_save_checkpoint_start(model, iteration)
101
-
102
- checkpoint_file = f"iter_{iteration:09}.pt"
103
-
104
- if distributed.get_rank() == 0:
105
- state_dict = dict(
106
- model=model.state_dict(),
107
- optimizer=optimizer.state_dict(),
108
- scheduler=scheduler.state_dict(),
109
- grad_scaler=grad_scaler.state_dict(),
110
- iteration=iteration,
111
- )
112
- state_dict = misc.to(state_dict, device="cpu")
113
- self.callbacks.on_save_checkpoint(model, state_dict=state_dict)
114
- # Wait for previous saver thread to end.
115
- if self.save_thread:
116
- self.save_thread.join()
117
- # Run the checkpoint saver in a separate thread.
118
- self.save_thread = threading.Thread(
119
- target=self._save_worker_object_store if self.save_to_object_store else self._save_worker_local,
120
- daemon=False,
121
- args=(state_dict, checkpoint_file, distributed.get_rank()),
122
- )
123
- self.save_thread.start()
124
-
125
- # Note: Checkpoints are saved on a separate thread and this callback is not accurate.
126
- # Please check logs from on_save_checkpoint_success() for better accuracy
127
- self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration)
128
-
129
- @misc.timer("checkpoint saving (local)")
130
- def _save_worker_local(self, state_dict: dict[str, torch.Tensor], checkpoint_file: str, rank: int = 0) -> None:
131
- """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training).
132
-
133
- Args:
134
- state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler.
135
- checkpoint_file (str): The file name of the model checkpoint.
136
- rank (int): GPU device (default: 0).
137
- """
138
- checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file)
139
- os.makedirs(self.checkpoint_dir_local, exist_ok=True)
140
- try:
141
- torch.save(state_dict, checkpoint_path)
142
- if rank == 0:
143
- self._write_latest_checkpoint_file(checkpoint_file)
144
- log.success(f"Saved checkpoint (local): {checkpoint_path}")
145
- iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", ""))
146
- self.callbacks.on_save_checkpoint_success(iteration=iteration)
147
- except Exception as e: # noqa: BLE001
148
- log.exception(f"Checkpoint failed to save (local): {e}")
149
-
150
- @misc.timer("checkpoint saving (object store)")
151
- def _save_worker_object_store(
152
- self, state_dict: dict[str, torch.Tensor], checkpoint_file: str, rank: int = 0
153
- ) -> None:
154
- """Worker to upload checkpoint to object store, spawned with a child thread (in parallel with the training).
155
-
156
- Args:
157
- state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler.
158
- checkpoint_file (str): The file name of the model checkpoint.
159
- rank (int): GPU device (default: 0).
160
- """
161
- checkpoint_path = os.path.join(self.checkpoint_dir_object_store, checkpoint_file)
162
- try:
163
- self.object_store_saver.save_object(state_dict, key=checkpoint_path, type="torch")
164
- if rank == 0:
165
- self._write_latest_checkpoint_file(checkpoint_file)
166
- log.success(f"Saved checkpoint (object store): {checkpoint_path}")
167
- iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", ""))
168
- self.callbacks.on_save_checkpoint_success(iteration=iteration)
169
- except Exception as e: # noqa: BLE001
170
- log.exception(f"Checkpoint failed to upload (object store): {e}")
171
-
172
- @misc.timer("checkpoint loading")
173
- def load(
174
- self,
175
- model: ImaginaireModel,
176
- optimizer: torch.optim.Optimizer | None = None,
177
- scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
178
- grad_scaler: torch.amp.GradScaler | None = None,
179
- ) -> int:
180
- """Load network weights and optimizer states from a checkpoint in a single process.
181
-
182
- The priority of the checkpoint loading logic is:
183
- 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name.
184
- 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path.
185
- - This is typically used for inference mode.
186
- - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states.
187
- 3. If none of the above, randomly initialize the model parameters and train from scratch.
188
-
189
- Args:
190
- model (ImaginaireModel): The PyTorch model.
191
- optimizer (torch.optim.Optimizer | None): The model optimizer (default: None).
192
- scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None).
193
- grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training).
194
-
195
- Returns:
196
- iteration (int): the iteration number to start/resume from.
197
- """
198
- self.callbacks.on_load_checkpoint_start(model)
199
-
200
- latest_checkpoint_file = self._read_latest_checkpoint_file()
201
- if latest_checkpoint_file is not None:
202
- # 1. Resume training from latest_checkpoint.txt under the same name.
203
- checkpoint_dir = (
204
- self.checkpoint_dir_object_store if self.load_from_object_store else self.checkpoint_dir_local
205
- )
206
- checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file)
207
- resume = True
208
- only_resume_scheduler = True
209
- else:
210
- if self.load_path:
211
- # 2. Load the module weights specified by config_checkpoint.path.
212
- checkpoint_path = self.load_path
213
- resume = self.load_training_state
214
- only_resume_scheduler = self.only_load_scheduler_state
215
- else:
216
- # 3. Randomly initialize the model parameters and train from scratch.
217
- checkpoint_path = None
218
- resume = False
219
- only_resume_scheduler = False
220
- # Load checkpoint.
221
- if checkpoint_path is not None:
222
- self._check_checkpoint_exists(checkpoint_path)
223
- if self.load_from_object_store:
224
- log.info(f"Loading checkpoint (object store): {checkpoint_path}")
225
- state_dict = self.object_store_loader.load_object(key=checkpoint_path, type="torch", max_attempts=20)
226
- log.success(f"Complete loading checkpoint (object store): {checkpoint_path}")
227
- else:
228
- log.info(f"Loading checkpoint (local): {checkpoint_path}")
229
- state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
230
- log.success(f"Complete loading checkpoint (local): {checkpoint_path}")
231
- self.callbacks.on_load_checkpoint(model, state_dict=state_dict)
232
- # Load the state dicts.
233
- log.info("- Loading the model...")
234
- model.load_state_dict(state_dict["model"], strict=self.strict_resume)
235
- if resume or only_resume_scheduler:
236
- iteration = state_dict["iteration"]
237
- assert scheduler
238
- log.info("- Loading the scheduler...")
239
- scheduler.load_state_dict(state_dict["scheduler"])
240
- scheduler.last_epoch = iteration
241
- else:
242
- iteration = 0
243
- if resume:
244
- assert optimizer
245
- log.info("- Loading the optimizer...")
246
- optimizer.load_state_dict(state_dict["optimizer"])
247
- log.info("- Loading the gradient scaler...")
248
- grad_scaler.load_state_dict(state_dict["grad_scaler"])
249
- log.success(f"Done with loading the checkpoint (iteration {iteration}).")
250
- else:
251
- log.success("Done with loading the checkpoint.")
252
- else:
253
- # Checkpoint not found and not specified. We will train everything from scratch.
254
- iteration = 0
255
- log.info("Training from scratch.")
256
- torch.cuda.empty_cache()
257
-
258
- self.callbacks.on_load_checkpoint_end(model, iteration=iteration, checkpoint_path=checkpoint_path)
259
-
260
- return iteration
261
-
262
- def _read_latest_checkpoint_file(self) -> str | None:
263
- """Get the file name of the latest saved checkpoint. If it doesn't exist, return None.
264
-
265
- Returns:
266
- checkpoint_file (str | None): file name of the latest saved checkpoint.
267
- """
268
- checkpoint_file = None
269
- if self.load_from_object_store:
270
- latest_path = os.path.join(self.checkpoint_dir_object_store, "latest_checkpoint.txt")
271
- if self.object_store_loader.object_exists(key=latest_path):
272
- checkpoint_file = self.object_store_loader.load_object(key=latest_path, type="text").strip()
273
- else:
274
- latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt")
275
- if os.path.isfile(latest_path):
276
- checkpoint_file = open(latest_path).read().strip()
277
- return checkpoint_file
278
-
279
- def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None:
280
- """Track the file name of the latest saved checkpoint.
281
-
282
- Args:
283
- checkpoint_file (str): file name of the latest saved checkpoint.
284
- """
285
- content = f"{checkpoint_file}\n"
286
- if self.save_to_object_store:
287
- latest_path = os.path.join(self.checkpoint_dir_object_store, "latest_checkpoint.txt")
288
- self.object_store_saver.save_object(content, key=latest_path, type="text")
289
- else:
290
- latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt")
291
- with open(latest_path, "w") as file:
292
- file.write(content)
293
-
294
- def _check_checkpoint_exists(self, checkpoint_path: str) -> None:
295
- """If the file checkpoint_path does not exist, raise an error.
296
-
297
- Args:
298
- checkpoint_path (str): full path to the checkpoint.
299
- """
300
- if self.load_from_object_store:
301
- if not self.object_store_loader.object_exists(key=checkpoint_path):
302
- raise FileNotFoundError(f"File not found (object store): {checkpoint_path}")
303
- else:
304
- if not os.path.exists(checkpoint_path):
305
- raise FileNotFoundError(f"File not found (local): {checkpoint_path}")
306
-
307
- def finalize(self) -> None:
308
- """Finalize the checkpointer."""
309
- if self.save_thread:
310
- self.save_thread.join()
311
-
312
-
313
- # https://github.com/facebookresearch/fvcore/blob/9d683aae73fb899dd35d6cf6720e5ef567761c57/fvcore/common/checkpoint.py
314
- def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys:
315
- # workaround https://github.com/pytorch/pytorch/issues/24139
316
- model_state_dict = model.state_dict()
317
- incorrect_shapes = []
318
- for k in list(checkpoint_state_dict.keys()):
319
- if k in model_state_dict:
320
- if "_extra_state" in k: # Key introduced by TransformerEngine for FP8
321
- log.warning(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.")
322
- continue
323
- model_param = model_state_dict[k]
324
- # Allow mismatch for uninitialized parameters
325
- if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter):
326
- continue
327
- if not isinstance(model_param, torch.Tensor):
328
- raise ValueError(
329
- f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not."
330
- )
331
-
332
- shape_model = tuple(model_param.shape)
333
- shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
334
- if shape_model != shape_checkpoint:
335
- has_observer_base_classes = (
336
- TORCH_VERSION >= (1, 8)
337
- and hasattr(quantization, "ObserverBase")
338
- and hasattr(quantization, "FakeQuantizeBase")
339
- )
340
- if has_observer_base_classes:
341
- # Handle the special case of quantization per channel observers,
342
- # where buffer shape mismatches are expected.
343
- def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module:
344
- # foo.bar.param_or_buffer_name -> [foo, bar]
345
- key_parts = key.split(".")[:-1]
346
- cur_module = model
347
- for key_part in key_parts:
348
- cur_module = getattr(cur_module, key_part)
349
- return cur_module
350
-
351
- cls_to_skip = (
352
- ObserverBase,
353
- FakeQuantizeBase,
354
- )
355
- target_module = _get_module_for_key(model, k)
356
- if isinstance(target_module, cls_to_skip):
357
- # Do not remove modules with expected shape mismatches
358
- # them from the state_dict loading. They have special logic
359
- # in _load_from_state_dict to handle the mismatches.
360
- continue
361
-
362
- incorrect_shapes.append((k, shape_checkpoint, shape_model))
363
- checkpoint_state_dict.pop(k)
364
- incompatible = model.load_state_dict(checkpoint_state_dict, strict=False)
365
- # Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling
366
- missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k]
367
- unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k]
368
- return _IncompatibleKeys(
369
- missing_keys=missing_keys,
370
- unexpected_keys=unexpected_keys,
371
- incorrect_shapes=incorrect_shapes,
372
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/config_helper.py DELETED
@@ -1,214 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import importlib
17
- import os
18
- import pkgutil
19
- import sys
20
- from dataclasses import fields as dataclass_fields
21
- from dataclasses import is_dataclass
22
- from typing import Any, Dict, Optional
23
-
24
- import attr
25
- import attrs
26
- from hydra import compose, initialize
27
- from hydra.core.config_store import ConfigStore
28
- from hydra.core.global_hydra import GlobalHydra
29
- from omegaconf import DictConfig, OmegaConf
30
-
31
- from lyra_2._ext.imaginaire.config import Config
32
- from lyra_2._ext.imaginaire.utils import log
33
-
34
-
35
- def is_attrs_or_dataclass(obj) -> bool:
36
- """
37
- Check if the object is an instance of an attrs class or a dataclass.
38
-
39
- Args:
40
- obj: The object to check.
41
-
42
- Returns:
43
- bool: True if the object is an instance of an attrs class or a dataclass, False otherwise.
44
- """
45
- return is_dataclass(obj) or attr.has(type(obj))
46
-
47
-
48
- def get_fields(obj):
49
- """
50
- Get the fields of an attrs class or a dataclass.
51
-
52
- Args:
53
- obj: The object to get fields from. Must be an instance of an attrs class or a dataclass.
54
-
55
- Returns:
56
- list: A list of field names.
57
-
58
- Raises:
59
- ValueError: If the object is neither an attrs class nor a dataclass.
60
- """
61
- if is_dataclass(obj):
62
- return [field.name for field in dataclass_fields(obj)]
63
- elif attr.has(type(obj)):
64
- return [field.name for field in attr.fields(type(obj))]
65
- else:
66
- raise ValueError("The object is neither an attrs class nor a dataclass.")
67
-
68
-
69
- def override(config: Config, overrides: Optional[list[str]] = None) -> Config:
70
- """
71
- :param config: the instance of class `Config` (usually from `make_config`)
72
- :param overrides: list of overrides for config
73
- :return: the composed instance of class `Config`
74
- """
75
- # Store the class of the config for reconstruction after overriding.
76
- # config_class = type(config)
77
-
78
- # Convert Config object to a DictConfig object
79
- config_dict = attrs.asdict(config)
80
- config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
81
- # Enforce "--" separator between the script arguments and overriding configs.
82
- if overrides:
83
- if overrides[0] != "--":
84
- raise ValueError(
85
- f'Hydra config overrides must be separated with a "--" token. but got overrides={overrides}, and overrides[0]={overrides[0]}'
86
- )
87
- overrides = overrides[1:]
88
- # Use Hydra to handle overrides
89
- cs = ConfigStore.instance()
90
- cs.store(name="config", node=config_omegaconf)
91
- if not GlobalHydra().is_initialized():
92
- with initialize(version_base=None):
93
- config_omegaconf = compose(config_name="config", overrides=overrides)
94
- OmegaConf.resolve(config_omegaconf)
95
- else:
96
- config_omegaconf = compose(config_name="config", overrides=overrides)
97
- OmegaConf.resolve(config_omegaconf)
98
-
99
- def config_from_dict(ref_instance: Any, kwargs: Any) -> Any:
100
- """
101
- Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data
102
-
103
- Args:
104
- ref_instance: The reference instance to determine the type and fields when needed
105
- kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data
106
-
107
- Returns:
108
- Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data
109
-
110
- Raises:
111
- AssertionError: If the fields do not match or if extra keys are found.
112
- Exception: If there is an error constructing the new instance.
113
- """
114
- is_type = is_attrs_or_dataclass(ref_instance)
115
- if not is_type:
116
- return kwargs
117
- else:
118
- ref_fields = set(get_fields(ref_instance))
119
- assert isinstance(kwargs, dict) or isinstance(kwargs, DictConfig), (
120
- "kwargs must be a dictionary or a DictConfig"
121
- )
122
- keys = set(kwargs.keys())
123
-
124
- # ref_fields must equal to or include all keys
125
- extra_keys = keys - ref_fields
126
- assert ref_fields == keys or keys.issubset(ref_fields), (
127
- f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}"
128
- )
129
-
130
- resolved_kwargs: Dict[str, Any] = {}
131
- for f in keys:
132
- resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f])
133
- try:
134
- new_instance = type(ref_instance)(**resolved_kwargs)
135
- except Exception as e:
136
- log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}")
137
- log.error(e)
138
- raise e
139
- return new_instance
140
-
141
- config = config_from_dict(config, config_omegaconf)
142
-
143
- return config
144
-
145
-
146
- def get_config_module(config_file: str) -> str:
147
- if not config_file.endswith(".py"):
148
- log.error("Config file cannot be specified as module.")
149
- log.error("Please provide the path to the Python config file (relative to the Imaginaire4 root).")
150
- assert os.path.isfile(config_file), f"Imaginaire4 config file ({config_file}) not found."
151
- # Convert to importable module format.
152
- config_module = config_file.replace("/", ".").replace(".py", "")
153
- return config_module
154
-
155
-
156
- def import_module(full_module_name: str, reload: bool = False):
157
- """
158
- Import a module by name.
159
-
160
- Args:
161
- full_module_name: The fully qualified name of the module to import.
162
- reload: If True, reload the module if it's already imported.
163
- """
164
- if full_module_name in sys.modules and reload:
165
- importlib.reload(sys.modules[full_module_name])
166
- else:
167
- importlib.import_module(full_module_name)
168
-
169
-
170
- def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None:
171
- """
172
- Import all modules from the specified package path recursively.
173
-
174
- This function is typically used in conjunction with Hydra to ensure that all modules
175
- within a specified package are imported, which is necessary for registering configurations.
176
-
177
- Example usage:
178
- ```python
179
- import_all_modules_from_package("projects.cosmos.diffusion.v1.config.experiment", reload=True, skip_underscore=False)
180
- ```
181
-
182
- Args:
183
- package_path (str): The dotted path to the package from which to import all modules.
184
- reload (bool): Flag to determine whether to reload modules if they're already imported.
185
- skip_underscore (bool): If True, skips importing modules that start with an underscore.
186
- """
187
- log.info(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}")
188
- package = importlib.import_module(package_path)
189
- package_directory = package.__path__
190
-
191
- def import_modules_recursively(directory: str, prefix: str) -> None:
192
- """
193
- Recursively imports or reloads all modules in the given directory.
194
-
195
- Args:
196
- directory (str): The file system path to the current package directory.
197
- prefix (str): The module prefix (e.g., 'projects.cosmos.diffusion.v1.config').
198
- """
199
- for _, module_name, is_pkg in pkgutil.iter_modules([directory]):
200
- if skip_underscore and module_name.startswith("_"):
201
- log.debug(f"Skipping module {module_name} as it starts with an underscore")
202
- continue
203
-
204
- full_module_name = f"{prefix}.{module_name}"
205
- log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}")
206
-
207
- import_module(full_module_name, reload=reload)
208
-
209
- if is_pkg:
210
- sub_package_directory = os.path.join(directory, module_name)
211
- import_modules_recursively(sub_package_directory, full_module_name)
212
-
213
- for directory in package_directory:
214
- import_modules_recursively(directory, package_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/context_managers.py DELETED
@@ -1,55 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from contextlib import ExitStack, contextmanager
17
- from typing import Generator
18
-
19
- from lyra_2._ext.imaginaire.utils.misc import timer
20
-
21
-
22
- @contextmanager
23
- def data_loader_init() -> Generator[None, None, None]:
24
- """
25
- Wrap the data loader initialization with multiple context managers used for telemetry and one logger.
26
- """
27
- contexts = [
28
- timer("init_data_loader"),
29
- ]
30
- with ExitStack() as stack:
31
- yield [stack.enter_context(cm) for cm in contexts]
32
-
33
-
34
- @contextmanager
35
- def model_init(set_barrier: bool = False) -> Generator[None, None, None]:
36
- """
37
- Wrap the instantiation of the model with multiple context managers used for telemetry and one logger.
38
- """
39
- contexts = [
40
- timer("init_model"),
41
- ]
42
- with ExitStack() as stack:
43
- yield [stack.enter_context(cm) for cm in contexts]
44
-
45
-
46
- @contextmanager
47
- def distributed_init() -> Generator[None, None, None]:
48
- """
49
- Wrap the distributed initialization, used for telemetry and timers
50
- """
51
- contexts = [
52
- timer("init_distributed"),
53
- ]
54
- with ExitStack() as stack:
55
- yield [stack.enter_context(cm) for cm in contexts]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/count_params.py DELETED
@@ -1,29 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from torch import nn
17
-
18
-
19
- def disabled_train(self, mode: bool = True):
20
- """Overwrite model.train with this function to make sure train/eval mode
21
- does not change anymore."""
22
- return self
23
-
24
-
25
- def count_params(model: nn.Module, verbose=False) -> int:
26
- total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
27
- if verbose:
28
- print(f"{model.__class__.__name__} has {total_params * 1.0e-6:.2f} M params.")
29
- return total_params
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/device.py DELETED
@@ -1,114 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import gc
17
- import math
18
- import os
19
-
20
- import pynvml
21
- from loguru import logger as logging
22
-
23
-
24
- def get_gpu_architecture():
25
- """
26
- Retrieves the GPU architecture of the available GPUs.
27
-
28
- Returns:
29
- str: The GPU architecture, which can be "H100", "A100", or "Other".
30
- """
31
- try:
32
- pynvml.nvmlInit()
33
- device_count = pynvml.nvmlDeviceGetCount()
34
- for i in range(device_count):
35
- handle = pynvml.nvmlDeviceGetHandleByIndex(i)
36
- model_name = pynvml.nvmlDeviceGetName(handle)
37
- if isinstance(model_name, bytes):
38
- model_name = model_name.decode("utf-8")
39
- print(f"GPU {i}: Model: {model_name}")
40
-
41
- # Check for specific models like H100 or A100
42
- if "H100" in model_name or "H200" in model_name:
43
- return "H100"
44
- elif "A100" in model_name:
45
- return "A100"
46
- elif "L40S" in model_name:
47
- return "L40S"
48
- elif "B200" in model_name:
49
- return "B200"
50
- except pynvml.NVMLError as error:
51
- print(f"Failed to get GPU info: {error}")
52
- finally:
53
- pynvml.nvmlShutdown()
54
-
55
- # return "Other" incase of non hopper/ampere or error
56
- return "Other"
57
-
58
-
59
- class GPUArchitectureNotSupported(Exception):
60
- """
61
- Custom exception raised when the expected GPU architecture is not supported.
62
- """
63
-
64
- pass
65
-
66
-
67
- def print_gpu_mem(str=None):
68
- try:
69
- pynvml.nvmlInit()
70
- meminfo = pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(0))
71
- logging.info(
72
- f"{str}: {meminfo.used / 1024 / 1024}/{meminfo.total / 1024 / 1024}MiB used ({meminfo.free / 1024 / 1024}MiB free)"
73
- )
74
- except pynvml.NVMLError as error:
75
- print(f"Failed to get GPU memory info: {error}")
76
-
77
-
78
- def force_gc():
79
- print_gpu_mem()
80
- print("gc()")
81
- gc.collect()
82
- print_gpu_mem()
83
- print("empty cuda cache")
84
- # print(torch.cuda.memory_summary())
85
- print_gpu_mem()
86
-
87
-
88
- def gpu0_has_80gb_or_less():
89
- try:
90
- pynvml.nvmlInit()
91
- meminfo = pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(0))
92
- return meminfo.total / 1024 / 1024 / 1024 <= 80
93
- except pynvml.NVMLError as error:
94
- print(f"Failed to get GPU memory info: {error}")
95
-
96
-
97
- class Device:
98
- _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) # type: ignore
99
-
100
- def __init__(self, device_idx: int):
101
- super().__init__()
102
- self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
103
-
104
- def get_name(self) -> str:
105
- return pynvml.nvmlDeviceGetName(self.handle)
106
-
107
- def get_cpu_affinity(self) -> list[int]:
108
- affinity_string = ""
109
- for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements):
110
- # assume nvml returns list of 64 bit ints
111
- affinity_string = "{:064b}".format(j) + affinity_string
112
- affinity_list = [int(x) for x in affinity_string]
113
- affinity_list.reverse() # so core 0 is in 0th element of list
114
- return [i for i, e in enumerate(affinity_list) if e != 0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/distributed.py DELETED
@@ -1,443 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from __future__ import annotations
17
-
18
- import collections
19
- import collections.abc
20
- import ctypes
21
- import functools
22
- import os
23
- from contextlib import contextmanager
24
- from datetime import timedelta
25
- from typing import TYPE_CHECKING, Any, Callable, Container, Optional
26
-
27
- import pynvml
28
- import torch
29
- import torch.distributed as dist
30
- from torch.distributed import get_process_group_ranks
31
-
32
- from lyra_2._ext.imaginaire.utils.device import Device
33
-
34
- if dist.is_available():
35
- from torch.distributed.distributed_c10d import _get_default_group
36
- from torch.distributed.utils import _sync_module_states, _verify_param_shape_across_processes
37
-
38
- from lyra_2._ext.imaginaire.utils import log
39
-
40
- if TYPE_CHECKING:
41
- from lyra_2._ext.imaginaire.config import DDPConfig
42
-
43
- try:
44
- from megatron.core import parallel_state
45
- except ImportError:
46
- print("Megatron-core is not installed.")
47
-
48
-
49
- def init() -> int | None:
50
- """Initialize distributed training."""
51
- if dist.is_initialized():
52
- return torch.cuda.current_device()
53
-
54
- # Set GPU affinity.
55
- pynvml.nvmlInit()
56
- local_rank = int(os.getenv("LOCAL_RANK", 0))
57
- try:
58
- device = Device(local_rank)
59
- os.sched_setaffinity(0, device.get_cpu_affinity())
60
- except Exception as e:
61
- log.warning(f"Failed to set device affinity: {e}")
62
- # Set up NCCL communication.
63
- os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0"
64
- os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
65
- if dist.is_available():
66
- torch.cuda.set_device(local_rank)
67
- # Get the timeout value from environment variable
68
- timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800)
69
- # Convert the timeout to an integer (if it isn't already) and then to a timedelta
70
- timeout_timedelta = timedelta(seconds=int(timeout_seconds))
71
- dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta)
72
- log.info(
73
- f"Initialized distributed training with local rank {local_rank} with timeout {timeout_seconds}",
74
- rank0_only=False,
75
- )
76
- # Increase the L2 fetch granularity for faster speed.
77
- _libcudart = ctypes.CDLL("libcudart.so")
78
- # Set device limit on the current device.
79
- p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
80
- _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
81
- _libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05))
82
- log.info(f"Training with {get_world_size()} GPUs.")
83
-
84
-
85
- def get_rank(group: Optional[dist.ProcessGroup] = None) -> int:
86
- """Get the rank (GPU device) of the worker.
87
-
88
- Returns:
89
- rank (int): The rank of the worker.
90
- """
91
- rank = 0
92
- if dist.is_available() and dist.is_initialized():
93
- rank = dist.get_rank(group)
94
- return rank
95
-
96
-
97
- def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int:
98
- """Get world size. How many GPUs are available in this job.
99
-
100
- Returns:
101
- world_size (int): The total number of GPUs available in this job.
102
- """
103
- world_size = 1
104
- if dist.is_available() and dist.is_initialized():
105
- world_size = dist.get_world_size(group)
106
- return world_size
107
-
108
-
109
- def is_rank0() -> bool:
110
- """Check if current process is the master GPU.
111
-
112
- Returns:
113
- (bool): True if this function is called from the master GPU, else False.
114
- """
115
- return get_rank() == 0
116
-
117
-
118
- def is_local_rank0() -> bool:
119
- """Check if current process is the local master GPU in the current node.
120
-
121
- Returns:
122
- (bool): True if this function is called from the local master GPU, else False.
123
- """
124
- return torch.cuda.current_device() == 0
125
-
126
-
127
- def rank0_only(func: Callable) -> Callable:
128
- """Apply this function only to the master GPU.
129
-
130
- Example usage:
131
- @rank0_only
132
- def func(x):
133
- return x + 3
134
-
135
- Args:
136
- func (Callable): a function.
137
-
138
- Returns:
139
- (Callable): A function wrapper executing the function only on the master GPU.
140
- """
141
-
142
- @functools.wraps(func)
143
- def wrapper(*args, **kwargs): # noqa: ANN202
144
- if is_rank0():
145
- return func(*args, **kwargs)
146
- else:
147
- return None
148
-
149
- return wrapper
150
-
151
-
152
- def barrier() -> None:
153
- """Barrier for all GPUs."""
154
- if dist.is_available() and dist.is_initialized():
155
- dist.barrier()
156
-
157
-
158
- def rank0_first(func: Callable) -> Callable:
159
- """run the function on rank 0 first, then on other ranks."""
160
-
161
- @functools.wraps(func)
162
- def wrapper(*args, **kwargs): # noqa: ANN202
163
- if is_rank0():
164
- result = func(*args, **kwargs)
165
- barrier()
166
- if not is_rank0():
167
- result = func(*args, **kwargs)
168
- return result
169
-
170
- return wrapper
171
-
172
-
173
- def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> torch.nn.Module | DistributedDataParallel:
174
- """Wraps the model to enable data parallalism for training across multiple GPU devices.
175
-
176
- Args:
177
- config_ddp (DDPConfig): The data parallel config.
178
- model (torch.nn.Module): The PyTorch module.
179
-
180
- Returns:
181
- model (torch.nn.Module | DistributedDataParallel): The data parallel model wrapper
182
- if distributed environment is available, otherwise return the original model.
183
- """
184
- if dist.is_available() and dist.is_initialized():
185
- local_rank = int(os.getenv("LOCAL_RANK", 0))
186
- try:
187
- ddp_group = parallel_state.get_data_parallel_group(with_context_parallel=True)
188
- except Exception as e:
189
- log.info(e)
190
- log.info("parallel_state not initialized, treating all GPUs equally for DDP")
191
- ddp_group = None
192
-
193
- model = DistributedDataParallel(
194
- model,
195
- device_ids=[local_rank],
196
- output_device=local_rank,
197
- find_unused_parameters=config_ddp.find_unused_parameters,
198
- static_graph=config_ddp.static_graph,
199
- broadcast_buffers=config_ddp.broadcast_buffers,
200
- process_group=ddp_group,
201
- )
202
- return model
203
-
204
-
205
- class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel):
206
- """This extends torch.nn.parallel.DistributedDataParallel with .training_step().
207
-
208
- This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an ImaginaireModel such that
209
- model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling
210
- model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward ->
211
- training_step), allowing us to preserve the function names and signatures.
212
- """
213
-
214
- def __init__(self, model: torch.nn.Module, *args, **kwargs):
215
- super().__init__(model, *args, **kwargs)
216
- self.show_sync_grad_static_graph_warning = True
217
-
218
- def training_step(self, *args, **kwargs) -> Any:
219
- # Cache the original model.forward() method.
220
- original_forward = self.module.forward
221
-
222
- def wrapped_training_step(*_args, **_kwargs): # noqa: ANN202
223
- # Unpatch immediately before calling training_step() because itself may want to call the real forward.
224
- self.module.forward = original_forward
225
- # The actual .training_step().
226
- return self.module.training_step(*_args, **_kwargs)
227
-
228
- # Patch the original_module's forward so we can redirect the arguments back to the real method.
229
- self.module.forward = wrapped_training_step
230
- # Call self, which implicitly calls self.forward() --> model.forward(), which is now model.training_step().
231
- # Without calling self.forward() or model.forward() explciitly, implicit hooks are also executed.
232
- return self(*args, **kwargs)
233
-
234
-
235
- @contextmanager
236
- def ddp_sync_grad(model, enabled):
237
- r"""
238
- Context manager to enable/disable gradient synchronizations across DDP processes for DDP model.
239
- Modified from:
240
- https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync
241
- Note that this is incompatible with static_graph=True and will be an no-op if static_graph=True.
242
-
243
- Within this context, gradients will be accumulated on module
244
- variables, which will later be synchronized in the first
245
- forward-backward pass exiting the context.
246
-
247
- .. warning::
248
- The forward pass should be included inside the context manager, or
249
- else gradients will still be synchronized.
250
- """
251
- assert isinstance(model, torch.nn.Module)
252
- if isinstance(model, DistributedDataParallel):
253
- old_require_backward_grad_sync = model.require_backward_grad_sync
254
- if model.static_graph and model.require_backward_grad_sync != enabled:
255
- if model.show_sync_grad_static_graph_warning:
256
- log.warning("DDP static_graph=True is incompatible with sync_grad(). Performance will be reduced.")
257
- model.show_sync_grad_static_graph_warning = False
258
- else:
259
- model.require_backward_grad_sync = enabled
260
- try:
261
- yield
262
- finally:
263
- if isinstance(model, DistributedDataParallel):
264
- model.require_backward_grad_sync = old_require_backward_grad_sync
265
-
266
-
267
- def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]:
268
- """Aggregate the list of data batches from all devices and process the results.
269
-
270
- This is used for gathering validation data batches with lyra_2._ext.imaginaire.utils.dataloader.DistributedEvalSampler.
271
- It will return the data/output of the entire validation set in its original index order. The sizes of data_batches
272
- in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be
273
- created before calling dis.all_gather().
274
-
275
- Args:
276
- data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where
277
- leaf entries are tensors.
278
-
279
- Returns:
280
- data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where
281
- leaf entries are concatenated tensors.
282
- """
283
- if isinstance(data_batches[0], torch.Tensor):
284
- # Concatenate the local data batches.
285
- data_concat = torch.cat(data_batches, dim=0) # type: ignore
286
- # Get the largest number of local samples from all ranks to determine whether to dummy-pad on this rank.
287
- max_num_local_samples = torch.tensor(len(data_concat), device="cuda")
288
- dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX)
289
- if len(data_concat) < max_num_local_samples:
290
- assert len(data_concat) + 1 == max_num_local_samples
291
- dummy = torch.empty_like(data_concat[:1])
292
- data_concat = torch.cat([data_concat, dummy], dim=0)
293
- dummy_count = torch.tensor(1, device="cuda")
294
- else:
295
- dummy_count = torch.tensor(0, device="cuda")
296
- # Get all concatenated batches from all ranks and concatenate again.
297
- dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM)
298
- data_concat = all_gather_tensor(data_concat.contiguous())
299
- data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1)
300
- # Remove the dummy samples.
301
- if dummy_count > 0:
302
- data_collate = data_collate[:-dummy_count]
303
- elif isinstance(data_batches[0], collections.abc.Mapping):
304
- data_collate = dict()
305
- for key in data_batches[0].keys():
306
- data_collate[key] = collate_batches([data[key] for data in data_batches]) # type: ignore
307
- else:
308
- raise TypeError
309
- return data_collate
310
-
311
-
312
- @torch.no_grad()
313
- def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]:
314
- """Gather the corresponding tensor from all GPU devices to a list.
315
-
316
- Args:
317
- tensor (torch.Tensor): Pytorch tensor.
318
-
319
- Returns:
320
- tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices.
321
- """
322
- tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())]
323
- dist.all_gather(tensor_list, tensor)
324
- return tensor_list
325
-
326
-
327
- def broadcast(tensor, src, group=None, async_op=False):
328
- world_size = get_world_size()
329
- if world_size < 2:
330
- return tensor
331
- dist.broadcast(tensor, src=src, group=group, async_op=async_op)
332
-
333
-
334
- def dist_reduce_tensor(tensor, rank=0, reduce="mean"):
335
- r"""Reduce to rank 0"""
336
- world_size = get_world_size()
337
- if world_size < 2:
338
- return tensor
339
- with torch.no_grad():
340
- dist.reduce(tensor, dst=rank)
341
- if get_rank() == rank:
342
- if reduce == "mean":
343
- tensor /= world_size
344
- elif reduce == "sum":
345
- pass
346
- else:
347
- raise NotImplementedError
348
- return tensor
349
-
350
-
351
- def sync_model_states(
352
- model: torch.nn.Module,
353
- process_group: Optional[dist.ProcessGroup] = None,
354
- src: int = 0,
355
- params_and_buffers_to_ignore: Optional[Container[str]] = None,
356
- broadcast_buffers: bool = True,
357
- ):
358
- """
359
- Modify based on DDP source code
360
- Synchronizes the parameters and buffers of a model across different processes in a distributed setting.
361
-
362
- This function ensures that all processes in the specified process group have the same initial parameters and
363
- buffers from the source rank, typically rank 0. It is useful when different processes start with different model
364
- states and a synchronization is required to ensure consistency across all ranks.
365
-
366
- Args:
367
- model (nn.Module): The model whose parameters and buffers are to be synchronized.
368
- process_group (dist.ProcessGroup, optional): The process group for communication. If None,
369
- the default group is used. Defaults to None.
370
- src (int, optional): The source rank from which parameters and buffers will be broadcasted.
371
- Defaults to 0.
372
- params_and_buffers_to_ignore (Optional[Container[str]], optional): A container of parameter and buffer
373
- names to exclude from synchronization. Defaults to None, which means all parameters and buffers are
374
- included.
375
- broadcast_buffers (bool, optional): Whether to broadcast buffers or not. Defaults to True.
376
-
377
- Side Effects:
378
- This function modifies the state of the model in-place to synchronize it with the source rank's model state.
379
-
380
- Raises:
381
- RuntimeError: If the shapes of parameters across processes do not match, a runtime error will be raised.
382
-
383
- Examples:
384
- >>> # downloading duplicated model weights from s3 in each rank and save network bandwidth
385
- >>> # useful and save our time when model weights are huge
386
- >>> if dist.get_rank == 0:
387
- >>> model.load_state_dict(network_bound_weights_download_fn(s3_weights_path))
388
- >>> dist.barrir()
389
- >>> sync_model_states(model) # sync rank0 weights to other ranks
390
- """
391
- if not dist.is_available() or not dist.is_initialized():
392
- return
393
- if process_group is None:
394
- process_group = _get_default_group()
395
- if not params_and_buffers_to_ignore:
396
- params_and_buffers_to_ignore = set()
397
-
398
- log.info(
399
- f"Synchronizing model states from rank {src} to all ranks in process group {get_process_group_ranks(process_group)}."
400
- )
401
-
402
- # Build tuple of (module, parameter) for all parameters that require grads.
403
- modules_and_parameters = [
404
- (module, parameter)
405
- for module_name, module in model.named_modules()
406
- for parameter in [
407
- param
408
- # Note that we access module.named_parameters instead of
409
- # parameters(module). parameters(module) is only needed in the
410
- # single-process multi device case, where it accesses replicated
411
- # parameters through _former_parameters.
412
- for param_name, param in module.named_parameters(recurse=False)
413
- if f"{module_name}.{param_name}" not in params_and_buffers_to_ignore
414
- # if param.requires_grad
415
- # and f"{module_name}.{param_name}" not in params_and_buffers_to_ignore
416
- ]
417
- ]
418
-
419
- # Deduplicate any parameters that might be shared across child modules.
420
- memo = set()
421
- modules_and_parameters = [
422
- # "p not in memo" is the deduplication check.
423
- # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed.
424
- (m, p)
425
- for m, p in modules_and_parameters
426
- if p not in memo and not memo.add(p) # type: ignore[func-returns-value]
427
- ]
428
-
429
- # Build list of parameters.
430
- parameters = [parameter for _, parameter in modules_and_parameters]
431
- if len(parameters) == 0:
432
- return
433
-
434
- _verify_param_shape_across_processes(process_group, parameters)
435
-
436
- _sync_module_states(
437
- module=model,
438
- process_group=process_group,
439
- broadcast_bucket_size=int(250 * 1024 * 1024),
440
- src=src,
441
- params_and_buffers_to_ignore=params_and_buffers_to_ignore,
442
- broadcast_buffers=broadcast_buffers,
443
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/backends/__init__.py DELETED
@@ -1,34 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from lyra_2._ext.imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
17
- from lyra_2._ext.imaginaire.utils.easy_io.backends.boto3_backend import Boto3Backend
18
- from lyra_2._ext.imaginaire.utils.easy_io.backends.http_backend import HTTPBackend
19
- from lyra_2._ext.imaginaire.utils.easy_io.backends.local_backend import LocalBackend
20
- from lyra_2._ext.imaginaire.utils.easy_io.backends.registry_utils import (
21
- backends,
22
- prefix_to_backends,
23
- register_backend,
24
- )
25
-
26
- __all__ = [
27
- "BaseStorageBackend",
28
- "LocalBackend",
29
- "HTTPBackend",
30
- "Boto3Backend",
31
- "register_backend",
32
- "backends",
33
- "prefix_to_backends",
34
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/backends/auto_auth.py DELETED
@@ -1,64 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import contextlib
17
- import json
18
- from typing import Any, Optional
19
-
20
- from lyra_2._ext.imaginaire.utils import log
21
- from lyra_2._ext.imaginaire.utils.env_parsers.cred_env_parser import CRED_ENVS, CRED_ENVS_DICT
22
-
23
- DEPLOYMENT_ENVS = ["prod", "dev", "stg"]
24
-
25
-
26
- # context manger to open a file or read from env variable
27
- @contextlib.contextmanager
28
- def open_auth(s3_credential_path: Optional[Any], mode: str):
29
- if not s3_credential_path:
30
- log.info(f"No credential file provided {s3_credential_path}.")
31
- yield None
32
- return
33
-
34
- name = s3_credential_path.split("/")[-1].split(".")[0]
35
- if not name:
36
- raise ValueError(f"Could not parse into env var: {s3_credential_path}")
37
- cred_env_name = f"PROD_{name.upper()}"
38
-
39
- if CRED_ENVS.APP_ENV in DEPLOYMENT_ENVS and cred_env_name in CRED_ENVS_DICT:
40
- object_storage_config = get_creds_from_env(cred_env_name)
41
- log.info(f"using ENV vars for {cred_env_name}")
42
- yield object_storage_config
43
- else:
44
- log.info(f"using credential file: {s3_credential_path}")
45
- with open(s3_credential_path, mode) as f:
46
- yield f
47
-
48
-
49
- def get_creds_from_env(cred_env_name: str) -> dict[str, str]:
50
- try:
51
- object_storage_config = CRED_ENVS_DICT[cred_env_name]
52
- except KeyError:
53
- raise ValueError(f"Could not find {cred_env_name} in CRED_ENVS")
54
- empty_args = {key.upper() for key in object_storage_config if object_storage_config[key] == ""}
55
- if empty_args:
56
- raise ValueError(f"Some required environment variable(s) were not provided for {cred_env_name}", empty_args)
57
- return object_storage_config
58
-
59
-
60
- def json_load_auth(f):
61
- if CRED_ENVS.APP_ENV in DEPLOYMENT_ENVS:
62
- return f if f else {}
63
- else:
64
- return json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/backends/base_backend.py DELETED
@@ -1,60 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import os
17
- import os.path as osp
18
- from abc import ABCMeta, abstractmethod
19
-
20
-
21
- def mkdir_or_exist(dir_name, mode=0o777):
22
- if dir_name == "":
23
- return
24
- dir_name = osp.expanduser(dir_name)
25
- os.makedirs(dir_name, mode=mode, exist_ok=True)
26
-
27
-
28
- def has_method(obj, method):
29
- return hasattr(obj, method) and callable(getattr(obj, method))
30
-
31
-
32
- class BaseStorageBackend(metaclass=ABCMeta):
33
- """Abstract class of storage backends.
34
-
35
- All backends need to implement two apis: :meth:`get()` and
36
- :meth:`get_text()`.
37
-
38
- - :meth:`get()` reads the file as a byte stream.
39
- - :meth:`get_text()` reads the file as texts.
40
- """
41
-
42
- # a flag to indicate whether the backend can create a symlink for a file
43
- # This attribute will be deprecated in future.
44
- _allow_symlink = False
45
-
46
- @property
47
- def allow_symlink(self):
48
- return self._allow_symlink
49
-
50
- @property
51
- def name(self):
52
- return self.__class__.__name__
53
-
54
- @abstractmethod
55
- def get(self, filepath):
56
- pass
57
-
58
- @abstractmethod
59
- def get_text(self, filepath):
60
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/backends/boto3_backend.py DELETED
@@ -1,841 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import io
17
- import os
18
- import re
19
- import tempfile
20
- from contextlib import contextmanager
21
- from pathlib import Path
22
- from shutil import SameFileError
23
- from typing import Generator, Iterator, Optional, Tuple, Union
24
-
25
- from lyra_2._ext.imaginaire.utils import log
26
- from lyra_2._ext.imaginaire.utils.easy_io.backends.base_backend import (
27
- BaseStorageBackend,
28
- has_method,
29
- mkdir_or_exist,
30
- )
31
- from lyra_2._ext.imaginaire.utils.easy_io.backends.boto3_client import Boto3Client
32
-
33
-
34
- class Boto3Backend(BaseStorageBackend):
35
- """boto3 storage backend (for internal usage).
36
-
37
- Boto3Backend supports reading and writing data to multiple clusters.
38
- If the file path contains the cluster name, Boto3Backend will read data
39
- from specified cluster or write data to it. Otherwise, Boto3Backend will
40
- access the default cluster.
41
-
42
- Args:
43
- path_mapping (dict, optional): Path mapping dict from local path to
44
- Boto3 path. When ``path_mapping={'src': 'dst'}``, ``src`` in
45
- ``filepath`` will be replaced by ``dst``. Defaults to None.
46
- s3_credential_path (str, optional): Config path of Boto3 client. Default: None.
47
- `New in version 0.3.3`.
48
-
49
- Examples:
50
- >>> backend = Boto3Backend()
51
- >>> filepath1 = 's3://path/of/file'
52
- >>> filepath2 = 'cluster-name:s3://path/of/file'
53
- >>> backend.get(filepath1) # get data from default cluster
54
- >>> client.get(filepath2) # get data from 'cluster-name' cluster
55
- """
56
-
57
- def __init__(
58
- self,
59
- s3_credential_path: str = "",
60
- path_mapping: Optional[dict] = None,
61
- ):
62
- self._client = Boto3Client(s3_credential_path=s3_credential_path)
63
- assert isinstance(path_mapping, dict) or path_mapping is None
64
- self.path_mapping = path_mapping
65
- if path_mapping:
66
- for k, v in path_mapping.items():
67
- log.critical(f"Path mapping: {k} -> {v}", rank0_only=False)
68
-
69
- def _map_path(self, filepath: Union[str, Path]) -> str:
70
- """Map ``filepath`` to a string path whose prefix will be replaced by
71
- :attr:`self.path_mapping`.
72
-
73
- Args:
74
- filepath (str or Path): Path to be mapped.
75
- """
76
- filepath = str(filepath)
77
- if self.path_mapping is not None:
78
- for k, v in self.path_mapping.items():
79
- filepath = filepath.replace(k, v, 1)
80
- return filepath
81
-
82
- def _format_path(self, filepath: str) -> str:
83
- """Convert a ``filepath`` to standard format of s3 oss.
84
-
85
- If the ``filepath`` is concatenated by ``os.path.join``, in a Windows
86
- environment, the ``filepath`` will be the format of
87
- 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the
88
- above ``filepath`` will be converted to 's3://bucket_name/image.jpg'.
89
-
90
- Args:
91
- filepath (str): Path to be formatted.
92
- """
93
- return re.sub(r"\\+", "/", filepath)
94
-
95
- def _replace_prefix(self, filepath: Union[str, Path]) -> str:
96
- filepath = str(filepath)
97
- return filepath
98
- # return filepath.replace('s3://', 's3://')
99
-
100
- def get(self, filepath: Union[str, Path]) -> bytes:
101
- """Read bytes from a given ``filepath`` with 'rb' mode.
102
-
103
- Args:
104
- filepath (str or Path): Path to read data.
105
-
106
- Returns:
107
- bytes: Return bytes read from filepath.
108
-
109
- Examples:
110
- >>> backend = Boto3Backend()
111
- >>> filepath = 's3://path/of/file'
112
- >>> backend.get(filepath)
113
- b'hello world'
114
- """
115
- filepath = self._map_path(filepath)
116
- filepath = self._format_path(filepath)
117
- filepath = self._replace_prefix(filepath)
118
- value = self._client.get(filepath)
119
- return value
120
-
121
- def get_text(
122
- self,
123
- filepath: Union[str, Path],
124
- encoding: str = "utf-8",
125
- ) -> str:
126
- """Read text from a given ``filepath`` with 'r' mode.
127
-
128
- Args:
129
- filepath (str or Path): Path to read data.
130
- encoding (str): The encoding format used to open the ``filepath``.
131
- Defaults to 'utf-8'.
132
-
133
- Returns:
134
- str: Expected text reading from ``filepath``.
135
-
136
- Examples:
137
- >>> backend = Boto3Backend()
138
- >>> filepath = 's3://path/of/file'
139
- >>> backend.get_text(filepath)
140
- 'hello world'
141
- """
142
- return str(self.get(filepath), encoding=encoding)
143
-
144
- def put(self, obj: Union[bytes, io.BytesIO], filepath: Union[str, Path]) -> None:
145
- """Write bytes to a given ``filepath``.
146
-
147
- Args:
148
- obj (bytes): Data to be saved.
149
- filepath (str or Path): Path to write data.
150
-
151
- Examples:
152
- >>> backend = Boto3Backend()
153
- >>> filepath = 's3://path/of/file'
154
- >>> backend.put(b'hello world', filepath)
155
- """
156
- filepath = self._map_path(filepath)
157
- filepath = self._format_path(filepath)
158
- filepath = self._replace_prefix(filepath)
159
- self._client.put(obj, filepath)
160
-
161
- def fast_put(self, obj: Union[bytes, io.BytesIO], filepath: Union[str, Path], num_processes: int = 32) -> None:
162
- """Write bytes to a given ``filepath`` with multiple processes and async"""
163
- assert num_processes > 1
164
- filepath = self._map_path(filepath)
165
- filepath = self._format_path(filepath)
166
- filepath = self._replace_prefix(filepath)
167
- self._client.fast_put(obj, filepath, num_processes=num_processes)
168
-
169
- def put_text(
170
- self,
171
- obj: str,
172
- filepath: Union[str, Path],
173
- encoding: str = "utf-8",
174
- ) -> None:
175
- """Write text to a given ``filepath``.
176
-
177
- Args:
178
- obj (str): Data to be written.
179
- filepath (str or Path): Path to write data.
180
- encoding (str): The encoding format used to encode the ``obj``.
181
- Defaults to 'utf-8'.
182
-
183
- Examples:
184
- >>> backend = Boto3Backend()
185
- >>> filepath = 's3://path/of/file'
186
- >>> backend.put_text('hello world', filepath)
187
- """
188
- self.put(bytes(obj, encoding=encoding), filepath)
189
-
190
- def exists(self, filepath: Union[str, Path]) -> bool:
191
- """Check whether a file path exists.
192
-
193
- Args:
194
- filepath (str or Path): Path to be checked whether exists.
195
-
196
- Returns:
197
- bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
198
-
199
- Examples:
200
- >>> backend = Boto3Backend()
201
- >>> filepath = 's3://path/of/file'
202
- >>> backend.exists(filepath)
203
- True
204
- """
205
- if not (has_method(self._client, "contains") and has_method(self._client, "isdir")):
206
- raise NotImplementedError(
207
- "Current version of Boto3 Python SDK has not supported "
208
- "the `contains` and `isdir` methods, please use a higher"
209
- "version or dev branch instead."
210
- )
211
-
212
- filepath = self._map_path(filepath)
213
- filepath = self._format_path(filepath)
214
- filepath = self._replace_prefix(filepath)
215
- return self._client.contains(filepath) or self._client.isdir(filepath)
216
-
217
- def isdir(self, filepath: Union[str, Path]) -> bool:
218
- """Check whether a file path is a directory.
219
-
220
- Args:
221
- filepath (str or Path): Path to be checked whether it is a
222
- directory.
223
-
224
- Returns:
225
- bool: Return ``True`` if ``filepath`` points to a directory,
226
- ``False`` otherwise.
227
-
228
- Examples:
229
- >>> backend = Boto3Backend()
230
- >>> filepath = 's3://path/of/dir'
231
- >>> backend.isdir(filepath)
232
- True
233
- """
234
- if not has_method(self._client, "isdir"):
235
- raise NotImplementedError(
236
- "Current version of Boto3 Python SDK has not supported "
237
- "the `isdir` method, please use a higher version or dev"
238
- " branch instead."
239
- )
240
-
241
- filepath = self._map_path(filepath)
242
- filepath = self._format_path(filepath)
243
- filepath = self._replace_prefix(filepath)
244
- return self._client.isdir(filepath)
245
-
246
- def isfile(self, filepath: Union[str, Path]) -> bool:
247
- """Check whether a file path is a file.
248
-
249
- Args:
250
- filepath (str or Path): Path to be checked whether it is a file.
251
-
252
- Returns:
253
- bool: Return ``True`` if ``filepath`` points to a file, ``False``
254
- otherwise.
255
-
256
- Examples:
257
- >>> backend = Boto3Backend()
258
- >>> filepath = 's3://path/of/file'
259
- >>> backend.isfile(filepath)
260
- True
261
- """
262
- if not has_method(self._client, "contains"):
263
- raise NotImplementedError(
264
- "Current version of Boto3 Python SDK has not supported "
265
- "the `contains` method, please use a higher version or "
266
- "dev branch instead."
267
- )
268
-
269
- filepath = self._map_path(filepath)
270
- filepath = self._format_path(filepath)
271
- filepath = self._replace_prefix(filepath)
272
- return self._client.contains(filepath)
273
-
274
- def join_path(
275
- self,
276
- filepath: Union[str, Path],
277
- *filepaths: Union[str, Path],
278
- ) -> str:
279
- r"""Concatenate all file paths.
280
-
281
- Join one or more filepath components intelligently. The return value
282
- is the concatenation of filepath and any members of \*filepaths.
283
-
284
- Args:
285
- filepath (str or Path): Path to be concatenated.
286
-
287
- Returns:
288
- str: The result after concatenation.
289
-
290
- Examples:
291
- >>> backend = Boto3Backend()
292
- >>> filepath = 's3://path/of/file'
293
- >>> backend.join_path(filepath, 'another/path')
294
- 's3://path/of/file/another/path'
295
- >>> backend.join_path(filepath, '/another/path')
296
- 's3://path/of/file/another/path'
297
- """
298
- filepath = self._format_path(self._map_path(filepath))
299
- if filepath.endswith("/"):
300
- filepath = filepath[:-1]
301
- formatted_paths = [filepath]
302
- for path in filepaths:
303
- formatted_path = self._format_path(self._map_path(path))
304
- formatted_paths.append(formatted_path.lstrip("/"))
305
-
306
- return "/".join(formatted_paths)
307
-
308
- @contextmanager
309
- def get_local_path(
310
- self,
311
- filepath: Union[str, Path],
312
- ) -> Generator[Union[str, Path], None, None]:
313
- """Download a file from ``filepath`` to a local temporary directory,
314
- and return the temporary path.
315
-
316
- ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
317
- can be called with ``with`` statement, and when exists from the
318
- ``with`` statement, the temporary path will be released.
319
-
320
- Args:
321
- filepath (str or Path): Download a file from ``filepath``.
322
-
323
- Yields:
324
- Iterable[str]: Only yield one temporary path.
325
-
326
- Examples:
327
- >>> backend = Boto3Backend()
328
- >>> # After existing from the ``with`` clause,
329
- >>> # the path will be removed
330
- >>> filepath = 's3://path/of/file'
331
- >>> with backend.get_local_path(filepath) as path:
332
- ... # do something here
333
- """
334
- assert self.isfile(filepath)
335
- try:
336
- f = tempfile.NamedTemporaryFile(delete=False)
337
- f.write(self.get(filepath))
338
- f.close()
339
- yield f.name
340
- finally:
341
- os.remove(f.name)
342
-
343
- def copyfile(
344
- self,
345
- src: Union[str, Path],
346
- dst: Union[str, Path],
347
- ) -> str:
348
- """Copy a file src to dst and return the destination file.
349
-
350
- src and dst should have the same prefix. If dst specifies a directory,
351
- the file will be copied into dst using the base filename from src. If
352
- dst specifies a file that already exists, it will be replaced.
353
-
354
- Args:
355
- src (str or Path): A file to be copied.
356
- dst (str or Path): Copy file to dst.
357
-
358
- Returns:
359
- str: The destination file.
360
-
361
- Raises:
362
- SameFileError: If src and dst are the same file, a SameFileError
363
- will be raised.
364
-
365
- Examples:
366
- >>> backend = Boto3Backend()
367
- >>> # dst is a file
368
- >>> src = 's3://path/of/file'
369
- >>> dst = 's3://path/of/file1'
370
- >>> backend.copyfile(src, dst)
371
- 's3://path/of/file1'
372
-
373
- >>> # dst is a directory
374
- >>> dst = 's3://path/of/dir'
375
- >>> backend.copyfile(src, dst)
376
- 's3://path/of/dir/file'
377
- """
378
- src = self._format_path(self._map_path(src))
379
- dst = self._format_path(self._map_path(dst))
380
- if self.isdir(dst):
381
- dst = self.join_path(dst, src.split("/")[-1])
382
-
383
- if src == dst:
384
- raise SameFileError("src and dst should not be same")
385
-
386
- self.put(self.get(src), dst)
387
- return dst
388
-
389
- def copytree(
390
- self,
391
- src: Union[str, Path],
392
- dst: Union[str, Path],
393
- ) -> str:
394
- """Recursively copy an entire directory tree rooted at src to a
395
- directory named dst and return the destination directory.
396
-
397
- src and dst should have the same prefix.
398
-
399
- Args:
400
- src (str or Path): A directory to be copied.
401
- dst (str or Path): Copy directory to dst.
402
- backend_args (dict, optional): Arguments to instantiate the
403
- prefix of uri corresponding backend. Defaults to None.
404
-
405
- Returns:
406
- str: The destination directory.
407
-
408
- Raises:
409
- FileExistsError: If dst had already existed, a FileExistsError will
410
- be raised.
411
-
412
- Examples:
413
- >>> backend = Boto3Backend()
414
- >>> src = 's3://path/of/dir'
415
- >>> dst = 's3://path/of/dir1'
416
- >>> backend.copytree(src, dst)
417
- 's3://path/of/dir1'
418
- """
419
- src = self._format_path(self._map_path(src))
420
- dst = self._format_path(self._map_path(dst))
421
-
422
- if self.exists(dst):
423
- raise FileExistsError("dst should not exist")
424
-
425
- for path in self.list_dir_or_file(src, list_dir=False, recursive=True):
426
- src_path = self.join_path(src, path)
427
- dst_path = self.join_path(dst, path)
428
- self.put(self.get(src_path), dst_path)
429
-
430
- return dst
431
-
432
- def copyfile_from_local(
433
- self,
434
- src: Union[str, Path],
435
- dst: Union[str, Path],
436
- ) -> str:
437
- """Upload a local file src to dst and return the destination file.
438
-
439
- Args:
440
- src (str or Path): A local file to be copied.
441
- dst (str or Path): Copy file to dst.
442
- backend_args (dict, optional): Arguments to instantiate the
443
- prefix of uri corresponding backend. Defaults to None.
444
-
445
- Returns:
446
- str: If dst specifies a directory, the file will be copied into dst
447
- using the base filename from src.
448
-
449
- Examples:
450
- >>> backend = Boto3Backend()
451
- >>> # dst is a file
452
- >>> src = 'path/of/your/file'
453
- >>> dst = 's3://path/of/file1'
454
- >>> backend.copyfile_from_local(src, dst)
455
- 's3://path/of/file1'
456
-
457
- >>> # dst is a directory
458
- >>> dst = 's3://path/of/dir'
459
- >>> backend.copyfile_from_local(src, dst)
460
- 's3://path/of/dir/file'
461
- """
462
- dst = self._format_path(self._map_path(dst))
463
- if self.isdir(dst):
464
- dst = self.join_path(dst, os.path.basename(src))
465
-
466
- with open(src, "rb") as f:
467
- self.put(f.read(), dst)
468
-
469
- return dst
470
-
471
- def copytree_from_local(
472
- self,
473
- src: Union[str, Path],
474
- dst: Union[str, Path],
475
- ) -> str:
476
- """Recursively copy an entire directory tree rooted at src to a
477
- directory named dst and return the destination directory.
478
-
479
- Args:
480
- src (str or Path): A local directory to be copied.
481
- dst (str or Path): Copy directory to dst.
482
-
483
- Returns:
484
- str: The destination directory.
485
-
486
- Raises:
487
- FileExistsError: If dst had already existed, a FileExistsError will
488
- be raised.
489
-
490
- Examples:
491
- >>> backend = Boto3Backend()
492
- >>> src = 'path/of/your/dir'
493
- >>> dst = 's3://path/of/dir1'
494
- >>> backend.copytree_from_local(src, dst)
495
- 's3://path/of/dir1'
496
- """
497
- dst = self._format_path(self._map_path(dst))
498
- if self.exists(dst):
499
- raise FileExistsError("dst should not exist")
500
-
501
- src = str(src)
502
-
503
- for cur_dir, _, files in os.walk(src):
504
- for f in files:
505
- src_path = os.path.join(cur_dir, f)
506
- dst_path = self.join_path(dst, src_path.replace(src, ""))
507
- self.copyfile_from_local(src_path, dst_path)
508
-
509
- return dst
510
-
511
- def copyfile_to_local(
512
- self,
513
- src: Union[str, Path],
514
- dst: Union[str, Path],
515
- dst_type: str, # Choose from ["file", "dir"]
516
- ) -> Union[str, Path]:
517
- """Copy the file src to local dst and return the destination file.
518
-
519
- If dst specifies a directory, the file will be copied into dst using
520
- the base filename from src. If dst specifies a file that already
521
- exists, it will be replaced.
522
-
523
- Args:
524
- src (str or Path): A file to be copied.
525
- dst (str or Path): Copy file to to local dst.
526
-
527
- Returns:
528
- str: If dst specifies a directory, the file will be copied into dst
529
- using the base filename from src.
530
-
531
- Examples:
532
- >>> backend = Boto3Backend()
533
- >>> # dst is a file
534
- >>> src = 's3://path/of/file'
535
- >>> dst = 'path/of/your/file'
536
- >>> backend.copyfile_to_local(src, dst)
537
- 'path/of/your/file'
538
-
539
- >>> # dst is a directory
540
- >>> dst = 'path/of/your/dir'
541
- >>> backend.copyfile_to_local(src, dst)
542
- 'path/of/your/dir/file'
543
- """
544
- assert dst_type in ["file", "dir"]
545
- # There is no good way to detect whether dst is a directory or a file, so we make dst_type required
546
- if dst_type == "dir":
547
- basename = os.path.basename(src)
548
- if isinstance(dst, str):
549
- dst = os.path.join(dst, basename)
550
- else:
551
- assert isinstance(dst, Path)
552
- dst = dst / basename
553
-
554
- # Create parent directory if it doesn't exist
555
- parent_dir = os.path.dirname(dst)
556
- os.makedirs(parent_dir, exist_ok=True)
557
-
558
- try:
559
- with open(dst, "wb") as f:
560
- data = self.get(src)
561
- f.write(data)
562
- except Exception as e:
563
- log.error(f"Failed to write file: {e}")
564
- raise
565
-
566
- return dst
567
-
568
- def copytree_to_local(
569
- self,
570
- src: Union[str, Path],
571
- dst: Union[str, Path],
572
- ) -> Union[str, Path]:
573
- """Recursively copy an entire directory tree rooted at src to a local
574
- directory named dst and return the destination directory.
575
-
576
- Args:
577
- src (str or Path): A directory to be copied.
578
- dst (str or Path): Copy directory to local dst.
579
- backend_args (dict, optional): Arguments to instantiate the
580
- prefix of uri corresponding backend. Defaults to None.
581
-
582
- Returns:
583
- str: The destination directory.
584
-
585
- Examples:
586
- >>> backend = Boto3Backend()
587
- >>> src = 's3://path/of/dir'
588
- >>> dst = 'path/of/your/dir'
589
- >>> backend.copytree_to_local(src, dst)
590
- 'path/of/your/dir'
591
- """
592
- for path in self.list_dir_or_file(src, list_dir=False, recursive=True):
593
- dst_path = os.path.join(dst, path)
594
- mkdir_or_exist(os.path.dirname(dst_path))
595
- with open(dst_path, "wb") as f:
596
- f.write(self.get(self.join_path(src, path)))
597
-
598
- return dst
599
-
600
- def remove(self, filepath: Union[str, Path]) -> None:
601
- """Remove a file.
602
-
603
- Args:
604
- filepath (str or Path): Path to be removed.
605
-
606
- Raises:
607
- FileNotFoundError: If filepath does not exist, an FileNotFoundError
608
- will be raised.
609
- IsADirectoryError: If filepath is a directory, an IsADirectoryError
610
- will be raised.
611
-
612
- Examples:
613
- >>> backend = Boto3Backend()
614
- >>> filepath = 's3://path/of/file'
615
- >>> backend.remove(filepath)
616
- """
617
- if not has_method(self._client, "delete"):
618
- raise NotImplementedError(
619
- "Current version of Boto3 Python SDK has not supported "
620
- "the `delete` method, please use a higher version or dev "
621
- "branch instead."
622
- )
623
-
624
- if not self.exists(filepath):
625
- raise FileNotFoundError(f"filepath {filepath} does not exist")
626
-
627
- if self.isdir(filepath):
628
- raise IsADirectoryError("filepath should be a file")
629
-
630
- filepath = self._map_path(filepath)
631
- filepath = self._format_path(filepath)
632
- filepath = self._replace_prefix(filepath)
633
- self._client.delete(filepath)
634
-
635
- def rmtree(self, dir_path: Union[str, Path]) -> None:
636
- """Recursively delete a directory tree.
637
-
638
- Args:
639
- dir_path (str or Path): A directory to be removed.
640
-
641
- Examples:
642
- >>> backend = Boto3Backend()
643
- >>> dir_path = 's3://path/of/dir'
644
- >>> backend.rmtree(dir_path)
645
- """
646
- for path in self.list_dir_or_file(dir_path, list_dir=False, recursive=True):
647
- filepath = self.join_path(dir_path, path)
648
- self.remove(filepath)
649
-
650
- def copy_if_symlink_fails(
651
- self,
652
- src: Union[str, Path],
653
- dst: Union[str, Path],
654
- ) -> bool:
655
- """Create a symbolic link pointing to src named dst.
656
-
657
- Directly copy src to dst because PetrelBacekend does not support create
658
- a symbolic link.
659
-
660
- Args:
661
- src (str or Path): A file or directory to be copied.
662
- dst (str or Path): Copy a file or directory to dst.
663
- backend_args (dict, optional): Arguments to instantiate the
664
- prefix of uri corresponding backend. Defaults to None.
665
-
666
- Returns:
667
- bool: Return False because Boto3Backend does not support create
668
- a symbolic link.
669
-
670
- Examples:
671
- >>> backend = Boto3Backend()
672
- >>> src = 's3://path/of/file'
673
- >>> dst = 's3://path/of/your/file'
674
- >>> backend.copy_if_symlink_fails(src, dst)
675
- False
676
- >>> src = 's3://path/of/dir'
677
- >>> dst = 's3://path/of/your/dir'
678
- >>> backend.copy_if_symlink_fails(src, dst)
679
- False
680
- """
681
- if self.isfile(src):
682
- self.copyfile(src, dst)
683
- else:
684
- self.copytree(src, dst)
685
- return False
686
-
687
- def list_dir(self, dir_path: Union[str, Path]):
688
- """List all folders in an S3 bucket with a given prefix.
689
-
690
- Args:
691
- dir_path (str | Path): Path of the directory.
692
-
693
- Examples:
694
- >>> backend = Boto3Backend()
695
- >>> dir_path = 's3://path/of/dir'
696
- >>> backend.list_dir(dir_path)
697
- """
698
- dir_path = self._map_path(dir_path)
699
- dir_path = self._format_path(dir_path)
700
- dir_path = self._replace_prefix(dir_path)
701
- return self._client.ls_dir(dir_path)
702
-
703
- def list_dir_or_file( # pylint: disable=too-many-arguments
704
- self,
705
- dir_path: Union[str, Path],
706
- list_dir: bool = True,
707
- list_file: bool = True,
708
- suffix: Optional[Union[str, Tuple[str]]] = None,
709
- recursive: bool = False,
710
- ) -> Iterator[str]:
711
- """Scan a directory to find the interested directories or files in
712
- arbitrary order.
713
-
714
- Note:
715
- Boto3 has no concept of directories but it simulates the directory
716
- hierarchy in the filesystem through public prefixes. In addition,
717
- if the returned path ends with '/', it means the path is a public
718
- prefix which is a logical directory.
719
-
720
- Note:
721
- :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
722
- In addition, the returned path of directory will not contains the
723
- suffix '/' which is consistent with other backends.
724
-
725
- Args:
726
- dir_path (str | Path): Path of the directory.
727
- list_dir (bool): List the directories. Defaults to True.
728
- list_file (bool): List the path of files. Defaults to True.
729
- suffix (str or tuple[str], optional): File suffix
730
- that we are interested in. Defaults to None.
731
- recursive (bool): If set to True, recursively scan the
732
- directory. Defaults to False.
733
-
734
- Yields:
735
- Iterable[str]: A relative path to ``dir_path``.
736
-
737
- Examples:
738
- >>> backend = Boto3Backend()
739
- >>> dir_path = 's3://path/of/dir'
740
- >>> # list those files and directories in current directory
741
- >>> for file_path in backend.list_dir_or_file(dir_path):
742
- ... print(file_path)
743
- >>> # only list files
744
- >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False):
745
- ... print(file_path)
746
- >>> # only list directories
747
- >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False):
748
- ... print(file_path)
749
- >>> # only list files ending with specified suffixes
750
- >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'):
751
- ... print(file_path)
752
- >>> # list all files and directory recursively
753
- >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True):
754
- ... print(file_path)
755
- """ # noqa: E501
756
- if not has_method(self._client, "list"):
757
- raise NotImplementedError(
758
- "Current version of Boto3 Python SDK has not supported "
759
- "the `list` method, please use a higher version or dev"
760
- " branch instead."
761
- )
762
-
763
- dir_path = self._map_path(dir_path)
764
- dir_path = self._format_path(dir_path)
765
- dir_path = self._replace_prefix(dir_path)
766
- if list_dir and suffix is not None:
767
- raise TypeError("`list_dir` should be False when `suffix` is not None")
768
-
769
- if list_dir and not list_file and not recursive:
770
- raise TypeError(
771
- "Please use `list_dir` instead of `list_dir_or_file` when you only want to list the first level directories."
772
- )
773
-
774
- if (suffix is not None) and not isinstance(suffix, (str, tuple)):
775
- raise TypeError("`suffix` must be a string or tuple of strings")
776
-
777
- # Boto3's simulated directory hierarchy assumes that directory paths
778
- # should end with `/`
779
- if not dir_path.endswith("/"):
780
- dir_path += "/"
781
-
782
- root = dir_path
783
-
784
- def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive):
785
- # Keep track of directories we've already yielded to avoid duplicates
786
- yielded_dirs = set() if list_dir else None
787
-
788
- for path in self._client.list(dir_path):
789
- # All paths returned by S3 list are file paths, never directory paths
790
- absolute_path = self.join_path(dir_path, path)
791
- rel_path = absolute_path[len(root) :]
792
-
793
- # If we want directories, extract directory prefixes from file paths
794
- # boto3 client actually never return dir, it only return file paths
795
- if list_dir and "/" in rel_path:
796
- if not recursive:
797
- # Non-recursive: only yield immediate child directory (first level)
798
- first_slash_pos = rel_path.find("/")
799
- immediate_child_dir = rel_path[:first_slash_pos]
800
-
801
- if immediate_child_dir not in yielded_dirs:
802
- yielded_dirs.add(immediate_child_dir)
803
- yield immediate_child_dir
804
- else:
805
- # Recursive: yield all directory levels
806
- path_parts = rel_path.split("/")[:-1] # Exclude filename
807
- current_dir = ""
808
- for part in path_parts:
809
- if current_dir:
810
- current_dir += "/" + part
811
- else:
812
- current_dir = part
813
-
814
- if current_dir not in yielded_dirs:
815
- yielded_dirs.add(current_dir)
816
- yield current_dir
817
-
818
- # Handle file listing
819
- if (suffix is None or rel_path.endswith(suffix)) and list_file:
820
- yield rel_path
821
-
822
- return _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive)
823
-
824
- def generate_presigned_url(self, url: str, client_method: str = "get_object", expires_in: int = 3600) -> str:
825
- """Generate the presigned url of video stream which can be passed to
826
- mmcv.VideoReader. Now only work on Boto3 backend.
827
-
828
- Note:
829
- Now only work on Boto3 backend.
830
-
831
- Args:
832
- url (str): Url of video stream.
833
- client_method (str): Method of client, 'get_object' or
834
- 'put_object'. Default: 'get_object'.
835
- expires_in (int): expires, in seconds. Default: 3600.
836
-
837
- Returns:
838
- str: Generated presigned url.
839
- """
840
- raise NotImplementedError("generate_presigned_url is not supported in Boto3Backend")
841
- return self._client.generate_presigned_url(url, client_method, expires_in)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/backends/boto3_client.py DELETED
@@ -1,565 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import asyncio
17
- import concurrent.futures
18
- import io
19
- import os
20
- import time
21
- from math import ceil
22
- from multiprocessing import shared_memory
23
- from typing import Any, Dict, Generator, List, Tuple
24
-
25
- import boto3
26
- import numpy as np
27
- from botocore.config import Config as S3Config
28
- from botocore.exceptions import ClientError
29
-
30
- import lyra_2._ext.imaginaire.utils.easy_io.backends.auto_auth as auto
31
- from lyra_2._ext.imaginaire.utils import log
32
- from lyra_2._ext.imaginaire.utils.env_parsers.cred_env_parser import CRED_ENVS
33
-
34
- try:
35
- import aioboto3
36
- import aioboto3.session
37
- from aiobotocore.config import AioConfig
38
- from aiobotocore.session import AioSession
39
- except ImportError:
40
- aioboto3 = None
41
- AioSession = None
42
-
43
- MAX_RETRIES = 5
44
- RETRY_DELAY = 1 # seconds
45
-
46
-
47
- async def upload_single_part_async(
48
- s3: AioSession, bucket: str, key: str, part_number: int, data: bytes, upload_id: str
49
- ) -> Dict[str, Any]:
50
- """
51
- Uploads a single part of a file asynchronously to S3.
52
-
53
- Args:
54
- s3 (S3): The S3 client.
55
- bucket (str): The S3 bucket name.
56
- key (str): The S3 key (file path).
57
- part_number (int): The part number of the upload.
58
- data (bytes): The data to upload.
59
- upload_id (str): The upload ID for the multipart upload.
60
-
61
- Returns:
62
- Dict[str, Any]: A dictionary containing the part number and ETag.
63
- """
64
- for attempt in range(MAX_RETRIES):
65
- try:
66
- response = await s3.upload_part(
67
- Bucket=bucket, Key=key, PartNumber=part_number, UploadId=upload_id, Body=data
68
- )
69
- return {"PartNumber": part_number, "ETag": response["ETag"]}
70
- except (ClientError, asyncio.TimeoutError, Exception) as e:
71
- log.warning(f"Attempt {attempt + 1} failed for part {part_number}: {str(e)}", rank0_only=False)
72
- if attempt < MAX_RETRIES - 1:
73
- await asyncio.sleep(RETRY_DELAY * (2**attempt)) # Exponential backoff
74
- else:
75
- log.error(f"Failed to upload part {part_number} after {MAX_RETRIES} attempts", rank0_only=False)
76
- raise
77
-
78
-
79
- async def upload_parts_async(
80
- part_size: int,
81
- part_numbers: range,
82
- upload_id: str,
83
- data: bytes,
84
- bucket: str,
85
- key: str,
86
- client_config: Dict[str, Any],
87
- ) -> List[Dict[str, Any]]:
88
- """
89
- Uploads multiple parts of a file asynchronously to S3.
90
-
91
- Args:
92
- part_size (int): The size of each part in bytes.
93
- part_numbers (range): The range of part numbers to upload.
94
- upload_id (str): The upload ID for the multipart upload.
95
- data (bytes): The data to upload.
96
- bucket (str): The S3 bucket name.
97
- key (str): The S3 key (file path).
98
- client_config (Dict[str, Any]): The S3 client configuration.
99
-
100
- Returns:
101
- List[Dict[str, Any]]: A list of dictionaries containing part numbers and ETags.
102
- """
103
- session = aioboto3.Session()
104
- config = AioConfig(retries={"max_attempts": 3, "mode": "adaptive"}, connect_timeout=5, read_timeout=10)
105
- start_idx = part_numbers[0]
106
- async with session.client("s3", config=config, **client_config) as s3:
107
- tasks = []
108
- for part_number in part_numbers:
109
- start = (part_number - start_idx) * part_size
110
- end = min(start + part_size, len(data))
111
- part_data = data[start:end]
112
- tasks.append(upload_single_part_async(s3, bucket, key, part_number + 1, part_data, upload_id))
113
-
114
- results = await asyncio.gather(*tasks, return_exceptions=True)
115
-
116
- successful_parts = []
117
- failed_parts = []
118
- for part_number, result in enumerate(results, start=start_idx + 1):
119
- if isinstance(result, Exception):
120
- failed_parts.append(part_number)
121
- else:
122
- successful_parts.append(result)
123
-
124
- if failed_parts:
125
- log.error(f"Failed to upload parts: {failed_parts}", rank0_only=False)
126
- raise Exception(f"Failed to upload {len(failed_parts)} parts")
127
-
128
- successful_parts.sort(key=lambda part: part["PartNumber"])
129
- return successful_parts
130
-
131
-
132
- def upload_parts_to_s3(args: Tuple[range, str, int, bytes, str, str, Dict[str, Any]]) -> List[Dict[str, Any]]:
133
- """
134
- Uploads parts of a file to S3 using a new event loop.
135
-
136
- Args:
137
- args (Tuple[range, str, int, bytes, str, str, Dict[str, Any]]): The arguments for uploading parts, including:
138
- part_numbers (range): The range of part numbers to upload.
139
- upload_id (str): The upload ID for the multipart upload.
140
- part_size (int): The size of each part in bytes.
141
- data (bytes): The data to upload.
142
- bucket (str): The S3 bucket name.
143
- key (str): The S3 key (file path).
144
- client_config (Dict[str, Any]): The S3 client configuration.
145
-
146
- Returns:
147
- List[Dict[str, Any]]: A list of dictionaries containing part numbers and ETags.
148
- """
149
- part_numbers, upload_id, part_size, data, bucket, key, client_config = args
150
- loop = asyncio.new_event_loop()
151
- asyncio.set_event_loop(loop)
152
- parts = loop.run_until_complete(
153
- upload_parts_async(part_size, part_numbers, upload_id, data, bucket, key, client_config)
154
- )
155
- loop.close()
156
- return parts
157
-
158
-
159
- async def download_single_part_async(
160
- s3, bucket: str, key: str, part_number: int, start: int, end: int, shm_name: str, part_size: int
161
- ) -> None:
162
- """
163
- Downloads a single part of a file asynchronously and writes it to shared memory.
164
-
165
- Args:
166
- s3 (S3): The S3 client.
167
- bucket (str): The S3 bucket name.
168
- key (str): The S3 key (file path).
169
- part_number (int): The part number.
170
- start (int): The start byte of the part.
171
- end (int): The end byte of the part.
172
- shm_name (str): The name of the shared memory block.
173
- part_size (int): The size of each part in bytes.
174
- """
175
- for attempt in range(MAX_RETRIES):
176
- try:
177
- range_header = f"bytes={start}-{end}"
178
- response = await s3.get_object(Bucket=bucket, Key=key, Range=range_header)
179
- data = await response["Body"].read()
180
-
181
- shm = shared_memory.SharedMemory(name=shm_name)
182
- offset = part_number * part_size
183
- shm.buf[offset : offset + len(data)] = data
184
- shm.close()
185
- return
186
- except (ClientError, asyncio.TimeoutError, Exception) as e:
187
- log.warning(f"Attempt {attempt + 1} failed for part {part_number}: {str(e)}", rank0_only=False)
188
- if attempt < MAX_RETRIES - 1:
189
- await asyncio.sleep(RETRY_DELAY * (2**attempt)) # Exponential backoff
190
- else:
191
- log.error(f"Failed to download part {part_number} after {MAX_RETRIES} attempts", rank0_only=False)
192
- raise
193
-
194
-
195
- async def download_parts_async(
196
- part_size: int, part_numbers: range, bucket: str, key: str, client_config: Dict[str, Any], shm_name: str
197
- ) -> None:
198
- """
199
- Downloads multiple parts of a file asynchronously and writes them to shared memory.
200
-
201
- Args:
202
- part_size (int): The size of each part in bytes.
203
- part_numbers (range): The range of part numbers to download.
204
- bucket (str): The S3 bucket name.
205
- key (str): The S3 key (file path).
206
- client_config (Dict[str, Any]): The S3 client configuration.
207
- shm_name (str): The name of the shared memory block.
208
- """
209
- session = aioboto3.Session()
210
- config = AioConfig(retries={"max_attempts": 5, "mode": "adaptive"}, connect_timeout=10, read_timeout=30)
211
- async with session.client("s3", config=config, **client_config) as s3:
212
- tasks = [
213
- download_single_part_async(
214
- s3,
215
- bucket,
216
- key,
217
- part_number,
218
- part_number * part_size,
219
- (part_number + 1) * part_size - 1,
220
- shm_name,
221
- part_size,
222
- )
223
- for part_number in part_numbers
224
- ]
225
- results = await asyncio.gather(*tasks, return_exceptions=True)
226
- failed_parts = [part for part, result in zip(part_numbers, results) if isinstance(result, Exception)]
227
-
228
- if failed_parts:
229
- log.error(f"Failed to download parts: {failed_parts}", rank0_only=False)
230
- raise Exception(f"Failed to download {len(failed_parts)} parts")
231
-
232
-
233
- def download_parts_to_s3(args: Tuple[range, int, str, str, Dict[str, Any], str]) -> bytes:
234
- """
235
- Downloads parts of a file using a new event loop.
236
-
237
- Args:
238
- args (Tuple[range, int, str, str, Dict[str, Any]]): The arguments for downloading parts, including:
239
- part_numbers (range): The range of part numbers to download.
240
- part_size (int): The size of each part in bytes.
241
- bucket (str): The S3 bucket name.
242
- key (str): The S3 key (file path).
243
- client_config (Dict[str, Any]): The S3 client configuration.
244
-
245
- Returns:
246
- bytes: The combined file data from all downloaded parts.
247
- """
248
- part_numbers, part_size, bucket, key, client_config, shm_name = args
249
- loop = asyncio.new_event_loop()
250
- asyncio.set_event_loop(loop)
251
- loop.run_until_complete(download_parts_async(part_size, part_numbers, bucket, key, client_config, shm_name))
252
- loop.close()
253
-
254
-
255
- class Boto3Client:
256
- def __init__(
257
- self,
258
- s3_credential_path: str,
259
- max_attempt: int = 3,
260
- ):
261
- self.max_attempt = max_attempt
262
- assert s3_credential_path, "s3_credential_path is required"
263
- assert os.path.exists(s3_credential_path) or CRED_ENVS.APP_ENV in [
264
- "prod",
265
- "dev",
266
- "stg",
267
- ], f"Credential file not found: {s3_credential_path}"
268
- with auto.open_auth(s3_credential_path, "r") as f:
269
- conf = auto.json_load_auth(f)
270
-
271
- s3_config = S3Config(
272
- signature_version="s3v4",
273
- s3={"addressing_style": "virtual"},
274
- response_checksum_validation="when_required",
275
- request_checksum_calculation="when_required",
276
- )
277
-
278
- self._client = boto3.client("s3", **conf, config=s3_config)
279
- self._s3_cred_info = conf
280
- self._mc_kv_store = None
281
-
282
- def get(self, filepath):
283
- filepath = self._check_path(filepath)
284
-
285
- if self._mc_kv_store and self._mc_kv_store.available:
286
- if self._mc_kv_store.has(filepath):
287
- return self._mc_kv_store.get(filepath)
288
-
289
- attempt = 0
290
- while attempt < self.max_attempt:
291
- try:
292
- buffer = io.BytesIO()
293
- self._client.download_fileobj(
294
- Bucket=filepath.split("/")[0],
295
- Key="/".join(filepath.split("/")[1:]),
296
- Fileobj=buffer,
297
- )
298
- buffer.seek(0)
299
- if self._mc_kv_store and self._mc_kv_store.available:
300
- self._mc_kv_store.put(filepath, buffer.read())
301
-
302
- return buffer.read()
303
- except Exception as e:
304
- attempt += 1
305
- log.error(f"Got an exception: attempt={attempt} - {e} - {filepath}", rank0_only=False)
306
-
307
- raise ConnectionError("Unable to read {} from. {} attempts tried.".format(filepath, attempt))
308
-
309
- def _get_file_size(self, bucket, key, max_retries=10):
310
- retries = 0
311
- while retries < max_retries:
312
- try:
313
- # Try to get the file size
314
- file_size = self._client.head_object(Bucket=bucket, Key=key)["ContentLength"]
315
- return file_size # Return file size if successful
316
- except ClientError as e:
317
- retries += 1
318
- log.error(f"Attempt {retries} failed for s3://{bucket}/{key}: {e}", rank0_only=False)
319
- if retries >= max_retries:
320
- raise # Re-raise the exception after max retries
321
- time.sleep(2) # Wait for 2 seconds before retrying
322
- except Exception as e:
323
- retries += 1
324
- log.error(
325
- f"Attempt {retries} failed for s3://{bucket}/{key}: due to an unexpected error: {e}",
326
- rank0_only=False,
327
- )
328
- if retries >= max_retries:
329
- raise # Re-raise the exception after max retries
330
- time.sleep(2) # Wait for 2 seconds before retrying
331
-
332
- def put(self, obj, filepath):
333
- filepath = self._check_path(filepath)
334
- bucket_name = filepath.split("/")[0]
335
- key = "/".join(filepath.split("/")[1:])
336
- attempt = 0
337
- while attempt < self.max_attempt:
338
- try:
339
- # If obj is a string path to a local file, use upload_file instead
340
- if isinstance(obj, str) and os.path.isfile(obj):
341
- self._client.upload_file(Filename=obj, Bucket=bucket_name, Key=key)
342
- return
343
- if isinstance(obj, io.BytesIO):
344
- obj.seek(0)
345
- self._client.upload_fileobj(obj, Bucket=bucket_name, Key=key)
346
- return
347
- if isinstance(obj, bytes):
348
- self._client.put_object(Body=obj, Bucket=bucket_name, Key=key)
349
- return
350
- else:
351
- raise ValueError("Unsupported object type for upload")
352
- except ClientError as e:
353
- attempt += 1
354
- log.error(f"Got an exception: attempt={attempt} - {e} - {filepath}", rank0_only=False)
355
-
356
- raise ConnectionError("Unable to write {} to. {} attempts tried.".format(filepath, attempt))
357
-
358
- def fast_put(self, obj, filepath, num_processes: int = 32):
359
- assert aioboto3 is not None, "aioboto3 is required for fast_put"
360
- original_filepath = filepath
361
- filepath = self._check_path(filepath)
362
- bucket = filepath.split("/")[0]
363
- key = "/".join(filepath.split("/")[1:])
364
- part_size = 16 * 1024 * 1024 # 16 MB part size
365
-
366
- if isinstance(obj, bytes):
367
- data = obj
368
- elif isinstance(obj, str) and os.path.isfile(obj):
369
- with open(obj, "rb") as f:
370
- data = f.read()
371
- elif isinstance(obj, io.BytesIO):
372
- obj.seek(0)
373
- data = obj.read()
374
- else:
375
- raise ValueError("Unsupported object type for upload")
376
-
377
- file_size = len(data)
378
- if file_size <= part_size * num_processes:
379
- return self.put(data, original_filepath)
380
- num_parts = ceil(file_size / part_size)
381
- upload_id = self._client.create_multipart_upload(Bucket=bucket, Key=key)["UploadId"]
382
-
383
- part_numbers = np.array_split(np.arange(num_parts), num_processes)
384
-
385
- with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor:
386
- args = []
387
- for i in range(num_processes):
388
- cur_parts = part_numbers[i].tolist()
389
- cur_data = data[cur_parts[0] * part_size : min(cur_parts[-1] * part_size + part_size, file_size)]
390
- args.append((cur_parts, upload_id, part_size, cur_data, bucket, key, self._s3_cred_info))
391
- results = executor.map(upload_parts_to_s3, args)
392
- parts = []
393
- for result in results:
394
- parts.extend(result)
395
-
396
- parts = sorted(parts, key=lambda part: part["PartNumber"])
397
- self._client.complete_multipart_upload(
398
- Bucket=bucket, Key=key, UploadId=upload_id, MultipartUpload={"Parts": parts}
399
- )
400
-
401
- def contains(self, filepath: str, max_retries=10) -> bool:
402
- """
403
- Checks if the specified object exists in the S3 bucket with retry logic for errors.
404
-
405
- Args:
406
- filepath (str): The s3 path of the file to check, must start with "s3://".
407
-
408
- Returns:
409
- bool: True if the object exists in the S3 bucket, False otherwise.
410
-
411
- Raises:
412
- ClientError: If an error response other than "404 Not Found" is returned from the S3 service.
413
- """
414
- filepath = self._check_path(filepath)
415
- bucket = filepath.split("/")[0]
416
- key = "/".join(filepath.split("/")[1:])
417
-
418
- retries = 0
419
- while retries < max_retries:
420
- try:
421
- # Try to check if the object exists
422
- self._client.head_object(Bucket=bucket, Key=key)
423
- return True # Object exists
424
- except ClientError as e:
425
- if e.response["Error"]["Code"] == "404":
426
- return False # Object does not exist
427
- else:
428
- retries += 1
429
- print(f"Attempt {retries} failed with error: {e}")
430
- if retries >= max_retries:
431
- raise # Re-raise the exception if max retries are reached
432
- time.sleep(2) # Wait for 2 seconds before retrying
433
- except Exception as e:
434
- retries += 1
435
- print(f"Attempt {retries} failed due to an unexpected error: {e}")
436
- if retries >= max_retries:
437
- raise # Re-raise the exception if max retries are reached
438
- time.sleep(2) # Wait for 2 seconds before retrying
439
-
440
- def isdir(self, filepath: str, max_retries=10) -> bool:
441
- """
442
- Determines if the specified path corresponds to a directory in S3 with retry logic.
443
-
444
- A directory in S3 is implied if there are any objects stored with the given prefix,
445
- which means this function checks for the existence of any objects at or under the specified path.
446
-
447
- Args:
448
- filepath (str): The s3 path to check, must start with "s3://".
449
-
450
- Returns:
451
- bool: True if the specified path corresponds to a directory in S3, False otherwise.
452
- Directories in S3 are not physical entities but are implied by object keys.
453
-
454
- Raises:
455
- ClientError: An error from the S3 API that isn't related to the absence of the directory
456
- (logged but not raised further).
457
- """
458
- filepath = self._check_path(filepath)
459
- if not filepath.endswith("/"):
460
- filepath += "/"
461
-
462
- bucket = filepath.split("/")[0]
463
- prefix = "/".join(filepath.split("/")[1:])
464
-
465
- retries = 0
466
- while retries < max_retries:
467
- try:
468
- # Try to check if any objects exist with the given prefix (i.e., directory in S3)
469
- resp = self._client.list_objects_v2(Bucket=bucket, Prefix=prefix, Delimiter="/", MaxKeys=1)
470
- # Check if any content or prefixes exist under the given path
471
- return "CommonPrefixes" in resp or "Contents" in resp
472
- except ClientError as e:
473
- retries += 1
474
- log.error(f"Attempt {retries} failed: {e}", rank0_only=False)
475
- if retries >= max_retries:
476
- return False # Return False if maximum retries are reached
477
- time.sleep(2) # Wait for 2 seconds before retrying
478
- except Exception as e:
479
- retries += 1
480
- log.error(f"Attempt {retries} failed due to an unexpected error: {e}", rank0_only=False)
481
- if retries >= max_retries:
482
- return False # Return False if maximum retries are reached
483
- time.sleep(2) # Wait for 2 seconds before retrying
484
-
485
- def delete(self, filepath):
486
- filepath = self._check_path(filepath)
487
- self._client.delete_object(Bucket=filepath.split("/")[0], Key="/".join(filepath.split("/")[1:]))
488
-
489
- def ls_dir(self, filepath: str) -> Generator[str, None, None]:
490
- """
491
- List all folders in an S3 bucket with a given prefix.
492
-
493
- Args:
494
- filepath (str): The S3 path of the folder to list.
495
-
496
- Yields:
497
- str: The keys of the folders in the S3 bucket.
498
- """
499
- filepath = self._check_path(filepath)
500
- bucket = filepath.split("/")[0]
501
- prefix = "/".join(filepath.split("/")[1:])
502
- continuation_token = None
503
- if prefix and not prefix.endswith("/"):
504
- prefix += "/"
505
-
506
- while True:
507
- if continuation_token:
508
- resp = self._client.list_objects_v2(
509
- Bucket=bucket, Prefix=prefix, Delimiter="/", ContinuationToken=continuation_token
510
- )
511
- else:
512
- resp = self._client.list_objects_v2(Bucket=bucket, Prefix=prefix, Delimiter="/")
513
-
514
- if "CommonPrefixes" in resp:
515
- for item in resp["CommonPrefixes"]:
516
- yield item["Prefix"][len(prefix) :]
517
-
518
- # Check if there are more keys to retrieve
519
- if resp.get("IsTruncated"): # If IsTruncated is True, there are more keys
520
- continuation_token = resp.get("NextContinuationToken")
521
- else:
522
- break
523
-
524
- def list(self, filepath: str, exclude_prefix: str = None) -> Generator[str, None, None]:
525
- """
526
- List all keys in an S3 bucket with a given prefix, excluding files that start with
527
- specified prefix.
528
-
529
- Args:
530
- filepath (str): The S3 path of the file to list.
531
- exclude_prefix (str): Files starting with this prefix will be excluded from results.
532
- Defaults to "real".
533
-
534
- Yields:
535
- str: The keys of the files in the S3 bucket that don't start with exclude_prefix.
536
- """
537
- filepath = self._check_path(filepath)
538
- bucket = filepath.split("/")[0]
539
- prefix = "/".join(filepath.split("/")[1:])
540
-
541
- continuation_token = None
542
-
543
- while True:
544
- if continuation_token:
545
- resp = self._client.list_objects_v2(Bucket=bucket, Prefix=prefix, ContinuationToken=continuation_token)
546
- else:
547
- resp = self._client.list_objects_v2(Bucket=bucket, Prefix=prefix)
548
-
549
- if "Contents" in resp:
550
- for item in resp["Contents"]:
551
- key = item["Key"][len(prefix) :]
552
- # Skip files that start with the excluded prefix
553
- if exclude_prefix is None or not key.startswith(exclude_prefix):
554
- yield key
555
-
556
- # Check if there are more keys to retrieve
557
- if resp.get("IsTruncated"): # If IsTruncated is True, there are more keys
558
- continuation_token = resp.get("NextContinuationToken")
559
- else:
560
- break
561
-
562
- def _check_path(self, filepath: str):
563
- assert filepath.startswith("s3://")
564
- filepath = filepath[5:]
565
- return filepath
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/backends/http_backend.py DELETED
@@ -1,91 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import os
17
- import tempfile
18
- from contextlib import contextmanager
19
- from pathlib import Path
20
- from typing import Generator, Union
21
- from urllib.request import urlopen
22
-
23
- from lyra_2._ext.imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
24
-
25
-
26
- class HTTPBackend(BaseStorageBackend):
27
- """HTTP and HTTPS storage bachend."""
28
-
29
- def get(self, filepath: str) -> bytes:
30
- """Read bytes from a given ``filepath``.
31
-
32
- Args:
33
- filepath (str): Path to read data.
34
-
35
- Returns:
36
- bytes: Expected bytes object.
37
-
38
- Examples:
39
- >>> backend = HTTPBackend()
40
- >>> backend.get('http://path/of/file')
41
- b'hello world'
42
- """
43
- return urlopen(filepath).read()
44
-
45
- def get_text(self, filepath, encoding="utf-8") -> str:
46
- """Read text from a given ``filepath``.
47
-
48
- Args:
49
- filepath (str): Path to read data.
50
- encoding (str): The encoding format used to open the ``filepath``.
51
- Defaults to 'utf-8'.
52
-
53
- Returns:
54
- str: Expected text reading from ``filepath``.
55
-
56
- Examples:
57
- >>> backend = HTTPBackend()
58
- >>> backend.get_text('http://path/of/file')
59
- 'hello world'
60
- """
61
- return urlopen(filepath).read().decode(encoding)
62
-
63
- @contextmanager
64
- def get_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
65
- """Download a file from ``filepath`` to a local temporary directory,
66
- and return the temporary path.
67
-
68
- ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
69
- can be called with ``with`` statement, and when exists from the
70
- ``with`` statement, the temporary path will be released.
71
-
72
- Args:
73
- filepath (str): Download a file from ``filepath``.
74
-
75
- Yields:
76
- Iterable[str]: Only yield one temporary path.
77
-
78
- Examples:
79
- >>> backend = HTTPBackend()
80
- >>> # After existing from the ``with`` clause,
81
- >>> # the path will be removed
82
- >>> with backend.get_local_path('http://path/of/file') as path:
83
- ... # do something here
84
- """
85
- try:
86
- f = tempfile.NamedTemporaryFile(delete=False)
87
- f.write(self.get(filepath))
88
- f.close()
89
- yield f.name
90
- finally:
91
- os.remove(f.name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/backends/local_backend.py DELETED
@@ -1,550 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import io
17
- import os
18
- import os.path as osp
19
- import shutil
20
- from contextlib import contextmanager
21
- from pathlib import Path
22
- from typing import Generator, Iterator, Optional, Tuple, Union
23
-
24
- from lyra_2._ext.imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend, mkdir_or_exist
25
-
26
-
27
- class LocalBackend(BaseStorageBackend):
28
- """Raw local storage backend."""
29
-
30
- _allow_symlink = True
31
-
32
- def get(self, filepath: Union[str, Path]) -> bytes:
33
- """Read bytes from a given ``filepath`` with 'rb' mode.
34
-
35
- Args:
36
- filepath (str or Path): Path to read data.
37
-
38
- Returns:
39
- bytes: Expected bytes object.
40
-
41
- Examples:
42
- >>> backend = LocalBackend()
43
- >>> filepath = '/path/of/file'
44
- >>> backend.get(filepath)
45
- b'hello world'
46
- """
47
- with open(filepath, "rb") as f:
48
- value = f.read()
49
- return value
50
-
51
- def get_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str:
52
- """Read text from a given ``filepath`` with 'r' mode.
53
-
54
- Args:
55
- filepath (str or Path): Path to read data.
56
- encoding (str): The encoding format used to open the ``filepath``.
57
- Defaults to 'utf-8'.
58
-
59
- Returns:
60
- str: Expected text reading from ``filepath``.
61
-
62
- Examples:
63
- >>> backend = LocalBackend()
64
- >>> filepath = '/path/of/file'
65
- >>> backend.get_text(filepath)
66
- 'hello world'
67
- """
68
- with open(filepath, encoding=encoding) as f:
69
- text = f.read()
70
- return text
71
-
72
- def put(self, obj: Union[bytes, io.BytesIO], filepath: Union[str, Path]) -> None:
73
- """Write bytes to a given ``filepath`` with 'wb' mode.
74
-
75
- Note:
76
- ``put`` will create a directory if the directory of
77
- ``filepath`` does not exist.
78
-
79
- Args:
80
- obj (bytes): Data to be written.
81
- filepath (str or Path): Path to write data.
82
-
83
- Examples:
84
- >>> backend = LocalBackend()
85
- >>> filepath = '/path/of/file'
86
- >>> backend.put(b'hello world', filepath)
87
- """
88
- mkdir_or_exist(osp.dirname(filepath))
89
- if isinstance(obj, io.BytesIO):
90
- obj.seek(0)
91
- obj = obj.getvalue()
92
- with open(filepath, "wb") as f:
93
- f.write(obj)
94
-
95
- def put_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
96
- """Write text to a given ``filepath`` with 'w' mode.
97
-
98
- Note:
99
- ``put_text`` will create a directory if the directory of
100
- ``filepath`` does not exist.
101
-
102
- Args:
103
- obj (str): Data to be written.
104
- filepath (str or Path): Path to write data.
105
- encoding (str): The encoding format used to open the ``filepath``.
106
- Defaults to 'utf-8'.
107
-
108
- Examples:
109
- >>> backend = LocalBackend()
110
- >>> filepath = '/path/of/file'
111
- >>> backend.put_text('hello world', filepath)
112
- """
113
- mkdir_or_exist(osp.dirname(filepath))
114
- with open(filepath, "w", encoding=encoding) as f:
115
- f.write(obj)
116
-
117
- def exists(self, filepath: Union[str, Path]) -> bool:
118
- """Check whether a file path exists.
119
-
120
- Args:
121
- filepath (str or Path): Path to be checked whether exists.
122
-
123
- Returns:
124
- bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
125
-
126
- Examples:
127
- >>> backend = LocalBackend()
128
- >>> filepath = '/path/of/file'
129
- >>> backend.exists(filepath)
130
- True
131
- """
132
- return osp.exists(filepath)
133
-
134
- def isdir(self, filepath: Union[str, Path]) -> bool:
135
- """Check whether a file path is a directory.
136
-
137
- Args:
138
- filepath (str or Path): Path to be checked whether it is a
139
- directory.
140
-
141
- Returns:
142
- bool: Return ``True`` if ``filepath`` points to a directory,
143
- ``False`` otherwise.
144
-
145
- Examples:
146
- >>> backend = LocalBackend()
147
- >>> filepath = '/path/of/dir'
148
- >>> backend.isdir(filepath)
149
- True
150
- """
151
- return osp.isdir(filepath)
152
-
153
- def isfile(self, filepath: Union[str, Path]) -> bool:
154
- """Check whether a file path is a file.
155
-
156
- Args:
157
- filepath (str or Path): Path to be checked whether it is a file.
158
-
159
- Returns:
160
- bool: Return ``True`` if ``filepath`` points to a file, ``False``
161
- otherwise.
162
-
163
- Examples:
164
- >>> backend = LocalBackend()
165
- >>> filepath = '/path/of/file'
166
- >>> backend.isfile(filepath)
167
- True
168
- """
169
- return osp.isfile(filepath)
170
-
171
- def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str:
172
- r"""Concatenate all file paths.
173
-
174
- Join one or more filepath components intelligently. The return value
175
- is the concatenation of filepath and any members of \*filepaths.
176
-
177
- Args:
178
- filepath (str or Path): Path to be concatenated.
179
-
180
- Returns:
181
- str: The result of concatenation.
182
-
183
- Examples:
184
- >>> backend = LocalBackend()
185
- >>> filepath1 = '/path/of/dir1'
186
- >>> filepath2 = 'dir2'
187
- >>> filepath3 = 'path/of/file'
188
- >>> backend.join_path(filepath1, filepath2, filepath3)
189
- '/path/of/dir/dir2/path/of/file'
190
- """
191
- return osp.join(filepath, *filepaths)
192
-
193
- @contextmanager
194
- def get_local_path(
195
- self,
196
- filepath: Union[str, Path],
197
- ) -> Generator[Union[str, Path], None, None]:
198
- """Only for unified API and do nothing.
199
-
200
- Args:
201
- filepath (str or Path): Path to be read data.
202
- backend_args (dict, optional): Arguments to instantiate the
203
- corresponding backend. Defaults to None.
204
-
205
- Examples:
206
- >>> backend = LocalBackend()
207
- >>> with backend.get_local_path('s3://bucket/abc.jpg') as path:
208
- ... # do something here
209
- """
210
- yield filepath
211
-
212
- def copyfile(
213
- self,
214
- src: Union[str, Path],
215
- dst: Union[str, Path],
216
- ) -> str:
217
- """Copy a file src to dst and return the destination file.
218
-
219
- src and dst should have the same prefix. If dst specifies a directory,
220
- the file will be copied into dst using the base filename from src. If
221
- dst specifies a file that already exists, it will be replaced.
222
-
223
- Args:
224
- src (str or Path): A file to be copied.
225
- dst (str or Path): Copy file to dst.
226
-
227
- Returns:
228
- str: The destination file.
229
-
230
- Raises:
231
- SameFileError: If src and dst are the same file, a SameFileError
232
- will be raised.
233
-
234
- Examples:
235
- >>> backend = LocalBackend()
236
- >>> # dst is a file
237
- >>> src = '/path/of/file'
238
- >>> dst = '/path1/of/file1'
239
- >>> # src will be copied to '/path1/of/file1'
240
- >>> backend.copyfile(src, dst)
241
- '/path1/of/file1'
242
-
243
- >>> # dst is a directory
244
- >>> dst = '/path1/of/dir'
245
- >>> # src will be copied to '/path1/of/dir/file'
246
- >>> backend.copyfile(src, dst)
247
- '/path1/of/dir/file'
248
- """
249
- return shutil.copy(src, dst)
250
-
251
- def copytree(
252
- self,
253
- src: Union[str, Path],
254
- dst: Union[str, Path],
255
- ) -> str:
256
- """Recursively copy an entire directory tree rooted at src to a
257
- directory named dst and return the destination directory.
258
-
259
- src and dst should have the same prefix and dst must not already exist.
260
-
261
-
262
- Args:
263
- src (str or Path): A directory to be copied.
264
- dst (str or Path): Copy directory to dst.
265
-
266
- Returns:
267
- str: The destination directory.
268
-
269
- Raises:
270
- FileExistsError: If dst had already existed, a FileExistsError will
271
- be raised.
272
-
273
- Examples:
274
- >>> backend = LocalBackend()
275
- >>> src = '/path/of/dir1'
276
- >>> dst = '/path/of/dir2'
277
- >>> backend.copytree(src, dst)
278
- '/path/of/dir2'
279
- """
280
- return shutil.copytree(src, dst)
281
-
282
- def copyfile_from_local(
283
- self,
284
- src: Union[str, Path],
285
- dst: Union[str, Path],
286
- ) -> str:
287
- """Copy a local file src to dst and return the destination file. Same
288
- as :meth:`copyfile`.
289
-
290
- Args:
291
- src (str or Path): A local file to be copied.
292
- dst (str or Path): Copy file to dst.
293
-
294
- Returns:
295
- str: If dst specifies a directory, the file will be copied into dst
296
- using the base filename from src.
297
-
298
- Raises:
299
- SameFileError: If src and dst are the same file, a SameFileError
300
- will be raised.
301
-
302
- Examples:
303
- >>> backend = LocalBackend()
304
- >>> # dst is a file
305
- >>> src = '/path/of/file'
306
- >>> dst = '/path1/of/file1'
307
- >>> # src will be copied to '/path1/of/file1'
308
- >>> backend.copyfile_from_local(src, dst)
309
- '/path1/of/file1'
310
-
311
- >>> # dst is a directory
312
- >>> dst = '/path1/of/dir'
313
- >>> # src will be copied to
314
- >>> backend.copyfile_from_local(src, dst)
315
- '/path1/of/dir/file'
316
- """
317
- return self.copyfile(src, dst)
318
-
319
- def copytree_from_local(
320
- self,
321
- src: Union[str, Path],
322
- dst: Union[str, Path],
323
- ) -> str:
324
- """Recursively copy an entire directory tree rooted at src to a
325
- directory named dst and return the destination directory. Same as
326
- :meth:`copytree`.
327
-
328
- Args:
329
- src (str or Path): A local directory to be copied.
330
- dst (str or Path): Copy directory to dst.
331
-
332
- Returns:
333
- str: The destination directory.
334
-
335
- Examples:
336
- >>> backend = LocalBackend()
337
- >>> src = '/path/of/dir1'
338
- >>> dst = '/path/of/dir2'
339
- >>> backend.copytree_from_local(src, dst)
340
- '/path/of/dir2'
341
- """
342
- return self.copytree(src, dst)
343
-
344
- def copyfile_to_local(
345
- self,
346
- src: Union[str, Path],
347
- dst: Union[str, Path],
348
- dst_type: Optional[str] = None,
349
- ) -> str:
350
- """Copy the file src to local dst and return the destination file. Same
351
- as :meth:`copyfile`.
352
-
353
- If dst specifies a directory, the file will be copied into dst using
354
- the base filename from src. If dst specifies a file that already
355
- exists, it will be replaced.
356
-
357
- Args:
358
- src (str or Path): A file to be copied.
359
- dst (str or Path): Copy file to to local dst.
360
-
361
- Returns:
362
- str: If dst specifies a directory, the file will be copied into dst
363
- using the base filename from src.
364
-
365
- Examples:
366
- >>> backend = LocalBackend()
367
- >>> # dst is a file
368
- >>> src = '/path/of/file'
369
- >>> dst = '/path1/of/file1'
370
- >>> # src will be copied to '/path1/of/file1'
371
- >>> backend.copyfile_to_local(src, dst)
372
- '/path1/of/file1'
373
-
374
- >>> # dst is a directory
375
- >>> dst = '/path1/of/dir'
376
- >>> # src will be copied to
377
- >>> backend.copyfile_to_local(src, dst)
378
- '/path1/of/dir/file'
379
- """
380
- return self.copyfile(src, dst)
381
-
382
- def copytree_to_local(
383
- self,
384
- src: Union[str, Path],
385
- dst: Union[str, Path],
386
- ) -> str:
387
- """Recursively copy an entire directory tree rooted at src to a local
388
- directory named dst and return the destination directory.
389
-
390
- Args:
391
- src (str or Path): A directory to be copied.
392
- dst (str or Path): Copy directory to local dst.
393
- backend_args (dict, optional): Arguments to instantiate the
394
- prefix of uri corresponding backend. Defaults to None.
395
-
396
- Returns:
397
- str: The destination directory.
398
-
399
- Examples:
400
- >>> backend = LocalBackend()
401
- >>> src = '/path/of/dir1'
402
- >>> dst = '/path/of/dir2'
403
- >>> backend.copytree_from_local(src, dst)
404
- '/path/of/dir2'
405
- """
406
- return self.copytree(src, dst)
407
-
408
- def remove(self, filepath: Union[str, Path]) -> None:
409
- """Remove a file.
410
-
411
- Args:
412
- filepath (str or Path): Path to be removed.
413
-
414
- Raises:
415
- IsADirectoryError: If filepath is a directory, an IsADirectoryError
416
- will be raised.
417
- FileNotFoundError: If filepath does not exist, an FileNotFoundError
418
- will be raised.
419
-
420
- Examples:
421
- >>> backend = LocalBackend()
422
- >>> filepath = '/path/of/file'
423
- >>> backend.remove(filepath)
424
- """
425
- if not self.exists(filepath):
426
- raise FileNotFoundError(f"filepath {filepath} does not exist")
427
-
428
- if self.isdir(filepath):
429
- raise IsADirectoryError("filepath should be a file")
430
-
431
- os.remove(filepath)
432
-
433
- def rmtree(self, dir_path: Union[str, Path]) -> None:
434
- """Recursively delete a directory tree.
435
-
436
- Args:
437
- dir_path (str or Path): A directory to be removed.
438
-
439
- Examples:
440
- >>> dir_path = '/path/of/dir'
441
- >>> backend.rmtree(dir_path)
442
- """
443
- shutil.rmtree(dir_path)
444
-
445
- def copy_if_symlink_fails(
446
- self,
447
- src: Union[str, Path],
448
- dst: Union[str, Path],
449
- ) -> bool:
450
- """Create a symbolic link pointing to src named dst.
451
-
452
- If failed to create a symbolic link pointing to src, directly copy src
453
- to dst instead.
454
-
455
- Args:
456
- src (str or Path): Create a symbolic link pointing to src.
457
- dst (str or Path): Create a symbolic link named dst.
458
-
459
- Returns:
460
- bool: Return True if successfully create a symbolic link pointing
461
- to src. Otherwise, return False.
462
-
463
- Examples:
464
- >>> backend = LocalBackend()
465
- >>> src = '/path/of/file'
466
- >>> dst = '/path1/of/file1'
467
- >>> backend.copy_if_symlink_fails(src, dst)
468
- True
469
- >>> src = '/path/of/dir'
470
- >>> dst = '/path1/of/dir1'
471
- >>> backend.copy_if_symlink_fails(src, dst)
472
- True
473
- """
474
- try:
475
- os.symlink(src, dst)
476
- return True
477
- except Exception:
478
- if self.isfile(src):
479
- self.copyfile(src, dst)
480
- else:
481
- self.copytree(src, dst)
482
- return False
483
-
484
- def list_dir_or_file(
485
- self,
486
- dir_path: Union[str, Path],
487
- list_dir: bool = True,
488
- list_file: bool = True,
489
- suffix: Optional[Union[str, Tuple[str]]] = None,
490
- recursive: bool = False,
491
- ) -> Iterator[str]:
492
- """Scan a directory to find the interested directories or files in
493
- arbitrary order.
494
-
495
- Note:
496
- :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
497
-
498
- Args:
499
- dir_path (str or Path): Path of the directory.
500
- list_dir (bool): List the directories. Defaults to True.
501
- list_file (bool): List the path of files. Defaults to True.
502
- suffix (str or tuple[str], optional): File suffix that we are
503
- interested in. Defaults to None.
504
- recursive (bool): If set to True, recursively scan the directory.
505
- Defaults to False.
506
-
507
- Yields:
508
- Iterable[str]: A relative path to ``dir_path``.
509
-
510
- Examples:
511
- >>> backend = LocalBackend()
512
- >>> dir_path = '/path/of/dir'
513
- >>> # list those files and directories in current directory
514
- >>> for file_path in backend.list_dir_or_file(dir_path):
515
- ... print(file_path)
516
- >>> # only list files
517
- >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False):
518
- ... print(file_path)
519
- >>> # only list directories
520
- >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False):
521
- ... print(file_path)
522
- >>> # only list files ending with specified suffixes
523
- >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'):
524
- ... print(file_path)
525
- >>> # list all files and directory recursively
526
- >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True):
527
- ... print(file_path)
528
- """ # noqa: E501
529
- if list_dir and suffix is not None:
530
- raise TypeError("`suffix` should be None when `list_dir` is True")
531
-
532
- if (suffix is not None) and not isinstance(suffix, (str, tuple)):
533
- raise TypeError("`suffix` must be a string or tuple of strings")
534
-
535
- root = dir_path
536
-
537
- def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive):
538
- for entry in os.scandir(dir_path):
539
- if not entry.name.startswith(".") and entry.is_file():
540
- rel_path = osp.relpath(entry.path, root)
541
- if (suffix is None or rel_path.endswith(suffix)) and list_file:
542
- yield rel_path
543
- elif osp.isdir(entry.path):
544
- if list_dir:
545
- rel_dir = osp.relpath(entry.path, root)
546
- yield rel_dir
547
- if recursive:
548
- yield from _list_dir_or_file(entry.path, list_dir, list_file, suffix, recursive)
549
-
550
- return _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/backends/registry_utils.py DELETED
@@ -1,130 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import inspect
17
- from typing import Optional, Type, Union
18
-
19
- from lyra_2._ext.imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
20
- from lyra_2._ext.imaginaire.utils.easy_io.backends.boto3_backend import Boto3Backend
21
- from lyra_2._ext.imaginaire.utils.easy_io.backends.http_backend import HTTPBackend
22
- from lyra_2._ext.imaginaire.utils.easy_io.backends.local_backend import LocalBackend
23
-
24
- backends: dict = {}
25
- prefix_to_backends: dict = {}
26
-
27
-
28
- def _register_backend(
29
- name: str,
30
- backend: Type[BaseStorageBackend],
31
- force: bool = False,
32
- prefixes: Union[str, list, tuple, None] = None,
33
- ):
34
- """Register a backend.
35
-
36
- Args:
37
- name (str): The name of the registered backend.
38
- backend (BaseStorageBackend): The backend class to be registered,
39
- which must be a subclass of :class:`BaseStorageBackend`.
40
- force (bool): Whether to override the backend if the name has already
41
- been registered. Defaults to False.
42
- prefixes (str or list[str] or tuple[str], optional): The prefix
43
- of the registered storage backend. Defaults to None.
44
- """
45
- global backends, prefix_to_backends
46
-
47
- if not isinstance(name, str):
48
- raise TypeError(f"the backend name should be a string, but got {type(name)}")
49
-
50
- if not inspect.isclass(backend):
51
- raise TypeError(f"backend should be a class, but got {type(backend)}")
52
- if not issubclass(backend, BaseStorageBackend):
53
- raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend")
54
-
55
- if name in backends and not force:
56
- raise ValueError(
57
- f'{name} is already registered as a storage backend, add "force=True" if you want to override it'
58
- )
59
- backends[name] = backend
60
-
61
- if prefixes is not None:
62
- if isinstance(prefixes, str):
63
- prefixes = [prefixes]
64
- else:
65
- assert isinstance(prefixes, (list, tuple))
66
-
67
- for prefix in prefixes:
68
- if prefix in prefix_to_backends and not force:
69
- raise ValueError(
70
- f'{prefix} is already registered as a storage backend, add "force=True" if you want to override it'
71
- )
72
-
73
- prefix_to_backends[prefix] = backend
74
-
75
-
76
- def register_backend(
77
- name: str,
78
- backend: Optional[Type[BaseStorageBackend]] = None,
79
- force: bool = False,
80
- prefixes: Union[str, list, tuple, None] = None,
81
- ):
82
- """Register a backend.
83
-
84
- Args:
85
- name (str): The name of the registered backend.
86
- backend (class, optional): The backend class to be registered,
87
- which must be a subclass of :class:`BaseStorageBackend`.
88
- When this method is used as a decorator, backend is None.
89
- Defaults to None.
90
- force (bool): Whether to override the backend if the name has already
91
- been registered. Defaults to False.
92
- prefixes (str or list[str] or tuple[str], optional): The prefix
93
- of the registered storage backend. Defaults to None.
94
-
95
- This method can be used as a normal method or a decorator.
96
-
97
- Examples:
98
-
99
- >>> class NewBackend(BaseStorageBackend):
100
- ... def get(self, filepath):
101
- ... return filepath
102
- ...
103
- ... def get_text(self, filepath):
104
- ... return filepath
105
- >>> register_backend('new', NewBackend)
106
-
107
- >>> @register_backend('new')
108
- ... class NewBackend(BaseStorageBackend):
109
- ... def get(self, filepath):
110
- ... return filepath
111
- ...
112
- ... def get_text(self, filepath):
113
- ... return filepath
114
- """
115
- if backend is not None:
116
- _register_backend(name, backend, force=force, prefixes=prefixes)
117
- return
118
-
119
- def _register(backend_cls):
120
- _register_backend(name, backend_cls, force=force, prefixes=prefixes)
121
- return backend_cls
122
-
123
- return _register
124
-
125
-
126
- register_backend("local", LocalBackend, prefixes="")
127
- # To avoid breaking backward Compatibility, 's3' is also used as a
128
- # prefix for Boto3Backend
129
- register_backend("s3", Boto3Backend, prefixes=["s3"])
130
- register_backend("http", HTTPBackend, prefixes=["http", "https"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/easy_io.py DELETED
@@ -1,1085 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import json
17
- import warnings
18
- from contextlib import contextmanager
19
- from io import BytesIO, StringIO
20
- from pathlib import Path
21
- from typing import IO, Any, Generator, Iterator, Optional, Tuple, Union
22
-
23
- from lyra_2._ext.imaginaire.utils.easy_io.backends import backends, prefix_to_backends
24
- from lyra_2._ext.imaginaire.utils.easy_io.file_client import FileClient
25
- from lyra_2._ext.imaginaire.utils.easy_io.handlers import file_handlers
26
-
27
- backend_instances: dict = {}
28
-
29
-
30
- def is_filepath(filepath):
31
- return isinstance(filepath, (str, Path))
32
-
33
-
34
- def _parse_uri_prefix(uri: Union[str, Path]) -> str:
35
- """Parse the prefix of uri.
36
-
37
- Args:
38
- uri (str or Path): Uri to be parsed that contains the file prefix.
39
-
40
- Examples:
41
- >>> _parse_uri_prefix('/home/path/of/your/file')
42
- ''
43
- >>> _parse_uri_prefix('s3://path/of/your/file')
44
- 's3'
45
- >>> _parse_uri_prefix('clusterName:s3://path/of/your/file')
46
- 's3'
47
-
48
- Returns:
49
- str: Return the prefix of uri if the uri contains '://'. Otherwise,
50
- return ''.
51
- """
52
- assert is_filepath(uri)
53
- uri = str(uri)
54
- # if uri does not contains '://', the uri will be handled by
55
- # LocalBackend by default
56
- if "://" not in uri:
57
- return ""
58
- else:
59
- prefix, _ = uri.split("://")
60
- # In the case of Boto3Backend, the prefix may contain the cluster
61
- # name like clusterName:s3://path/of/your/file
62
- if ":" in prefix:
63
- _, prefix = prefix.split(":")
64
- return prefix
65
-
66
-
67
- def _get_file_backend(prefix: str, backend_args: dict):
68
- """Return a file backend based on the prefix or backend_args.
69
-
70
- Args:
71
- prefix (str): Prefix of uri.
72
- backend_args (dict): Arguments to instantiate the corresponding
73
- backend.
74
- """
75
- # backend name has a higher priority
76
- if "backend" in backend_args:
77
- # backend_args should not be modified
78
- backend_args_bak = backend_args.copy()
79
- backend_name = backend_args_bak.pop("backend")
80
- backend = backends[backend_name](**backend_args_bak)
81
- else:
82
- backend = prefix_to_backends[prefix](**backend_args)
83
- return backend
84
-
85
-
86
- def set_s3_backend(
87
- key: str = "s3:{}",
88
- backend_args: Optional[dict] = None,
89
- ):
90
- """register s3 backend.
91
-
92
- Args:
93
- key str: The key to register the s3 backend. Defaults to s3.
94
- backend_args (dict, optional): Arguments to instantiate the
95
- corresponding backend. Defaults to None.
96
- """
97
- global backend_instances
98
- if backend_args is None:
99
- backend_args = {}
100
- backend = _get_file_backend(key, backend_args)
101
- backend_instances[key] = backend
102
- return backend
103
-
104
-
105
- def get_file_backend(
106
- uri: Union[str, Path, None] = None,
107
- *,
108
- backend_args: Optional[dict] = None,
109
- enable_singleton: bool = False,
110
- backend_key: Optional[str] = None,
111
- ):
112
- """Return a file backend based on the prefix of uri or backend_args.
113
-
114
- Args:
115
- uri (str or Path): Uri to be parsed that contains the file prefix.
116
- backend_args (dict, optional): Arguments to instantiate the
117
- corresponding backend. Defaults to None.
118
- enable_singleton (bool): Whether to enable the singleton pattern.
119
- If it is True, the backend created will be reused if the
120
- signature is same with the previous one. Defaults to False.
121
- backend_key: str: The key to register the backend. Defaults to None.
122
-
123
- Returns:
124
- BaseStorageBackend: Instantiated Backend object.
125
-
126
- Examples:
127
- >>> # get file backend based on the prefix of uri
128
- >>> uri = 's3://path/of/your/file'
129
- >>> backend = get_file_backend(uri)
130
- >>> # get file backend based on the backend_args
131
- >>> backend = get_file_backend(backend_args={'backend': 's3'})
132
- >>> # backend name has a higher priority if 'backend' in backend_args
133
- >>> backend = get_file_backend(uri, backend_args={'backend': 's3'})
134
- """
135
- global backend_instances
136
- if backend_key is not None:
137
- if backend_key in backend_instances:
138
- return backend_instances[backend_key]
139
-
140
- if backend_args is None:
141
- backend_args = {}
142
-
143
- if uri is None and "backend" not in backend_args and backend_key is None:
144
- raise ValueError('uri should not be None when "backend" does not exist in backend_args and backend_key is None')
145
-
146
- if uri is not None:
147
- prefix = _parse_uri_prefix(uri)
148
- else:
149
- prefix = ""
150
-
151
- if enable_singleton:
152
- unique_key = f"{prefix}:{json.dumps(backend_args)}"
153
- if unique_key in backend_instances:
154
- return backend_instances[unique_key]
155
-
156
- backend = _get_file_backend(prefix, backend_args)
157
- backend_instances[unique_key] = backend
158
- if backend_key is not None:
159
- backend_instances[backend_key] = backend
160
- return backend
161
- else:
162
- backend = _get_file_backend(prefix, backend_args)
163
- return backend
164
-
165
-
166
- def get(
167
- filepath: Union[str, Path],
168
- backend_args: Optional[dict] = None,
169
- backend_key: Optional[str] = None,
170
- ) -> bytes:
171
- """Read bytes from a given ``filepath`` with 'rb' mode.
172
-
173
- Args:
174
- filepath (str or Path): Path to read data.
175
- backend_args (dict, optional): Arguments to instantiate the
176
- corresponding backend. Defaults to None.
177
- backend_key (str, optional): The key to get the backend from register.
178
-
179
- Returns:
180
- bytes: Expected bytes object.
181
-
182
- Examples:
183
- >>> filepath = '/path/of/file'
184
- >>> get(filepath)
185
- b'hello world'
186
- """
187
- backend = get_file_backend(
188
- filepath,
189
- backend_args=backend_args,
190
- enable_singleton=True,
191
- backend_key=backend_key,
192
- )
193
- return backend.get(filepath)
194
-
195
-
196
- def get_text(
197
- filepath: Union[str, Path],
198
- encoding="utf-8",
199
- backend_args: Optional[dict] = None,
200
- backend_key: Optional[str] = None,
201
- ) -> str:
202
- """Read text from a given ``filepath`` with 'r' mode.
203
-
204
- Args:
205
- filepath (str or Path): Path to read data.
206
- encoding (str): The encoding format used to open the ``filepath``.
207
- Defaults to 'utf-8'.
208
- backend_args (dict, optional): Arguments to instantiate the
209
- corresponding backend. Defaults to None.
210
- backend_key (str, optional): The key to get the backend from register.
211
-
212
- Returns:
213
- str: Expected text reading from ``filepath``.
214
-
215
- Examples:
216
- >>> filepath = '/path/of/file'
217
- >>> get_text(filepath)
218
- 'hello world'
219
- """
220
- backend = get_file_backend(
221
- filepath,
222
- backend_args=backend_args,
223
- enable_singleton=True,
224
- backend_key=backend_key,
225
- )
226
- return backend.get_text(filepath, encoding)
227
-
228
-
229
- def put(
230
- obj: bytes,
231
- filepath: Union[str, Path],
232
- backend_args: Optional[dict] = None,
233
- backend_key: Optional[str] = None,
234
- ) -> None:
235
- """Write bytes to a given ``filepath`` with 'wb' mode.
236
-
237
- Note:
238
- ``put`` should create a directory if the directory of
239
- ``filepath`` does not exist.
240
-
241
- Args:
242
- obj (bytes): Data to be written.
243
- filepath (str or Path): Path to write data.
244
- backend_args (dict, optional): Arguments to instantiate the
245
- corresponding backend. Defaults to None.
246
- backend_key (str, optional): The key to get the backend from register.
247
-
248
- Examples:
249
- >>> filepath = '/path/of/file'
250
- >>> put(b'hello world', filepath)
251
- """
252
- backend = get_file_backend(
253
- filepath,
254
- backend_args=backend_args,
255
- enable_singleton=True,
256
- backend_key=backend_key,
257
- )
258
- backend.put(obj, filepath)
259
-
260
-
261
- def put_text(
262
- obj: str,
263
- filepath: Union[str, Path],
264
- backend_args: Optional[dict] = None,
265
- backend_key: Optional[str] = None,
266
- ) -> None:
267
- """Write text to a given ``filepath`` with 'w' mode.
268
-
269
- Note:
270
- ``put_text`` should create a directory if the directory of
271
- ``filepath`` does not exist.
272
-
273
- Args:
274
- obj (str): Data to be written.
275
- filepath (str or Path): Path to write data.
276
- encoding (str, optional): The encoding format used to open the
277
- ``filepath``. Defaults to 'utf-8'.
278
- backend_args (dict, optional): Arguments to instantiate the
279
- corresponding backend. Defaults to None.
280
- backend_key (str, optional): The key to get the backend from register.
281
-
282
- Examples:
283
- >>> filepath = '/path/of/file'
284
- >>> put_text('hello world', filepath)
285
- """
286
- backend = get_file_backend(
287
- filepath,
288
- backend_args=backend_args,
289
- enable_singleton=True,
290
- backend_key=backend_key,
291
- )
292
- backend.put_text(obj, filepath)
293
-
294
-
295
- def exists(
296
- filepath: Union[str, Path],
297
- backend_args: Optional[dict] = None,
298
- backend_key: Optional[str] = None,
299
- ) -> bool:
300
- """Check whether a file path exists.
301
-
302
- Args:
303
- filepath (str or Path): Path to be checked whether exists.
304
- backend_args (dict, optional): Arguments to instantiate the
305
- corresponding backend. Defaults to None.
306
- backend_key (str, optional): The key to get the backend from register.
307
-
308
- Returns:
309
- bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
310
-
311
- Examples:
312
- >>> filepath = '/path/of/file'
313
- >>> exists(filepath)
314
- True
315
- """
316
- backend = get_file_backend(
317
- filepath,
318
- backend_args=backend_args,
319
- enable_singleton=True,
320
- backend_key=backend_key,
321
- )
322
- return backend.exists(filepath)
323
-
324
-
325
- def isdir(
326
- filepath: Union[str, Path],
327
- backend_args: Optional[dict] = None,
328
- backend_key: Optional[str] = None,
329
- ) -> bool:
330
- """Check whether a file path is a directory.
331
-
332
- Args:
333
- filepath (str or Path): Path to be checked whether it is a
334
- directory.
335
- backend_args (dict, optional): Arguments to instantiate the
336
- corresponding backend. Defaults to None.
337
- backend_key (str, optional): The key to get the backend from register.
338
-
339
- Returns:
340
- bool: Return ``True`` if ``filepath`` points to a directory,
341
- ``False`` otherwise.
342
-
343
- Examples:
344
- >>> filepath = '/path/of/dir'
345
- >>> isdir(filepath)
346
- True
347
- """
348
- backend = get_file_backend(
349
- filepath,
350
- backend_args=backend_args,
351
- enable_singleton=True,
352
- backend_key=backend_key,
353
- )
354
- return backend.isdir(filepath)
355
-
356
-
357
- def isfile(
358
- filepath: Union[str, Path],
359
- backend_args: Optional[dict] = None,
360
- backend_key: Optional[str] = None,
361
- ) -> bool:
362
- """Check whether a file path is a file.
363
-
364
- Args:
365
- filepath (str or Path): Path to be checked whether it is a file.
366
- backend_args (dict, optional): Arguments to instantiate the
367
- corresponding backend. Defaults to None.
368
- backend_key (str, optional): The key to get the backend from register.
369
-
370
- Returns:
371
- bool: Return ``True`` if ``filepath`` points to a file, ``False``
372
- otherwise.
373
-
374
- Examples:
375
- >>> filepath = '/path/of/file'
376
- >>> isfile(filepath)
377
- True
378
- """
379
- backend = get_file_backend(
380
- filepath,
381
- backend_args=backend_args,
382
- enable_singleton=True,
383
- backend_key=backend_key,
384
- )
385
- return backend.isfile(filepath)
386
-
387
-
388
- def join_path(
389
- filepath: Union[str, Path],
390
- *filepaths: Union[str, Path],
391
- backend_args: Optional[dict] = None,
392
- backend_key: Optional[str] = None,
393
- ) -> Union[str, Path]:
394
- r"""Concatenate all file paths.
395
-
396
- Join one or more filepath components intelligently. The return value
397
- is the concatenation of filepath and any members of \*filepaths.
398
-
399
- Args:
400
- filepath (str or Path): Path to be concatenated.
401
- *filepaths (str or Path): Other paths to be concatenated.
402
- backend_args (dict, optional): Arguments to instantiate the
403
- corresponding backend. Defaults to None.
404
- backend_key (str, optional): The key to get the backend from register.
405
-
406
- Returns:
407
- str: The result of concatenation.
408
-
409
- Examples:
410
- >>> filepath1 = '/path/of/dir1'
411
- >>> filepath2 = 'dir2'
412
- >>> filepath3 = 'path/of/file'
413
- >>> join_path(filepath1, filepath2, filepath3)
414
- '/path/of/dir/dir2/path/of/file'
415
- """
416
- backend = get_file_backend(
417
- filepath,
418
- backend_args=backend_args,
419
- enable_singleton=True,
420
- backend_key=backend_key,
421
- )
422
- return backend.join_path(filepath, *filepaths)
423
-
424
-
425
- @contextmanager
426
- def get_local_path(
427
- filepath: Union[str, Path],
428
- backend_args: Optional[dict] = None,
429
- backend_key: Optional[str] = None,
430
- ) -> Generator[Union[str, Path], None, None]:
431
- """Download data from ``filepath`` and write the data to local path.
432
-
433
- ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
434
- can be called with ``with`` statement, and when exists from the
435
- ``with`` statement, the temporary path will be released.
436
-
437
- Note:
438
- If the ``filepath`` is a local path, just return itself and it will
439
- not be released (removed).
440
-
441
- Args:
442
- filepath (str or Path): Path to be read data.
443
- backend_args (dict, optional): Arguments to instantiate the
444
- corresponding backend. Defaults to None.
445
-
446
- Yields:
447
- Iterable[str]: Only yield one path.
448
-
449
- Examples:
450
- >>> with get_local_path('s3://bucket/abc.jpg') as path:
451
- ... # do something here
452
- """
453
- backend = get_file_backend(
454
- filepath,
455
- backend_args=backend_args,
456
- enable_singleton=True,
457
- backend_key=backend_key,
458
- )
459
- with backend.get_local_path(str(filepath)) as local_path:
460
- yield local_path
461
-
462
-
463
- def copyfile(
464
- src: Union[str, Path],
465
- dst: Union[str, Path],
466
- backend_args: Optional[dict] = None,
467
- backend_key: Optional[str] = None,
468
- ) -> Union[str, Path]:
469
- """Copy a file src to dst and return the destination file.
470
-
471
- src and dst should have the same prefix. If dst specifies a directory,
472
- the file will be copied into dst using the base filename from src. If
473
- dst specifies a file that already exists, it will be replaced.
474
-
475
- Args:
476
- src (str or Path): A file to be copied.
477
- dst (str or Path): Copy file to dst.
478
- backend_args (dict, optional): Arguments to instantiate the
479
- corresponding backend. Defaults to None.
480
-
481
- Returns:
482
- str: The destination file.
483
-
484
- Raises:
485
- SameFileError: If src and dst are the same file, a SameFileError will
486
- be raised.
487
-
488
- Examples:
489
- >>> # dst is a file
490
- >>> src = '/path/of/file'
491
- >>> dst = '/path1/of/file1'
492
- >>> # src will be copied to '/path1/of/file1'
493
- >>> copyfile(src, dst)
494
- '/path1/of/file1'
495
-
496
- >>> # dst is a directory
497
- >>> dst = '/path1/of/dir'
498
- >>> # src will be copied to '/path1/of/dir/file'
499
- >>> copyfile(src, dst)
500
- '/path1/of/dir/file'
501
- """
502
- backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
503
- return backend.copyfile(src, dst)
504
-
505
-
506
- def copytree(
507
- src: Union[str, Path],
508
- dst: Union[str, Path],
509
- backend_args: Optional[dict] = None,
510
- backend_key: Optional[str] = None,
511
- ) -> Union[str, Path]:
512
- """Recursively copy an entire directory tree rooted at src to a directory
513
- named dst and return the destination directory.
514
-
515
- src and dst should have the same prefix and dst must not already exist.
516
-
517
- Args:
518
- src (str or Path): A directory to be copied.
519
- dst (str or Path): Copy directory to dst.
520
- backend_args (dict, optional): Arguments to instantiate the
521
- corresponding backend. Defaults to None.
522
- backend_key (str, optional): The key to get the backend from register.
523
-
524
- Returns:
525
- str: The destination directory.
526
-
527
- Raises:
528
- FileExistsError: If dst had already existed, a FileExistsError will be
529
- raised.
530
-
531
- Examples:
532
- >>> src = '/path/of/dir1'
533
- >>> dst = '/path/of/dir2'
534
- >>> copytree(src, dst)
535
- '/path/of/dir2'
536
- """
537
- backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
538
- return backend.copytree(src, dst)
539
-
540
-
541
- def copyfile_from_local(
542
- src: Union[str, Path],
543
- dst: Union[str, Path],
544
- backend_args: Optional[dict] = None,
545
- backend_key: Optional[str] = None,
546
- ) -> Union[str, Path]:
547
- """Copy a local file src to dst and return the destination file.
548
-
549
- Note:
550
- If the backend is the instance of LocalBackend, it does the same
551
- thing with :func:`copyfile`.
552
-
553
- Args:
554
- src (str or Path): A local file to be copied.
555
- dst (str or Path): Copy file to dst.
556
- backend_args (dict, optional): Arguments to instantiate the
557
- corresponding backend. Defaults to None.
558
-
559
- Returns:
560
- str: If dst specifies a directory, the file will be copied into dst
561
- using the base filename from src.
562
-
563
- Examples:
564
- >>> # dst is a file
565
- >>> src = '/path/of/file'
566
- >>> dst = 's3://openmmlab/mmengine/file1'
567
- >>> # src will be copied to 's3://openmmlab/mmengine/file1'
568
- >>> copyfile_from_local(src, dst)
569
- s3://openmmlab/mmengine/file1
570
-
571
- >>> # dst is a directory
572
- >>> dst = 's3://openmmlab/mmengine'
573
- >>> # src will be copied to 's3://openmmlab/mmengine/file''
574
- >>> copyfile_from_local(src, dst)
575
- 's3://openmmlab/mmengine/file'
576
- """
577
- backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
578
- return backend.copyfile_from_local(src, dst)
579
-
580
-
581
- def copytree_from_local(
582
- src: Union[str, Path],
583
- dst: Union[str, Path],
584
- backend_args: Optional[dict] = None,
585
- backend_key: Optional[str] = None,
586
- ) -> Union[str, Path]:
587
- """Recursively copy an entire directory tree rooted at src to a directory
588
- named dst and return the destination directory.
589
-
590
- Note:
591
- If the backend is the instance of LocalBackend, it does the same
592
- thing with :func:`copytree`.
593
-
594
- Args:
595
- src (str or Path): A local directory to be copied.
596
- dst (str or Path): Copy directory to dst.
597
- backend_args (dict, optional): Arguments to instantiate the
598
- corresponding backend. Defaults to None.
599
-
600
- Returns:
601
- str: The destination directory.
602
-
603
- Examples:
604
- >>> src = '/path/of/dir'
605
- >>> dst = 's3://openmmlab/mmengine/dir'
606
- >>> copyfile_from_local(src, dst)
607
- 's3://openmmlab/mmengine/dir'
608
- """
609
- backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
610
- return backend.copytree_from_local(src, dst)
611
-
612
-
613
- def copyfile_to_local(
614
- src: Union[str, Path],
615
- dst: Union[str, Path],
616
- dst_type: str, # Choose from ["file", "dir"]
617
- backend_args: Optional[dict] = None,
618
- backend_key: Optional[str] = None,
619
- ) -> Union[str, Path]:
620
- """Copy the file src to local dst and return the destination file.
621
-
622
- If dst specifies a directory, the file will be copied into dst using
623
- the base filename from src. If dst specifies a file that already
624
- exists, it will be replaced.
625
-
626
- Note:
627
- If the backend is the instance of LocalBackend, it does the same
628
- thing with :func:`copyfile`.
629
-
630
- Args:
631
- src (str or Path): A file to be copied.
632
- dst (str or Path): Copy file to to local dst.
633
- backend_args (dict, optional): Arguments to instantiate the
634
- corresponding backend. Defaults to None.
635
-
636
- Returns:
637
- str: If dst specifies a directory, the file will be copied into dst
638
- using the base filename from src.
639
-
640
- Examples:
641
- >>> # dst is a file
642
- >>> src = 's3://openmmlab/mmengine/file'
643
- >>> dst = '/path/of/file'
644
- >>> # src will be copied to '/path/of/file'
645
- >>> copyfile_to_local(src, dst)
646
- '/path/of/file'
647
-
648
- >>> # dst is a directory
649
- >>> dst = '/path/of/dir'
650
- >>> # src will be copied to '/path/of/dir/file'
651
- >>> copyfile_to_local(src, dst)
652
- '/path/of/dir/file'
653
- """
654
- assert dst_type in ["file", "dir"]
655
- Path(dst).parent.mkdir(parents=True, exist_ok=True)
656
- backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
657
- return backend.copyfile_to_local(src, dst, dst_type=dst_type)
658
-
659
-
660
- def copytree_to_local(
661
- src: Union[str, Path],
662
- dst: Union[str, Path],
663
- backend_args: Optional[dict] = None,
664
- backend_key: Optional[str] = None,
665
- ) -> Union[str, Path]:
666
- """Recursively copy an entire directory tree rooted at src to a local
667
- directory named dst and return the destination directory.
668
-
669
- Note:
670
- If the backend is the instance of LocalBackend, it does the same
671
- thing with :func:`copytree`.
672
-
673
- Args:
674
- src (str or Path): A directory to be copied.
675
- dst (str or Path): Copy directory to local dst.
676
- backend_args (dict, optional): Arguments to instantiate the
677
- corresponding backend. Defaults to None.
678
-
679
- Returns:
680
- str: The destination directory.
681
-
682
- Examples:
683
- >>> src = 's3://openmmlab/mmengine/dir'
684
- >>> dst = '/path/of/dir'
685
- >>> copytree_to_local(src, dst)
686
- '/path/of/dir'
687
- """
688
- Path(dst).parent.mkdir(parents=True, exist_ok=True)
689
- backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
690
- return backend.copytree_to_local(src, dst)
691
-
692
-
693
- def remove(
694
- filepath: Union[str, Path],
695
- backend_args: Optional[dict] = None,
696
- backend_key: Optional[str] = None,
697
- ) -> None:
698
- """Remove a file.
699
-
700
- Args:
701
- filepath (str, Path): Path to be removed.
702
- backend_args (dict, optional): Arguments to instantiate the
703
- corresponding backend. Defaults to None.
704
-
705
- Raises:
706
- FileNotFoundError: If filepath does not exist, an FileNotFoundError
707
- will be raised.
708
- IsADirectoryError: If filepath is a directory, an IsADirectoryError
709
- will be raised.
710
-
711
- Examples:
712
- >>> filepath = '/path/of/file'
713
- >>> remove(filepath)
714
- """
715
- backend = get_file_backend(
716
- filepath,
717
- backend_args=backend_args,
718
- enable_singleton=True,
719
- backend_key=backend_key,
720
- )
721
- backend.remove(filepath)
722
-
723
-
724
- def rmtree(
725
- dir_path: Union[str, Path],
726
- backend_args: Optional[dict] = None,
727
- backend_key: Optional[str] = None,
728
- ) -> None:
729
- """Recursively delete a directory tree.
730
-
731
- Args:
732
- dir_path (str or Path): A directory to be removed.
733
- backend_args (dict, optional): Arguments to instantiate the
734
- corresponding backend. Defaults to None.
735
-
736
- Examples:
737
- >>> dir_path = '/path/of/dir'
738
- >>> rmtree(dir_path)
739
- """
740
- backend = get_file_backend(
741
- dir_path,
742
- backend_args=backend_args,
743
- enable_singleton=True,
744
- backend_key=backend_key,
745
- )
746
- backend.rmtree(dir_path)
747
-
748
-
749
- def copy_if_symlink_fails(
750
- src: Union[str, Path],
751
- dst: Union[str, Path],
752
- backend_args: Optional[dict] = None,
753
- backend_key: Optional[str] = None,
754
- ) -> bool:
755
- """Create a symbolic link pointing to src named dst.
756
-
757
- If failed to create a symbolic link pointing to src, directory copy src to
758
- dst instead.
759
-
760
- Args:
761
- src (str or Path): Create a symbolic link pointing to src.
762
- dst (str or Path): Create a symbolic link named dst.
763
- backend_args (dict, optional): Arguments to instantiate the
764
- corresponding backend. Defaults to None.
765
-
766
- Returns:
767
- bool: Return True if successfully create a symbolic link pointing to
768
- src. Otherwise, return False.
769
-
770
- Examples:
771
- >>> src = '/path/of/file'
772
- >>> dst = '/path1/of/file1'
773
- >>> copy_if_symlink_fails(src, dst)
774
- True
775
- >>> src = '/path/of/dir'
776
- >>> dst = '/path1/of/dir1'
777
- >>> copy_if_symlink_fails(src, dst)
778
- True
779
- """
780
- backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
781
- return backend.copy_if_symlink_fails(src, dst)
782
-
783
-
784
- def list_dir(
785
- dir_path: Union[str, Path],
786
- backend_args: Optional[dict] = None,
787
- backend_key: Optional[str] = None,
788
- ):
789
- """List all folders in an S3 bucket with a given prefix.
790
-
791
- Args:
792
- dir_path (str | Path): Path of the directory.
793
-
794
- Examples:
795
- >>> dir_path = '/path/of/dir'
796
- >>> for file_path in list_dir(dir_path):
797
- ... print(file_path)
798
- """
799
- if not dir_path.endswith("/"):
800
- dir_path += "/"
801
- backend = get_file_backend(
802
- dir_path,
803
- backend_args=backend_args,
804
- enable_singleton=True,
805
- backend_key=backend_key,
806
- )
807
-
808
- return backend.list_dir(dir_path)
809
-
810
-
811
- def list_dir_or_file(
812
- dir_path: Union[str, Path],
813
- list_dir: bool = True,
814
- list_file: bool = True,
815
- suffix: Optional[Union[str, Tuple[str]]] = None,
816
- recursive: bool = False,
817
- backend_args: Optional[dict] = None,
818
- backend_key: Optional[str] = None,
819
- ) -> Iterator[str]:
820
- """Scan a directory to find the interested directories or files in
821
- arbitrary order.
822
-
823
- Note:
824
- :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
825
-
826
- Args:
827
- dir_path (str or Path): Path of the directory.
828
- list_dir (bool): List the directories. Defaults to True.
829
- list_file (bool): List the path of files. Defaults to True.
830
- suffix (str or tuple[str], optional): File suffix that we are
831
- interested in. Defaults to None.
832
- recursive (bool): If set to True, recursively scan the directory.
833
- Defaults to False.
834
- backend_args (dict, optional): Arguments to instantiate the
835
- corresponding backend. Defaults to None.
836
-
837
- Yields:
838
- Iterable[str]: A relative path to ``dir_path``.
839
-
840
- Examples:
841
- >>> dir_path = '/path/of/dir'
842
- >>> for file_path in list_dir_or_file(dir_path):
843
- ... print(file_path)
844
- >>> # list those files and directories in current directory
845
- >>> for file_path in list_dir_or_file(dir_path):
846
- ... print(file_path)
847
- >>> # only list files
848
- >>> for file_path in list_dir_or_file(dir_path, list_dir=False):
849
- ... print(file_path)
850
- >>> # only list directories
851
- >>> for file_path in list_dir_or_file(dir_path, list_file=False):
852
- ... print(file_path)
853
- >>> # only list files ending with specified suffixes
854
- >>> for file_path in list_dir_or_file(dir_path, suffix='.txt'):
855
- ... print(file_path)
856
- >>> # list all files and directory recursively
857
- >>> for file_path in list_dir_or_file(dir_path, recursive=True):
858
- ... print(file_path)
859
- """
860
- backend = get_file_backend(
861
- dir_path,
862
- backend_args=backend_args,
863
- enable_singleton=True,
864
- backend_key=backend_key,
865
- )
866
- yield from backend.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive)
867
-
868
-
869
- def generate_presigned_url(
870
- url: str,
871
- client_method: str = "get_object",
872
- expires_in: int = 3600,
873
- backend_args: Optional[dict] = None,
874
- backend_key: Optional[str] = None,
875
- ) -> str:
876
- """Generate the presigned url of video stream which can be passed to
877
- mmcv.VideoReader. Now only work on s3 backend.
878
-
879
- Note:
880
- Now only work on s3 backend.
881
-
882
- Args:
883
- url (str): Url of video stream.
884
- client_method (str): Method of client, 'get_object' or
885
- 'put_object'. Defaults to 'get_object'.
886
- expires_in (int): expires, in seconds. Defaults to 3600.
887
- backend_args (dict, optional): Arguments to instantiate the
888
- corresponding backend. Defaults to None.
889
-
890
- Returns:
891
- str: Generated presigned url.
892
- """
893
- backend = get_file_backend(url, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
894
- return backend.generate_presigned_url(url, client_method, expires_in)
895
-
896
-
897
- def load(
898
- file: Union[str, Path, IO[Any]],
899
- file_format: Optional[str] = None,
900
- file_client_args: Optional[dict] = None,
901
- fast_backend: bool = False,
902
- backend_args: Optional[dict] = None,
903
- backend_key: Optional[str] = None,
904
- **kwargs,
905
- ):
906
- """Load data from json/yaml/pickle files.
907
-
908
- This method provides a unified api for loading data from serialized files.
909
-
910
- ``load`` supports loading data from serialized files those can be storaged
911
- in different backends.
912
-
913
- Args:
914
- file (str or :obj:`Path` or file-like object): Filename or a file-like
915
- object.
916
- file_format (str, optional): If not specified, the file format will be
917
- inferred from the file extension, otherwise use the specified one.
918
- Currently supported formats include "json", "yaml/yml" and
919
- "pickle/pkl".
920
- file_client_args (dict, optional): Arguments to instantiate a
921
- FileClient. See :class:`mmengine.fileio.FileClient` for details.
922
- Defaults to None. It will be deprecated in future. Please use
923
- ``backend_args`` instead.
924
- fast_backend: bool: Whether to use multiprocess. Defaults to False.
925
- backend_args (dict, optional): Arguments to instantiate the
926
- prefix of uri corresponding backend. Defaults to None.
927
- New in v0.2.0.
928
-
929
- Examples:
930
- >>> load('/path/of/your/file') # file is storaged in disk
931
- >>> load('https://path/of/your/file') # file is storaged in Internet
932
- >>> load('s3://path/of/your/file') # file is storaged in s3
933
-
934
- Returns:
935
- The content from the file.
936
- """
937
- if isinstance(file, Path):
938
- file = str(file)
939
- if file_format is None and isinstance(file, str):
940
- file_format = file.split(".")[-1]
941
- # convert file_format to lower case
942
- file_format = file_format.lower()
943
- if file_format not in file_handlers:
944
- raise TypeError(f"Unsupported format: {file_format}")
945
-
946
- if file_client_args is not None:
947
- warnings.warn(
948
- '"file_client_args" will be deprecated in future. Please use "backend_args" instead',
949
- DeprecationWarning,
950
- )
951
- if backend_args is not None:
952
- raise ValueError('"file_client_args and "backend_args" cannot be set at the same time.')
953
-
954
- handler = file_handlers[file_format]
955
- if isinstance(file, str):
956
- if file_client_args is not None:
957
- file_client = FileClient.infer_client(file_client_args, file)
958
- file_backend = file_client
959
- else:
960
- file_backend = get_file_backend(
961
- file,
962
- backend_args=backend_args,
963
- backend_key=backend_key,
964
- enable_singleton=True,
965
- )
966
-
967
- if handler.str_like:
968
- with StringIO(file_backend.get_text(file)) as f:
969
- obj = handler.load_from_fileobj(f, **kwargs)
970
- else:
971
- if fast_backend:
972
- if hasattr(file_backend, "fast_get"):
973
- with BytesIO(file_backend.fast_get(file)) as f:
974
- obj = handler.load_from_fileobj(f, **kwargs)
975
- else:
976
- warnings.warn(
977
- f"fast_backend is not supported by the backend, type {type(file_backend)} fallback to normal get"
978
- )
979
- with BytesIO(file_backend.get(file)) as f:
980
- obj = handler.load_from_fileobj(f, **kwargs)
981
- else:
982
- with BytesIO(file_backend.get(file)) as f:
983
- obj = handler.load_from_fileobj(f, **kwargs)
984
- elif hasattr(file, "read"):
985
- obj = handler.load_from_fileobj(file, **kwargs)
986
- else:
987
- raise TypeError('"file" must be a filepath str or a file-object')
988
- return obj
989
-
990
-
991
- def dump(
992
- obj: Any,
993
- file: Union[str, Path, IO[Any], None] = None,
994
- file_format: Optional[str] = None,
995
- file_client_args: Optional[dict] = None,
996
- fast_backend: bool = False,
997
- backend_args: Optional[dict] = None,
998
- backend_key: Optional[str] = None,
999
- **kwargs,
1000
- ):
1001
- """Dump data to json/yaml/pickle strings or files.
1002
-
1003
- This method provides a unified api for dumping data as strings or to files,
1004
- and also supports custom arguments for each file format.
1005
-
1006
- ``dump`` supports dumping data as strings or to files which is saved to
1007
- different backends.
1008
-
1009
- Args:
1010
- obj (any): The python object to be dumped.
1011
- file (str or :obj:`Path` or file-like object, optional): If not
1012
- specified, then the object is dumped to a str, otherwise to a file
1013
- specified by the filename or file-like object.
1014
- file_format (str, optional): Same as :func:`load`.
1015
- file_client_args (dict, optional): Arguments to instantiate a
1016
- FileClient. See :class:`mmengine.fileio.FileClient` for details.
1017
- Defaults to None. It will be deprecated in future. Please use
1018
- ``backend_args`` instead.
1019
- fast_backend: bool: Whether to use multiprocess. Defaults to False.
1020
- backend_args (dict, optional): Arguments to instantiate the
1021
- prefix of uri corresponding backend. Defaults to None.
1022
- New in v0.2.0.
1023
- backend_key: str: The key to register the backend. Defaults to None.
1024
-
1025
- Examples:
1026
- >>> dump('hello world', '/path/of/your/file') # disk
1027
- >>> dump('hello world', 's3://path/of/your/file') # ceph or s3
1028
-
1029
- Returns:
1030
- bool: True for success, False otherwise.
1031
- """
1032
- if isinstance(file, Path):
1033
- file = str(file)
1034
- if file_format is None:
1035
- if isinstance(file, str):
1036
- file_format = file.split(".")[-1]
1037
- elif file is None:
1038
- raise ValueError("file_format must be specified since file is None")
1039
- # convert file_format to lower case
1040
- file_format = file_format.lower()
1041
- if file_format not in file_handlers:
1042
- raise TypeError(f"Unsupported format: {file_format}")
1043
-
1044
- if file_client_args is not None:
1045
- warnings.warn(
1046
- '"file_client_args" will be deprecated in future. Please use "backend_args" instead',
1047
- DeprecationWarning,
1048
- )
1049
- if backend_args is not None:
1050
- raise ValueError('"file_client_args" and "backend_args" cannot be set at the same time.')
1051
-
1052
- handler = file_handlers[file_format]
1053
- if file is None:
1054
- return handler.dump_to_str(obj, **kwargs)
1055
- elif isinstance(file, str):
1056
- if file_client_args is not None:
1057
- file_client = FileClient.infer_client(file_client_args, file)
1058
- file_backend = file_client
1059
- else:
1060
- file_backend = get_file_backend(
1061
- file,
1062
- backend_args=backend_args,
1063
- backend_key=backend_key,
1064
- enable_singleton=True,
1065
- )
1066
-
1067
- if handler.str_like:
1068
- with StringIO() as f:
1069
- handler.dump_to_fileobj(obj, f, **kwargs)
1070
- file_backend.put_text(f.getvalue(), file)
1071
- else:
1072
- with BytesIO() as f:
1073
- handler.dump_to_fileobj(obj, f, **kwargs)
1074
- if fast_backend:
1075
- if hasattr(file_backend, "fast_put"):
1076
- file_backend.fast_put(f, file)
1077
- else:
1078
- warnings.warn("fast_backend is not supported by the backend, fallback to normal put")
1079
- file_backend.put(f, file)
1080
- else:
1081
- file_backend.put(f, file)
1082
- elif hasattr(file, "write"):
1083
- handler.dump_to_fileobj(obj, file, **kwargs)
1084
- else:
1085
- raise TypeError('"file" must be a filename str or a file-object')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/file_client.py DELETED
@@ -1,458 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import inspect
17
- from contextlib import contextmanager
18
- from pathlib import Path
19
- from typing import Any, Generator, Iterator, Optional, Tuple, Union
20
-
21
- from lyra_2._ext.imaginaire.utils.easy_io.backends import (
22
- BaseStorageBackend,
23
- Boto3Backend,
24
- HTTPBackend,
25
- LocalBackend,
26
- )
27
-
28
-
29
- def is_filepath(filepath):
30
- return isinstance(filepath, (str, Path))
31
-
32
-
33
- class HardDiskBackend(LocalBackend):
34
- """Raw hard disks storage backend."""
35
-
36
- @property
37
- def name(self):
38
- return self.__class__.__name__
39
-
40
-
41
- class FileClient:
42
- """A general file client to access files in different backends.
43
-
44
- The client loads a file or text in a specified backend from its path
45
- and returns it as a binary or text file. There are two ways to choose a
46
- backend, the name of backend and the prefix of path. Although both of them
47
- can be used to choose a storage backend, ``backend`` has a higher priority
48
- that is if they are all set, the storage backend will be chosen by the
49
- backend argument. If they are all `None`, the disk backend will be chosen.
50
- Note that It can also register other backend accessor with a given name,
51
- prefixes, and backend class. In addition, We use the singleton pattern to
52
- avoid repeated object creation. If the arguments are the same, the same
53
- object will be returned.
54
-
55
- Warning:
56
- `FileClient` will be deprecated in future. Please use io functions
57
- in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io
58
-
59
- Args:
60
- backend (str, optional): The storage backend type. Options are "disk",
61
- "memcached", "lmdb", "http" and "s3". Defaults to None.
62
- prefix (str, optional): The prefix of the registered storage backend.
63
- Options are "s3", "http", "https". Defaults to None.
64
-
65
- Examples:
66
- >>> # only set backend
67
- >>> file_client = FileClient(backend='s3')
68
- >>> # only set prefix
69
- >>> file_client = FileClient(prefix='s3')
70
- >>> # set both backend and prefix but use backend to choose client
71
- >>> file_client = FileClient(backend='s3', prefix='s3')
72
- >>> # if the arguments are the same, the same object is returned
73
- >>> file_client1 = FileClient(backend='s3')
74
- >>> file_client1 is file_client
75
- True
76
-
77
- Attributes:
78
- client (:obj:`BaseStorageBackend`): The backend object.
79
- """
80
-
81
- _backends = {
82
- "disk": HardDiskBackend,
83
- "s3": Boto3Backend,
84
- "http": HTTPBackend,
85
- }
86
-
87
- _prefix_to_backends: dict = {
88
- "s3": Boto3Backend,
89
- "http": HTTPBackend,
90
- "https": HTTPBackend,
91
- }
92
-
93
- _instances: dict = {}
94
-
95
- client: Any
96
-
97
- def __new__(cls, backend=None, prefix=None, **kwargs):
98
- if backend is None and prefix is None:
99
- backend = "disk"
100
- if backend is not None and backend not in cls._backends:
101
- raise ValueError(
102
- f"Backend {backend} is not supported. Currently supported ones are {list(cls._backends.keys())}"
103
- )
104
- if prefix is not None and prefix not in cls._prefix_to_backends:
105
- raise ValueError(
106
- f"prefix {prefix} is not supported. Currently supported ones are {list(cls._prefix_to_backends.keys())}"
107
- )
108
-
109
- # concatenate the arguments to a unique key for determining whether
110
- # objects with the same arguments were created
111
- arg_key = f"{backend}:{prefix}"
112
- for key, value in kwargs.items():
113
- arg_key += f":{key}:{value}"
114
-
115
- # if a backend was overridden, it will create a new object
116
- if arg_key in cls._instances:
117
- _instance = cls._instances[arg_key]
118
- else:
119
- # create a new object and put it to _instance
120
- _instance = super().__new__(cls)
121
- if backend is not None:
122
- _instance.client = cls._backends[backend](**kwargs)
123
- else:
124
- _instance.client = cls._prefix_to_backends[prefix](**kwargs)
125
-
126
- cls._instances[arg_key] = _instance
127
-
128
- return _instance
129
-
130
- @property
131
- def name(self):
132
- return self.client.name
133
-
134
- @property
135
- def allow_symlink(self):
136
- return self.client.allow_symlink
137
-
138
- @staticmethod
139
- def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]:
140
- """Parse the prefix of a uri.
141
-
142
- Args:
143
- uri (str | Path): Uri to be parsed that contains the file prefix.
144
-
145
- Examples:
146
- >>> FileClient.parse_uri_prefix('s3://path/of/your/file')
147
- 's3'
148
-
149
- Returns:
150
- str | None: Return the prefix of uri if the uri contains '://' else
151
- ``None``.
152
- """
153
- assert is_filepath(uri)
154
- uri = str(uri)
155
- if "://" not in uri:
156
- return None
157
- else:
158
- prefix, _ = uri.split("://")
159
- # In the case of Boto3Backend, the prefix may contains the cluster
160
- # name like clusterName:s3
161
- if ":" in prefix:
162
- _, prefix = prefix.split(":")
163
- return prefix
164
-
165
- @classmethod
166
- def infer_client(
167
- cls,
168
- file_client_args: Optional[dict] = None,
169
- uri: Optional[Union[str, Path]] = None,
170
- ) -> "FileClient":
171
- """Infer a suitable file client based on the URI and arguments.
172
-
173
- Args:
174
- file_client_args (dict, optional): Arguments to instantiate a
175
- FileClient. Defaults to None.
176
- uri (str | Path, optional): Uri to be parsed that contains the file
177
- prefix. Defaults to None.
178
-
179
- Examples:
180
- >>> uri = 's3://path/of/your/file'
181
- >>> file_client = FileClient.infer_client(uri=uri)
182
- >>> file_client_args = {'backend': 's3'}
183
- >>> file_client = FileClient.infer_client(file_client_args)
184
-
185
- Returns:
186
- FileClient: Instantiated FileClient object.
187
- """
188
- assert file_client_args is not None or uri is not None
189
- if file_client_args is None:
190
- file_prefix = cls.parse_uri_prefix(uri) # type: ignore
191
- return cls(prefix=file_prefix)
192
- else:
193
- return cls(**file_client_args)
194
-
195
- @classmethod
196
- def _register_backend(cls, name, backend, force=False, prefixes=None):
197
- if not isinstance(name, str):
198
- raise TypeError(f"the backend name should be a string, but got {type(name)}")
199
- if not inspect.isclass(backend):
200
- raise TypeError(f"backend should be a class but got {type(backend)}")
201
- if not issubclass(backend, BaseStorageBackend):
202
- raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend")
203
- if not force and name in cls._backends:
204
- raise KeyError(
205
- f'{name} is already registered as a storage backend, add "force=True" if you want to override it'
206
- )
207
-
208
- if name in cls._backends and force:
209
- for arg_key, instance in list(cls._instances.items()):
210
- if isinstance(instance.client, cls._backends[name]):
211
- cls._instances.pop(arg_key)
212
- cls._backends[name] = backend
213
-
214
- if prefixes is not None:
215
- if isinstance(prefixes, str):
216
- prefixes = [prefixes]
217
- else:
218
- assert isinstance(prefixes, (list, tuple))
219
- for prefix in prefixes:
220
- if prefix not in cls._prefix_to_backends:
221
- cls._prefix_to_backends[prefix] = backend
222
- elif (prefix in cls._prefix_to_backends) and force:
223
- overridden_backend = cls._prefix_to_backends[prefix]
224
- for arg_key, instance in list(cls._instances.items()):
225
- if isinstance(instance.client, overridden_backend):
226
- cls._instances.pop(arg_key)
227
- else:
228
- raise KeyError(
229
- f"{prefix} is already registered as a storage backend,"
230
- ' add "force=True" if you want to override it'
231
- )
232
-
233
- @classmethod
234
- def register_backend(cls, name, backend=None, force=False, prefixes=None):
235
- """Register a backend to FileClient.
236
-
237
- This method can be used as a normal class method or a decorator.
238
-
239
- .. code-block:: python
240
-
241
- class NewBackend(BaseStorageBackend):
242
-
243
- def get(self, filepath):
244
- return filepath
245
-
246
- def get_text(self, filepath):
247
- return filepath
248
-
249
- FileClient.register_backend('new', NewBackend)
250
-
251
- or
252
-
253
- .. code-block:: python
254
-
255
- @FileClient.register_backend('new')
256
- class NewBackend(BaseStorageBackend):
257
-
258
- def get(self, filepath):
259
- return filepath
260
-
261
- def get_text(self, filepath):
262
- return filepath
263
-
264
- Args:
265
- name (str): The name of the registered backend.
266
- backend (class, optional): The backend class to be registered,
267
- which must be a subclass of :class:`BaseStorageBackend`.
268
- When this method is used as a decorator, backend is None.
269
- Defaults to None.
270
- force (bool, optional): Whether to override the backend if the name
271
- has already been registered. Defaults to False.
272
- prefixes (str or list[str] or tuple[str], optional): The prefixes
273
- of the registered storage backend. Defaults to None.
274
- `New in version 1.3.15.`
275
- """
276
- if backend is not None:
277
- cls._register_backend(name, backend, force=force, prefixes=prefixes)
278
- return
279
-
280
- def _register(backend_cls):
281
- cls._register_backend(name, backend_cls, force=force, prefixes=prefixes)
282
- return backend_cls
283
-
284
- return _register
285
-
286
- def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]:
287
- """Read data from a given ``filepath`` with 'rb' mode.
288
-
289
- Note:
290
- There are two types of return values for ``get``, one is ``bytes``
291
- and the other is ``memoryview``. The advantage of using memoryview
292
- is that you can avoid copying, and if you want to convert it to
293
- ``bytes``, you can use ``.tobytes()``.
294
-
295
- Args:
296
- filepath (str or Path): Path to read data.
297
-
298
- Returns:
299
- bytes | memoryview: Expected bytes object or a memory view of the
300
- bytes object.
301
- """
302
- return self.client.get(filepath)
303
-
304
- def get_text(self, filepath: Union[str, Path], encoding="utf-8") -> str:
305
- """Read data from a given ``filepath`` with 'r' mode.
306
-
307
- Args:
308
- filepath (str or Path): Path to read data.
309
- encoding (str): The encoding format used to open the ``filepath``.
310
- Defaults to 'utf-8'.
311
-
312
- Returns:
313
- str: Expected text reading from ``filepath``.
314
- """
315
- return self.client.get_text(filepath, encoding)
316
-
317
- def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
318
- """Write data to a given ``filepath`` with 'wb' mode.
319
-
320
- Note:
321
- ``put`` should create a directory if the directory of ``filepath``
322
- does not exist.
323
-
324
- Args:
325
- obj (bytes): Data to be written.
326
- filepath (str or Path): Path to write data.
327
- """
328
- self.client.put(obj, filepath)
329
-
330
- def put_text(self, obj: str, filepath: Union[str, Path]) -> None:
331
- """Write data to a given ``filepath`` with 'w' mode.
332
-
333
- Note:
334
- ``put_text`` should create a directory if the directory of
335
- ``filepath`` does not exist.
336
-
337
- Args:
338
- obj (str): Data to be written.
339
- filepath (str or Path): Path to write data.
340
- encoding (str, optional): The encoding format used to open the
341
- `filepath`. Defaults to 'utf-8'.
342
- """
343
- self.client.put_text(obj, filepath)
344
-
345
- def remove(self, filepath: Union[str, Path]) -> None:
346
- """Remove a file.
347
-
348
- Args:
349
- filepath (str, Path): Path to be removed.
350
- """
351
- self.client.remove(filepath)
352
-
353
- def exists(self, filepath: Union[str, Path]) -> bool:
354
- """Check whether a file path exists.
355
-
356
- Args:
357
- filepath (str or Path): Path to be checked whether exists.
358
-
359
- Returns:
360
- bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
361
- """
362
- return self.client.exists(filepath)
363
-
364
- def isdir(self, filepath: Union[str, Path]) -> bool:
365
- """Check whether a file path is a directory.
366
-
367
- Args:
368
- filepath (str or Path): Path to be checked whether it is a
369
- directory.
370
-
371
- Returns:
372
- bool: Return ``True`` if ``filepath`` points to a directory,
373
- ``False`` otherwise.
374
- """
375
- return self.client.isdir(filepath)
376
-
377
- def isfile(self, filepath: Union[str, Path]) -> bool:
378
- """Check whether a file path is a file.
379
-
380
- Args:
381
- filepath (str or Path): Path to be checked whether it is a file.
382
-
383
- Returns:
384
- bool: Return ``True`` if ``filepath`` points to a file, ``False``
385
- otherwise.
386
- """
387
- return self.client.isfile(filepath)
388
-
389
- def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str:
390
- r"""Concatenate all file paths.
391
-
392
- Join one or more filepath components intelligently. The return value
393
- is the concatenation of filepath and any members of \*filepaths.
394
-
395
- Args:
396
- filepath (str or Path): Path to be concatenated.
397
-
398
- Returns:
399
- str: The result of concatenation.
400
- """
401
- return self.client.join_path(filepath, *filepaths)
402
-
403
- @contextmanager
404
- def get_local_path(self, filepath: Union[str, Path]) -> Generator[Union[str, Path], None, None]:
405
- """Download data from ``filepath`` and write the data to local path.
406
-
407
- ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
408
- can be called with ``with`` statement, and when exists from the
409
- ``with`` statement, the temporary path will be released.
410
-
411
- Note:
412
- If the ``filepath`` is a local path, just return itself.
413
-
414
- .. warning::
415
- ``get_local_path`` is an experimental interface that may change in
416
- the future.
417
-
418
- Args:
419
- filepath (str or Path): Path to be read data.
420
-
421
- Examples:
422
- >>> file_client = FileClient(prefix='s3')
423
- >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path:
424
- ... # do something here
425
-
426
- Yields:
427
- Iterable[str]: Only yield one path.
428
- """
429
- with self.client.get_local_path(str(filepath)) as local_path:
430
- yield local_path
431
-
432
- def list_dir_or_file( # pylint: disable=too-many-arguments
433
- self,
434
- dir_path: Union[str, Path],
435
- list_dir: bool = True,
436
- list_file: bool = True,
437
- suffix: Optional[Union[str, Tuple[str]]] = None,
438
- recursive: bool = False,
439
- ) -> Iterator[str]:
440
- """Scan a directory to find the interested directories or files in
441
- arbitrary order.
442
-
443
- Note:
444
- :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
445
-
446
- Args:
447
- dir_path (str | Path): Path of the directory.
448
- list_dir (bool): List the directories. Defaults to True.
449
- list_file (bool): List the path of files. Defaults to True.
450
- suffix (str or tuple[str], optional): File suffix
451
- that we are interested in. Defaults to None.
452
- recursive (bool): If set to True, recursively scan the
453
- directory. Defaults to False.
454
-
455
- Yields:
456
- Iterable[str]: A relative path to ``dir_path``.
457
- """
458
- yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/handlers/__init__.py DELETED
@@ -1,29 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler
17
- from lyra_2._ext.imaginaire.utils.easy_io.handlers.json_handler import JsonHandler
18
- from lyra_2._ext.imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler
19
- from lyra_2._ext.imaginaire.utils.easy_io.handlers.registry_utils import file_handlers, register_handler
20
- from lyra_2._ext.imaginaire.utils.easy_io.handlers.yaml_handler import YamlHandler
21
-
22
- __all__ = [
23
- "BaseFileHandler",
24
- "JsonHandler",
25
- "PickleHandler",
26
- "YamlHandler",
27
- "register_handler",
28
- "file_handlers",
29
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/handlers/base.py DELETED
@@ -1,44 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from abc import ABCMeta, abstractmethod
17
-
18
-
19
- class BaseFileHandler(metaclass=ABCMeta):
20
- # `str_like` is a flag to indicate whether the type of file object is
21
- # str-like object or bytes-like object. Pickle only processes bytes-like
22
- # objects but json only processes str-like object. If it is str-like
23
- # object, `StringIO` will be used to process the buffer.
24
- str_like = True
25
-
26
- @abstractmethod
27
- def load_from_fileobj(self, file, **kwargs):
28
- pass
29
-
30
- @abstractmethod
31
- def dump_to_fileobj(self, obj, file, **kwargs):
32
- pass
33
-
34
- @abstractmethod
35
- def dump_to_str(self, obj, **kwargs):
36
- pass
37
-
38
- def load_from_path(self, filepath, mode="r", **kwargs):
39
- with open(filepath, mode) as f:
40
- return self.load_from_fileobj(f, **kwargs)
41
-
42
- def dump_to_path(self, obj, filepath, mode="w", **kwargs):
43
- with open(filepath, mode) as f:
44
- self.dump_to_fileobj(obj, f, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/handlers/byte_handler.py DELETED
@@ -1,39 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import IO
17
-
18
- from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler
19
-
20
-
21
- class ByteHandler(BaseFileHandler):
22
- str_like = False
23
-
24
- def load_from_fileobj(self, file: IO[bytes], **kwargs):
25
- file.seek(0)
26
- # extra all bytes and return
27
- return file.read()
28
-
29
- def dump_to_fileobj(
30
- self,
31
- obj: bytes,
32
- file: IO[bytes],
33
- **kwargs,
34
- ):
35
- # write all bytes to file
36
- file.write(obj)
37
-
38
- def dump_to_str(self, obj, **kwargs):
39
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/handlers/csv_handler.py DELETED
@@ -1,42 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import csv
17
- from io import StringIO
18
-
19
- from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler
20
-
21
-
22
- class CsvHandler(BaseFileHandler):
23
- def load_from_fileobj(self, file, **kwargs):
24
- del kwargs
25
- reader = csv.reader(file)
26
- return list(reader)
27
-
28
- def dump_to_fileobj(self, obj, file, **kwargs):
29
- del kwargs
30
- writer = csv.writer(file)
31
- if not all(isinstance(row, list) for row in obj):
32
- raise ValueError("Each row must be a list")
33
- writer.writerows(obj)
34
-
35
- def dump_to_str(self, obj, **kwargs):
36
- del kwargs
37
- output = StringIO()
38
- writer = csv.writer(output)
39
- if not all(isinstance(row, list) for row in obj):
40
- raise ValueError("Each row must be a list")
41
- writer.writerows(obj)
42
- return output.getvalue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/handlers/gzip_handler.py DELETED
@@ -1,33 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import gzip
17
- import pickle
18
- from io import BytesIO
19
- from typing import Any
20
-
21
- from lyra_2._ext.imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler
22
-
23
-
24
- class GzipHandler(PickleHandler):
25
- str_like = False
26
-
27
- def load_from_fileobj(self, file: BytesIO, **kwargs):
28
- with gzip.GzipFile(fileobj=file, mode="rb") as f:
29
- return pickle.load(f)
30
-
31
- def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs):
32
- with gzip.GzipFile(fileobj=file, mode="wb") as f:
33
- pickle.dump(obj, f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/handlers/imageio_video_handler.py DELETED
@@ -1,168 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import IO, Any, Dict, Tuple
17
-
18
- import imageio
19
- import imageio.v3 as iio_v3
20
- import numpy as np
21
- import torch
22
-
23
- from lyra_2._ext.imaginaire.utils import log
24
- from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler
25
-
26
-
27
- class ImageioVideoHandler(BaseFileHandler):
28
- str_like = False
29
-
30
- def load_from_fileobj(
31
- self, file: IO[bytes], format: str = "mp4", mode: str = "rgb", **kwargs
32
- ) -> Tuple[np.ndarray, Dict[str, Any]]:
33
- """
34
- Load video from a file-like object using imageio.v3 with specified format and color mode.
35
-
36
- Parameters:
37
- file (IO[bytes]): A file-like object containing video data.
38
- format (str): Format of the video file (default 'mp4').
39
- mode (str): Color mode of the video, 'rgb' or 'gray' (default 'rgb').
40
-
41
- Returns:
42
- tuple: A tuple containing an array of video frames and metadata about the video.
43
- """
44
- file.seek(0)
45
-
46
- # The plugin argument in v3 replaces the format argument in v2
47
- plugin = kwargs.pop("plugin", "pyav")
48
-
49
- # Load all frames at once using v3 API
50
- video_frames = iio_v3.imread(file, plugin=plugin, **kwargs)
51
-
52
- # Handle grayscale conversion if needed
53
- if mode == "gray":
54
- import cv2
55
-
56
- if len(video_frames.shape) == 4: # (frames, height, width, channels)
57
- gray_frames = []
58
- for frame in video_frames:
59
- gray_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
60
- gray_frame = np.expand_dims(gray_frame, axis=2) # Keep dimensions consistent
61
- gray_frames.append(gray_frame)
62
- video_frames = np.array(gray_frames)
63
-
64
- # Extract metadata
65
- # Note: iio_v3.imread doesn't return metadata directly like v2 did
66
- # We need to extract it separately
67
- file.seek(0)
68
- metadata = self._extract_metadata(file, plugin=plugin)
69
-
70
- return video_frames, metadata
71
-
72
- def _extract_metadata(self, file: IO[bytes], plugin: str = "pyav") -> Dict[str, Any]:
73
- """
74
- Extract metadata from a video file.
75
-
76
- Parameters:
77
- file (IO[bytes]): File-like object containing video data.
78
- plugin (str): Plugin to use for reading.
79
-
80
- Returns:
81
- dict: Video metadata.
82
- """
83
- try:
84
- # Create a generator to read frames and metadata
85
- metadata = iio_v3.immeta(file, plugin=plugin)
86
-
87
- # Add some standard fields similar to v2 metadata format
88
- if "fps" not in metadata and "duration" in metadata:
89
- # Read the first frame to get shape information
90
- file.seek(0)
91
- first_frame = iio_v3.imread(file, plugin=plugin, index=0)
92
- metadata["size"] = first_frame.shape[1::-1] # (width, height)
93
- metadata["source_size"] = metadata["size"]
94
-
95
- # Create a consistent metadata structure with v2
96
- metadata["plugin"] = plugin
97
- if "codec" not in metadata:
98
- metadata["codec"] = "unknown"
99
- if "pix_fmt" not in metadata:
100
- metadata["pix_fmt"] = "unknown"
101
-
102
- # Calculate nframes if possible
103
- if "fps" in metadata and "duration" in metadata:
104
- metadata["nframes"] = int(metadata["fps"] * metadata["duration"])
105
- else:
106
- metadata["nframes"] = float("inf")
107
-
108
- return metadata
109
-
110
- except Exception as e:
111
- # Fallback to basic metadata
112
- return {
113
- "plugin": plugin,
114
- "nframes": float("inf"),
115
- "codec": "unknown",
116
- "fps": 30.0, # Default values
117
- "duration": 0,
118
- "size": (0, 0),
119
- }
120
-
121
- def dump_to_fileobj(
122
- self,
123
- obj: np.ndarray | torch.Tensor,
124
- file: IO[bytes],
125
- format: str = "mp4", # pylint: disable=redefined-builtin
126
- fps: int = 17,
127
- quality: int = 5,
128
- ffmpeg_params=None,
129
- **kwargs,
130
- ):
131
- """
132
- Save an array of video frames to a file-like object using imageio.
133
-
134
- Parameters:
135
- obj (Union[np.ndarray, torch.Tensor]): An array of frames to be saved as video.
136
- file (IO[bytes]): A file-like object to which the video data will be written.
137
- format (str): Format of the video file (default 'mp4').
138
- fps (int): Frames per second of the output video (default 17).
139
- quality (int): Quality of the video (0-10, default 5).
140
- ffmpeg_params (list): Additional parameters to pass to ffmpeg.
141
-
142
- """
143
- if isinstance(obj, torch.Tensor):
144
- assert obj.dtype == torch.uint8, "Tensor must be of type uint8"
145
- obj = obj.cpu().numpy()
146
- h, w = obj.shape[1:-1]
147
-
148
- # Default ffmpeg params that ensure width and height are set
149
- default_ffmpeg_params = ["-s", f"{w}x{h}"]
150
-
151
- # Use provided ffmpeg_params if any, otherwise use defaults
152
- final_ffmpeg_params = ffmpeg_params if ffmpeg_params is not None else default_ffmpeg_params
153
-
154
- mimsave_kwargs = {
155
- "fps": fps,
156
- "quality": quality,
157
- "macro_block_size": 1,
158
- "ffmpeg_params": final_ffmpeg_params,
159
- "output_params": ["-f", "mp4"],
160
- }
161
- # Update with any other kwargs
162
- mimsave_kwargs.update(kwargs)
163
- log.debug(f"mimsave_kwargs: {mimsave_kwargs}")
164
-
165
- imageio.mimsave(file, obj, format, **mimsave_kwargs)
166
-
167
- def dump_to_str(self, obj, **kwargs):
168
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/handlers/json_handler.py DELETED
@@ -1,49 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import json
17
-
18
- import numpy as np
19
-
20
- from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler
21
-
22
-
23
- def set_default(obj):
24
- """Set default json values for non-serializable values.
25
-
26
- It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
27
- It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
28
- etc.) into plain numbers of plain python built-in types.
29
- """
30
- if isinstance(obj, (set, range)):
31
- return list(obj)
32
- elif isinstance(obj, np.ndarray):
33
- return obj.tolist()
34
- elif isinstance(obj, np.generic):
35
- return obj.item()
36
- raise TypeError(f"{type(obj)} is unsupported for json dump")
37
-
38
-
39
- class JsonHandler(BaseFileHandler):
40
- def load_from_fileobj(self, file):
41
- return json.load(file)
42
-
43
- def dump_to_fileobj(self, obj, file, **kwargs):
44
- kwargs.setdefault("default", set_default)
45
- json.dump(obj, file, **kwargs)
46
-
47
- def dump_to_str(self, obj, **kwargs):
48
- kwargs.setdefault("default", set_default)
49
- return json.dumps(obj, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/handlers/jsonl_handler.py DELETED
@@ -1,80 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import json
17
- from typing import IO
18
-
19
- import numpy as np
20
-
21
- from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler
22
-
23
-
24
- def set_default(obj):
25
- """Set default json values for non-serializable values.
26
-
27
- It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
28
- It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
29
- etc.) into plain numbers of plain python built-in types.
30
- """
31
- if isinstance(obj, (set, range)):
32
- return list(obj)
33
- elif isinstance(obj, np.ndarray):
34
- return obj.tolist()
35
- elif isinstance(obj, np.generic):
36
- return obj.item()
37
- raise TypeError(f"{type(obj)} is unsupported for json dump")
38
-
39
-
40
- class JsonlHandler(BaseFileHandler):
41
- """Handler for JSON lines (JSONL) files."""
42
-
43
- def load_from_fileobj(self, file: IO[bytes]):
44
- """Load JSON objects from a newline-delimited JSON (JSONL) file object.
45
-
46
- Returns:
47
- A list of Python objects loaded from each JSON line.
48
- """
49
- data = []
50
- for line in file:
51
- line = line.strip()
52
- if not line:
53
- continue # skip empty lines if any
54
- data.append(json.loads(line))
55
- return data
56
-
57
- def dump_to_fileobj(self, obj: IO[bytes], file, **kwargs):
58
- """Dump a list of objects to a newline-delimited JSON (JSONL) file object.
59
-
60
- Args:
61
- obj: A list (or iterable) of objects to dump line by line.
62
- """
63
- kwargs.setdefault("default", set_default)
64
- for item in obj:
65
- file.write(json.dumps(item, **kwargs) + "\n")
66
-
67
- def dump_to_str(self, obj, **kwargs):
68
- """Dump a list of objects to a newline-delimited JSON (JSONL) string."""
69
- kwargs.setdefault("default", set_default)
70
- lines = [json.dumps(item, **kwargs) for item in obj]
71
- return "\n".join(lines)
72
-
73
-
74
- if __name__ == "__main__":
75
- from lyra_2._ext.imaginaire.utils.easy_io import easy_io
76
-
77
- easy_io.dump([1, 2, 3], "test.jsonl", file_format="jsonl")
78
- print(easy_io.load("test.jsonl"))
79
- easy_io.dump([{"key1": 1, "key2": 2}, {"key1": 3, "key2": 4}], "test.jsonl", file_format="jsonl")
80
- print(easy_io.load("test.jsonl"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/handlers/np_handler.py DELETED
@@ -1,89 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from io import BytesIO
17
- from typing import IO, Any
18
-
19
- import numpy as np
20
-
21
- from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler
22
-
23
-
24
- class NumpyHandler(BaseFileHandler):
25
- str_like = False
26
-
27
- def load_from_fileobj(self, file: IO[bytes], **kwargs) -> Any:
28
- """
29
- Load a NumPy array from a file-like object.
30
-
31
- Parameters:
32
- file (IO[bytes]): The file-like object containing the NumPy array data.
33
- **kwargs: Additional keyword arguments passed to `np.load`.
34
-
35
- Returns:
36
- numpy.ndarray: The loaded NumPy array.
37
- """
38
- return np.load(file, **kwargs)
39
-
40
- def load_from_path(self, filepath: str, **kwargs) -> Any:
41
- """
42
- Load a NumPy array from a file path.
43
-
44
- Parameters:
45
- filepath (str): The path to the file to load.
46
- **kwargs: Additional keyword arguments passed to `np.load`.
47
-
48
- Returns:
49
- numpy.ndarray: The loaded NumPy array.
50
- """
51
- return super().load_from_path(filepath, mode="rb", **kwargs)
52
-
53
- def dump_to_str(self, obj: np.ndarray, **kwargs) -> str:
54
- """
55
- Serialize a NumPy array to a string in binary format.
56
-
57
- Parameters:
58
- obj (np.ndarray): The NumPy array to serialize.
59
- **kwargs: Additional keyword arguments passed to `np.save`.
60
-
61
- Returns:
62
- str: The serialized NumPy array as a string.
63
- """
64
- with BytesIO() as f:
65
- np.save(f, obj, **kwargs)
66
- return f.getvalue()
67
-
68
- def dump_to_fileobj(self, obj: np.ndarray, file: IO[bytes], **kwargs):
69
- """
70
- Dump a NumPy array to a file-like object.
71
-
72
- Parameters:
73
- obj (np.ndarray): The NumPy array to dump.
74
- file (IO[bytes]): The file-like object to which the array is dumped.
75
- **kwargs: Additional keyword arguments passed to `np.save`.
76
- """
77
- np.save(file, obj, **kwargs)
78
-
79
- def dump_to_path(self, obj: np.ndarray, filepath: str, **kwargs):
80
- """
81
- Dump a NumPy array to a file path.
82
-
83
- Parameters:
84
- obj (np.ndarray): The NumPy array to dump.
85
- filepath (str): The file path where the array should be saved.
86
- **kwargs: Additional keyword arguments passed to `np.save`.
87
- """
88
- with open(filepath, "wb") as f:
89
- np.save(f, obj, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lyra_2/_ext/imaginaire/utils/easy_io/handlers/pandas_handler.py DELETED
@@ -1,31 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import pandas as pd
17
-
18
- from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler # isort:skip
19
-
20
-
21
- class PandasHandler(BaseFileHandler):
22
- str_like = False
23
-
24
- def load_from_fileobj(self, file, **kwargs):
25
- return pd.read_csv(file, **kwargs)
26
-
27
- def dump_to_fileobj(self, obj, file, **kwargs):
28
- obj.to_csv(file, **kwargs)
29
-
30
- def dump_to_str(self, obj, **kwargs):
31
- raise NotImplementedError("PandasHandler does not support dumping to str")