Spaces:
Running on Zero
Running on Zero
Delete lyra_2
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- lyra_2/__init__.py +0 -14
- lyra_2/_ext/__init__.py +0 -4
- lyra_2/_ext/imaginaire/__init__.py +0 -14
- lyra_2/_ext/imaginaire/checkpointer/__init__.py +0 -14
- lyra_2/_ext/imaginaire/checkpointer/base.py +0 -177
- lyra_2/_ext/imaginaire/checkpointer/dcp.py +0 -1003
- lyra_2/_ext/imaginaire/checkpointer/dummy.py +0 -47
- lyra_2/_ext/imaginaire/checkpointer/s3_filesystem.py +0 -323
- lyra_2/_ext/imaginaire/config.py +0 -445
- lyra_2/_ext/imaginaire/flags.py +0 -52
- lyra_2/_ext/imaginaire/functional/batch_ops.py +0 -61
- lyra_2/_ext/imaginaire/functional/lr_scheduler.py +0 -178
- lyra_2/_ext/imaginaire/lazy_config/__init__.py +0 -60
- lyra_2/_ext/imaginaire/lazy_config/file_io.py +0 -24
- lyra_2/_ext/imaginaire/lazy_config/instantiate.py +0 -120
- lyra_2/_ext/imaginaire/lazy_config/lazy.py +0 -430
- lyra_2/_ext/imaginaire/lazy_config/omegaconf_patch.py +0 -65
- lyra_2/_ext/imaginaire/lazy_config/registry.py +0 -74
- lyra_2/_ext/imaginaire/model.py +0 -129
- lyra_2/_ext/imaginaire/trainer.py +0 -379
- lyra_2/_ext/imaginaire/types/denoise_prediction.py +0 -28
- lyra_2/_ext/imaginaire/utils/__init__.py +0 -14
- lyra_2/_ext/imaginaire/utils/callback.py +0 -524
- lyra_2/_ext/imaginaire/utils/checkpointer.py +0 -372
- lyra_2/_ext/imaginaire/utils/config_helper.py +0 -214
- lyra_2/_ext/imaginaire/utils/context_managers.py +0 -55
- lyra_2/_ext/imaginaire/utils/count_params.py +0 -29
- lyra_2/_ext/imaginaire/utils/device.py +0 -114
- lyra_2/_ext/imaginaire/utils/distributed.py +0 -443
- lyra_2/_ext/imaginaire/utils/easy_io/__init__.py +0 -14
- lyra_2/_ext/imaginaire/utils/easy_io/backends/__init__.py +0 -34
- lyra_2/_ext/imaginaire/utils/easy_io/backends/auto_auth.py +0 -64
- lyra_2/_ext/imaginaire/utils/easy_io/backends/base_backend.py +0 -60
- lyra_2/_ext/imaginaire/utils/easy_io/backends/boto3_backend.py +0 -841
- lyra_2/_ext/imaginaire/utils/easy_io/backends/boto3_client.py +0 -565
- lyra_2/_ext/imaginaire/utils/easy_io/backends/http_backend.py +0 -91
- lyra_2/_ext/imaginaire/utils/easy_io/backends/local_backend.py +0 -550
- lyra_2/_ext/imaginaire/utils/easy_io/backends/registry_utils.py +0 -130
- lyra_2/_ext/imaginaire/utils/easy_io/easy_io.py +0 -1085
- lyra_2/_ext/imaginaire/utils/easy_io/file_client.py +0 -458
- lyra_2/_ext/imaginaire/utils/easy_io/handlers/__init__.py +0 -29
- lyra_2/_ext/imaginaire/utils/easy_io/handlers/base.py +0 -44
- lyra_2/_ext/imaginaire/utils/easy_io/handlers/byte_handler.py +0 -39
- lyra_2/_ext/imaginaire/utils/easy_io/handlers/csv_handler.py +0 -42
- lyra_2/_ext/imaginaire/utils/easy_io/handlers/gzip_handler.py +0 -33
- lyra_2/_ext/imaginaire/utils/easy_io/handlers/imageio_video_handler.py +0 -168
- lyra_2/_ext/imaginaire/utils/easy_io/handlers/json_handler.py +0 -49
- lyra_2/_ext/imaginaire/utils/easy_io/handlers/jsonl_handler.py +0 -80
- lyra_2/_ext/imaginaire/utils/easy_io/handlers/np_handler.py +0 -89
- lyra_2/_ext/imaginaire/utils/easy_io/handlers/pandas_handler.py +0 -31
lyra_2/__init__.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/__init__.py
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copied from cosmos repository, with necessary modifications.
|
| 3 |
-
Original source: https://github.com/nvidia-cosmos/cosmos-predict2.5/
|
| 4 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/__init__.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/checkpointer/__init__.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/checkpointer/base.py
DELETED
|
@@ -1,177 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
from abc import ABC, abstractmethod
|
| 18 |
-
from typing import Optional
|
| 19 |
-
|
| 20 |
-
import torch
|
| 21 |
-
|
| 22 |
-
from lyra_2._ext.imaginaire.config import CheckpointConfig, JobConfig
|
| 23 |
-
from lyra_2._ext.imaginaire.model import ImaginaireModel
|
| 24 |
-
from lyra_2._ext.imaginaire.utils import callback
|
| 25 |
-
from lyra_2._ext.imaginaire.utils.easy_io import easy_io
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class AbstractCheckpointer(ABC):
|
| 29 |
-
"""The checkpointer class. Supports checkpoint saving/loading to both local disk or object store."""
|
| 30 |
-
|
| 31 |
-
def __init__(
|
| 32 |
-
self,
|
| 33 |
-
config_checkpoint: CheckpointConfig,
|
| 34 |
-
config_job: JobConfig,
|
| 35 |
-
callbacks: Optional[callback.CallBackGroup] = None,
|
| 36 |
-
):
|
| 37 |
-
"""Constructor of the checkpointer.
|
| 38 |
-
|
| 39 |
-
Args:
|
| 40 |
-
config_checkpoint (CheckpointConfig): The config object for the checkpointer.
|
| 41 |
-
"""
|
| 42 |
-
self.config_checkpoint = config_checkpoint
|
| 43 |
-
# Set the callback functions.
|
| 44 |
-
self.callbacks = callbacks
|
| 45 |
-
self.save_to_object_store = config_checkpoint.save_to_object_store.enabled
|
| 46 |
-
self.load_from_object_store = config_checkpoint.load_from_object_store.enabled
|
| 47 |
-
|
| 48 |
-
# Set checkpoint directories for local and object store paths
|
| 49 |
-
self._local_dirname = os.path.join(config_job.path_local, "checkpoints")
|
| 50 |
-
self._object_store_dirname = os.path.join(config_job.path, "checkpoints")
|
| 51 |
-
|
| 52 |
-
self.strict_resume = config_checkpoint.strict_resume
|
| 53 |
-
self.load_path = config_checkpoint.load_path or None
|
| 54 |
-
self.load_training_state = config_checkpoint.load_training_state
|
| 55 |
-
self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state
|
| 56 |
-
self.save_thread = None
|
| 57 |
-
self.verbose = config_checkpoint.verbose
|
| 58 |
-
self.keys_not_to_resume = config_checkpoint.keys_not_to_resume
|
| 59 |
-
self.broadcast_via_filesystem = config_checkpoint.broadcast_via_filesystem
|
| 60 |
-
# Create the object store client interface.
|
| 61 |
-
if config_checkpoint.load_from_object_store.enabled:
|
| 62 |
-
self.load_s3_backend_key = "_ckpt_s3_loader"
|
| 63 |
-
easy_io.set_s3_backend(
|
| 64 |
-
key="_ckpt_s3_loader",
|
| 65 |
-
backend_args={
|
| 66 |
-
"backend": "s3",
|
| 67 |
-
"path_mapping": {
|
| 68 |
-
"s3://ckpt/": f"s3://{config_checkpoint.load_from_object_store.bucket}/",
|
| 69 |
-
},
|
| 70 |
-
"s3_credential_path": config_checkpoint.load_from_object_store.credentials,
|
| 71 |
-
},
|
| 72 |
-
)
|
| 73 |
-
else:
|
| 74 |
-
self.load_s3_backend_key = None
|
| 75 |
-
|
| 76 |
-
if config_checkpoint.save_to_object_store.enabled:
|
| 77 |
-
self.save_s3_backend_key = "_ckpt_s3_saver"
|
| 78 |
-
easy_io.set_s3_backend(
|
| 79 |
-
key="_ckpt_s3_saver",
|
| 80 |
-
backend_args={
|
| 81 |
-
"backend": "s3",
|
| 82 |
-
"path_mapping": {
|
| 83 |
-
"s3://ckpt/": f"s3://{config_checkpoint.save_to_object_store.bucket}/",
|
| 84 |
-
},
|
| 85 |
-
"s3_credential_path": config_checkpoint.save_to_object_store.credentials,
|
| 86 |
-
},
|
| 87 |
-
)
|
| 88 |
-
else:
|
| 89 |
-
self.save_s3_backend_key = None
|
| 90 |
-
|
| 91 |
-
@abstractmethod
|
| 92 |
-
def save(
|
| 93 |
-
self,
|
| 94 |
-
model: ImaginaireModel,
|
| 95 |
-
optimizer: torch.optim.Optimizer,
|
| 96 |
-
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 97 |
-
grad_scaler: torch.amp.GradScaler,
|
| 98 |
-
iteration: int,
|
| 99 |
-
) -> None:
|
| 100 |
-
pass
|
| 101 |
-
|
| 102 |
-
@abstractmethod
|
| 103 |
-
def load(
|
| 104 |
-
self,
|
| 105 |
-
model: ImaginaireModel,
|
| 106 |
-
optimizer: Optional[torch.optim.Optimizer] = None,
|
| 107 |
-
scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
|
| 108 |
-
grad_scaler: Optional[torch.amp.GradScaler] = None,
|
| 109 |
-
) -> int:
|
| 110 |
-
pass
|
| 111 |
-
|
| 112 |
-
@property
|
| 113 |
-
def save_bucket(self):
|
| 114 |
-
"""Get the bucket name for saving checkpoints."""
|
| 115 |
-
return self.config_checkpoint.save_to_object_store.bucket if self.save_to_object_store else None
|
| 116 |
-
|
| 117 |
-
@property
|
| 118 |
-
def load_bucket(self):
|
| 119 |
-
"""Get the bucket name for loading checkpoints."""
|
| 120 |
-
return self.config_checkpoint.load_from_object_store.bucket if self.load_from_object_store else None
|
| 121 |
-
|
| 122 |
-
@property
|
| 123 |
-
def save_dirname(self):
|
| 124 |
-
return (
|
| 125 |
-
f"s3://{self.save_bucket}/{self._object_store_dirname}"
|
| 126 |
-
if self.save_to_object_store
|
| 127 |
-
else self._local_dirname
|
| 128 |
-
)
|
| 129 |
-
|
| 130 |
-
@property
|
| 131 |
-
def load_dirname(self):
|
| 132 |
-
return (
|
| 133 |
-
f"s3://{self.load_bucket}/{self._object_store_dirname}"
|
| 134 |
-
if self.load_from_object_store
|
| 135 |
-
else self._local_dirname
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
def finalize(self) -> None:
|
| 139 |
-
"""Finalize the checkpointer."""
|
| 140 |
-
if self.save_thread:
|
| 141 |
-
self.save_thread.join()
|
| 142 |
-
|
| 143 |
-
def _read_latest_checkpoint_file(self) -> str | None:
|
| 144 |
-
"""Get the file name of the latest saved checkpoint. If it doesn't exist, return None.
|
| 145 |
-
|
| 146 |
-
Returns:
|
| 147 |
-
checkpoint_file (str | None): file name of the latest saved checkpoint.
|
| 148 |
-
"""
|
| 149 |
-
checkpoint_file = None
|
| 150 |
-
checkpoint_path = os.path.join(self.load_dirname, "latest_checkpoint.txt")
|
| 151 |
-
if easy_io.exists(f"{checkpoint_path}", backend_key=self.load_s3_backend_key):
|
| 152 |
-
checkpoint_file = easy_io.load(f"{checkpoint_path}", backend_key=self.load_s3_backend_key).strip()
|
| 153 |
-
|
| 154 |
-
return checkpoint_file
|
| 155 |
-
|
| 156 |
-
def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None:
|
| 157 |
-
"""Track the file name of the latest saved checkpoint.
|
| 158 |
-
|
| 159 |
-
Args:
|
| 160 |
-
checkpoint_file (str): file name of the latest saved checkpoint.
|
| 161 |
-
"""
|
| 162 |
-
content = f"{checkpoint_file}\n"
|
| 163 |
-
checkpoint_path = os.path.join(self.save_dirname, "latest_checkpoint.txt")
|
| 164 |
-
easy_io.dump(
|
| 165 |
-
content,
|
| 166 |
-
checkpoint_path,
|
| 167 |
-
backend_key=self.save_s3_backend_key,
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
def _check_checkpoint_exists(self, checkpoint_path: str) -> None:
|
| 171 |
-
"""If the file checkpoint_path does not exist, raise an error.
|
| 172 |
-
|
| 173 |
-
Args:
|
| 174 |
-
checkpoint_path (str): full path to the checkpoint.
|
| 175 |
-
"""
|
| 176 |
-
if not easy_io.exists(f"{checkpoint_path}", backend_key=self.load_s3_backend_key):
|
| 177 |
-
raise FileNotFoundError(f"File not found (object store): {checkpoint_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/checkpointer/dcp.py
DELETED
|
@@ -1,1003 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
"""
|
| 18 |
-
Distributed checkpoint (DCP) directory structure and storage backends.
|
| 19 |
-
|
| 20 |
-
The checkpointer saves model state in a sharded format across multiple processes:
|
| 21 |
-
|
| 22 |
-
self.save_dirname/
|
| 23 |
-
├── iter_000000005/ # Checkpoint at iteration 5
|
| 24 |
-
│ ├── model/ # Model state shards
|
| 25 |
-
│ │ ├── __0_0.distcp # Shard 0 from rank 0
|
| 26 |
-
│ │ └── __1_0.distcp # Shard 1 from rank 1
|
| 27 |
-
│ ├── optim/ # Optimizer state shards
|
| 28 |
-
│ │ ├── __0_0.distcp # Shard 0 from rank 0
|
| 29 |
-
│ │ └── __1_0.distcp # Shard 1 from rank 1
|
| 30 |
-
│ ├── scheduler/ # Learning rate scheduler state
|
| 31 |
-
│ │ ├── __0_0.distcp # Shard 0 from rank 0
|
| 32 |
-
│ │ └── __1_0.distcp # Shard 1 from rank 1
|
| 33 |
-
│ └── trainer/ # Additional training state
|
| 34 |
-
│ ├── __0_0.distcp # Shard 0 from rank 0
|
| 35 |
-
│ └── __1_0.distcp # Shard 1 from rank 1
|
| 36 |
-
└── latest_checkpoint.txt # Points to most recent checkpoint folder, e.g. iter_000000005
|
| 37 |
-
|
| 38 |
-
Storage path format:
|
| 39 |
-
self.save_dirname = "{config_job.path_local}/checkpoints"
|
| 40 |
-
|
| 41 |
-
The sharded format enables efficient distributed saving/loading by:
|
| 42 |
-
1. Parallelizing I/O across processes
|
| 43 |
-
2. Reducing memory usage per process
|
| 44 |
-
3. Supporting both local and cloud storage backends
|
| 45 |
-
"""
|
| 46 |
-
|
| 47 |
-
import enum
|
| 48 |
-
import functools
|
| 49 |
-
import multiprocessing
|
| 50 |
-
import os
|
| 51 |
-
import queue
|
| 52 |
-
import re
|
| 53 |
-
import time
|
| 54 |
-
import warnings
|
| 55 |
-
from collections import namedtuple
|
| 56 |
-
from multiprocessing import get_context
|
| 57 |
-
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
|
| 58 |
-
|
| 59 |
-
import torch
|
| 60 |
-
import torch.distributed
|
| 61 |
-
import torch.distributed.checkpoint as dcp
|
| 62 |
-
from torch import nn
|
| 63 |
-
from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
|
| 64 |
-
from torch.distributed.checkpoint._storage_utils import _storage_setup
|
| 65 |
-
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
|
| 66 |
-
from torch.distributed.checkpoint.logger import _dcp_method_logger
|
| 67 |
-
from torch.distributed.checkpoint.state_dict import (
|
| 68 |
-
StateDictOptions,
|
| 69 |
-
get_model_state_dict,
|
| 70 |
-
get_optimizer_state_dict,
|
| 71 |
-
set_model_state_dict,
|
| 72 |
-
set_optimizer_state_dict,
|
| 73 |
-
)
|
| 74 |
-
from torch.distributed.checkpoint.stateful import Stateful
|
| 75 |
-
from torch.distributed.checkpoint.storage import StorageReader
|
| 76 |
-
from torch.distributed.checkpoint.utils import _api_bc_check, _DistWrapper, _profile
|
| 77 |
-
from torch.distributed.tensor import distribute_tensor
|
| 78 |
-
|
| 79 |
-
from lyra_2._ext.imaginaire.checkpointer.base import AbstractCheckpointer
|
| 80 |
-
from lyra_2._ext.imaginaire.checkpointer.s3_filesystem import S3StorageReader, S3StorageWriter
|
| 81 |
-
from lyra_2._ext.imaginaire.config import CheckpointConfig, JobConfig
|
| 82 |
-
from lyra_2._ext.imaginaire.model import ImaginaireModel
|
| 83 |
-
from lyra_2._ext.imaginaire.utils import callback, distributed, log, misc
|
| 84 |
-
from lyra_2._ext.imaginaire.utils.easy_io import easy_io
|
| 85 |
-
from lyra_2._src.models.wan_t2v_model import WANDiffusionModel as DiffusionModel
|
| 86 |
-
|
| 87 |
-
try:
|
| 88 |
-
"""
|
| 89 |
-
We override the default load function to skip loadding _extra_state keys created by transformer-engine.
|
| 90 |
-
"""
|
| 91 |
-
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner as _DefaultLoadPlanner
|
| 92 |
-
from torch.distributed.checkpoint.default_planner import (
|
| 93 |
-
DTensor,
|
| 94 |
-
LoadPlan,
|
| 95 |
-
_create_read_items,
|
| 96 |
-
_version,
|
| 97 |
-
flatten_state_dict,
|
| 98 |
-
)
|
| 99 |
-
from torch.distributed.checkpoint.metadata import Metadata, TensorStorageMetadata
|
| 100 |
-
|
| 101 |
-
def create_default_local_load_plan(
|
| 102 |
-
state_dict: dict[str, Any], metadata: Metadata, strict: bool = True, dcp_allow_mismatched_size: bool = False
|
| 103 |
-
) -> LoadPlan:
|
| 104 |
-
requests = []
|
| 105 |
-
"""
|
| 106 |
-
Create the ``LoadPlan`` used by DefaultLoadPlanner.
|
| 107 |
-
|
| 108 |
-
It produces one read item per value in ``state_dict`` using the metadata in ``metadata``.
|
| 109 |
-
|
| 110 |
-
The default behavior is to match key exactly between state_dict and metadata.
|
| 111 |
-
It handles resharding by issuing multiple read requests against storage in order to match
|
| 112 |
-
load requirements.
|
| 113 |
-
"""
|
| 114 |
-
|
| 115 |
-
for fqn, obj in state_dict.items():
|
| 116 |
-
if fqn.endswith("._extra_state"): # dirty TE attention package!
|
| 117 |
-
continue
|
| 118 |
-
# ignore state_dict keys which do not exist in `state_dict` if strict=False
|
| 119 |
-
if fqn not in metadata.state_dict_metadata:
|
| 120 |
-
if strict:
|
| 121 |
-
raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
|
| 122 |
-
else:
|
| 123 |
-
log.warning(f"Local load plan: Missing key in checkpoint state_dict: {fqn}.")
|
| 124 |
-
continue
|
| 125 |
-
|
| 126 |
-
md = metadata.state_dict_metadata[fqn]
|
| 127 |
-
|
| 128 |
-
if not dcp_allow_mismatched_size:
|
| 129 |
-
if (
|
| 130 |
-
isinstance(md, TensorStorageMetadata)
|
| 131 |
-
and getattr(obj, "size", None) is not None
|
| 132 |
-
and md.size != obj.size()
|
| 133 |
-
):
|
| 134 |
-
if not strict:
|
| 135 |
-
log.critical(f"Size mismatch between saved {md.size} and current: {obj.size()} for {fqn}")
|
| 136 |
-
continue
|
| 137 |
-
else:
|
| 138 |
-
raise ValueError(
|
| 139 |
-
f"Size mismatch between saved {md.size} and current: {obj.size()} for {fqn}",
|
| 140 |
-
)
|
| 141 |
-
# Since DTensor supports submesh, adding extra check to ensure _create_read_items()
|
| 142 |
-
# gets called only when the current rank is part of the mesh for the corresponding DTensor.
|
| 143 |
-
if isinstance(obj, DTensor):
|
| 144 |
-
if obj.device_mesh.get_coordinate() is not None:
|
| 145 |
-
requests += _create_read_items(fqn, md, obj)
|
| 146 |
-
else:
|
| 147 |
-
requests += _create_read_items(fqn, md, obj)
|
| 148 |
-
|
| 149 |
-
return LoadPlan(requests)
|
| 150 |
-
|
| 151 |
-
class DefaultLoadPlanner(_DefaultLoadPlanner):
|
| 152 |
-
def set_partial_channel_weight(self, dcp_allow_mismatched_size: bool):
|
| 153 |
-
self.dcp_allow_mismatched_size = dcp_allow_mismatched_size
|
| 154 |
-
|
| 155 |
-
def create_local_plan(self) -> LoadPlan:
|
| 156 |
-
assert self.metadata is not None
|
| 157 |
-
if self.flatten_state_dict:
|
| 158 |
-
# To support checkpoints that are saved before v2.4, we have to
|
| 159 |
-
# differentiate if the missing keys are due to old checkpoints.
|
| 160 |
-
# The contracts are:
|
| 161 |
-
# 1. There are 3 cases when we found a missing key.
|
| 162 |
-
# 1.1 Actual missing key, but allow_partial_load is False
|
| 163 |
-
# 1.2 Actual missing key, but allow_partial load is True
|
| 164 |
-
# 1.3 Old checkpoint, but allow_partial_load is False
|
| 165 |
-
# 1.4 Old checkpoint, but allow_partial_load is True
|
| 166 |
-
# 2. If we found a missing key, we first convert the keys back to
|
| 167 |
-
# the key format of v2.3
|
| 168 |
-
# 3. If the previous missing keys are in the v2.3 keys, we assume
|
| 169 |
-
# this is a old checkpoint.
|
| 170 |
-
# 4. Pass the state_dict to `create_default_local_load_plan()`,
|
| 171 |
-
# which has the logic to check missing for allow_partial_load.
|
| 172 |
-
# So for 1.2 and 1.4 cases, we delegate allow_partial_load check to
|
| 173 |
-
# `create_default_local_load_plan()`. The logic here is to determine
|
| 174 |
-
# whether the checkpoint belong to 2.3 (or before) or 2.4 (or after).
|
| 175 |
-
current_keys = set(self.state_dict.keys())
|
| 176 |
-
load_keys = set(self.metadata.state_dict_metadata.keys())
|
| 177 |
-
missing_keys = load_keys - current_keys
|
| 178 |
-
if missing_keys:
|
| 179 |
-
_version._derived_version = "2_3"
|
| 180 |
-
old_state_dict, old_mappings = flatten_state_dict(self.original_state_dict)
|
| 181 |
-
old_keys = set(old_state_dict.keys())
|
| 182 |
-
if old_keys & missing_keys:
|
| 183 |
-
self.state_dict, self.mappings = old_state_dict, old_mappings
|
| 184 |
-
# _derived_version is only used by flatten_state_dict now.
|
| 185 |
-
# Set it back to None so that later we can save to a new version.
|
| 186 |
-
_version._derived_version = None
|
| 187 |
-
|
| 188 |
-
return create_default_local_load_plan(
|
| 189 |
-
self.state_dict,
|
| 190 |
-
self.metadata,
|
| 191 |
-
not self.allow_partial_load,
|
| 192 |
-
getattr(self, "dcp_allow_mismatched_size", False),
|
| 193 |
-
)
|
| 194 |
-
|
| 195 |
-
log.info("for the back comptiable pytorch! New DefaultLoadPlanner class is created.")
|
| 196 |
-
|
| 197 |
-
@_dcp_method_logger(log_exceptions=True)
|
| 198 |
-
@_api_bc_check
|
| 199 |
-
def load(
|
| 200 |
-
state_dict: dict[str, Any],
|
| 201 |
-
*,
|
| 202 |
-
checkpoint_id: Union[str, os.PathLike, None] = None,
|
| 203 |
-
storage_reader: Optional[StorageReader] = None,
|
| 204 |
-
planner: Optional[DefaultLoadPlanner] = None,
|
| 205 |
-
process_group: Optional[torch.distributed.ProcessGroup] = None,
|
| 206 |
-
no_dist: bool = False,
|
| 207 |
-
) -> None:
|
| 208 |
-
"""
|
| 209 |
-
We override the default load function to perform a load plan check for mismatched/missing keys.
|
| 210 |
-
==========================Original Doc string=====================================
|
| 211 |
-
Load a checkpoint into a distributed state dict in SPMD style.
|
| 212 |
-
|
| 213 |
-
Each rank must have the same keys in their ``state_dict`` provided to this
|
| 214 |
-
API. Mismatched keys may result in hangs or errors. If unsure, you can use
|
| 215 |
-
the ``utils._assert_same_keys`` API to check (but may incur communication
|
| 216 |
-
costs).
|
| 217 |
-
|
| 218 |
-
Each rank will try to read the least amount of data necessary
|
| 219 |
-
to fullfill the requested `state_dict`. When loading :class:`ShardedTensor`
|
| 220 |
-
or :class:`DTensor` instances, each rank only reads data for their local shards.
|
| 221 |
-
|
| 222 |
-
For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``),
|
| 223 |
-
load will first call ``state_dict`` before attempting deserialization, followed by
|
| 224 |
-
``load_state_dict`` once the deserialization is complete.
|
| 225 |
-
For each non-``Stateful`` object, load will deserailize the object, and then replace
|
| 226 |
-
it in the ``state_dict`` with the deserialized object.
|
| 227 |
-
|
| 228 |
-
.. warning::
|
| 229 |
-
All tensors in ``state_dict`` must be allocated on their
|
| 230 |
-
destination device *prior to* calling this function.
|
| 231 |
-
|
| 232 |
-
All non-tensor data is loaded using `torch.load()` and modified in place
|
| 233 |
-
on state_dict.
|
| 234 |
-
|
| 235 |
-
.. warning::
|
| 236 |
-
Users must call `load_state_dict` on the root module to ensure load
|
| 237 |
-
pos-processing and non-tensor data properly propagates.
|
| 238 |
-
|
| 239 |
-
.. note:
|
| 240 |
-
If no process group is initialized, this function will assume the intent
|
| 241 |
-
is to load a checkpoint into the local process. This can be useful in the
|
| 242 |
-
case of local inference, and when using regular Tensors (as opposed to DTensor
|
| 243 |
-
or ShardedTensor)
|
| 244 |
-
|
| 245 |
-
.. note:
|
| 246 |
-
Rank 0 is assumed to be the coordinator rank.
|
| 247 |
-
|
| 248 |
-
Args:
|
| 249 |
-
state_dict (Dict[str, Any]): The state_dict to load the checkpoint into.
|
| 250 |
-
checkpoint_id (Union[str, os.PathLike, None]):
|
| 251 |
-
The ID of this checkpoint instance. The meaning of the checkpoint_id
|
| 252 |
-
depends on the storage. It can be a path to a folder or to a file.
|
| 253 |
-
It can also be a key if the storage is a key-value store.
|
| 254 |
-
(Default: ``None``)
|
| 255 |
-
storage_reader (Optional[StorageReader]):
|
| 256 |
-
Instance of StorageWriter used to perform reads. If this is not
|
| 257 |
-
specified, DCP will automatically infer the reader based on the
|
| 258 |
-
checkpoint_id. If checkpoint_id is also None, an exception will
|
| 259 |
-
be raised. (Default: ``None``)
|
| 260 |
-
planner (Optional[LoadPlanner]):
|
| 261 |
-
Instance of LoadPlanner. If this is not specificed, the default
|
| 262 |
-
planner will be used. (Default: ``None``)
|
| 263 |
-
process_group (Optional[ProcessGroup]):
|
| 264 |
-
ProcessGroup to be used for cross-rank synchronization.
|
| 265 |
-
(Default: ``None``)
|
| 266 |
-
no_dist (bool): If ``True``, this function will assume the intent is to load
|
| 267 |
-
a checkpoint without using cross-rank synchronization. (Default: ``False``)
|
| 268 |
-
Returns:
|
| 269 |
-
None.
|
| 270 |
-
|
| 271 |
-
Examples
|
| 272 |
-
>>> # xdoctest: +SKIP
|
| 273 |
-
>>> my_model = MyModule()
|
| 274 |
-
>>> optimizer = Adagrad(my_model.parameters())
|
| 275 |
-
>>> model_state_dict = my_model.state_dict()
|
| 276 |
-
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(
|
| 277 |
-
... "/checkpoint/1"
|
| 278 |
-
... )
|
| 279 |
-
|
| 280 |
-
>>> torch.distributed.checkpoint.load_state_dict(
|
| 281 |
-
>>> state_dict=model_state_dict,
|
| 282 |
-
>>> storage_reader=fs_storage_reader,
|
| 283 |
-
>>> )
|
| 284 |
-
|
| 285 |
-
>>> # module.load_state_dict() function might have customized steps
|
| 286 |
-
>>> # to flush the state_dict, must call it to
|
| 287 |
-
>>> # ensure correct behavior.
|
| 288 |
-
>>> my_model.load_state_dict(model_state_dict)
|
| 289 |
-
|
| 290 |
-
.. note::
|
| 291 |
-
load_state_dict uses collectives to coordinate reads across ranks.
|
| 292 |
-
For NCCL-based process groups, internal tensor representations of
|
| 293 |
-
objects must be moved to the GPU device before communication takes place.
|
| 294 |
-
In this case, the device used is given by ``torch.cuda.current_device()``
|
| 295 |
-
and it is the user's responsibility to ensure that this is set so that each
|
| 296 |
-
rank has an individual GPU, via ``torch.cuda.set_device()``.
|
| 297 |
-
"""
|
| 298 |
-
|
| 299 |
-
no_dist = no_dist or (not torch.distributed.is_available()) or (not torch.distributed.is_initialized())
|
| 300 |
-
if no_dist:
|
| 301 |
-
warnings.warn(
|
| 302 |
-
"torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to load in a single process."
|
| 303 |
-
)
|
| 304 |
-
|
| 305 |
-
with _profile():
|
| 306 |
-
storage_reader = cast(StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True))
|
| 307 |
-
|
| 308 |
-
# All ranks must have the same keys in their `state_dict` provided to
|
| 309 |
-
# this API. See documentation for more details.
|
| 310 |
-
# Here we simply sort the keys to ensure that all ranks load values in
|
| 311 |
-
# the same order.
|
| 312 |
-
keys = sorted(state_dict.keys())
|
| 313 |
-
|
| 314 |
-
statetful_sd = {}
|
| 315 |
-
for key in keys:
|
| 316 |
-
if key not in state_dict:
|
| 317 |
-
continue
|
| 318 |
-
elem = state_dict[key]
|
| 319 |
-
statetful_sd[key] = elem.state_dict() if isinstance(elem, Stateful) else elem
|
| 320 |
-
|
| 321 |
-
_load_state_dict(
|
| 322 |
-
state_dict=statetful_sd,
|
| 323 |
-
storage_reader=storage_reader,
|
| 324 |
-
process_group=process_group,
|
| 325 |
-
no_dist=no_dist,
|
| 326 |
-
planner=planner,
|
| 327 |
-
)
|
| 328 |
-
for key in keys:
|
| 329 |
-
if key not in state_dict:
|
| 330 |
-
continue
|
| 331 |
-
elem = state_dict[key]
|
| 332 |
-
if isinstance(elem, Stateful):
|
| 333 |
-
# If the state_dict is a Stateful object,
|
| 334 |
-
# DCP does an in-place load in the original state dict.
|
| 335 |
-
elem.load_state_dict(statetful_sd[key])
|
| 336 |
-
else:
|
| 337 |
-
# Otherwise, replace the state_dict with the loaded state_dict.
|
| 338 |
-
state_dict[key] = statetful_sd[key]
|
| 339 |
-
|
| 340 |
-
def _load_state_dict(
|
| 341 |
-
state_dict: dict[str, Any],
|
| 342 |
-
storage_reader: StorageReader,
|
| 343 |
-
process_group: Optional[torch.distributed.ProcessGroup] = None,
|
| 344 |
-
coordinator_rank: int = 0,
|
| 345 |
-
no_dist: bool = False,
|
| 346 |
-
planner: Optional[DefaultLoadPlanner] = None,
|
| 347 |
-
) -> None:
|
| 348 |
-
torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict")
|
| 349 |
-
|
| 350 |
-
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
|
| 351 |
-
if planner is None:
|
| 352 |
-
planner = DefaultLoadPlanner()
|
| 353 |
-
|
| 354 |
-
ckpt_kwargs = {}
|
| 355 |
-
if (ckpt_id := getattr(storage_reader, "checkpoint_id", None)) is not None:
|
| 356 |
-
ckpt_kwargs["checkpoint_id"] = ckpt_id
|
| 357 |
-
ckpt_kwargs["process_group"] = distW.group
|
| 358 |
-
|
| 359 |
-
@_dcp_method_logger(**ckpt_kwargs)
|
| 360 |
-
def local_step():
|
| 361 |
-
assert planner is not None
|
| 362 |
-
metadata = storage_reader.read_metadata()
|
| 363 |
-
planner.set_up_planner(state_dict, metadata, distW.is_coordinator)
|
| 364 |
-
storage_reader.set_up_storage_reader(metadata, distW.is_coordinator)
|
| 365 |
-
|
| 366 |
-
local_plan = planner.create_local_plan()
|
| 367 |
-
local_plan = storage_reader.prepare_local_plan(local_plan)
|
| 368 |
-
return local_plan
|
| 369 |
-
|
| 370 |
-
@_dcp_method_logger(**ckpt_kwargs)
|
| 371 |
-
def global_step(all_local_plans):
|
| 372 |
-
assert planner is not None
|
| 373 |
-
all_local_plans = planner.create_global_plan(all_local_plans)
|
| 374 |
-
all_local_plans = storage_reader.prepare_global_plan(all_local_plans)
|
| 375 |
-
return all_local_plans
|
| 376 |
-
|
| 377 |
-
central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step)
|
| 378 |
-
if distW.is_coordinator:
|
| 379 |
-
# Compare central_plan items with storage_reader.storage_data keys
|
| 380 |
-
dest_fqns = set()
|
| 381 |
-
storage_fqns = set()
|
| 382 |
-
|
| 383 |
-
# Extract FQNs from central_plan items
|
| 384 |
-
for item in central_plan.items:
|
| 385 |
-
if hasattr(item, "dest_index") and hasattr(item.dest_index, "fqn"):
|
| 386 |
-
dest_fqns.add(item.dest_index.fqn)
|
| 387 |
-
if hasattr(item, "storage_index") and hasattr(item.storage_index, "fqn"):
|
| 388 |
-
storage_fqns.add(item.storage_index.fqn)
|
| 389 |
-
|
| 390 |
-
# Get storage data keys
|
| 391 |
-
storage_data_keys = set()
|
| 392 |
-
if hasattr(storage_reader, "storage_data") and storage_reader.storage_data is not None:
|
| 393 |
-
storage_data_keys = set(item[0].fqn for item in storage_reader.storage_data.items())
|
| 394 |
-
state_dict_keys = set(state_dict.keys())
|
| 395 |
-
# Compare sets and log differences
|
| 396 |
-
# Remove any item that has "_extra_state" as substring in the sets
|
| 397 |
-
state_dict_keys = {fqn for fqn in state_dict_keys if "_extra_state" not in fqn}
|
| 398 |
-
dest_fqns = {fqn for fqn in dest_fqns if "_extra_state" not in fqn}
|
| 399 |
-
storage_fqns = {fqn for fqn in storage_fqns if "_extra_state" not in fqn}
|
| 400 |
-
storage_data_keys = {fqn for fqn in storage_data_keys if "_extra_state" not in fqn}
|
| 401 |
-
|
| 402 |
-
log.info("=== Load Plan FQN Analysis ===")
|
| 403 |
-
log.info(f"State Dict FQNs count: {len(state_dict_keys)}")
|
| 404 |
-
log.info(f"Destination FQNs count (without _extra_state): {len(dest_fqns)}")
|
| 405 |
-
log.info(f"Loaded FQNs count (without _extra_state): {len(storage_fqns)}")
|
| 406 |
-
log.info(f"In Storage keys count (without _extra_state): {len(storage_data_keys)}")
|
| 407 |
-
|
| 408 |
-
# Find missing keys in each direction
|
| 409 |
-
state_dict_missing_from_dest = state_dict_keys - dest_fqns
|
| 410 |
-
storage_data_missing_from_storage_fqns = storage_data_keys - storage_fqns
|
| 411 |
-
|
| 412 |
-
if state_dict_missing_from_dest:
|
| 413 |
-
log.info(
|
| 414 |
-
f"State Dict FQNs missing from load plan ({len(state_dict_missing_from_dest)} items): {sorted(state_dict_missing_from_dest)}"
|
| 415 |
-
)
|
| 416 |
-
else:
|
| 417 |
-
log.info("✓ All State Dict FQNs found in storage_data")
|
| 418 |
-
|
| 419 |
-
if storage_data_missing_from_storage_fqns:
|
| 420 |
-
# If there are more than 100 "net_ema" keys in storage_data_missing_from_storage_fqns, summarize them
|
| 421 |
-
net_ema_keys = {k for k in storage_data_missing_from_storage_fqns if "net_ema" in k}
|
| 422 |
-
if len(net_ema_keys) > 100:
|
| 423 |
-
storage_data_missing_from_storage_fqns = storage_data_missing_from_storage_fqns - net_ema_keys
|
| 424 |
-
storage_data_missing_from_storage_fqns = set(
|
| 425 |
-
storage_data_missing_from_storage_fqns
|
| 426 |
-
) # ensure set type
|
| 427 |
-
storage_data_missing_from_storage_fqns.add("net_ema")
|
| 428 |
-
log.info(
|
| 429 |
-
f"Summarized {len(net_ema_keys)} 'net_ema' keys as 'net_ema' in missing storage data keys."
|
| 430 |
-
)
|
| 431 |
-
log.info(
|
| 432 |
-
f"Storage data keys not loaded by load plan ({len(storage_data_missing_from_storage_fqns)} items): {sorted(storage_data_missing_from_storage_fqns)}"
|
| 433 |
-
)
|
| 434 |
-
else:
|
| 435 |
-
log.info("✓ All storage data keys found in Loaded FQNs")
|
| 436 |
-
|
| 437 |
-
log.info("=== End Load Plan FQN Analysis ===")
|
| 438 |
-
|
| 439 |
-
@_dcp_method_logger(**ckpt_kwargs)
|
| 440 |
-
def read_data():
|
| 441 |
-
assert planner is not None
|
| 442 |
-
final_local_plan = planner.finish_plan(central_plan)
|
| 443 |
-
all_reads = storage_reader.read_data(final_local_plan, planner)
|
| 444 |
-
|
| 445 |
-
all_reads.wait()
|
| 446 |
-
return None
|
| 447 |
-
|
| 448 |
-
_ = distW.all_gather("read", read_data)
|
| 449 |
-
|
| 450 |
-
except ImportError as e:
|
| 451 |
-
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
|
| 452 |
-
|
| 453 |
-
log.critical(f"{e}, using default planner")
|
| 454 |
-
|
| 455 |
-
StateDictItemPath = namedtuple("StateDictItemPath", ["state_dict", "save_path"])
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
class ModelWrapper(Stateful):
|
| 459 |
-
"""Wrapper for model state dict handling"""
|
| 460 |
-
|
| 461 |
-
def __init__(self, model: Union[nn.Module, List[nn.Module]], load_ema_to_reg: bool = False):
|
| 462 |
-
self.model = [model] if isinstance(model, nn.Module) else model
|
| 463 |
-
self.load_ema_to_reg = load_ema_to_reg
|
| 464 |
-
if self.load_ema_to_reg:
|
| 465 |
-
assert isinstance(model, DiffusionModel)
|
| 466 |
-
|
| 467 |
-
def state_dict(self) -> Dict[str, Any]:
|
| 468 |
-
_state_dict = {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}
|
| 469 |
-
if self.load_ema_to_reg:
|
| 470 |
-
if not self.model[0].config.ema.enabled:
|
| 471 |
-
all_keys = list(_state_dict.keys())
|
| 472 |
-
assert all(k.startswith("net.") for k in all_keys), "All keys must start with net."
|
| 473 |
-
for k in all_keys:
|
| 474 |
-
_state_dict[k.replace("net.", "net_ema.")] = _state_dict.pop(k)
|
| 475 |
-
else:
|
| 476 |
-
log.warning("EMA is enabled, will only load EMA weights from checkpoint file.")
|
| 477 |
-
all_keys = list(_state_dict.keys())
|
| 478 |
-
for k in all_keys:
|
| 479 |
-
if k.startswith("net_ema."):
|
| 480 |
-
break
|
| 481 |
-
else:
|
| 482 |
-
raise ValueError("No EMA keys found in state_dict")
|
| 483 |
-
# do not load .net keys, since we do not need them anyway.
|
| 484 |
-
_state_dict = {k: _state_dict[k] for k in all_keys if not k.startswith("net.")}
|
| 485 |
-
|
| 486 |
-
if hasattr(self.model[0].config, "lora_config") and self.model[0].config.lora_config.enabled:
|
| 487 |
-
"""
|
| 488 |
-
When using LoRA, `inject_adapter_in_model` modifies the target modules in place.
|
| 489 |
-
For example, `blocks[0].attn.q_proj.weight` will be modified to `blocks[0].attn.q_proj.base_layer.weight`.
|
| 490 |
-
This means that the model will have the key `blocks[0].attn.q_proj.base_layer.weight`,
|
| 491 |
-
but the checkpoint will have the key `blocks[0].attn.q_proj.weight`.
|
| 492 |
-
We need to map the model key to the checkpoint key.
|
| 493 |
-
"""
|
| 494 |
-
self.checkpoint_to_model_key = {}
|
| 495 |
-
mapping_keys = {"base_layer.": "", "base_model.model.": ""}
|
| 496 |
-
keys_to_update = []
|
| 497 |
-
for k in _state_dict.keys():
|
| 498 |
-
new_key = k
|
| 499 |
-
for from_key, to_key in mapping_keys.items():
|
| 500 |
-
new_key = new_key.replace(from_key, to_key)
|
| 501 |
-
if new_key != k:
|
| 502 |
-
keys_to_update.append((k, new_key))
|
| 503 |
-
self.checkpoint_to_model_key[new_key] = k
|
| 504 |
-
for k, new_key in keys_to_update:
|
| 505 |
-
_state_dict[new_key] = _state_dict.pop(k)
|
| 506 |
-
|
| 507 |
-
return _state_dict
|
| 508 |
-
|
| 509 |
-
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
| 510 |
-
if hasattr(self.model[0].config, "lora_config") and self.model[0].config.lora_config.enabled:
|
| 511 |
-
if hasattr(self, "checkpoint_to_model_key"):
|
| 512 |
-
for checkpoint_key, model_key in self.checkpoint_to_model_key.items():
|
| 513 |
-
state_dict[model_key] = state_dict.pop(checkpoint_key)
|
| 514 |
-
else:
|
| 515 |
-
raise ValueError("checkpoint_to_model_key is not set by `state_dict`")
|
| 516 |
-
|
| 517 |
-
if self.load_ema_to_reg:
|
| 518 |
-
if not self.model[0].config.ema.enabled:
|
| 519 |
-
all_keys = list(state_dict.keys())
|
| 520 |
-
assert all(k.startswith("net_ema.") for k in all_keys), "All keys must start with net_ema."
|
| 521 |
-
for k in all_keys:
|
| 522 |
-
state_dict[k.replace("net_ema.", "net.")] = state_dict.pop(k)
|
| 523 |
-
else:
|
| 524 |
-
log.warning("EMA is enabled, will load EMA weights to regular model weights")
|
| 525 |
-
all_keys = list(state_dict.keys())
|
| 526 |
-
assert all(not k.startswith("net.") for k in all_keys), "No .net keys should be in state_dict"
|
| 527 |
-
for k in all_keys:
|
| 528 |
-
if k.startswith("net_ema."):
|
| 529 |
-
state_dict[k.replace("net_ema.", "net.")] = torch.clone(state_dict[k])
|
| 530 |
-
func = functools.partial(
|
| 531 |
-
set_model_state_dict,
|
| 532 |
-
model_state_dict=state_dict,
|
| 533 |
-
options=StateDictOptions(strict=False),
|
| 534 |
-
)
|
| 535 |
-
list(map(func, self.model))
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
class OptimizerWrapper(Stateful):
|
| 539 |
-
def __init__(
|
| 540 |
-
self,
|
| 541 |
-
model: Union[nn.Module, List[nn.Module]],
|
| 542 |
-
optim: Union[torch.optim.Optimizer, List[torch.optim.Optimizer]],
|
| 543 |
-
) -> None:
|
| 544 |
-
self.model = [model] if isinstance(model, nn.Module) else model
|
| 545 |
-
self.optim = [optim] if isinstance(optim, torch.optim.Optimizer) else optim
|
| 546 |
-
|
| 547 |
-
def state_dict(self) -> Dict[str, Any]:
|
| 548 |
-
func = functools.partial(
|
| 549 |
-
get_optimizer_state_dict,
|
| 550 |
-
options=StateDictOptions(flatten_optimizer_state_dict=True),
|
| 551 |
-
)
|
| 552 |
-
return {k: v for sd in map(func, self.model, self.optim) for k, v in sd.items()}
|
| 553 |
-
|
| 554 |
-
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
| 555 |
-
func = functools.partial(
|
| 556 |
-
set_optimizer_state_dict,
|
| 557 |
-
optim_state_dict=state_dict,
|
| 558 |
-
options=StateDictOptions(flatten_optimizer_state_dict=True),
|
| 559 |
-
)
|
| 560 |
-
list(map(func, self.model, self.optim))
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
class AsyncMode(str, enum.Enum):
|
| 564 |
-
DISABLED = "disabled"
|
| 565 |
-
ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem"
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
class Terminate:
|
| 569 |
-
pass
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
class SaveDone:
|
| 573 |
-
def __init__(self, iteration: int, elapsed_time: float, succeeded: bool):
|
| 574 |
-
self.iteration = iteration
|
| 575 |
-
self.elapsed_time = elapsed_time
|
| 576 |
-
self.succeeded = succeeded
|
| 577 |
-
|
| 578 |
-
def __str__(self):
|
| 579 |
-
return f"SaveDone(iteration={self.iteration}, elapsed_time={self.elapsed_time}, succeeded={self.succeeded})"
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
def save_checkpoint_in_background(
|
| 583 |
-
receiver_queue: multiprocessing.Queue,
|
| 584 |
-
sender_queue: multiprocessing.Queue,
|
| 585 |
-
checkpoint_config: CheckpointConfig,
|
| 586 |
-
job_config: JobConfig,
|
| 587 |
-
) -> None:
|
| 588 |
-
"""
|
| 589 |
-
Handles model checkpoint saving in a separate background process using PyTorch's distributed functionality.
|
| 590 |
-
This function runs in a dedicated process to avoid blocking the main training loop.
|
| 591 |
-
|
| 592 |
-
Args:
|
| 593 |
-
receiver_queue: Queue to receive state dictionaries and commands from the main process
|
| 594 |
-
sender_queue: Queue to send completion signals back to the main process
|
| 595 |
-
checkpoint_config: Configuration settings for checkpoint saving behavior
|
| 596 |
-
job_config: Configuration settings for the training job
|
| 597 |
-
|
| 598 |
-
Flow:
|
| 599 |
-
1. Initializes distributed processing environment
|
| 600 |
-
2. Continuously waits for state dictionaries to save
|
| 601 |
-
3. Saves checkpoints asynchronously
|
| 602 |
-
4. Signals completion back to main process
|
| 603 |
-
5. Terminates when receiving a Terminate signal
|
| 604 |
-
|
| 605 |
-
Raises:
|
| 606 |
-
AssertionError: If received object is neither Terminate signal nor valid state dict tuple
|
| 607 |
-
|
| 608 |
-
Note:
|
| 609 |
-
- Uses a different port than the main process to avoid conflicts
|
| 610 |
-
- Disables TorchElastic agent store for checkpoint operations
|
| 611 |
-
- Automatically cleans up distributed process group on exit
|
| 612 |
-
"""
|
| 613 |
-
# Configure distributed environment
|
| 614 |
-
os.environ["MASTER_PORT"] = str(int(os.environ["MASTER_PORT"]) + 2)
|
| 615 |
-
os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False"
|
| 616 |
-
|
| 617 |
-
# Set up GPU device and distributed processing
|
| 618 |
-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
| 619 |
-
distributed.init()
|
| 620 |
-
|
| 621 |
-
# Initialize checkpointing mechanism
|
| 622 |
-
checkpoint_handler = DistributedCheckpointer(checkpoint_config, job_config, None, disable_async=True)
|
| 623 |
-
|
| 624 |
-
try:
|
| 625 |
-
while True:
|
| 626 |
-
log.debug("Checkpoint background process is ready for next task, waiting for new state_dict")
|
| 627 |
-
received_data = receiver_queue.get()
|
| 628 |
-
log.debug("Received new state_dict")
|
| 629 |
-
|
| 630 |
-
if isinstance(received_data, Terminate):
|
| 631 |
-
log.info("Received termination signal in checkpoint background process, closing sender queue")
|
| 632 |
-
sender_queue.put(Terminate())
|
| 633 |
-
sender_queue.close()
|
| 634 |
-
return
|
| 635 |
-
|
| 636 |
-
assert isinstance(received_data, tuple), "Received data must be a tuple of (state_dict, checkpoint_path)"
|
| 637 |
-
state_dict, checkpoint_path = received_data
|
| 638 |
-
|
| 639 |
-
# Save checkpoint and measure time taken
|
| 640 |
-
start_time = time.monotonic()
|
| 641 |
-
iteration = state_dict["trainer"][0]["iteration"]
|
| 642 |
-
elapsed_time = 0
|
| 643 |
-
succeeded = False
|
| 644 |
-
try:
|
| 645 |
-
checkpoint_handler.save_state_dict_worker(state_dict, checkpoint_path)
|
| 646 |
-
elapsed_time = time.monotonic() - start_time
|
| 647 |
-
log.info(
|
| 648 |
-
f"Checkpoint saved successfully in background process. Time taken: {elapsed_time:.2f} seconds, iteration: {iteration}"
|
| 649 |
-
)
|
| 650 |
-
succeeded = True
|
| 651 |
-
except Exception as e:
|
| 652 |
-
log.error(f"Error saving checkpoint to {checkpoint_path}: {e}")
|
| 653 |
-
# continue because if the thread exits, the main thread keeps on adding to the queue
|
| 654 |
-
finally:
|
| 655 |
-
if elapsed_time == 0:
|
| 656 |
-
elapsed_time = time.monotonic() - start_time
|
| 657 |
-
sender_queue.put(SaveDone(iteration, elapsed_time, succeeded))
|
| 658 |
-
|
| 659 |
-
finally:
|
| 660 |
-
log.info("Cleaning up: destroying distributed process group")
|
| 661 |
-
torch.distributed.destroy_process_group()
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
class DistributedCheckpointer(AbstractCheckpointer):
|
| 665 |
-
KEYS_TO_SAVE = ["model", "optim", "scheduler", "trainer"]
|
| 666 |
-
|
| 667 |
-
def __init__(
|
| 668 |
-
self,
|
| 669 |
-
config_checkpoint: CheckpointConfig,
|
| 670 |
-
config_job: JobConfig,
|
| 671 |
-
callbacks: Optional[callback.CallBackGroup] = None,
|
| 672 |
-
disable_async: bool = False,
|
| 673 |
-
):
|
| 674 |
-
super().__init__(config_checkpoint, config_job, callbacks)
|
| 675 |
-
self.config_checkpoint = config_checkpoint
|
| 676 |
-
if config_checkpoint.dcp_async_mode_enabled:
|
| 677 |
-
self.async_mode = AsyncMode.ASYNC_WITH_PINNED_MEM
|
| 678 |
-
else:
|
| 679 |
-
self.async_mode = AsyncMode.DISABLED
|
| 680 |
-
|
| 681 |
-
if disable_async:
|
| 682 |
-
self.async_mode = AsyncMode.DISABLED
|
| 683 |
-
|
| 684 |
-
if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
|
| 685 |
-
ctx = get_context("spawn")
|
| 686 |
-
self.mp_queue_send = ctx.Queue()
|
| 687 |
-
self.mp_queue_recv = ctx.Queue()
|
| 688 |
-
self.mp = ctx.Process(
|
| 689 |
-
target=save_checkpoint_in_background,
|
| 690 |
-
args=(
|
| 691 |
-
self.mp_queue_send,
|
| 692 |
-
self.mp_queue_recv,
|
| 693 |
-
config_checkpoint,
|
| 694 |
-
config_job,
|
| 695 |
-
),
|
| 696 |
-
daemon=True,
|
| 697 |
-
)
|
| 698 |
-
self.mp.start()
|
| 699 |
-
self.cpu_offload_state_dict = None
|
| 700 |
-
self.staging = False
|
| 701 |
-
self.staging_ckpt_file = None
|
| 702 |
-
self.staging_stream = torch.cuda.Stream()
|
| 703 |
-
|
| 704 |
-
def keys_to_resume_during_load(self) -> Tuple[Set, Union[str, None]]:
|
| 705 |
-
latest_checkpoint_file = self._read_latest_checkpoint_file()
|
| 706 |
-
|
| 707 |
-
resume_keys = []
|
| 708 |
-
|
| 709 |
-
if latest_checkpoint_file is not None:
|
| 710 |
-
# 1. Resume training from latest_checkpoint.txt under the same name.
|
| 711 |
-
checkpoint_path = os.path.join(self.load_dirname, latest_checkpoint_file)
|
| 712 |
-
resume_keys.extend(self.KEYS_TO_SAVE)
|
| 713 |
-
else:
|
| 714 |
-
if self.load_path and not self.load_path.endswith(".pth"):
|
| 715 |
-
# 2. Load the module weights specified by config_checkpoint.path.
|
| 716 |
-
checkpoint_path = self.load_path
|
| 717 |
-
if self.load_s3_backend_key:
|
| 718 |
-
checkpoint_path = f"s3://{self.config_checkpoint.load_from_object_store.bucket}/{checkpoint_path}"
|
| 719 |
-
if not re.search(r"/checkpoints/iter_\d{9}/?$", checkpoint_path):
|
| 720 |
-
old_ckpt_path = checkpoint_path
|
| 721 |
-
# If path doesn't end with specific checkpoint, read latest checkpoint file
|
| 722 |
-
latest_ckpt_path = os.path.join(checkpoint_path, "checkpoints/latest_checkpoint.txt")
|
| 723 |
-
if easy_io.exists(latest_ckpt_path, backend_key=self.load_s3_backend_key):
|
| 724 |
-
checkpoint_file = easy_io.load(
|
| 725 |
-
latest_ckpt_path, backend_key=self.load_s3_backend_key
|
| 726 |
-
).strip()
|
| 727 |
-
checkpoint_path = f"{checkpoint_path}/checkpoints/{checkpoint_file}"
|
| 728 |
-
else:
|
| 729 |
-
log.warning(
|
| 730 |
-
f"Latest checkpoint file {latest_ckpt_path} not found, load from {old_ckpt_path}"
|
| 731 |
-
)
|
| 732 |
-
checkpoint_path = old_ckpt_path
|
| 733 |
-
if self.load_training_state:
|
| 734 |
-
resume_keys.extend(self.KEYS_TO_SAVE)
|
| 735 |
-
else:
|
| 736 |
-
resume_keys.append("model")
|
| 737 |
-
if self.only_load_scheduler_state:
|
| 738 |
-
resume_keys.append("scheduler")
|
| 739 |
-
elif self.load_path and self.load_path.endswith(".pth"):
|
| 740 |
-
checkpoint_path = self.load_path
|
| 741 |
-
else:
|
| 742 |
-
checkpoint_path = None
|
| 743 |
-
if len(self.keys_not_to_resume) > 0:
|
| 744 |
-
for key in self.keys_not_to_resume:
|
| 745 |
-
assert key in self.KEYS_TO_SAVE, f"Invalid key to resume: {key} not in {self.KEYS_TO_SAVE}"
|
| 746 |
-
resume_keys = [key for key in resume_keys if key not in self.keys_not_to_resume]
|
| 747 |
-
return set(resume_keys), checkpoint_path
|
| 748 |
-
|
| 749 |
-
@misc.timer("checkpoint loading")
|
| 750 |
-
def load(
|
| 751 |
-
self,
|
| 752 |
-
model: ImaginaireModel,
|
| 753 |
-
optimizer: torch.optim.Optimizer | None = None,
|
| 754 |
-
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
|
| 755 |
-
grad_scaler: torch.amp.GradScaler | None = None,
|
| 756 |
-
) -> int:
|
| 757 |
-
if self.callbacks is not None:
|
| 758 |
-
self.callbacks.on_load_checkpoint_start(model)
|
| 759 |
-
|
| 760 |
-
resume_keys, checkpoint_path = self.keys_to_resume_during_load()
|
| 761 |
-
resume_keys = sorted(resume_keys)
|
| 762 |
-
log.info(f"Resuming ckpt {checkpoint_path} with keys: {resume_keys}")
|
| 763 |
-
|
| 764 |
-
iteration = 0
|
| 765 |
-
|
| 766 |
-
if checkpoint_path is not None and not checkpoint_path.endswith(".pth"):
|
| 767 |
-
self._check_checkpoint_exists(checkpoint_path)
|
| 768 |
-
for key in resume_keys:
|
| 769 |
-
load_planner = DefaultLoadPlanner(allow_partial_load=True)
|
| 770 |
-
if hasattr(load_planner, "set_partial_channel_weight"):
|
| 771 |
-
log.info(f"set_partial_channel_weight: {self.config_checkpoint.dcp_allow_mismatched_size}")
|
| 772 |
-
load_planner.set_partial_channel_weight(self.config_checkpoint.dcp_allow_mismatched_size)
|
| 773 |
-
cur_key_ckpt_full_path = os.path.join(checkpoint_path, key)
|
| 774 |
-
log.info(f"Start loading checkpoint from {checkpoint_path}")
|
| 775 |
-
storage_reader = self.get_storage_reader(cur_key_ckpt_full_path)
|
| 776 |
-
torch.distributed.barrier()
|
| 777 |
-
log.info(f"starting {cur_key_ckpt_full_path}", rank0_only=False)
|
| 778 |
-
if key == "model":
|
| 779 |
-
log.info("- Loading the model...")
|
| 780 |
-
_model_wrapper = ModelWrapper(model)
|
| 781 |
-
_state_dict = _model_wrapper.state_dict()
|
| 782 |
-
load(_state_dict, storage_reader=storage_reader, planner=load_planner)
|
| 783 |
-
_model_wrapper.load_state_dict(_state_dict)
|
| 784 |
-
elif key == "optim":
|
| 785 |
-
log.info("- Loading the optimizer...")
|
| 786 |
-
_optim_wrapper = OptimizerWrapper(model, optimizer)
|
| 787 |
-
_state_dict = _optim_wrapper.state_dict()
|
| 788 |
-
dcp.load(
|
| 789 |
-
_state_dict,
|
| 790 |
-
storage_reader=storage_reader,
|
| 791 |
-
planner=load_planner,
|
| 792 |
-
)
|
| 793 |
-
_optim_wrapper.load_state_dict(_state_dict)
|
| 794 |
-
elif key == "scheduler":
|
| 795 |
-
log.info("- Loading the scheduler...")
|
| 796 |
-
_state_dict = scheduler.state_dict()
|
| 797 |
-
dcp.load(
|
| 798 |
-
_state_dict,
|
| 799 |
-
storage_reader=storage_reader,
|
| 800 |
-
planner=load_planner,
|
| 801 |
-
)
|
| 802 |
-
scheduler.load_state_dict(_state_dict)
|
| 803 |
-
elif key == "trainer":
|
| 804 |
-
log.info("- Loading the trainer...")
|
| 805 |
-
_state_dict = {
|
| 806 |
-
"grad_scaler": grad_scaler.state_dict(),
|
| 807 |
-
"iteration": iteration,
|
| 808 |
-
}
|
| 809 |
-
dcp.load(
|
| 810 |
-
_state_dict,
|
| 811 |
-
storage_reader=storage_reader,
|
| 812 |
-
planner=load_planner,
|
| 813 |
-
)
|
| 814 |
-
grad_scaler.load_state_dict(_state_dict["grad_scaler"])
|
| 815 |
-
iteration = _state_dict["iteration"]
|
| 816 |
-
else:
|
| 817 |
-
raise ValueError(f"Invalid key: {key}. not support to resume.")
|
| 818 |
-
if self.callbacks is not None:
|
| 819 |
-
self.callbacks.on_load_checkpoint(model, state_dict=_state_dict)
|
| 820 |
-
log.info(f"Loaded checkpoint from {checkpoint_path} in iteration {iteration}")
|
| 821 |
-
elif checkpoint_path is not None and checkpoint_path.endswith(".pth"):
|
| 822 |
-
state = easy_io.load(checkpoint_path)
|
| 823 |
-
model_state = model.net.state_dict()
|
| 824 |
-
|
| 825 |
-
for k, v in list(state.items()):
|
| 826 |
-
tgt = model_state.get(k, None)
|
| 827 |
-
if tgt is None:
|
| 828 |
-
continue
|
| 829 |
-
# If target param/buffer is a DTensor and checkpoint value is not, distribute it
|
| 830 |
-
if isinstance(tgt, DTensor) and not isinstance(v, DTensor):
|
| 831 |
-
# Match device, dtype, and placements from the target DTensor
|
| 832 |
-
v = v.to(tgt.device, dtype=tgt.dtype, copy=False)
|
| 833 |
-
v = distribute_tensor(v, tgt.device_mesh, tgt.placements)
|
| 834 |
-
state[k] = v
|
| 835 |
-
# If target is a plain Tensor but checkpoint is a DTensor, bring it local
|
| 836 |
-
if not isinstance(tgt, DTensor) and isinstance(v, DTensor):
|
| 837 |
-
state[k] = v.to_local().to(tgt.device, dtype=tgt.dtype, copy=False)
|
| 838 |
-
|
| 839 |
-
model.load_state_dict(state, strict=False, pretrain_copy=True)
|
| 840 |
-
# Clear unused reserved memory from fp32
|
| 841 |
-
torch.cuda.empty_cache()
|
| 842 |
-
log.critical(f"Loaded checkpoint from {checkpoint_path}. **** This only happen at iteration 0 **** ")
|
| 843 |
-
else:
|
| 844 |
-
log.info("Training from scratch.")
|
| 845 |
-
torch.cuda.empty_cache()
|
| 846 |
-
|
| 847 |
-
if self.callbacks is not None:
|
| 848 |
-
self.callbacks.on_load_checkpoint_end(model, iteration=iteration, checkpoint_path=checkpoint_path)
|
| 849 |
-
return iteration
|
| 850 |
-
|
| 851 |
-
def _async_with_pinned_memory(self, checkpoint_file: str, state_dict: Dict[str, Tuple[Any, str]]) -> None:
|
| 852 |
-
try:
|
| 853 |
-
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
|
| 854 |
-
except ImportError as e:
|
| 855 |
-
raise ImportError(
|
| 856 |
-
"Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
|
| 857 |
-
) from e
|
| 858 |
-
if self.cpu_offload_state_dict is None:
|
| 859 |
-
log.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
|
| 860 |
-
self.cpu_offload_state_dict = _create_cpu_state_dict(state_dict, pin_memory=True, share_memory=True)
|
| 861 |
-
|
| 862 |
-
log.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
|
| 863 |
-
with torch.cuda.stream(self.staging_stream):
|
| 864 |
-
self.cpu_offload_state_dict = _copy_state_dict(
|
| 865 |
-
state_dict,
|
| 866 |
-
self.cpu_offload_state_dict,
|
| 867 |
-
non_blocking=True,
|
| 868 |
-
)
|
| 869 |
-
self.staging = True
|
| 870 |
-
self.staging_ckpt_file = checkpoint_file
|
| 871 |
-
|
| 872 |
-
self.maybe_wait_for_staging()
|
| 873 |
-
|
| 874 |
-
def maybe_wait_for_staging(self) -> None:
|
| 875 |
-
if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM and self.staging:
|
| 876 |
-
if not self.staging_stream.query():
|
| 877 |
-
self.staging_stream.synchronize()
|
| 878 |
-
|
| 879 |
-
def sync_func():
|
| 880 |
-
self.mp_queue_send.put_nowait((self.cpu_offload_state_dict, self.staging_ckpt_file))
|
| 881 |
-
|
| 882 |
-
sync_func()
|
| 883 |
-
self.staging = False
|
| 884 |
-
|
| 885 |
-
def get_previous_checkpoint_results(self, wait_for: int = 0) -> None:
|
| 886 |
-
"""Get the results of previously submitted checkpoints and pass them to callbacks if checkpoint succeeded"""
|
| 887 |
-
if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
|
| 888 |
-
try:
|
| 889 |
-
start_time = time.monotonic()
|
| 890 |
-
while not self.mp_queue_recv.empty() or wait_for > 0:
|
| 891 |
-
try:
|
| 892 |
-
ret = self.mp_queue_recv.get(timeout=1)
|
| 893 |
-
if isinstance(ret, Terminate):
|
| 894 |
-
log.info("Received termination event from checkpoint background process")
|
| 895 |
-
break
|
| 896 |
-
save_done: SaveDone = ret
|
| 897 |
-
log.logger.info(f"Received checkpoint save result: {save_done}")
|
| 898 |
-
if self.callbacks is not None and save_done.succeeded:
|
| 899 |
-
self.callbacks.on_save_checkpoint_success(
|
| 900 |
-
iteration=save_done.iteration, elapsed_time=save_done.elapsed_time
|
| 901 |
-
)
|
| 902 |
-
except queue.Empty:
|
| 903 |
-
elapsed_time = time.monotonic() - start_time
|
| 904 |
-
if elapsed_time > wait_for:
|
| 905 |
-
break
|
| 906 |
-
except (EOFError, BrokenPipeError):
|
| 907 |
-
log.info("Queue was closed by checkpoint background process")
|
| 908 |
-
|
| 909 |
-
def get_storage_writer(self, checkpoint_path: str) -> Union[S3StorageWriter, FileSystemWriter]:
|
| 910 |
-
if self.save_to_object_store:
|
| 911 |
-
return S3StorageWriter(
|
| 912 |
-
credential_path=self.config_checkpoint.save_to_object_store.credentials,
|
| 913 |
-
path=checkpoint_path,
|
| 914 |
-
)
|
| 915 |
-
return FileSystemWriter(path=checkpoint_path)
|
| 916 |
-
|
| 917 |
-
def get_storage_reader(self, checkpoint_path: str) -> Union[S3StorageReader, FileSystemReader]:
|
| 918 |
-
if self.load_from_object_store:
|
| 919 |
-
return S3StorageReader(
|
| 920 |
-
credential_path=self.config_checkpoint.load_from_object_store.credentials,
|
| 921 |
-
path=checkpoint_path,
|
| 922 |
-
)
|
| 923 |
-
return FileSystemReader(checkpoint_path)
|
| 924 |
-
|
| 925 |
-
def save_state_dict_worker(self, to_save_dict: Dict[str, Tuple[Any, str]], checkpoint_file: str) -> None:
|
| 926 |
-
for k, (v, full_checkpoint_path) in to_save_dict.items():
|
| 927 |
-
storage_writer = self.get_storage_writer(full_checkpoint_path)
|
| 928 |
-
dcp.save(
|
| 929 |
-
v,
|
| 930 |
-
storage_writer=storage_writer,
|
| 931 |
-
planner=DefaultSavePlanner(dedup_save_to_lowest_rank=True),
|
| 932 |
-
)
|
| 933 |
-
|
| 934 |
-
if distributed.is_rank0():
|
| 935 |
-
print(f"Saving last checkpoint file {checkpoint_file}")
|
| 936 |
-
self._write_latest_checkpoint_file(checkpoint_file)
|
| 937 |
-
|
| 938 |
-
log.critical(f"Saved checkpoint to {os.path.join(self.save_dirname, checkpoint_file)}", rank0_only=True)
|
| 939 |
-
|
| 940 |
-
def save(
|
| 941 |
-
self,
|
| 942 |
-
model: ImaginaireModel,
|
| 943 |
-
optimizer: torch.optim.Optimizer,
|
| 944 |
-
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 945 |
-
grad_scaler: torch.amp.GradScaler,
|
| 946 |
-
iteration: int,
|
| 947 |
-
) -> None:
|
| 948 |
-
"""Save network weights, optimizer parameters, scheduler parameters to a checkpoint.
|
| 949 |
-
|
| 950 |
-
Args:
|
| 951 |
-
model (ImaginaireModel): The PyTorch model.
|
| 952 |
-
optimizer (torch.optim.Optimizer): The model optimizer.
|
| 953 |
-
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
|
| 954 |
-
grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
|
| 955 |
-
iteration (int): Current iteration number.
|
| 956 |
-
"""
|
| 957 |
-
if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
|
| 958 |
-
self.get_previous_checkpoint_results(wait_for=0)
|
| 959 |
-
|
| 960 |
-
if self.callbacks is not None:
|
| 961 |
-
self.callbacks.on_save_checkpoint_start(model, iteration)
|
| 962 |
-
|
| 963 |
-
checkpoint_file = f"iter_{iteration:09}"
|
| 964 |
-
to_save_dict = {
|
| 965 |
-
"model": ModelWrapper(model).state_dict(),
|
| 966 |
-
"optim": OptimizerWrapper(model, optimizer).state_dict(),
|
| 967 |
-
"scheduler": scheduler.state_dict(),
|
| 968 |
-
"trainer": {
|
| 969 |
-
"grad_scaler": grad_scaler.state_dict(),
|
| 970 |
-
"iteration": iteration,
|
| 971 |
-
},
|
| 972 |
-
}
|
| 973 |
-
for k in to_save_dict.keys():
|
| 974 |
-
output_dirname = os.path.join(self.save_dirname, f"iter_{iteration:09}/{k}")
|
| 975 |
-
to_save_dict[k] = (to_save_dict[k], output_dirname)
|
| 976 |
-
|
| 977 |
-
if self.callbacks is not None:
|
| 978 |
-
self.callbacks.on_save_checkpoint(model, state_dict=to_save_dict)
|
| 979 |
-
|
| 980 |
-
if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
|
| 981 |
-
self._async_with_pinned_memory(checkpoint_file, to_save_dict)
|
| 982 |
-
else:
|
| 983 |
-
start_time = time.monotonic()
|
| 984 |
-
try:
|
| 985 |
-
self.save_state_dict_worker(to_save_dict, checkpoint_file)
|
| 986 |
-
finally:
|
| 987 |
-
if self.callbacks is not None:
|
| 988 |
-
self.callbacks.on_save_checkpoint_success(
|
| 989 |
-
iteration=iteration, elapsed_time=time.monotonic() - start_time
|
| 990 |
-
)
|
| 991 |
-
|
| 992 |
-
# This measures exposed (synchronous) checkpoint time, on_save_checkpoint_success()
|
| 993 |
-
# is instead called to measure the entire duration for asynchronous checkpoint for the async case too.
|
| 994 |
-
if self.callbacks is not None:
|
| 995 |
-
self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration)
|
| 996 |
-
|
| 997 |
-
def finalize(self) -> None:
|
| 998 |
-
super().finalize()
|
| 999 |
-
if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
|
| 1000 |
-
if self.mp and self.mp.is_alive():
|
| 1001 |
-
self.mp_queue_send.put(Terminate())
|
| 1002 |
-
self.get_previous_checkpoint_results(wait_for=60)
|
| 1003 |
-
self.mp.join()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/checkpointer/dummy.py
DELETED
|
@@ -1,47 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from typing import Optional
|
| 17 |
-
|
| 18 |
-
import torch
|
| 19 |
-
import torch.distributed
|
| 20 |
-
|
| 21 |
-
from lyra_2._ext.imaginaire.checkpointer.base import AbstractCheckpointer
|
| 22 |
-
from lyra_2._ext.imaginaire.model import ImaginaireModel
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class Checkpointer(AbstractCheckpointer):
|
| 26 |
-
"""
|
| 27 |
-
A dummy checkpointer that does not save or load anything. This is useful for debugging jobs or share workload with collobrators.
|
| 28 |
-
"""
|
| 29 |
-
|
| 30 |
-
def save(
|
| 31 |
-
self,
|
| 32 |
-
model: ImaginaireModel,
|
| 33 |
-
optimizer: torch.optim.Optimizer,
|
| 34 |
-
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 35 |
-
grad_scaler: torch.amp.GradScaler,
|
| 36 |
-
iteration: int,
|
| 37 |
-
) -> None:
|
| 38 |
-
pass
|
| 39 |
-
|
| 40 |
-
def load(
|
| 41 |
-
self,
|
| 42 |
-
model: ImaginaireModel,
|
| 43 |
-
optimizer: Optional[torch.optim.Optimizer] = None,
|
| 44 |
-
scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
|
| 45 |
-
grad_scaler: Optional[torch.amp.GradScaler] = None,
|
| 46 |
-
) -> int:
|
| 47 |
-
return 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/checkpointer/s3_filesystem.py
DELETED
|
@@ -1,323 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import io
|
| 17 |
-
import json
|
| 18 |
-
import os
|
| 19 |
-
import time
|
| 20 |
-
from contextlib import contextmanager
|
| 21 |
-
from typing import Generator, Union
|
| 22 |
-
from urllib.parse import urlparse
|
| 23 |
-
|
| 24 |
-
import boto3
|
| 25 |
-
from botocore.config import Config
|
| 26 |
-
from botocore.exceptions import ClientError
|
| 27 |
-
from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
|
| 28 |
-
from torch.distributed.checkpoint.filesystem import FileSystemBase
|
| 29 |
-
|
| 30 |
-
from lyra_2._ext.imaginaire.utils import log
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class S3Stream(io.BytesIO):
|
| 34 |
-
"""
|
| 35 |
-
Workaround for PyTorch manually closing the stream before we can upload it to S3. We override the close() as noop
|
| 36 |
-
and instead call our own _true_close() method to close the stream after we are done using it.
|
| 37 |
-
The commit at fault is https://github.com/pytorch/pytorch/commit/9c909bf3bb122db2cce95e2eb7459bbe50dfa15a
|
| 38 |
-
"""
|
| 39 |
-
|
| 40 |
-
def close(self):
|
| 41 |
-
self.flush()
|
| 42 |
-
# No close
|
| 43 |
-
|
| 44 |
-
def _true_close(self):
|
| 45 |
-
super().close()
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
class S3FileSystem(FileSystemBase):
|
| 49 |
-
"""Implementation of FileSystemBase for AWS S3 storage."""
|
| 50 |
-
|
| 51 |
-
def __init__(
|
| 52 |
-
self,
|
| 53 |
-
credential_path: str,
|
| 54 |
-
max_attempts: int = 20,
|
| 55 |
-
initial_backoff: float = 1.0,
|
| 56 |
-
max_backoff: float = 30.0,
|
| 57 |
-
backoff_factor: float = 2.0,
|
| 58 |
-
) -> None:
|
| 59 |
-
"""
|
| 60 |
-
Initialize S3FileSystem with retry configuration.
|
| 61 |
-
|
| 62 |
-
Args:
|
| 63 |
-
credential_path: Path to AWS credentials JSON file
|
| 64 |
-
max_attempts: Maximum number of retry attempts
|
| 65 |
-
initial_backoff: Initial backoff time in seconds
|
| 66 |
-
max_backoff: Maximum backoff time in seconds
|
| 67 |
-
backoff_factor: Multiplicative factor for backoff time
|
| 68 |
-
"""
|
| 69 |
-
with open(credential_path, "r") as f:
|
| 70 |
-
conf = json.load(f)
|
| 71 |
-
|
| 72 |
-
# Configure boto3 with retry settings
|
| 73 |
-
config = Config(
|
| 74 |
-
retries=dict(max_attempts=max_attempts, mode="adaptive"), # Adaptive mode automatically handles throttling
|
| 75 |
-
connect_timeout=60,
|
| 76 |
-
read_timeout=60,
|
| 77 |
-
request_checksum_calculation="when_required", # Data integrity check for uploads and downloads
|
| 78 |
-
response_checksum_validation="when_required", # Data integrity check for uploads and downloads
|
| 79 |
-
)
|
| 80 |
-
|
| 81 |
-
self.s3_client = boto3.client("s3", config=config, **conf)
|
| 82 |
-
self.max_attempts = max_attempts
|
| 83 |
-
self.initial_backoff = initial_backoff
|
| 84 |
-
self.max_backoff = max_backoff
|
| 85 |
-
self.backoff_factor = backoff_factor
|
| 86 |
-
|
| 87 |
-
def _retry_with_backoff(self, operation_func, *args, **kwargs):
|
| 88 |
-
"""
|
| 89 |
-
Execute an operation with exponential backoff retry logic.
|
| 90 |
-
|
| 91 |
-
Args:
|
| 92 |
-
operation_func: Function to execute
|
| 93 |
-
*args: Positional arguments for the function
|
| 94 |
-
**kwargs: Keyword arguments for the function
|
| 95 |
-
|
| 96 |
-
Returns:
|
| 97 |
-
Result of the operation function
|
| 98 |
-
|
| 99 |
-
Raises:
|
| 100 |
-
Exception: If all retry attempts fail
|
| 101 |
-
"""
|
| 102 |
-
last_exception = None
|
| 103 |
-
backoff = self.initial_backoff
|
| 104 |
-
|
| 105 |
-
for attempt in range(self.max_attempts):
|
| 106 |
-
try:
|
| 107 |
-
return operation_func(*args, **kwargs)
|
| 108 |
-
except ClientError as e:
|
| 109 |
-
error_code = e.response.get("Error", {}).get("Code", "")
|
| 110 |
-
log.info(f"S3 Filesystem: Received ClientError: {error_code}", rank0_only=False)
|
| 111 |
-
|
| 112 |
-
# Handle specific error cases
|
| 113 |
-
if error_code in ["SlowDown", "ThrottlingException", "RequestLimitExceeded", "InternalError"]:
|
| 114 |
-
last_exception = e
|
| 115 |
-
if attempt < self.max_attempts - 1: # Don't sleep on last attempt
|
| 116 |
-
current_backoff = min(backoff, self.max_backoff)
|
| 117 |
-
log.info(f"S3 Filesystem: Retrying in {current_backoff} seconds", rank0_only=False)
|
| 118 |
-
time.sleep(current_backoff)
|
| 119 |
-
backoff *= self.backoff_factor
|
| 120 |
-
continue
|
| 121 |
-
# For other client errors, raise immediately
|
| 122 |
-
raise
|
| 123 |
-
except Exception as e:
|
| 124 |
-
log.info(f"S3 Filesystem: Received Exception: {str(e)}", rank0_only=False)
|
| 125 |
-
last_exception = e
|
| 126 |
-
if attempt < self.max_attempts - 1:
|
| 127 |
-
current_backoff = min(backoff, self.max_backoff)
|
| 128 |
-
log.info(f"S3 Filesystem: Retrying in {current_backoff} seconds", rank0_only=False)
|
| 129 |
-
time.sleep(current_backoff)
|
| 130 |
-
backoff *= self.backoff_factor
|
| 131 |
-
continue
|
| 132 |
-
|
| 133 |
-
raise last_exception
|
| 134 |
-
|
| 135 |
-
@contextmanager
|
| 136 |
-
def create_stream(self, path: Union[str, os.PathLike], mode: str) -> Generator[io.IOBase, None, None]:
|
| 137 |
-
"""Create a stream for reading from or writing to S3 with retry logic."""
|
| 138 |
-
path_str = str(path)
|
| 139 |
-
bucket, key = self._parse_s3_uri(path_str)
|
| 140 |
-
log.info(f"S3 Filesystem: Creating stream for {key} in bucket {bucket}", rank0_only=False)
|
| 141 |
-
|
| 142 |
-
if mode == "rb":
|
| 143 |
-
stream = io.BytesIO()
|
| 144 |
-
try:
|
| 145 |
-
|
| 146 |
-
def download_operation():
|
| 147 |
-
self.s3_client.download_fileobj(bucket, key, stream)
|
| 148 |
-
stream.seek(0)
|
| 149 |
-
|
| 150 |
-
log.info(f"S3 Filesystem: Downloading {key} from bucket {bucket}", rank0_only=False)
|
| 151 |
-
self._retry_with_backoff(download_operation)
|
| 152 |
-
log.info("S3 Filesystem: Download complete", rank0_only=False)
|
| 153 |
-
yield stream
|
| 154 |
-
finally:
|
| 155 |
-
stream.close()
|
| 156 |
-
elif mode == "wb":
|
| 157 |
-
stream = S3Stream()
|
| 158 |
-
try:
|
| 159 |
-
yield stream
|
| 160 |
-
|
| 161 |
-
def upload_operation():
|
| 162 |
-
stream.seek(0)
|
| 163 |
-
self.s3_client.upload_fileobj(stream, bucket, key)
|
| 164 |
-
|
| 165 |
-
log.info(f"S3 Filesystem: Uploading {key} to bucket {bucket}", rank0_only=False)
|
| 166 |
-
self._retry_with_backoff(upload_operation)
|
| 167 |
-
log.info("S3 Filesystem: Upload complete", rank0_only=False)
|
| 168 |
-
finally:
|
| 169 |
-
stream._true_close()
|
| 170 |
-
else:
|
| 171 |
-
raise ValueError(f"Unsupported mode: {mode}")
|
| 172 |
-
|
| 173 |
-
def concat_path(self, path: Union[str, os.PathLike], suffix: str) -> Union[str, os.PathLike]:
|
| 174 |
-
"""Concatenate S3 path with suffix."""
|
| 175 |
-
path_str = str(path)
|
| 176 |
-
if path_str.endswith("/"):
|
| 177 |
-
return f"{path_str}{suffix}"
|
| 178 |
-
return f"{path_str}/{suffix}"
|
| 179 |
-
|
| 180 |
-
def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
|
| 181 |
-
"""Initialize and validate S3 path."""
|
| 182 |
-
path_str = str(path)
|
| 183 |
-
if not path_str.startswith("s3://"):
|
| 184 |
-
raise ValueError(f"Invalid S3 URI: {path_str}. Must start with 's3://'")
|
| 185 |
-
return path_str
|
| 186 |
-
|
| 187 |
-
def rename(self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]) -> None:
|
| 188 |
-
"""Rename (move) an object in S3 with retry logic."""
|
| 189 |
-
src_bucket, src_key = self._parse_s3_uri(str(path))
|
| 190 |
-
dst_bucket, dst_key = self._parse_s3_uri(str(new_path))
|
| 191 |
-
|
| 192 |
-
def copy_operation():
|
| 193 |
-
copy_source = {"Bucket": src_bucket, "Key": src_key}
|
| 194 |
-
self.s3_client.copy(copy_source, dst_bucket, dst_key)
|
| 195 |
-
|
| 196 |
-
self._retry_with_backoff(copy_operation)
|
| 197 |
-
|
| 198 |
-
def delete_operation():
|
| 199 |
-
self.s3_client.delete_object(Bucket=src_bucket, Key=src_key)
|
| 200 |
-
|
| 201 |
-
self._retry_with_backoff(delete_operation)
|
| 202 |
-
|
| 203 |
-
def mkdir(self, path: Union[str, os.PathLike]) -> None:
|
| 204 |
-
"""
|
| 205 |
-
Create a "directory" in S3.
|
| 206 |
-
|
| 207 |
-
Note: S3 doesn't have real directories, but we can create an empty object
|
| 208 |
-
with a trailing slash to simulate a directory.
|
| 209 |
-
"""
|
| 210 |
-
# Creating same buckets from different ranks can cause rate limit issues in GCP.
|
| 211 |
-
# In object store, we don't need to create a directory.
|
| 212 |
-
pass
|
| 213 |
-
|
| 214 |
-
@classmethod
|
| 215 |
-
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
|
| 216 |
-
"""Validate if the checkpoint_id is a valid S3 URI."""
|
| 217 |
-
checkpoint_id_str = str(checkpoint_id)
|
| 218 |
-
try:
|
| 219 |
-
if not checkpoint_id_str.startswith("s3://"):
|
| 220 |
-
return False
|
| 221 |
-
parsed = urlparse(checkpoint_id_str)
|
| 222 |
-
return bool(parsed.netloc and parsed.path) # Must have bucket and key
|
| 223 |
-
except Exception:
|
| 224 |
-
return False
|
| 225 |
-
|
| 226 |
-
def exists(self, path: Union[str, os.PathLike]) -> bool:
|
| 227 |
-
"""Check if an object exists in S3 with retry logic."""
|
| 228 |
-
bucket, key = self._parse_s3_uri(str(path))
|
| 229 |
-
try:
|
| 230 |
-
|
| 231 |
-
def head_operation():
|
| 232 |
-
self.s3_client.head_object(Bucket=bucket, Key=key)
|
| 233 |
-
|
| 234 |
-
self._retry_with_backoff(head_operation)
|
| 235 |
-
return True
|
| 236 |
-
except ClientError as e:
|
| 237 |
-
if e.response.get("Error", {}).get("Code", "") == "404":
|
| 238 |
-
return False
|
| 239 |
-
raise
|
| 240 |
-
|
| 241 |
-
def rm_file(self, path: Union[str, os.PathLike]) -> None:
|
| 242 |
-
"""Remove a file from S3 with retry logic."""
|
| 243 |
-
bucket, key = self._parse_s3_uri(str(path))
|
| 244 |
-
|
| 245 |
-
def delete_operation():
|
| 246 |
-
self.s3_client.delete_object(Bucket=bucket, Key=key)
|
| 247 |
-
|
| 248 |
-
self._retry_with_backoff(delete_operation)
|
| 249 |
-
|
| 250 |
-
def _parse_s3_uri(self, uri: str) -> tuple[str, str]:
|
| 251 |
-
"""
|
| 252 |
-
Parse an S3 URI into bucket and key.
|
| 253 |
-
|
| 254 |
-
Args:
|
| 255 |
-
uri: S3 URI in the format s3://bucket-name/key
|
| 256 |
-
|
| 257 |
-
Returns:
|
| 258 |
-
Tuple of (bucket_name, key)
|
| 259 |
-
|
| 260 |
-
Raises:
|
| 261 |
-
ValueError: If the URI is invalid
|
| 262 |
-
"""
|
| 263 |
-
uri = uri if isinstance(uri, str) else str(uri)
|
| 264 |
-
if not uri.startswith("s3://"):
|
| 265 |
-
raise ValueError(f"Invalid S3 URI: {uri}. Must start with 's3://'")
|
| 266 |
-
|
| 267 |
-
parsed = urlparse(uri)
|
| 268 |
-
bucket = parsed.netloc
|
| 269 |
-
|
| 270 |
-
# Remove leading slash from key
|
| 271 |
-
key = parsed.path.lstrip("/")
|
| 272 |
-
|
| 273 |
-
if not bucket:
|
| 274 |
-
raise ValueError(f"Invalid S3 URI: {uri}. No bucket specified")
|
| 275 |
-
|
| 276 |
-
return bucket, key
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
class S3StorageWriter(FileSystemWriter):
|
| 280 |
-
def __init__(
|
| 281 |
-
self,
|
| 282 |
-
credential_path: str,
|
| 283 |
-
path: str,
|
| 284 |
-
**kwargs,
|
| 285 |
-
) -> None:
|
| 286 |
-
"""
|
| 287 |
-
Initialize an S3 writer for distributed checkpointing.
|
| 288 |
-
|
| 289 |
-
Args:
|
| 290 |
-
region (str): The AWS region for S3.
|
| 291 |
-
path (str): The S3 URI to write checkpoints to.
|
| 292 |
-
kwargs (dict): Keyword arguments to pass to the parent :class:`FileSystemWriter`.
|
| 293 |
-
"""
|
| 294 |
-
super().__init__(
|
| 295 |
-
path=path,
|
| 296 |
-
sync_files=False,
|
| 297 |
-
**kwargs,
|
| 298 |
-
)
|
| 299 |
-
self.fs = S3FileSystem(credential_path) # type: ignore
|
| 300 |
-
self.path = self.fs.init_path(path)
|
| 301 |
-
|
| 302 |
-
@classmethod
|
| 303 |
-
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
|
| 304 |
-
return S3FileSystem.validate_checkpoint_id(checkpoint_id)
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
class S3StorageReader(FileSystemReader):
|
| 308 |
-
def __init__(self, credential_path: str, path: Union[str, os.PathLike]) -> None:
|
| 309 |
-
"""
|
| 310 |
-
Initialize an S3 reader for distributed checkpointing.
|
| 311 |
-
|
| 312 |
-
Args:
|
| 313 |
-
region (str): The AWS region for S3.
|
| 314 |
-
path (Union[str, os.PathLike]): The S3 path to read checkpoints from.
|
| 315 |
-
"""
|
| 316 |
-
super().__init__(path)
|
| 317 |
-
self.fs = S3FileSystem(credential_path) # type: ignore
|
| 318 |
-
self.path = self.fs.init_path(path)
|
| 319 |
-
self.sync_files = False
|
| 320 |
-
|
| 321 |
-
@classmethod
|
| 322 |
-
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
|
| 323 |
-
return S3FileSystem.validate_checkpoint_id(checkpoint_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/config.py
DELETED
|
@@ -1,445 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
"""Training config system for Imaginare4"""
|
| 17 |
-
|
| 18 |
-
from __future__ import annotations
|
| 19 |
-
|
| 20 |
-
import os
|
| 21 |
-
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
| 22 |
-
|
| 23 |
-
import attrs
|
| 24 |
-
import torch
|
| 25 |
-
import torch.utils.data
|
| 26 |
-
import torch.utils.data.distributed
|
| 27 |
-
|
| 28 |
-
try:
|
| 29 |
-
from megatron.core import ModelParallelConfig
|
| 30 |
-
|
| 31 |
-
USE_MEGATRON = True
|
| 32 |
-
except ImportError:
|
| 33 |
-
USE_MEGATRON = False
|
| 34 |
-
print("Megatron-core is not installed.")
|
| 35 |
-
|
| 36 |
-
from lyra_2._ext.imaginaire.lazy_config import LazyCall as L
|
| 37 |
-
from lyra_2._ext.imaginaire.lazy_config import LazyDict
|
| 38 |
-
from lyra_2._ext.imaginaire.utils import callback, distributed
|
| 39 |
-
from lyra_2._ext.imaginaire.utils.misc import Color
|
| 40 |
-
|
| 41 |
-
T = TypeVar("T")
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def _is_attrs_instance(obj: object) -> bool:
|
| 45 |
-
"""
|
| 46 |
-
Helper function to check if an object is an instance of an attrs-defined class.
|
| 47 |
-
|
| 48 |
-
Args:
|
| 49 |
-
obj: The object to check.
|
| 50 |
-
|
| 51 |
-
Returns:
|
| 52 |
-
bool: True if the object is an instance of an attrs-defined class, False otherwise.
|
| 53 |
-
"""
|
| 54 |
-
return hasattr(obj, "__attrs_attrs__")
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def make_freezable(cls: T) -> T:
|
| 58 |
-
"""
|
| 59 |
-
A decorator that adds the capability to freeze instances of an attrs-defined class.
|
| 60 |
-
|
| 61 |
-
NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need
|
| 62 |
-
to hack on a "_is_frozen" attribute.
|
| 63 |
-
|
| 64 |
-
This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime.
|
| 65 |
-
Once an instance is frozen, its attributes cannot be changed. It also recursively freezes
|
| 66 |
-
any attrs-defined objects that are attributes of the class.
|
| 67 |
-
|
| 68 |
-
Usage:
|
| 69 |
-
@make_freezable
|
| 70 |
-
@attrs.define(slots=False)
|
| 71 |
-
class MyClass:
|
| 72 |
-
attribute1: int
|
| 73 |
-
attribute2: str
|
| 74 |
-
|
| 75 |
-
obj = MyClass(1, 'a')
|
| 76 |
-
obj.freeze() # Freeze the instance
|
| 77 |
-
obj.attribute1 = 2 # Raises AttributeError
|
| 78 |
-
|
| 79 |
-
Args:
|
| 80 |
-
cls: The class to be decorated.
|
| 81 |
-
|
| 82 |
-
Returns:
|
| 83 |
-
The decorated class with added freezing capability.
|
| 84 |
-
"""
|
| 85 |
-
|
| 86 |
-
if not hasattr(cls, "__dict__"):
|
| 87 |
-
raise TypeError(
|
| 88 |
-
"make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped "
|
| 89 |
-
"class was defined with `@attrs.define(slots=False)`"
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
original_setattr = cls.__setattr__
|
| 93 |
-
|
| 94 |
-
def setattr_override(self, key, value) -> None: # noqa: ANN001
|
| 95 |
-
"""
|
| 96 |
-
Override __setattr__ to allow modifications during initialization
|
| 97 |
-
and prevent modifications once the instance is frozen.
|
| 98 |
-
"""
|
| 99 |
-
if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen":
|
| 100 |
-
raise AttributeError("Cannot modify frozen instance")
|
| 101 |
-
original_setattr(self, key, value) # type: ignore
|
| 102 |
-
|
| 103 |
-
cls.__setattr__ = setattr_override # type: ignore
|
| 104 |
-
|
| 105 |
-
def freeze(self: object) -> None:
|
| 106 |
-
"""
|
| 107 |
-
Freeze the instance and all its attrs-defined attributes.
|
| 108 |
-
"""
|
| 109 |
-
for _, value in attrs.asdict(self, recurse=False).items():
|
| 110 |
-
if _is_attrs_instance(value) and hasattr(value, "freeze"):
|
| 111 |
-
value.freeze()
|
| 112 |
-
self._is_frozen = True # type: ignore
|
| 113 |
-
|
| 114 |
-
cls.freeze = freeze # type: ignore
|
| 115 |
-
|
| 116 |
-
return cls
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str:
|
| 120 |
-
"""
|
| 121 |
-
Recursively pretty prints attrs objects with color.
|
| 122 |
-
"""
|
| 123 |
-
|
| 124 |
-
assert attrs.has(obj.__class__)
|
| 125 |
-
|
| 126 |
-
lines: list[str] = []
|
| 127 |
-
for attribute in attrs.fields(obj.__class__):
|
| 128 |
-
value = getattr(obj, attribute.name)
|
| 129 |
-
if attrs.has(value.__class__):
|
| 130 |
-
if use_color:
|
| 131 |
-
lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":")
|
| 132 |
-
else:
|
| 133 |
-
lines.append(" " * indent + "* " + attribute.name + ":")
|
| 134 |
-
lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color))
|
| 135 |
-
else:
|
| 136 |
-
if use_color:
|
| 137 |
-
lines.append(
|
| 138 |
-
" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value)
|
| 139 |
-
)
|
| 140 |
-
else:
|
| 141 |
-
lines.append(" " * indent + "* " + attribute.name + ": " + str(value))
|
| 142 |
-
return "\n".join(lines)
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
def pretty_print_overrides(overrides: Optional[list[str]] = None, use_color: bool = False) -> str:
|
| 146 |
-
"""
|
| 147 |
-
Pretty prints overrides.
|
| 148 |
-
"""
|
| 149 |
-
|
| 150 |
-
lines: list[str] = []
|
| 151 |
-
lines.append(Color.cyan("* ") + Color.green("overrides") + ": ")
|
| 152 |
-
for override in overrides:
|
| 153 |
-
if override == "--":
|
| 154 |
-
continue
|
| 155 |
-
if override.startswith("~"):
|
| 156 |
-
attribute_name = override[1:]
|
| 157 |
-
attribute_value = None
|
| 158 |
-
else:
|
| 159 |
-
attribute_name, attribute_value = override.split("=")
|
| 160 |
-
if use_color:
|
| 161 |
-
lines.append(" " + Color.cyan("* ") + Color.green(attribute_name) + ": " + Color.yellow(attribute_value))
|
| 162 |
-
else:
|
| 163 |
-
lines.append(" " + "* " + attribute_name + ": " + str(attribute_value))
|
| 164 |
-
|
| 165 |
-
return "\n".join(lines)
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
@make_freezable
|
| 169 |
-
@attrs.define(slots=False) # slots=False is required for make_freezable. See the make_freezable notes for more info.
|
| 170 |
-
class ObjectStoreConfig:
|
| 171 |
-
# Whether the file I/O is from object store instead of local disk.
|
| 172 |
-
enabled: bool = False
|
| 173 |
-
# Path to the object store credentials file.
|
| 174 |
-
credentials: str = ""
|
| 175 |
-
# Object store bucket to read from / write to the objects.
|
| 176 |
-
bucket: str = ""
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
@make_freezable
|
| 180 |
-
@attrs.define(slots=False)
|
| 181 |
-
class JobConfig:
|
| 182 |
-
# Project name.
|
| 183 |
-
project: str = ""
|
| 184 |
-
# Experiment name.
|
| 185 |
-
group: str = ""
|
| 186 |
-
# Run/job name.
|
| 187 |
-
name: str = ""
|
| 188 |
-
@property
|
| 189 |
-
def path(self) -> str:
|
| 190 |
-
return f"{self.project}/{self.group}/{self.name}"
|
| 191 |
-
|
| 192 |
-
@property
|
| 193 |
-
def path_local(self) -> str:
|
| 194 |
-
local_root = os.environ.get("IMAGINAIRE_OUTPUT_ROOT", "/tmp/imaginaire4-output")
|
| 195 |
-
return f"{local_root}/{self.path}"
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
@make_freezable
|
| 199 |
-
@attrs.define(slots=False)
|
| 200 |
-
class EMAConfig:
|
| 201 |
-
# Enable tracking a set of exponential moving average (EMA) weights.
|
| 202 |
-
enabled: bool = False
|
| 203 |
-
# EMA decay rate.
|
| 204 |
-
beta: float = 0.9999
|
| 205 |
-
# Enable removing "_orig_mod-" from buffer names that is added by torch.compile
|
| 206 |
-
torch_compile_buffer_renaming: bool = False
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
@make_freezable
|
| 210 |
-
@attrs.define(slots=False)
|
| 211 |
-
class PowerEMAConfig:
|
| 212 |
-
# Enable tracking a set of exponential moving average (EMA) weights.
|
| 213 |
-
enabled: bool = False
|
| 214 |
-
# EDM2 paper EMA decay rate.
|
| 215 |
-
s: float = 0.1
|
| 216 |
-
# Enable removing "_orig_mod-" from buffer names that is added by torch.compile
|
| 217 |
-
torch_compile_buffer_renaming: bool = False
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
@make_freezable
|
| 221 |
-
@attrs.define(slots=False)
|
| 222 |
-
class DDPConfig:
|
| 223 |
-
# Traverse the computation graph to find parameters that don't receive gradients.
|
| 224 |
-
find_unused_parameters: bool = False
|
| 225 |
-
# Set to True if the computation graph does not change during the whole training loop.
|
| 226 |
-
static_graph: bool = True
|
| 227 |
-
# Set to True if we want to synchronize buffers. Set to False if the sync is going to be handled elsewhere.
|
| 228 |
-
broadcast_buffers: bool = True
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
@make_freezable
|
| 232 |
-
@attrs.define(slots=False)
|
| 233 |
-
class CuDNNConfig:
|
| 234 |
-
# Set to True for better reproducibility of the results (only using deterministic cudnn functions).
|
| 235 |
-
deterministic: bool = False
|
| 236 |
-
# If set to True, cudnn will benchmark several algorithms and pick the fastest one.
|
| 237 |
-
benchmark: bool = True
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
@make_freezable
|
| 241 |
-
@attrs.define(slots=False)
|
| 242 |
-
class JITConfig:
|
| 243 |
-
# Enable exporting a JIT compiled model.
|
| 244 |
-
enabled: bool = False
|
| 245 |
-
# Input tensor shape, for example input.
|
| 246 |
-
input_shape: Union[list[int], None] = None
|
| 247 |
-
# Device to compile onto.
|
| 248 |
-
device: str = "cuda"
|
| 249 |
-
# # Data type to compile onto.
|
| 250 |
-
dtype: str = "bfloat16"
|
| 251 |
-
# Strict mode for PyTorch JIT.
|
| 252 |
-
strict: bool = True
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
@make_freezable
|
| 256 |
-
@attrs.define(slots=False)
|
| 257 |
-
class CheckpointConfig:
|
| 258 |
-
# possible checkpoint class
|
| 259 |
-
type: Optional[Dict] = None
|
| 260 |
-
# for dcp, whether to use async mode
|
| 261 |
-
dcp_async_mode_enabled: bool = False
|
| 262 |
-
# Configs for saving the checkpoints to object store.
|
| 263 |
-
save_to_object_store: ObjectStoreConfig = attrs.field(factory=ObjectStoreConfig)
|
| 264 |
-
# Save the checkpoint every N iterations.
|
| 265 |
-
save_iter: int = 999999999
|
| 266 |
-
# Configs for loading the checkpoints from object store.
|
| 267 |
-
load_from_object_store: ObjectStoreConfig = attrs.field(factory=ObjectStoreConfig)
|
| 268 |
-
# Path of model weights to resume the checkpoint from.
|
| 269 |
-
load_path: str = ""
|
| 270 |
-
# Whether to load the training states (optimizer/scheduler/grad-scaler) from the checkpoint path.
|
| 271 |
-
load_training_state: bool = False
|
| 272 |
-
# Whether to load the scheduler state only from the checkpoint path. If load_training_state is True, this will be ignored.
|
| 273 |
-
only_load_scheduler_state: bool = False
|
| 274 |
-
# Load state_dict to the models in strict mode.
|
| 275 |
-
strict_resume: bool = True
|
| 276 |
-
# Configs for JIT compiling EMA model.
|
| 277 |
-
jit: JITConfig = attrs.field(factory=JITConfig)
|
| 278 |
-
# Print detailed information during checkpoint saving/loading.
|
| 279 |
-
verbose: bool = True
|
| 280 |
-
# keys not to resume from the checkpoint, choices: ["model", "optim", "scheduler", "trainer"]
|
| 281 |
-
keys_not_to_resume: list[str] = []
|
| 282 |
-
# Whether to use the local filesystem for broadcasting checkpoint data (used for Tensor Parallel Checkpointer).
|
| 283 |
-
broadcast_via_filesystem: bool = False
|
| 284 |
-
load_ema_to_reg: bool = False # used in inference, load EMA weights to regular model
|
| 285 |
-
# In dcp planner, skip the weight shape check, load weights into the model even weight shape is different
|
| 286 |
-
dcp_allow_mismatched_size: bool = False
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
@make_freezable
|
| 290 |
-
@attrs.define(slots=False)
|
| 291 |
-
class NVTXConfig:
|
| 292 |
-
"""Config for NVTX ranges used in the main training loop.
|
| 293 |
-
|
| 294 |
-
See tutorials/nanogpt for more details on how to integrate profiling into your model."""
|
| 295 |
-
|
| 296 |
-
# Enable the NVTX ranges.
|
| 297 |
-
enabled: bool = False
|
| 298 |
-
# Synchronize everything in each NVTX range.
|
| 299 |
-
cuda_synchronize: bool = False
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
@make_freezable
|
| 303 |
-
@attrs.define(slots=False)
|
| 304 |
-
class StragglerDetectionConfig:
|
| 305 |
-
"""Config for Straggler detection tool."""
|
| 306 |
-
|
| 307 |
-
# Enable the Straggler Detection.
|
| 308 |
-
enabled: bool = False
|
| 309 |
-
# How frequently should the Straggler reports be generated.
|
| 310 |
-
report_freq: int = 100
|
| 311 |
-
# How frequently iterations should be profiled
|
| 312 |
-
profile_freq: int = 1
|
| 313 |
-
# What is the maximum relative difference between GPUs after they are considered stragglers
|
| 314 |
-
max_diff: float = 2.0
|
| 315 |
-
# Should the error be raised when straggler is detected
|
| 316 |
-
raise_error: bool = True
|
| 317 |
-
# Analyze kernels in the forward pass.
|
| 318 |
-
analyze_forward: bool = True
|
| 319 |
-
# Analyze kernels in the backward pass.
|
| 320 |
-
analyze_backward: bool = True
|
| 321 |
-
# Analyze kernels in the optimizer.
|
| 322 |
-
analyze_optimizer: bool = True
|
| 323 |
-
# Analyze dataloading time.
|
| 324 |
-
analyze_dataloading: bool = True
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
@make_freezable
|
| 328 |
-
@attrs.define(slots=False)
|
| 329 |
-
class Profiling:
|
| 330 |
-
enable_profiling: bool = False
|
| 331 |
-
enable_memory_snapshot: bool = False
|
| 332 |
-
save_s3: bool = False
|
| 333 |
-
profile_freq: int = 1
|
| 334 |
-
# Target ranks for profiling, each entry must be >=0 and < world_size.
|
| 335 |
-
target_ranks: list[int] = list(range(8))
|
| 336 |
-
# Set `record_shape` and `profile_memory` to False to reduce profile size.
|
| 337 |
-
record_shape: bool = False
|
| 338 |
-
profile_memory: bool = False
|
| 339 |
-
with_stack: bool = True
|
| 340 |
-
with_modules: bool = True
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
@make_freezable
|
| 344 |
-
@attrs.define(slots=False)
|
| 345 |
-
class TrainerConfig:
|
| 346 |
-
from lyra_2._ext.imaginaire.trainer import ImaginaireTrainer
|
| 347 |
-
|
| 348 |
-
type: Type[ImaginaireTrainer] = ImaginaireTrainer
|
| 349 |
-
# Set the callback class.
|
| 350 |
-
# Defaults to the callbacks below.
|
| 351 |
-
callbacks: LazyDict = LazyDict(
|
| 352 |
-
dict(
|
| 353 |
-
ema=L(callback.EMAModelCallback)(),
|
| 354 |
-
progress_bar=L(callback.ProgressBarCallback)(),
|
| 355 |
-
)
|
| 356 |
-
)
|
| 357 |
-
# distributed parallelism strategy
|
| 358 |
-
distributed_parallelism: str = "ddp"
|
| 359 |
-
# Distributed data parallel configs.
|
| 360 |
-
ddp: DDPConfig = attrs.field(factory=DDPConfig)
|
| 361 |
-
# cuDNN configs.
|
| 362 |
-
cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig)
|
| 363 |
-
# Set the random seed.
|
| 364 |
-
seed: int = 0
|
| 365 |
-
# Set the random seed based on current timestamp
|
| 366 |
-
timestamp_seed: bool = False
|
| 367 |
-
# Gradient scaler arguments (for torch.amp.GradScaler).
|
| 368 |
-
grad_scaler_args: dict = attrs.field(factory=lambda: dict(enabled=False))
|
| 369 |
-
# Maximum number of iterations to train the model.
|
| 370 |
-
max_iter: int = 999999999
|
| 371 |
-
# Maximum number of iterations to validate the model. If None, validate on the entire dataset.
|
| 372 |
-
max_val_iter: int | None = None
|
| 373 |
-
# How often we log the training stats.
|
| 374 |
-
logging_iter: int = 100
|
| 375 |
-
# Whether we want to run the validation routines.
|
| 376 |
-
run_validation: bool = True
|
| 377 |
-
# How often we evaluate on the validation set.
|
| 378 |
-
validation_iter: int = 999999999
|
| 379 |
-
# Kill the process after N seconds since the last iteration (usually means dead job).
|
| 380 |
-
timeout_period: int = 999999999
|
| 381 |
-
# Tensor memory organization format.
|
| 382 |
-
memory_format: torch.memory_format = torch.preserve_format
|
| 383 |
-
# Gradient accumulation (update step every N iteration).
|
| 384 |
-
grad_accum_iter: int = 1
|
| 385 |
-
# Straggler Detection config
|
| 386 |
-
straggler_detection: StragglerDetectionConfig = attrs.field(factory=StragglerDetectionConfig)
|
| 387 |
-
# Profiling config
|
| 388 |
-
profiling: Profiling = attrs.field(factory=Profiling)
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
@make_freezable
|
| 392 |
-
@attrs.define(slots=False)
|
| 393 |
-
class Config:
|
| 394 |
-
"""Config for an imaginaire4 job.
|
| 395 |
-
|
| 396 |
-
See /README.md/Configuration System for more info.
|
| 397 |
-
"""
|
| 398 |
-
|
| 399 |
-
# Model configs.
|
| 400 |
-
model: LazyDict
|
| 401 |
-
# Optimizer configs.
|
| 402 |
-
optimizer: LazyDict
|
| 403 |
-
# Scheduler configs.
|
| 404 |
-
scheduler: LazyDict
|
| 405 |
-
# Training data configs.
|
| 406 |
-
dataloader_train: LazyDict
|
| 407 |
-
# Validation data configs.
|
| 408 |
-
dataloader_val: LazyDict
|
| 409 |
-
|
| 410 |
-
# Training job configs.
|
| 411 |
-
job: JobConfig = attrs.field(factory=JobConfig)
|
| 412 |
-
|
| 413 |
-
# Trainer configs.
|
| 414 |
-
trainer: TrainerConfig = attrs.field(factory=TrainerConfig)
|
| 415 |
-
|
| 416 |
-
if USE_MEGATRON:
|
| 417 |
-
# Megatron-Core configs
|
| 418 |
-
model_parallel: ModelParallelConfig = attrs.field(factory=ModelParallelConfig)
|
| 419 |
-
else:
|
| 420 |
-
model_parallel: None = None
|
| 421 |
-
|
| 422 |
-
# Checkpointer configs.
|
| 423 |
-
checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig)
|
| 424 |
-
|
| 425 |
-
# enable upload reproducible setup to s3
|
| 426 |
-
upload_reproducible_setup: bool = False
|
| 427 |
-
|
| 428 |
-
def pretty_print(self, use_color: bool = False) -> str:
|
| 429 |
-
return _pretty_print_attrs_instance(self, 0, use_color)
|
| 430 |
-
|
| 431 |
-
def to_dict(self) -> dict[str, Any]:
|
| 432 |
-
return attrs.asdict(self)
|
| 433 |
-
|
| 434 |
-
def validate(self) -> None:
|
| 435 |
-
"""Validate that the config has all required fields."""
|
| 436 |
-
|
| 437 |
-
# broadcast job.name across all ranks to make sure it is consistent
|
| 438 |
-
# otherwise, unaligned job names leads unaligned path to save checkpoints
|
| 439 |
-
job_name_tensor = torch.ByteTensor(bytearray(self.job.name, "utf-8")).cuda()
|
| 440 |
-
distributed.broadcast(job_name_tensor, 0)
|
| 441 |
-
self.job.name = job_name_tensor.cpu().numpy().tobytes().decode("utf-8")
|
| 442 |
-
|
| 443 |
-
assert self.job.project != ""
|
| 444 |
-
assert self.job.group != ""
|
| 445 |
-
assert self.job.name != ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/flags.py
DELETED
|
@@ -1,52 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
"""Feature flags."""
|
| 17 |
-
|
| 18 |
-
import os
|
| 19 |
-
from dataclasses import dataclass
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def _parse_bool(value: str) -> bool:
|
| 23 |
-
"""Parse string to a boolean."""
|
| 24 |
-
return value.lower() in ["true", "1", "yes", "y"]
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
INTERNAL = _parse_bool(os.environ.get("COSMOS_INTERNAL", "0"))
|
| 28 |
-
"""Whether to enable internal features."""
|
| 29 |
-
|
| 30 |
-
SMOKE = _parse_bool(os.environ.get("COSMOS_SMOKE", "0"))
|
| 31 |
-
"""Whether to enable smoke test.
|
| 32 |
-
|
| 33 |
-
Disables expensive operations such as checkpoint loading.
|
| 34 |
-
"""
|
| 35 |
-
|
| 36 |
-
VERBOSE = _parse_bool(os.environ.get("COSMOS_VERBOSE", "0"))
|
| 37 |
-
"""Whether to enable verbose output."""
|
| 38 |
-
|
| 39 |
-
EXPERIMENTAL_CHECKPOINTS = _parse_bool(os.environ.get("COSMOS_EXPERIMENTAL_CHECKPOINTS", "0"))
|
| 40 |
-
"""Whether to enable experimental checkpoints."""
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
@dataclass
|
| 44 |
-
class Flags:
|
| 45 |
-
internal: bool = INTERNAL
|
| 46 |
-
smoke: bool = SMOKE
|
| 47 |
-
verbose: bool = VERBOSE
|
| 48 |
-
experimental_checkpoints: bool = EXPERIMENTAL_CHECKPOINTS
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
FLAGS = Flags()
|
| 52 |
-
"""Convenience object for accessing flags."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/functional/batch_ops.py
DELETED
|
@@ -1,61 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
# Functions for performing operations with broadcasting to the right axis
|
| 17 |
-
#
|
| 18 |
-
# Example
|
| 19 |
-
# input1: tensor of size (N1, N2)
|
| 20 |
-
# input2: tensor of size (N1, N2, N3, N4)
|
| 21 |
-
# batch_mul(input1, input2) = input1[:, :, None, None] * input2
|
| 22 |
-
#
|
| 23 |
-
# If the common dimensions don't match, we raise an assertion error.
|
| 24 |
-
|
| 25 |
-
from torch import Tensor
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
|
| 29 |
-
ndims1 = x.ndim
|
| 30 |
-
ndims2 = y.ndim
|
| 31 |
-
|
| 32 |
-
common_ndims = min(ndims1, ndims2)
|
| 33 |
-
for axis in range(common_ndims):
|
| 34 |
-
assert x.shape[axis] == y.shape[axis], "Dimensions not equal at axis {}".format(axis)
|
| 35 |
-
|
| 36 |
-
if ndims1 < ndims2:
|
| 37 |
-
x = x.reshape(x.shape + (1,) * (ndims2 - ndims1))
|
| 38 |
-
elif ndims2 < ndims1:
|
| 39 |
-
y = y.reshape(y.shape + (1,) * (ndims1 - ndims2))
|
| 40 |
-
|
| 41 |
-
return x, y
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def batch_add(x: Tensor, y: Tensor) -> Tensor:
|
| 45 |
-
x, y = common_broadcast(x, y)
|
| 46 |
-
return x + y
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def batch_mul(x: Tensor, y: Tensor) -> Tensor:
|
| 50 |
-
x, y = common_broadcast(x, y)
|
| 51 |
-
return x * y
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def batch_sub(x: Tensor, y: Tensor) -> Tensor:
|
| 55 |
-
x, y = common_broadcast(x, y)
|
| 56 |
-
return x - y
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def batch_div(x: Tensor, y: Tensor) -> Tensor:
|
| 60 |
-
x, y = common_broadcast(x, y)
|
| 61 |
-
return x / y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/functional/lr_scheduler.py
DELETED
|
@@ -1,178 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from typing import Optional
|
| 17 |
-
|
| 18 |
-
import numpy as np
|
| 19 |
-
|
| 20 |
-
from lyra_2._ext.imaginaire.utils import distributed, log
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class TeroPolyScheduler:
|
| 24 |
-
def __init__(
|
| 25 |
-
self,
|
| 26 |
-
total_Mimg: int,
|
| 27 |
-
batch_size: int,
|
| 28 |
-
ref_Mimg: Optional[int] = None,
|
| 29 |
-
ref_batches: float = 70e3 / 1024,
|
| 30 |
-
max_lr_ratio: Optional[float] = 1.0,
|
| 31 |
-
min_lr_ratio: Optional[float] = None,
|
| 32 |
-
rampup_Mimg: float = 0,
|
| 33 |
-
rampdown_Mimg: int = 0,
|
| 34 |
-
verbosity_interval: int = 0,
|
| 35 |
-
formula: str = "poly",
|
| 36 |
-
poly_exp: float = 0.5,
|
| 37 |
-
):
|
| 38 |
-
self.total_Mimg = total_Mimg
|
| 39 |
-
self.batch_size = batch_size * distributed.get_world_size()
|
| 40 |
-
self.ref_Mimg = ref_Mimg or ref_batches * batch_size / 1e6
|
| 41 |
-
self.ref_batches = ref_batches
|
| 42 |
-
self.max_lr_ratio = max_lr_ratio
|
| 43 |
-
self.min_lr_ratio = min_lr_ratio
|
| 44 |
-
self.rampup_Mimg = rampup_Mimg
|
| 45 |
-
self.rampdown_Mimg = rampdown_Mimg
|
| 46 |
-
self.verbosity_interval = verbosity_interval
|
| 47 |
-
self.formula = formula
|
| 48 |
-
self.poly_exp = poly_exp
|
| 49 |
-
|
| 50 |
-
self._model = None
|
| 51 |
-
|
| 52 |
-
@property
|
| 53 |
-
def model(self):
|
| 54 |
-
return self._model
|
| 55 |
-
|
| 56 |
-
@model.setter
|
| 57 |
-
def model(self, model):
|
| 58 |
-
self._model = model
|
| 59 |
-
|
| 60 |
-
def schedule(self, n, **kwargs):
|
| 61 |
-
cur_Mimg = getattr(self.model, "sample_counter", 0) / 1e6
|
| 62 |
-
|
| 63 |
-
if self.formula == "constant":
|
| 64 |
-
lr = 1.0
|
| 65 |
-
elif self.formula == "poly":
|
| 66 |
-
lr = max(cur_Mimg / self.ref_Mimg, 1e-8) ** -self.poly_exp
|
| 67 |
-
else:
|
| 68 |
-
raise ValueError(f'Invalid learning rate formula "{self.formula}"')
|
| 69 |
-
|
| 70 |
-
if self.max_lr_ratio is not None:
|
| 71 |
-
lr = min(lr, self.max_lr_ratio)
|
| 72 |
-
if self.min_lr_ratio is not None:
|
| 73 |
-
lr = max(lr, self.min_lr_ratio)
|
| 74 |
-
|
| 75 |
-
if self.rampup_Mimg > 0 and cur_Mimg < self.rampup_Mimg:
|
| 76 |
-
lr *= cur_Mimg / self.rampup_Mimg
|
| 77 |
-
if self.rampdown_Mimg > 0 and cur_Mimg > self.total_Mimg - self.rampdown_Mimg:
|
| 78 |
-
lr *= (self.total_Mimg - cur_Mimg) / self.rampdown_Mimg
|
| 79 |
-
|
| 80 |
-
return lr
|
| 81 |
-
|
| 82 |
-
def __call__(self, n, **kwargs):
|
| 83 |
-
return self.schedule(n, **kwargs)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
class LambdaWarmUpCosineScheduler:
|
| 87 |
-
"""
|
| 88 |
-
A learning rate scheduler that combines warm-up with a cosine decay schedule for multiple cycles.
|
| 89 |
-
It supports different configurations for each cycle, including the number of warm-up steps, minimum
|
| 90 |
-
and maximum scaling factors for the learning rate.
|
| 91 |
-
|
| 92 |
-
The scheduler is intended to be used with a base learning rate of 1.0, where the actual learning
|
| 93 |
-
rate at any step is the base learning rate multiplied by the scaling factor computed by the scheduler.
|
| 94 |
-
|
| 95 |
-
Parameters:
|
| 96 |
-
warm_up_steps (list[int]): List of integers where each element represents the number of warm-up
|
| 97 |
-
steps for the corresponding cycle.
|
| 98 |
-
f_min (list[float]): List of the minimum scaling factors for each cycle after warm-up.
|
| 99 |
-
f_max (list[float]): List of the maximum scaling factors at the start and end of each cosine cycle.
|
| 100 |
-
f_start (list[float]): List of starting scaling factors for each warm-up phase.
|
| 101 |
-
cycle_lengths (list[int]): List of the total lengths of each cycle, including warm-up steps.
|
| 102 |
-
verbosity_interval (int, optional): Interval of training steps at which to print current step and
|
| 103 |
-
scaling factor information. Set to 0 by default to disable verbosity.
|
| 104 |
-
|
| 105 |
-
Examples:
|
| 106 |
-
>>> scheduler = LambdaWarmUpCosineScheduler2(
|
| 107 |
-
warm_up_steps=[10, 10],
|
| 108 |
-
f_min=[0.1, 0.1],
|
| 109 |
-
f_max=[1.0, 1.0],
|
| 110 |
-
f_start=[0.01, 0.01],
|
| 111 |
-
cycle_lengths=[50, 50],
|
| 112 |
-
verbosity_interval=10)
|
| 113 |
-
>>> for step in range(100):
|
| 114 |
-
>>> lr_multiplier = scheduler(step)
|
| 115 |
-
>>> print(f"Step {step}: LR Multiplier = {lr_multiplier}")
|
| 116 |
-
"""
|
| 117 |
-
|
| 118 |
-
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
| 119 |
-
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
| 120 |
-
self.lr_warm_up_steps = warm_up_steps
|
| 121 |
-
self.f_start = f_start
|
| 122 |
-
self.f_min = f_min
|
| 123 |
-
self.f_max = f_max
|
| 124 |
-
self.cycle_lengths = cycle_lengths
|
| 125 |
-
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
| 126 |
-
self.last_f = 0.0
|
| 127 |
-
self.verbosity_interval = verbosity_interval
|
| 128 |
-
|
| 129 |
-
def find_in_interval(self, n):
|
| 130 |
-
interval = 0
|
| 131 |
-
for cl in self.cum_cycles[1:]:
|
| 132 |
-
if n <= cl:
|
| 133 |
-
return interval
|
| 134 |
-
interval += 1
|
| 135 |
-
|
| 136 |
-
def schedule(self, n, **kwargs):
|
| 137 |
-
cycle = self.find_in_interval(n)
|
| 138 |
-
n = n - self.cum_cycles[cycle]
|
| 139 |
-
if self.verbosity_interval > 0:
|
| 140 |
-
if n % self.verbosity_interval == 0:
|
| 141 |
-
log.info(f"current step: {n}, recent lr-multiplier: {self.last_f}, current cycle {cycle}")
|
| 142 |
-
if n < self.lr_warm_up_steps[cycle]:
|
| 143 |
-
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
| 144 |
-
self.last_f = f
|
| 145 |
-
return f
|
| 146 |
-
else:
|
| 147 |
-
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
| 148 |
-
t = min(t, 1.0)
|
| 149 |
-
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi))
|
| 150 |
-
self.last_f = f
|
| 151 |
-
return f
|
| 152 |
-
|
| 153 |
-
def __call__(self, n, **kwargs):
|
| 154 |
-
return self.schedule(n, **kwargs)
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler):
|
| 158 |
-
"""
|
| 159 |
-
Linear instead of cosine decay for the main part of the cycle.
|
| 160 |
-
"""
|
| 161 |
-
|
| 162 |
-
def schedule(self, n, **kwargs):
|
| 163 |
-
cycle = self.find_in_interval(n)
|
| 164 |
-
n = n - self.cum_cycles[cycle]
|
| 165 |
-
if self.verbosity_interval > 0:
|
| 166 |
-
if n % self.verbosity_interval == 0:
|
| 167 |
-
log.info(f"current step: {n}, recent lr-multiplier: {self.last_f}, current cycle {cycle}")
|
| 168 |
-
|
| 169 |
-
if n < self.lr_warm_up_steps[cycle]:
|
| 170 |
-
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
| 171 |
-
self.last_f = f
|
| 172 |
-
return f
|
| 173 |
-
else:
|
| 174 |
-
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (
|
| 175 |
-
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
|
| 176 |
-
)
|
| 177 |
-
self.last_f = f
|
| 178 |
-
return f
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/lazy_config/__init__.py
DELETED
|
@@ -1,60 +0,0 @@
|
|
| 1 |
-
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
from omegaconf import DictConfig, OmegaConf
|
| 5 |
-
|
| 6 |
-
from lyra_2._ext.imaginaire.lazy_config.instantiate import instantiate
|
| 7 |
-
from lyra_2._ext.imaginaire.lazy_config.lazy import LazyCall, LazyConfig
|
| 8 |
-
from lyra_2._ext.imaginaire.lazy_config.omegaconf_patch import to_object
|
| 9 |
-
|
| 10 |
-
OmegaConf.to_object = to_object
|
| 11 |
-
|
| 12 |
-
PLACEHOLDER = None
|
| 13 |
-
LazyDict = DictConfig
|
| 14 |
-
|
| 15 |
-
__all__ = ["instantiate", "LazyCall", "LazyConfig", "PLACEHOLDER", "LazyDict"]
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def fixup_module_metadata(module_name, namespace, keys=None):
|
| 22 |
-
"""
|
| 23 |
-
Fix the __qualname__ of module members to be their exported api name, so
|
| 24 |
-
when they are referenced in docs, sphinx can find them. Reference:
|
| 25 |
-
https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241
|
| 26 |
-
"""
|
| 27 |
-
if not DOC_BUILDING:
|
| 28 |
-
return
|
| 29 |
-
seen_ids = set()
|
| 30 |
-
|
| 31 |
-
def fix_one(qualname, name, obj):
|
| 32 |
-
# avoid infinite recursion (relevant when using
|
| 33 |
-
# typing.Generic, for example)
|
| 34 |
-
if id(obj) in seen_ids:
|
| 35 |
-
return
|
| 36 |
-
seen_ids.add(id(obj))
|
| 37 |
-
|
| 38 |
-
mod = getattr(obj, "__module__", None)
|
| 39 |
-
if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")):
|
| 40 |
-
obj.__module__ = module_name
|
| 41 |
-
# Modules, unlike everything else in Python, put fully-qualitied
|
| 42 |
-
# names into their __name__ attribute. We check for "." to avoid
|
| 43 |
-
# rewriting these.
|
| 44 |
-
if hasattr(obj, "__name__") and "." not in obj.__name__:
|
| 45 |
-
obj.__name__ = name
|
| 46 |
-
obj.__qualname__ = qualname
|
| 47 |
-
if isinstance(obj, type):
|
| 48 |
-
for attr_name, attr_value in obj.__dict__.items():
|
| 49 |
-
fix_one(objname + "." + attr_name, attr_name, attr_value)
|
| 50 |
-
|
| 51 |
-
if keys is None:
|
| 52 |
-
keys = namespace.keys()
|
| 53 |
-
for objname in keys:
|
| 54 |
-
if not objname.startswith("_"):
|
| 55 |
-
obj = namespace[objname]
|
| 56 |
-
fix_one(objname, objname, obj)
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
fixup_module_metadata(__name__, globals(), __all__)
|
| 60 |
-
del fixup_module_metadata
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/lazy_config/file_io.py
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler
|
| 17 |
-
from iopath.common.file_io import PathManager as PathManagerBase
|
| 18 |
-
|
| 19 |
-
__all__ = ["PathManager", "PathHandler"]
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
PathManager = PathManagerBase()
|
| 23 |
-
PathManager.register_handler(HTTPURLHandler())
|
| 24 |
-
PathManager.register_handler(OneDrivePathHandler())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/lazy_config/instantiate.py
DELETED
|
@@ -1,120 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import collections.abc as abc
|
| 17 |
-
import dataclasses
|
| 18 |
-
import logging
|
| 19 |
-
from typing import Any
|
| 20 |
-
|
| 21 |
-
import attrs
|
| 22 |
-
|
| 23 |
-
from lyra_2._ext.imaginaire.lazy_config.registry import _convert_target_to_string, locate
|
| 24 |
-
|
| 25 |
-
__all__ = ["dump_dataclass", "instantiate"]
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def is_dataclass_or_attrs(target):
|
| 29 |
-
return dataclasses.is_dataclass(target) or attrs.has(target)
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def dump_dataclass(obj: Any):
|
| 33 |
-
"""
|
| 34 |
-
Dump a dataclass recursively into a dict that can be later instantiated.
|
| 35 |
-
|
| 36 |
-
Args:
|
| 37 |
-
obj: a dataclass object
|
| 38 |
-
|
| 39 |
-
Returns:
|
| 40 |
-
dict
|
| 41 |
-
"""
|
| 42 |
-
assert dataclasses.is_dataclass(obj) and not isinstance(obj, type), (
|
| 43 |
-
"dump_dataclass() requires an instance of a dataclass."
|
| 44 |
-
)
|
| 45 |
-
ret = {"_target_": _convert_target_to_string(type(obj))}
|
| 46 |
-
for f in dataclasses.fields(obj):
|
| 47 |
-
v = getattr(obj, f.name)
|
| 48 |
-
if dataclasses.is_dataclass(v):
|
| 49 |
-
v = dump_dataclass(v)
|
| 50 |
-
if isinstance(v, (list, tuple)):
|
| 51 |
-
v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v]
|
| 52 |
-
ret[f.name] = v
|
| 53 |
-
return ret
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def instantiate(cfg, *args, **kwargs):
|
| 57 |
-
"""
|
| 58 |
-
Recursively instantiate objects defined in dictionaries by
|
| 59 |
-
"_target_" and arguments.
|
| 60 |
-
|
| 61 |
-
Args:
|
| 62 |
-
cfg: a dict-like object with "_target_" that defines the caller, and
|
| 63 |
-
other keys that define the arguments
|
| 64 |
-
args: Optional positional parameters pass-through.
|
| 65 |
-
kwargs: Optional named parameters pass-through.
|
| 66 |
-
|
| 67 |
-
Returns:
|
| 68 |
-
object instantiated by cfg
|
| 69 |
-
"""
|
| 70 |
-
from omegaconf import DictConfig, ListConfig, OmegaConf
|
| 71 |
-
|
| 72 |
-
if isinstance(cfg, ListConfig):
|
| 73 |
-
lst = [instantiate(x) for x in cfg]
|
| 74 |
-
return ListConfig(lst, flags={"allow_objects": True})
|
| 75 |
-
if isinstance(cfg, list):
|
| 76 |
-
# Specialize for list, because many classes take
|
| 77 |
-
# list[objects] as arguments, such as ResNet, DatasetMapper
|
| 78 |
-
return [instantiate(x) for x in cfg]
|
| 79 |
-
|
| 80 |
-
# If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config),
|
| 81 |
-
# instantiate it to the actual dataclass.
|
| 82 |
-
if isinstance(cfg, DictConfig) and is_dataclass_or_attrs(cfg._metadata.object_type):
|
| 83 |
-
return OmegaConf.to_object(cfg)
|
| 84 |
-
|
| 85 |
-
if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
|
| 86 |
-
# conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
|
| 87 |
-
# but faster: https://github.com/facebookresearch/hydra/issues/1200
|
| 88 |
-
is_recursive = getattr(cfg, "_recursive_", True)
|
| 89 |
-
if is_recursive:
|
| 90 |
-
cfg = {k: instantiate(v) for k, v in cfg.items()}
|
| 91 |
-
else:
|
| 92 |
-
cfg = {k: v for k, v in cfg.items()}
|
| 93 |
-
# pop the _recursive_ key to avoid passing it as a parameter
|
| 94 |
-
if "_recursive_" in cfg:
|
| 95 |
-
cfg.pop("_recursive_")
|
| 96 |
-
cls = cfg.pop("_target_")
|
| 97 |
-
cls = instantiate(cls)
|
| 98 |
-
|
| 99 |
-
if isinstance(cls, str):
|
| 100 |
-
cls_name = cls
|
| 101 |
-
cls = locate(cls_name)
|
| 102 |
-
assert cls is not None, cls_name
|
| 103 |
-
else:
|
| 104 |
-
try:
|
| 105 |
-
cls_name = cls.__module__ + "." + cls.__qualname__
|
| 106 |
-
except Exception:
|
| 107 |
-
# target could be anything, so the above could fail
|
| 108 |
-
cls_name = str(cls)
|
| 109 |
-
assert callable(cls), f"_target_ {cls} does not define a callable object"
|
| 110 |
-
try:
|
| 111 |
-
# override config with kwargs
|
| 112 |
-
instantiate_kwargs = {}
|
| 113 |
-
instantiate_kwargs.update(cfg)
|
| 114 |
-
instantiate_kwargs.update(kwargs)
|
| 115 |
-
return cls(*args, **instantiate_kwargs)
|
| 116 |
-
except TypeError:
|
| 117 |
-
logger = logging.getLogger(__name__)
|
| 118 |
-
logger.error(f"Error when instantiating {cls_name}!")
|
| 119 |
-
raise
|
| 120 |
-
return cfg # return as-is if don't know what to do
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/lazy_config/lazy.py
DELETED
|
@@ -1,430 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import ast
|
| 17 |
-
import builtins
|
| 18 |
-
import collections.abc as abc
|
| 19 |
-
import importlib
|
| 20 |
-
import inspect
|
| 21 |
-
import logging
|
| 22 |
-
import os
|
| 23 |
-
import pickle
|
| 24 |
-
import uuid
|
| 25 |
-
from collections import OrderedDict
|
| 26 |
-
from contextlib import contextmanager
|
| 27 |
-
from copy import deepcopy
|
| 28 |
-
from dataclasses import is_dataclass
|
| 29 |
-
from typing import Any, Dict, List, Tuple, Union
|
| 30 |
-
|
| 31 |
-
import attrs
|
| 32 |
-
import yaml
|
| 33 |
-
from omegaconf import DictConfig, ListConfig, OmegaConf
|
| 34 |
-
|
| 35 |
-
try:
|
| 36 |
-
import dill as dill_pickle
|
| 37 |
-
except ImportError:
|
| 38 |
-
dill_pickle = None
|
| 39 |
-
|
| 40 |
-
try:
|
| 41 |
-
import cloudpickle
|
| 42 |
-
except ImportError:
|
| 43 |
-
cloudpickle = None
|
| 44 |
-
|
| 45 |
-
from lyra_2._ext.imaginaire.lazy_config.file_io import PathManager
|
| 46 |
-
from lyra_2._ext.imaginaire.lazy_config.registry import _convert_target_to_string
|
| 47 |
-
|
| 48 |
-
__all__ = ["LazyCall", "LazyConfig"]
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def sort_dict(d: Dict[str, Any]) -> OrderedDict[str, Any]:
|
| 52 |
-
return OrderedDict(sorted(d.items(), key=lambda x: x[0]))
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def dict_representer(dumper: yaml.Dumper, data: OrderedDict[str, Any]) -> yaml.nodes.MappingNode:
|
| 56 |
-
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def sort_recursive(obj: Union[Dict[str, Any], List[Any], Any]) -> Union[OrderedDict[str, Any], List[Any], Any]:
|
| 60 |
-
if isinstance(obj, dict):
|
| 61 |
-
return sort_dict({k: sort_recursive(v) for k, v in obj.items()})
|
| 62 |
-
elif isinstance(obj, list):
|
| 63 |
-
return [sort_recursive(item) for item in obj]
|
| 64 |
-
return obj
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
yaml.add_representer(OrderedDict, dict_representer)
|
| 68 |
-
|
| 69 |
-
OmegaConf.register_new_resolver("add", lambda *vals: sum(vals))
|
| 70 |
-
OmegaConf.register_new_resolver("subtract", lambda *vals: vals[0] - sum(vals[1:]))
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def get_default_params(cls_or_func):
|
| 74 |
-
if callable(cls_or_func):
|
| 75 |
-
# inspect signature for function
|
| 76 |
-
signature = inspect.signature(cls_or_func)
|
| 77 |
-
else:
|
| 78 |
-
# inspect signature for class
|
| 79 |
-
signature = inspect.signature(cls_or_func.__init__)
|
| 80 |
-
params = signature.parameters
|
| 81 |
-
default_params = {
|
| 82 |
-
name: param.default for name, param in params.items() if param.default is not inspect.Parameter.empty
|
| 83 |
-
}
|
| 84 |
-
return default_params
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
class LazyCall:
|
| 88 |
-
"""
|
| 89 |
-
Wrap a callable so that when it's called, the call will not be executed,
|
| 90 |
-
but returns a dict that describes the call.
|
| 91 |
-
|
| 92 |
-
LazyCall object has to be called with only keyword arguments. Positional
|
| 93 |
-
arguments are not yet supported.
|
| 94 |
-
|
| 95 |
-
Examples:
|
| 96 |
-
::
|
| 97 |
-
from detectron2.config import instantiate, LazyCall
|
| 98 |
-
|
| 99 |
-
layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
|
| 100 |
-
layer_cfg.out_channels = 64 # can edit it afterwards
|
| 101 |
-
layer = instantiate(layer_cfg)
|
| 102 |
-
"""
|
| 103 |
-
|
| 104 |
-
def __init__(self, target):
|
| 105 |
-
if not (callable(target) or isinstance(target, (str, abc.Mapping))):
|
| 106 |
-
raise TypeError(f"target of LazyCall must be a callable or defines a callable! Got {target}")
|
| 107 |
-
self._target = target
|
| 108 |
-
|
| 109 |
-
def __call__(self, **kwargs):
|
| 110 |
-
if is_dataclass(self._target) or attrs.has(self._target):
|
| 111 |
-
# omegaconf object cannot hold dataclass type
|
| 112 |
-
# https://github.com/omry/omegaconf/issues/784
|
| 113 |
-
target = _convert_target_to_string(self._target)
|
| 114 |
-
else:
|
| 115 |
-
target = self._target
|
| 116 |
-
kwargs["_target_"] = target
|
| 117 |
-
|
| 118 |
-
_final_params = get_default_params(self._target)
|
| 119 |
-
_final_params.update(kwargs)
|
| 120 |
-
|
| 121 |
-
return DictConfig(content=_final_params, flags={"allow_objects": True})
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def _visit_dict_config(cfg, func):
|
| 125 |
-
"""
|
| 126 |
-
Apply func recursively to all DictConfig in cfg.
|
| 127 |
-
"""
|
| 128 |
-
if isinstance(cfg, DictConfig):
|
| 129 |
-
func(cfg)
|
| 130 |
-
for v in cfg.values():
|
| 131 |
-
_visit_dict_config(v, func)
|
| 132 |
-
elif isinstance(cfg, ListConfig):
|
| 133 |
-
for v in cfg:
|
| 134 |
-
_visit_dict_config(v, func)
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
def _validate_py_syntax(filename):
|
| 138 |
-
# see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
|
| 139 |
-
with PathManager.open(filename, "r") as f:
|
| 140 |
-
content = f.read()
|
| 141 |
-
try:
|
| 142 |
-
ast.parse(content)
|
| 143 |
-
except SyntaxError as e:
|
| 144 |
-
raise SyntaxError(f"Config file {filename} has syntax error!") from e
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
def _cast_to_config(obj):
|
| 148 |
-
# if given a dict, return DictConfig instead
|
| 149 |
-
if isinstance(obj, dict):
|
| 150 |
-
return DictConfig(obj, flags={"allow_objects": True})
|
| 151 |
-
return obj
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
_CFG_PACKAGE_NAME = "detectron2._cfg_loader"
|
| 155 |
-
"""
|
| 156 |
-
A namespace to put all imported config into.
|
| 157 |
-
"""
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
def _random_package_name(filename):
|
| 161 |
-
# generate a random package name when loading config files
|
| 162 |
-
return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename)
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
@contextmanager
|
| 166 |
-
def _patch_import():
|
| 167 |
-
"""
|
| 168 |
-
Enhance relative import statements in config files, so that they:
|
| 169 |
-
1. locate files purely based on relative location, regardless of packages.
|
| 170 |
-
e.g. you can import file without having __init__
|
| 171 |
-
2. do not cache modules globally; modifications of module states has no side effect
|
| 172 |
-
3. support other storage system through PathManager, so config files can be in the cloud
|
| 173 |
-
4. imported dict are turned into omegaconf.DictConfig automatically
|
| 174 |
-
"""
|
| 175 |
-
old_import = builtins.__import__
|
| 176 |
-
|
| 177 |
-
def find_relative_file(original_file, relative_import_path, level):
|
| 178 |
-
# NOTE: "from . import x" is not handled. Because then it's unclear
|
| 179 |
-
# if such import should produce `x` as a python module or DictConfig.
|
| 180 |
-
# This can be discussed further if needed.
|
| 181 |
-
relative_import_err = """
|
| 182 |
-
Relative import of directories is not allowed within config files.
|
| 183 |
-
Within a config file, relative import can only import other config files.
|
| 184 |
-
""".replace("\n", " ")
|
| 185 |
-
if not len(relative_import_path):
|
| 186 |
-
raise ImportError(relative_import_err)
|
| 187 |
-
|
| 188 |
-
cur_file = os.path.dirname(original_file)
|
| 189 |
-
for _ in range(level - 1):
|
| 190 |
-
cur_file = os.path.dirname(cur_file)
|
| 191 |
-
cur_name = relative_import_path.lstrip(".")
|
| 192 |
-
for part in cur_name.split("."):
|
| 193 |
-
cur_file = os.path.join(cur_file, part)
|
| 194 |
-
if not cur_file.endswith(".py"):
|
| 195 |
-
cur_file += ".py"
|
| 196 |
-
if not PathManager.isfile(cur_file):
|
| 197 |
-
cur_file_no_suffix = cur_file[: -len(".py")]
|
| 198 |
-
if PathManager.isdir(cur_file_no_suffix):
|
| 199 |
-
raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err)
|
| 200 |
-
else:
|
| 201 |
-
raise ImportError(
|
| 202 |
-
f"Cannot import name {relative_import_path} from {original_file}: {cur_file} does not exist."
|
| 203 |
-
)
|
| 204 |
-
return cur_file
|
| 205 |
-
|
| 206 |
-
def new_import(name, globals=None, locals=None, fromlist=(), level=0):
|
| 207 |
-
if (
|
| 208 |
-
# Only deal with relative imports inside config files
|
| 209 |
-
level != 0 and globals is not None and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME)
|
| 210 |
-
):
|
| 211 |
-
cur_file = find_relative_file(globals["__file__"], name, level)
|
| 212 |
-
_validate_py_syntax(cur_file)
|
| 213 |
-
spec = importlib.machinery.ModuleSpec(_random_package_name(cur_file), None, origin=cur_file)
|
| 214 |
-
module = importlib.util.module_from_spec(spec)
|
| 215 |
-
module.__file__ = cur_file
|
| 216 |
-
with PathManager.open(cur_file) as f:
|
| 217 |
-
content = f.read()
|
| 218 |
-
exec(compile(content, cur_file, "exec"), module.__dict__)
|
| 219 |
-
for name in fromlist: # turn imported dict into DictConfig automatically
|
| 220 |
-
val = _cast_to_config(module.__dict__[name])
|
| 221 |
-
module.__dict__[name] = val
|
| 222 |
-
return module
|
| 223 |
-
return old_import(name, globals, locals, fromlist=fromlist, level=level)
|
| 224 |
-
|
| 225 |
-
builtins.__import__ = new_import
|
| 226 |
-
yield new_import
|
| 227 |
-
builtins.__import__ = old_import
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
class LazyConfig:
|
| 231 |
-
"""
|
| 232 |
-
Provide methods to save, load, and overrides an omegaconf config object
|
| 233 |
-
which may contain definition of lazily-constructed objects.
|
| 234 |
-
"""
|
| 235 |
-
|
| 236 |
-
@staticmethod
|
| 237 |
-
def load_rel(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
|
| 238 |
-
"""
|
| 239 |
-
Similar to :meth:`load()`, but load path relative to the caller's
|
| 240 |
-
source file.
|
| 241 |
-
|
| 242 |
-
This has the same functionality as a relative import, except that this method
|
| 243 |
-
accepts filename as a string, so more characters are allowed in the filename.
|
| 244 |
-
"""
|
| 245 |
-
caller_frame = inspect.stack()[1]
|
| 246 |
-
caller_fname = caller_frame[0].f_code.co_filename
|
| 247 |
-
assert caller_fname != "<string>", "load_rel Unable to find caller"
|
| 248 |
-
caller_dir = os.path.dirname(caller_fname)
|
| 249 |
-
filename = os.path.join(caller_dir, filename)
|
| 250 |
-
return LazyConfig.load(filename, keys)
|
| 251 |
-
|
| 252 |
-
@staticmethod
|
| 253 |
-
def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
|
| 254 |
-
"""
|
| 255 |
-
Load a config file.
|
| 256 |
-
|
| 257 |
-
Args:
|
| 258 |
-
filename: absolute path or relative path w.r.t. the current working directory
|
| 259 |
-
keys: keys to load and return. If not given, return all keys
|
| 260 |
-
(whose values are config objects) in a dict.
|
| 261 |
-
"""
|
| 262 |
-
has_keys = keys is not None
|
| 263 |
-
filename = filename.replace("/./", "/") # redundant
|
| 264 |
-
if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]:
|
| 265 |
-
raise ValueError(f"Config file {filename} has to be a python or yaml file.")
|
| 266 |
-
if filename.endswith(".py"):
|
| 267 |
-
_validate_py_syntax(filename)
|
| 268 |
-
|
| 269 |
-
with _patch_import():
|
| 270 |
-
# Record the filename
|
| 271 |
-
module_namespace = {
|
| 272 |
-
"__file__": filename,
|
| 273 |
-
"__package__": _random_package_name(filename),
|
| 274 |
-
}
|
| 275 |
-
with PathManager.open(filename) as f:
|
| 276 |
-
content = f.read()
|
| 277 |
-
# Compile first with filename to:
|
| 278 |
-
# 1. make filename appears in stacktrace
|
| 279 |
-
# 2. make load_rel able to find its parent's (possibly remote) location
|
| 280 |
-
exec(compile(content, filename, "exec"), module_namespace)
|
| 281 |
-
|
| 282 |
-
ret = module_namespace
|
| 283 |
-
else:
|
| 284 |
-
with PathManager.open(filename) as f:
|
| 285 |
-
obj = yaml.unsafe_load(f)
|
| 286 |
-
ret = OmegaConf.create(obj, flags={"allow_objects": True})
|
| 287 |
-
|
| 288 |
-
if has_keys:
|
| 289 |
-
if isinstance(keys, str):
|
| 290 |
-
return _cast_to_config(ret[keys])
|
| 291 |
-
else:
|
| 292 |
-
return tuple(_cast_to_config(ret[a]) for a in keys)
|
| 293 |
-
else:
|
| 294 |
-
if filename.endswith(".py"):
|
| 295 |
-
# when not specified, only load those that are config objects
|
| 296 |
-
ret = DictConfig(
|
| 297 |
-
{
|
| 298 |
-
name: _cast_to_config(value)
|
| 299 |
-
for name, value in ret.items()
|
| 300 |
-
if isinstance(value, (DictConfig, ListConfig, dict)) and not name.startswith("_")
|
| 301 |
-
},
|
| 302 |
-
flags={"allow_objects": True},
|
| 303 |
-
)
|
| 304 |
-
return ret
|
| 305 |
-
|
| 306 |
-
@staticmethod
|
| 307 |
-
def save_pkl(cfg, filename: str) -> str:
|
| 308 |
-
"""
|
| 309 |
-
Saves a Config object to a file using pickle serialization. This method is typically used
|
| 310 |
-
when the configuration object contains complex objects, such as lambdas, that are not supported by
|
| 311 |
-
simpler serialization methods like YAML. The function attempts to create a deep copy of the configuration
|
| 312 |
-
object before serialization to ensure that the original object remains unmodified.
|
| 313 |
-
|
| 314 |
-
Args:
|
| 315 |
-
cfg: A Config object to be serialized and saved.
|
| 316 |
-
filename: The path and name of the file where the configuration should be saved. The function
|
| 317 |
-
assumes the file extension indicates a pickle format (e.g., .pkl).
|
| 318 |
-
|
| 319 |
-
Returns:
|
| 320 |
-
str: The filename to which the configuration was saved. This can be used to verify the file location
|
| 321 |
-
or log the outcome.
|
| 322 |
-
|
| 323 |
-
Notes:
|
| 324 |
-
- The function logs a warning if the configuration is successfully saved using pickle.
|
| 325 |
-
- If saving fails, an error is logged with the exception details.
|
| 326 |
-
"""
|
| 327 |
-
logger = logging.getLogger(__name__)
|
| 328 |
-
try:
|
| 329 |
-
cfg = deepcopy(cfg)
|
| 330 |
-
except Exception:
|
| 331 |
-
pass
|
| 332 |
-
|
| 333 |
-
try:
|
| 334 |
-
with PathManager.open(filename, "wb") as f:
|
| 335 |
-
pickle.dump(cfg, f)
|
| 336 |
-
logger.warning(f"Config is saved using pickle at {filename}.")
|
| 337 |
-
except Exception as e:
|
| 338 |
-
logger.error(f"Failed to save config to {filename}: {e}. Trying dill or cloudpickle instead")
|
| 339 |
-
if dill_pickle:
|
| 340 |
-
try:
|
| 341 |
-
with PathManager.open(filename, "wb") as f:
|
| 342 |
-
pickle.dump(dill_pickle.dumps(cfg, recurse=True), f)
|
| 343 |
-
logger.warning(f"Config is saved using dill at {filename}.")
|
| 344 |
-
except Exception as e:
|
| 345 |
-
logger.error(f"Failed to save config to {filename}: {e}.")
|
| 346 |
-
if cloudpickle:
|
| 347 |
-
try:
|
| 348 |
-
with PathManager.open(filename, "wb") as f:
|
| 349 |
-
pickle.dump(cloudpickle.dumps(cfg), f)
|
| 350 |
-
logger.warning(f"Config is saved using cloudpickle at {filename}.")
|
| 351 |
-
except Exception as e:
|
| 352 |
-
logger.error(f"Failed to save config to {filename}: {e}.")
|
| 353 |
-
else:
|
| 354 |
-
logger.error("cloudpickle is not available. Cannot save the config.")
|
| 355 |
-
raise e
|
| 356 |
-
|
| 357 |
-
return filename
|
| 358 |
-
|
| 359 |
-
@staticmethod
|
| 360 |
-
def save_yaml(cfg, filename: str) -> str:
|
| 361 |
-
"""
|
| 362 |
-
Saves a Config object to a file using YAML serialization. This method is beneficial when the configuration object's content needs to be human-readable and easily editable. YAML is suitable for configurations that do not contain complex types like lambdas, which must be handled differently. The function converts unserializable items to strings before saving to ensure compatibility with YAML serialization.
|
| 363 |
-
|
| 364 |
-
Args:
|
| 365 |
-
cfg: A Config object to be serialized and saved. It handles both DictConfig and ListConfig types.
|
| 366 |
-
filename: The path and name of the file where the configuration should be saved. The function does not require a specific file extension but typically uses '.yaml'.
|
| 367 |
-
|
| 368 |
-
Returns:
|
| 369 |
-
str: The filename to which the configuration was saved. This can be used to verify the file location or log the outcome.
|
| 370 |
-
|
| 371 |
-
Notes:
|
| 372 |
-
- The function logs a warning if the configuration is successfully saved using YAML.
|
| 373 |
-
- If saving fails, an error is logged with the exception details.
|
| 374 |
-
"""
|
| 375 |
-
logger = logging.getLogger(__name__)
|
| 376 |
-
try:
|
| 377 |
-
cfg = deepcopy(cfg)
|
| 378 |
-
except Exception:
|
| 379 |
-
pass
|
| 380 |
-
|
| 381 |
-
# Define a function to check if an item is serializable to YAML
|
| 382 |
-
def is_serializable(item):
|
| 383 |
-
try:
|
| 384 |
-
OmegaConf.to_yaml(item)
|
| 385 |
-
return True
|
| 386 |
-
except Exception as e:
|
| 387 |
-
return False
|
| 388 |
-
|
| 389 |
-
# Function to convert unserializable items to strings
|
| 390 |
-
def serialize_config(config):
|
| 391 |
-
if isinstance(config, DictConfig):
|
| 392 |
-
for key, value in config.items():
|
| 393 |
-
if isinstance(value, (DictConfig, ListConfig)):
|
| 394 |
-
try:
|
| 395 |
-
if "_target_" in value:
|
| 396 |
-
default_params = get_default_params(value["_target_"])
|
| 397 |
-
for default_key, default_v in default_params.items():
|
| 398 |
-
if default_key not in value:
|
| 399 |
-
value[default_key] = default_v
|
| 400 |
-
except Exception as e:
|
| 401 |
-
logger.error(f"Failed to add default argument values: {e}")
|
| 402 |
-
|
| 403 |
-
serialize_config(value)
|
| 404 |
-
else:
|
| 405 |
-
if not is_serializable(value) and value is not None:
|
| 406 |
-
config[key] = str(value)
|
| 407 |
-
elif isinstance(config, ListConfig):
|
| 408 |
-
for i, item in enumerate(config):
|
| 409 |
-
if isinstance(item, (DictConfig, ListConfig)):
|
| 410 |
-
serialize_config(item)
|
| 411 |
-
else:
|
| 412 |
-
if not is_serializable(item) and item is not None:
|
| 413 |
-
config[i] = str(item)
|
| 414 |
-
else:
|
| 415 |
-
raise NotImplementedError("Input config must be a DictConfig or ListConfig.")
|
| 416 |
-
return config
|
| 417 |
-
|
| 418 |
-
# Convert Config object to a DictConfig object.
|
| 419 |
-
config_dict = attrs.asdict(cfg)
|
| 420 |
-
config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
|
| 421 |
-
|
| 422 |
-
# Serialize the DictConfig object by converting non-serializable objects to strings.
|
| 423 |
-
config_omegaconf = serialize_config(config_omegaconf)
|
| 424 |
-
|
| 425 |
-
config_dict: Dict[str, Any] = OmegaConf.to_container(config_omegaconf, resolve=True)
|
| 426 |
-
sorted_config: OrderedDict[str, Any] = sort_recursive(config_dict)
|
| 427 |
-
with open(filename, "w") as f:
|
| 428 |
-
yaml.dump(sorted_config, f, default_flow_style=False)
|
| 429 |
-
logger.warning(f"Config is saved using omegaconf at {filename}.")
|
| 430 |
-
return filename
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/lazy_config/omegaconf_patch.py
DELETED
|
@@ -1,65 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from typing import Any, Dict, List, Union
|
| 17 |
-
|
| 18 |
-
from omegaconf import OmegaConf
|
| 19 |
-
from omegaconf.base import DictKeyType, SCMode
|
| 20 |
-
from omegaconf.dictconfig import DictConfig # pragma: no cover
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def to_object(cfg: Any) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]:
|
| 24 |
-
"""
|
| 25 |
-
Converts an OmegaConf configuration object to a native Python container (dict or list), unless
|
| 26 |
-
the configuration is specifically created by LazyCall, in which case the original configuration
|
| 27 |
-
is returned directly.
|
| 28 |
-
|
| 29 |
-
This function serves as a modification of the original `to_object` method from OmegaConf,
|
| 30 |
-
preventing DictConfig objects created by LazyCall from being automatically converted to Python
|
| 31 |
-
dictionaries. This ensures that configurations meant to be lazily evaluated retain their intended
|
| 32 |
-
structure and behavior.
|
| 33 |
-
|
| 34 |
-
Differences from OmegaConf's original `to_object`:
|
| 35 |
-
- Adds a check at the beginning to return the configuration unchanged if it is created by LazyCall.
|
| 36 |
-
|
| 37 |
-
Reference:
|
| 38 |
-
- Original OmegaConf `to_object` method: https://github.com/omry/omegaconf/blob/master/omegaconf/omegaconf.py#L595
|
| 39 |
-
|
| 40 |
-
Args:
|
| 41 |
-
cfg (Any): The OmegaConf configuration object to convert.
|
| 42 |
-
|
| 43 |
-
Returns:
|
| 44 |
-
Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: The converted Python container if
|
| 45 |
-
`cfg` is not a LazyCall created configuration, otherwise the unchanged `cfg`.
|
| 46 |
-
|
| 47 |
-
Examples:
|
| 48 |
-
>>> cfg = DictConfig({"key": "value", "_target_": "Model"})
|
| 49 |
-
>>> to_object(cfg)
|
| 50 |
-
DictConfig({"key": "value", "_target_": "Model"})
|
| 51 |
-
|
| 52 |
-
>>> cfg = DictConfig({"list": [1, 2, 3]})
|
| 53 |
-
>>> to_object(cfg)
|
| 54 |
-
{'list': [1, 2, 3]}
|
| 55 |
-
"""
|
| 56 |
-
if isinstance(cfg, DictConfig) and "_target_" in cfg.keys():
|
| 57 |
-
return cfg
|
| 58 |
-
|
| 59 |
-
return OmegaConf.to_container(
|
| 60 |
-
cfg=cfg,
|
| 61 |
-
resolve=True,
|
| 62 |
-
throw_on_missing=True,
|
| 63 |
-
enum_to_str=False,
|
| 64 |
-
structured_config_mode=SCMode.INSTANTIATE,
|
| 65 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/lazy_config/registry.py
DELETED
|
@@ -1,74 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import pydoc
|
| 17 |
-
from typing import Any
|
| 18 |
-
|
| 19 |
-
from fvcore.common.registry import Registry # for backward compatibility.
|
| 20 |
-
|
| 21 |
-
"""
|
| 22 |
-
``Registry`` and `locate` provide ways to map a string (typically found
|
| 23 |
-
in config files) to callable objects.
|
| 24 |
-
"""
|
| 25 |
-
|
| 26 |
-
__all__ = ["Registry", "locate"]
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def _convert_target_to_string(t: Any) -> str:
|
| 30 |
-
"""
|
| 31 |
-
Inverse of ``locate()``.
|
| 32 |
-
|
| 33 |
-
Args:
|
| 34 |
-
t: any object with ``__module__`` and ``__qualname__``
|
| 35 |
-
"""
|
| 36 |
-
module, qualname = t.__module__, t.__qualname__
|
| 37 |
-
|
| 38 |
-
# Compress the path to this object, e.g. ``module.submodule._impl.class``
|
| 39 |
-
# may become ``module.submodule.class``, if the later also resolves to the same
|
| 40 |
-
# object. This simplifies the string, and also is less affected by moving the
|
| 41 |
-
# class implementation.
|
| 42 |
-
module_parts = module.split(".")
|
| 43 |
-
for k in range(1, len(module_parts)):
|
| 44 |
-
prefix = ".".join(module_parts[:k])
|
| 45 |
-
candidate = f"{prefix}.{qualname}"
|
| 46 |
-
try:
|
| 47 |
-
if locate(candidate) is t:
|
| 48 |
-
return candidate
|
| 49 |
-
except ImportError:
|
| 50 |
-
pass
|
| 51 |
-
return f"{module}.{qualname}"
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def locate(name: str) -> Any:
|
| 55 |
-
"""
|
| 56 |
-
Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``,
|
| 57 |
-
such as "module.submodule.class_name".
|
| 58 |
-
|
| 59 |
-
Raise Exception if it cannot be found.
|
| 60 |
-
"""
|
| 61 |
-
obj = pydoc.locate(name)
|
| 62 |
-
|
| 63 |
-
# Some cases (e.g. torch.optim.sgd.SGD) not handled correctly
|
| 64 |
-
# by pydoc.locate. Try a private function from hydra.
|
| 65 |
-
if obj is None:
|
| 66 |
-
try:
|
| 67 |
-
# from hydra.utils import get_method - will print many errors
|
| 68 |
-
from hydra.utils import _locate
|
| 69 |
-
except ImportError as e:
|
| 70 |
-
raise ImportError(f"Cannot dynamically locate object {name}!") from e
|
| 71 |
-
else:
|
| 72 |
-
obj = _locate(name) # it raises if fails
|
| 73 |
-
|
| 74 |
-
return obj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/model.py
DELETED
|
@@ -1,129 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from typing import Any
|
| 17 |
-
|
| 18 |
-
import torch
|
| 19 |
-
|
| 20 |
-
from lyra_2._ext.imaginaire.lazy_config import LazyDict, instantiate
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class ImaginaireModel(torch.nn.Module):
|
| 24 |
-
"""The base model class of Imaginaire. It is inherited from torch.nn.Module.
|
| 25 |
-
|
| 26 |
-
All models in Imaginaire should inherit ImaginaireModel. It should include the implementions for all the
|
| 27 |
-
computation graphs. All inheriting child classes should implement the following methods:
|
| 28 |
-
- training_step(): The training step of the model, including the loss computation.
|
| 29 |
-
- validation_step(): The validation step of the model, including the loss computation.
|
| 30 |
-
- forward(): The computation graph for model inference.
|
| 31 |
-
The following methods have default implementations in ImaginaireModel:
|
| 32 |
-
- init_optimizer_scheduler(): Creates the optimizer and scheduler for the model.
|
| 33 |
-
"""
|
| 34 |
-
|
| 35 |
-
def __init__(self) -> None:
|
| 36 |
-
super().__init__()
|
| 37 |
-
|
| 38 |
-
def init_optimizer_scheduler(
|
| 39 |
-
self, optimizer_config: LazyDict, scheduler_config: LazyDict
|
| 40 |
-
) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
|
| 41 |
-
"""Creates the optimizer and scheduler for the model.
|
| 42 |
-
|
| 43 |
-
Args:
|
| 44 |
-
config_model (ModelConfig): The config object for the model.
|
| 45 |
-
|
| 46 |
-
Returns:
|
| 47 |
-
optimizer (torch.optim.Optimizer): The model optimizer.
|
| 48 |
-
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
|
| 49 |
-
"""
|
| 50 |
-
optimizer_config.params = self.parameters()
|
| 51 |
-
optimizer = instantiate(optimizer_config)
|
| 52 |
-
scheduler_config.optimizer = optimizer
|
| 53 |
-
scheduler = instantiate(scheduler_config)
|
| 54 |
-
return optimizer, scheduler
|
| 55 |
-
|
| 56 |
-
def training_step(
|
| 57 |
-
self, data_batch: dict[str, torch.Tensor], iteration: int
|
| 58 |
-
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
|
| 59 |
-
"""The training step of the model, including the loss computation.
|
| 60 |
-
|
| 61 |
-
Args:
|
| 62 |
-
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
|
| 63 |
-
iteration (int): Current iteration number.
|
| 64 |
-
|
| 65 |
-
Returns:
|
| 66 |
-
output_batch (dict[str, torch.Tensor]): Auxiliary model output from the training batch.
|
| 67 |
-
loss (torch.Tensor): The total loss for backprop (weighted sum of various losses).
|
| 68 |
-
"""
|
| 69 |
-
raise NotImplementedError
|
| 70 |
-
|
| 71 |
-
@torch.no_grad()
|
| 72 |
-
def validation_step(
|
| 73 |
-
self, data_batch: dict[str, torch.Tensor], iteration: int
|
| 74 |
-
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
|
| 75 |
-
"""The validation step of the model, including the loss computation.
|
| 76 |
-
|
| 77 |
-
Args:
|
| 78 |
-
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
|
| 79 |
-
iteration (int): Current iteration number.
|
| 80 |
-
|
| 81 |
-
Returns:
|
| 82 |
-
output_batch (dict[str, torch.Tensor]): Auxiliary model output from the validation batch.
|
| 83 |
-
loss (torch.Tensor): The total loss (weighted sum of various losses).
|
| 84 |
-
"""
|
| 85 |
-
raise NotImplementedError
|
| 86 |
-
|
| 87 |
-
@torch.inference_mode()
|
| 88 |
-
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
| 89 |
-
"""The computation graph for model inference.
|
| 90 |
-
|
| 91 |
-
Args:
|
| 92 |
-
*args: Whatever you decide to pass into the forward method.
|
| 93 |
-
**kwargs: Keyword arguments are also possible.
|
| 94 |
-
|
| 95 |
-
Return:
|
| 96 |
-
Your model's output.
|
| 97 |
-
"""
|
| 98 |
-
raise NotImplementedError
|
| 99 |
-
|
| 100 |
-
def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None:
|
| 101 |
-
"""The model preparation before the training is launched
|
| 102 |
-
|
| 103 |
-
Args:
|
| 104 |
-
memory_format (torch.memory_format): Memory format of the model.
|
| 105 |
-
"""
|
| 106 |
-
pass
|
| 107 |
-
|
| 108 |
-
def on_before_zero_grad(
|
| 109 |
-
self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int
|
| 110 |
-
) -> None:
|
| 111 |
-
"""Hook before zero_grad() is called.
|
| 112 |
-
|
| 113 |
-
Args:
|
| 114 |
-
optimizer (torch.optim.Optimizer): The model optimizer.
|
| 115 |
-
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
|
| 116 |
-
iteration (int): Current iteration number.
|
| 117 |
-
"""
|
| 118 |
-
pass
|
| 119 |
-
|
| 120 |
-
def on_after_backward(self, iteration: int = 0) -> None:
|
| 121 |
-
"""Hook after loss.backward() is called.
|
| 122 |
-
|
| 123 |
-
This method is called immediately after the backward pass, allowing for custom operations
|
| 124 |
-
or modifications to be performed on the gradients before the optimizer step.
|
| 125 |
-
|
| 126 |
-
Args:
|
| 127 |
-
iteration (int): Current iteration number.
|
| 128 |
-
"""
|
| 129 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/trainer.py
DELETED
|
@@ -1,379 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import functools
|
| 17 |
-
import inspect
|
| 18 |
-
import os
|
| 19 |
-
import signal
|
| 20 |
-
|
| 21 |
-
import torch
|
| 22 |
-
import torch.distributed as dist
|
| 23 |
-
import torch.utils.data
|
| 24 |
-
|
| 25 |
-
from lyra_2._ext.imaginaire.utils.context_managers import distributed_init
|
| 26 |
-
from lyra_2._ext.imaginaire.utils.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
|
| 27 |
-
|
| 28 |
-
try:
|
| 29 |
-
from megatron.core import parallel_state
|
| 30 |
-
|
| 31 |
-
USE_MEGATRON = True
|
| 32 |
-
except ImportError:
|
| 33 |
-
USE_MEGATRON = False
|
| 34 |
-
print("Megatron-core is not installed.")
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
from lyra_2._ext.imaginaire.lazy_config import LazyConfig, instantiate
|
| 38 |
-
from lyra_2._ext.imaginaire.model import ImaginaireModel
|
| 39 |
-
from lyra_2._ext.imaginaire.utils import callback, distributed, ema, log, misc
|
| 40 |
-
from lyra_2._ext.imaginaire.utils.checkpointer import Checkpointer
|
| 41 |
-
from lyra_2._ext.imaginaire.utils.misc import StragglerDetectorV2
|
| 42 |
-
|
| 43 |
-
try:
|
| 44 |
-
from lyra_2._src.callbacks.smart_stop import TimeoutException
|
| 45 |
-
except ImportError:
|
| 46 |
-
# Define a dummy exception if smart_stop is not available
|
| 47 |
-
class TimeoutException(Exception):
|
| 48 |
-
pass
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
class ImaginaireTrainer:
|
| 52 |
-
"""The base trainer class of Imaginaire.
|
| 53 |
-
|
| 54 |
-
All trainers in Imaginaire should inherit ImaginaireTrainer. It contains the basic functionality for model training
|
| 55 |
-
(particularly suited for large-scale training), including data parallel (DDP/FSDP), model weight average (EMA),
|
| 56 |
-
mixed-precision training (fp16/bf16).
|
| 57 |
-
|
| 58 |
-
Attributes:
|
| 59 |
-
checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states.
|
| 60 |
-
training_timer (misc.Timer): Timer object to time code blocks and functions.
|
| 61 |
-
"""
|
| 62 |
-
|
| 63 |
-
def __init__(self, config):
|
| 64 |
-
"""Constructor of the trainer.
|
| 65 |
-
|
| 66 |
-
Args:
|
| 67 |
-
config (Config): The config object for the Imaginaire codebase.
|
| 68 |
-
"""
|
| 69 |
-
super().__init__()
|
| 70 |
-
self.config = config
|
| 71 |
-
# Set up the distributed computing environment.
|
| 72 |
-
with distributed_init():
|
| 73 |
-
distributed.init()
|
| 74 |
-
# Set up parallel states.
|
| 75 |
-
if hasattr(config.model, "context_parallel_size"):
|
| 76 |
-
if config.model_parallel.context_parallel_size > 1:
|
| 77 |
-
raise ValueError(
|
| 78 |
-
"Both config.model.context_parallel_size and config.model_parallel.context_parallel_size are set. "
|
| 79 |
-
"config.model.context_parallel_size is deprecated. Please only set config.model_parallel.context_parallel_size."
|
| 80 |
-
)
|
| 81 |
-
else:
|
| 82 |
-
log.critical(
|
| 83 |
-
"Using deprecated config.model.context_parallel_size. Please use config.model_parallel.context_parallel_size instead."
|
| 84 |
-
)
|
| 85 |
-
config.model_parallel.context_parallel_size = config.model.context_parallel_size
|
| 86 |
-
if USE_MEGATRON:
|
| 87 |
-
if (
|
| 88 |
-
"create_gloo_process_groups"
|
| 89 |
-
in inspect.signature(parallel_state.initialize_model_parallel).parameters
|
| 90 |
-
):
|
| 91 |
-
parallel_state.initialize_model_parallel(
|
| 92 |
-
pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size,
|
| 93 |
-
tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size,
|
| 94 |
-
context_parallel_size=config.model_parallel.context_parallel_size,
|
| 95 |
-
create_gloo_process_groups=False,
|
| 96 |
-
)
|
| 97 |
-
else:
|
| 98 |
-
parallel_state.initialize_model_parallel(
|
| 99 |
-
pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size,
|
| 100 |
-
tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size,
|
| 101 |
-
context_parallel_size=config.model_parallel.context_parallel_size,
|
| 102 |
-
)
|
| 103 |
-
# `config.model_parallel.sequence_parallel` is a bool that indicates whether to use sequence parallelism.
|
| 104 |
-
# It is not part of the original `parallel_state` API, so we need to set it manually.
|
| 105 |
-
parallel_state.sequence_parallel = config.model_parallel.sequence_parallel
|
| 106 |
-
if parallel_state.sequence_parallel:
|
| 107 |
-
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
| 108 |
-
|
| 109 |
-
# Create the local job directory, save the config file, and pipe to a local log.
|
| 110 |
-
if distributed.is_rank0():
|
| 111 |
-
os.makedirs(config.job.path_local, exist_ok=True)
|
| 112 |
-
# Save the config as .pkl for reproducibility.
|
| 113 |
-
LazyConfig.save_pkl(config, f"{config.job.path_local}/config.pkl")
|
| 114 |
-
# Save the config as .yaml for reading or parsing experiment hyperparameters.
|
| 115 |
-
LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml")
|
| 116 |
-
dist.barrier()
|
| 117 |
-
log.init_loguru_file(f"{config.job.path_local}/stdout.log")
|
| 118 |
-
if distributed.is_rank0():
|
| 119 |
-
# Print important environment variables and the effective config.
|
| 120 |
-
log.info("Config:\n" + config.pretty_print(use_color=True))
|
| 121 |
-
misc.print_environ_variables(["HF_HOME", "IMAGINAIRE_OUTPUT_ROOT"])
|
| 122 |
-
# Set the random seed. If multi-GPU, different ranks are set with different seeds.
|
| 123 |
-
misc.set_random_seed(seed=config.trainer.seed, by_rank=True)
|
| 124 |
-
# Initialize cuDNN.
|
| 125 |
-
torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic
|
| 126 |
-
torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark
|
| 127 |
-
# Floating-point precision settings.
|
| 128 |
-
torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True
|
| 129 |
-
# Initialize the callback functions.
|
| 130 |
-
self.callbacks = callback.CallBackGroup(config=config, trainer=self)
|
| 131 |
-
# Initialize the model checkpointer.
|
| 132 |
-
if config.checkpoint.type is None:
|
| 133 |
-
self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks)
|
| 134 |
-
else:
|
| 135 |
-
self.checkpointer: Checkpointer = instantiate(
|
| 136 |
-
config.checkpoint.type, config.checkpoint, config.job, callbacks=self.callbacks
|
| 137 |
-
)
|
| 138 |
-
# Initialize the timer for speed benchmarking.
|
| 139 |
-
self.training_timer = misc.TrainingTimer()
|
| 140 |
-
# Initialize Straggler Detection
|
| 141 |
-
self.straggler_detector = StragglerDetectorV2(
|
| 142 |
-
enabled=self.config.trainer.straggler_detection.enabled,
|
| 143 |
-
report_freq=self.config.trainer.straggler_detection.report_freq,
|
| 144 |
-
profile_freq=self.config.trainer.straggler_detection.profile_freq,
|
| 145 |
-
max_diff=self.config.trainer.straggler_detection.max_diff,
|
| 146 |
-
raise_error=self.config.trainer.straggler_detection.raise_error,
|
| 147 |
-
)
|
| 148 |
-
self.straggler_detector.initialize()
|
| 149 |
-
# Send a TimeoutError if a training step takes over timeout_period seconds.
|
| 150 |
-
signal.signal(signal.SIGALRM, functools.partial(misc.timeout_handler, config.trainer.timeout_period)) # type: ignore
|
| 151 |
-
|
| 152 |
-
def train(
|
| 153 |
-
self,
|
| 154 |
-
model: ImaginaireModel,
|
| 155 |
-
dataloader_train: torch.utils.data.DataLoader,
|
| 156 |
-
dataloader_val: torch.utils.data.DataLoader,
|
| 157 |
-
) -> None:
|
| 158 |
-
"""The training function.
|
| 159 |
-
|
| 160 |
-
Args:
|
| 161 |
-
model (ImaginaireModel): The PyTorch model.
|
| 162 |
-
dataloader_train (torch.utils.data.DataLoader): The training data loader.
|
| 163 |
-
dataloader_val (torch.utils.data.DataLoader): The validation data loader.
|
| 164 |
-
"""
|
| 165 |
-
# Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models.
|
| 166 |
-
model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore
|
| 167 |
-
model.on_train_start(self.config.trainer.memory_format)
|
| 168 |
-
|
| 169 |
-
# Initialize the optimizer, scheduler, and grad_scaler.
|
| 170 |
-
self.callbacks.on_optimizer_init_start()
|
| 171 |
-
optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler)
|
| 172 |
-
grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args)
|
| 173 |
-
self.callbacks.on_optimizer_init_end()
|
| 174 |
-
# Load the model checkpoint and get the starting iteration number.
|
| 175 |
-
iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler)
|
| 176 |
-
grad_accum_iter = 0
|
| 177 |
-
log.info(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}")
|
| 178 |
-
if self.config.trainer.distributed_parallelism == "ddp":
|
| 179 |
-
# Create a DDP model wrapper.
|
| 180 |
-
model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model)
|
| 181 |
-
elif self.config.trainer.distributed_parallelism == "fsdp":
|
| 182 |
-
model_ddp = model
|
| 183 |
-
else:
|
| 184 |
-
raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}")
|
| 185 |
-
|
| 186 |
-
log.info("Starting training...")
|
| 187 |
-
self.callbacks.on_train_start(model, iteration=iteration)
|
| 188 |
-
# Initial validation.
|
| 189 |
-
if self.config.trainer.run_validation and iteration == 0:
|
| 190 |
-
self.validate(model, dataloader_val, iteration=iteration)
|
| 191 |
-
_end_training = False
|
| 192 |
-
_smart_stop_triggered = False
|
| 193 |
-
_oom_triggered = False
|
| 194 |
-
with (
|
| 195 |
-
maybe_enable_profiling(self.config, global_step=iteration) as torch_profiler,
|
| 196 |
-
maybe_enable_memory_snapshot(self.config, global_step=iteration) as memory_profiler,
|
| 197 |
-
):
|
| 198 |
-
while True:
|
| 199 |
-
dataloader_train_iter = iter(dataloader_train)
|
| 200 |
-
while True:
|
| 201 |
-
self.callbacks.on_before_dataloading(iteration)
|
| 202 |
-
try:
|
| 203 |
-
with (
|
| 204 |
-
self.training_timer("dataloader_train"),
|
| 205 |
-
self.straggler_detector.profile_section(
|
| 206 |
-
"dataloading",
|
| 207 |
-
self.config.trainer.straggler_detection.analyze_dataloading,
|
| 208 |
-
profile_cuda=False,
|
| 209 |
-
),
|
| 210 |
-
):
|
| 211 |
-
data_batch = next(dataloader_train_iter)
|
| 212 |
-
except StopIteration:
|
| 213 |
-
break
|
| 214 |
-
finally:
|
| 215 |
-
self.callbacks.on_after_dataloading(iteration)
|
| 216 |
-
# If max_iter is reached, exit the training loop.
|
| 217 |
-
if iteration >= self.config.trainer.max_iter:
|
| 218 |
-
_end_training = True
|
| 219 |
-
break
|
| 220 |
-
# Move all tensors in the data batch to GPU device.
|
| 221 |
-
data_batch = misc.to(data_batch, device="cuda")
|
| 222 |
-
# The actual training step.
|
| 223 |
-
self.callbacks.on_training_step_start(model, data_batch, iteration=iteration)
|
| 224 |
-
self.callbacks.on_training_step_batch_start(model, data_batch, iteration=iteration)
|
| 225 |
-
if not model.training:
|
| 226 |
-
model_ddp.train()
|
| 227 |
-
assert model_ddp.training, "model_ddp is not in training mode."
|
| 228 |
-
assert model.training, "model is not in training mode."
|
| 229 |
-
try:
|
| 230 |
-
output_batch, loss, grad_accum_iter = self.training_step(
|
| 231 |
-
model_ddp,
|
| 232 |
-
optimizer,
|
| 233 |
-
scheduler,
|
| 234 |
-
grad_scaler,
|
| 235 |
-
data_batch,
|
| 236 |
-
iteration=iteration,
|
| 237 |
-
grad_accum_iter=grad_accum_iter,
|
| 238 |
-
)
|
| 239 |
-
except torch.OutOfMemoryError as e:
|
| 240 |
-
# CUDA OOM error - save checkpoint and exit gracefully
|
| 241 |
-
_oom_triggered = True
|
| 242 |
-
log.error(f"CUDA Out of Memory error caught: {e}")
|
| 243 |
-
log.info("Saving checkpoint due to CUDA OOM error...")
|
| 244 |
-
self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration)
|
| 245 |
-
log.success("Checkpoint saved successfully. Exiting training gracefully due to OOM.")
|
| 246 |
-
_end_training = True
|
| 247 |
-
break
|
| 248 |
-
self.callbacks.on_training_step_batch_end(
|
| 249 |
-
model, data_batch, output_batch, loss, iteration=iteration
|
| 250 |
-
)
|
| 251 |
-
# If the gradients are still being accumulated, continue to load the next training batch.
|
| 252 |
-
if grad_accum_iter != 0:
|
| 253 |
-
continue
|
| 254 |
-
# Do the following when an actual optimizer (update) step has been made.
|
| 255 |
-
iteration += 1
|
| 256 |
-
# Save checkpoint.
|
| 257 |
-
if iteration % self.config.checkpoint.save_iter == 0:
|
| 258 |
-
self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration)
|
| 259 |
-
# Call on_training_step_end with SmartStop exception handling
|
| 260 |
-
try:
|
| 261 |
-
self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration)
|
| 262 |
-
except TimeoutException as e:
|
| 263 |
-
# Smart stop triggered - save checkpoint and exit gracefully
|
| 264 |
-
_smart_stop_triggered = True
|
| 265 |
-
log.warning(f"SmartStop exception caught: {e}")
|
| 266 |
-
log.info("Saving checkpoint due to time limit...")
|
| 267 |
-
self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration)
|
| 268 |
-
log.success("Checkpoint saved successfully. Exiting training gracefully.")
|
| 269 |
-
_end_training = True
|
| 270 |
-
break
|
| 271 |
-
# Validation.
|
| 272 |
-
if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0:
|
| 273 |
-
self.validate(model, dataloader_val, iteration=iteration)
|
| 274 |
-
# This iteration is successful; reset the timeout signal.
|
| 275 |
-
signal.alarm(self.config.trainer.timeout_period)
|
| 276 |
-
self.straggler_detector.generate_report(iteration)
|
| 277 |
-
if torch_profiler:
|
| 278 |
-
torch_profiler.step()
|
| 279 |
-
if memory_profiler:
|
| 280 |
-
memory_profiler.step()
|
| 281 |
-
if _end_training:
|
| 282 |
-
break
|
| 283 |
-
log.success("Done with training.")
|
| 284 |
-
if iteration % self.config.checkpoint.save_iter != 0 and not _smart_stop_triggered and not _oom_triggered:
|
| 285 |
-
self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration)
|
| 286 |
-
self.callbacks.on_train_end(model, iteration=iteration)
|
| 287 |
-
self.checkpointer.finalize()
|
| 288 |
-
distributed.barrier()
|
| 289 |
-
self.callbacks.on_app_end()
|
| 290 |
-
|
| 291 |
-
def training_step(
|
| 292 |
-
self,
|
| 293 |
-
model_ddp: torch.nn.Module | distributed.DistributedDataParallel,
|
| 294 |
-
optimizer: torch.optim.Optimizer,
|
| 295 |
-
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 296 |
-
grad_scaler: torch.amp.GradScaler,
|
| 297 |
-
data: dict[str, torch.Tensor],
|
| 298 |
-
iteration: int = 0,
|
| 299 |
-
grad_accum_iter: int = 0,
|
| 300 |
-
) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]:
|
| 301 |
-
"""The training step.
|
| 302 |
-
|
| 303 |
-
Args:
|
| 304 |
-
model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare
|
| 305 |
-
module, depending on whether distributed training is enabled or not.
|
| 306 |
-
optimizer (torch.optim.Optimizer): The model optimizer.
|
| 307 |
-
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
|
| 308 |
-
grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
|
| 309 |
-
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
|
| 310 |
-
iteration (int): Current iteration number.
|
| 311 |
-
grad_accum_iter (int): Number of gradient accumulation iterations.
|
| 312 |
-
|
| 313 |
-
Returns:
|
| 314 |
-
output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors).
|
| 315 |
-
loss (torch.Tensor): The total loss of the training data batch.
|
| 316 |
-
"""
|
| 317 |
-
# Only let DDP sync gradient at the last iteration of the gradient accumulation window
|
| 318 |
-
with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1):
|
| 319 |
-
self.callbacks.on_before_forward(iteration=iteration)
|
| 320 |
-
with self.training_timer("forward"):
|
| 321 |
-
with self.straggler_detector.profile_section(
|
| 322 |
-
"fwd", self.config.trainer.straggler_detection.analyze_forward
|
| 323 |
-
):
|
| 324 |
-
output_batch, loss = model_ddp.training_step(data, iteration)
|
| 325 |
-
self.callbacks.on_after_forward(iteration=iteration)
|
| 326 |
-
self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration)
|
| 327 |
-
with self.training_timer("backward"):
|
| 328 |
-
with self.straggler_detector.profile_section(
|
| 329 |
-
"bwd", self.config.trainer.straggler_detection.analyze_backward
|
| 330 |
-
):
|
| 331 |
-
loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter)
|
| 332 |
-
loss_scaled.backward()
|
| 333 |
-
if self.config.trainer.distributed_parallelism == "ddp":
|
| 334 |
-
model_ddp.module.on_after_backward()
|
| 335 |
-
else:
|
| 336 |
-
model_ddp.on_after_backward()
|
| 337 |
-
self.callbacks.on_after_backward(model_ddp, iteration=iteration)
|
| 338 |
-
grad_accum_iter += 1
|
| 339 |
-
if grad_accum_iter == self.config.trainer.grad_accum_iter:
|
| 340 |
-
with self.training_timer("optimizer_step"):
|
| 341 |
-
with self.straggler_detector.profile_section(
|
| 342 |
-
"opt", self.config.trainer.straggler_detection.analyze_optimizer
|
| 343 |
-
):
|
| 344 |
-
self.callbacks.on_before_optimizer_step(
|
| 345 |
-
model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration
|
| 346 |
-
)
|
| 347 |
-
grad_scaler.step(optimizer)
|
| 348 |
-
grad_scaler.update()
|
| 349 |
-
scheduler.step()
|
| 350 |
-
self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration)
|
| 351 |
-
if self.config.trainer.distributed_parallelism == "ddp":
|
| 352 |
-
model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration)
|
| 353 |
-
else:
|
| 354 |
-
model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration)
|
| 355 |
-
optimizer.zero_grad(set_to_none=True)
|
| 356 |
-
grad_accum_iter = 0
|
| 357 |
-
return output_batch, loss, grad_accum_iter
|
| 358 |
-
|
| 359 |
-
@torch.no_grad()
|
| 360 |
-
def validate(self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None:
|
| 361 |
-
"""Validate on the full validation dataset.
|
| 362 |
-
|
| 363 |
-
Args:
|
| 364 |
-
model (ImaginaireModel): The PyTorch model.
|
| 365 |
-
dataloader_val (torch.utils.data.DataLoader): The validation data loader.
|
| 366 |
-
iteration (int): Current iteration number.
|
| 367 |
-
"""
|
| 368 |
-
self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration)
|
| 369 |
-
model.eval()
|
| 370 |
-
# Evaluate on the full validation set.
|
| 371 |
-
with ema.ema_scope(model, enabled=model.config.ema.enabled):
|
| 372 |
-
for val_iter, data_batch in enumerate(dataloader_val):
|
| 373 |
-
if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter:
|
| 374 |
-
break
|
| 375 |
-
data_batch = misc.to(data_batch, device="cuda")
|
| 376 |
-
self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration)
|
| 377 |
-
output_batch, loss = model.validation_step(data_batch, iteration)
|
| 378 |
-
self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration)
|
| 379 |
-
self.callbacks.on_validation_end(model, iteration=iteration)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/types/denoise_prediction.py
DELETED
|
@@ -1,28 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from __future__ import annotations
|
| 17 |
-
|
| 18 |
-
from dataclasses import dataclass
|
| 19 |
-
from typing import Optional
|
| 20 |
-
|
| 21 |
-
import torch
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
@dataclass
|
| 25 |
-
class DenoisePrediction:
|
| 26 |
-
x0: torch.Tensor # clean data prediction
|
| 27 |
-
eps: Optional[torch.Tensor] = None # noise prediction
|
| 28 |
-
logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/__init__.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/callback.py
DELETED
|
@@ -1,524 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from __future__ import annotations
|
| 17 |
-
|
| 18 |
-
import sys
|
| 19 |
-
import time
|
| 20 |
-
import warnings
|
| 21 |
-
from typing import TYPE_CHECKING, Any, Callable, Optional
|
| 22 |
-
|
| 23 |
-
import omegaconf
|
| 24 |
-
import torch
|
| 25 |
-
import torch.distributed as dist
|
| 26 |
-
import torch.utils.data
|
| 27 |
-
import tqdm
|
| 28 |
-
from lyra_2._ext.imaginaire.lazy_config import instantiate
|
| 29 |
-
from lyra_2._ext.imaginaire.utils import distributed, log, misc
|
| 30 |
-
from lyra_2._ext.imaginaire.utils.misc import get_local_tensor_if_DTensor
|
| 31 |
-
|
| 32 |
-
try:
|
| 33 |
-
from megatron.core import parallel_state
|
| 34 |
-
except ImportError:
|
| 35 |
-
parallel_state = None
|
| 36 |
-
print("Megatron-core is not installed.")
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
if TYPE_CHECKING:
|
| 40 |
-
from lyra_2._ext.imaginaire.config import Config
|
| 41 |
-
from lyra_2._ext.imaginaire.model import ImaginaireModel
|
| 42 |
-
from lyra_2._ext.imaginaire.trainer import ImaginaireTrainer
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class CallBackGroup:
|
| 46 |
-
"""A class for hosting a collection of callback objects.
|
| 47 |
-
|
| 48 |
-
It is used to execute callback functions of multiple callback objects with the same method name.
|
| 49 |
-
When callbackgroup.func(args) is executed, internally it loops through the objects in self._callbacks and runs
|
| 50 |
-
self._callbacks[0].func(args), self._callbacks[1].func(args), etc. The method name and arguments should match.
|
| 51 |
-
|
| 52 |
-
Attributes:
|
| 53 |
-
_callbacks (list[Callback]): List of callback objects.
|
| 54 |
-
"""
|
| 55 |
-
|
| 56 |
-
def __init__(self, config: Config, trainer: ImaginaireTrainer) -> None:
|
| 57 |
-
"""Initializes the list of callback objects.
|
| 58 |
-
|
| 59 |
-
Args:
|
| 60 |
-
config (Config): The config object for the Imaginaire codebase.
|
| 61 |
-
trainer (ImaginaireTrainer): The main trainer.
|
| 62 |
-
"""
|
| 63 |
-
self._callbacks = []
|
| 64 |
-
callback_configs = config.trainer.callbacks
|
| 65 |
-
if callback_configs:
|
| 66 |
-
if isinstance(callback_configs, list) or isinstance(callback_configs, omegaconf.listconfig.ListConfig):
|
| 67 |
-
warnings.warn(
|
| 68 |
-
"The 'config.trainer.callbacks' parameter should be a dict instead of a list. "
|
| 69 |
-
"Please update your code",
|
| 70 |
-
DeprecationWarning,
|
| 71 |
-
stacklevel=2,
|
| 72 |
-
)
|
| 73 |
-
callback_configs = {f"callback_{i}": v for i, v in enumerate(callback_configs)}
|
| 74 |
-
for callback_name, current_callback_cfg in callback_configs.items():
|
| 75 |
-
if "_target_" not in current_callback_cfg:
|
| 76 |
-
log.critical(
|
| 77 |
-
f"Callback {callback_name} is missing the '_target_' field. \n SKip {current_callback_cfg}"
|
| 78 |
-
)
|
| 79 |
-
continue
|
| 80 |
-
log.info(f"Instantiating callback {callback_name}: {current_callback_cfg}")
|
| 81 |
-
_callback = instantiate(current_callback_cfg)
|
| 82 |
-
assert isinstance(_callback, Callback), f"{current_callback_cfg} is not a valid callback."
|
| 83 |
-
_callback.config = config
|
| 84 |
-
_callback.trainer = trainer
|
| 85 |
-
self._callbacks.append(_callback)
|
| 86 |
-
|
| 87 |
-
def __getattr__(self, method_name: str) -> Callable:
|
| 88 |
-
"""Loops through the callback objects to call the corresponding callback function.
|
| 89 |
-
|
| 90 |
-
Args:
|
| 91 |
-
method_name (str): Callback method name.
|
| 92 |
-
"""
|
| 93 |
-
|
| 94 |
-
def multi_callback_wrapper(*args, **kwargs) -> None:
|
| 95 |
-
for callback in self._callbacks:
|
| 96 |
-
assert hasattr(callback, method_name)
|
| 97 |
-
method = getattr(callback, method_name)
|
| 98 |
-
assert callable(method)
|
| 99 |
-
_ = method(*args, **kwargs)
|
| 100 |
-
|
| 101 |
-
return multi_callback_wrapper
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
class Callback:
|
| 105 |
-
"""The base class for all callbacks.
|
| 106 |
-
|
| 107 |
-
All callbacks should inherit from this class and adhere to the established method names and signatures.
|
| 108 |
-
"""
|
| 109 |
-
|
| 110 |
-
def __init__(self, config: Optional["Config"] = None, trainer: Optional["ImaginaireTrainer"] = None):
|
| 111 |
-
"""Initializes a Callback object.
|
| 112 |
-
|
| 113 |
-
Args:
|
| 114 |
-
config (Optional[Config]): The configuration object for the Imaginaire codebase, if available.
|
| 115 |
-
trainer (Optional[ImaginaireTrainer]): The main trainer handling the training loop, if available.
|
| 116 |
-
|
| 117 |
-
Notes:
|
| 118 |
-
The config and trainer parameters are optional to maintain backward compatibility.
|
| 119 |
-
In future releases, these parameters will be removed. Upon using these parameters, a deprecation
|
| 120 |
-
warning will be issued.
|
| 121 |
-
|
| 122 |
-
"""
|
| 123 |
-
if config is not None or trainer is not None:
|
| 124 |
-
warnings.warn(
|
| 125 |
-
"The 'config' and 'trainer' parameters are deprecated and will be removed in a future release. "
|
| 126 |
-
"Please update your code to create Callback instances without these parameters.",
|
| 127 |
-
DeprecationWarning,
|
| 128 |
-
stacklevel=2,
|
| 129 |
-
)
|
| 130 |
-
del config, trainer
|
| 131 |
-
|
| 132 |
-
def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 133 |
-
pass
|
| 134 |
-
|
| 135 |
-
def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
|
| 136 |
-
"""
|
| 137 |
-
Called before the training step, for each batch. This is paired with on_training_step_end() but note that
|
| 138 |
-
when using gradient accumulation, while on_training_step_end() is only called when the optimizer is updated,
|
| 139 |
-
this function is called for every batch.
|
| 140 |
-
Use on_training_step_batch_start and on_training_step_batch_end if you need callbacks that are called
|
| 141 |
-
for every batch, albeit with the same iteration number.
|
| 142 |
-
FIXME - should this either be deprecated, or called only when a new training step is started after having updated
|
| 143 |
-
the optimizer?
|
| 144 |
-
"""
|
| 145 |
-
pass
|
| 146 |
-
|
| 147 |
-
def on_training_step_batch_start(
|
| 148 |
-
self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
|
| 149 |
-
) -> None:
|
| 150 |
-
"""
|
| 151 |
-
Called before the training step, for each batch, similarly to on_training_step_start(). This function is paired with
|
| 152 |
-
on_training_step_batch_end(), and both functions are called for every batch even when using gradient accumulation.
|
| 153 |
-
Note that the iteration is only updated when the optimizer is updated, and therefore it may be the same for multiple invocations.
|
| 154 |
-
"""
|
| 155 |
-
pass
|
| 156 |
-
|
| 157 |
-
def on_before_forward(self, iteration: int = 0) -> None:
|
| 158 |
-
pass
|
| 159 |
-
|
| 160 |
-
def on_after_forward(self, iteration: int = 0) -> None:
|
| 161 |
-
pass
|
| 162 |
-
|
| 163 |
-
def on_before_backward(
|
| 164 |
-
self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0
|
| 165 |
-
) -> None:
|
| 166 |
-
pass
|
| 167 |
-
|
| 168 |
-
def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None:
|
| 169 |
-
pass
|
| 170 |
-
|
| 171 |
-
def on_before_dataloading(self, iteration: int = 0) -> None:
|
| 172 |
-
pass
|
| 173 |
-
|
| 174 |
-
def on_after_dataloading(self, iteration: int = 0) -> None:
|
| 175 |
-
pass
|
| 176 |
-
|
| 177 |
-
def on_optimizer_init_start(self) -> None:
|
| 178 |
-
pass
|
| 179 |
-
|
| 180 |
-
def on_optimizer_init_end(self) -> None:
|
| 181 |
-
pass
|
| 182 |
-
|
| 183 |
-
def on_before_optimizer_step(
|
| 184 |
-
self,
|
| 185 |
-
model_ddp: distributed.DistributedDataParallel,
|
| 186 |
-
optimizer: torch.optim.Optimizer,
|
| 187 |
-
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 188 |
-
grad_scaler: torch.amp.GradScaler,
|
| 189 |
-
iteration: int = 0,
|
| 190 |
-
) -> None:
|
| 191 |
-
pass
|
| 192 |
-
|
| 193 |
-
def on_before_zero_grad(
|
| 194 |
-
self,
|
| 195 |
-
model_ddp: distributed.DistributedDataParallel,
|
| 196 |
-
optimizer: torch.optim.Optimizer,
|
| 197 |
-
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 198 |
-
iteration: int = 0,
|
| 199 |
-
) -> None:
|
| 200 |
-
pass
|
| 201 |
-
|
| 202 |
-
def on_training_step_batch_end(
|
| 203 |
-
self,
|
| 204 |
-
model: ImaginaireModel,
|
| 205 |
-
data_batch: dict[str, torch.Tensor],
|
| 206 |
-
output_batch: dict[str, torch.Tensor],
|
| 207 |
-
loss: torch.Tensor,
|
| 208 |
-
iteration: int = 0,
|
| 209 |
-
) -> None:
|
| 210 |
-
"""
|
| 211 |
-
Called at the end of a training step for every batch even when using gradient accumulation.
|
| 212 |
-
This is paired with on_training_step_batch_start(). Note that the iteration is only updated when the optimizer is updated,
|
| 213 |
-
and therefore it may be the same for multiple batches.
|
| 214 |
-
"""
|
| 215 |
-
pass
|
| 216 |
-
|
| 217 |
-
def on_training_step_end(
|
| 218 |
-
self,
|
| 219 |
-
model: ImaginaireModel,
|
| 220 |
-
data_batch: dict[str, torch.Tensor],
|
| 221 |
-
output_batch: dict[str, torch.Tensor],
|
| 222 |
-
loss: torch.Tensor,
|
| 223 |
-
iteration: int = 0,
|
| 224 |
-
) -> None:
|
| 225 |
-
"""
|
| 226 |
-
Called at the end of a training step, but note that when using gradient accumulation, this is only called
|
| 227 |
-
when the optimizer is updated, and the iteration incremented, whereas on_training_step_start is called every time.
|
| 228 |
-
Use on_training_step_batch_start and on_training_step_batch_end if you need callbacks that are called
|
| 229 |
-
for every batch.
|
| 230 |
-
"""
|
| 231 |
-
pass
|
| 232 |
-
|
| 233 |
-
def on_validation_start(
|
| 234 |
-
self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0
|
| 235 |
-
) -> None:
|
| 236 |
-
pass
|
| 237 |
-
|
| 238 |
-
def on_validation_step_start(
|
| 239 |
-
self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
|
| 240 |
-
) -> None:
|
| 241 |
-
pass
|
| 242 |
-
|
| 243 |
-
def on_validation_step_end(
|
| 244 |
-
self,
|
| 245 |
-
model: ImaginaireModel,
|
| 246 |
-
data_batch: dict[str, torch.Tensor],
|
| 247 |
-
output_batch: dict[str, torch.Tensor],
|
| 248 |
-
loss: torch.Tensor,
|
| 249 |
-
iteration: int = 0,
|
| 250 |
-
) -> None:
|
| 251 |
-
pass
|
| 252 |
-
|
| 253 |
-
def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 254 |
-
pass
|
| 255 |
-
|
| 256 |
-
def on_load_checkpoint_start(self, model: ImaginaireModel) -> None:
|
| 257 |
-
pass
|
| 258 |
-
|
| 259 |
-
def on_load_checkpoint_end(
|
| 260 |
-
self, model: ImaginaireModel, iteration: int = 0, checkpoint_path: Optional[str] = None
|
| 261 |
-
) -> None:
|
| 262 |
-
pass
|
| 263 |
-
|
| 264 |
-
def on_load_checkpoint(self, model: ImaginaireModel, state_dict: dict[Any]) -> None:
|
| 265 |
-
"""
|
| 266 |
-
Called when checkpoint loading is about to start, but after on_save_checkpoint_start().
|
| 267 |
-
FIXME - why do we need this callback, can't we just use on_save_checkpoint_start()?
|
| 268 |
-
"""
|
| 269 |
-
pass
|
| 270 |
-
|
| 271 |
-
def on_save_checkpoint_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 272 |
-
"""
|
| 273 |
-
Called when checkpoint saving is about to start.
|
| 274 |
-
"""
|
| 275 |
-
pass
|
| 276 |
-
|
| 277 |
-
def on_save_checkpoint_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 278 |
-
"""
|
| 279 |
-
Called when the synchronous part of checkpointing is finished, this function can be used
|
| 280 |
-
along with on_save_checkpoint_start() to measure the exposed (synchronous) checkpoint time.
|
| 281 |
-
Note that for asynchronous checkpoint, the checkpoint may still be ongoing, so this function
|
| 282 |
-
does not mean the checkpoint is finished for the asynchronous case, use on_save_checkpoint_success()
|
| 283 |
-
for that.
|
| 284 |
-
"""
|
| 285 |
-
pass
|
| 286 |
-
|
| 287 |
-
def on_save_checkpoint_success(self, iteration: int = 0, elapsed_time: float = 0) -> None:
|
| 288 |
-
"""
|
| 289 |
-
Called when checkpoint saving is fully finished, and succeeded. Not called if checkpoint failed.
|
| 290 |
-
For synchronous checkpoint, it is called at the same time as on_save_checkpoint_end(), but for asynchronous
|
| 291 |
-
checkpoint, it is called after the asynchronous part has also finished. For checkpointers with out-of-process
|
| 292 |
-
checkpointing, this function is called as soon as the notification is received from the checkpointer process,
|
| 293 |
-
which may not be immediately after the checkpoint has completed but later on. Therefore, if you need to measure
|
| 294 |
-
the full checkpoint duration for the asynchronous part, use the elapsed_time parameter, do not measure it directly
|
| 295 |
-
as this would be a significant overestimate.
|
| 296 |
-
"""
|
| 297 |
-
pass
|
| 298 |
-
|
| 299 |
-
def on_save_checkpoint(self, model: ImaginaireModel, state_dict: dict[Any]) -> None:
|
| 300 |
-
pass
|
| 301 |
-
|
| 302 |
-
def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 303 |
-
pass
|
| 304 |
-
|
| 305 |
-
def on_app_end(self) -> None:
|
| 306 |
-
pass
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
class EMAModelCallback(Callback):
|
| 310 |
-
"""The callback class for tracking EMA model weights."""
|
| 311 |
-
|
| 312 |
-
def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 313 |
-
# Set up the EMA model weight tracker.
|
| 314 |
-
if model.config.ema.enabled:
|
| 315 |
-
assert hasattr(model, "ema"), "EMA should be initialized from ImaginaireModel"
|
| 316 |
-
# EMA model must be kept in FP32 precision.
|
| 317 |
-
model.ema = model.ema.to(dtype=torch.float32)
|
| 318 |
-
else:
|
| 319 |
-
assert not hasattr(model, "ema"), "There should be no EMA initialized."
|
| 320 |
-
|
| 321 |
-
def on_training_step_end(
|
| 322 |
-
self,
|
| 323 |
-
model: ImaginaireModel,
|
| 324 |
-
data_batch: dict[str, torch.Tensor],
|
| 325 |
-
output_batch: dict[str, torch.Tensor],
|
| 326 |
-
loss: torch.Tensor,
|
| 327 |
-
iteration: int = 0,
|
| 328 |
-
) -> None:
|
| 329 |
-
# Update the EMA model with the new regular weights.
|
| 330 |
-
if model.config.ema.enabled:
|
| 331 |
-
model.ema.update_average(model, iteration)
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
class ProgressBarCallback(Callback):
|
| 335 |
-
"""The callback class for visualizing the training/validation progress bar in the console."""
|
| 336 |
-
|
| 337 |
-
@distributed.rank0_only
|
| 338 |
-
def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 339 |
-
self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training")
|
| 340 |
-
|
| 341 |
-
@distributed.rank0_only
|
| 342 |
-
def on_training_step_end(
|
| 343 |
-
self,
|
| 344 |
-
model: ImaginaireModel,
|
| 345 |
-
data_batch: dict[str, torch.Tensor],
|
| 346 |
-
output_batch: dict[str, torch.Tensor],
|
| 347 |
-
loss: torch.Tensor,
|
| 348 |
-
iteration: int = 0,
|
| 349 |
-
) -> None:
|
| 350 |
-
self.train_pbar.update()
|
| 351 |
-
|
| 352 |
-
@distributed.rank0_only
|
| 353 |
-
def on_validation_start(
|
| 354 |
-
self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0
|
| 355 |
-
) -> None:
|
| 356 |
-
if self.config.trainer.max_val_iter is not None:
|
| 357 |
-
num_iter = self.config.trainer.max_val_iter
|
| 358 |
-
else:
|
| 359 |
-
num_iter = len(dataloader_val)
|
| 360 |
-
assert num_iter is not None and num_iter > 0, f"Invalid number of validation iterations: {num_iter}"
|
| 361 |
-
self.val_pbar = tqdm.trange(num_iter, desc="Validating", position=1, leave=False)
|
| 362 |
-
|
| 363 |
-
@distributed.rank0_only
|
| 364 |
-
def on_validation_step_end(
|
| 365 |
-
self,
|
| 366 |
-
model: ImaginaireModel,
|
| 367 |
-
data_batch: dict[str, torch.Tensor],
|
| 368 |
-
output_batch: dict[str, torch.Tensor],
|
| 369 |
-
loss: torch.Tensor,
|
| 370 |
-
iteration: int = 0,
|
| 371 |
-
) -> None:
|
| 372 |
-
self.val_pbar.update()
|
| 373 |
-
|
| 374 |
-
@distributed.rank0_only
|
| 375 |
-
def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 376 |
-
self.val_pbar.close()
|
| 377 |
-
|
| 378 |
-
@distributed.rank0_only
|
| 379 |
-
def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 380 |
-
self.trainer.checkpointer.finalize()
|
| 381 |
-
self.train_pbar.close()
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
class IterationLoggerCallback(Callback):
|
| 385 |
-
"""The callback class for visualizing the training/validation progress bar in the console."""
|
| 386 |
-
|
| 387 |
-
@distributed.rank0_only
|
| 388 |
-
def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 389 |
-
# self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training")
|
| 390 |
-
self.start_iteration_time = time.time()
|
| 391 |
-
self.elapsed_iteration_time = 0
|
| 392 |
-
|
| 393 |
-
@distributed.rank0_only
|
| 394 |
-
def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
|
| 395 |
-
self.start_iteration_time = time.time()
|
| 396 |
-
|
| 397 |
-
@distributed.rank0_only
|
| 398 |
-
def on_training_step_end(
|
| 399 |
-
self,
|
| 400 |
-
model: ImaginaireModel,
|
| 401 |
-
data_batch: dict[str, torch.Tensor],
|
| 402 |
-
output_batch: dict[str, torch.Tensor],
|
| 403 |
-
loss: torch.Tensor,
|
| 404 |
-
iteration: int = 0,
|
| 405 |
-
) -> None:
|
| 406 |
-
# but this is only called when the optimizer is updated, so it's only the time for the last batch.
|
| 407 |
-
self.elapsed_iteration_time += time.time() - self.start_iteration_time
|
| 408 |
-
|
| 409 |
-
if iteration % self.config.trainer.logging_iter == 0:
|
| 410 |
-
avg_time = self.elapsed_iteration_time / self.config.trainer.logging_iter
|
| 411 |
-
log.info(f"Iteration: {iteration}, average iter time: {avg_time:2f}, total loss {loss.item():4f}")
|
| 412 |
-
|
| 413 |
-
self.elapsed_iteration_time = 0
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
class LowPrecisionCallback(Callback):
|
| 417 |
-
"""The callback class handling low precision training"""
|
| 418 |
-
|
| 419 |
-
def __init__(self, config: Config, trainer: ImaginaireTrainer, update_iter: int):
|
| 420 |
-
self.update_iter = update_iter
|
| 421 |
-
|
| 422 |
-
def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None:
|
| 423 |
-
if model.precision == torch.float32:
|
| 424 |
-
log.critical("Using fp32. We should disable master weights update.")
|
| 425 |
-
self.update_iter = sys.maxsize
|
| 426 |
-
else:
|
| 427 |
-
assert model.precision in [
|
| 428 |
-
torch.bfloat16,
|
| 429 |
-
torch.float16,
|
| 430 |
-
torch.half,
|
| 431 |
-
], "LowPrecisionCallback must use a low precision dtype."
|
| 432 |
-
self.precision_type = model.precision
|
| 433 |
-
|
| 434 |
-
def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None:
|
| 435 |
-
for k, v in data.items():
|
| 436 |
-
if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]):
|
| 437 |
-
data[k] = v.to(dtype=self.precision_type)
|
| 438 |
-
|
| 439 |
-
def on_validation_step_start(
|
| 440 |
-
self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0
|
| 441 |
-
) -> None:
|
| 442 |
-
for k, v in data.items():
|
| 443 |
-
if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]):
|
| 444 |
-
data[k] = v.to(dtype=self.precision_type)
|
| 445 |
-
|
| 446 |
-
def on_before_zero_grad(
|
| 447 |
-
self,
|
| 448 |
-
model_ddp: distributed.DistributedDataParallel,
|
| 449 |
-
optimizer: torch.optim.Optimizer,
|
| 450 |
-
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 451 |
-
iteration: int = 0,
|
| 452 |
-
) -> None:
|
| 453 |
-
if iteration % self.update_iter == 0:
|
| 454 |
-
if getattr(optimizer, "master_weights", False):
|
| 455 |
-
params, master_params = [], []
|
| 456 |
-
for group, group_master in zip(optimizer.param_groups, optimizer.param_groups_master):
|
| 457 |
-
for p, p_master in zip(group["params"], group_master["params"]):
|
| 458 |
-
params.append(get_local_tensor_if_DTensor(p.data))
|
| 459 |
-
master_params.append(p_master.data)
|
| 460 |
-
torch._foreach_copy_(params, master_params)
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
class NVTXCallback(Callback):
|
| 464 |
-
"""The callback for creating NVTX ranges"""
|
| 465 |
-
|
| 466 |
-
def __init__(
|
| 467 |
-
self,
|
| 468 |
-
synchronize: bool = False,
|
| 469 |
-
config: Optional["Config"] = None,
|
| 470 |
-
trainer: Optional["ImaginaireTrainer"] = None,
|
| 471 |
-
):
|
| 472 |
-
super().__init__(config, trainer)
|
| 473 |
-
self.synchronize = synchronize
|
| 474 |
-
|
| 475 |
-
def on_before_forward(self, iteration: int = 0) -> None:
|
| 476 |
-
if self.synchronize:
|
| 477 |
-
torch.cuda.synchronize()
|
| 478 |
-
torch.cuda.nvtx.range_push("forward")
|
| 479 |
-
|
| 480 |
-
def on_after_forward(self, iteration: int = 0) -> None:
|
| 481 |
-
if self.synchronize:
|
| 482 |
-
torch.cuda.synchronize()
|
| 483 |
-
torch.cuda.nvtx.range_pop()
|
| 484 |
-
|
| 485 |
-
def on_before_backward(
|
| 486 |
-
self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0
|
| 487 |
-
) -> None:
|
| 488 |
-
if self.synchronize:
|
| 489 |
-
torch.cuda.synchronize()
|
| 490 |
-
torch.cuda.nvtx.range_push("backward")
|
| 491 |
-
|
| 492 |
-
def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None:
|
| 493 |
-
if self.synchronize:
|
| 494 |
-
torch.cuda.synchronize()
|
| 495 |
-
torch.cuda.nvtx.range_pop()
|
| 496 |
-
|
| 497 |
-
def on_before_optimizer_step(
|
| 498 |
-
self,
|
| 499 |
-
model_ddp: distributed.DistributedDataParallel,
|
| 500 |
-
optimizer: torch.optim.Optimizer,
|
| 501 |
-
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 502 |
-
grad_scaler: torch.amp.GradScaler,
|
| 503 |
-
iteration: int = 0,
|
| 504 |
-
) -> None:
|
| 505 |
-
if self.synchronize:
|
| 506 |
-
torch.cuda.synchronize()
|
| 507 |
-
torch.cuda.nvtx.range_push("optimizer_step")
|
| 508 |
-
|
| 509 |
-
def on_before_zero_grad(
|
| 510 |
-
self,
|
| 511 |
-
model_ddp: distributed.DistributedDataParallel,
|
| 512 |
-
optimizer: torch.optim.Optimizer,
|
| 513 |
-
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 514 |
-
iteration: int = 0,
|
| 515 |
-
) -> None:
|
| 516 |
-
if self.synchronize:
|
| 517 |
-
torch.cuda.synchronize()
|
| 518 |
-
torch.cuda.nvtx.range_pop()
|
| 519 |
-
|
| 520 |
-
def on_before_dataloading(self, iteration: int = 0) -> None:
|
| 521 |
-
torch.cuda.nvtx.range_push("dataloading")
|
| 522 |
-
|
| 523 |
-
def on_after_dataloading(self, iteration: int = 0) -> None:
|
| 524 |
-
torch.cuda.nvtx.range_pop()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/checkpointer.py
DELETED
|
@@ -1,372 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from __future__ import annotations
|
| 17 |
-
|
| 18 |
-
import os
|
| 19 |
-
import threading
|
| 20 |
-
from typing import TYPE_CHECKING, List, NamedTuple, Tuple
|
| 21 |
-
|
| 22 |
-
import torch
|
| 23 |
-
|
| 24 |
-
from lyra_2._ext.imaginaire.model import ImaginaireModel
|
| 25 |
-
from lyra_2._ext.imaginaire.utils import callback, distributed, log, misc, object_store
|
| 26 |
-
|
| 27 |
-
if TYPE_CHECKING:
|
| 28 |
-
from lyra_2._ext.imaginaire.config import CheckpointConfig, JobConfig
|
| 29 |
-
|
| 30 |
-
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
|
| 31 |
-
if TORCH_VERSION >= (1, 11):
|
| 32 |
-
from torch.ao import quantization
|
| 33 |
-
from torch.ao.quantization import FakeQuantizeBase, ObserverBase
|
| 34 |
-
elif (
|
| 35 |
-
TORCH_VERSION >= (1, 8)
|
| 36 |
-
and hasattr(torch.quantization, "FakeQuantizeBase")
|
| 37 |
-
and hasattr(torch.quantization, "ObserverBase")
|
| 38 |
-
):
|
| 39 |
-
from torch import quantization
|
| 40 |
-
from torch.quantization import FakeQuantizeBase, ObserverBase
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
class _IncompatibleKeys(
|
| 44 |
-
NamedTuple(
|
| 45 |
-
"IncompatibleKeys",
|
| 46 |
-
[
|
| 47 |
-
("missing_keys", List[str]),
|
| 48 |
-
("unexpected_keys", List[str]),
|
| 49 |
-
("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]),
|
| 50 |
-
],
|
| 51 |
-
)
|
| 52 |
-
):
|
| 53 |
-
pass
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
class Checkpointer:
|
| 57 |
-
"""The checkpointer class. Supports checkpoint saving/loading to both local disk or object store."""
|
| 58 |
-
|
| 59 |
-
def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup):
|
| 60 |
-
"""Constructor of the checkpointer.
|
| 61 |
-
|
| 62 |
-
Args:
|
| 63 |
-
config_checkpoint (CheckpointConfig): The config object for the checkpointer.
|
| 64 |
-
"""
|
| 65 |
-
# Set the callback functions.
|
| 66 |
-
self.callbacks = callbacks
|
| 67 |
-
|
| 68 |
-
self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints"
|
| 69 |
-
self.checkpoint_dir_object_store = f"{config_job.path}/checkpoints"
|
| 70 |
-
self.save_to_object_store = config_checkpoint.save_to_object_store.enabled
|
| 71 |
-
self.load_from_object_store = config_checkpoint.load_from_object_store.enabled
|
| 72 |
-
self.strict_resume = config_checkpoint.strict_resume
|
| 73 |
-
self.load_path = config_checkpoint.load_path or None
|
| 74 |
-
self.load_training_state = config_checkpoint.load_training_state
|
| 75 |
-
self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state
|
| 76 |
-
self.save_thread = None
|
| 77 |
-
# Create the object store client interface.
|
| 78 |
-
if self.save_to_object_store:
|
| 79 |
-
self.object_store_saver = object_store.ObjectStore(config_checkpoint.save_to_object_store)
|
| 80 |
-
if self.load_from_object_store:
|
| 81 |
-
self.object_store_loader = object_store.ObjectStore(config_checkpoint.load_from_object_store)
|
| 82 |
-
|
| 83 |
-
def save(
|
| 84 |
-
self,
|
| 85 |
-
model: ImaginaireModel,
|
| 86 |
-
optimizer: torch.optim.Optimizer,
|
| 87 |
-
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 88 |
-
grad_scaler: torch.amp.GradScaler,
|
| 89 |
-
iteration: int,
|
| 90 |
-
) -> None:
|
| 91 |
-
"""Save network weights, optimizer parameters, scheduler parameters to a checkpoint.
|
| 92 |
-
|
| 93 |
-
Args:
|
| 94 |
-
model (ImaginaireModel): The PyTorch model.
|
| 95 |
-
optimizer (torch.optim.Optimizer): The model optimizer.
|
| 96 |
-
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
|
| 97 |
-
grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training).
|
| 98 |
-
iteration (int): Current iteration number.
|
| 99 |
-
"""
|
| 100 |
-
self.callbacks.on_save_checkpoint_start(model, iteration)
|
| 101 |
-
|
| 102 |
-
checkpoint_file = f"iter_{iteration:09}.pt"
|
| 103 |
-
|
| 104 |
-
if distributed.get_rank() == 0:
|
| 105 |
-
state_dict = dict(
|
| 106 |
-
model=model.state_dict(),
|
| 107 |
-
optimizer=optimizer.state_dict(),
|
| 108 |
-
scheduler=scheduler.state_dict(),
|
| 109 |
-
grad_scaler=grad_scaler.state_dict(),
|
| 110 |
-
iteration=iteration,
|
| 111 |
-
)
|
| 112 |
-
state_dict = misc.to(state_dict, device="cpu")
|
| 113 |
-
self.callbacks.on_save_checkpoint(model, state_dict=state_dict)
|
| 114 |
-
# Wait for previous saver thread to end.
|
| 115 |
-
if self.save_thread:
|
| 116 |
-
self.save_thread.join()
|
| 117 |
-
# Run the checkpoint saver in a separate thread.
|
| 118 |
-
self.save_thread = threading.Thread(
|
| 119 |
-
target=self._save_worker_object_store if self.save_to_object_store else self._save_worker_local,
|
| 120 |
-
daemon=False,
|
| 121 |
-
args=(state_dict, checkpoint_file, distributed.get_rank()),
|
| 122 |
-
)
|
| 123 |
-
self.save_thread.start()
|
| 124 |
-
|
| 125 |
-
# Note: Checkpoints are saved on a separate thread and this callback is not accurate.
|
| 126 |
-
# Please check logs from on_save_checkpoint_success() for better accuracy
|
| 127 |
-
self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration)
|
| 128 |
-
|
| 129 |
-
@misc.timer("checkpoint saving (local)")
|
| 130 |
-
def _save_worker_local(self, state_dict: dict[str, torch.Tensor], checkpoint_file: str, rank: int = 0) -> None:
|
| 131 |
-
"""Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training).
|
| 132 |
-
|
| 133 |
-
Args:
|
| 134 |
-
state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler.
|
| 135 |
-
checkpoint_file (str): The file name of the model checkpoint.
|
| 136 |
-
rank (int): GPU device (default: 0).
|
| 137 |
-
"""
|
| 138 |
-
checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file)
|
| 139 |
-
os.makedirs(self.checkpoint_dir_local, exist_ok=True)
|
| 140 |
-
try:
|
| 141 |
-
torch.save(state_dict, checkpoint_path)
|
| 142 |
-
if rank == 0:
|
| 143 |
-
self._write_latest_checkpoint_file(checkpoint_file)
|
| 144 |
-
log.success(f"Saved checkpoint (local): {checkpoint_path}")
|
| 145 |
-
iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", ""))
|
| 146 |
-
self.callbacks.on_save_checkpoint_success(iteration=iteration)
|
| 147 |
-
except Exception as e: # noqa: BLE001
|
| 148 |
-
log.exception(f"Checkpoint failed to save (local): {e}")
|
| 149 |
-
|
| 150 |
-
@misc.timer("checkpoint saving (object store)")
|
| 151 |
-
def _save_worker_object_store(
|
| 152 |
-
self, state_dict: dict[str, torch.Tensor], checkpoint_file: str, rank: int = 0
|
| 153 |
-
) -> None:
|
| 154 |
-
"""Worker to upload checkpoint to object store, spawned with a child thread (in parallel with the training).
|
| 155 |
-
|
| 156 |
-
Args:
|
| 157 |
-
state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler.
|
| 158 |
-
checkpoint_file (str): The file name of the model checkpoint.
|
| 159 |
-
rank (int): GPU device (default: 0).
|
| 160 |
-
"""
|
| 161 |
-
checkpoint_path = os.path.join(self.checkpoint_dir_object_store, checkpoint_file)
|
| 162 |
-
try:
|
| 163 |
-
self.object_store_saver.save_object(state_dict, key=checkpoint_path, type="torch")
|
| 164 |
-
if rank == 0:
|
| 165 |
-
self._write_latest_checkpoint_file(checkpoint_file)
|
| 166 |
-
log.success(f"Saved checkpoint (object store): {checkpoint_path}")
|
| 167 |
-
iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", ""))
|
| 168 |
-
self.callbacks.on_save_checkpoint_success(iteration=iteration)
|
| 169 |
-
except Exception as e: # noqa: BLE001
|
| 170 |
-
log.exception(f"Checkpoint failed to upload (object store): {e}")
|
| 171 |
-
|
| 172 |
-
@misc.timer("checkpoint loading")
|
| 173 |
-
def load(
|
| 174 |
-
self,
|
| 175 |
-
model: ImaginaireModel,
|
| 176 |
-
optimizer: torch.optim.Optimizer | None = None,
|
| 177 |
-
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
|
| 178 |
-
grad_scaler: torch.amp.GradScaler | None = None,
|
| 179 |
-
) -> int:
|
| 180 |
-
"""Load network weights and optimizer states from a checkpoint in a single process.
|
| 181 |
-
|
| 182 |
-
The priority of the checkpoint loading logic is:
|
| 183 |
-
1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name.
|
| 184 |
-
2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path.
|
| 185 |
-
- This is typically used for inference mode.
|
| 186 |
-
- If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states.
|
| 187 |
-
3. If none of the above, randomly initialize the model parameters and train from scratch.
|
| 188 |
-
|
| 189 |
-
Args:
|
| 190 |
-
model (ImaginaireModel): The PyTorch model.
|
| 191 |
-
optimizer (torch.optim.Optimizer | None): The model optimizer (default: None).
|
| 192 |
-
scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None).
|
| 193 |
-
grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training).
|
| 194 |
-
|
| 195 |
-
Returns:
|
| 196 |
-
iteration (int): the iteration number to start/resume from.
|
| 197 |
-
"""
|
| 198 |
-
self.callbacks.on_load_checkpoint_start(model)
|
| 199 |
-
|
| 200 |
-
latest_checkpoint_file = self._read_latest_checkpoint_file()
|
| 201 |
-
if latest_checkpoint_file is not None:
|
| 202 |
-
# 1. Resume training from latest_checkpoint.txt under the same name.
|
| 203 |
-
checkpoint_dir = (
|
| 204 |
-
self.checkpoint_dir_object_store if self.load_from_object_store else self.checkpoint_dir_local
|
| 205 |
-
)
|
| 206 |
-
checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file)
|
| 207 |
-
resume = True
|
| 208 |
-
only_resume_scheduler = True
|
| 209 |
-
else:
|
| 210 |
-
if self.load_path:
|
| 211 |
-
# 2. Load the module weights specified by config_checkpoint.path.
|
| 212 |
-
checkpoint_path = self.load_path
|
| 213 |
-
resume = self.load_training_state
|
| 214 |
-
only_resume_scheduler = self.only_load_scheduler_state
|
| 215 |
-
else:
|
| 216 |
-
# 3. Randomly initialize the model parameters and train from scratch.
|
| 217 |
-
checkpoint_path = None
|
| 218 |
-
resume = False
|
| 219 |
-
only_resume_scheduler = False
|
| 220 |
-
# Load checkpoint.
|
| 221 |
-
if checkpoint_path is not None:
|
| 222 |
-
self._check_checkpoint_exists(checkpoint_path)
|
| 223 |
-
if self.load_from_object_store:
|
| 224 |
-
log.info(f"Loading checkpoint (object store): {checkpoint_path}")
|
| 225 |
-
state_dict = self.object_store_loader.load_object(key=checkpoint_path, type="torch", max_attempts=20)
|
| 226 |
-
log.success(f"Complete loading checkpoint (object store): {checkpoint_path}")
|
| 227 |
-
else:
|
| 228 |
-
log.info(f"Loading checkpoint (local): {checkpoint_path}")
|
| 229 |
-
state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
| 230 |
-
log.success(f"Complete loading checkpoint (local): {checkpoint_path}")
|
| 231 |
-
self.callbacks.on_load_checkpoint(model, state_dict=state_dict)
|
| 232 |
-
# Load the state dicts.
|
| 233 |
-
log.info("- Loading the model...")
|
| 234 |
-
model.load_state_dict(state_dict["model"], strict=self.strict_resume)
|
| 235 |
-
if resume or only_resume_scheduler:
|
| 236 |
-
iteration = state_dict["iteration"]
|
| 237 |
-
assert scheduler
|
| 238 |
-
log.info("- Loading the scheduler...")
|
| 239 |
-
scheduler.load_state_dict(state_dict["scheduler"])
|
| 240 |
-
scheduler.last_epoch = iteration
|
| 241 |
-
else:
|
| 242 |
-
iteration = 0
|
| 243 |
-
if resume:
|
| 244 |
-
assert optimizer
|
| 245 |
-
log.info("- Loading the optimizer...")
|
| 246 |
-
optimizer.load_state_dict(state_dict["optimizer"])
|
| 247 |
-
log.info("- Loading the gradient scaler...")
|
| 248 |
-
grad_scaler.load_state_dict(state_dict["grad_scaler"])
|
| 249 |
-
log.success(f"Done with loading the checkpoint (iteration {iteration}).")
|
| 250 |
-
else:
|
| 251 |
-
log.success("Done with loading the checkpoint.")
|
| 252 |
-
else:
|
| 253 |
-
# Checkpoint not found and not specified. We will train everything from scratch.
|
| 254 |
-
iteration = 0
|
| 255 |
-
log.info("Training from scratch.")
|
| 256 |
-
torch.cuda.empty_cache()
|
| 257 |
-
|
| 258 |
-
self.callbacks.on_load_checkpoint_end(model, iteration=iteration, checkpoint_path=checkpoint_path)
|
| 259 |
-
|
| 260 |
-
return iteration
|
| 261 |
-
|
| 262 |
-
def _read_latest_checkpoint_file(self) -> str | None:
|
| 263 |
-
"""Get the file name of the latest saved checkpoint. If it doesn't exist, return None.
|
| 264 |
-
|
| 265 |
-
Returns:
|
| 266 |
-
checkpoint_file (str | None): file name of the latest saved checkpoint.
|
| 267 |
-
"""
|
| 268 |
-
checkpoint_file = None
|
| 269 |
-
if self.load_from_object_store:
|
| 270 |
-
latest_path = os.path.join(self.checkpoint_dir_object_store, "latest_checkpoint.txt")
|
| 271 |
-
if self.object_store_loader.object_exists(key=latest_path):
|
| 272 |
-
checkpoint_file = self.object_store_loader.load_object(key=latest_path, type="text").strip()
|
| 273 |
-
else:
|
| 274 |
-
latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt")
|
| 275 |
-
if os.path.isfile(latest_path):
|
| 276 |
-
checkpoint_file = open(latest_path).read().strip()
|
| 277 |
-
return checkpoint_file
|
| 278 |
-
|
| 279 |
-
def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None:
|
| 280 |
-
"""Track the file name of the latest saved checkpoint.
|
| 281 |
-
|
| 282 |
-
Args:
|
| 283 |
-
checkpoint_file (str): file name of the latest saved checkpoint.
|
| 284 |
-
"""
|
| 285 |
-
content = f"{checkpoint_file}\n"
|
| 286 |
-
if self.save_to_object_store:
|
| 287 |
-
latest_path = os.path.join(self.checkpoint_dir_object_store, "latest_checkpoint.txt")
|
| 288 |
-
self.object_store_saver.save_object(content, key=latest_path, type="text")
|
| 289 |
-
else:
|
| 290 |
-
latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt")
|
| 291 |
-
with open(latest_path, "w") as file:
|
| 292 |
-
file.write(content)
|
| 293 |
-
|
| 294 |
-
def _check_checkpoint_exists(self, checkpoint_path: str) -> None:
|
| 295 |
-
"""If the file checkpoint_path does not exist, raise an error.
|
| 296 |
-
|
| 297 |
-
Args:
|
| 298 |
-
checkpoint_path (str): full path to the checkpoint.
|
| 299 |
-
"""
|
| 300 |
-
if self.load_from_object_store:
|
| 301 |
-
if not self.object_store_loader.object_exists(key=checkpoint_path):
|
| 302 |
-
raise FileNotFoundError(f"File not found (object store): {checkpoint_path}")
|
| 303 |
-
else:
|
| 304 |
-
if not os.path.exists(checkpoint_path):
|
| 305 |
-
raise FileNotFoundError(f"File not found (local): {checkpoint_path}")
|
| 306 |
-
|
| 307 |
-
def finalize(self) -> None:
|
| 308 |
-
"""Finalize the checkpointer."""
|
| 309 |
-
if self.save_thread:
|
| 310 |
-
self.save_thread.join()
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
# https://github.com/facebookresearch/fvcore/blob/9d683aae73fb899dd35d6cf6720e5ef567761c57/fvcore/common/checkpoint.py
|
| 314 |
-
def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys:
|
| 315 |
-
# workaround https://github.com/pytorch/pytorch/issues/24139
|
| 316 |
-
model_state_dict = model.state_dict()
|
| 317 |
-
incorrect_shapes = []
|
| 318 |
-
for k in list(checkpoint_state_dict.keys()):
|
| 319 |
-
if k in model_state_dict:
|
| 320 |
-
if "_extra_state" in k: # Key introduced by TransformerEngine for FP8
|
| 321 |
-
log.warning(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.")
|
| 322 |
-
continue
|
| 323 |
-
model_param = model_state_dict[k]
|
| 324 |
-
# Allow mismatch for uninitialized parameters
|
| 325 |
-
if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter):
|
| 326 |
-
continue
|
| 327 |
-
if not isinstance(model_param, torch.Tensor):
|
| 328 |
-
raise ValueError(
|
| 329 |
-
f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not."
|
| 330 |
-
)
|
| 331 |
-
|
| 332 |
-
shape_model = tuple(model_param.shape)
|
| 333 |
-
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
|
| 334 |
-
if shape_model != shape_checkpoint:
|
| 335 |
-
has_observer_base_classes = (
|
| 336 |
-
TORCH_VERSION >= (1, 8)
|
| 337 |
-
and hasattr(quantization, "ObserverBase")
|
| 338 |
-
and hasattr(quantization, "FakeQuantizeBase")
|
| 339 |
-
)
|
| 340 |
-
if has_observer_base_classes:
|
| 341 |
-
# Handle the special case of quantization per channel observers,
|
| 342 |
-
# where buffer shape mismatches are expected.
|
| 343 |
-
def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module:
|
| 344 |
-
# foo.bar.param_or_buffer_name -> [foo, bar]
|
| 345 |
-
key_parts = key.split(".")[:-1]
|
| 346 |
-
cur_module = model
|
| 347 |
-
for key_part in key_parts:
|
| 348 |
-
cur_module = getattr(cur_module, key_part)
|
| 349 |
-
return cur_module
|
| 350 |
-
|
| 351 |
-
cls_to_skip = (
|
| 352 |
-
ObserverBase,
|
| 353 |
-
FakeQuantizeBase,
|
| 354 |
-
)
|
| 355 |
-
target_module = _get_module_for_key(model, k)
|
| 356 |
-
if isinstance(target_module, cls_to_skip):
|
| 357 |
-
# Do not remove modules with expected shape mismatches
|
| 358 |
-
# them from the state_dict loading. They have special logic
|
| 359 |
-
# in _load_from_state_dict to handle the mismatches.
|
| 360 |
-
continue
|
| 361 |
-
|
| 362 |
-
incorrect_shapes.append((k, shape_checkpoint, shape_model))
|
| 363 |
-
checkpoint_state_dict.pop(k)
|
| 364 |
-
incompatible = model.load_state_dict(checkpoint_state_dict, strict=False)
|
| 365 |
-
# Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling
|
| 366 |
-
missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k]
|
| 367 |
-
unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k]
|
| 368 |
-
return _IncompatibleKeys(
|
| 369 |
-
missing_keys=missing_keys,
|
| 370 |
-
unexpected_keys=unexpected_keys,
|
| 371 |
-
incorrect_shapes=incorrect_shapes,
|
| 372 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/config_helper.py
DELETED
|
@@ -1,214 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import importlib
|
| 17 |
-
import os
|
| 18 |
-
import pkgutil
|
| 19 |
-
import sys
|
| 20 |
-
from dataclasses import fields as dataclass_fields
|
| 21 |
-
from dataclasses import is_dataclass
|
| 22 |
-
from typing import Any, Dict, Optional
|
| 23 |
-
|
| 24 |
-
import attr
|
| 25 |
-
import attrs
|
| 26 |
-
from hydra import compose, initialize
|
| 27 |
-
from hydra.core.config_store import ConfigStore
|
| 28 |
-
from hydra.core.global_hydra import GlobalHydra
|
| 29 |
-
from omegaconf import DictConfig, OmegaConf
|
| 30 |
-
|
| 31 |
-
from lyra_2._ext.imaginaire.config import Config
|
| 32 |
-
from lyra_2._ext.imaginaire.utils import log
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def is_attrs_or_dataclass(obj) -> bool:
|
| 36 |
-
"""
|
| 37 |
-
Check if the object is an instance of an attrs class or a dataclass.
|
| 38 |
-
|
| 39 |
-
Args:
|
| 40 |
-
obj: The object to check.
|
| 41 |
-
|
| 42 |
-
Returns:
|
| 43 |
-
bool: True if the object is an instance of an attrs class or a dataclass, False otherwise.
|
| 44 |
-
"""
|
| 45 |
-
return is_dataclass(obj) or attr.has(type(obj))
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def get_fields(obj):
|
| 49 |
-
"""
|
| 50 |
-
Get the fields of an attrs class or a dataclass.
|
| 51 |
-
|
| 52 |
-
Args:
|
| 53 |
-
obj: The object to get fields from. Must be an instance of an attrs class or a dataclass.
|
| 54 |
-
|
| 55 |
-
Returns:
|
| 56 |
-
list: A list of field names.
|
| 57 |
-
|
| 58 |
-
Raises:
|
| 59 |
-
ValueError: If the object is neither an attrs class nor a dataclass.
|
| 60 |
-
"""
|
| 61 |
-
if is_dataclass(obj):
|
| 62 |
-
return [field.name for field in dataclass_fields(obj)]
|
| 63 |
-
elif attr.has(type(obj)):
|
| 64 |
-
return [field.name for field in attr.fields(type(obj))]
|
| 65 |
-
else:
|
| 66 |
-
raise ValueError("The object is neither an attrs class nor a dataclass.")
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def override(config: Config, overrides: Optional[list[str]] = None) -> Config:
|
| 70 |
-
"""
|
| 71 |
-
:param config: the instance of class `Config` (usually from `make_config`)
|
| 72 |
-
:param overrides: list of overrides for config
|
| 73 |
-
:return: the composed instance of class `Config`
|
| 74 |
-
"""
|
| 75 |
-
# Store the class of the config for reconstruction after overriding.
|
| 76 |
-
# config_class = type(config)
|
| 77 |
-
|
| 78 |
-
# Convert Config object to a DictConfig object
|
| 79 |
-
config_dict = attrs.asdict(config)
|
| 80 |
-
config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
|
| 81 |
-
# Enforce "--" separator between the script arguments and overriding configs.
|
| 82 |
-
if overrides:
|
| 83 |
-
if overrides[0] != "--":
|
| 84 |
-
raise ValueError(
|
| 85 |
-
f'Hydra config overrides must be separated with a "--" token. but got overrides={overrides}, and overrides[0]={overrides[0]}'
|
| 86 |
-
)
|
| 87 |
-
overrides = overrides[1:]
|
| 88 |
-
# Use Hydra to handle overrides
|
| 89 |
-
cs = ConfigStore.instance()
|
| 90 |
-
cs.store(name="config", node=config_omegaconf)
|
| 91 |
-
if not GlobalHydra().is_initialized():
|
| 92 |
-
with initialize(version_base=None):
|
| 93 |
-
config_omegaconf = compose(config_name="config", overrides=overrides)
|
| 94 |
-
OmegaConf.resolve(config_omegaconf)
|
| 95 |
-
else:
|
| 96 |
-
config_omegaconf = compose(config_name="config", overrides=overrides)
|
| 97 |
-
OmegaConf.resolve(config_omegaconf)
|
| 98 |
-
|
| 99 |
-
def config_from_dict(ref_instance: Any, kwargs: Any) -> Any:
|
| 100 |
-
"""
|
| 101 |
-
Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data
|
| 102 |
-
|
| 103 |
-
Args:
|
| 104 |
-
ref_instance: The reference instance to determine the type and fields when needed
|
| 105 |
-
kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data
|
| 106 |
-
|
| 107 |
-
Returns:
|
| 108 |
-
Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data
|
| 109 |
-
|
| 110 |
-
Raises:
|
| 111 |
-
AssertionError: If the fields do not match or if extra keys are found.
|
| 112 |
-
Exception: If there is an error constructing the new instance.
|
| 113 |
-
"""
|
| 114 |
-
is_type = is_attrs_or_dataclass(ref_instance)
|
| 115 |
-
if not is_type:
|
| 116 |
-
return kwargs
|
| 117 |
-
else:
|
| 118 |
-
ref_fields = set(get_fields(ref_instance))
|
| 119 |
-
assert isinstance(kwargs, dict) or isinstance(kwargs, DictConfig), (
|
| 120 |
-
"kwargs must be a dictionary or a DictConfig"
|
| 121 |
-
)
|
| 122 |
-
keys = set(kwargs.keys())
|
| 123 |
-
|
| 124 |
-
# ref_fields must equal to or include all keys
|
| 125 |
-
extra_keys = keys - ref_fields
|
| 126 |
-
assert ref_fields == keys or keys.issubset(ref_fields), (
|
| 127 |
-
f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}"
|
| 128 |
-
)
|
| 129 |
-
|
| 130 |
-
resolved_kwargs: Dict[str, Any] = {}
|
| 131 |
-
for f in keys:
|
| 132 |
-
resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f])
|
| 133 |
-
try:
|
| 134 |
-
new_instance = type(ref_instance)(**resolved_kwargs)
|
| 135 |
-
except Exception as e:
|
| 136 |
-
log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}")
|
| 137 |
-
log.error(e)
|
| 138 |
-
raise e
|
| 139 |
-
return new_instance
|
| 140 |
-
|
| 141 |
-
config = config_from_dict(config, config_omegaconf)
|
| 142 |
-
|
| 143 |
-
return config
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
def get_config_module(config_file: str) -> str:
|
| 147 |
-
if not config_file.endswith(".py"):
|
| 148 |
-
log.error("Config file cannot be specified as module.")
|
| 149 |
-
log.error("Please provide the path to the Python config file (relative to the Imaginaire4 root).")
|
| 150 |
-
assert os.path.isfile(config_file), f"Imaginaire4 config file ({config_file}) not found."
|
| 151 |
-
# Convert to importable module format.
|
| 152 |
-
config_module = config_file.replace("/", ".").replace(".py", "")
|
| 153 |
-
return config_module
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def import_module(full_module_name: str, reload: bool = False):
|
| 157 |
-
"""
|
| 158 |
-
Import a module by name.
|
| 159 |
-
|
| 160 |
-
Args:
|
| 161 |
-
full_module_name: The fully qualified name of the module to import.
|
| 162 |
-
reload: If True, reload the module if it's already imported.
|
| 163 |
-
"""
|
| 164 |
-
if full_module_name in sys.modules and reload:
|
| 165 |
-
importlib.reload(sys.modules[full_module_name])
|
| 166 |
-
else:
|
| 167 |
-
importlib.import_module(full_module_name)
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None:
|
| 171 |
-
"""
|
| 172 |
-
Import all modules from the specified package path recursively.
|
| 173 |
-
|
| 174 |
-
This function is typically used in conjunction with Hydra to ensure that all modules
|
| 175 |
-
within a specified package are imported, which is necessary for registering configurations.
|
| 176 |
-
|
| 177 |
-
Example usage:
|
| 178 |
-
```python
|
| 179 |
-
import_all_modules_from_package("projects.cosmos.diffusion.v1.config.experiment", reload=True, skip_underscore=False)
|
| 180 |
-
```
|
| 181 |
-
|
| 182 |
-
Args:
|
| 183 |
-
package_path (str): The dotted path to the package from which to import all modules.
|
| 184 |
-
reload (bool): Flag to determine whether to reload modules if they're already imported.
|
| 185 |
-
skip_underscore (bool): If True, skips importing modules that start with an underscore.
|
| 186 |
-
"""
|
| 187 |
-
log.info(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}")
|
| 188 |
-
package = importlib.import_module(package_path)
|
| 189 |
-
package_directory = package.__path__
|
| 190 |
-
|
| 191 |
-
def import_modules_recursively(directory: str, prefix: str) -> None:
|
| 192 |
-
"""
|
| 193 |
-
Recursively imports or reloads all modules in the given directory.
|
| 194 |
-
|
| 195 |
-
Args:
|
| 196 |
-
directory (str): The file system path to the current package directory.
|
| 197 |
-
prefix (str): The module prefix (e.g., 'projects.cosmos.diffusion.v1.config').
|
| 198 |
-
"""
|
| 199 |
-
for _, module_name, is_pkg in pkgutil.iter_modules([directory]):
|
| 200 |
-
if skip_underscore and module_name.startswith("_"):
|
| 201 |
-
log.debug(f"Skipping module {module_name} as it starts with an underscore")
|
| 202 |
-
continue
|
| 203 |
-
|
| 204 |
-
full_module_name = f"{prefix}.{module_name}"
|
| 205 |
-
log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}")
|
| 206 |
-
|
| 207 |
-
import_module(full_module_name, reload=reload)
|
| 208 |
-
|
| 209 |
-
if is_pkg:
|
| 210 |
-
sub_package_directory = os.path.join(directory, module_name)
|
| 211 |
-
import_modules_recursively(sub_package_directory, full_module_name)
|
| 212 |
-
|
| 213 |
-
for directory in package_directory:
|
| 214 |
-
import_modules_recursively(directory, package_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/context_managers.py
DELETED
|
@@ -1,55 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from contextlib import ExitStack, contextmanager
|
| 17 |
-
from typing import Generator
|
| 18 |
-
|
| 19 |
-
from lyra_2._ext.imaginaire.utils.misc import timer
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
@contextmanager
|
| 23 |
-
def data_loader_init() -> Generator[None, None, None]:
|
| 24 |
-
"""
|
| 25 |
-
Wrap the data loader initialization with multiple context managers used for telemetry and one logger.
|
| 26 |
-
"""
|
| 27 |
-
contexts = [
|
| 28 |
-
timer("init_data_loader"),
|
| 29 |
-
]
|
| 30 |
-
with ExitStack() as stack:
|
| 31 |
-
yield [stack.enter_context(cm) for cm in contexts]
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
@contextmanager
|
| 35 |
-
def model_init(set_barrier: bool = False) -> Generator[None, None, None]:
|
| 36 |
-
"""
|
| 37 |
-
Wrap the instantiation of the model with multiple context managers used for telemetry and one logger.
|
| 38 |
-
"""
|
| 39 |
-
contexts = [
|
| 40 |
-
timer("init_model"),
|
| 41 |
-
]
|
| 42 |
-
with ExitStack() as stack:
|
| 43 |
-
yield [stack.enter_context(cm) for cm in contexts]
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
@contextmanager
|
| 47 |
-
def distributed_init() -> Generator[None, None, None]:
|
| 48 |
-
"""
|
| 49 |
-
Wrap the distributed initialization, used for telemetry and timers
|
| 50 |
-
"""
|
| 51 |
-
contexts = [
|
| 52 |
-
timer("init_distributed"),
|
| 53 |
-
]
|
| 54 |
-
with ExitStack() as stack:
|
| 55 |
-
yield [stack.enter_context(cm) for cm in contexts]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/count_params.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from torch import nn
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def disabled_train(self, mode: bool = True):
|
| 20 |
-
"""Overwrite model.train with this function to make sure train/eval mode
|
| 21 |
-
does not change anymore."""
|
| 22 |
-
return self
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def count_params(model: nn.Module, verbose=False) -> int:
|
| 26 |
-
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 27 |
-
if verbose:
|
| 28 |
-
print(f"{model.__class__.__name__} has {total_params * 1.0e-6:.2f} M params.")
|
| 29 |
-
return total_params
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/device.py
DELETED
|
@@ -1,114 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import gc
|
| 17 |
-
import math
|
| 18 |
-
import os
|
| 19 |
-
|
| 20 |
-
import pynvml
|
| 21 |
-
from loguru import logger as logging
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def get_gpu_architecture():
|
| 25 |
-
"""
|
| 26 |
-
Retrieves the GPU architecture of the available GPUs.
|
| 27 |
-
|
| 28 |
-
Returns:
|
| 29 |
-
str: The GPU architecture, which can be "H100", "A100", or "Other".
|
| 30 |
-
"""
|
| 31 |
-
try:
|
| 32 |
-
pynvml.nvmlInit()
|
| 33 |
-
device_count = pynvml.nvmlDeviceGetCount()
|
| 34 |
-
for i in range(device_count):
|
| 35 |
-
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
| 36 |
-
model_name = pynvml.nvmlDeviceGetName(handle)
|
| 37 |
-
if isinstance(model_name, bytes):
|
| 38 |
-
model_name = model_name.decode("utf-8")
|
| 39 |
-
print(f"GPU {i}: Model: {model_name}")
|
| 40 |
-
|
| 41 |
-
# Check for specific models like H100 or A100
|
| 42 |
-
if "H100" in model_name or "H200" in model_name:
|
| 43 |
-
return "H100"
|
| 44 |
-
elif "A100" in model_name:
|
| 45 |
-
return "A100"
|
| 46 |
-
elif "L40S" in model_name:
|
| 47 |
-
return "L40S"
|
| 48 |
-
elif "B200" in model_name:
|
| 49 |
-
return "B200"
|
| 50 |
-
except pynvml.NVMLError as error:
|
| 51 |
-
print(f"Failed to get GPU info: {error}")
|
| 52 |
-
finally:
|
| 53 |
-
pynvml.nvmlShutdown()
|
| 54 |
-
|
| 55 |
-
# return "Other" incase of non hopper/ampere or error
|
| 56 |
-
return "Other"
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
class GPUArchitectureNotSupported(Exception):
|
| 60 |
-
"""
|
| 61 |
-
Custom exception raised when the expected GPU architecture is not supported.
|
| 62 |
-
"""
|
| 63 |
-
|
| 64 |
-
pass
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def print_gpu_mem(str=None):
|
| 68 |
-
try:
|
| 69 |
-
pynvml.nvmlInit()
|
| 70 |
-
meminfo = pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(0))
|
| 71 |
-
logging.info(
|
| 72 |
-
f"{str}: {meminfo.used / 1024 / 1024}/{meminfo.total / 1024 / 1024}MiB used ({meminfo.free / 1024 / 1024}MiB free)"
|
| 73 |
-
)
|
| 74 |
-
except pynvml.NVMLError as error:
|
| 75 |
-
print(f"Failed to get GPU memory info: {error}")
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def force_gc():
|
| 79 |
-
print_gpu_mem()
|
| 80 |
-
print("gc()")
|
| 81 |
-
gc.collect()
|
| 82 |
-
print_gpu_mem()
|
| 83 |
-
print("empty cuda cache")
|
| 84 |
-
# print(torch.cuda.memory_summary())
|
| 85 |
-
print_gpu_mem()
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def gpu0_has_80gb_or_less():
|
| 89 |
-
try:
|
| 90 |
-
pynvml.nvmlInit()
|
| 91 |
-
meminfo = pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(0))
|
| 92 |
-
return meminfo.total / 1024 / 1024 / 1024 <= 80
|
| 93 |
-
except pynvml.NVMLError as error:
|
| 94 |
-
print(f"Failed to get GPU memory info: {error}")
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
class Device:
|
| 98 |
-
_nvml_affinity_elements = math.ceil(os.cpu_count() / 64) # type: ignore
|
| 99 |
-
|
| 100 |
-
def __init__(self, device_idx: int):
|
| 101 |
-
super().__init__()
|
| 102 |
-
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
|
| 103 |
-
|
| 104 |
-
def get_name(self) -> str:
|
| 105 |
-
return pynvml.nvmlDeviceGetName(self.handle)
|
| 106 |
-
|
| 107 |
-
def get_cpu_affinity(self) -> list[int]:
|
| 108 |
-
affinity_string = ""
|
| 109 |
-
for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements):
|
| 110 |
-
# assume nvml returns list of 64 bit ints
|
| 111 |
-
affinity_string = "{:064b}".format(j) + affinity_string
|
| 112 |
-
affinity_list = [int(x) for x in affinity_string]
|
| 113 |
-
affinity_list.reverse() # so core 0 is in 0th element of list
|
| 114 |
-
return [i for i, e in enumerate(affinity_list) if e != 0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/distributed.py
DELETED
|
@@ -1,443 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from __future__ import annotations
|
| 17 |
-
|
| 18 |
-
import collections
|
| 19 |
-
import collections.abc
|
| 20 |
-
import ctypes
|
| 21 |
-
import functools
|
| 22 |
-
import os
|
| 23 |
-
from contextlib import contextmanager
|
| 24 |
-
from datetime import timedelta
|
| 25 |
-
from typing import TYPE_CHECKING, Any, Callable, Container, Optional
|
| 26 |
-
|
| 27 |
-
import pynvml
|
| 28 |
-
import torch
|
| 29 |
-
import torch.distributed as dist
|
| 30 |
-
from torch.distributed import get_process_group_ranks
|
| 31 |
-
|
| 32 |
-
from lyra_2._ext.imaginaire.utils.device import Device
|
| 33 |
-
|
| 34 |
-
if dist.is_available():
|
| 35 |
-
from torch.distributed.distributed_c10d import _get_default_group
|
| 36 |
-
from torch.distributed.utils import _sync_module_states, _verify_param_shape_across_processes
|
| 37 |
-
|
| 38 |
-
from lyra_2._ext.imaginaire.utils import log
|
| 39 |
-
|
| 40 |
-
if TYPE_CHECKING:
|
| 41 |
-
from lyra_2._ext.imaginaire.config import DDPConfig
|
| 42 |
-
|
| 43 |
-
try:
|
| 44 |
-
from megatron.core import parallel_state
|
| 45 |
-
except ImportError:
|
| 46 |
-
print("Megatron-core is not installed.")
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def init() -> int | None:
|
| 50 |
-
"""Initialize distributed training."""
|
| 51 |
-
if dist.is_initialized():
|
| 52 |
-
return torch.cuda.current_device()
|
| 53 |
-
|
| 54 |
-
# Set GPU affinity.
|
| 55 |
-
pynvml.nvmlInit()
|
| 56 |
-
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 57 |
-
try:
|
| 58 |
-
device = Device(local_rank)
|
| 59 |
-
os.sched_setaffinity(0, device.get_cpu_affinity())
|
| 60 |
-
except Exception as e:
|
| 61 |
-
log.warning(f"Failed to set device affinity: {e}")
|
| 62 |
-
# Set up NCCL communication.
|
| 63 |
-
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0"
|
| 64 |
-
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
| 65 |
-
if dist.is_available():
|
| 66 |
-
torch.cuda.set_device(local_rank)
|
| 67 |
-
# Get the timeout value from environment variable
|
| 68 |
-
timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800)
|
| 69 |
-
# Convert the timeout to an integer (if it isn't already) and then to a timedelta
|
| 70 |
-
timeout_timedelta = timedelta(seconds=int(timeout_seconds))
|
| 71 |
-
dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta)
|
| 72 |
-
log.info(
|
| 73 |
-
f"Initialized distributed training with local rank {local_rank} with timeout {timeout_seconds}",
|
| 74 |
-
rank0_only=False,
|
| 75 |
-
)
|
| 76 |
-
# Increase the L2 fetch granularity for faster speed.
|
| 77 |
-
_libcudart = ctypes.CDLL("libcudart.so")
|
| 78 |
-
# Set device limit on the current device.
|
| 79 |
-
p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
|
| 80 |
-
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
|
| 81 |
-
_libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05))
|
| 82 |
-
log.info(f"Training with {get_world_size()} GPUs.")
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
def get_rank(group: Optional[dist.ProcessGroup] = None) -> int:
|
| 86 |
-
"""Get the rank (GPU device) of the worker.
|
| 87 |
-
|
| 88 |
-
Returns:
|
| 89 |
-
rank (int): The rank of the worker.
|
| 90 |
-
"""
|
| 91 |
-
rank = 0
|
| 92 |
-
if dist.is_available() and dist.is_initialized():
|
| 93 |
-
rank = dist.get_rank(group)
|
| 94 |
-
return rank
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int:
|
| 98 |
-
"""Get world size. How many GPUs are available in this job.
|
| 99 |
-
|
| 100 |
-
Returns:
|
| 101 |
-
world_size (int): The total number of GPUs available in this job.
|
| 102 |
-
"""
|
| 103 |
-
world_size = 1
|
| 104 |
-
if dist.is_available() and dist.is_initialized():
|
| 105 |
-
world_size = dist.get_world_size(group)
|
| 106 |
-
return world_size
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def is_rank0() -> bool:
|
| 110 |
-
"""Check if current process is the master GPU.
|
| 111 |
-
|
| 112 |
-
Returns:
|
| 113 |
-
(bool): True if this function is called from the master GPU, else False.
|
| 114 |
-
"""
|
| 115 |
-
return get_rank() == 0
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
def is_local_rank0() -> bool:
|
| 119 |
-
"""Check if current process is the local master GPU in the current node.
|
| 120 |
-
|
| 121 |
-
Returns:
|
| 122 |
-
(bool): True if this function is called from the local master GPU, else False.
|
| 123 |
-
"""
|
| 124 |
-
return torch.cuda.current_device() == 0
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
def rank0_only(func: Callable) -> Callable:
|
| 128 |
-
"""Apply this function only to the master GPU.
|
| 129 |
-
|
| 130 |
-
Example usage:
|
| 131 |
-
@rank0_only
|
| 132 |
-
def func(x):
|
| 133 |
-
return x + 3
|
| 134 |
-
|
| 135 |
-
Args:
|
| 136 |
-
func (Callable): a function.
|
| 137 |
-
|
| 138 |
-
Returns:
|
| 139 |
-
(Callable): A function wrapper executing the function only on the master GPU.
|
| 140 |
-
"""
|
| 141 |
-
|
| 142 |
-
@functools.wraps(func)
|
| 143 |
-
def wrapper(*args, **kwargs): # noqa: ANN202
|
| 144 |
-
if is_rank0():
|
| 145 |
-
return func(*args, **kwargs)
|
| 146 |
-
else:
|
| 147 |
-
return None
|
| 148 |
-
|
| 149 |
-
return wrapper
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
def barrier() -> None:
|
| 153 |
-
"""Barrier for all GPUs."""
|
| 154 |
-
if dist.is_available() and dist.is_initialized():
|
| 155 |
-
dist.barrier()
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
def rank0_first(func: Callable) -> Callable:
|
| 159 |
-
"""run the function on rank 0 first, then on other ranks."""
|
| 160 |
-
|
| 161 |
-
@functools.wraps(func)
|
| 162 |
-
def wrapper(*args, **kwargs): # noqa: ANN202
|
| 163 |
-
if is_rank0():
|
| 164 |
-
result = func(*args, **kwargs)
|
| 165 |
-
barrier()
|
| 166 |
-
if not is_rank0():
|
| 167 |
-
result = func(*args, **kwargs)
|
| 168 |
-
return result
|
| 169 |
-
|
| 170 |
-
return wrapper
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> torch.nn.Module | DistributedDataParallel:
|
| 174 |
-
"""Wraps the model to enable data parallalism for training across multiple GPU devices.
|
| 175 |
-
|
| 176 |
-
Args:
|
| 177 |
-
config_ddp (DDPConfig): The data parallel config.
|
| 178 |
-
model (torch.nn.Module): The PyTorch module.
|
| 179 |
-
|
| 180 |
-
Returns:
|
| 181 |
-
model (torch.nn.Module | DistributedDataParallel): The data parallel model wrapper
|
| 182 |
-
if distributed environment is available, otherwise return the original model.
|
| 183 |
-
"""
|
| 184 |
-
if dist.is_available() and dist.is_initialized():
|
| 185 |
-
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 186 |
-
try:
|
| 187 |
-
ddp_group = parallel_state.get_data_parallel_group(with_context_parallel=True)
|
| 188 |
-
except Exception as e:
|
| 189 |
-
log.info(e)
|
| 190 |
-
log.info("parallel_state not initialized, treating all GPUs equally for DDP")
|
| 191 |
-
ddp_group = None
|
| 192 |
-
|
| 193 |
-
model = DistributedDataParallel(
|
| 194 |
-
model,
|
| 195 |
-
device_ids=[local_rank],
|
| 196 |
-
output_device=local_rank,
|
| 197 |
-
find_unused_parameters=config_ddp.find_unused_parameters,
|
| 198 |
-
static_graph=config_ddp.static_graph,
|
| 199 |
-
broadcast_buffers=config_ddp.broadcast_buffers,
|
| 200 |
-
process_group=ddp_group,
|
| 201 |
-
)
|
| 202 |
-
return model
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel):
|
| 206 |
-
"""This extends torch.nn.parallel.DistributedDataParallel with .training_step().
|
| 207 |
-
|
| 208 |
-
This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an ImaginaireModel such that
|
| 209 |
-
model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling
|
| 210 |
-
model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward ->
|
| 211 |
-
training_step), allowing us to preserve the function names and signatures.
|
| 212 |
-
"""
|
| 213 |
-
|
| 214 |
-
def __init__(self, model: torch.nn.Module, *args, **kwargs):
|
| 215 |
-
super().__init__(model, *args, **kwargs)
|
| 216 |
-
self.show_sync_grad_static_graph_warning = True
|
| 217 |
-
|
| 218 |
-
def training_step(self, *args, **kwargs) -> Any:
|
| 219 |
-
# Cache the original model.forward() method.
|
| 220 |
-
original_forward = self.module.forward
|
| 221 |
-
|
| 222 |
-
def wrapped_training_step(*_args, **_kwargs): # noqa: ANN202
|
| 223 |
-
# Unpatch immediately before calling training_step() because itself may want to call the real forward.
|
| 224 |
-
self.module.forward = original_forward
|
| 225 |
-
# The actual .training_step().
|
| 226 |
-
return self.module.training_step(*_args, **_kwargs)
|
| 227 |
-
|
| 228 |
-
# Patch the original_module's forward so we can redirect the arguments back to the real method.
|
| 229 |
-
self.module.forward = wrapped_training_step
|
| 230 |
-
# Call self, which implicitly calls self.forward() --> model.forward(), which is now model.training_step().
|
| 231 |
-
# Without calling self.forward() or model.forward() explciitly, implicit hooks are also executed.
|
| 232 |
-
return self(*args, **kwargs)
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
@contextmanager
|
| 236 |
-
def ddp_sync_grad(model, enabled):
|
| 237 |
-
r"""
|
| 238 |
-
Context manager to enable/disable gradient synchronizations across DDP processes for DDP model.
|
| 239 |
-
Modified from:
|
| 240 |
-
https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync
|
| 241 |
-
Note that this is incompatible with static_graph=True and will be an no-op if static_graph=True.
|
| 242 |
-
|
| 243 |
-
Within this context, gradients will be accumulated on module
|
| 244 |
-
variables, which will later be synchronized in the first
|
| 245 |
-
forward-backward pass exiting the context.
|
| 246 |
-
|
| 247 |
-
.. warning::
|
| 248 |
-
The forward pass should be included inside the context manager, or
|
| 249 |
-
else gradients will still be synchronized.
|
| 250 |
-
"""
|
| 251 |
-
assert isinstance(model, torch.nn.Module)
|
| 252 |
-
if isinstance(model, DistributedDataParallel):
|
| 253 |
-
old_require_backward_grad_sync = model.require_backward_grad_sync
|
| 254 |
-
if model.static_graph and model.require_backward_grad_sync != enabled:
|
| 255 |
-
if model.show_sync_grad_static_graph_warning:
|
| 256 |
-
log.warning("DDP static_graph=True is incompatible with sync_grad(). Performance will be reduced.")
|
| 257 |
-
model.show_sync_grad_static_graph_warning = False
|
| 258 |
-
else:
|
| 259 |
-
model.require_backward_grad_sync = enabled
|
| 260 |
-
try:
|
| 261 |
-
yield
|
| 262 |
-
finally:
|
| 263 |
-
if isinstance(model, DistributedDataParallel):
|
| 264 |
-
model.require_backward_grad_sync = old_require_backward_grad_sync
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]:
|
| 268 |
-
"""Aggregate the list of data batches from all devices and process the results.
|
| 269 |
-
|
| 270 |
-
This is used for gathering validation data batches with lyra_2._ext.imaginaire.utils.dataloader.DistributedEvalSampler.
|
| 271 |
-
It will return the data/output of the entire validation set in its original index order. The sizes of data_batches
|
| 272 |
-
in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be
|
| 273 |
-
created before calling dis.all_gather().
|
| 274 |
-
|
| 275 |
-
Args:
|
| 276 |
-
data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where
|
| 277 |
-
leaf entries are tensors.
|
| 278 |
-
|
| 279 |
-
Returns:
|
| 280 |
-
data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where
|
| 281 |
-
leaf entries are concatenated tensors.
|
| 282 |
-
"""
|
| 283 |
-
if isinstance(data_batches[0], torch.Tensor):
|
| 284 |
-
# Concatenate the local data batches.
|
| 285 |
-
data_concat = torch.cat(data_batches, dim=0) # type: ignore
|
| 286 |
-
# Get the largest number of local samples from all ranks to determine whether to dummy-pad on this rank.
|
| 287 |
-
max_num_local_samples = torch.tensor(len(data_concat), device="cuda")
|
| 288 |
-
dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX)
|
| 289 |
-
if len(data_concat) < max_num_local_samples:
|
| 290 |
-
assert len(data_concat) + 1 == max_num_local_samples
|
| 291 |
-
dummy = torch.empty_like(data_concat[:1])
|
| 292 |
-
data_concat = torch.cat([data_concat, dummy], dim=0)
|
| 293 |
-
dummy_count = torch.tensor(1, device="cuda")
|
| 294 |
-
else:
|
| 295 |
-
dummy_count = torch.tensor(0, device="cuda")
|
| 296 |
-
# Get all concatenated batches from all ranks and concatenate again.
|
| 297 |
-
dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM)
|
| 298 |
-
data_concat = all_gather_tensor(data_concat.contiguous())
|
| 299 |
-
data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1)
|
| 300 |
-
# Remove the dummy samples.
|
| 301 |
-
if dummy_count > 0:
|
| 302 |
-
data_collate = data_collate[:-dummy_count]
|
| 303 |
-
elif isinstance(data_batches[0], collections.abc.Mapping):
|
| 304 |
-
data_collate = dict()
|
| 305 |
-
for key in data_batches[0].keys():
|
| 306 |
-
data_collate[key] = collate_batches([data[key] for data in data_batches]) # type: ignore
|
| 307 |
-
else:
|
| 308 |
-
raise TypeError
|
| 309 |
-
return data_collate
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
@torch.no_grad()
|
| 313 |
-
def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]:
|
| 314 |
-
"""Gather the corresponding tensor from all GPU devices to a list.
|
| 315 |
-
|
| 316 |
-
Args:
|
| 317 |
-
tensor (torch.Tensor): Pytorch tensor.
|
| 318 |
-
|
| 319 |
-
Returns:
|
| 320 |
-
tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices.
|
| 321 |
-
"""
|
| 322 |
-
tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())]
|
| 323 |
-
dist.all_gather(tensor_list, tensor)
|
| 324 |
-
return tensor_list
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
def broadcast(tensor, src, group=None, async_op=False):
|
| 328 |
-
world_size = get_world_size()
|
| 329 |
-
if world_size < 2:
|
| 330 |
-
return tensor
|
| 331 |
-
dist.broadcast(tensor, src=src, group=group, async_op=async_op)
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
def dist_reduce_tensor(tensor, rank=0, reduce="mean"):
|
| 335 |
-
r"""Reduce to rank 0"""
|
| 336 |
-
world_size = get_world_size()
|
| 337 |
-
if world_size < 2:
|
| 338 |
-
return tensor
|
| 339 |
-
with torch.no_grad():
|
| 340 |
-
dist.reduce(tensor, dst=rank)
|
| 341 |
-
if get_rank() == rank:
|
| 342 |
-
if reduce == "mean":
|
| 343 |
-
tensor /= world_size
|
| 344 |
-
elif reduce == "sum":
|
| 345 |
-
pass
|
| 346 |
-
else:
|
| 347 |
-
raise NotImplementedError
|
| 348 |
-
return tensor
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
def sync_model_states(
|
| 352 |
-
model: torch.nn.Module,
|
| 353 |
-
process_group: Optional[dist.ProcessGroup] = None,
|
| 354 |
-
src: int = 0,
|
| 355 |
-
params_and_buffers_to_ignore: Optional[Container[str]] = None,
|
| 356 |
-
broadcast_buffers: bool = True,
|
| 357 |
-
):
|
| 358 |
-
"""
|
| 359 |
-
Modify based on DDP source code
|
| 360 |
-
Synchronizes the parameters and buffers of a model across different processes in a distributed setting.
|
| 361 |
-
|
| 362 |
-
This function ensures that all processes in the specified process group have the same initial parameters and
|
| 363 |
-
buffers from the source rank, typically rank 0. It is useful when different processes start with different model
|
| 364 |
-
states and a synchronization is required to ensure consistency across all ranks.
|
| 365 |
-
|
| 366 |
-
Args:
|
| 367 |
-
model (nn.Module): The model whose parameters and buffers are to be synchronized.
|
| 368 |
-
process_group (dist.ProcessGroup, optional): The process group for communication. If None,
|
| 369 |
-
the default group is used. Defaults to None.
|
| 370 |
-
src (int, optional): The source rank from which parameters and buffers will be broadcasted.
|
| 371 |
-
Defaults to 0.
|
| 372 |
-
params_and_buffers_to_ignore (Optional[Container[str]], optional): A container of parameter and buffer
|
| 373 |
-
names to exclude from synchronization. Defaults to None, which means all parameters and buffers are
|
| 374 |
-
included.
|
| 375 |
-
broadcast_buffers (bool, optional): Whether to broadcast buffers or not. Defaults to True.
|
| 376 |
-
|
| 377 |
-
Side Effects:
|
| 378 |
-
This function modifies the state of the model in-place to synchronize it with the source rank's model state.
|
| 379 |
-
|
| 380 |
-
Raises:
|
| 381 |
-
RuntimeError: If the shapes of parameters across processes do not match, a runtime error will be raised.
|
| 382 |
-
|
| 383 |
-
Examples:
|
| 384 |
-
>>> # downloading duplicated model weights from s3 in each rank and save network bandwidth
|
| 385 |
-
>>> # useful and save our time when model weights are huge
|
| 386 |
-
>>> if dist.get_rank == 0:
|
| 387 |
-
>>> model.load_state_dict(network_bound_weights_download_fn(s3_weights_path))
|
| 388 |
-
>>> dist.barrir()
|
| 389 |
-
>>> sync_model_states(model) # sync rank0 weights to other ranks
|
| 390 |
-
"""
|
| 391 |
-
if not dist.is_available() or not dist.is_initialized():
|
| 392 |
-
return
|
| 393 |
-
if process_group is None:
|
| 394 |
-
process_group = _get_default_group()
|
| 395 |
-
if not params_and_buffers_to_ignore:
|
| 396 |
-
params_and_buffers_to_ignore = set()
|
| 397 |
-
|
| 398 |
-
log.info(
|
| 399 |
-
f"Synchronizing model states from rank {src} to all ranks in process group {get_process_group_ranks(process_group)}."
|
| 400 |
-
)
|
| 401 |
-
|
| 402 |
-
# Build tuple of (module, parameter) for all parameters that require grads.
|
| 403 |
-
modules_and_parameters = [
|
| 404 |
-
(module, parameter)
|
| 405 |
-
for module_name, module in model.named_modules()
|
| 406 |
-
for parameter in [
|
| 407 |
-
param
|
| 408 |
-
# Note that we access module.named_parameters instead of
|
| 409 |
-
# parameters(module). parameters(module) is only needed in the
|
| 410 |
-
# single-process multi device case, where it accesses replicated
|
| 411 |
-
# parameters through _former_parameters.
|
| 412 |
-
for param_name, param in module.named_parameters(recurse=False)
|
| 413 |
-
if f"{module_name}.{param_name}" not in params_and_buffers_to_ignore
|
| 414 |
-
# if param.requires_grad
|
| 415 |
-
# and f"{module_name}.{param_name}" not in params_and_buffers_to_ignore
|
| 416 |
-
]
|
| 417 |
-
]
|
| 418 |
-
|
| 419 |
-
# Deduplicate any parameters that might be shared across child modules.
|
| 420 |
-
memo = set()
|
| 421 |
-
modules_and_parameters = [
|
| 422 |
-
# "p not in memo" is the deduplication check.
|
| 423 |
-
# "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed.
|
| 424 |
-
(m, p)
|
| 425 |
-
for m, p in modules_and_parameters
|
| 426 |
-
if p not in memo and not memo.add(p) # type: ignore[func-returns-value]
|
| 427 |
-
]
|
| 428 |
-
|
| 429 |
-
# Build list of parameters.
|
| 430 |
-
parameters = [parameter for _, parameter in modules_and_parameters]
|
| 431 |
-
if len(parameters) == 0:
|
| 432 |
-
return
|
| 433 |
-
|
| 434 |
-
_verify_param_shape_across_processes(process_group, parameters)
|
| 435 |
-
|
| 436 |
-
_sync_module_states(
|
| 437 |
-
module=model,
|
| 438 |
-
process_group=process_group,
|
| 439 |
-
broadcast_bucket_size=int(250 * 1024 * 1024),
|
| 440 |
-
src=src,
|
| 441 |
-
params_and_buffers_to_ignore=params_and_buffers_to_ignore,
|
| 442 |
-
broadcast_buffers=broadcast_buffers,
|
| 443 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/__init__.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/backends/__init__.py
DELETED
|
@@ -1,34 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from lyra_2._ext.imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
|
| 17 |
-
from lyra_2._ext.imaginaire.utils.easy_io.backends.boto3_backend import Boto3Backend
|
| 18 |
-
from lyra_2._ext.imaginaire.utils.easy_io.backends.http_backend import HTTPBackend
|
| 19 |
-
from lyra_2._ext.imaginaire.utils.easy_io.backends.local_backend import LocalBackend
|
| 20 |
-
from lyra_2._ext.imaginaire.utils.easy_io.backends.registry_utils import (
|
| 21 |
-
backends,
|
| 22 |
-
prefix_to_backends,
|
| 23 |
-
register_backend,
|
| 24 |
-
)
|
| 25 |
-
|
| 26 |
-
__all__ = [
|
| 27 |
-
"BaseStorageBackend",
|
| 28 |
-
"LocalBackend",
|
| 29 |
-
"HTTPBackend",
|
| 30 |
-
"Boto3Backend",
|
| 31 |
-
"register_backend",
|
| 32 |
-
"backends",
|
| 33 |
-
"prefix_to_backends",
|
| 34 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/backends/auto_auth.py
DELETED
|
@@ -1,64 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import contextlib
|
| 17 |
-
import json
|
| 18 |
-
from typing import Any, Optional
|
| 19 |
-
|
| 20 |
-
from lyra_2._ext.imaginaire.utils import log
|
| 21 |
-
from lyra_2._ext.imaginaire.utils.env_parsers.cred_env_parser import CRED_ENVS, CRED_ENVS_DICT
|
| 22 |
-
|
| 23 |
-
DEPLOYMENT_ENVS = ["prod", "dev", "stg"]
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
# context manger to open a file or read from env variable
|
| 27 |
-
@contextlib.contextmanager
|
| 28 |
-
def open_auth(s3_credential_path: Optional[Any], mode: str):
|
| 29 |
-
if not s3_credential_path:
|
| 30 |
-
log.info(f"No credential file provided {s3_credential_path}.")
|
| 31 |
-
yield None
|
| 32 |
-
return
|
| 33 |
-
|
| 34 |
-
name = s3_credential_path.split("/")[-1].split(".")[0]
|
| 35 |
-
if not name:
|
| 36 |
-
raise ValueError(f"Could not parse into env var: {s3_credential_path}")
|
| 37 |
-
cred_env_name = f"PROD_{name.upper()}"
|
| 38 |
-
|
| 39 |
-
if CRED_ENVS.APP_ENV in DEPLOYMENT_ENVS and cred_env_name in CRED_ENVS_DICT:
|
| 40 |
-
object_storage_config = get_creds_from_env(cred_env_name)
|
| 41 |
-
log.info(f"using ENV vars for {cred_env_name}")
|
| 42 |
-
yield object_storage_config
|
| 43 |
-
else:
|
| 44 |
-
log.info(f"using credential file: {s3_credential_path}")
|
| 45 |
-
with open(s3_credential_path, mode) as f:
|
| 46 |
-
yield f
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def get_creds_from_env(cred_env_name: str) -> dict[str, str]:
|
| 50 |
-
try:
|
| 51 |
-
object_storage_config = CRED_ENVS_DICT[cred_env_name]
|
| 52 |
-
except KeyError:
|
| 53 |
-
raise ValueError(f"Could not find {cred_env_name} in CRED_ENVS")
|
| 54 |
-
empty_args = {key.upper() for key in object_storage_config if object_storage_config[key] == ""}
|
| 55 |
-
if empty_args:
|
| 56 |
-
raise ValueError(f"Some required environment variable(s) were not provided for {cred_env_name}", empty_args)
|
| 57 |
-
return object_storage_config
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def json_load_auth(f):
|
| 61 |
-
if CRED_ENVS.APP_ENV in DEPLOYMENT_ENVS:
|
| 62 |
-
return f if f else {}
|
| 63 |
-
else:
|
| 64 |
-
return json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/backends/base_backend.py
DELETED
|
@@ -1,60 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
import os.path as osp
|
| 18 |
-
from abc import ABCMeta, abstractmethod
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def mkdir_or_exist(dir_name, mode=0o777):
|
| 22 |
-
if dir_name == "":
|
| 23 |
-
return
|
| 24 |
-
dir_name = osp.expanduser(dir_name)
|
| 25 |
-
os.makedirs(dir_name, mode=mode, exist_ok=True)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def has_method(obj, method):
|
| 29 |
-
return hasattr(obj, method) and callable(getattr(obj, method))
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
class BaseStorageBackend(metaclass=ABCMeta):
|
| 33 |
-
"""Abstract class of storage backends.
|
| 34 |
-
|
| 35 |
-
All backends need to implement two apis: :meth:`get()` and
|
| 36 |
-
:meth:`get_text()`.
|
| 37 |
-
|
| 38 |
-
- :meth:`get()` reads the file as a byte stream.
|
| 39 |
-
- :meth:`get_text()` reads the file as texts.
|
| 40 |
-
"""
|
| 41 |
-
|
| 42 |
-
# a flag to indicate whether the backend can create a symlink for a file
|
| 43 |
-
# This attribute will be deprecated in future.
|
| 44 |
-
_allow_symlink = False
|
| 45 |
-
|
| 46 |
-
@property
|
| 47 |
-
def allow_symlink(self):
|
| 48 |
-
return self._allow_symlink
|
| 49 |
-
|
| 50 |
-
@property
|
| 51 |
-
def name(self):
|
| 52 |
-
return self.__class__.__name__
|
| 53 |
-
|
| 54 |
-
@abstractmethod
|
| 55 |
-
def get(self, filepath):
|
| 56 |
-
pass
|
| 57 |
-
|
| 58 |
-
@abstractmethod
|
| 59 |
-
def get_text(self, filepath):
|
| 60 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/backends/boto3_backend.py
DELETED
|
@@ -1,841 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import io
|
| 17 |
-
import os
|
| 18 |
-
import re
|
| 19 |
-
import tempfile
|
| 20 |
-
from contextlib import contextmanager
|
| 21 |
-
from pathlib import Path
|
| 22 |
-
from shutil import SameFileError
|
| 23 |
-
from typing import Generator, Iterator, Optional, Tuple, Union
|
| 24 |
-
|
| 25 |
-
from lyra_2._ext.imaginaire.utils import log
|
| 26 |
-
from lyra_2._ext.imaginaire.utils.easy_io.backends.base_backend import (
|
| 27 |
-
BaseStorageBackend,
|
| 28 |
-
has_method,
|
| 29 |
-
mkdir_or_exist,
|
| 30 |
-
)
|
| 31 |
-
from lyra_2._ext.imaginaire.utils.easy_io.backends.boto3_client import Boto3Client
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
class Boto3Backend(BaseStorageBackend):
|
| 35 |
-
"""boto3 storage backend (for internal usage).
|
| 36 |
-
|
| 37 |
-
Boto3Backend supports reading and writing data to multiple clusters.
|
| 38 |
-
If the file path contains the cluster name, Boto3Backend will read data
|
| 39 |
-
from specified cluster or write data to it. Otherwise, Boto3Backend will
|
| 40 |
-
access the default cluster.
|
| 41 |
-
|
| 42 |
-
Args:
|
| 43 |
-
path_mapping (dict, optional): Path mapping dict from local path to
|
| 44 |
-
Boto3 path. When ``path_mapping={'src': 'dst'}``, ``src`` in
|
| 45 |
-
``filepath`` will be replaced by ``dst``. Defaults to None.
|
| 46 |
-
s3_credential_path (str, optional): Config path of Boto3 client. Default: None.
|
| 47 |
-
`New in version 0.3.3`.
|
| 48 |
-
|
| 49 |
-
Examples:
|
| 50 |
-
>>> backend = Boto3Backend()
|
| 51 |
-
>>> filepath1 = 's3://path/of/file'
|
| 52 |
-
>>> filepath2 = 'cluster-name:s3://path/of/file'
|
| 53 |
-
>>> backend.get(filepath1) # get data from default cluster
|
| 54 |
-
>>> client.get(filepath2) # get data from 'cluster-name' cluster
|
| 55 |
-
"""
|
| 56 |
-
|
| 57 |
-
def __init__(
|
| 58 |
-
self,
|
| 59 |
-
s3_credential_path: str = "",
|
| 60 |
-
path_mapping: Optional[dict] = None,
|
| 61 |
-
):
|
| 62 |
-
self._client = Boto3Client(s3_credential_path=s3_credential_path)
|
| 63 |
-
assert isinstance(path_mapping, dict) or path_mapping is None
|
| 64 |
-
self.path_mapping = path_mapping
|
| 65 |
-
if path_mapping:
|
| 66 |
-
for k, v in path_mapping.items():
|
| 67 |
-
log.critical(f"Path mapping: {k} -> {v}", rank0_only=False)
|
| 68 |
-
|
| 69 |
-
def _map_path(self, filepath: Union[str, Path]) -> str:
|
| 70 |
-
"""Map ``filepath`` to a string path whose prefix will be replaced by
|
| 71 |
-
:attr:`self.path_mapping`.
|
| 72 |
-
|
| 73 |
-
Args:
|
| 74 |
-
filepath (str or Path): Path to be mapped.
|
| 75 |
-
"""
|
| 76 |
-
filepath = str(filepath)
|
| 77 |
-
if self.path_mapping is not None:
|
| 78 |
-
for k, v in self.path_mapping.items():
|
| 79 |
-
filepath = filepath.replace(k, v, 1)
|
| 80 |
-
return filepath
|
| 81 |
-
|
| 82 |
-
def _format_path(self, filepath: str) -> str:
|
| 83 |
-
"""Convert a ``filepath`` to standard format of s3 oss.
|
| 84 |
-
|
| 85 |
-
If the ``filepath`` is concatenated by ``os.path.join``, in a Windows
|
| 86 |
-
environment, the ``filepath`` will be the format of
|
| 87 |
-
's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the
|
| 88 |
-
above ``filepath`` will be converted to 's3://bucket_name/image.jpg'.
|
| 89 |
-
|
| 90 |
-
Args:
|
| 91 |
-
filepath (str): Path to be formatted.
|
| 92 |
-
"""
|
| 93 |
-
return re.sub(r"\\+", "/", filepath)
|
| 94 |
-
|
| 95 |
-
def _replace_prefix(self, filepath: Union[str, Path]) -> str:
|
| 96 |
-
filepath = str(filepath)
|
| 97 |
-
return filepath
|
| 98 |
-
# return filepath.replace('s3://', 's3://')
|
| 99 |
-
|
| 100 |
-
def get(self, filepath: Union[str, Path]) -> bytes:
|
| 101 |
-
"""Read bytes from a given ``filepath`` with 'rb' mode.
|
| 102 |
-
|
| 103 |
-
Args:
|
| 104 |
-
filepath (str or Path): Path to read data.
|
| 105 |
-
|
| 106 |
-
Returns:
|
| 107 |
-
bytes: Return bytes read from filepath.
|
| 108 |
-
|
| 109 |
-
Examples:
|
| 110 |
-
>>> backend = Boto3Backend()
|
| 111 |
-
>>> filepath = 's3://path/of/file'
|
| 112 |
-
>>> backend.get(filepath)
|
| 113 |
-
b'hello world'
|
| 114 |
-
"""
|
| 115 |
-
filepath = self._map_path(filepath)
|
| 116 |
-
filepath = self._format_path(filepath)
|
| 117 |
-
filepath = self._replace_prefix(filepath)
|
| 118 |
-
value = self._client.get(filepath)
|
| 119 |
-
return value
|
| 120 |
-
|
| 121 |
-
def get_text(
|
| 122 |
-
self,
|
| 123 |
-
filepath: Union[str, Path],
|
| 124 |
-
encoding: str = "utf-8",
|
| 125 |
-
) -> str:
|
| 126 |
-
"""Read text from a given ``filepath`` with 'r' mode.
|
| 127 |
-
|
| 128 |
-
Args:
|
| 129 |
-
filepath (str or Path): Path to read data.
|
| 130 |
-
encoding (str): The encoding format used to open the ``filepath``.
|
| 131 |
-
Defaults to 'utf-8'.
|
| 132 |
-
|
| 133 |
-
Returns:
|
| 134 |
-
str: Expected text reading from ``filepath``.
|
| 135 |
-
|
| 136 |
-
Examples:
|
| 137 |
-
>>> backend = Boto3Backend()
|
| 138 |
-
>>> filepath = 's3://path/of/file'
|
| 139 |
-
>>> backend.get_text(filepath)
|
| 140 |
-
'hello world'
|
| 141 |
-
"""
|
| 142 |
-
return str(self.get(filepath), encoding=encoding)
|
| 143 |
-
|
| 144 |
-
def put(self, obj: Union[bytes, io.BytesIO], filepath: Union[str, Path]) -> None:
|
| 145 |
-
"""Write bytes to a given ``filepath``.
|
| 146 |
-
|
| 147 |
-
Args:
|
| 148 |
-
obj (bytes): Data to be saved.
|
| 149 |
-
filepath (str or Path): Path to write data.
|
| 150 |
-
|
| 151 |
-
Examples:
|
| 152 |
-
>>> backend = Boto3Backend()
|
| 153 |
-
>>> filepath = 's3://path/of/file'
|
| 154 |
-
>>> backend.put(b'hello world', filepath)
|
| 155 |
-
"""
|
| 156 |
-
filepath = self._map_path(filepath)
|
| 157 |
-
filepath = self._format_path(filepath)
|
| 158 |
-
filepath = self._replace_prefix(filepath)
|
| 159 |
-
self._client.put(obj, filepath)
|
| 160 |
-
|
| 161 |
-
def fast_put(self, obj: Union[bytes, io.BytesIO], filepath: Union[str, Path], num_processes: int = 32) -> None:
|
| 162 |
-
"""Write bytes to a given ``filepath`` with multiple processes and async"""
|
| 163 |
-
assert num_processes > 1
|
| 164 |
-
filepath = self._map_path(filepath)
|
| 165 |
-
filepath = self._format_path(filepath)
|
| 166 |
-
filepath = self._replace_prefix(filepath)
|
| 167 |
-
self._client.fast_put(obj, filepath, num_processes=num_processes)
|
| 168 |
-
|
| 169 |
-
def put_text(
|
| 170 |
-
self,
|
| 171 |
-
obj: str,
|
| 172 |
-
filepath: Union[str, Path],
|
| 173 |
-
encoding: str = "utf-8",
|
| 174 |
-
) -> None:
|
| 175 |
-
"""Write text to a given ``filepath``.
|
| 176 |
-
|
| 177 |
-
Args:
|
| 178 |
-
obj (str): Data to be written.
|
| 179 |
-
filepath (str or Path): Path to write data.
|
| 180 |
-
encoding (str): The encoding format used to encode the ``obj``.
|
| 181 |
-
Defaults to 'utf-8'.
|
| 182 |
-
|
| 183 |
-
Examples:
|
| 184 |
-
>>> backend = Boto3Backend()
|
| 185 |
-
>>> filepath = 's3://path/of/file'
|
| 186 |
-
>>> backend.put_text('hello world', filepath)
|
| 187 |
-
"""
|
| 188 |
-
self.put(bytes(obj, encoding=encoding), filepath)
|
| 189 |
-
|
| 190 |
-
def exists(self, filepath: Union[str, Path]) -> bool:
|
| 191 |
-
"""Check whether a file path exists.
|
| 192 |
-
|
| 193 |
-
Args:
|
| 194 |
-
filepath (str or Path): Path to be checked whether exists.
|
| 195 |
-
|
| 196 |
-
Returns:
|
| 197 |
-
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
|
| 198 |
-
|
| 199 |
-
Examples:
|
| 200 |
-
>>> backend = Boto3Backend()
|
| 201 |
-
>>> filepath = 's3://path/of/file'
|
| 202 |
-
>>> backend.exists(filepath)
|
| 203 |
-
True
|
| 204 |
-
"""
|
| 205 |
-
if not (has_method(self._client, "contains") and has_method(self._client, "isdir")):
|
| 206 |
-
raise NotImplementedError(
|
| 207 |
-
"Current version of Boto3 Python SDK has not supported "
|
| 208 |
-
"the `contains` and `isdir` methods, please use a higher"
|
| 209 |
-
"version or dev branch instead."
|
| 210 |
-
)
|
| 211 |
-
|
| 212 |
-
filepath = self._map_path(filepath)
|
| 213 |
-
filepath = self._format_path(filepath)
|
| 214 |
-
filepath = self._replace_prefix(filepath)
|
| 215 |
-
return self._client.contains(filepath) or self._client.isdir(filepath)
|
| 216 |
-
|
| 217 |
-
def isdir(self, filepath: Union[str, Path]) -> bool:
|
| 218 |
-
"""Check whether a file path is a directory.
|
| 219 |
-
|
| 220 |
-
Args:
|
| 221 |
-
filepath (str or Path): Path to be checked whether it is a
|
| 222 |
-
directory.
|
| 223 |
-
|
| 224 |
-
Returns:
|
| 225 |
-
bool: Return ``True`` if ``filepath`` points to a directory,
|
| 226 |
-
``False`` otherwise.
|
| 227 |
-
|
| 228 |
-
Examples:
|
| 229 |
-
>>> backend = Boto3Backend()
|
| 230 |
-
>>> filepath = 's3://path/of/dir'
|
| 231 |
-
>>> backend.isdir(filepath)
|
| 232 |
-
True
|
| 233 |
-
"""
|
| 234 |
-
if not has_method(self._client, "isdir"):
|
| 235 |
-
raise NotImplementedError(
|
| 236 |
-
"Current version of Boto3 Python SDK has not supported "
|
| 237 |
-
"the `isdir` method, please use a higher version or dev"
|
| 238 |
-
" branch instead."
|
| 239 |
-
)
|
| 240 |
-
|
| 241 |
-
filepath = self._map_path(filepath)
|
| 242 |
-
filepath = self._format_path(filepath)
|
| 243 |
-
filepath = self._replace_prefix(filepath)
|
| 244 |
-
return self._client.isdir(filepath)
|
| 245 |
-
|
| 246 |
-
def isfile(self, filepath: Union[str, Path]) -> bool:
|
| 247 |
-
"""Check whether a file path is a file.
|
| 248 |
-
|
| 249 |
-
Args:
|
| 250 |
-
filepath (str or Path): Path to be checked whether it is a file.
|
| 251 |
-
|
| 252 |
-
Returns:
|
| 253 |
-
bool: Return ``True`` if ``filepath`` points to a file, ``False``
|
| 254 |
-
otherwise.
|
| 255 |
-
|
| 256 |
-
Examples:
|
| 257 |
-
>>> backend = Boto3Backend()
|
| 258 |
-
>>> filepath = 's3://path/of/file'
|
| 259 |
-
>>> backend.isfile(filepath)
|
| 260 |
-
True
|
| 261 |
-
"""
|
| 262 |
-
if not has_method(self._client, "contains"):
|
| 263 |
-
raise NotImplementedError(
|
| 264 |
-
"Current version of Boto3 Python SDK has not supported "
|
| 265 |
-
"the `contains` method, please use a higher version or "
|
| 266 |
-
"dev branch instead."
|
| 267 |
-
)
|
| 268 |
-
|
| 269 |
-
filepath = self._map_path(filepath)
|
| 270 |
-
filepath = self._format_path(filepath)
|
| 271 |
-
filepath = self._replace_prefix(filepath)
|
| 272 |
-
return self._client.contains(filepath)
|
| 273 |
-
|
| 274 |
-
def join_path(
|
| 275 |
-
self,
|
| 276 |
-
filepath: Union[str, Path],
|
| 277 |
-
*filepaths: Union[str, Path],
|
| 278 |
-
) -> str:
|
| 279 |
-
r"""Concatenate all file paths.
|
| 280 |
-
|
| 281 |
-
Join one or more filepath components intelligently. The return value
|
| 282 |
-
is the concatenation of filepath and any members of \*filepaths.
|
| 283 |
-
|
| 284 |
-
Args:
|
| 285 |
-
filepath (str or Path): Path to be concatenated.
|
| 286 |
-
|
| 287 |
-
Returns:
|
| 288 |
-
str: The result after concatenation.
|
| 289 |
-
|
| 290 |
-
Examples:
|
| 291 |
-
>>> backend = Boto3Backend()
|
| 292 |
-
>>> filepath = 's3://path/of/file'
|
| 293 |
-
>>> backend.join_path(filepath, 'another/path')
|
| 294 |
-
's3://path/of/file/another/path'
|
| 295 |
-
>>> backend.join_path(filepath, '/another/path')
|
| 296 |
-
's3://path/of/file/another/path'
|
| 297 |
-
"""
|
| 298 |
-
filepath = self._format_path(self._map_path(filepath))
|
| 299 |
-
if filepath.endswith("/"):
|
| 300 |
-
filepath = filepath[:-1]
|
| 301 |
-
formatted_paths = [filepath]
|
| 302 |
-
for path in filepaths:
|
| 303 |
-
formatted_path = self._format_path(self._map_path(path))
|
| 304 |
-
formatted_paths.append(formatted_path.lstrip("/"))
|
| 305 |
-
|
| 306 |
-
return "/".join(formatted_paths)
|
| 307 |
-
|
| 308 |
-
@contextmanager
|
| 309 |
-
def get_local_path(
|
| 310 |
-
self,
|
| 311 |
-
filepath: Union[str, Path],
|
| 312 |
-
) -> Generator[Union[str, Path], None, None]:
|
| 313 |
-
"""Download a file from ``filepath`` to a local temporary directory,
|
| 314 |
-
and return the temporary path.
|
| 315 |
-
|
| 316 |
-
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
|
| 317 |
-
can be called with ``with`` statement, and when exists from the
|
| 318 |
-
``with`` statement, the temporary path will be released.
|
| 319 |
-
|
| 320 |
-
Args:
|
| 321 |
-
filepath (str or Path): Download a file from ``filepath``.
|
| 322 |
-
|
| 323 |
-
Yields:
|
| 324 |
-
Iterable[str]: Only yield one temporary path.
|
| 325 |
-
|
| 326 |
-
Examples:
|
| 327 |
-
>>> backend = Boto3Backend()
|
| 328 |
-
>>> # After existing from the ``with`` clause,
|
| 329 |
-
>>> # the path will be removed
|
| 330 |
-
>>> filepath = 's3://path/of/file'
|
| 331 |
-
>>> with backend.get_local_path(filepath) as path:
|
| 332 |
-
... # do something here
|
| 333 |
-
"""
|
| 334 |
-
assert self.isfile(filepath)
|
| 335 |
-
try:
|
| 336 |
-
f = tempfile.NamedTemporaryFile(delete=False)
|
| 337 |
-
f.write(self.get(filepath))
|
| 338 |
-
f.close()
|
| 339 |
-
yield f.name
|
| 340 |
-
finally:
|
| 341 |
-
os.remove(f.name)
|
| 342 |
-
|
| 343 |
-
def copyfile(
|
| 344 |
-
self,
|
| 345 |
-
src: Union[str, Path],
|
| 346 |
-
dst: Union[str, Path],
|
| 347 |
-
) -> str:
|
| 348 |
-
"""Copy a file src to dst and return the destination file.
|
| 349 |
-
|
| 350 |
-
src and dst should have the same prefix. If dst specifies a directory,
|
| 351 |
-
the file will be copied into dst using the base filename from src. If
|
| 352 |
-
dst specifies a file that already exists, it will be replaced.
|
| 353 |
-
|
| 354 |
-
Args:
|
| 355 |
-
src (str or Path): A file to be copied.
|
| 356 |
-
dst (str or Path): Copy file to dst.
|
| 357 |
-
|
| 358 |
-
Returns:
|
| 359 |
-
str: The destination file.
|
| 360 |
-
|
| 361 |
-
Raises:
|
| 362 |
-
SameFileError: If src and dst are the same file, a SameFileError
|
| 363 |
-
will be raised.
|
| 364 |
-
|
| 365 |
-
Examples:
|
| 366 |
-
>>> backend = Boto3Backend()
|
| 367 |
-
>>> # dst is a file
|
| 368 |
-
>>> src = 's3://path/of/file'
|
| 369 |
-
>>> dst = 's3://path/of/file1'
|
| 370 |
-
>>> backend.copyfile(src, dst)
|
| 371 |
-
's3://path/of/file1'
|
| 372 |
-
|
| 373 |
-
>>> # dst is a directory
|
| 374 |
-
>>> dst = 's3://path/of/dir'
|
| 375 |
-
>>> backend.copyfile(src, dst)
|
| 376 |
-
's3://path/of/dir/file'
|
| 377 |
-
"""
|
| 378 |
-
src = self._format_path(self._map_path(src))
|
| 379 |
-
dst = self._format_path(self._map_path(dst))
|
| 380 |
-
if self.isdir(dst):
|
| 381 |
-
dst = self.join_path(dst, src.split("/")[-1])
|
| 382 |
-
|
| 383 |
-
if src == dst:
|
| 384 |
-
raise SameFileError("src and dst should not be same")
|
| 385 |
-
|
| 386 |
-
self.put(self.get(src), dst)
|
| 387 |
-
return dst
|
| 388 |
-
|
| 389 |
-
def copytree(
|
| 390 |
-
self,
|
| 391 |
-
src: Union[str, Path],
|
| 392 |
-
dst: Union[str, Path],
|
| 393 |
-
) -> str:
|
| 394 |
-
"""Recursively copy an entire directory tree rooted at src to a
|
| 395 |
-
directory named dst and return the destination directory.
|
| 396 |
-
|
| 397 |
-
src and dst should have the same prefix.
|
| 398 |
-
|
| 399 |
-
Args:
|
| 400 |
-
src (str or Path): A directory to be copied.
|
| 401 |
-
dst (str or Path): Copy directory to dst.
|
| 402 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 403 |
-
prefix of uri corresponding backend. Defaults to None.
|
| 404 |
-
|
| 405 |
-
Returns:
|
| 406 |
-
str: The destination directory.
|
| 407 |
-
|
| 408 |
-
Raises:
|
| 409 |
-
FileExistsError: If dst had already existed, a FileExistsError will
|
| 410 |
-
be raised.
|
| 411 |
-
|
| 412 |
-
Examples:
|
| 413 |
-
>>> backend = Boto3Backend()
|
| 414 |
-
>>> src = 's3://path/of/dir'
|
| 415 |
-
>>> dst = 's3://path/of/dir1'
|
| 416 |
-
>>> backend.copytree(src, dst)
|
| 417 |
-
's3://path/of/dir1'
|
| 418 |
-
"""
|
| 419 |
-
src = self._format_path(self._map_path(src))
|
| 420 |
-
dst = self._format_path(self._map_path(dst))
|
| 421 |
-
|
| 422 |
-
if self.exists(dst):
|
| 423 |
-
raise FileExistsError("dst should not exist")
|
| 424 |
-
|
| 425 |
-
for path in self.list_dir_or_file(src, list_dir=False, recursive=True):
|
| 426 |
-
src_path = self.join_path(src, path)
|
| 427 |
-
dst_path = self.join_path(dst, path)
|
| 428 |
-
self.put(self.get(src_path), dst_path)
|
| 429 |
-
|
| 430 |
-
return dst
|
| 431 |
-
|
| 432 |
-
def copyfile_from_local(
|
| 433 |
-
self,
|
| 434 |
-
src: Union[str, Path],
|
| 435 |
-
dst: Union[str, Path],
|
| 436 |
-
) -> str:
|
| 437 |
-
"""Upload a local file src to dst and return the destination file.
|
| 438 |
-
|
| 439 |
-
Args:
|
| 440 |
-
src (str or Path): A local file to be copied.
|
| 441 |
-
dst (str or Path): Copy file to dst.
|
| 442 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 443 |
-
prefix of uri corresponding backend. Defaults to None.
|
| 444 |
-
|
| 445 |
-
Returns:
|
| 446 |
-
str: If dst specifies a directory, the file will be copied into dst
|
| 447 |
-
using the base filename from src.
|
| 448 |
-
|
| 449 |
-
Examples:
|
| 450 |
-
>>> backend = Boto3Backend()
|
| 451 |
-
>>> # dst is a file
|
| 452 |
-
>>> src = 'path/of/your/file'
|
| 453 |
-
>>> dst = 's3://path/of/file1'
|
| 454 |
-
>>> backend.copyfile_from_local(src, dst)
|
| 455 |
-
's3://path/of/file1'
|
| 456 |
-
|
| 457 |
-
>>> # dst is a directory
|
| 458 |
-
>>> dst = 's3://path/of/dir'
|
| 459 |
-
>>> backend.copyfile_from_local(src, dst)
|
| 460 |
-
's3://path/of/dir/file'
|
| 461 |
-
"""
|
| 462 |
-
dst = self._format_path(self._map_path(dst))
|
| 463 |
-
if self.isdir(dst):
|
| 464 |
-
dst = self.join_path(dst, os.path.basename(src))
|
| 465 |
-
|
| 466 |
-
with open(src, "rb") as f:
|
| 467 |
-
self.put(f.read(), dst)
|
| 468 |
-
|
| 469 |
-
return dst
|
| 470 |
-
|
| 471 |
-
def copytree_from_local(
|
| 472 |
-
self,
|
| 473 |
-
src: Union[str, Path],
|
| 474 |
-
dst: Union[str, Path],
|
| 475 |
-
) -> str:
|
| 476 |
-
"""Recursively copy an entire directory tree rooted at src to a
|
| 477 |
-
directory named dst and return the destination directory.
|
| 478 |
-
|
| 479 |
-
Args:
|
| 480 |
-
src (str or Path): A local directory to be copied.
|
| 481 |
-
dst (str or Path): Copy directory to dst.
|
| 482 |
-
|
| 483 |
-
Returns:
|
| 484 |
-
str: The destination directory.
|
| 485 |
-
|
| 486 |
-
Raises:
|
| 487 |
-
FileExistsError: If dst had already existed, a FileExistsError will
|
| 488 |
-
be raised.
|
| 489 |
-
|
| 490 |
-
Examples:
|
| 491 |
-
>>> backend = Boto3Backend()
|
| 492 |
-
>>> src = 'path/of/your/dir'
|
| 493 |
-
>>> dst = 's3://path/of/dir1'
|
| 494 |
-
>>> backend.copytree_from_local(src, dst)
|
| 495 |
-
's3://path/of/dir1'
|
| 496 |
-
"""
|
| 497 |
-
dst = self._format_path(self._map_path(dst))
|
| 498 |
-
if self.exists(dst):
|
| 499 |
-
raise FileExistsError("dst should not exist")
|
| 500 |
-
|
| 501 |
-
src = str(src)
|
| 502 |
-
|
| 503 |
-
for cur_dir, _, files in os.walk(src):
|
| 504 |
-
for f in files:
|
| 505 |
-
src_path = os.path.join(cur_dir, f)
|
| 506 |
-
dst_path = self.join_path(dst, src_path.replace(src, ""))
|
| 507 |
-
self.copyfile_from_local(src_path, dst_path)
|
| 508 |
-
|
| 509 |
-
return dst
|
| 510 |
-
|
| 511 |
-
def copyfile_to_local(
|
| 512 |
-
self,
|
| 513 |
-
src: Union[str, Path],
|
| 514 |
-
dst: Union[str, Path],
|
| 515 |
-
dst_type: str, # Choose from ["file", "dir"]
|
| 516 |
-
) -> Union[str, Path]:
|
| 517 |
-
"""Copy the file src to local dst and return the destination file.
|
| 518 |
-
|
| 519 |
-
If dst specifies a directory, the file will be copied into dst using
|
| 520 |
-
the base filename from src. If dst specifies a file that already
|
| 521 |
-
exists, it will be replaced.
|
| 522 |
-
|
| 523 |
-
Args:
|
| 524 |
-
src (str or Path): A file to be copied.
|
| 525 |
-
dst (str or Path): Copy file to to local dst.
|
| 526 |
-
|
| 527 |
-
Returns:
|
| 528 |
-
str: If dst specifies a directory, the file will be copied into dst
|
| 529 |
-
using the base filename from src.
|
| 530 |
-
|
| 531 |
-
Examples:
|
| 532 |
-
>>> backend = Boto3Backend()
|
| 533 |
-
>>> # dst is a file
|
| 534 |
-
>>> src = 's3://path/of/file'
|
| 535 |
-
>>> dst = 'path/of/your/file'
|
| 536 |
-
>>> backend.copyfile_to_local(src, dst)
|
| 537 |
-
'path/of/your/file'
|
| 538 |
-
|
| 539 |
-
>>> # dst is a directory
|
| 540 |
-
>>> dst = 'path/of/your/dir'
|
| 541 |
-
>>> backend.copyfile_to_local(src, dst)
|
| 542 |
-
'path/of/your/dir/file'
|
| 543 |
-
"""
|
| 544 |
-
assert dst_type in ["file", "dir"]
|
| 545 |
-
# There is no good way to detect whether dst is a directory or a file, so we make dst_type required
|
| 546 |
-
if dst_type == "dir":
|
| 547 |
-
basename = os.path.basename(src)
|
| 548 |
-
if isinstance(dst, str):
|
| 549 |
-
dst = os.path.join(dst, basename)
|
| 550 |
-
else:
|
| 551 |
-
assert isinstance(dst, Path)
|
| 552 |
-
dst = dst / basename
|
| 553 |
-
|
| 554 |
-
# Create parent directory if it doesn't exist
|
| 555 |
-
parent_dir = os.path.dirname(dst)
|
| 556 |
-
os.makedirs(parent_dir, exist_ok=True)
|
| 557 |
-
|
| 558 |
-
try:
|
| 559 |
-
with open(dst, "wb") as f:
|
| 560 |
-
data = self.get(src)
|
| 561 |
-
f.write(data)
|
| 562 |
-
except Exception as e:
|
| 563 |
-
log.error(f"Failed to write file: {e}")
|
| 564 |
-
raise
|
| 565 |
-
|
| 566 |
-
return dst
|
| 567 |
-
|
| 568 |
-
def copytree_to_local(
|
| 569 |
-
self,
|
| 570 |
-
src: Union[str, Path],
|
| 571 |
-
dst: Union[str, Path],
|
| 572 |
-
) -> Union[str, Path]:
|
| 573 |
-
"""Recursively copy an entire directory tree rooted at src to a local
|
| 574 |
-
directory named dst and return the destination directory.
|
| 575 |
-
|
| 576 |
-
Args:
|
| 577 |
-
src (str or Path): A directory to be copied.
|
| 578 |
-
dst (str or Path): Copy directory to local dst.
|
| 579 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 580 |
-
prefix of uri corresponding backend. Defaults to None.
|
| 581 |
-
|
| 582 |
-
Returns:
|
| 583 |
-
str: The destination directory.
|
| 584 |
-
|
| 585 |
-
Examples:
|
| 586 |
-
>>> backend = Boto3Backend()
|
| 587 |
-
>>> src = 's3://path/of/dir'
|
| 588 |
-
>>> dst = 'path/of/your/dir'
|
| 589 |
-
>>> backend.copytree_to_local(src, dst)
|
| 590 |
-
'path/of/your/dir'
|
| 591 |
-
"""
|
| 592 |
-
for path in self.list_dir_or_file(src, list_dir=False, recursive=True):
|
| 593 |
-
dst_path = os.path.join(dst, path)
|
| 594 |
-
mkdir_or_exist(os.path.dirname(dst_path))
|
| 595 |
-
with open(dst_path, "wb") as f:
|
| 596 |
-
f.write(self.get(self.join_path(src, path)))
|
| 597 |
-
|
| 598 |
-
return dst
|
| 599 |
-
|
| 600 |
-
def remove(self, filepath: Union[str, Path]) -> None:
|
| 601 |
-
"""Remove a file.
|
| 602 |
-
|
| 603 |
-
Args:
|
| 604 |
-
filepath (str or Path): Path to be removed.
|
| 605 |
-
|
| 606 |
-
Raises:
|
| 607 |
-
FileNotFoundError: If filepath does not exist, an FileNotFoundError
|
| 608 |
-
will be raised.
|
| 609 |
-
IsADirectoryError: If filepath is a directory, an IsADirectoryError
|
| 610 |
-
will be raised.
|
| 611 |
-
|
| 612 |
-
Examples:
|
| 613 |
-
>>> backend = Boto3Backend()
|
| 614 |
-
>>> filepath = 's3://path/of/file'
|
| 615 |
-
>>> backend.remove(filepath)
|
| 616 |
-
"""
|
| 617 |
-
if not has_method(self._client, "delete"):
|
| 618 |
-
raise NotImplementedError(
|
| 619 |
-
"Current version of Boto3 Python SDK has not supported "
|
| 620 |
-
"the `delete` method, please use a higher version or dev "
|
| 621 |
-
"branch instead."
|
| 622 |
-
)
|
| 623 |
-
|
| 624 |
-
if not self.exists(filepath):
|
| 625 |
-
raise FileNotFoundError(f"filepath {filepath} does not exist")
|
| 626 |
-
|
| 627 |
-
if self.isdir(filepath):
|
| 628 |
-
raise IsADirectoryError("filepath should be a file")
|
| 629 |
-
|
| 630 |
-
filepath = self._map_path(filepath)
|
| 631 |
-
filepath = self._format_path(filepath)
|
| 632 |
-
filepath = self._replace_prefix(filepath)
|
| 633 |
-
self._client.delete(filepath)
|
| 634 |
-
|
| 635 |
-
def rmtree(self, dir_path: Union[str, Path]) -> None:
|
| 636 |
-
"""Recursively delete a directory tree.
|
| 637 |
-
|
| 638 |
-
Args:
|
| 639 |
-
dir_path (str or Path): A directory to be removed.
|
| 640 |
-
|
| 641 |
-
Examples:
|
| 642 |
-
>>> backend = Boto3Backend()
|
| 643 |
-
>>> dir_path = 's3://path/of/dir'
|
| 644 |
-
>>> backend.rmtree(dir_path)
|
| 645 |
-
"""
|
| 646 |
-
for path in self.list_dir_or_file(dir_path, list_dir=False, recursive=True):
|
| 647 |
-
filepath = self.join_path(dir_path, path)
|
| 648 |
-
self.remove(filepath)
|
| 649 |
-
|
| 650 |
-
def copy_if_symlink_fails(
|
| 651 |
-
self,
|
| 652 |
-
src: Union[str, Path],
|
| 653 |
-
dst: Union[str, Path],
|
| 654 |
-
) -> bool:
|
| 655 |
-
"""Create a symbolic link pointing to src named dst.
|
| 656 |
-
|
| 657 |
-
Directly copy src to dst because PetrelBacekend does not support create
|
| 658 |
-
a symbolic link.
|
| 659 |
-
|
| 660 |
-
Args:
|
| 661 |
-
src (str or Path): A file or directory to be copied.
|
| 662 |
-
dst (str or Path): Copy a file or directory to dst.
|
| 663 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 664 |
-
prefix of uri corresponding backend. Defaults to None.
|
| 665 |
-
|
| 666 |
-
Returns:
|
| 667 |
-
bool: Return False because Boto3Backend does not support create
|
| 668 |
-
a symbolic link.
|
| 669 |
-
|
| 670 |
-
Examples:
|
| 671 |
-
>>> backend = Boto3Backend()
|
| 672 |
-
>>> src = 's3://path/of/file'
|
| 673 |
-
>>> dst = 's3://path/of/your/file'
|
| 674 |
-
>>> backend.copy_if_symlink_fails(src, dst)
|
| 675 |
-
False
|
| 676 |
-
>>> src = 's3://path/of/dir'
|
| 677 |
-
>>> dst = 's3://path/of/your/dir'
|
| 678 |
-
>>> backend.copy_if_symlink_fails(src, dst)
|
| 679 |
-
False
|
| 680 |
-
"""
|
| 681 |
-
if self.isfile(src):
|
| 682 |
-
self.copyfile(src, dst)
|
| 683 |
-
else:
|
| 684 |
-
self.copytree(src, dst)
|
| 685 |
-
return False
|
| 686 |
-
|
| 687 |
-
def list_dir(self, dir_path: Union[str, Path]):
|
| 688 |
-
"""List all folders in an S3 bucket with a given prefix.
|
| 689 |
-
|
| 690 |
-
Args:
|
| 691 |
-
dir_path (str | Path): Path of the directory.
|
| 692 |
-
|
| 693 |
-
Examples:
|
| 694 |
-
>>> backend = Boto3Backend()
|
| 695 |
-
>>> dir_path = 's3://path/of/dir'
|
| 696 |
-
>>> backend.list_dir(dir_path)
|
| 697 |
-
"""
|
| 698 |
-
dir_path = self._map_path(dir_path)
|
| 699 |
-
dir_path = self._format_path(dir_path)
|
| 700 |
-
dir_path = self._replace_prefix(dir_path)
|
| 701 |
-
return self._client.ls_dir(dir_path)
|
| 702 |
-
|
| 703 |
-
def list_dir_or_file( # pylint: disable=too-many-arguments
|
| 704 |
-
self,
|
| 705 |
-
dir_path: Union[str, Path],
|
| 706 |
-
list_dir: bool = True,
|
| 707 |
-
list_file: bool = True,
|
| 708 |
-
suffix: Optional[Union[str, Tuple[str]]] = None,
|
| 709 |
-
recursive: bool = False,
|
| 710 |
-
) -> Iterator[str]:
|
| 711 |
-
"""Scan a directory to find the interested directories or files in
|
| 712 |
-
arbitrary order.
|
| 713 |
-
|
| 714 |
-
Note:
|
| 715 |
-
Boto3 has no concept of directories but it simulates the directory
|
| 716 |
-
hierarchy in the filesystem through public prefixes. In addition,
|
| 717 |
-
if the returned path ends with '/', it means the path is a public
|
| 718 |
-
prefix which is a logical directory.
|
| 719 |
-
|
| 720 |
-
Note:
|
| 721 |
-
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
|
| 722 |
-
In addition, the returned path of directory will not contains the
|
| 723 |
-
suffix '/' which is consistent with other backends.
|
| 724 |
-
|
| 725 |
-
Args:
|
| 726 |
-
dir_path (str | Path): Path of the directory.
|
| 727 |
-
list_dir (bool): List the directories. Defaults to True.
|
| 728 |
-
list_file (bool): List the path of files. Defaults to True.
|
| 729 |
-
suffix (str or tuple[str], optional): File suffix
|
| 730 |
-
that we are interested in. Defaults to None.
|
| 731 |
-
recursive (bool): If set to True, recursively scan the
|
| 732 |
-
directory. Defaults to False.
|
| 733 |
-
|
| 734 |
-
Yields:
|
| 735 |
-
Iterable[str]: A relative path to ``dir_path``.
|
| 736 |
-
|
| 737 |
-
Examples:
|
| 738 |
-
>>> backend = Boto3Backend()
|
| 739 |
-
>>> dir_path = 's3://path/of/dir'
|
| 740 |
-
>>> # list those files and directories in current directory
|
| 741 |
-
>>> for file_path in backend.list_dir_or_file(dir_path):
|
| 742 |
-
... print(file_path)
|
| 743 |
-
>>> # only list files
|
| 744 |
-
>>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False):
|
| 745 |
-
... print(file_path)
|
| 746 |
-
>>> # only list directories
|
| 747 |
-
>>> for file_path in backend.list_dir_or_file(dir_path, list_file=False):
|
| 748 |
-
... print(file_path)
|
| 749 |
-
>>> # only list files ending with specified suffixes
|
| 750 |
-
>>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'):
|
| 751 |
-
... print(file_path)
|
| 752 |
-
>>> # list all files and directory recursively
|
| 753 |
-
>>> for file_path in backend.list_dir_or_file(dir_path, recursive=True):
|
| 754 |
-
... print(file_path)
|
| 755 |
-
""" # noqa: E501
|
| 756 |
-
if not has_method(self._client, "list"):
|
| 757 |
-
raise NotImplementedError(
|
| 758 |
-
"Current version of Boto3 Python SDK has not supported "
|
| 759 |
-
"the `list` method, please use a higher version or dev"
|
| 760 |
-
" branch instead."
|
| 761 |
-
)
|
| 762 |
-
|
| 763 |
-
dir_path = self._map_path(dir_path)
|
| 764 |
-
dir_path = self._format_path(dir_path)
|
| 765 |
-
dir_path = self._replace_prefix(dir_path)
|
| 766 |
-
if list_dir and suffix is not None:
|
| 767 |
-
raise TypeError("`list_dir` should be False when `suffix` is not None")
|
| 768 |
-
|
| 769 |
-
if list_dir and not list_file and not recursive:
|
| 770 |
-
raise TypeError(
|
| 771 |
-
"Please use `list_dir` instead of `list_dir_or_file` when you only want to list the first level directories."
|
| 772 |
-
)
|
| 773 |
-
|
| 774 |
-
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
| 775 |
-
raise TypeError("`suffix` must be a string or tuple of strings")
|
| 776 |
-
|
| 777 |
-
# Boto3's simulated directory hierarchy assumes that directory paths
|
| 778 |
-
# should end with `/`
|
| 779 |
-
if not dir_path.endswith("/"):
|
| 780 |
-
dir_path += "/"
|
| 781 |
-
|
| 782 |
-
root = dir_path
|
| 783 |
-
|
| 784 |
-
def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive):
|
| 785 |
-
# Keep track of directories we've already yielded to avoid duplicates
|
| 786 |
-
yielded_dirs = set() if list_dir else None
|
| 787 |
-
|
| 788 |
-
for path in self._client.list(dir_path):
|
| 789 |
-
# All paths returned by S3 list are file paths, never directory paths
|
| 790 |
-
absolute_path = self.join_path(dir_path, path)
|
| 791 |
-
rel_path = absolute_path[len(root) :]
|
| 792 |
-
|
| 793 |
-
# If we want directories, extract directory prefixes from file paths
|
| 794 |
-
# boto3 client actually never return dir, it only return file paths
|
| 795 |
-
if list_dir and "/" in rel_path:
|
| 796 |
-
if not recursive:
|
| 797 |
-
# Non-recursive: only yield immediate child directory (first level)
|
| 798 |
-
first_slash_pos = rel_path.find("/")
|
| 799 |
-
immediate_child_dir = rel_path[:first_slash_pos]
|
| 800 |
-
|
| 801 |
-
if immediate_child_dir not in yielded_dirs:
|
| 802 |
-
yielded_dirs.add(immediate_child_dir)
|
| 803 |
-
yield immediate_child_dir
|
| 804 |
-
else:
|
| 805 |
-
# Recursive: yield all directory levels
|
| 806 |
-
path_parts = rel_path.split("/")[:-1] # Exclude filename
|
| 807 |
-
current_dir = ""
|
| 808 |
-
for part in path_parts:
|
| 809 |
-
if current_dir:
|
| 810 |
-
current_dir += "/" + part
|
| 811 |
-
else:
|
| 812 |
-
current_dir = part
|
| 813 |
-
|
| 814 |
-
if current_dir not in yielded_dirs:
|
| 815 |
-
yielded_dirs.add(current_dir)
|
| 816 |
-
yield current_dir
|
| 817 |
-
|
| 818 |
-
# Handle file listing
|
| 819 |
-
if (suffix is None or rel_path.endswith(suffix)) and list_file:
|
| 820 |
-
yield rel_path
|
| 821 |
-
|
| 822 |
-
return _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive)
|
| 823 |
-
|
| 824 |
-
def generate_presigned_url(self, url: str, client_method: str = "get_object", expires_in: int = 3600) -> str:
|
| 825 |
-
"""Generate the presigned url of video stream which can be passed to
|
| 826 |
-
mmcv.VideoReader. Now only work on Boto3 backend.
|
| 827 |
-
|
| 828 |
-
Note:
|
| 829 |
-
Now only work on Boto3 backend.
|
| 830 |
-
|
| 831 |
-
Args:
|
| 832 |
-
url (str): Url of video stream.
|
| 833 |
-
client_method (str): Method of client, 'get_object' or
|
| 834 |
-
'put_object'. Default: 'get_object'.
|
| 835 |
-
expires_in (int): expires, in seconds. Default: 3600.
|
| 836 |
-
|
| 837 |
-
Returns:
|
| 838 |
-
str: Generated presigned url.
|
| 839 |
-
"""
|
| 840 |
-
raise NotImplementedError("generate_presigned_url is not supported in Boto3Backend")
|
| 841 |
-
return self._client.generate_presigned_url(url, client_method, expires_in)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/backends/boto3_client.py
DELETED
|
@@ -1,565 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import asyncio
|
| 17 |
-
import concurrent.futures
|
| 18 |
-
import io
|
| 19 |
-
import os
|
| 20 |
-
import time
|
| 21 |
-
from math import ceil
|
| 22 |
-
from multiprocessing import shared_memory
|
| 23 |
-
from typing import Any, Dict, Generator, List, Tuple
|
| 24 |
-
|
| 25 |
-
import boto3
|
| 26 |
-
import numpy as np
|
| 27 |
-
from botocore.config import Config as S3Config
|
| 28 |
-
from botocore.exceptions import ClientError
|
| 29 |
-
|
| 30 |
-
import lyra_2._ext.imaginaire.utils.easy_io.backends.auto_auth as auto
|
| 31 |
-
from lyra_2._ext.imaginaire.utils import log
|
| 32 |
-
from lyra_2._ext.imaginaire.utils.env_parsers.cred_env_parser import CRED_ENVS
|
| 33 |
-
|
| 34 |
-
try:
|
| 35 |
-
import aioboto3
|
| 36 |
-
import aioboto3.session
|
| 37 |
-
from aiobotocore.config import AioConfig
|
| 38 |
-
from aiobotocore.session import AioSession
|
| 39 |
-
except ImportError:
|
| 40 |
-
aioboto3 = None
|
| 41 |
-
AioSession = None
|
| 42 |
-
|
| 43 |
-
MAX_RETRIES = 5
|
| 44 |
-
RETRY_DELAY = 1 # seconds
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
async def upload_single_part_async(
|
| 48 |
-
s3: AioSession, bucket: str, key: str, part_number: int, data: bytes, upload_id: str
|
| 49 |
-
) -> Dict[str, Any]:
|
| 50 |
-
"""
|
| 51 |
-
Uploads a single part of a file asynchronously to S3.
|
| 52 |
-
|
| 53 |
-
Args:
|
| 54 |
-
s3 (S3): The S3 client.
|
| 55 |
-
bucket (str): The S3 bucket name.
|
| 56 |
-
key (str): The S3 key (file path).
|
| 57 |
-
part_number (int): The part number of the upload.
|
| 58 |
-
data (bytes): The data to upload.
|
| 59 |
-
upload_id (str): The upload ID for the multipart upload.
|
| 60 |
-
|
| 61 |
-
Returns:
|
| 62 |
-
Dict[str, Any]: A dictionary containing the part number and ETag.
|
| 63 |
-
"""
|
| 64 |
-
for attempt in range(MAX_RETRIES):
|
| 65 |
-
try:
|
| 66 |
-
response = await s3.upload_part(
|
| 67 |
-
Bucket=bucket, Key=key, PartNumber=part_number, UploadId=upload_id, Body=data
|
| 68 |
-
)
|
| 69 |
-
return {"PartNumber": part_number, "ETag": response["ETag"]}
|
| 70 |
-
except (ClientError, asyncio.TimeoutError, Exception) as e:
|
| 71 |
-
log.warning(f"Attempt {attempt + 1} failed for part {part_number}: {str(e)}", rank0_only=False)
|
| 72 |
-
if attempt < MAX_RETRIES - 1:
|
| 73 |
-
await asyncio.sleep(RETRY_DELAY * (2**attempt)) # Exponential backoff
|
| 74 |
-
else:
|
| 75 |
-
log.error(f"Failed to upload part {part_number} after {MAX_RETRIES} attempts", rank0_only=False)
|
| 76 |
-
raise
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
async def upload_parts_async(
|
| 80 |
-
part_size: int,
|
| 81 |
-
part_numbers: range,
|
| 82 |
-
upload_id: str,
|
| 83 |
-
data: bytes,
|
| 84 |
-
bucket: str,
|
| 85 |
-
key: str,
|
| 86 |
-
client_config: Dict[str, Any],
|
| 87 |
-
) -> List[Dict[str, Any]]:
|
| 88 |
-
"""
|
| 89 |
-
Uploads multiple parts of a file asynchronously to S3.
|
| 90 |
-
|
| 91 |
-
Args:
|
| 92 |
-
part_size (int): The size of each part in bytes.
|
| 93 |
-
part_numbers (range): The range of part numbers to upload.
|
| 94 |
-
upload_id (str): The upload ID for the multipart upload.
|
| 95 |
-
data (bytes): The data to upload.
|
| 96 |
-
bucket (str): The S3 bucket name.
|
| 97 |
-
key (str): The S3 key (file path).
|
| 98 |
-
client_config (Dict[str, Any]): The S3 client configuration.
|
| 99 |
-
|
| 100 |
-
Returns:
|
| 101 |
-
List[Dict[str, Any]]: A list of dictionaries containing part numbers and ETags.
|
| 102 |
-
"""
|
| 103 |
-
session = aioboto3.Session()
|
| 104 |
-
config = AioConfig(retries={"max_attempts": 3, "mode": "adaptive"}, connect_timeout=5, read_timeout=10)
|
| 105 |
-
start_idx = part_numbers[0]
|
| 106 |
-
async with session.client("s3", config=config, **client_config) as s3:
|
| 107 |
-
tasks = []
|
| 108 |
-
for part_number in part_numbers:
|
| 109 |
-
start = (part_number - start_idx) * part_size
|
| 110 |
-
end = min(start + part_size, len(data))
|
| 111 |
-
part_data = data[start:end]
|
| 112 |
-
tasks.append(upload_single_part_async(s3, bucket, key, part_number + 1, part_data, upload_id))
|
| 113 |
-
|
| 114 |
-
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 115 |
-
|
| 116 |
-
successful_parts = []
|
| 117 |
-
failed_parts = []
|
| 118 |
-
for part_number, result in enumerate(results, start=start_idx + 1):
|
| 119 |
-
if isinstance(result, Exception):
|
| 120 |
-
failed_parts.append(part_number)
|
| 121 |
-
else:
|
| 122 |
-
successful_parts.append(result)
|
| 123 |
-
|
| 124 |
-
if failed_parts:
|
| 125 |
-
log.error(f"Failed to upload parts: {failed_parts}", rank0_only=False)
|
| 126 |
-
raise Exception(f"Failed to upload {len(failed_parts)} parts")
|
| 127 |
-
|
| 128 |
-
successful_parts.sort(key=lambda part: part["PartNumber"])
|
| 129 |
-
return successful_parts
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def upload_parts_to_s3(args: Tuple[range, str, int, bytes, str, str, Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 133 |
-
"""
|
| 134 |
-
Uploads parts of a file to S3 using a new event loop.
|
| 135 |
-
|
| 136 |
-
Args:
|
| 137 |
-
args (Tuple[range, str, int, bytes, str, str, Dict[str, Any]]): The arguments for uploading parts, including:
|
| 138 |
-
part_numbers (range): The range of part numbers to upload.
|
| 139 |
-
upload_id (str): The upload ID for the multipart upload.
|
| 140 |
-
part_size (int): The size of each part in bytes.
|
| 141 |
-
data (bytes): The data to upload.
|
| 142 |
-
bucket (str): The S3 bucket name.
|
| 143 |
-
key (str): The S3 key (file path).
|
| 144 |
-
client_config (Dict[str, Any]): The S3 client configuration.
|
| 145 |
-
|
| 146 |
-
Returns:
|
| 147 |
-
List[Dict[str, Any]]: A list of dictionaries containing part numbers and ETags.
|
| 148 |
-
"""
|
| 149 |
-
part_numbers, upload_id, part_size, data, bucket, key, client_config = args
|
| 150 |
-
loop = asyncio.new_event_loop()
|
| 151 |
-
asyncio.set_event_loop(loop)
|
| 152 |
-
parts = loop.run_until_complete(
|
| 153 |
-
upload_parts_async(part_size, part_numbers, upload_id, data, bucket, key, client_config)
|
| 154 |
-
)
|
| 155 |
-
loop.close()
|
| 156 |
-
return parts
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
async def download_single_part_async(
|
| 160 |
-
s3, bucket: str, key: str, part_number: int, start: int, end: int, shm_name: str, part_size: int
|
| 161 |
-
) -> None:
|
| 162 |
-
"""
|
| 163 |
-
Downloads a single part of a file asynchronously and writes it to shared memory.
|
| 164 |
-
|
| 165 |
-
Args:
|
| 166 |
-
s3 (S3): The S3 client.
|
| 167 |
-
bucket (str): The S3 bucket name.
|
| 168 |
-
key (str): The S3 key (file path).
|
| 169 |
-
part_number (int): The part number.
|
| 170 |
-
start (int): The start byte of the part.
|
| 171 |
-
end (int): The end byte of the part.
|
| 172 |
-
shm_name (str): The name of the shared memory block.
|
| 173 |
-
part_size (int): The size of each part in bytes.
|
| 174 |
-
"""
|
| 175 |
-
for attempt in range(MAX_RETRIES):
|
| 176 |
-
try:
|
| 177 |
-
range_header = f"bytes={start}-{end}"
|
| 178 |
-
response = await s3.get_object(Bucket=bucket, Key=key, Range=range_header)
|
| 179 |
-
data = await response["Body"].read()
|
| 180 |
-
|
| 181 |
-
shm = shared_memory.SharedMemory(name=shm_name)
|
| 182 |
-
offset = part_number * part_size
|
| 183 |
-
shm.buf[offset : offset + len(data)] = data
|
| 184 |
-
shm.close()
|
| 185 |
-
return
|
| 186 |
-
except (ClientError, asyncio.TimeoutError, Exception) as e:
|
| 187 |
-
log.warning(f"Attempt {attempt + 1} failed for part {part_number}: {str(e)}", rank0_only=False)
|
| 188 |
-
if attempt < MAX_RETRIES - 1:
|
| 189 |
-
await asyncio.sleep(RETRY_DELAY * (2**attempt)) # Exponential backoff
|
| 190 |
-
else:
|
| 191 |
-
log.error(f"Failed to download part {part_number} after {MAX_RETRIES} attempts", rank0_only=False)
|
| 192 |
-
raise
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
async def download_parts_async(
|
| 196 |
-
part_size: int, part_numbers: range, bucket: str, key: str, client_config: Dict[str, Any], shm_name: str
|
| 197 |
-
) -> None:
|
| 198 |
-
"""
|
| 199 |
-
Downloads multiple parts of a file asynchronously and writes them to shared memory.
|
| 200 |
-
|
| 201 |
-
Args:
|
| 202 |
-
part_size (int): The size of each part in bytes.
|
| 203 |
-
part_numbers (range): The range of part numbers to download.
|
| 204 |
-
bucket (str): The S3 bucket name.
|
| 205 |
-
key (str): The S3 key (file path).
|
| 206 |
-
client_config (Dict[str, Any]): The S3 client configuration.
|
| 207 |
-
shm_name (str): The name of the shared memory block.
|
| 208 |
-
"""
|
| 209 |
-
session = aioboto3.Session()
|
| 210 |
-
config = AioConfig(retries={"max_attempts": 5, "mode": "adaptive"}, connect_timeout=10, read_timeout=30)
|
| 211 |
-
async with session.client("s3", config=config, **client_config) as s3:
|
| 212 |
-
tasks = [
|
| 213 |
-
download_single_part_async(
|
| 214 |
-
s3,
|
| 215 |
-
bucket,
|
| 216 |
-
key,
|
| 217 |
-
part_number,
|
| 218 |
-
part_number * part_size,
|
| 219 |
-
(part_number + 1) * part_size - 1,
|
| 220 |
-
shm_name,
|
| 221 |
-
part_size,
|
| 222 |
-
)
|
| 223 |
-
for part_number in part_numbers
|
| 224 |
-
]
|
| 225 |
-
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 226 |
-
failed_parts = [part for part, result in zip(part_numbers, results) if isinstance(result, Exception)]
|
| 227 |
-
|
| 228 |
-
if failed_parts:
|
| 229 |
-
log.error(f"Failed to download parts: {failed_parts}", rank0_only=False)
|
| 230 |
-
raise Exception(f"Failed to download {len(failed_parts)} parts")
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
def download_parts_to_s3(args: Tuple[range, int, str, str, Dict[str, Any], str]) -> bytes:
|
| 234 |
-
"""
|
| 235 |
-
Downloads parts of a file using a new event loop.
|
| 236 |
-
|
| 237 |
-
Args:
|
| 238 |
-
args (Tuple[range, int, str, str, Dict[str, Any]]): The arguments for downloading parts, including:
|
| 239 |
-
part_numbers (range): The range of part numbers to download.
|
| 240 |
-
part_size (int): The size of each part in bytes.
|
| 241 |
-
bucket (str): The S3 bucket name.
|
| 242 |
-
key (str): The S3 key (file path).
|
| 243 |
-
client_config (Dict[str, Any]): The S3 client configuration.
|
| 244 |
-
|
| 245 |
-
Returns:
|
| 246 |
-
bytes: The combined file data from all downloaded parts.
|
| 247 |
-
"""
|
| 248 |
-
part_numbers, part_size, bucket, key, client_config, shm_name = args
|
| 249 |
-
loop = asyncio.new_event_loop()
|
| 250 |
-
asyncio.set_event_loop(loop)
|
| 251 |
-
loop.run_until_complete(download_parts_async(part_size, part_numbers, bucket, key, client_config, shm_name))
|
| 252 |
-
loop.close()
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
class Boto3Client:
|
| 256 |
-
def __init__(
|
| 257 |
-
self,
|
| 258 |
-
s3_credential_path: str,
|
| 259 |
-
max_attempt: int = 3,
|
| 260 |
-
):
|
| 261 |
-
self.max_attempt = max_attempt
|
| 262 |
-
assert s3_credential_path, "s3_credential_path is required"
|
| 263 |
-
assert os.path.exists(s3_credential_path) or CRED_ENVS.APP_ENV in [
|
| 264 |
-
"prod",
|
| 265 |
-
"dev",
|
| 266 |
-
"stg",
|
| 267 |
-
], f"Credential file not found: {s3_credential_path}"
|
| 268 |
-
with auto.open_auth(s3_credential_path, "r") as f:
|
| 269 |
-
conf = auto.json_load_auth(f)
|
| 270 |
-
|
| 271 |
-
s3_config = S3Config(
|
| 272 |
-
signature_version="s3v4",
|
| 273 |
-
s3={"addressing_style": "virtual"},
|
| 274 |
-
response_checksum_validation="when_required",
|
| 275 |
-
request_checksum_calculation="when_required",
|
| 276 |
-
)
|
| 277 |
-
|
| 278 |
-
self._client = boto3.client("s3", **conf, config=s3_config)
|
| 279 |
-
self._s3_cred_info = conf
|
| 280 |
-
self._mc_kv_store = None
|
| 281 |
-
|
| 282 |
-
def get(self, filepath):
|
| 283 |
-
filepath = self._check_path(filepath)
|
| 284 |
-
|
| 285 |
-
if self._mc_kv_store and self._mc_kv_store.available:
|
| 286 |
-
if self._mc_kv_store.has(filepath):
|
| 287 |
-
return self._mc_kv_store.get(filepath)
|
| 288 |
-
|
| 289 |
-
attempt = 0
|
| 290 |
-
while attempt < self.max_attempt:
|
| 291 |
-
try:
|
| 292 |
-
buffer = io.BytesIO()
|
| 293 |
-
self._client.download_fileobj(
|
| 294 |
-
Bucket=filepath.split("/")[0],
|
| 295 |
-
Key="/".join(filepath.split("/")[1:]),
|
| 296 |
-
Fileobj=buffer,
|
| 297 |
-
)
|
| 298 |
-
buffer.seek(0)
|
| 299 |
-
if self._mc_kv_store and self._mc_kv_store.available:
|
| 300 |
-
self._mc_kv_store.put(filepath, buffer.read())
|
| 301 |
-
|
| 302 |
-
return buffer.read()
|
| 303 |
-
except Exception as e:
|
| 304 |
-
attempt += 1
|
| 305 |
-
log.error(f"Got an exception: attempt={attempt} - {e} - {filepath}", rank0_only=False)
|
| 306 |
-
|
| 307 |
-
raise ConnectionError("Unable to read {} from. {} attempts tried.".format(filepath, attempt))
|
| 308 |
-
|
| 309 |
-
def _get_file_size(self, bucket, key, max_retries=10):
|
| 310 |
-
retries = 0
|
| 311 |
-
while retries < max_retries:
|
| 312 |
-
try:
|
| 313 |
-
# Try to get the file size
|
| 314 |
-
file_size = self._client.head_object(Bucket=bucket, Key=key)["ContentLength"]
|
| 315 |
-
return file_size # Return file size if successful
|
| 316 |
-
except ClientError as e:
|
| 317 |
-
retries += 1
|
| 318 |
-
log.error(f"Attempt {retries} failed for s3://{bucket}/{key}: {e}", rank0_only=False)
|
| 319 |
-
if retries >= max_retries:
|
| 320 |
-
raise # Re-raise the exception after max retries
|
| 321 |
-
time.sleep(2) # Wait for 2 seconds before retrying
|
| 322 |
-
except Exception as e:
|
| 323 |
-
retries += 1
|
| 324 |
-
log.error(
|
| 325 |
-
f"Attempt {retries} failed for s3://{bucket}/{key}: due to an unexpected error: {e}",
|
| 326 |
-
rank0_only=False,
|
| 327 |
-
)
|
| 328 |
-
if retries >= max_retries:
|
| 329 |
-
raise # Re-raise the exception after max retries
|
| 330 |
-
time.sleep(2) # Wait for 2 seconds before retrying
|
| 331 |
-
|
| 332 |
-
def put(self, obj, filepath):
|
| 333 |
-
filepath = self._check_path(filepath)
|
| 334 |
-
bucket_name = filepath.split("/")[0]
|
| 335 |
-
key = "/".join(filepath.split("/")[1:])
|
| 336 |
-
attempt = 0
|
| 337 |
-
while attempt < self.max_attempt:
|
| 338 |
-
try:
|
| 339 |
-
# If obj is a string path to a local file, use upload_file instead
|
| 340 |
-
if isinstance(obj, str) and os.path.isfile(obj):
|
| 341 |
-
self._client.upload_file(Filename=obj, Bucket=bucket_name, Key=key)
|
| 342 |
-
return
|
| 343 |
-
if isinstance(obj, io.BytesIO):
|
| 344 |
-
obj.seek(0)
|
| 345 |
-
self._client.upload_fileobj(obj, Bucket=bucket_name, Key=key)
|
| 346 |
-
return
|
| 347 |
-
if isinstance(obj, bytes):
|
| 348 |
-
self._client.put_object(Body=obj, Bucket=bucket_name, Key=key)
|
| 349 |
-
return
|
| 350 |
-
else:
|
| 351 |
-
raise ValueError("Unsupported object type for upload")
|
| 352 |
-
except ClientError as e:
|
| 353 |
-
attempt += 1
|
| 354 |
-
log.error(f"Got an exception: attempt={attempt} - {e} - {filepath}", rank0_only=False)
|
| 355 |
-
|
| 356 |
-
raise ConnectionError("Unable to write {} to. {} attempts tried.".format(filepath, attempt))
|
| 357 |
-
|
| 358 |
-
def fast_put(self, obj, filepath, num_processes: int = 32):
|
| 359 |
-
assert aioboto3 is not None, "aioboto3 is required for fast_put"
|
| 360 |
-
original_filepath = filepath
|
| 361 |
-
filepath = self._check_path(filepath)
|
| 362 |
-
bucket = filepath.split("/")[0]
|
| 363 |
-
key = "/".join(filepath.split("/")[1:])
|
| 364 |
-
part_size = 16 * 1024 * 1024 # 16 MB part size
|
| 365 |
-
|
| 366 |
-
if isinstance(obj, bytes):
|
| 367 |
-
data = obj
|
| 368 |
-
elif isinstance(obj, str) and os.path.isfile(obj):
|
| 369 |
-
with open(obj, "rb") as f:
|
| 370 |
-
data = f.read()
|
| 371 |
-
elif isinstance(obj, io.BytesIO):
|
| 372 |
-
obj.seek(0)
|
| 373 |
-
data = obj.read()
|
| 374 |
-
else:
|
| 375 |
-
raise ValueError("Unsupported object type for upload")
|
| 376 |
-
|
| 377 |
-
file_size = len(data)
|
| 378 |
-
if file_size <= part_size * num_processes:
|
| 379 |
-
return self.put(data, original_filepath)
|
| 380 |
-
num_parts = ceil(file_size / part_size)
|
| 381 |
-
upload_id = self._client.create_multipart_upload(Bucket=bucket, Key=key)["UploadId"]
|
| 382 |
-
|
| 383 |
-
part_numbers = np.array_split(np.arange(num_parts), num_processes)
|
| 384 |
-
|
| 385 |
-
with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor:
|
| 386 |
-
args = []
|
| 387 |
-
for i in range(num_processes):
|
| 388 |
-
cur_parts = part_numbers[i].tolist()
|
| 389 |
-
cur_data = data[cur_parts[0] * part_size : min(cur_parts[-1] * part_size + part_size, file_size)]
|
| 390 |
-
args.append((cur_parts, upload_id, part_size, cur_data, bucket, key, self._s3_cred_info))
|
| 391 |
-
results = executor.map(upload_parts_to_s3, args)
|
| 392 |
-
parts = []
|
| 393 |
-
for result in results:
|
| 394 |
-
parts.extend(result)
|
| 395 |
-
|
| 396 |
-
parts = sorted(parts, key=lambda part: part["PartNumber"])
|
| 397 |
-
self._client.complete_multipart_upload(
|
| 398 |
-
Bucket=bucket, Key=key, UploadId=upload_id, MultipartUpload={"Parts": parts}
|
| 399 |
-
)
|
| 400 |
-
|
| 401 |
-
def contains(self, filepath: str, max_retries=10) -> bool:
|
| 402 |
-
"""
|
| 403 |
-
Checks if the specified object exists in the S3 bucket with retry logic for errors.
|
| 404 |
-
|
| 405 |
-
Args:
|
| 406 |
-
filepath (str): The s3 path of the file to check, must start with "s3://".
|
| 407 |
-
|
| 408 |
-
Returns:
|
| 409 |
-
bool: True if the object exists in the S3 bucket, False otherwise.
|
| 410 |
-
|
| 411 |
-
Raises:
|
| 412 |
-
ClientError: If an error response other than "404 Not Found" is returned from the S3 service.
|
| 413 |
-
"""
|
| 414 |
-
filepath = self._check_path(filepath)
|
| 415 |
-
bucket = filepath.split("/")[0]
|
| 416 |
-
key = "/".join(filepath.split("/")[1:])
|
| 417 |
-
|
| 418 |
-
retries = 0
|
| 419 |
-
while retries < max_retries:
|
| 420 |
-
try:
|
| 421 |
-
# Try to check if the object exists
|
| 422 |
-
self._client.head_object(Bucket=bucket, Key=key)
|
| 423 |
-
return True # Object exists
|
| 424 |
-
except ClientError as e:
|
| 425 |
-
if e.response["Error"]["Code"] == "404":
|
| 426 |
-
return False # Object does not exist
|
| 427 |
-
else:
|
| 428 |
-
retries += 1
|
| 429 |
-
print(f"Attempt {retries} failed with error: {e}")
|
| 430 |
-
if retries >= max_retries:
|
| 431 |
-
raise # Re-raise the exception if max retries are reached
|
| 432 |
-
time.sleep(2) # Wait for 2 seconds before retrying
|
| 433 |
-
except Exception as e:
|
| 434 |
-
retries += 1
|
| 435 |
-
print(f"Attempt {retries} failed due to an unexpected error: {e}")
|
| 436 |
-
if retries >= max_retries:
|
| 437 |
-
raise # Re-raise the exception if max retries are reached
|
| 438 |
-
time.sleep(2) # Wait for 2 seconds before retrying
|
| 439 |
-
|
| 440 |
-
def isdir(self, filepath: str, max_retries=10) -> bool:
|
| 441 |
-
"""
|
| 442 |
-
Determines if the specified path corresponds to a directory in S3 with retry logic.
|
| 443 |
-
|
| 444 |
-
A directory in S3 is implied if there are any objects stored with the given prefix,
|
| 445 |
-
which means this function checks for the existence of any objects at or under the specified path.
|
| 446 |
-
|
| 447 |
-
Args:
|
| 448 |
-
filepath (str): The s3 path to check, must start with "s3://".
|
| 449 |
-
|
| 450 |
-
Returns:
|
| 451 |
-
bool: True if the specified path corresponds to a directory in S3, False otherwise.
|
| 452 |
-
Directories in S3 are not physical entities but are implied by object keys.
|
| 453 |
-
|
| 454 |
-
Raises:
|
| 455 |
-
ClientError: An error from the S3 API that isn't related to the absence of the directory
|
| 456 |
-
(logged but not raised further).
|
| 457 |
-
"""
|
| 458 |
-
filepath = self._check_path(filepath)
|
| 459 |
-
if not filepath.endswith("/"):
|
| 460 |
-
filepath += "/"
|
| 461 |
-
|
| 462 |
-
bucket = filepath.split("/")[0]
|
| 463 |
-
prefix = "/".join(filepath.split("/")[1:])
|
| 464 |
-
|
| 465 |
-
retries = 0
|
| 466 |
-
while retries < max_retries:
|
| 467 |
-
try:
|
| 468 |
-
# Try to check if any objects exist with the given prefix (i.e., directory in S3)
|
| 469 |
-
resp = self._client.list_objects_v2(Bucket=bucket, Prefix=prefix, Delimiter="/", MaxKeys=1)
|
| 470 |
-
# Check if any content or prefixes exist under the given path
|
| 471 |
-
return "CommonPrefixes" in resp or "Contents" in resp
|
| 472 |
-
except ClientError as e:
|
| 473 |
-
retries += 1
|
| 474 |
-
log.error(f"Attempt {retries} failed: {e}", rank0_only=False)
|
| 475 |
-
if retries >= max_retries:
|
| 476 |
-
return False # Return False if maximum retries are reached
|
| 477 |
-
time.sleep(2) # Wait for 2 seconds before retrying
|
| 478 |
-
except Exception as e:
|
| 479 |
-
retries += 1
|
| 480 |
-
log.error(f"Attempt {retries} failed due to an unexpected error: {e}", rank0_only=False)
|
| 481 |
-
if retries >= max_retries:
|
| 482 |
-
return False # Return False if maximum retries are reached
|
| 483 |
-
time.sleep(2) # Wait for 2 seconds before retrying
|
| 484 |
-
|
| 485 |
-
def delete(self, filepath):
|
| 486 |
-
filepath = self._check_path(filepath)
|
| 487 |
-
self._client.delete_object(Bucket=filepath.split("/")[0], Key="/".join(filepath.split("/")[1:]))
|
| 488 |
-
|
| 489 |
-
def ls_dir(self, filepath: str) -> Generator[str, None, None]:
|
| 490 |
-
"""
|
| 491 |
-
List all folders in an S3 bucket with a given prefix.
|
| 492 |
-
|
| 493 |
-
Args:
|
| 494 |
-
filepath (str): The S3 path of the folder to list.
|
| 495 |
-
|
| 496 |
-
Yields:
|
| 497 |
-
str: The keys of the folders in the S3 bucket.
|
| 498 |
-
"""
|
| 499 |
-
filepath = self._check_path(filepath)
|
| 500 |
-
bucket = filepath.split("/")[0]
|
| 501 |
-
prefix = "/".join(filepath.split("/")[1:])
|
| 502 |
-
continuation_token = None
|
| 503 |
-
if prefix and not prefix.endswith("/"):
|
| 504 |
-
prefix += "/"
|
| 505 |
-
|
| 506 |
-
while True:
|
| 507 |
-
if continuation_token:
|
| 508 |
-
resp = self._client.list_objects_v2(
|
| 509 |
-
Bucket=bucket, Prefix=prefix, Delimiter="/", ContinuationToken=continuation_token
|
| 510 |
-
)
|
| 511 |
-
else:
|
| 512 |
-
resp = self._client.list_objects_v2(Bucket=bucket, Prefix=prefix, Delimiter="/")
|
| 513 |
-
|
| 514 |
-
if "CommonPrefixes" in resp:
|
| 515 |
-
for item in resp["CommonPrefixes"]:
|
| 516 |
-
yield item["Prefix"][len(prefix) :]
|
| 517 |
-
|
| 518 |
-
# Check if there are more keys to retrieve
|
| 519 |
-
if resp.get("IsTruncated"): # If IsTruncated is True, there are more keys
|
| 520 |
-
continuation_token = resp.get("NextContinuationToken")
|
| 521 |
-
else:
|
| 522 |
-
break
|
| 523 |
-
|
| 524 |
-
def list(self, filepath: str, exclude_prefix: str = None) -> Generator[str, None, None]:
|
| 525 |
-
"""
|
| 526 |
-
List all keys in an S3 bucket with a given prefix, excluding files that start with
|
| 527 |
-
specified prefix.
|
| 528 |
-
|
| 529 |
-
Args:
|
| 530 |
-
filepath (str): The S3 path of the file to list.
|
| 531 |
-
exclude_prefix (str): Files starting with this prefix will be excluded from results.
|
| 532 |
-
Defaults to "real".
|
| 533 |
-
|
| 534 |
-
Yields:
|
| 535 |
-
str: The keys of the files in the S3 bucket that don't start with exclude_prefix.
|
| 536 |
-
"""
|
| 537 |
-
filepath = self._check_path(filepath)
|
| 538 |
-
bucket = filepath.split("/")[0]
|
| 539 |
-
prefix = "/".join(filepath.split("/")[1:])
|
| 540 |
-
|
| 541 |
-
continuation_token = None
|
| 542 |
-
|
| 543 |
-
while True:
|
| 544 |
-
if continuation_token:
|
| 545 |
-
resp = self._client.list_objects_v2(Bucket=bucket, Prefix=prefix, ContinuationToken=continuation_token)
|
| 546 |
-
else:
|
| 547 |
-
resp = self._client.list_objects_v2(Bucket=bucket, Prefix=prefix)
|
| 548 |
-
|
| 549 |
-
if "Contents" in resp:
|
| 550 |
-
for item in resp["Contents"]:
|
| 551 |
-
key = item["Key"][len(prefix) :]
|
| 552 |
-
# Skip files that start with the excluded prefix
|
| 553 |
-
if exclude_prefix is None or not key.startswith(exclude_prefix):
|
| 554 |
-
yield key
|
| 555 |
-
|
| 556 |
-
# Check if there are more keys to retrieve
|
| 557 |
-
if resp.get("IsTruncated"): # If IsTruncated is True, there are more keys
|
| 558 |
-
continuation_token = resp.get("NextContinuationToken")
|
| 559 |
-
else:
|
| 560 |
-
break
|
| 561 |
-
|
| 562 |
-
def _check_path(self, filepath: str):
|
| 563 |
-
assert filepath.startswith("s3://")
|
| 564 |
-
filepath = filepath[5:]
|
| 565 |
-
return filepath
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/backends/http_backend.py
DELETED
|
@@ -1,91 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
import tempfile
|
| 18 |
-
from contextlib import contextmanager
|
| 19 |
-
from pathlib import Path
|
| 20 |
-
from typing import Generator, Union
|
| 21 |
-
from urllib.request import urlopen
|
| 22 |
-
|
| 23 |
-
from lyra_2._ext.imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
class HTTPBackend(BaseStorageBackend):
|
| 27 |
-
"""HTTP and HTTPS storage bachend."""
|
| 28 |
-
|
| 29 |
-
def get(self, filepath: str) -> bytes:
|
| 30 |
-
"""Read bytes from a given ``filepath``.
|
| 31 |
-
|
| 32 |
-
Args:
|
| 33 |
-
filepath (str): Path to read data.
|
| 34 |
-
|
| 35 |
-
Returns:
|
| 36 |
-
bytes: Expected bytes object.
|
| 37 |
-
|
| 38 |
-
Examples:
|
| 39 |
-
>>> backend = HTTPBackend()
|
| 40 |
-
>>> backend.get('http://path/of/file')
|
| 41 |
-
b'hello world'
|
| 42 |
-
"""
|
| 43 |
-
return urlopen(filepath).read()
|
| 44 |
-
|
| 45 |
-
def get_text(self, filepath, encoding="utf-8") -> str:
|
| 46 |
-
"""Read text from a given ``filepath``.
|
| 47 |
-
|
| 48 |
-
Args:
|
| 49 |
-
filepath (str): Path to read data.
|
| 50 |
-
encoding (str): The encoding format used to open the ``filepath``.
|
| 51 |
-
Defaults to 'utf-8'.
|
| 52 |
-
|
| 53 |
-
Returns:
|
| 54 |
-
str: Expected text reading from ``filepath``.
|
| 55 |
-
|
| 56 |
-
Examples:
|
| 57 |
-
>>> backend = HTTPBackend()
|
| 58 |
-
>>> backend.get_text('http://path/of/file')
|
| 59 |
-
'hello world'
|
| 60 |
-
"""
|
| 61 |
-
return urlopen(filepath).read().decode(encoding)
|
| 62 |
-
|
| 63 |
-
@contextmanager
|
| 64 |
-
def get_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]:
|
| 65 |
-
"""Download a file from ``filepath`` to a local temporary directory,
|
| 66 |
-
and return the temporary path.
|
| 67 |
-
|
| 68 |
-
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
|
| 69 |
-
can be called with ``with`` statement, and when exists from the
|
| 70 |
-
``with`` statement, the temporary path will be released.
|
| 71 |
-
|
| 72 |
-
Args:
|
| 73 |
-
filepath (str): Download a file from ``filepath``.
|
| 74 |
-
|
| 75 |
-
Yields:
|
| 76 |
-
Iterable[str]: Only yield one temporary path.
|
| 77 |
-
|
| 78 |
-
Examples:
|
| 79 |
-
>>> backend = HTTPBackend()
|
| 80 |
-
>>> # After existing from the ``with`` clause,
|
| 81 |
-
>>> # the path will be removed
|
| 82 |
-
>>> with backend.get_local_path('http://path/of/file') as path:
|
| 83 |
-
... # do something here
|
| 84 |
-
"""
|
| 85 |
-
try:
|
| 86 |
-
f = tempfile.NamedTemporaryFile(delete=False)
|
| 87 |
-
f.write(self.get(filepath))
|
| 88 |
-
f.close()
|
| 89 |
-
yield f.name
|
| 90 |
-
finally:
|
| 91 |
-
os.remove(f.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/backends/local_backend.py
DELETED
|
@@ -1,550 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import io
|
| 17 |
-
import os
|
| 18 |
-
import os.path as osp
|
| 19 |
-
import shutil
|
| 20 |
-
from contextlib import contextmanager
|
| 21 |
-
from pathlib import Path
|
| 22 |
-
from typing import Generator, Iterator, Optional, Tuple, Union
|
| 23 |
-
|
| 24 |
-
from lyra_2._ext.imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend, mkdir_or_exist
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
class LocalBackend(BaseStorageBackend):
|
| 28 |
-
"""Raw local storage backend."""
|
| 29 |
-
|
| 30 |
-
_allow_symlink = True
|
| 31 |
-
|
| 32 |
-
def get(self, filepath: Union[str, Path]) -> bytes:
|
| 33 |
-
"""Read bytes from a given ``filepath`` with 'rb' mode.
|
| 34 |
-
|
| 35 |
-
Args:
|
| 36 |
-
filepath (str or Path): Path to read data.
|
| 37 |
-
|
| 38 |
-
Returns:
|
| 39 |
-
bytes: Expected bytes object.
|
| 40 |
-
|
| 41 |
-
Examples:
|
| 42 |
-
>>> backend = LocalBackend()
|
| 43 |
-
>>> filepath = '/path/of/file'
|
| 44 |
-
>>> backend.get(filepath)
|
| 45 |
-
b'hello world'
|
| 46 |
-
"""
|
| 47 |
-
with open(filepath, "rb") as f:
|
| 48 |
-
value = f.read()
|
| 49 |
-
return value
|
| 50 |
-
|
| 51 |
-
def get_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str:
|
| 52 |
-
"""Read text from a given ``filepath`` with 'r' mode.
|
| 53 |
-
|
| 54 |
-
Args:
|
| 55 |
-
filepath (str or Path): Path to read data.
|
| 56 |
-
encoding (str): The encoding format used to open the ``filepath``.
|
| 57 |
-
Defaults to 'utf-8'.
|
| 58 |
-
|
| 59 |
-
Returns:
|
| 60 |
-
str: Expected text reading from ``filepath``.
|
| 61 |
-
|
| 62 |
-
Examples:
|
| 63 |
-
>>> backend = LocalBackend()
|
| 64 |
-
>>> filepath = '/path/of/file'
|
| 65 |
-
>>> backend.get_text(filepath)
|
| 66 |
-
'hello world'
|
| 67 |
-
"""
|
| 68 |
-
with open(filepath, encoding=encoding) as f:
|
| 69 |
-
text = f.read()
|
| 70 |
-
return text
|
| 71 |
-
|
| 72 |
-
def put(self, obj: Union[bytes, io.BytesIO], filepath: Union[str, Path]) -> None:
|
| 73 |
-
"""Write bytes to a given ``filepath`` with 'wb' mode.
|
| 74 |
-
|
| 75 |
-
Note:
|
| 76 |
-
``put`` will create a directory if the directory of
|
| 77 |
-
``filepath`` does not exist.
|
| 78 |
-
|
| 79 |
-
Args:
|
| 80 |
-
obj (bytes): Data to be written.
|
| 81 |
-
filepath (str or Path): Path to write data.
|
| 82 |
-
|
| 83 |
-
Examples:
|
| 84 |
-
>>> backend = LocalBackend()
|
| 85 |
-
>>> filepath = '/path/of/file'
|
| 86 |
-
>>> backend.put(b'hello world', filepath)
|
| 87 |
-
"""
|
| 88 |
-
mkdir_or_exist(osp.dirname(filepath))
|
| 89 |
-
if isinstance(obj, io.BytesIO):
|
| 90 |
-
obj.seek(0)
|
| 91 |
-
obj = obj.getvalue()
|
| 92 |
-
with open(filepath, "wb") as f:
|
| 93 |
-
f.write(obj)
|
| 94 |
-
|
| 95 |
-
def put_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None:
|
| 96 |
-
"""Write text to a given ``filepath`` with 'w' mode.
|
| 97 |
-
|
| 98 |
-
Note:
|
| 99 |
-
``put_text`` will create a directory if the directory of
|
| 100 |
-
``filepath`` does not exist.
|
| 101 |
-
|
| 102 |
-
Args:
|
| 103 |
-
obj (str): Data to be written.
|
| 104 |
-
filepath (str or Path): Path to write data.
|
| 105 |
-
encoding (str): The encoding format used to open the ``filepath``.
|
| 106 |
-
Defaults to 'utf-8'.
|
| 107 |
-
|
| 108 |
-
Examples:
|
| 109 |
-
>>> backend = LocalBackend()
|
| 110 |
-
>>> filepath = '/path/of/file'
|
| 111 |
-
>>> backend.put_text('hello world', filepath)
|
| 112 |
-
"""
|
| 113 |
-
mkdir_or_exist(osp.dirname(filepath))
|
| 114 |
-
with open(filepath, "w", encoding=encoding) as f:
|
| 115 |
-
f.write(obj)
|
| 116 |
-
|
| 117 |
-
def exists(self, filepath: Union[str, Path]) -> bool:
|
| 118 |
-
"""Check whether a file path exists.
|
| 119 |
-
|
| 120 |
-
Args:
|
| 121 |
-
filepath (str or Path): Path to be checked whether exists.
|
| 122 |
-
|
| 123 |
-
Returns:
|
| 124 |
-
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
|
| 125 |
-
|
| 126 |
-
Examples:
|
| 127 |
-
>>> backend = LocalBackend()
|
| 128 |
-
>>> filepath = '/path/of/file'
|
| 129 |
-
>>> backend.exists(filepath)
|
| 130 |
-
True
|
| 131 |
-
"""
|
| 132 |
-
return osp.exists(filepath)
|
| 133 |
-
|
| 134 |
-
def isdir(self, filepath: Union[str, Path]) -> bool:
|
| 135 |
-
"""Check whether a file path is a directory.
|
| 136 |
-
|
| 137 |
-
Args:
|
| 138 |
-
filepath (str or Path): Path to be checked whether it is a
|
| 139 |
-
directory.
|
| 140 |
-
|
| 141 |
-
Returns:
|
| 142 |
-
bool: Return ``True`` if ``filepath`` points to a directory,
|
| 143 |
-
``False`` otherwise.
|
| 144 |
-
|
| 145 |
-
Examples:
|
| 146 |
-
>>> backend = LocalBackend()
|
| 147 |
-
>>> filepath = '/path/of/dir'
|
| 148 |
-
>>> backend.isdir(filepath)
|
| 149 |
-
True
|
| 150 |
-
"""
|
| 151 |
-
return osp.isdir(filepath)
|
| 152 |
-
|
| 153 |
-
def isfile(self, filepath: Union[str, Path]) -> bool:
|
| 154 |
-
"""Check whether a file path is a file.
|
| 155 |
-
|
| 156 |
-
Args:
|
| 157 |
-
filepath (str or Path): Path to be checked whether it is a file.
|
| 158 |
-
|
| 159 |
-
Returns:
|
| 160 |
-
bool: Return ``True`` if ``filepath`` points to a file, ``False``
|
| 161 |
-
otherwise.
|
| 162 |
-
|
| 163 |
-
Examples:
|
| 164 |
-
>>> backend = LocalBackend()
|
| 165 |
-
>>> filepath = '/path/of/file'
|
| 166 |
-
>>> backend.isfile(filepath)
|
| 167 |
-
True
|
| 168 |
-
"""
|
| 169 |
-
return osp.isfile(filepath)
|
| 170 |
-
|
| 171 |
-
def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str:
|
| 172 |
-
r"""Concatenate all file paths.
|
| 173 |
-
|
| 174 |
-
Join one or more filepath components intelligently. The return value
|
| 175 |
-
is the concatenation of filepath and any members of \*filepaths.
|
| 176 |
-
|
| 177 |
-
Args:
|
| 178 |
-
filepath (str or Path): Path to be concatenated.
|
| 179 |
-
|
| 180 |
-
Returns:
|
| 181 |
-
str: The result of concatenation.
|
| 182 |
-
|
| 183 |
-
Examples:
|
| 184 |
-
>>> backend = LocalBackend()
|
| 185 |
-
>>> filepath1 = '/path/of/dir1'
|
| 186 |
-
>>> filepath2 = 'dir2'
|
| 187 |
-
>>> filepath3 = 'path/of/file'
|
| 188 |
-
>>> backend.join_path(filepath1, filepath2, filepath3)
|
| 189 |
-
'/path/of/dir/dir2/path/of/file'
|
| 190 |
-
"""
|
| 191 |
-
return osp.join(filepath, *filepaths)
|
| 192 |
-
|
| 193 |
-
@contextmanager
|
| 194 |
-
def get_local_path(
|
| 195 |
-
self,
|
| 196 |
-
filepath: Union[str, Path],
|
| 197 |
-
) -> Generator[Union[str, Path], None, None]:
|
| 198 |
-
"""Only for unified API and do nothing.
|
| 199 |
-
|
| 200 |
-
Args:
|
| 201 |
-
filepath (str or Path): Path to be read data.
|
| 202 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 203 |
-
corresponding backend. Defaults to None.
|
| 204 |
-
|
| 205 |
-
Examples:
|
| 206 |
-
>>> backend = LocalBackend()
|
| 207 |
-
>>> with backend.get_local_path('s3://bucket/abc.jpg') as path:
|
| 208 |
-
... # do something here
|
| 209 |
-
"""
|
| 210 |
-
yield filepath
|
| 211 |
-
|
| 212 |
-
def copyfile(
|
| 213 |
-
self,
|
| 214 |
-
src: Union[str, Path],
|
| 215 |
-
dst: Union[str, Path],
|
| 216 |
-
) -> str:
|
| 217 |
-
"""Copy a file src to dst and return the destination file.
|
| 218 |
-
|
| 219 |
-
src and dst should have the same prefix. If dst specifies a directory,
|
| 220 |
-
the file will be copied into dst using the base filename from src. If
|
| 221 |
-
dst specifies a file that already exists, it will be replaced.
|
| 222 |
-
|
| 223 |
-
Args:
|
| 224 |
-
src (str or Path): A file to be copied.
|
| 225 |
-
dst (str or Path): Copy file to dst.
|
| 226 |
-
|
| 227 |
-
Returns:
|
| 228 |
-
str: The destination file.
|
| 229 |
-
|
| 230 |
-
Raises:
|
| 231 |
-
SameFileError: If src and dst are the same file, a SameFileError
|
| 232 |
-
will be raised.
|
| 233 |
-
|
| 234 |
-
Examples:
|
| 235 |
-
>>> backend = LocalBackend()
|
| 236 |
-
>>> # dst is a file
|
| 237 |
-
>>> src = '/path/of/file'
|
| 238 |
-
>>> dst = '/path1/of/file1'
|
| 239 |
-
>>> # src will be copied to '/path1/of/file1'
|
| 240 |
-
>>> backend.copyfile(src, dst)
|
| 241 |
-
'/path1/of/file1'
|
| 242 |
-
|
| 243 |
-
>>> # dst is a directory
|
| 244 |
-
>>> dst = '/path1/of/dir'
|
| 245 |
-
>>> # src will be copied to '/path1/of/dir/file'
|
| 246 |
-
>>> backend.copyfile(src, dst)
|
| 247 |
-
'/path1/of/dir/file'
|
| 248 |
-
"""
|
| 249 |
-
return shutil.copy(src, dst)
|
| 250 |
-
|
| 251 |
-
def copytree(
|
| 252 |
-
self,
|
| 253 |
-
src: Union[str, Path],
|
| 254 |
-
dst: Union[str, Path],
|
| 255 |
-
) -> str:
|
| 256 |
-
"""Recursively copy an entire directory tree rooted at src to a
|
| 257 |
-
directory named dst and return the destination directory.
|
| 258 |
-
|
| 259 |
-
src and dst should have the same prefix and dst must not already exist.
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
Args:
|
| 263 |
-
src (str or Path): A directory to be copied.
|
| 264 |
-
dst (str or Path): Copy directory to dst.
|
| 265 |
-
|
| 266 |
-
Returns:
|
| 267 |
-
str: The destination directory.
|
| 268 |
-
|
| 269 |
-
Raises:
|
| 270 |
-
FileExistsError: If dst had already existed, a FileExistsError will
|
| 271 |
-
be raised.
|
| 272 |
-
|
| 273 |
-
Examples:
|
| 274 |
-
>>> backend = LocalBackend()
|
| 275 |
-
>>> src = '/path/of/dir1'
|
| 276 |
-
>>> dst = '/path/of/dir2'
|
| 277 |
-
>>> backend.copytree(src, dst)
|
| 278 |
-
'/path/of/dir2'
|
| 279 |
-
"""
|
| 280 |
-
return shutil.copytree(src, dst)
|
| 281 |
-
|
| 282 |
-
def copyfile_from_local(
|
| 283 |
-
self,
|
| 284 |
-
src: Union[str, Path],
|
| 285 |
-
dst: Union[str, Path],
|
| 286 |
-
) -> str:
|
| 287 |
-
"""Copy a local file src to dst and return the destination file. Same
|
| 288 |
-
as :meth:`copyfile`.
|
| 289 |
-
|
| 290 |
-
Args:
|
| 291 |
-
src (str or Path): A local file to be copied.
|
| 292 |
-
dst (str or Path): Copy file to dst.
|
| 293 |
-
|
| 294 |
-
Returns:
|
| 295 |
-
str: If dst specifies a directory, the file will be copied into dst
|
| 296 |
-
using the base filename from src.
|
| 297 |
-
|
| 298 |
-
Raises:
|
| 299 |
-
SameFileError: If src and dst are the same file, a SameFileError
|
| 300 |
-
will be raised.
|
| 301 |
-
|
| 302 |
-
Examples:
|
| 303 |
-
>>> backend = LocalBackend()
|
| 304 |
-
>>> # dst is a file
|
| 305 |
-
>>> src = '/path/of/file'
|
| 306 |
-
>>> dst = '/path1/of/file1'
|
| 307 |
-
>>> # src will be copied to '/path1/of/file1'
|
| 308 |
-
>>> backend.copyfile_from_local(src, dst)
|
| 309 |
-
'/path1/of/file1'
|
| 310 |
-
|
| 311 |
-
>>> # dst is a directory
|
| 312 |
-
>>> dst = '/path1/of/dir'
|
| 313 |
-
>>> # src will be copied to
|
| 314 |
-
>>> backend.copyfile_from_local(src, dst)
|
| 315 |
-
'/path1/of/dir/file'
|
| 316 |
-
"""
|
| 317 |
-
return self.copyfile(src, dst)
|
| 318 |
-
|
| 319 |
-
def copytree_from_local(
|
| 320 |
-
self,
|
| 321 |
-
src: Union[str, Path],
|
| 322 |
-
dst: Union[str, Path],
|
| 323 |
-
) -> str:
|
| 324 |
-
"""Recursively copy an entire directory tree rooted at src to a
|
| 325 |
-
directory named dst and return the destination directory. Same as
|
| 326 |
-
:meth:`copytree`.
|
| 327 |
-
|
| 328 |
-
Args:
|
| 329 |
-
src (str or Path): A local directory to be copied.
|
| 330 |
-
dst (str or Path): Copy directory to dst.
|
| 331 |
-
|
| 332 |
-
Returns:
|
| 333 |
-
str: The destination directory.
|
| 334 |
-
|
| 335 |
-
Examples:
|
| 336 |
-
>>> backend = LocalBackend()
|
| 337 |
-
>>> src = '/path/of/dir1'
|
| 338 |
-
>>> dst = '/path/of/dir2'
|
| 339 |
-
>>> backend.copytree_from_local(src, dst)
|
| 340 |
-
'/path/of/dir2'
|
| 341 |
-
"""
|
| 342 |
-
return self.copytree(src, dst)
|
| 343 |
-
|
| 344 |
-
def copyfile_to_local(
|
| 345 |
-
self,
|
| 346 |
-
src: Union[str, Path],
|
| 347 |
-
dst: Union[str, Path],
|
| 348 |
-
dst_type: Optional[str] = None,
|
| 349 |
-
) -> str:
|
| 350 |
-
"""Copy the file src to local dst and return the destination file. Same
|
| 351 |
-
as :meth:`copyfile`.
|
| 352 |
-
|
| 353 |
-
If dst specifies a directory, the file will be copied into dst using
|
| 354 |
-
the base filename from src. If dst specifies a file that already
|
| 355 |
-
exists, it will be replaced.
|
| 356 |
-
|
| 357 |
-
Args:
|
| 358 |
-
src (str or Path): A file to be copied.
|
| 359 |
-
dst (str or Path): Copy file to to local dst.
|
| 360 |
-
|
| 361 |
-
Returns:
|
| 362 |
-
str: If dst specifies a directory, the file will be copied into dst
|
| 363 |
-
using the base filename from src.
|
| 364 |
-
|
| 365 |
-
Examples:
|
| 366 |
-
>>> backend = LocalBackend()
|
| 367 |
-
>>> # dst is a file
|
| 368 |
-
>>> src = '/path/of/file'
|
| 369 |
-
>>> dst = '/path1/of/file1'
|
| 370 |
-
>>> # src will be copied to '/path1/of/file1'
|
| 371 |
-
>>> backend.copyfile_to_local(src, dst)
|
| 372 |
-
'/path1/of/file1'
|
| 373 |
-
|
| 374 |
-
>>> # dst is a directory
|
| 375 |
-
>>> dst = '/path1/of/dir'
|
| 376 |
-
>>> # src will be copied to
|
| 377 |
-
>>> backend.copyfile_to_local(src, dst)
|
| 378 |
-
'/path1/of/dir/file'
|
| 379 |
-
"""
|
| 380 |
-
return self.copyfile(src, dst)
|
| 381 |
-
|
| 382 |
-
def copytree_to_local(
|
| 383 |
-
self,
|
| 384 |
-
src: Union[str, Path],
|
| 385 |
-
dst: Union[str, Path],
|
| 386 |
-
) -> str:
|
| 387 |
-
"""Recursively copy an entire directory tree rooted at src to a local
|
| 388 |
-
directory named dst and return the destination directory.
|
| 389 |
-
|
| 390 |
-
Args:
|
| 391 |
-
src (str or Path): A directory to be copied.
|
| 392 |
-
dst (str or Path): Copy directory to local dst.
|
| 393 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 394 |
-
prefix of uri corresponding backend. Defaults to None.
|
| 395 |
-
|
| 396 |
-
Returns:
|
| 397 |
-
str: The destination directory.
|
| 398 |
-
|
| 399 |
-
Examples:
|
| 400 |
-
>>> backend = LocalBackend()
|
| 401 |
-
>>> src = '/path/of/dir1'
|
| 402 |
-
>>> dst = '/path/of/dir2'
|
| 403 |
-
>>> backend.copytree_from_local(src, dst)
|
| 404 |
-
'/path/of/dir2'
|
| 405 |
-
"""
|
| 406 |
-
return self.copytree(src, dst)
|
| 407 |
-
|
| 408 |
-
def remove(self, filepath: Union[str, Path]) -> None:
|
| 409 |
-
"""Remove a file.
|
| 410 |
-
|
| 411 |
-
Args:
|
| 412 |
-
filepath (str or Path): Path to be removed.
|
| 413 |
-
|
| 414 |
-
Raises:
|
| 415 |
-
IsADirectoryError: If filepath is a directory, an IsADirectoryError
|
| 416 |
-
will be raised.
|
| 417 |
-
FileNotFoundError: If filepath does not exist, an FileNotFoundError
|
| 418 |
-
will be raised.
|
| 419 |
-
|
| 420 |
-
Examples:
|
| 421 |
-
>>> backend = LocalBackend()
|
| 422 |
-
>>> filepath = '/path/of/file'
|
| 423 |
-
>>> backend.remove(filepath)
|
| 424 |
-
"""
|
| 425 |
-
if not self.exists(filepath):
|
| 426 |
-
raise FileNotFoundError(f"filepath {filepath} does not exist")
|
| 427 |
-
|
| 428 |
-
if self.isdir(filepath):
|
| 429 |
-
raise IsADirectoryError("filepath should be a file")
|
| 430 |
-
|
| 431 |
-
os.remove(filepath)
|
| 432 |
-
|
| 433 |
-
def rmtree(self, dir_path: Union[str, Path]) -> None:
|
| 434 |
-
"""Recursively delete a directory tree.
|
| 435 |
-
|
| 436 |
-
Args:
|
| 437 |
-
dir_path (str or Path): A directory to be removed.
|
| 438 |
-
|
| 439 |
-
Examples:
|
| 440 |
-
>>> dir_path = '/path/of/dir'
|
| 441 |
-
>>> backend.rmtree(dir_path)
|
| 442 |
-
"""
|
| 443 |
-
shutil.rmtree(dir_path)
|
| 444 |
-
|
| 445 |
-
def copy_if_symlink_fails(
|
| 446 |
-
self,
|
| 447 |
-
src: Union[str, Path],
|
| 448 |
-
dst: Union[str, Path],
|
| 449 |
-
) -> bool:
|
| 450 |
-
"""Create a symbolic link pointing to src named dst.
|
| 451 |
-
|
| 452 |
-
If failed to create a symbolic link pointing to src, directly copy src
|
| 453 |
-
to dst instead.
|
| 454 |
-
|
| 455 |
-
Args:
|
| 456 |
-
src (str or Path): Create a symbolic link pointing to src.
|
| 457 |
-
dst (str or Path): Create a symbolic link named dst.
|
| 458 |
-
|
| 459 |
-
Returns:
|
| 460 |
-
bool: Return True if successfully create a symbolic link pointing
|
| 461 |
-
to src. Otherwise, return False.
|
| 462 |
-
|
| 463 |
-
Examples:
|
| 464 |
-
>>> backend = LocalBackend()
|
| 465 |
-
>>> src = '/path/of/file'
|
| 466 |
-
>>> dst = '/path1/of/file1'
|
| 467 |
-
>>> backend.copy_if_symlink_fails(src, dst)
|
| 468 |
-
True
|
| 469 |
-
>>> src = '/path/of/dir'
|
| 470 |
-
>>> dst = '/path1/of/dir1'
|
| 471 |
-
>>> backend.copy_if_symlink_fails(src, dst)
|
| 472 |
-
True
|
| 473 |
-
"""
|
| 474 |
-
try:
|
| 475 |
-
os.symlink(src, dst)
|
| 476 |
-
return True
|
| 477 |
-
except Exception:
|
| 478 |
-
if self.isfile(src):
|
| 479 |
-
self.copyfile(src, dst)
|
| 480 |
-
else:
|
| 481 |
-
self.copytree(src, dst)
|
| 482 |
-
return False
|
| 483 |
-
|
| 484 |
-
def list_dir_or_file(
|
| 485 |
-
self,
|
| 486 |
-
dir_path: Union[str, Path],
|
| 487 |
-
list_dir: bool = True,
|
| 488 |
-
list_file: bool = True,
|
| 489 |
-
suffix: Optional[Union[str, Tuple[str]]] = None,
|
| 490 |
-
recursive: bool = False,
|
| 491 |
-
) -> Iterator[str]:
|
| 492 |
-
"""Scan a directory to find the interested directories or files in
|
| 493 |
-
arbitrary order.
|
| 494 |
-
|
| 495 |
-
Note:
|
| 496 |
-
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
|
| 497 |
-
|
| 498 |
-
Args:
|
| 499 |
-
dir_path (str or Path): Path of the directory.
|
| 500 |
-
list_dir (bool): List the directories. Defaults to True.
|
| 501 |
-
list_file (bool): List the path of files. Defaults to True.
|
| 502 |
-
suffix (str or tuple[str], optional): File suffix that we are
|
| 503 |
-
interested in. Defaults to None.
|
| 504 |
-
recursive (bool): If set to True, recursively scan the directory.
|
| 505 |
-
Defaults to False.
|
| 506 |
-
|
| 507 |
-
Yields:
|
| 508 |
-
Iterable[str]: A relative path to ``dir_path``.
|
| 509 |
-
|
| 510 |
-
Examples:
|
| 511 |
-
>>> backend = LocalBackend()
|
| 512 |
-
>>> dir_path = '/path/of/dir'
|
| 513 |
-
>>> # list those files and directories in current directory
|
| 514 |
-
>>> for file_path in backend.list_dir_or_file(dir_path):
|
| 515 |
-
... print(file_path)
|
| 516 |
-
>>> # only list files
|
| 517 |
-
>>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False):
|
| 518 |
-
... print(file_path)
|
| 519 |
-
>>> # only list directories
|
| 520 |
-
>>> for file_path in backend.list_dir_or_file(dir_path, list_file=False):
|
| 521 |
-
... print(file_path)
|
| 522 |
-
>>> # only list files ending with specified suffixes
|
| 523 |
-
>>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'):
|
| 524 |
-
... print(file_path)
|
| 525 |
-
>>> # list all files and directory recursively
|
| 526 |
-
>>> for file_path in backend.list_dir_or_file(dir_path, recursive=True):
|
| 527 |
-
... print(file_path)
|
| 528 |
-
""" # noqa: E501
|
| 529 |
-
if list_dir and suffix is not None:
|
| 530 |
-
raise TypeError("`suffix` should be None when `list_dir` is True")
|
| 531 |
-
|
| 532 |
-
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
| 533 |
-
raise TypeError("`suffix` must be a string or tuple of strings")
|
| 534 |
-
|
| 535 |
-
root = dir_path
|
| 536 |
-
|
| 537 |
-
def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive):
|
| 538 |
-
for entry in os.scandir(dir_path):
|
| 539 |
-
if not entry.name.startswith(".") and entry.is_file():
|
| 540 |
-
rel_path = osp.relpath(entry.path, root)
|
| 541 |
-
if (suffix is None or rel_path.endswith(suffix)) and list_file:
|
| 542 |
-
yield rel_path
|
| 543 |
-
elif osp.isdir(entry.path):
|
| 544 |
-
if list_dir:
|
| 545 |
-
rel_dir = osp.relpath(entry.path, root)
|
| 546 |
-
yield rel_dir
|
| 547 |
-
if recursive:
|
| 548 |
-
yield from _list_dir_or_file(entry.path, list_dir, list_file, suffix, recursive)
|
| 549 |
-
|
| 550 |
-
return _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/backends/registry_utils.py
DELETED
|
@@ -1,130 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import inspect
|
| 17 |
-
from typing import Optional, Type, Union
|
| 18 |
-
|
| 19 |
-
from lyra_2._ext.imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend
|
| 20 |
-
from lyra_2._ext.imaginaire.utils.easy_io.backends.boto3_backend import Boto3Backend
|
| 21 |
-
from lyra_2._ext.imaginaire.utils.easy_io.backends.http_backend import HTTPBackend
|
| 22 |
-
from lyra_2._ext.imaginaire.utils.easy_io.backends.local_backend import LocalBackend
|
| 23 |
-
|
| 24 |
-
backends: dict = {}
|
| 25 |
-
prefix_to_backends: dict = {}
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def _register_backend(
|
| 29 |
-
name: str,
|
| 30 |
-
backend: Type[BaseStorageBackend],
|
| 31 |
-
force: bool = False,
|
| 32 |
-
prefixes: Union[str, list, tuple, None] = None,
|
| 33 |
-
):
|
| 34 |
-
"""Register a backend.
|
| 35 |
-
|
| 36 |
-
Args:
|
| 37 |
-
name (str): The name of the registered backend.
|
| 38 |
-
backend (BaseStorageBackend): The backend class to be registered,
|
| 39 |
-
which must be a subclass of :class:`BaseStorageBackend`.
|
| 40 |
-
force (bool): Whether to override the backend if the name has already
|
| 41 |
-
been registered. Defaults to False.
|
| 42 |
-
prefixes (str or list[str] or tuple[str], optional): The prefix
|
| 43 |
-
of the registered storage backend. Defaults to None.
|
| 44 |
-
"""
|
| 45 |
-
global backends, prefix_to_backends
|
| 46 |
-
|
| 47 |
-
if not isinstance(name, str):
|
| 48 |
-
raise TypeError(f"the backend name should be a string, but got {type(name)}")
|
| 49 |
-
|
| 50 |
-
if not inspect.isclass(backend):
|
| 51 |
-
raise TypeError(f"backend should be a class, but got {type(backend)}")
|
| 52 |
-
if not issubclass(backend, BaseStorageBackend):
|
| 53 |
-
raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend")
|
| 54 |
-
|
| 55 |
-
if name in backends and not force:
|
| 56 |
-
raise ValueError(
|
| 57 |
-
f'{name} is already registered as a storage backend, add "force=True" if you want to override it'
|
| 58 |
-
)
|
| 59 |
-
backends[name] = backend
|
| 60 |
-
|
| 61 |
-
if prefixes is not None:
|
| 62 |
-
if isinstance(prefixes, str):
|
| 63 |
-
prefixes = [prefixes]
|
| 64 |
-
else:
|
| 65 |
-
assert isinstance(prefixes, (list, tuple))
|
| 66 |
-
|
| 67 |
-
for prefix in prefixes:
|
| 68 |
-
if prefix in prefix_to_backends and not force:
|
| 69 |
-
raise ValueError(
|
| 70 |
-
f'{prefix} is already registered as a storage backend, add "force=True" if you want to override it'
|
| 71 |
-
)
|
| 72 |
-
|
| 73 |
-
prefix_to_backends[prefix] = backend
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def register_backend(
|
| 77 |
-
name: str,
|
| 78 |
-
backend: Optional[Type[BaseStorageBackend]] = None,
|
| 79 |
-
force: bool = False,
|
| 80 |
-
prefixes: Union[str, list, tuple, None] = None,
|
| 81 |
-
):
|
| 82 |
-
"""Register a backend.
|
| 83 |
-
|
| 84 |
-
Args:
|
| 85 |
-
name (str): The name of the registered backend.
|
| 86 |
-
backend (class, optional): The backend class to be registered,
|
| 87 |
-
which must be a subclass of :class:`BaseStorageBackend`.
|
| 88 |
-
When this method is used as a decorator, backend is None.
|
| 89 |
-
Defaults to None.
|
| 90 |
-
force (bool): Whether to override the backend if the name has already
|
| 91 |
-
been registered. Defaults to False.
|
| 92 |
-
prefixes (str or list[str] or tuple[str], optional): The prefix
|
| 93 |
-
of the registered storage backend. Defaults to None.
|
| 94 |
-
|
| 95 |
-
This method can be used as a normal method or a decorator.
|
| 96 |
-
|
| 97 |
-
Examples:
|
| 98 |
-
|
| 99 |
-
>>> class NewBackend(BaseStorageBackend):
|
| 100 |
-
... def get(self, filepath):
|
| 101 |
-
... return filepath
|
| 102 |
-
...
|
| 103 |
-
... def get_text(self, filepath):
|
| 104 |
-
... return filepath
|
| 105 |
-
>>> register_backend('new', NewBackend)
|
| 106 |
-
|
| 107 |
-
>>> @register_backend('new')
|
| 108 |
-
... class NewBackend(BaseStorageBackend):
|
| 109 |
-
... def get(self, filepath):
|
| 110 |
-
... return filepath
|
| 111 |
-
...
|
| 112 |
-
... def get_text(self, filepath):
|
| 113 |
-
... return filepath
|
| 114 |
-
"""
|
| 115 |
-
if backend is not None:
|
| 116 |
-
_register_backend(name, backend, force=force, prefixes=prefixes)
|
| 117 |
-
return
|
| 118 |
-
|
| 119 |
-
def _register(backend_cls):
|
| 120 |
-
_register_backend(name, backend_cls, force=force, prefixes=prefixes)
|
| 121 |
-
return backend_cls
|
| 122 |
-
|
| 123 |
-
return _register
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
register_backend("local", LocalBackend, prefixes="")
|
| 127 |
-
# To avoid breaking backward Compatibility, 's3' is also used as a
|
| 128 |
-
# prefix for Boto3Backend
|
| 129 |
-
register_backend("s3", Boto3Backend, prefixes=["s3"])
|
| 130 |
-
register_backend("http", HTTPBackend, prefixes=["http", "https"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/easy_io.py
DELETED
|
@@ -1,1085 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import json
|
| 17 |
-
import warnings
|
| 18 |
-
from contextlib import contextmanager
|
| 19 |
-
from io import BytesIO, StringIO
|
| 20 |
-
from pathlib import Path
|
| 21 |
-
from typing import IO, Any, Generator, Iterator, Optional, Tuple, Union
|
| 22 |
-
|
| 23 |
-
from lyra_2._ext.imaginaire.utils.easy_io.backends import backends, prefix_to_backends
|
| 24 |
-
from lyra_2._ext.imaginaire.utils.easy_io.file_client import FileClient
|
| 25 |
-
from lyra_2._ext.imaginaire.utils.easy_io.handlers import file_handlers
|
| 26 |
-
|
| 27 |
-
backend_instances: dict = {}
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def is_filepath(filepath):
|
| 31 |
-
return isinstance(filepath, (str, Path))
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def _parse_uri_prefix(uri: Union[str, Path]) -> str:
|
| 35 |
-
"""Parse the prefix of uri.
|
| 36 |
-
|
| 37 |
-
Args:
|
| 38 |
-
uri (str or Path): Uri to be parsed that contains the file prefix.
|
| 39 |
-
|
| 40 |
-
Examples:
|
| 41 |
-
>>> _parse_uri_prefix('/home/path/of/your/file')
|
| 42 |
-
''
|
| 43 |
-
>>> _parse_uri_prefix('s3://path/of/your/file')
|
| 44 |
-
's3'
|
| 45 |
-
>>> _parse_uri_prefix('clusterName:s3://path/of/your/file')
|
| 46 |
-
's3'
|
| 47 |
-
|
| 48 |
-
Returns:
|
| 49 |
-
str: Return the prefix of uri if the uri contains '://'. Otherwise,
|
| 50 |
-
return ''.
|
| 51 |
-
"""
|
| 52 |
-
assert is_filepath(uri)
|
| 53 |
-
uri = str(uri)
|
| 54 |
-
# if uri does not contains '://', the uri will be handled by
|
| 55 |
-
# LocalBackend by default
|
| 56 |
-
if "://" not in uri:
|
| 57 |
-
return ""
|
| 58 |
-
else:
|
| 59 |
-
prefix, _ = uri.split("://")
|
| 60 |
-
# In the case of Boto3Backend, the prefix may contain the cluster
|
| 61 |
-
# name like clusterName:s3://path/of/your/file
|
| 62 |
-
if ":" in prefix:
|
| 63 |
-
_, prefix = prefix.split(":")
|
| 64 |
-
return prefix
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def _get_file_backend(prefix: str, backend_args: dict):
|
| 68 |
-
"""Return a file backend based on the prefix or backend_args.
|
| 69 |
-
|
| 70 |
-
Args:
|
| 71 |
-
prefix (str): Prefix of uri.
|
| 72 |
-
backend_args (dict): Arguments to instantiate the corresponding
|
| 73 |
-
backend.
|
| 74 |
-
"""
|
| 75 |
-
# backend name has a higher priority
|
| 76 |
-
if "backend" in backend_args:
|
| 77 |
-
# backend_args should not be modified
|
| 78 |
-
backend_args_bak = backend_args.copy()
|
| 79 |
-
backend_name = backend_args_bak.pop("backend")
|
| 80 |
-
backend = backends[backend_name](**backend_args_bak)
|
| 81 |
-
else:
|
| 82 |
-
backend = prefix_to_backends[prefix](**backend_args)
|
| 83 |
-
return backend
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
def set_s3_backend(
|
| 87 |
-
key: str = "s3:{}",
|
| 88 |
-
backend_args: Optional[dict] = None,
|
| 89 |
-
):
|
| 90 |
-
"""register s3 backend.
|
| 91 |
-
|
| 92 |
-
Args:
|
| 93 |
-
key str: The key to register the s3 backend. Defaults to s3.
|
| 94 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 95 |
-
corresponding backend. Defaults to None.
|
| 96 |
-
"""
|
| 97 |
-
global backend_instances
|
| 98 |
-
if backend_args is None:
|
| 99 |
-
backend_args = {}
|
| 100 |
-
backend = _get_file_backend(key, backend_args)
|
| 101 |
-
backend_instances[key] = backend
|
| 102 |
-
return backend
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
def get_file_backend(
|
| 106 |
-
uri: Union[str, Path, None] = None,
|
| 107 |
-
*,
|
| 108 |
-
backend_args: Optional[dict] = None,
|
| 109 |
-
enable_singleton: bool = False,
|
| 110 |
-
backend_key: Optional[str] = None,
|
| 111 |
-
):
|
| 112 |
-
"""Return a file backend based on the prefix of uri or backend_args.
|
| 113 |
-
|
| 114 |
-
Args:
|
| 115 |
-
uri (str or Path): Uri to be parsed that contains the file prefix.
|
| 116 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 117 |
-
corresponding backend. Defaults to None.
|
| 118 |
-
enable_singleton (bool): Whether to enable the singleton pattern.
|
| 119 |
-
If it is True, the backend created will be reused if the
|
| 120 |
-
signature is same with the previous one. Defaults to False.
|
| 121 |
-
backend_key: str: The key to register the backend. Defaults to None.
|
| 122 |
-
|
| 123 |
-
Returns:
|
| 124 |
-
BaseStorageBackend: Instantiated Backend object.
|
| 125 |
-
|
| 126 |
-
Examples:
|
| 127 |
-
>>> # get file backend based on the prefix of uri
|
| 128 |
-
>>> uri = 's3://path/of/your/file'
|
| 129 |
-
>>> backend = get_file_backend(uri)
|
| 130 |
-
>>> # get file backend based on the backend_args
|
| 131 |
-
>>> backend = get_file_backend(backend_args={'backend': 's3'})
|
| 132 |
-
>>> # backend name has a higher priority if 'backend' in backend_args
|
| 133 |
-
>>> backend = get_file_backend(uri, backend_args={'backend': 's3'})
|
| 134 |
-
"""
|
| 135 |
-
global backend_instances
|
| 136 |
-
if backend_key is not None:
|
| 137 |
-
if backend_key in backend_instances:
|
| 138 |
-
return backend_instances[backend_key]
|
| 139 |
-
|
| 140 |
-
if backend_args is None:
|
| 141 |
-
backend_args = {}
|
| 142 |
-
|
| 143 |
-
if uri is None and "backend" not in backend_args and backend_key is None:
|
| 144 |
-
raise ValueError('uri should not be None when "backend" does not exist in backend_args and backend_key is None')
|
| 145 |
-
|
| 146 |
-
if uri is not None:
|
| 147 |
-
prefix = _parse_uri_prefix(uri)
|
| 148 |
-
else:
|
| 149 |
-
prefix = ""
|
| 150 |
-
|
| 151 |
-
if enable_singleton:
|
| 152 |
-
unique_key = f"{prefix}:{json.dumps(backend_args)}"
|
| 153 |
-
if unique_key in backend_instances:
|
| 154 |
-
return backend_instances[unique_key]
|
| 155 |
-
|
| 156 |
-
backend = _get_file_backend(prefix, backend_args)
|
| 157 |
-
backend_instances[unique_key] = backend
|
| 158 |
-
if backend_key is not None:
|
| 159 |
-
backend_instances[backend_key] = backend
|
| 160 |
-
return backend
|
| 161 |
-
else:
|
| 162 |
-
backend = _get_file_backend(prefix, backend_args)
|
| 163 |
-
return backend
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
def get(
|
| 167 |
-
filepath: Union[str, Path],
|
| 168 |
-
backend_args: Optional[dict] = None,
|
| 169 |
-
backend_key: Optional[str] = None,
|
| 170 |
-
) -> bytes:
|
| 171 |
-
"""Read bytes from a given ``filepath`` with 'rb' mode.
|
| 172 |
-
|
| 173 |
-
Args:
|
| 174 |
-
filepath (str or Path): Path to read data.
|
| 175 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 176 |
-
corresponding backend. Defaults to None.
|
| 177 |
-
backend_key (str, optional): The key to get the backend from register.
|
| 178 |
-
|
| 179 |
-
Returns:
|
| 180 |
-
bytes: Expected bytes object.
|
| 181 |
-
|
| 182 |
-
Examples:
|
| 183 |
-
>>> filepath = '/path/of/file'
|
| 184 |
-
>>> get(filepath)
|
| 185 |
-
b'hello world'
|
| 186 |
-
"""
|
| 187 |
-
backend = get_file_backend(
|
| 188 |
-
filepath,
|
| 189 |
-
backend_args=backend_args,
|
| 190 |
-
enable_singleton=True,
|
| 191 |
-
backend_key=backend_key,
|
| 192 |
-
)
|
| 193 |
-
return backend.get(filepath)
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
def get_text(
|
| 197 |
-
filepath: Union[str, Path],
|
| 198 |
-
encoding="utf-8",
|
| 199 |
-
backend_args: Optional[dict] = None,
|
| 200 |
-
backend_key: Optional[str] = None,
|
| 201 |
-
) -> str:
|
| 202 |
-
"""Read text from a given ``filepath`` with 'r' mode.
|
| 203 |
-
|
| 204 |
-
Args:
|
| 205 |
-
filepath (str or Path): Path to read data.
|
| 206 |
-
encoding (str): The encoding format used to open the ``filepath``.
|
| 207 |
-
Defaults to 'utf-8'.
|
| 208 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 209 |
-
corresponding backend. Defaults to None.
|
| 210 |
-
backend_key (str, optional): The key to get the backend from register.
|
| 211 |
-
|
| 212 |
-
Returns:
|
| 213 |
-
str: Expected text reading from ``filepath``.
|
| 214 |
-
|
| 215 |
-
Examples:
|
| 216 |
-
>>> filepath = '/path/of/file'
|
| 217 |
-
>>> get_text(filepath)
|
| 218 |
-
'hello world'
|
| 219 |
-
"""
|
| 220 |
-
backend = get_file_backend(
|
| 221 |
-
filepath,
|
| 222 |
-
backend_args=backend_args,
|
| 223 |
-
enable_singleton=True,
|
| 224 |
-
backend_key=backend_key,
|
| 225 |
-
)
|
| 226 |
-
return backend.get_text(filepath, encoding)
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
def put(
|
| 230 |
-
obj: bytes,
|
| 231 |
-
filepath: Union[str, Path],
|
| 232 |
-
backend_args: Optional[dict] = None,
|
| 233 |
-
backend_key: Optional[str] = None,
|
| 234 |
-
) -> None:
|
| 235 |
-
"""Write bytes to a given ``filepath`` with 'wb' mode.
|
| 236 |
-
|
| 237 |
-
Note:
|
| 238 |
-
``put`` should create a directory if the directory of
|
| 239 |
-
``filepath`` does not exist.
|
| 240 |
-
|
| 241 |
-
Args:
|
| 242 |
-
obj (bytes): Data to be written.
|
| 243 |
-
filepath (str or Path): Path to write data.
|
| 244 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 245 |
-
corresponding backend. Defaults to None.
|
| 246 |
-
backend_key (str, optional): The key to get the backend from register.
|
| 247 |
-
|
| 248 |
-
Examples:
|
| 249 |
-
>>> filepath = '/path/of/file'
|
| 250 |
-
>>> put(b'hello world', filepath)
|
| 251 |
-
"""
|
| 252 |
-
backend = get_file_backend(
|
| 253 |
-
filepath,
|
| 254 |
-
backend_args=backend_args,
|
| 255 |
-
enable_singleton=True,
|
| 256 |
-
backend_key=backend_key,
|
| 257 |
-
)
|
| 258 |
-
backend.put(obj, filepath)
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
def put_text(
|
| 262 |
-
obj: str,
|
| 263 |
-
filepath: Union[str, Path],
|
| 264 |
-
backend_args: Optional[dict] = None,
|
| 265 |
-
backend_key: Optional[str] = None,
|
| 266 |
-
) -> None:
|
| 267 |
-
"""Write text to a given ``filepath`` with 'w' mode.
|
| 268 |
-
|
| 269 |
-
Note:
|
| 270 |
-
``put_text`` should create a directory if the directory of
|
| 271 |
-
``filepath`` does not exist.
|
| 272 |
-
|
| 273 |
-
Args:
|
| 274 |
-
obj (str): Data to be written.
|
| 275 |
-
filepath (str or Path): Path to write data.
|
| 276 |
-
encoding (str, optional): The encoding format used to open the
|
| 277 |
-
``filepath``. Defaults to 'utf-8'.
|
| 278 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 279 |
-
corresponding backend. Defaults to None.
|
| 280 |
-
backend_key (str, optional): The key to get the backend from register.
|
| 281 |
-
|
| 282 |
-
Examples:
|
| 283 |
-
>>> filepath = '/path/of/file'
|
| 284 |
-
>>> put_text('hello world', filepath)
|
| 285 |
-
"""
|
| 286 |
-
backend = get_file_backend(
|
| 287 |
-
filepath,
|
| 288 |
-
backend_args=backend_args,
|
| 289 |
-
enable_singleton=True,
|
| 290 |
-
backend_key=backend_key,
|
| 291 |
-
)
|
| 292 |
-
backend.put_text(obj, filepath)
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
def exists(
|
| 296 |
-
filepath: Union[str, Path],
|
| 297 |
-
backend_args: Optional[dict] = None,
|
| 298 |
-
backend_key: Optional[str] = None,
|
| 299 |
-
) -> bool:
|
| 300 |
-
"""Check whether a file path exists.
|
| 301 |
-
|
| 302 |
-
Args:
|
| 303 |
-
filepath (str or Path): Path to be checked whether exists.
|
| 304 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 305 |
-
corresponding backend. Defaults to None.
|
| 306 |
-
backend_key (str, optional): The key to get the backend from register.
|
| 307 |
-
|
| 308 |
-
Returns:
|
| 309 |
-
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
|
| 310 |
-
|
| 311 |
-
Examples:
|
| 312 |
-
>>> filepath = '/path/of/file'
|
| 313 |
-
>>> exists(filepath)
|
| 314 |
-
True
|
| 315 |
-
"""
|
| 316 |
-
backend = get_file_backend(
|
| 317 |
-
filepath,
|
| 318 |
-
backend_args=backend_args,
|
| 319 |
-
enable_singleton=True,
|
| 320 |
-
backend_key=backend_key,
|
| 321 |
-
)
|
| 322 |
-
return backend.exists(filepath)
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
def isdir(
|
| 326 |
-
filepath: Union[str, Path],
|
| 327 |
-
backend_args: Optional[dict] = None,
|
| 328 |
-
backend_key: Optional[str] = None,
|
| 329 |
-
) -> bool:
|
| 330 |
-
"""Check whether a file path is a directory.
|
| 331 |
-
|
| 332 |
-
Args:
|
| 333 |
-
filepath (str or Path): Path to be checked whether it is a
|
| 334 |
-
directory.
|
| 335 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 336 |
-
corresponding backend. Defaults to None.
|
| 337 |
-
backend_key (str, optional): The key to get the backend from register.
|
| 338 |
-
|
| 339 |
-
Returns:
|
| 340 |
-
bool: Return ``True`` if ``filepath`` points to a directory,
|
| 341 |
-
``False`` otherwise.
|
| 342 |
-
|
| 343 |
-
Examples:
|
| 344 |
-
>>> filepath = '/path/of/dir'
|
| 345 |
-
>>> isdir(filepath)
|
| 346 |
-
True
|
| 347 |
-
"""
|
| 348 |
-
backend = get_file_backend(
|
| 349 |
-
filepath,
|
| 350 |
-
backend_args=backend_args,
|
| 351 |
-
enable_singleton=True,
|
| 352 |
-
backend_key=backend_key,
|
| 353 |
-
)
|
| 354 |
-
return backend.isdir(filepath)
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
def isfile(
|
| 358 |
-
filepath: Union[str, Path],
|
| 359 |
-
backend_args: Optional[dict] = None,
|
| 360 |
-
backend_key: Optional[str] = None,
|
| 361 |
-
) -> bool:
|
| 362 |
-
"""Check whether a file path is a file.
|
| 363 |
-
|
| 364 |
-
Args:
|
| 365 |
-
filepath (str or Path): Path to be checked whether it is a file.
|
| 366 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 367 |
-
corresponding backend. Defaults to None.
|
| 368 |
-
backend_key (str, optional): The key to get the backend from register.
|
| 369 |
-
|
| 370 |
-
Returns:
|
| 371 |
-
bool: Return ``True`` if ``filepath`` points to a file, ``False``
|
| 372 |
-
otherwise.
|
| 373 |
-
|
| 374 |
-
Examples:
|
| 375 |
-
>>> filepath = '/path/of/file'
|
| 376 |
-
>>> isfile(filepath)
|
| 377 |
-
True
|
| 378 |
-
"""
|
| 379 |
-
backend = get_file_backend(
|
| 380 |
-
filepath,
|
| 381 |
-
backend_args=backend_args,
|
| 382 |
-
enable_singleton=True,
|
| 383 |
-
backend_key=backend_key,
|
| 384 |
-
)
|
| 385 |
-
return backend.isfile(filepath)
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
def join_path(
|
| 389 |
-
filepath: Union[str, Path],
|
| 390 |
-
*filepaths: Union[str, Path],
|
| 391 |
-
backend_args: Optional[dict] = None,
|
| 392 |
-
backend_key: Optional[str] = None,
|
| 393 |
-
) -> Union[str, Path]:
|
| 394 |
-
r"""Concatenate all file paths.
|
| 395 |
-
|
| 396 |
-
Join one or more filepath components intelligently. The return value
|
| 397 |
-
is the concatenation of filepath and any members of \*filepaths.
|
| 398 |
-
|
| 399 |
-
Args:
|
| 400 |
-
filepath (str or Path): Path to be concatenated.
|
| 401 |
-
*filepaths (str or Path): Other paths to be concatenated.
|
| 402 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 403 |
-
corresponding backend. Defaults to None.
|
| 404 |
-
backend_key (str, optional): The key to get the backend from register.
|
| 405 |
-
|
| 406 |
-
Returns:
|
| 407 |
-
str: The result of concatenation.
|
| 408 |
-
|
| 409 |
-
Examples:
|
| 410 |
-
>>> filepath1 = '/path/of/dir1'
|
| 411 |
-
>>> filepath2 = 'dir2'
|
| 412 |
-
>>> filepath3 = 'path/of/file'
|
| 413 |
-
>>> join_path(filepath1, filepath2, filepath3)
|
| 414 |
-
'/path/of/dir/dir2/path/of/file'
|
| 415 |
-
"""
|
| 416 |
-
backend = get_file_backend(
|
| 417 |
-
filepath,
|
| 418 |
-
backend_args=backend_args,
|
| 419 |
-
enable_singleton=True,
|
| 420 |
-
backend_key=backend_key,
|
| 421 |
-
)
|
| 422 |
-
return backend.join_path(filepath, *filepaths)
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
@contextmanager
|
| 426 |
-
def get_local_path(
|
| 427 |
-
filepath: Union[str, Path],
|
| 428 |
-
backend_args: Optional[dict] = None,
|
| 429 |
-
backend_key: Optional[str] = None,
|
| 430 |
-
) -> Generator[Union[str, Path], None, None]:
|
| 431 |
-
"""Download data from ``filepath`` and write the data to local path.
|
| 432 |
-
|
| 433 |
-
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
|
| 434 |
-
can be called with ``with`` statement, and when exists from the
|
| 435 |
-
``with`` statement, the temporary path will be released.
|
| 436 |
-
|
| 437 |
-
Note:
|
| 438 |
-
If the ``filepath`` is a local path, just return itself and it will
|
| 439 |
-
not be released (removed).
|
| 440 |
-
|
| 441 |
-
Args:
|
| 442 |
-
filepath (str or Path): Path to be read data.
|
| 443 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 444 |
-
corresponding backend. Defaults to None.
|
| 445 |
-
|
| 446 |
-
Yields:
|
| 447 |
-
Iterable[str]: Only yield one path.
|
| 448 |
-
|
| 449 |
-
Examples:
|
| 450 |
-
>>> with get_local_path('s3://bucket/abc.jpg') as path:
|
| 451 |
-
... # do something here
|
| 452 |
-
"""
|
| 453 |
-
backend = get_file_backend(
|
| 454 |
-
filepath,
|
| 455 |
-
backend_args=backend_args,
|
| 456 |
-
enable_singleton=True,
|
| 457 |
-
backend_key=backend_key,
|
| 458 |
-
)
|
| 459 |
-
with backend.get_local_path(str(filepath)) as local_path:
|
| 460 |
-
yield local_path
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
def copyfile(
|
| 464 |
-
src: Union[str, Path],
|
| 465 |
-
dst: Union[str, Path],
|
| 466 |
-
backend_args: Optional[dict] = None,
|
| 467 |
-
backend_key: Optional[str] = None,
|
| 468 |
-
) -> Union[str, Path]:
|
| 469 |
-
"""Copy a file src to dst and return the destination file.
|
| 470 |
-
|
| 471 |
-
src and dst should have the same prefix. If dst specifies a directory,
|
| 472 |
-
the file will be copied into dst using the base filename from src. If
|
| 473 |
-
dst specifies a file that already exists, it will be replaced.
|
| 474 |
-
|
| 475 |
-
Args:
|
| 476 |
-
src (str or Path): A file to be copied.
|
| 477 |
-
dst (str or Path): Copy file to dst.
|
| 478 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 479 |
-
corresponding backend. Defaults to None.
|
| 480 |
-
|
| 481 |
-
Returns:
|
| 482 |
-
str: The destination file.
|
| 483 |
-
|
| 484 |
-
Raises:
|
| 485 |
-
SameFileError: If src and dst are the same file, a SameFileError will
|
| 486 |
-
be raised.
|
| 487 |
-
|
| 488 |
-
Examples:
|
| 489 |
-
>>> # dst is a file
|
| 490 |
-
>>> src = '/path/of/file'
|
| 491 |
-
>>> dst = '/path1/of/file1'
|
| 492 |
-
>>> # src will be copied to '/path1/of/file1'
|
| 493 |
-
>>> copyfile(src, dst)
|
| 494 |
-
'/path1/of/file1'
|
| 495 |
-
|
| 496 |
-
>>> # dst is a directory
|
| 497 |
-
>>> dst = '/path1/of/dir'
|
| 498 |
-
>>> # src will be copied to '/path1/of/dir/file'
|
| 499 |
-
>>> copyfile(src, dst)
|
| 500 |
-
'/path1/of/dir/file'
|
| 501 |
-
"""
|
| 502 |
-
backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
|
| 503 |
-
return backend.copyfile(src, dst)
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
def copytree(
|
| 507 |
-
src: Union[str, Path],
|
| 508 |
-
dst: Union[str, Path],
|
| 509 |
-
backend_args: Optional[dict] = None,
|
| 510 |
-
backend_key: Optional[str] = None,
|
| 511 |
-
) -> Union[str, Path]:
|
| 512 |
-
"""Recursively copy an entire directory tree rooted at src to a directory
|
| 513 |
-
named dst and return the destination directory.
|
| 514 |
-
|
| 515 |
-
src and dst should have the same prefix and dst must not already exist.
|
| 516 |
-
|
| 517 |
-
Args:
|
| 518 |
-
src (str or Path): A directory to be copied.
|
| 519 |
-
dst (str or Path): Copy directory to dst.
|
| 520 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 521 |
-
corresponding backend. Defaults to None.
|
| 522 |
-
backend_key (str, optional): The key to get the backend from register.
|
| 523 |
-
|
| 524 |
-
Returns:
|
| 525 |
-
str: The destination directory.
|
| 526 |
-
|
| 527 |
-
Raises:
|
| 528 |
-
FileExistsError: If dst had already existed, a FileExistsError will be
|
| 529 |
-
raised.
|
| 530 |
-
|
| 531 |
-
Examples:
|
| 532 |
-
>>> src = '/path/of/dir1'
|
| 533 |
-
>>> dst = '/path/of/dir2'
|
| 534 |
-
>>> copytree(src, dst)
|
| 535 |
-
'/path/of/dir2'
|
| 536 |
-
"""
|
| 537 |
-
backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
|
| 538 |
-
return backend.copytree(src, dst)
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
def copyfile_from_local(
|
| 542 |
-
src: Union[str, Path],
|
| 543 |
-
dst: Union[str, Path],
|
| 544 |
-
backend_args: Optional[dict] = None,
|
| 545 |
-
backend_key: Optional[str] = None,
|
| 546 |
-
) -> Union[str, Path]:
|
| 547 |
-
"""Copy a local file src to dst and return the destination file.
|
| 548 |
-
|
| 549 |
-
Note:
|
| 550 |
-
If the backend is the instance of LocalBackend, it does the same
|
| 551 |
-
thing with :func:`copyfile`.
|
| 552 |
-
|
| 553 |
-
Args:
|
| 554 |
-
src (str or Path): A local file to be copied.
|
| 555 |
-
dst (str or Path): Copy file to dst.
|
| 556 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 557 |
-
corresponding backend. Defaults to None.
|
| 558 |
-
|
| 559 |
-
Returns:
|
| 560 |
-
str: If dst specifies a directory, the file will be copied into dst
|
| 561 |
-
using the base filename from src.
|
| 562 |
-
|
| 563 |
-
Examples:
|
| 564 |
-
>>> # dst is a file
|
| 565 |
-
>>> src = '/path/of/file'
|
| 566 |
-
>>> dst = 's3://openmmlab/mmengine/file1'
|
| 567 |
-
>>> # src will be copied to 's3://openmmlab/mmengine/file1'
|
| 568 |
-
>>> copyfile_from_local(src, dst)
|
| 569 |
-
s3://openmmlab/mmengine/file1
|
| 570 |
-
|
| 571 |
-
>>> # dst is a directory
|
| 572 |
-
>>> dst = 's3://openmmlab/mmengine'
|
| 573 |
-
>>> # src will be copied to 's3://openmmlab/mmengine/file''
|
| 574 |
-
>>> copyfile_from_local(src, dst)
|
| 575 |
-
's3://openmmlab/mmengine/file'
|
| 576 |
-
"""
|
| 577 |
-
backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
|
| 578 |
-
return backend.copyfile_from_local(src, dst)
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
def copytree_from_local(
|
| 582 |
-
src: Union[str, Path],
|
| 583 |
-
dst: Union[str, Path],
|
| 584 |
-
backend_args: Optional[dict] = None,
|
| 585 |
-
backend_key: Optional[str] = None,
|
| 586 |
-
) -> Union[str, Path]:
|
| 587 |
-
"""Recursively copy an entire directory tree rooted at src to a directory
|
| 588 |
-
named dst and return the destination directory.
|
| 589 |
-
|
| 590 |
-
Note:
|
| 591 |
-
If the backend is the instance of LocalBackend, it does the same
|
| 592 |
-
thing with :func:`copytree`.
|
| 593 |
-
|
| 594 |
-
Args:
|
| 595 |
-
src (str or Path): A local directory to be copied.
|
| 596 |
-
dst (str or Path): Copy directory to dst.
|
| 597 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 598 |
-
corresponding backend. Defaults to None.
|
| 599 |
-
|
| 600 |
-
Returns:
|
| 601 |
-
str: The destination directory.
|
| 602 |
-
|
| 603 |
-
Examples:
|
| 604 |
-
>>> src = '/path/of/dir'
|
| 605 |
-
>>> dst = 's3://openmmlab/mmengine/dir'
|
| 606 |
-
>>> copyfile_from_local(src, dst)
|
| 607 |
-
's3://openmmlab/mmengine/dir'
|
| 608 |
-
"""
|
| 609 |
-
backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
|
| 610 |
-
return backend.copytree_from_local(src, dst)
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
def copyfile_to_local(
|
| 614 |
-
src: Union[str, Path],
|
| 615 |
-
dst: Union[str, Path],
|
| 616 |
-
dst_type: str, # Choose from ["file", "dir"]
|
| 617 |
-
backend_args: Optional[dict] = None,
|
| 618 |
-
backend_key: Optional[str] = None,
|
| 619 |
-
) -> Union[str, Path]:
|
| 620 |
-
"""Copy the file src to local dst and return the destination file.
|
| 621 |
-
|
| 622 |
-
If dst specifies a directory, the file will be copied into dst using
|
| 623 |
-
the base filename from src. If dst specifies a file that already
|
| 624 |
-
exists, it will be replaced.
|
| 625 |
-
|
| 626 |
-
Note:
|
| 627 |
-
If the backend is the instance of LocalBackend, it does the same
|
| 628 |
-
thing with :func:`copyfile`.
|
| 629 |
-
|
| 630 |
-
Args:
|
| 631 |
-
src (str or Path): A file to be copied.
|
| 632 |
-
dst (str or Path): Copy file to to local dst.
|
| 633 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 634 |
-
corresponding backend. Defaults to None.
|
| 635 |
-
|
| 636 |
-
Returns:
|
| 637 |
-
str: If dst specifies a directory, the file will be copied into dst
|
| 638 |
-
using the base filename from src.
|
| 639 |
-
|
| 640 |
-
Examples:
|
| 641 |
-
>>> # dst is a file
|
| 642 |
-
>>> src = 's3://openmmlab/mmengine/file'
|
| 643 |
-
>>> dst = '/path/of/file'
|
| 644 |
-
>>> # src will be copied to '/path/of/file'
|
| 645 |
-
>>> copyfile_to_local(src, dst)
|
| 646 |
-
'/path/of/file'
|
| 647 |
-
|
| 648 |
-
>>> # dst is a directory
|
| 649 |
-
>>> dst = '/path/of/dir'
|
| 650 |
-
>>> # src will be copied to '/path/of/dir/file'
|
| 651 |
-
>>> copyfile_to_local(src, dst)
|
| 652 |
-
'/path/of/dir/file'
|
| 653 |
-
"""
|
| 654 |
-
assert dst_type in ["file", "dir"]
|
| 655 |
-
Path(dst).parent.mkdir(parents=True, exist_ok=True)
|
| 656 |
-
backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
|
| 657 |
-
return backend.copyfile_to_local(src, dst, dst_type=dst_type)
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
def copytree_to_local(
|
| 661 |
-
src: Union[str, Path],
|
| 662 |
-
dst: Union[str, Path],
|
| 663 |
-
backend_args: Optional[dict] = None,
|
| 664 |
-
backend_key: Optional[str] = None,
|
| 665 |
-
) -> Union[str, Path]:
|
| 666 |
-
"""Recursively copy an entire directory tree rooted at src to a local
|
| 667 |
-
directory named dst and return the destination directory.
|
| 668 |
-
|
| 669 |
-
Note:
|
| 670 |
-
If the backend is the instance of LocalBackend, it does the same
|
| 671 |
-
thing with :func:`copytree`.
|
| 672 |
-
|
| 673 |
-
Args:
|
| 674 |
-
src (str or Path): A directory to be copied.
|
| 675 |
-
dst (str or Path): Copy directory to local dst.
|
| 676 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 677 |
-
corresponding backend. Defaults to None.
|
| 678 |
-
|
| 679 |
-
Returns:
|
| 680 |
-
str: The destination directory.
|
| 681 |
-
|
| 682 |
-
Examples:
|
| 683 |
-
>>> src = 's3://openmmlab/mmengine/dir'
|
| 684 |
-
>>> dst = '/path/of/dir'
|
| 685 |
-
>>> copytree_to_local(src, dst)
|
| 686 |
-
'/path/of/dir'
|
| 687 |
-
"""
|
| 688 |
-
Path(dst).parent.mkdir(parents=True, exist_ok=True)
|
| 689 |
-
backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
|
| 690 |
-
return backend.copytree_to_local(src, dst)
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
def remove(
|
| 694 |
-
filepath: Union[str, Path],
|
| 695 |
-
backend_args: Optional[dict] = None,
|
| 696 |
-
backend_key: Optional[str] = None,
|
| 697 |
-
) -> None:
|
| 698 |
-
"""Remove a file.
|
| 699 |
-
|
| 700 |
-
Args:
|
| 701 |
-
filepath (str, Path): Path to be removed.
|
| 702 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 703 |
-
corresponding backend. Defaults to None.
|
| 704 |
-
|
| 705 |
-
Raises:
|
| 706 |
-
FileNotFoundError: If filepath does not exist, an FileNotFoundError
|
| 707 |
-
will be raised.
|
| 708 |
-
IsADirectoryError: If filepath is a directory, an IsADirectoryError
|
| 709 |
-
will be raised.
|
| 710 |
-
|
| 711 |
-
Examples:
|
| 712 |
-
>>> filepath = '/path/of/file'
|
| 713 |
-
>>> remove(filepath)
|
| 714 |
-
"""
|
| 715 |
-
backend = get_file_backend(
|
| 716 |
-
filepath,
|
| 717 |
-
backend_args=backend_args,
|
| 718 |
-
enable_singleton=True,
|
| 719 |
-
backend_key=backend_key,
|
| 720 |
-
)
|
| 721 |
-
backend.remove(filepath)
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
def rmtree(
|
| 725 |
-
dir_path: Union[str, Path],
|
| 726 |
-
backend_args: Optional[dict] = None,
|
| 727 |
-
backend_key: Optional[str] = None,
|
| 728 |
-
) -> None:
|
| 729 |
-
"""Recursively delete a directory tree.
|
| 730 |
-
|
| 731 |
-
Args:
|
| 732 |
-
dir_path (str or Path): A directory to be removed.
|
| 733 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 734 |
-
corresponding backend. Defaults to None.
|
| 735 |
-
|
| 736 |
-
Examples:
|
| 737 |
-
>>> dir_path = '/path/of/dir'
|
| 738 |
-
>>> rmtree(dir_path)
|
| 739 |
-
"""
|
| 740 |
-
backend = get_file_backend(
|
| 741 |
-
dir_path,
|
| 742 |
-
backend_args=backend_args,
|
| 743 |
-
enable_singleton=True,
|
| 744 |
-
backend_key=backend_key,
|
| 745 |
-
)
|
| 746 |
-
backend.rmtree(dir_path)
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
def copy_if_symlink_fails(
|
| 750 |
-
src: Union[str, Path],
|
| 751 |
-
dst: Union[str, Path],
|
| 752 |
-
backend_args: Optional[dict] = None,
|
| 753 |
-
backend_key: Optional[str] = None,
|
| 754 |
-
) -> bool:
|
| 755 |
-
"""Create a symbolic link pointing to src named dst.
|
| 756 |
-
|
| 757 |
-
If failed to create a symbolic link pointing to src, directory copy src to
|
| 758 |
-
dst instead.
|
| 759 |
-
|
| 760 |
-
Args:
|
| 761 |
-
src (str or Path): Create a symbolic link pointing to src.
|
| 762 |
-
dst (str or Path): Create a symbolic link named dst.
|
| 763 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 764 |
-
corresponding backend. Defaults to None.
|
| 765 |
-
|
| 766 |
-
Returns:
|
| 767 |
-
bool: Return True if successfully create a symbolic link pointing to
|
| 768 |
-
src. Otherwise, return False.
|
| 769 |
-
|
| 770 |
-
Examples:
|
| 771 |
-
>>> src = '/path/of/file'
|
| 772 |
-
>>> dst = '/path1/of/file1'
|
| 773 |
-
>>> copy_if_symlink_fails(src, dst)
|
| 774 |
-
True
|
| 775 |
-
>>> src = '/path/of/dir'
|
| 776 |
-
>>> dst = '/path1/of/dir1'
|
| 777 |
-
>>> copy_if_symlink_fails(src, dst)
|
| 778 |
-
True
|
| 779 |
-
"""
|
| 780 |
-
backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
|
| 781 |
-
return backend.copy_if_symlink_fails(src, dst)
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
def list_dir(
|
| 785 |
-
dir_path: Union[str, Path],
|
| 786 |
-
backend_args: Optional[dict] = None,
|
| 787 |
-
backend_key: Optional[str] = None,
|
| 788 |
-
):
|
| 789 |
-
"""List all folders in an S3 bucket with a given prefix.
|
| 790 |
-
|
| 791 |
-
Args:
|
| 792 |
-
dir_path (str | Path): Path of the directory.
|
| 793 |
-
|
| 794 |
-
Examples:
|
| 795 |
-
>>> dir_path = '/path/of/dir'
|
| 796 |
-
>>> for file_path in list_dir(dir_path):
|
| 797 |
-
... print(file_path)
|
| 798 |
-
"""
|
| 799 |
-
if not dir_path.endswith("/"):
|
| 800 |
-
dir_path += "/"
|
| 801 |
-
backend = get_file_backend(
|
| 802 |
-
dir_path,
|
| 803 |
-
backend_args=backend_args,
|
| 804 |
-
enable_singleton=True,
|
| 805 |
-
backend_key=backend_key,
|
| 806 |
-
)
|
| 807 |
-
|
| 808 |
-
return backend.list_dir(dir_path)
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
def list_dir_or_file(
|
| 812 |
-
dir_path: Union[str, Path],
|
| 813 |
-
list_dir: bool = True,
|
| 814 |
-
list_file: bool = True,
|
| 815 |
-
suffix: Optional[Union[str, Tuple[str]]] = None,
|
| 816 |
-
recursive: bool = False,
|
| 817 |
-
backend_args: Optional[dict] = None,
|
| 818 |
-
backend_key: Optional[str] = None,
|
| 819 |
-
) -> Iterator[str]:
|
| 820 |
-
"""Scan a directory to find the interested directories or files in
|
| 821 |
-
arbitrary order.
|
| 822 |
-
|
| 823 |
-
Note:
|
| 824 |
-
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
|
| 825 |
-
|
| 826 |
-
Args:
|
| 827 |
-
dir_path (str or Path): Path of the directory.
|
| 828 |
-
list_dir (bool): List the directories. Defaults to True.
|
| 829 |
-
list_file (bool): List the path of files. Defaults to True.
|
| 830 |
-
suffix (str or tuple[str], optional): File suffix that we are
|
| 831 |
-
interested in. Defaults to None.
|
| 832 |
-
recursive (bool): If set to True, recursively scan the directory.
|
| 833 |
-
Defaults to False.
|
| 834 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 835 |
-
corresponding backend. Defaults to None.
|
| 836 |
-
|
| 837 |
-
Yields:
|
| 838 |
-
Iterable[str]: A relative path to ``dir_path``.
|
| 839 |
-
|
| 840 |
-
Examples:
|
| 841 |
-
>>> dir_path = '/path/of/dir'
|
| 842 |
-
>>> for file_path in list_dir_or_file(dir_path):
|
| 843 |
-
... print(file_path)
|
| 844 |
-
>>> # list those files and directories in current directory
|
| 845 |
-
>>> for file_path in list_dir_or_file(dir_path):
|
| 846 |
-
... print(file_path)
|
| 847 |
-
>>> # only list files
|
| 848 |
-
>>> for file_path in list_dir_or_file(dir_path, list_dir=False):
|
| 849 |
-
... print(file_path)
|
| 850 |
-
>>> # only list directories
|
| 851 |
-
>>> for file_path in list_dir_or_file(dir_path, list_file=False):
|
| 852 |
-
... print(file_path)
|
| 853 |
-
>>> # only list files ending with specified suffixes
|
| 854 |
-
>>> for file_path in list_dir_or_file(dir_path, suffix='.txt'):
|
| 855 |
-
... print(file_path)
|
| 856 |
-
>>> # list all files and directory recursively
|
| 857 |
-
>>> for file_path in list_dir_or_file(dir_path, recursive=True):
|
| 858 |
-
... print(file_path)
|
| 859 |
-
"""
|
| 860 |
-
backend = get_file_backend(
|
| 861 |
-
dir_path,
|
| 862 |
-
backend_args=backend_args,
|
| 863 |
-
enable_singleton=True,
|
| 864 |
-
backend_key=backend_key,
|
| 865 |
-
)
|
| 866 |
-
yield from backend.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive)
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
def generate_presigned_url(
|
| 870 |
-
url: str,
|
| 871 |
-
client_method: str = "get_object",
|
| 872 |
-
expires_in: int = 3600,
|
| 873 |
-
backend_args: Optional[dict] = None,
|
| 874 |
-
backend_key: Optional[str] = None,
|
| 875 |
-
) -> str:
|
| 876 |
-
"""Generate the presigned url of video stream which can be passed to
|
| 877 |
-
mmcv.VideoReader. Now only work on s3 backend.
|
| 878 |
-
|
| 879 |
-
Note:
|
| 880 |
-
Now only work on s3 backend.
|
| 881 |
-
|
| 882 |
-
Args:
|
| 883 |
-
url (str): Url of video stream.
|
| 884 |
-
client_method (str): Method of client, 'get_object' or
|
| 885 |
-
'put_object'. Defaults to 'get_object'.
|
| 886 |
-
expires_in (int): expires, in seconds. Defaults to 3600.
|
| 887 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 888 |
-
corresponding backend. Defaults to None.
|
| 889 |
-
|
| 890 |
-
Returns:
|
| 891 |
-
str: Generated presigned url.
|
| 892 |
-
"""
|
| 893 |
-
backend = get_file_backend(url, backend_args=backend_args, enable_singleton=True, backend_key=backend_key)
|
| 894 |
-
return backend.generate_presigned_url(url, client_method, expires_in)
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
def load(
|
| 898 |
-
file: Union[str, Path, IO[Any]],
|
| 899 |
-
file_format: Optional[str] = None,
|
| 900 |
-
file_client_args: Optional[dict] = None,
|
| 901 |
-
fast_backend: bool = False,
|
| 902 |
-
backend_args: Optional[dict] = None,
|
| 903 |
-
backend_key: Optional[str] = None,
|
| 904 |
-
**kwargs,
|
| 905 |
-
):
|
| 906 |
-
"""Load data from json/yaml/pickle files.
|
| 907 |
-
|
| 908 |
-
This method provides a unified api for loading data from serialized files.
|
| 909 |
-
|
| 910 |
-
``load`` supports loading data from serialized files those can be storaged
|
| 911 |
-
in different backends.
|
| 912 |
-
|
| 913 |
-
Args:
|
| 914 |
-
file (str or :obj:`Path` or file-like object): Filename or a file-like
|
| 915 |
-
object.
|
| 916 |
-
file_format (str, optional): If not specified, the file format will be
|
| 917 |
-
inferred from the file extension, otherwise use the specified one.
|
| 918 |
-
Currently supported formats include "json", "yaml/yml" and
|
| 919 |
-
"pickle/pkl".
|
| 920 |
-
file_client_args (dict, optional): Arguments to instantiate a
|
| 921 |
-
FileClient. See :class:`mmengine.fileio.FileClient` for details.
|
| 922 |
-
Defaults to None. It will be deprecated in future. Please use
|
| 923 |
-
``backend_args`` instead.
|
| 924 |
-
fast_backend: bool: Whether to use multiprocess. Defaults to False.
|
| 925 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 926 |
-
prefix of uri corresponding backend. Defaults to None.
|
| 927 |
-
New in v0.2.0.
|
| 928 |
-
|
| 929 |
-
Examples:
|
| 930 |
-
>>> load('/path/of/your/file') # file is storaged in disk
|
| 931 |
-
>>> load('https://path/of/your/file') # file is storaged in Internet
|
| 932 |
-
>>> load('s3://path/of/your/file') # file is storaged in s3
|
| 933 |
-
|
| 934 |
-
Returns:
|
| 935 |
-
The content from the file.
|
| 936 |
-
"""
|
| 937 |
-
if isinstance(file, Path):
|
| 938 |
-
file = str(file)
|
| 939 |
-
if file_format is None and isinstance(file, str):
|
| 940 |
-
file_format = file.split(".")[-1]
|
| 941 |
-
# convert file_format to lower case
|
| 942 |
-
file_format = file_format.lower()
|
| 943 |
-
if file_format not in file_handlers:
|
| 944 |
-
raise TypeError(f"Unsupported format: {file_format}")
|
| 945 |
-
|
| 946 |
-
if file_client_args is not None:
|
| 947 |
-
warnings.warn(
|
| 948 |
-
'"file_client_args" will be deprecated in future. Please use "backend_args" instead',
|
| 949 |
-
DeprecationWarning,
|
| 950 |
-
)
|
| 951 |
-
if backend_args is not None:
|
| 952 |
-
raise ValueError('"file_client_args and "backend_args" cannot be set at the same time.')
|
| 953 |
-
|
| 954 |
-
handler = file_handlers[file_format]
|
| 955 |
-
if isinstance(file, str):
|
| 956 |
-
if file_client_args is not None:
|
| 957 |
-
file_client = FileClient.infer_client(file_client_args, file)
|
| 958 |
-
file_backend = file_client
|
| 959 |
-
else:
|
| 960 |
-
file_backend = get_file_backend(
|
| 961 |
-
file,
|
| 962 |
-
backend_args=backend_args,
|
| 963 |
-
backend_key=backend_key,
|
| 964 |
-
enable_singleton=True,
|
| 965 |
-
)
|
| 966 |
-
|
| 967 |
-
if handler.str_like:
|
| 968 |
-
with StringIO(file_backend.get_text(file)) as f:
|
| 969 |
-
obj = handler.load_from_fileobj(f, **kwargs)
|
| 970 |
-
else:
|
| 971 |
-
if fast_backend:
|
| 972 |
-
if hasattr(file_backend, "fast_get"):
|
| 973 |
-
with BytesIO(file_backend.fast_get(file)) as f:
|
| 974 |
-
obj = handler.load_from_fileobj(f, **kwargs)
|
| 975 |
-
else:
|
| 976 |
-
warnings.warn(
|
| 977 |
-
f"fast_backend is not supported by the backend, type {type(file_backend)} fallback to normal get"
|
| 978 |
-
)
|
| 979 |
-
with BytesIO(file_backend.get(file)) as f:
|
| 980 |
-
obj = handler.load_from_fileobj(f, **kwargs)
|
| 981 |
-
else:
|
| 982 |
-
with BytesIO(file_backend.get(file)) as f:
|
| 983 |
-
obj = handler.load_from_fileobj(f, **kwargs)
|
| 984 |
-
elif hasattr(file, "read"):
|
| 985 |
-
obj = handler.load_from_fileobj(file, **kwargs)
|
| 986 |
-
else:
|
| 987 |
-
raise TypeError('"file" must be a filepath str or a file-object')
|
| 988 |
-
return obj
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
def dump(
|
| 992 |
-
obj: Any,
|
| 993 |
-
file: Union[str, Path, IO[Any], None] = None,
|
| 994 |
-
file_format: Optional[str] = None,
|
| 995 |
-
file_client_args: Optional[dict] = None,
|
| 996 |
-
fast_backend: bool = False,
|
| 997 |
-
backend_args: Optional[dict] = None,
|
| 998 |
-
backend_key: Optional[str] = None,
|
| 999 |
-
**kwargs,
|
| 1000 |
-
):
|
| 1001 |
-
"""Dump data to json/yaml/pickle strings or files.
|
| 1002 |
-
|
| 1003 |
-
This method provides a unified api for dumping data as strings or to files,
|
| 1004 |
-
and also supports custom arguments for each file format.
|
| 1005 |
-
|
| 1006 |
-
``dump`` supports dumping data as strings or to files which is saved to
|
| 1007 |
-
different backends.
|
| 1008 |
-
|
| 1009 |
-
Args:
|
| 1010 |
-
obj (any): The python object to be dumped.
|
| 1011 |
-
file (str or :obj:`Path` or file-like object, optional): If not
|
| 1012 |
-
specified, then the object is dumped to a str, otherwise to a file
|
| 1013 |
-
specified by the filename or file-like object.
|
| 1014 |
-
file_format (str, optional): Same as :func:`load`.
|
| 1015 |
-
file_client_args (dict, optional): Arguments to instantiate a
|
| 1016 |
-
FileClient. See :class:`mmengine.fileio.FileClient` for details.
|
| 1017 |
-
Defaults to None. It will be deprecated in future. Please use
|
| 1018 |
-
``backend_args`` instead.
|
| 1019 |
-
fast_backend: bool: Whether to use multiprocess. Defaults to False.
|
| 1020 |
-
backend_args (dict, optional): Arguments to instantiate the
|
| 1021 |
-
prefix of uri corresponding backend. Defaults to None.
|
| 1022 |
-
New in v0.2.0.
|
| 1023 |
-
backend_key: str: The key to register the backend. Defaults to None.
|
| 1024 |
-
|
| 1025 |
-
Examples:
|
| 1026 |
-
>>> dump('hello world', '/path/of/your/file') # disk
|
| 1027 |
-
>>> dump('hello world', 's3://path/of/your/file') # ceph or s3
|
| 1028 |
-
|
| 1029 |
-
Returns:
|
| 1030 |
-
bool: True for success, False otherwise.
|
| 1031 |
-
"""
|
| 1032 |
-
if isinstance(file, Path):
|
| 1033 |
-
file = str(file)
|
| 1034 |
-
if file_format is None:
|
| 1035 |
-
if isinstance(file, str):
|
| 1036 |
-
file_format = file.split(".")[-1]
|
| 1037 |
-
elif file is None:
|
| 1038 |
-
raise ValueError("file_format must be specified since file is None")
|
| 1039 |
-
# convert file_format to lower case
|
| 1040 |
-
file_format = file_format.lower()
|
| 1041 |
-
if file_format not in file_handlers:
|
| 1042 |
-
raise TypeError(f"Unsupported format: {file_format}")
|
| 1043 |
-
|
| 1044 |
-
if file_client_args is not None:
|
| 1045 |
-
warnings.warn(
|
| 1046 |
-
'"file_client_args" will be deprecated in future. Please use "backend_args" instead',
|
| 1047 |
-
DeprecationWarning,
|
| 1048 |
-
)
|
| 1049 |
-
if backend_args is not None:
|
| 1050 |
-
raise ValueError('"file_client_args" and "backend_args" cannot be set at the same time.')
|
| 1051 |
-
|
| 1052 |
-
handler = file_handlers[file_format]
|
| 1053 |
-
if file is None:
|
| 1054 |
-
return handler.dump_to_str(obj, **kwargs)
|
| 1055 |
-
elif isinstance(file, str):
|
| 1056 |
-
if file_client_args is not None:
|
| 1057 |
-
file_client = FileClient.infer_client(file_client_args, file)
|
| 1058 |
-
file_backend = file_client
|
| 1059 |
-
else:
|
| 1060 |
-
file_backend = get_file_backend(
|
| 1061 |
-
file,
|
| 1062 |
-
backend_args=backend_args,
|
| 1063 |
-
backend_key=backend_key,
|
| 1064 |
-
enable_singleton=True,
|
| 1065 |
-
)
|
| 1066 |
-
|
| 1067 |
-
if handler.str_like:
|
| 1068 |
-
with StringIO() as f:
|
| 1069 |
-
handler.dump_to_fileobj(obj, f, **kwargs)
|
| 1070 |
-
file_backend.put_text(f.getvalue(), file)
|
| 1071 |
-
else:
|
| 1072 |
-
with BytesIO() as f:
|
| 1073 |
-
handler.dump_to_fileobj(obj, f, **kwargs)
|
| 1074 |
-
if fast_backend:
|
| 1075 |
-
if hasattr(file_backend, "fast_put"):
|
| 1076 |
-
file_backend.fast_put(f, file)
|
| 1077 |
-
else:
|
| 1078 |
-
warnings.warn("fast_backend is not supported by the backend, fallback to normal put")
|
| 1079 |
-
file_backend.put(f, file)
|
| 1080 |
-
else:
|
| 1081 |
-
file_backend.put(f, file)
|
| 1082 |
-
elif hasattr(file, "write"):
|
| 1083 |
-
handler.dump_to_fileobj(obj, file, **kwargs)
|
| 1084 |
-
else:
|
| 1085 |
-
raise TypeError('"file" must be a filename str or a file-object')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/file_client.py
DELETED
|
@@ -1,458 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import inspect
|
| 17 |
-
from contextlib import contextmanager
|
| 18 |
-
from pathlib import Path
|
| 19 |
-
from typing import Any, Generator, Iterator, Optional, Tuple, Union
|
| 20 |
-
|
| 21 |
-
from lyra_2._ext.imaginaire.utils.easy_io.backends import (
|
| 22 |
-
BaseStorageBackend,
|
| 23 |
-
Boto3Backend,
|
| 24 |
-
HTTPBackend,
|
| 25 |
-
LocalBackend,
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def is_filepath(filepath):
|
| 30 |
-
return isinstance(filepath, (str, Path))
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class HardDiskBackend(LocalBackend):
|
| 34 |
-
"""Raw hard disks storage backend."""
|
| 35 |
-
|
| 36 |
-
@property
|
| 37 |
-
def name(self):
|
| 38 |
-
return self.__class__.__name__
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
class FileClient:
|
| 42 |
-
"""A general file client to access files in different backends.
|
| 43 |
-
|
| 44 |
-
The client loads a file or text in a specified backend from its path
|
| 45 |
-
and returns it as a binary or text file. There are two ways to choose a
|
| 46 |
-
backend, the name of backend and the prefix of path. Although both of them
|
| 47 |
-
can be used to choose a storage backend, ``backend`` has a higher priority
|
| 48 |
-
that is if they are all set, the storage backend will be chosen by the
|
| 49 |
-
backend argument. If they are all `None`, the disk backend will be chosen.
|
| 50 |
-
Note that It can also register other backend accessor with a given name,
|
| 51 |
-
prefixes, and backend class. In addition, We use the singleton pattern to
|
| 52 |
-
avoid repeated object creation. If the arguments are the same, the same
|
| 53 |
-
object will be returned.
|
| 54 |
-
|
| 55 |
-
Warning:
|
| 56 |
-
`FileClient` will be deprecated in future. Please use io functions
|
| 57 |
-
in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io
|
| 58 |
-
|
| 59 |
-
Args:
|
| 60 |
-
backend (str, optional): The storage backend type. Options are "disk",
|
| 61 |
-
"memcached", "lmdb", "http" and "s3". Defaults to None.
|
| 62 |
-
prefix (str, optional): The prefix of the registered storage backend.
|
| 63 |
-
Options are "s3", "http", "https". Defaults to None.
|
| 64 |
-
|
| 65 |
-
Examples:
|
| 66 |
-
>>> # only set backend
|
| 67 |
-
>>> file_client = FileClient(backend='s3')
|
| 68 |
-
>>> # only set prefix
|
| 69 |
-
>>> file_client = FileClient(prefix='s3')
|
| 70 |
-
>>> # set both backend and prefix but use backend to choose client
|
| 71 |
-
>>> file_client = FileClient(backend='s3', prefix='s3')
|
| 72 |
-
>>> # if the arguments are the same, the same object is returned
|
| 73 |
-
>>> file_client1 = FileClient(backend='s3')
|
| 74 |
-
>>> file_client1 is file_client
|
| 75 |
-
True
|
| 76 |
-
|
| 77 |
-
Attributes:
|
| 78 |
-
client (:obj:`BaseStorageBackend`): The backend object.
|
| 79 |
-
"""
|
| 80 |
-
|
| 81 |
-
_backends = {
|
| 82 |
-
"disk": HardDiskBackend,
|
| 83 |
-
"s3": Boto3Backend,
|
| 84 |
-
"http": HTTPBackend,
|
| 85 |
-
}
|
| 86 |
-
|
| 87 |
-
_prefix_to_backends: dict = {
|
| 88 |
-
"s3": Boto3Backend,
|
| 89 |
-
"http": HTTPBackend,
|
| 90 |
-
"https": HTTPBackend,
|
| 91 |
-
}
|
| 92 |
-
|
| 93 |
-
_instances: dict = {}
|
| 94 |
-
|
| 95 |
-
client: Any
|
| 96 |
-
|
| 97 |
-
def __new__(cls, backend=None, prefix=None, **kwargs):
|
| 98 |
-
if backend is None and prefix is None:
|
| 99 |
-
backend = "disk"
|
| 100 |
-
if backend is not None and backend not in cls._backends:
|
| 101 |
-
raise ValueError(
|
| 102 |
-
f"Backend {backend} is not supported. Currently supported ones are {list(cls._backends.keys())}"
|
| 103 |
-
)
|
| 104 |
-
if prefix is not None and prefix not in cls._prefix_to_backends:
|
| 105 |
-
raise ValueError(
|
| 106 |
-
f"prefix {prefix} is not supported. Currently supported ones are {list(cls._prefix_to_backends.keys())}"
|
| 107 |
-
)
|
| 108 |
-
|
| 109 |
-
# concatenate the arguments to a unique key for determining whether
|
| 110 |
-
# objects with the same arguments were created
|
| 111 |
-
arg_key = f"{backend}:{prefix}"
|
| 112 |
-
for key, value in kwargs.items():
|
| 113 |
-
arg_key += f":{key}:{value}"
|
| 114 |
-
|
| 115 |
-
# if a backend was overridden, it will create a new object
|
| 116 |
-
if arg_key in cls._instances:
|
| 117 |
-
_instance = cls._instances[arg_key]
|
| 118 |
-
else:
|
| 119 |
-
# create a new object and put it to _instance
|
| 120 |
-
_instance = super().__new__(cls)
|
| 121 |
-
if backend is not None:
|
| 122 |
-
_instance.client = cls._backends[backend](**kwargs)
|
| 123 |
-
else:
|
| 124 |
-
_instance.client = cls._prefix_to_backends[prefix](**kwargs)
|
| 125 |
-
|
| 126 |
-
cls._instances[arg_key] = _instance
|
| 127 |
-
|
| 128 |
-
return _instance
|
| 129 |
-
|
| 130 |
-
@property
|
| 131 |
-
def name(self):
|
| 132 |
-
return self.client.name
|
| 133 |
-
|
| 134 |
-
@property
|
| 135 |
-
def allow_symlink(self):
|
| 136 |
-
return self.client.allow_symlink
|
| 137 |
-
|
| 138 |
-
@staticmethod
|
| 139 |
-
def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]:
|
| 140 |
-
"""Parse the prefix of a uri.
|
| 141 |
-
|
| 142 |
-
Args:
|
| 143 |
-
uri (str | Path): Uri to be parsed that contains the file prefix.
|
| 144 |
-
|
| 145 |
-
Examples:
|
| 146 |
-
>>> FileClient.parse_uri_prefix('s3://path/of/your/file')
|
| 147 |
-
's3'
|
| 148 |
-
|
| 149 |
-
Returns:
|
| 150 |
-
str | None: Return the prefix of uri if the uri contains '://' else
|
| 151 |
-
``None``.
|
| 152 |
-
"""
|
| 153 |
-
assert is_filepath(uri)
|
| 154 |
-
uri = str(uri)
|
| 155 |
-
if "://" not in uri:
|
| 156 |
-
return None
|
| 157 |
-
else:
|
| 158 |
-
prefix, _ = uri.split("://")
|
| 159 |
-
# In the case of Boto3Backend, the prefix may contains the cluster
|
| 160 |
-
# name like clusterName:s3
|
| 161 |
-
if ":" in prefix:
|
| 162 |
-
_, prefix = prefix.split(":")
|
| 163 |
-
return prefix
|
| 164 |
-
|
| 165 |
-
@classmethod
|
| 166 |
-
def infer_client(
|
| 167 |
-
cls,
|
| 168 |
-
file_client_args: Optional[dict] = None,
|
| 169 |
-
uri: Optional[Union[str, Path]] = None,
|
| 170 |
-
) -> "FileClient":
|
| 171 |
-
"""Infer a suitable file client based on the URI and arguments.
|
| 172 |
-
|
| 173 |
-
Args:
|
| 174 |
-
file_client_args (dict, optional): Arguments to instantiate a
|
| 175 |
-
FileClient. Defaults to None.
|
| 176 |
-
uri (str | Path, optional): Uri to be parsed that contains the file
|
| 177 |
-
prefix. Defaults to None.
|
| 178 |
-
|
| 179 |
-
Examples:
|
| 180 |
-
>>> uri = 's3://path/of/your/file'
|
| 181 |
-
>>> file_client = FileClient.infer_client(uri=uri)
|
| 182 |
-
>>> file_client_args = {'backend': 's3'}
|
| 183 |
-
>>> file_client = FileClient.infer_client(file_client_args)
|
| 184 |
-
|
| 185 |
-
Returns:
|
| 186 |
-
FileClient: Instantiated FileClient object.
|
| 187 |
-
"""
|
| 188 |
-
assert file_client_args is not None or uri is not None
|
| 189 |
-
if file_client_args is None:
|
| 190 |
-
file_prefix = cls.parse_uri_prefix(uri) # type: ignore
|
| 191 |
-
return cls(prefix=file_prefix)
|
| 192 |
-
else:
|
| 193 |
-
return cls(**file_client_args)
|
| 194 |
-
|
| 195 |
-
@classmethod
|
| 196 |
-
def _register_backend(cls, name, backend, force=False, prefixes=None):
|
| 197 |
-
if not isinstance(name, str):
|
| 198 |
-
raise TypeError(f"the backend name should be a string, but got {type(name)}")
|
| 199 |
-
if not inspect.isclass(backend):
|
| 200 |
-
raise TypeError(f"backend should be a class but got {type(backend)}")
|
| 201 |
-
if not issubclass(backend, BaseStorageBackend):
|
| 202 |
-
raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend")
|
| 203 |
-
if not force and name in cls._backends:
|
| 204 |
-
raise KeyError(
|
| 205 |
-
f'{name} is already registered as a storage backend, add "force=True" if you want to override it'
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
-
if name in cls._backends and force:
|
| 209 |
-
for arg_key, instance in list(cls._instances.items()):
|
| 210 |
-
if isinstance(instance.client, cls._backends[name]):
|
| 211 |
-
cls._instances.pop(arg_key)
|
| 212 |
-
cls._backends[name] = backend
|
| 213 |
-
|
| 214 |
-
if prefixes is not None:
|
| 215 |
-
if isinstance(prefixes, str):
|
| 216 |
-
prefixes = [prefixes]
|
| 217 |
-
else:
|
| 218 |
-
assert isinstance(prefixes, (list, tuple))
|
| 219 |
-
for prefix in prefixes:
|
| 220 |
-
if prefix not in cls._prefix_to_backends:
|
| 221 |
-
cls._prefix_to_backends[prefix] = backend
|
| 222 |
-
elif (prefix in cls._prefix_to_backends) and force:
|
| 223 |
-
overridden_backend = cls._prefix_to_backends[prefix]
|
| 224 |
-
for arg_key, instance in list(cls._instances.items()):
|
| 225 |
-
if isinstance(instance.client, overridden_backend):
|
| 226 |
-
cls._instances.pop(arg_key)
|
| 227 |
-
else:
|
| 228 |
-
raise KeyError(
|
| 229 |
-
f"{prefix} is already registered as a storage backend,"
|
| 230 |
-
' add "force=True" if you want to override it'
|
| 231 |
-
)
|
| 232 |
-
|
| 233 |
-
@classmethod
|
| 234 |
-
def register_backend(cls, name, backend=None, force=False, prefixes=None):
|
| 235 |
-
"""Register a backend to FileClient.
|
| 236 |
-
|
| 237 |
-
This method can be used as a normal class method or a decorator.
|
| 238 |
-
|
| 239 |
-
.. code-block:: python
|
| 240 |
-
|
| 241 |
-
class NewBackend(BaseStorageBackend):
|
| 242 |
-
|
| 243 |
-
def get(self, filepath):
|
| 244 |
-
return filepath
|
| 245 |
-
|
| 246 |
-
def get_text(self, filepath):
|
| 247 |
-
return filepath
|
| 248 |
-
|
| 249 |
-
FileClient.register_backend('new', NewBackend)
|
| 250 |
-
|
| 251 |
-
or
|
| 252 |
-
|
| 253 |
-
.. code-block:: python
|
| 254 |
-
|
| 255 |
-
@FileClient.register_backend('new')
|
| 256 |
-
class NewBackend(BaseStorageBackend):
|
| 257 |
-
|
| 258 |
-
def get(self, filepath):
|
| 259 |
-
return filepath
|
| 260 |
-
|
| 261 |
-
def get_text(self, filepath):
|
| 262 |
-
return filepath
|
| 263 |
-
|
| 264 |
-
Args:
|
| 265 |
-
name (str): The name of the registered backend.
|
| 266 |
-
backend (class, optional): The backend class to be registered,
|
| 267 |
-
which must be a subclass of :class:`BaseStorageBackend`.
|
| 268 |
-
When this method is used as a decorator, backend is None.
|
| 269 |
-
Defaults to None.
|
| 270 |
-
force (bool, optional): Whether to override the backend if the name
|
| 271 |
-
has already been registered. Defaults to False.
|
| 272 |
-
prefixes (str or list[str] or tuple[str], optional): The prefixes
|
| 273 |
-
of the registered storage backend. Defaults to None.
|
| 274 |
-
`New in version 1.3.15.`
|
| 275 |
-
"""
|
| 276 |
-
if backend is not None:
|
| 277 |
-
cls._register_backend(name, backend, force=force, prefixes=prefixes)
|
| 278 |
-
return
|
| 279 |
-
|
| 280 |
-
def _register(backend_cls):
|
| 281 |
-
cls._register_backend(name, backend_cls, force=force, prefixes=prefixes)
|
| 282 |
-
return backend_cls
|
| 283 |
-
|
| 284 |
-
return _register
|
| 285 |
-
|
| 286 |
-
def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]:
|
| 287 |
-
"""Read data from a given ``filepath`` with 'rb' mode.
|
| 288 |
-
|
| 289 |
-
Note:
|
| 290 |
-
There are two types of return values for ``get``, one is ``bytes``
|
| 291 |
-
and the other is ``memoryview``. The advantage of using memoryview
|
| 292 |
-
is that you can avoid copying, and if you want to convert it to
|
| 293 |
-
``bytes``, you can use ``.tobytes()``.
|
| 294 |
-
|
| 295 |
-
Args:
|
| 296 |
-
filepath (str or Path): Path to read data.
|
| 297 |
-
|
| 298 |
-
Returns:
|
| 299 |
-
bytes | memoryview: Expected bytes object or a memory view of the
|
| 300 |
-
bytes object.
|
| 301 |
-
"""
|
| 302 |
-
return self.client.get(filepath)
|
| 303 |
-
|
| 304 |
-
def get_text(self, filepath: Union[str, Path], encoding="utf-8") -> str:
|
| 305 |
-
"""Read data from a given ``filepath`` with 'r' mode.
|
| 306 |
-
|
| 307 |
-
Args:
|
| 308 |
-
filepath (str or Path): Path to read data.
|
| 309 |
-
encoding (str): The encoding format used to open the ``filepath``.
|
| 310 |
-
Defaults to 'utf-8'.
|
| 311 |
-
|
| 312 |
-
Returns:
|
| 313 |
-
str: Expected text reading from ``filepath``.
|
| 314 |
-
"""
|
| 315 |
-
return self.client.get_text(filepath, encoding)
|
| 316 |
-
|
| 317 |
-
def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
| 318 |
-
"""Write data to a given ``filepath`` with 'wb' mode.
|
| 319 |
-
|
| 320 |
-
Note:
|
| 321 |
-
``put`` should create a directory if the directory of ``filepath``
|
| 322 |
-
does not exist.
|
| 323 |
-
|
| 324 |
-
Args:
|
| 325 |
-
obj (bytes): Data to be written.
|
| 326 |
-
filepath (str or Path): Path to write data.
|
| 327 |
-
"""
|
| 328 |
-
self.client.put(obj, filepath)
|
| 329 |
-
|
| 330 |
-
def put_text(self, obj: str, filepath: Union[str, Path]) -> None:
|
| 331 |
-
"""Write data to a given ``filepath`` with 'w' mode.
|
| 332 |
-
|
| 333 |
-
Note:
|
| 334 |
-
``put_text`` should create a directory if the directory of
|
| 335 |
-
``filepath`` does not exist.
|
| 336 |
-
|
| 337 |
-
Args:
|
| 338 |
-
obj (str): Data to be written.
|
| 339 |
-
filepath (str or Path): Path to write data.
|
| 340 |
-
encoding (str, optional): The encoding format used to open the
|
| 341 |
-
`filepath`. Defaults to 'utf-8'.
|
| 342 |
-
"""
|
| 343 |
-
self.client.put_text(obj, filepath)
|
| 344 |
-
|
| 345 |
-
def remove(self, filepath: Union[str, Path]) -> None:
|
| 346 |
-
"""Remove a file.
|
| 347 |
-
|
| 348 |
-
Args:
|
| 349 |
-
filepath (str, Path): Path to be removed.
|
| 350 |
-
"""
|
| 351 |
-
self.client.remove(filepath)
|
| 352 |
-
|
| 353 |
-
def exists(self, filepath: Union[str, Path]) -> bool:
|
| 354 |
-
"""Check whether a file path exists.
|
| 355 |
-
|
| 356 |
-
Args:
|
| 357 |
-
filepath (str or Path): Path to be checked whether exists.
|
| 358 |
-
|
| 359 |
-
Returns:
|
| 360 |
-
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
|
| 361 |
-
"""
|
| 362 |
-
return self.client.exists(filepath)
|
| 363 |
-
|
| 364 |
-
def isdir(self, filepath: Union[str, Path]) -> bool:
|
| 365 |
-
"""Check whether a file path is a directory.
|
| 366 |
-
|
| 367 |
-
Args:
|
| 368 |
-
filepath (str or Path): Path to be checked whether it is a
|
| 369 |
-
directory.
|
| 370 |
-
|
| 371 |
-
Returns:
|
| 372 |
-
bool: Return ``True`` if ``filepath`` points to a directory,
|
| 373 |
-
``False`` otherwise.
|
| 374 |
-
"""
|
| 375 |
-
return self.client.isdir(filepath)
|
| 376 |
-
|
| 377 |
-
def isfile(self, filepath: Union[str, Path]) -> bool:
|
| 378 |
-
"""Check whether a file path is a file.
|
| 379 |
-
|
| 380 |
-
Args:
|
| 381 |
-
filepath (str or Path): Path to be checked whether it is a file.
|
| 382 |
-
|
| 383 |
-
Returns:
|
| 384 |
-
bool: Return ``True`` if ``filepath`` points to a file, ``False``
|
| 385 |
-
otherwise.
|
| 386 |
-
"""
|
| 387 |
-
return self.client.isfile(filepath)
|
| 388 |
-
|
| 389 |
-
def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str:
|
| 390 |
-
r"""Concatenate all file paths.
|
| 391 |
-
|
| 392 |
-
Join one or more filepath components intelligently. The return value
|
| 393 |
-
is the concatenation of filepath and any members of \*filepaths.
|
| 394 |
-
|
| 395 |
-
Args:
|
| 396 |
-
filepath (str or Path): Path to be concatenated.
|
| 397 |
-
|
| 398 |
-
Returns:
|
| 399 |
-
str: The result of concatenation.
|
| 400 |
-
"""
|
| 401 |
-
return self.client.join_path(filepath, *filepaths)
|
| 402 |
-
|
| 403 |
-
@contextmanager
|
| 404 |
-
def get_local_path(self, filepath: Union[str, Path]) -> Generator[Union[str, Path], None, None]:
|
| 405 |
-
"""Download data from ``filepath`` and write the data to local path.
|
| 406 |
-
|
| 407 |
-
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
|
| 408 |
-
can be called with ``with`` statement, and when exists from the
|
| 409 |
-
``with`` statement, the temporary path will be released.
|
| 410 |
-
|
| 411 |
-
Note:
|
| 412 |
-
If the ``filepath`` is a local path, just return itself.
|
| 413 |
-
|
| 414 |
-
.. warning::
|
| 415 |
-
``get_local_path`` is an experimental interface that may change in
|
| 416 |
-
the future.
|
| 417 |
-
|
| 418 |
-
Args:
|
| 419 |
-
filepath (str or Path): Path to be read data.
|
| 420 |
-
|
| 421 |
-
Examples:
|
| 422 |
-
>>> file_client = FileClient(prefix='s3')
|
| 423 |
-
>>> with file_client.get_local_path('s3://bucket/abc.jpg') as path:
|
| 424 |
-
... # do something here
|
| 425 |
-
|
| 426 |
-
Yields:
|
| 427 |
-
Iterable[str]: Only yield one path.
|
| 428 |
-
"""
|
| 429 |
-
with self.client.get_local_path(str(filepath)) as local_path:
|
| 430 |
-
yield local_path
|
| 431 |
-
|
| 432 |
-
def list_dir_or_file( # pylint: disable=too-many-arguments
|
| 433 |
-
self,
|
| 434 |
-
dir_path: Union[str, Path],
|
| 435 |
-
list_dir: bool = True,
|
| 436 |
-
list_file: bool = True,
|
| 437 |
-
suffix: Optional[Union[str, Tuple[str]]] = None,
|
| 438 |
-
recursive: bool = False,
|
| 439 |
-
) -> Iterator[str]:
|
| 440 |
-
"""Scan a directory to find the interested directories or files in
|
| 441 |
-
arbitrary order.
|
| 442 |
-
|
| 443 |
-
Note:
|
| 444 |
-
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
|
| 445 |
-
|
| 446 |
-
Args:
|
| 447 |
-
dir_path (str | Path): Path of the directory.
|
| 448 |
-
list_dir (bool): List the directories. Defaults to True.
|
| 449 |
-
list_file (bool): List the path of files. Defaults to True.
|
| 450 |
-
suffix (str or tuple[str], optional): File suffix
|
| 451 |
-
that we are interested in. Defaults to None.
|
| 452 |
-
recursive (bool): If set to True, recursively scan the
|
| 453 |
-
directory. Defaults to False.
|
| 454 |
-
|
| 455 |
-
Yields:
|
| 456 |
-
Iterable[str]: A relative path to ``dir_path``.
|
| 457 |
-
"""
|
| 458 |
-
yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/handlers/__init__.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 17 |
-
from lyra_2._ext.imaginaire.utils.easy_io.handlers.json_handler import JsonHandler
|
| 18 |
-
from lyra_2._ext.imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler
|
| 19 |
-
from lyra_2._ext.imaginaire.utils.easy_io.handlers.registry_utils import file_handlers, register_handler
|
| 20 |
-
from lyra_2._ext.imaginaire.utils.easy_io.handlers.yaml_handler import YamlHandler
|
| 21 |
-
|
| 22 |
-
__all__ = [
|
| 23 |
-
"BaseFileHandler",
|
| 24 |
-
"JsonHandler",
|
| 25 |
-
"PickleHandler",
|
| 26 |
-
"YamlHandler",
|
| 27 |
-
"register_handler",
|
| 28 |
-
"file_handlers",
|
| 29 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/handlers/base.py
DELETED
|
@@ -1,44 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from abc import ABCMeta, abstractmethod
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class BaseFileHandler(metaclass=ABCMeta):
|
| 20 |
-
# `str_like` is a flag to indicate whether the type of file object is
|
| 21 |
-
# str-like object or bytes-like object. Pickle only processes bytes-like
|
| 22 |
-
# objects but json only processes str-like object. If it is str-like
|
| 23 |
-
# object, `StringIO` will be used to process the buffer.
|
| 24 |
-
str_like = True
|
| 25 |
-
|
| 26 |
-
@abstractmethod
|
| 27 |
-
def load_from_fileobj(self, file, **kwargs):
|
| 28 |
-
pass
|
| 29 |
-
|
| 30 |
-
@abstractmethod
|
| 31 |
-
def dump_to_fileobj(self, obj, file, **kwargs):
|
| 32 |
-
pass
|
| 33 |
-
|
| 34 |
-
@abstractmethod
|
| 35 |
-
def dump_to_str(self, obj, **kwargs):
|
| 36 |
-
pass
|
| 37 |
-
|
| 38 |
-
def load_from_path(self, filepath, mode="r", **kwargs):
|
| 39 |
-
with open(filepath, mode) as f:
|
| 40 |
-
return self.load_from_fileobj(f, **kwargs)
|
| 41 |
-
|
| 42 |
-
def dump_to_path(self, obj, filepath, mode="w", **kwargs):
|
| 43 |
-
with open(filepath, mode) as f:
|
| 44 |
-
self.dump_to_fileobj(obj, f, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/handlers/byte_handler.py
DELETED
|
@@ -1,39 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from typing import IO
|
| 17 |
-
|
| 18 |
-
from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class ByteHandler(BaseFileHandler):
|
| 22 |
-
str_like = False
|
| 23 |
-
|
| 24 |
-
def load_from_fileobj(self, file: IO[bytes], **kwargs):
|
| 25 |
-
file.seek(0)
|
| 26 |
-
# extra all bytes and return
|
| 27 |
-
return file.read()
|
| 28 |
-
|
| 29 |
-
def dump_to_fileobj(
|
| 30 |
-
self,
|
| 31 |
-
obj: bytes,
|
| 32 |
-
file: IO[bytes],
|
| 33 |
-
**kwargs,
|
| 34 |
-
):
|
| 35 |
-
# write all bytes to file
|
| 36 |
-
file.write(obj)
|
| 37 |
-
|
| 38 |
-
def dump_to_str(self, obj, **kwargs):
|
| 39 |
-
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/handlers/csv_handler.py
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import csv
|
| 17 |
-
from io import StringIO
|
| 18 |
-
|
| 19 |
-
from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class CsvHandler(BaseFileHandler):
|
| 23 |
-
def load_from_fileobj(self, file, **kwargs):
|
| 24 |
-
del kwargs
|
| 25 |
-
reader = csv.reader(file)
|
| 26 |
-
return list(reader)
|
| 27 |
-
|
| 28 |
-
def dump_to_fileobj(self, obj, file, **kwargs):
|
| 29 |
-
del kwargs
|
| 30 |
-
writer = csv.writer(file)
|
| 31 |
-
if not all(isinstance(row, list) for row in obj):
|
| 32 |
-
raise ValueError("Each row must be a list")
|
| 33 |
-
writer.writerows(obj)
|
| 34 |
-
|
| 35 |
-
def dump_to_str(self, obj, **kwargs):
|
| 36 |
-
del kwargs
|
| 37 |
-
output = StringIO()
|
| 38 |
-
writer = csv.writer(output)
|
| 39 |
-
if not all(isinstance(row, list) for row in obj):
|
| 40 |
-
raise ValueError("Each row must be a list")
|
| 41 |
-
writer.writerows(obj)
|
| 42 |
-
return output.getvalue()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/handlers/gzip_handler.py
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import gzip
|
| 17 |
-
import pickle
|
| 18 |
-
from io import BytesIO
|
| 19 |
-
from typing import Any
|
| 20 |
-
|
| 21 |
-
from lyra_2._ext.imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class GzipHandler(PickleHandler):
|
| 25 |
-
str_like = False
|
| 26 |
-
|
| 27 |
-
def load_from_fileobj(self, file: BytesIO, **kwargs):
|
| 28 |
-
with gzip.GzipFile(fileobj=file, mode="rb") as f:
|
| 29 |
-
return pickle.load(f)
|
| 30 |
-
|
| 31 |
-
def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs):
|
| 32 |
-
with gzip.GzipFile(fileobj=file, mode="wb") as f:
|
| 33 |
-
pickle.dump(obj, f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/handlers/imageio_video_handler.py
DELETED
|
@@ -1,168 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from typing import IO, Any, Dict, Tuple
|
| 17 |
-
|
| 18 |
-
import imageio
|
| 19 |
-
import imageio.v3 as iio_v3
|
| 20 |
-
import numpy as np
|
| 21 |
-
import torch
|
| 22 |
-
|
| 23 |
-
from lyra_2._ext.imaginaire.utils import log
|
| 24 |
-
from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
class ImageioVideoHandler(BaseFileHandler):
|
| 28 |
-
str_like = False
|
| 29 |
-
|
| 30 |
-
def load_from_fileobj(
|
| 31 |
-
self, file: IO[bytes], format: str = "mp4", mode: str = "rgb", **kwargs
|
| 32 |
-
) -> Tuple[np.ndarray, Dict[str, Any]]:
|
| 33 |
-
"""
|
| 34 |
-
Load video from a file-like object using imageio.v3 with specified format and color mode.
|
| 35 |
-
|
| 36 |
-
Parameters:
|
| 37 |
-
file (IO[bytes]): A file-like object containing video data.
|
| 38 |
-
format (str): Format of the video file (default 'mp4').
|
| 39 |
-
mode (str): Color mode of the video, 'rgb' or 'gray' (default 'rgb').
|
| 40 |
-
|
| 41 |
-
Returns:
|
| 42 |
-
tuple: A tuple containing an array of video frames and metadata about the video.
|
| 43 |
-
"""
|
| 44 |
-
file.seek(0)
|
| 45 |
-
|
| 46 |
-
# The plugin argument in v3 replaces the format argument in v2
|
| 47 |
-
plugin = kwargs.pop("plugin", "pyav")
|
| 48 |
-
|
| 49 |
-
# Load all frames at once using v3 API
|
| 50 |
-
video_frames = iio_v3.imread(file, plugin=plugin, **kwargs)
|
| 51 |
-
|
| 52 |
-
# Handle grayscale conversion if needed
|
| 53 |
-
if mode == "gray":
|
| 54 |
-
import cv2
|
| 55 |
-
|
| 56 |
-
if len(video_frames.shape) == 4: # (frames, height, width, channels)
|
| 57 |
-
gray_frames = []
|
| 58 |
-
for frame in video_frames:
|
| 59 |
-
gray_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
| 60 |
-
gray_frame = np.expand_dims(gray_frame, axis=2) # Keep dimensions consistent
|
| 61 |
-
gray_frames.append(gray_frame)
|
| 62 |
-
video_frames = np.array(gray_frames)
|
| 63 |
-
|
| 64 |
-
# Extract metadata
|
| 65 |
-
# Note: iio_v3.imread doesn't return metadata directly like v2 did
|
| 66 |
-
# We need to extract it separately
|
| 67 |
-
file.seek(0)
|
| 68 |
-
metadata = self._extract_metadata(file, plugin=plugin)
|
| 69 |
-
|
| 70 |
-
return video_frames, metadata
|
| 71 |
-
|
| 72 |
-
def _extract_metadata(self, file: IO[bytes], plugin: str = "pyav") -> Dict[str, Any]:
|
| 73 |
-
"""
|
| 74 |
-
Extract metadata from a video file.
|
| 75 |
-
|
| 76 |
-
Parameters:
|
| 77 |
-
file (IO[bytes]): File-like object containing video data.
|
| 78 |
-
plugin (str): Plugin to use for reading.
|
| 79 |
-
|
| 80 |
-
Returns:
|
| 81 |
-
dict: Video metadata.
|
| 82 |
-
"""
|
| 83 |
-
try:
|
| 84 |
-
# Create a generator to read frames and metadata
|
| 85 |
-
metadata = iio_v3.immeta(file, plugin=plugin)
|
| 86 |
-
|
| 87 |
-
# Add some standard fields similar to v2 metadata format
|
| 88 |
-
if "fps" not in metadata and "duration" in metadata:
|
| 89 |
-
# Read the first frame to get shape information
|
| 90 |
-
file.seek(0)
|
| 91 |
-
first_frame = iio_v3.imread(file, plugin=plugin, index=0)
|
| 92 |
-
metadata["size"] = first_frame.shape[1::-1] # (width, height)
|
| 93 |
-
metadata["source_size"] = metadata["size"]
|
| 94 |
-
|
| 95 |
-
# Create a consistent metadata structure with v2
|
| 96 |
-
metadata["plugin"] = plugin
|
| 97 |
-
if "codec" not in metadata:
|
| 98 |
-
metadata["codec"] = "unknown"
|
| 99 |
-
if "pix_fmt" not in metadata:
|
| 100 |
-
metadata["pix_fmt"] = "unknown"
|
| 101 |
-
|
| 102 |
-
# Calculate nframes if possible
|
| 103 |
-
if "fps" in metadata and "duration" in metadata:
|
| 104 |
-
metadata["nframes"] = int(metadata["fps"] * metadata["duration"])
|
| 105 |
-
else:
|
| 106 |
-
metadata["nframes"] = float("inf")
|
| 107 |
-
|
| 108 |
-
return metadata
|
| 109 |
-
|
| 110 |
-
except Exception as e:
|
| 111 |
-
# Fallback to basic metadata
|
| 112 |
-
return {
|
| 113 |
-
"plugin": plugin,
|
| 114 |
-
"nframes": float("inf"),
|
| 115 |
-
"codec": "unknown",
|
| 116 |
-
"fps": 30.0, # Default values
|
| 117 |
-
"duration": 0,
|
| 118 |
-
"size": (0, 0),
|
| 119 |
-
}
|
| 120 |
-
|
| 121 |
-
def dump_to_fileobj(
|
| 122 |
-
self,
|
| 123 |
-
obj: np.ndarray | torch.Tensor,
|
| 124 |
-
file: IO[bytes],
|
| 125 |
-
format: str = "mp4", # pylint: disable=redefined-builtin
|
| 126 |
-
fps: int = 17,
|
| 127 |
-
quality: int = 5,
|
| 128 |
-
ffmpeg_params=None,
|
| 129 |
-
**kwargs,
|
| 130 |
-
):
|
| 131 |
-
"""
|
| 132 |
-
Save an array of video frames to a file-like object using imageio.
|
| 133 |
-
|
| 134 |
-
Parameters:
|
| 135 |
-
obj (Union[np.ndarray, torch.Tensor]): An array of frames to be saved as video.
|
| 136 |
-
file (IO[bytes]): A file-like object to which the video data will be written.
|
| 137 |
-
format (str): Format of the video file (default 'mp4').
|
| 138 |
-
fps (int): Frames per second of the output video (default 17).
|
| 139 |
-
quality (int): Quality of the video (0-10, default 5).
|
| 140 |
-
ffmpeg_params (list): Additional parameters to pass to ffmpeg.
|
| 141 |
-
|
| 142 |
-
"""
|
| 143 |
-
if isinstance(obj, torch.Tensor):
|
| 144 |
-
assert obj.dtype == torch.uint8, "Tensor must be of type uint8"
|
| 145 |
-
obj = obj.cpu().numpy()
|
| 146 |
-
h, w = obj.shape[1:-1]
|
| 147 |
-
|
| 148 |
-
# Default ffmpeg params that ensure width and height are set
|
| 149 |
-
default_ffmpeg_params = ["-s", f"{w}x{h}"]
|
| 150 |
-
|
| 151 |
-
# Use provided ffmpeg_params if any, otherwise use defaults
|
| 152 |
-
final_ffmpeg_params = ffmpeg_params if ffmpeg_params is not None else default_ffmpeg_params
|
| 153 |
-
|
| 154 |
-
mimsave_kwargs = {
|
| 155 |
-
"fps": fps,
|
| 156 |
-
"quality": quality,
|
| 157 |
-
"macro_block_size": 1,
|
| 158 |
-
"ffmpeg_params": final_ffmpeg_params,
|
| 159 |
-
"output_params": ["-f", "mp4"],
|
| 160 |
-
}
|
| 161 |
-
# Update with any other kwargs
|
| 162 |
-
mimsave_kwargs.update(kwargs)
|
| 163 |
-
log.debug(f"mimsave_kwargs: {mimsave_kwargs}")
|
| 164 |
-
|
| 165 |
-
imageio.mimsave(file, obj, format, **mimsave_kwargs)
|
| 166 |
-
|
| 167 |
-
def dump_to_str(self, obj, **kwargs):
|
| 168 |
-
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/handlers/json_handler.py
DELETED
|
@@ -1,49 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import json
|
| 17 |
-
|
| 18 |
-
import numpy as np
|
| 19 |
-
|
| 20 |
-
from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def set_default(obj):
|
| 24 |
-
"""Set default json values for non-serializable values.
|
| 25 |
-
|
| 26 |
-
It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
|
| 27 |
-
It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
|
| 28 |
-
etc.) into plain numbers of plain python built-in types.
|
| 29 |
-
"""
|
| 30 |
-
if isinstance(obj, (set, range)):
|
| 31 |
-
return list(obj)
|
| 32 |
-
elif isinstance(obj, np.ndarray):
|
| 33 |
-
return obj.tolist()
|
| 34 |
-
elif isinstance(obj, np.generic):
|
| 35 |
-
return obj.item()
|
| 36 |
-
raise TypeError(f"{type(obj)} is unsupported for json dump")
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
class JsonHandler(BaseFileHandler):
|
| 40 |
-
def load_from_fileobj(self, file):
|
| 41 |
-
return json.load(file)
|
| 42 |
-
|
| 43 |
-
def dump_to_fileobj(self, obj, file, **kwargs):
|
| 44 |
-
kwargs.setdefault("default", set_default)
|
| 45 |
-
json.dump(obj, file, **kwargs)
|
| 46 |
-
|
| 47 |
-
def dump_to_str(self, obj, **kwargs):
|
| 48 |
-
kwargs.setdefault("default", set_default)
|
| 49 |
-
return json.dumps(obj, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/handlers/jsonl_handler.py
DELETED
|
@@ -1,80 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import json
|
| 17 |
-
from typing import IO
|
| 18 |
-
|
| 19 |
-
import numpy as np
|
| 20 |
-
|
| 21 |
-
from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def set_default(obj):
|
| 25 |
-
"""Set default json values for non-serializable values.
|
| 26 |
-
|
| 27 |
-
It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
|
| 28 |
-
It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
|
| 29 |
-
etc.) into plain numbers of plain python built-in types.
|
| 30 |
-
"""
|
| 31 |
-
if isinstance(obj, (set, range)):
|
| 32 |
-
return list(obj)
|
| 33 |
-
elif isinstance(obj, np.ndarray):
|
| 34 |
-
return obj.tolist()
|
| 35 |
-
elif isinstance(obj, np.generic):
|
| 36 |
-
return obj.item()
|
| 37 |
-
raise TypeError(f"{type(obj)} is unsupported for json dump")
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class JsonlHandler(BaseFileHandler):
|
| 41 |
-
"""Handler for JSON lines (JSONL) files."""
|
| 42 |
-
|
| 43 |
-
def load_from_fileobj(self, file: IO[bytes]):
|
| 44 |
-
"""Load JSON objects from a newline-delimited JSON (JSONL) file object.
|
| 45 |
-
|
| 46 |
-
Returns:
|
| 47 |
-
A list of Python objects loaded from each JSON line.
|
| 48 |
-
"""
|
| 49 |
-
data = []
|
| 50 |
-
for line in file:
|
| 51 |
-
line = line.strip()
|
| 52 |
-
if not line:
|
| 53 |
-
continue # skip empty lines if any
|
| 54 |
-
data.append(json.loads(line))
|
| 55 |
-
return data
|
| 56 |
-
|
| 57 |
-
def dump_to_fileobj(self, obj: IO[bytes], file, **kwargs):
|
| 58 |
-
"""Dump a list of objects to a newline-delimited JSON (JSONL) file object.
|
| 59 |
-
|
| 60 |
-
Args:
|
| 61 |
-
obj: A list (or iterable) of objects to dump line by line.
|
| 62 |
-
"""
|
| 63 |
-
kwargs.setdefault("default", set_default)
|
| 64 |
-
for item in obj:
|
| 65 |
-
file.write(json.dumps(item, **kwargs) + "\n")
|
| 66 |
-
|
| 67 |
-
def dump_to_str(self, obj, **kwargs):
|
| 68 |
-
"""Dump a list of objects to a newline-delimited JSON (JSONL) string."""
|
| 69 |
-
kwargs.setdefault("default", set_default)
|
| 70 |
-
lines = [json.dumps(item, **kwargs) for item in obj]
|
| 71 |
-
return "\n".join(lines)
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
if __name__ == "__main__":
|
| 75 |
-
from lyra_2._ext.imaginaire.utils.easy_io import easy_io
|
| 76 |
-
|
| 77 |
-
easy_io.dump([1, 2, 3], "test.jsonl", file_format="jsonl")
|
| 78 |
-
print(easy_io.load("test.jsonl"))
|
| 79 |
-
easy_io.dump([{"key1": 1, "key2": 2}, {"key1": 3, "key2": 4}], "test.jsonl", file_format="jsonl")
|
| 80 |
-
print(easy_io.load("test.jsonl"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/handlers/np_handler.py
DELETED
|
@@ -1,89 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
from io import BytesIO
|
| 17 |
-
from typing import IO, Any
|
| 18 |
-
|
| 19 |
-
import numpy as np
|
| 20 |
-
|
| 21 |
-
from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class NumpyHandler(BaseFileHandler):
|
| 25 |
-
str_like = False
|
| 26 |
-
|
| 27 |
-
def load_from_fileobj(self, file: IO[bytes], **kwargs) -> Any:
|
| 28 |
-
"""
|
| 29 |
-
Load a NumPy array from a file-like object.
|
| 30 |
-
|
| 31 |
-
Parameters:
|
| 32 |
-
file (IO[bytes]): The file-like object containing the NumPy array data.
|
| 33 |
-
**kwargs: Additional keyword arguments passed to `np.load`.
|
| 34 |
-
|
| 35 |
-
Returns:
|
| 36 |
-
numpy.ndarray: The loaded NumPy array.
|
| 37 |
-
"""
|
| 38 |
-
return np.load(file, **kwargs)
|
| 39 |
-
|
| 40 |
-
def load_from_path(self, filepath: str, **kwargs) -> Any:
|
| 41 |
-
"""
|
| 42 |
-
Load a NumPy array from a file path.
|
| 43 |
-
|
| 44 |
-
Parameters:
|
| 45 |
-
filepath (str): The path to the file to load.
|
| 46 |
-
**kwargs: Additional keyword arguments passed to `np.load`.
|
| 47 |
-
|
| 48 |
-
Returns:
|
| 49 |
-
numpy.ndarray: The loaded NumPy array.
|
| 50 |
-
"""
|
| 51 |
-
return super().load_from_path(filepath, mode="rb", **kwargs)
|
| 52 |
-
|
| 53 |
-
def dump_to_str(self, obj: np.ndarray, **kwargs) -> str:
|
| 54 |
-
"""
|
| 55 |
-
Serialize a NumPy array to a string in binary format.
|
| 56 |
-
|
| 57 |
-
Parameters:
|
| 58 |
-
obj (np.ndarray): The NumPy array to serialize.
|
| 59 |
-
**kwargs: Additional keyword arguments passed to `np.save`.
|
| 60 |
-
|
| 61 |
-
Returns:
|
| 62 |
-
str: The serialized NumPy array as a string.
|
| 63 |
-
"""
|
| 64 |
-
with BytesIO() as f:
|
| 65 |
-
np.save(f, obj, **kwargs)
|
| 66 |
-
return f.getvalue()
|
| 67 |
-
|
| 68 |
-
def dump_to_fileobj(self, obj: np.ndarray, file: IO[bytes], **kwargs):
|
| 69 |
-
"""
|
| 70 |
-
Dump a NumPy array to a file-like object.
|
| 71 |
-
|
| 72 |
-
Parameters:
|
| 73 |
-
obj (np.ndarray): The NumPy array to dump.
|
| 74 |
-
file (IO[bytes]): The file-like object to which the array is dumped.
|
| 75 |
-
**kwargs: Additional keyword arguments passed to `np.save`.
|
| 76 |
-
"""
|
| 77 |
-
np.save(file, obj, **kwargs)
|
| 78 |
-
|
| 79 |
-
def dump_to_path(self, obj: np.ndarray, filepath: str, **kwargs):
|
| 80 |
-
"""
|
| 81 |
-
Dump a NumPy array to a file path.
|
| 82 |
-
|
| 83 |
-
Parameters:
|
| 84 |
-
obj (np.ndarray): The NumPy array to dump.
|
| 85 |
-
filepath (str): The file path where the array should be saved.
|
| 86 |
-
**kwargs: Additional keyword arguments passed to `np.save`.
|
| 87 |
-
"""
|
| 88 |
-
with open(filepath, "wb") as f:
|
| 89 |
-
np.save(f, obj, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lyra_2/_ext/imaginaire/utils/easy_io/handlers/pandas_handler.py
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
-
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
|
| 16 |
-
import pandas as pd
|
| 17 |
-
|
| 18 |
-
from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler # isort:skip
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class PandasHandler(BaseFileHandler):
|
| 22 |
-
str_like = False
|
| 23 |
-
|
| 24 |
-
def load_from_fileobj(self, file, **kwargs):
|
| 25 |
-
return pd.read_csv(file, **kwargs)
|
| 26 |
-
|
| 27 |
-
def dump_to_fileobj(self, obj, file, **kwargs):
|
| 28 |
-
obj.to_csv(file, **kwargs)
|
| 29 |
-
|
| 30 |
-
def dump_to_str(self, obj, **kwargs):
|
| 31 |
-
raise NotImplementedError("PandasHandler does not support dumping to str")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|