diff --git a/lyra_2/__init__.py b/lyra_2/__init__.py deleted file mode 100644 index dac9a4d7496eb38831f1f3c820a90d50e25e2a7e..0000000000000000000000000000000000000000 --- a/lyra_2/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/lyra_2/_ext/__init__.py b/lyra_2/_ext/__init__.py deleted file mode 100644 index 8ab6042807808a9884937177e5f28bdf16e915c0..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Copied from cosmos repository, with necessary modifications. -Original source: https://github.com/nvidia-cosmos/cosmos-predict2.5/ -""" diff --git a/lyra_2/_ext/imaginaire/__init__.py b/lyra_2/_ext/imaginaire/__init__.py deleted file mode 100644 index dac9a4d7496eb38831f1f3c820a90d50e25e2a7e..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/lyra_2/_ext/imaginaire/checkpointer/__init__.py b/lyra_2/_ext/imaginaire/checkpointer/__init__.py deleted file mode 100644 index dac9a4d7496eb38831f1f3c820a90d50e25e2a7e..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/checkpointer/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/lyra_2/_ext/imaginaire/checkpointer/base.py b/lyra_2/_ext/imaginaire/checkpointer/base.py deleted file mode 100644 index a8292fc664894c0b0addda90bb217561774fea07..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/checkpointer/base.py +++ /dev/null @@ -1,177 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from abc import ABC, abstractmethod -from typing import Optional - -import torch - -from lyra_2._ext.imaginaire.config import CheckpointConfig, JobConfig -from lyra_2._ext.imaginaire.model import ImaginaireModel -from lyra_2._ext.imaginaire.utils import callback -from lyra_2._ext.imaginaire.utils.easy_io import easy_io - - -class AbstractCheckpointer(ABC): - """The checkpointer class. Supports checkpoint saving/loading to both local disk or object store.""" - - def __init__( - self, - config_checkpoint: CheckpointConfig, - config_job: JobConfig, - callbacks: Optional[callback.CallBackGroup] = None, - ): - """Constructor of the checkpointer. - - Args: - config_checkpoint (CheckpointConfig): The config object for the checkpointer. - """ - self.config_checkpoint = config_checkpoint - # Set the callback functions. - self.callbacks = callbacks - self.save_to_object_store = config_checkpoint.save_to_object_store.enabled - self.load_from_object_store = config_checkpoint.load_from_object_store.enabled - - # Set checkpoint directories for local and object store paths - self._local_dirname = os.path.join(config_job.path_local, "checkpoints") - self._object_store_dirname = os.path.join(config_job.path, "checkpoints") - - self.strict_resume = config_checkpoint.strict_resume - self.load_path = config_checkpoint.load_path or None - self.load_training_state = config_checkpoint.load_training_state - self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state - self.save_thread = None - self.verbose = config_checkpoint.verbose - self.keys_not_to_resume = config_checkpoint.keys_not_to_resume - self.broadcast_via_filesystem = config_checkpoint.broadcast_via_filesystem - # Create the object store client interface. - if config_checkpoint.load_from_object_store.enabled: - self.load_s3_backend_key = "_ckpt_s3_loader" - easy_io.set_s3_backend( - key="_ckpt_s3_loader", - backend_args={ - "backend": "s3", - "path_mapping": { - "s3://ckpt/": f"s3://{config_checkpoint.load_from_object_store.bucket}/", - }, - "s3_credential_path": config_checkpoint.load_from_object_store.credentials, - }, - ) - else: - self.load_s3_backend_key = None - - if config_checkpoint.save_to_object_store.enabled: - self.save_s3_backend_key = "_ckpt_s3_saver" - easy_io.set_s3_backend( - key="_ckpt_s3_saver", - backend_args={ - "backend": "s3", - "path_mapping": { - "s3://ckpt/": f"s3://{config_checkpoint.save_to_object_store.bucket}/", - }, - "s3_credential_path": config_checkpoint.save_to_object_store.credentials, - }, - ) - else: - self.save_s3_backend_key = None - - @abstractmethod - def save( - self, - model: ImaginaireModel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - iteration: int, - ) -> None: - pass - - @abstractmethod - def load( - self, - model: ImaginaireModel, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, - grad_scaler: Optional[torch.amp.GradScaler] = None, - ) -> int: - pass - - @property - def save_bucket(self): - """Get the bucket name for saving checkpoints.""" - return self.config_checkpoint.save_to_object_store.bucket if self.save_to_object_store else None - - @property - def load_bucket(self): - """Get the bucket name for loading checkpoints.""" - return self.config_checkpoint.load_from_object_store.bucket if self.load_from_object_store else None - - @property - def save_dirname(self): - return ( - f"s3://{self.save_bucket}/{self._object_store_dirname}" - if self.save_to_object_store - else self._local_dirname - ) - - @property - def load_dirname(self): - return ( - f"s3://{self.load_bucket}/{self._object_store_dirname}" - if self.load_from_object_store - else self._local_dirname - ) - - def finalize(self) -> None: - """Finalize the checkpointer.""" - if self.save_thread: - self.save_thread.join() - - def _read_latest_checkpoint_file(self) -> str | None: - """Get the file name of the latest saved checkpoint. If it doesn't exist, return None. - - Returns: - checkpoint_file (str | None): file name of the latest saved checkpoint. - """ - checkpoint_file = None - checkpoint_path = os.path.join(self.load_dirname, "latest_checkpoint.txt") - if easy_io.exists(f"{checkpoint_path}", backend_key=self.load_s3_backend_key): - checkpoint_file = easy_io.load(f"{checkpoint_path}", backend_key=self.load_s3_backend_key).strip() - - return checkpoint_file - - def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None: - """Track the file name of the latest saved checkpoint. - - Args: - checkpoint_file (str): file name of the latest saved checkpoint. - """ - content = f"{checkpoint_file}\n" - checkpoint_path = os.path.join(self.save_dirname, "latest_checkpoint.txt") - easy_io.dump( - content, - checkpoint_path, - backend_key=self.save_s3_backend_key, - ) - - def _check_checkpoint_exists(self, checkpoint_path: str) -> None: - """If the file checkpoint_path does not exist, raise an error. - - Args: - checkpoint_path (str): full path to the checkpoint. - """ - if not easy_io.exists(f"{checkpoint_path}", backend_key=self.load_s3_backend_key): - raise FileNotFoundError(f"File not found (object store): {checkpoint_path}") diff --git a/lyra_2/_ext/imaginaire/checkpointer/dcp.py b/lyra_2/_ext/imaginaire/checkpointer/dcp.py deleted file mode 100644 index f493796d9e5a379ce8d17f11b5372fcbafa4beaf..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/checkpointer/dcp.py +++ /dev/null @@ -1,1003 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -Distributed checkpoint (DCP) directory structure and storage backends. - -The checkpointer saves model state in a sharded format across multiple processes: - -self.save_dirname/ -├── iter_000000005/ # Checkpoint at iteration 5 -│ ├── model/ # Model state shards -│ │ ├── __0_0.distcp # Shard 0 from rank 0 -│ │ └── __1_0.distcp # Shard 1 from rank 1 -│ ├── optim/ # Optimizer state shards -│ │ ├── __0_0.distcp # Shard 0 from rank 0 -│ │ └── __1_0.distcp # Shard 1 from rank 1 -│ ├── scheduler/ # Learning rate scheduler state -│ │ ├── __0_0.distcp # Shard 0 from rank 0 -│ │ └── __1_0.distcp # Shard 1 from rank 1 -│ └── trainer/ # Additional training state -│ ├── __0_0.distcp # Shard 0 from rank 0 -│ └── __1_0.distcp # Shard 1 from rank 1 -└── latest_checkpoint.txt # Points to most recent checkpoint folder, e.g. iter_000000005 - -Storage path format: - self.save_dirname = "{config_job.path_local}/checkpoints" - -The sharded format enables efficient distributed saving/loading by: -1. Parallelizing I/O across processes -2. Reducing memory usage per process -3. Supporting both local and cloud storage backends -""" - -import enum -import functools -import multiprocessing -import os -import queue -import re -import time -import warnings -from collections import namedtuple -from multiprocessing import get_context -from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast - -import torch -import torch.distributed -import torch.distributed.checkpoint as dcp -from torch import nn -from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter -from torch.distributed.checkpoint._storage_utils import _storage_setup -from torch.distributed.checkpoint.default_planner import DefaultSavePlanner -from torch.distributed.checkpoint.logger import _dcp_method_logger -from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - get_model_state_dict, - get_optimizer_state_dict, - set_model_state_dict, - set_optimizer_state_dict, -) -from torch.distributed.checkpoint.stateful import Stateful -from torch.distributed.checkpoint.storage import StorageReader -from torch.distributed.checkpoint.utils import _api_bc_check, _DistWrapper, _profile -from torch.distributed.tensor import distribute_tensor - -from lyra_2._ext.imaginaire.checkpointer.base import AbstractCheckpointer -from lyra_2._ext.imaginaire.checkpointer.s3_filesystem import S3StorageReader, S3StorageWriter -from lyra_2._ext.imaginaire.config import CheckpointConfig, JobConfig -from lyra_2._ext.imaginaire.model import ImaginaireModel -from lyra_2._ext.imaginaire.utils import callback, distributed, log, misc -from lyra_2._ext.imaginaire.utils.easy_io import easy_io -from lyra_2._src.models.wan_t2v_model import WANDiffusionModel as DiffusionModel - -try: - """ - We override the default load function to skip loadding _extra_state keys created by transformer-engine. - """ - from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner as _DefaultLoadPlanner - from torch.distributed.checkpoint.default_planner import ( - DTensor, - LoadPlan, - _create_read_items, - _version, - flatten_state_dict, - ) - from torch.distributed.checkpoint.metadata import Metadata, TensorStorageMetadata - - def create_default_local_load_plan( - state_dict: dict[str, Any], metadata: Metadata, strict: bool = True, dcp_allow_mismatched_size: bool = False - ) -> LoadPlan: - requests = [] - """ - Create the ``LoadPlan`` used by DefaultLoadPlanner. - - It produces one read item per value in ``state_dict`` using the metadata in ``metadata``. - - The default behavior is to match key exactly between state_dict and metadata. - It handles resharding by issuing multiple read requests against storage in order to match - load requirements. - """ - - for fqn, obj in state_dict.items(): - if fqn.endswith("._extra_state"): # dirty TE attention package! - continue - # ignore state_dict keys which do not exist in `state_dict` if strict=False - if fqn not in metadata.state_dict_metadata: - if strict: - raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.") - else: - log.warning(f"Local load plan: Missing key in checkpoint state_dict: {fqn}.") - continue - - md = metadata.state_dict_metadata[fqn] - - if not dcp_allow_mismatched_size: - if ( - isinstance(md, TensorStorageMetadata) - and getattr(obj, "size", None) is not None - and md.size != obj.size() - ): - if not strict: - log.critical(f"Size mismatch between saved {md.size} and current: {obj.size()} for {fqn}") - continue - else: - raise ValueError( - f"Size mismatch between saved {md.size} and current: {obj.size()} for {fqn}", - ) - # Since DTensor supports submesh, adding extra check to ensure _create_read_items() - # gets called only when the current rank is part of the mesh for the corresponding DTensor. - if isinstance(obj, DTensor): - if obj.device_mesh.get_coordinate() is not None: - requests += _create_read_items(fqn, md, obj) - else: - requests += _create_read_items(fqn, md, obj) - - return LoadPlan(requests) - - class DefaultLoadPlanner(_DefaultLoadPlanner): - def set_partial_channel_weight(self, dcp_allow_mismatched_size: bool): - self.dcp_allow_mismatched_size = dcp_allow_mismatched_size - - def create_local_plan(self) -> LoadPlan: - assert self.metadata is not None - if self.flatten_state_dict: - # To support checkpoints that are saved before v2.4, we have to - # differentiate if the missing keys are due to old checkpoints. - # The contracts are: - # 1. There are 3 cases when we found a missing key. - # 1.1 Actual missing key, but allow_partial_load is False - # 1.2 Actual missing key, but allow_partial load is True - # 1.3 Old checkpoint, but allow_partial_load is False - # 1.4 Old checkpoint, but allow_partial_load is True - # 2. If we found a missing key, we first convert the keys back to - # the key format of v2.3 - # 3. If the previous missing keys are in the v2.3 keys, we assume - # this is a old checkpoint. - # 4. Pass the state_dict to `create_default_local_load_plan()`, - # which has the logic to check missing for allow_partial_load. - # So for 1.2 and 1.4 cases, we delegate allow_partial_load check to - # `create_default_local_load_plan()`. The logic here is to determine - # whether the checkpoint belong to 2.3 (or before) or 2.4 (or after). - current_keys = set(self.state_dict.keys()) - load_keys = set(self.metadata.state_dict_metadata.keys()) - missing_keys = load_keys - current_keys - if missing_keys: - _version._derived_version = "2_3" - old_state_dict, old_mappings = flatten_state_dict(self.original_state_dict) - old_keys = set(old_state_dict.keys()) - if old_keys & missing_keys: - self.state_dict, self.mappings = old_state_dict, old_mappings - # _derived_version is only used by flatten_state_dict now. - # Set it back to None so that later we can save to a new version. - _version._derived_version = None - - return create_default_local_load_plan( - self.state_dict, - self.metadata, - not self.allow_partial_load, - getattr(self, "dcp_allow_mismatched_size", False), - ) - - log.info("for the back comptiable pytorch! New DefaultLoadPlanner class is created.") - - @_dcp_method_logger(log_exceptions=True) - @_api_bc_check - def load( - state_dict: dict[str, Any], - *, - checkpoint_id: Union[str, os.PathLike, None] = None, - storage_reader: Optional[StorageReader] = None, - planner: Optional[DefaultLoadPlanner] = None, - process_group: Optional[torch.distributed.ProcessGroup] = None, - no_dist: bool = False, - ) -> None: - """ - We override the default load function to perform a load plan check for mismatched/missing keys. - ==========================Original Doc string===================================== - Load a checkpoint into a distributed state dict in SPMD style. - - Each rank must have the same keys in their ``state_dict`` provided to this - API. Mismatched keys may result in hangs or errors. If unsure, you can use - the ``utils._assert_same_keys`` API to check (but may incur communication - costs). - - Each rank will try to read the least amount of data necessary - to fullfill the requested `state_dict`. When loading :class:`ShardedTensor` - or :class:`DTensor` instances, each rank only reads data for their local shards. - - For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``), - load will first call ``state_dict`` before attempting deserialization, followed by - ``load_state_dict`` once the deserialization is complete. - For each non-``Stateful`` object, load will deserailize the object, and then replace - it in the ``state_dict`` with the deserialized object. - - .. warning:: - All tensors in ``state_dict`` must be allocated on their - destination device *prior to* calling this function. - - All non-tensor data is loaded using `torch.load()` and modified in place - on state_dict. - - .. warning:: - Users must call `load_state_dict` on the root module to ensure load - pos-processing and non-tensor data properly propagates. - - .. note: - If no process group is initialized, this function will assume the intent - is to load a checkpoint into the local process. This can be useful in the - case of local inference, and when using regular Tensors (as opposed to DTensor - or ShardedTensor) - - .. note: - Rank 0 is assumed to be the coordinator rank. - - Args: - state_dict (Dict[str, Any]): The state_dict to load the checkpoint into. - checkpoint_id (Union[str, os.PathLike, None]): - The ID of this checkpoint instance. The meaning of the checkpoint_id - depends on the storage. It can be a path to a folder or to a file. - It can also be a key if the storage is a key-value store. - (Default: ``None``) - storage_reader (Optional[StorageReader]): - Instance of StorageWriter used to perform reads. If this is not - specified, DCP will automatically infer the reader based on the - checkpoint_id. If checkpoint_id is also None, an exception will - be raised. (Default: ``None``) - planner (Optional[LoadPlanner]): - Instance of LoadPlanner. If this is not specificed, the default - planner will be used. (Default: ``None``) - process_group (Optional[ProcessGroup]): - ProcessGroup to be used for cross-rank synchronization. - (Default: ``None``) - no_dist (bool): If ``True``, this function will assume the intent is to load - a checkpoint without using cross-rank synchronization. (Default: ``False``) - Returns: - None. - - Examples - >>> # xdoctest: +SKIP - >>> my_model = MyModule() - >>> optimizer = Adagrad(my_model.parameters()) - >>> model_state_dict = my_model.state_dict() - >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader( - ... "/checkpoint/1" - ... ) - - >>> torch.distributed.checkpoint.load_state_dict( - >>> state_dict=model_state_dict, - >>> storage_reader=fs_storage_reader, - >>> ) - - >>> # module.load_state_dict() function might have customized steps - >>> # to flush the state_dict, must call it to - >>> # ensure correct behavior. - >>> my_model.load_state_dict(model_state_dict) - - .. note:: - load_state_dict uses collectives to coordinate reads across ranks. - For NCCL-based process groups, internal tensor representations of - objects must be moved to the GPU device before communication takes place. - In this case, the device used is given by ``torch.cuda.current_device()`` - and it is the user's responsibility to ensure that this is set so that each - rank has an individual GPU, via ``torch.cuda.set_device()``. - """ - - no_dist = no_dist or (not torch.distributed.is_available()) or (not torch.distributed.is_initialized()) - if no_dist: - warnings.warn( - "torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to load in a single process." - ) - - with _profile(): - storage_reader = cast(StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True)) - - # All ranks must have the same keys in their `state_dict` provided to - # this API. See documentation for more details. - # Here we simply sort the keys to ensure that all ranks load values in - # the same order. - keys = sorted(state_dict.keys()) - - statetful_sd = {} - for key in keys: - if key not in state_dict: - continue - elem = state_dict[key] - statetful_sd[key] = elem.state_dict() if isinstance(elem, Stateful) else elem - - _load_state_dict( - state_dict=statetful_sd, - storage_reader=storage_reader, - process_group=process_group, - no_dist=no_dist, - planner=planner, - ) - for key in keys: - if key not in state_dict: - continue - elem = state_dict[key] - if isinstance(elem, Stateful): - # If the state_dict is a Stateful object, - # DCP does an in-place load in the original state dict. - elem.load_state_dict(statetful_sd[key]) - else: - # Otherwise, replace the state_dict with the loaded state_dict. - state_dict[key] = statetful_sd[key] - - def _load_state_dict( - state_dict: dict[str, Any], - storage_reader: StorageReader, - process_group: Optional[torch.distributed.ProcessGroup] = None, - coordinator_rank: int = 0, - no_dist: bool = False, - planner: Optional[DefaultLoadPlanner] = None, - ) -> None: - torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict") - - distW = _DistWrapper(process_group, not no_dist, coordinator_rank) - if planner is None: - planner = DefaultLoadPlanner() - - ckpt_kwargs = {} - if (ckpt_id := getattr(storage_reader, "checkpoint_id", None)) is not None: - ckpt_kwargs["checkpoint_id"] = ckpt_id - ckpt_kwargs["process_group"] = distW.group - - @_dcp_method_logger(**ckpt_kwargs) - def local_step(): - assert planner is not None - metadata = storage_reader.read_metadata() - planner.set_up_planner(state_dict, metadata, distW.is_coordinator) - storage_reader.set_up_storage_reader(metadata, distW.is_coordinator) - - local_plan = planner.create_local_plan() - local_plan = storage_reader.prepare_local_plan(local_plan) - return local_plan - - @_dcp_method_logger(**ckpt_kwargs) - def global_step(all_local_plans): - assert planner is not None - all_local_plans = planner.create_global_plan(all_local_plans) - all_local_plans = storage_reader.prepare_global_plan(all_local_plans) - return all_local_plans - - central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step) - if distW.is_coordinator: - # Compare central_plan items with storage_reader.storage_data keys - dest_fqns = set() - storage_fqns = set() - - # Extract FQNs from central_plan items - for item in central_plan.items: - if hasattr(item, "dest_index") and hasattr(item.dest_index, "fqn"): - dest_fqns.add(item.dest_index.fqn) - if hasattr(item, "storage_index") and hasattr(item.storage_index, "fqn"): - storage_fqns.add(item.storage_index.fqn) - - # Get storage data keys - storage_data_keys = set() - if hasattr(storage_reader, "storage_data") and storage_reader.storage_data is not None: - storage_data_keys = set(item[0].fqn for item in storage_reader.storage_data.items()) - state_dict_keys = set(state_dict.keys()) - # Compare sets and log differences - # Remove any item that has "_extra_state" as substring in the sets - state_dict_keys = {fqn for fqn in state_dict_keys if "_extra_state" not in fqn} - dest_fqns = {fqn for fqn in dest_fqns if "_extra_state" not in fqn} - storage_fqns = {fqn for fqn in storage_fqns if "_extra_state" not in fqn} - storage_data_keys = {fqn for fqn in storage_data_keys if "_extra_state" not in fqn} - - log.info("=== Load Plan FQN Analysis ===") - log.info(f"State Dict FQNs count: {len(state_dict_keys)}") - log.info(f"Destination FQNs count (without _extra_state): {len(dest_fqns)}") - log.info(f"Loaded FQNs count (without _extra_state): {len(storage_fqns)}") - log.info(f"In Storage keys count (without _extra_state): {len(storage_data_keys)}") - - # Find missing keys in each direction - state_dict_missing_from_dest = state_dict_keys - dest_fqns - storage_data_missing_from_storage_fqns = storage_data_keys - storage_fqns - - if state_dict_missing_from_dest: - log.info( - f"State Dict FQNs missing from load plan ({len(state_dict_missing_from_dest)} items): {sorted(state_dict_missing_from_dest)}" - ) - else: - log.info("✓ All State Dict FQNs found in storage_data") - - if storage_data_missing_from_storage_fqns: - # If there are more than 100 "net_ema" keys in storage_data_missing_from_storage_fqns, summarize them - net_ema_keys = {k for k in storage_data_missing_from_storage_fqns if "net_ema" in k} - if len(net_ema_keys) > 100: - storage_data_missing_from_storage_fqns = storage_data_missing_from_storage_fqns - net_ema_keys - storage_data_missing_from_storage_fqns = set( - storage_data_missing_from_storage_fqns - ) # ensure set type - storage_data_missing_from_storage_fqns.add("net_ema") - log.info( - f"Summarized {len(net_ema_keys)} 'net_ema' keys as 'net_ema' in missing storage data keys." - ) - log.info( - f"Storage data keys not loaded by load plan ({len(storage_data_missing_from_storage_fqns)} items): {sorted(storage_data_missing_from_storage_fqns)}" - ) - else: - log.info("✓ All storage data keys found in Loaded FQNs") - - log.info("=== End Load Plan FQN Analysis ===") - - @_dcp_method_logger(**ckpt_kwargs) - def read_data(): - assert planner is not None - final_local_plan = planner.finish_plan(central_plan) - all_reads = storage_reader.read_data(final_local_plan, planner) - - all_reads.wait() - return None - - _ = distW.all_gather("read", read_data) - -except ImportError as e: - from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner - - log.critical(f"{e}, using default planner") - -StateDictItemPath = namedtuple("StateDictItemPath", ["state_dict", "save_path"]) - - -class ModelWrapper(Stateful): - """Wrapper for model state dict handling""" - - def __init__(self, model: Union[nn.Module, List[nn.Module]], load_ema_to_reg: bool = False): - self.model = [model] if isinstance(model, nn.Module) else model - self.load_ema_to_reg = load_ema_to_reg - if self.load_ema_to_reg: - assert isinstance(model, DiffusionModel) - - def state_dict(self) -> Dict[str, Any]: - _state_dict = {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()} - if self.load_ema_to_reg: - if not self.model[0].config.ema.enabled: - all_keys = list(_state_dict.keys()) - assert all(k.startswith("net.") for k in all_keys), "All keys must start with net." - for k in all_keys: - _state_dict[k.replace("net.", "net_ema.")] = _state_dict.pop(k) - else: - log.warning("EMA is enabled, will only load EMA weights from checkpoint file.") - all_keys = list(_state_dict.keys()) - for k in all_keys: - if k.startswith("net_ema."): - break - else: - raise ValueError("No EMA keys found in state_dict") - # do not load .net keys, since we do not need them anyway. - _state_dict = {k: _state_dict[k] for k in all_keys if not k.startswith("net.")} - - if hasattr(self.model[0].config, "lora_config") and self.model[0].config.lora_config.enabled: - """ - When using LoRA, `inject_adapter_in_model` modifies the target modules in place. - For example, `blocks[0].attn.q_proj.weight` will be modified to `blocks[0].attn.q_proj.base_layer.weight`. - This means that the model will have the key `blocks[0].attn.q_proj.base_layer.weight`, - but the checkpoint will have the key `blocks[0].attn.q_proj.weight`. - We need to map the model key to the checkpoint key. - """ - self.checkpoint_to_model_key = {} - mapping_keys = {"base_layer.": "", "base_model.model.": ""} - keys_to_update = [] - for k in _state_dict.keys(): - new_key = k - for from_key, to_key in mapping_keys.items(): - new_key = new_key.replace(from_key, to_key) - if new_key != k: - keys_to_update.append((k, new_key)) - self.checkpoint_to_model_key[new_key] = k - for k, new_key in keys_to_update: - _state_dict[new_key] = _state_dict.pop(k) - - return _state_dict - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - if hasattr(self.model[0].config, "lora_config") and self.model[0].config.lora_config.enabled: - if hasattr(self, "checkpoint_to_model_key"): - for checkpoint_key, model_key in self.checkpoint_to_model_key.items(): - state_dict[model_key] = state_dict.pop(checkpoint_key) - else: - raise ValueError("checkpoint_to_model_key is not set by `state_dict`") - - if self.load_ema_to_reg: - if not self.model[0].config.ema.enabled: - all_keys = list(state_dict.keys()) - assert all(k.startswith("net_ema.") for k in all_keys), "All keys must start with net_ema." - for k in all_keys: - state_dict[k.replace("net_ema.", "net.")] = state_dict.pop(k) - else: - log.warning("EMA is enabled, will load EMA weights to regular model weights") - all_keys = list(state_dict.keys()) - assert all(not k.startswith("net.") for k in all_keys), "No .net keys should be in state_dict" - for k in all_keys: - if k.startswith("net_ema."): - state_dict[k.replace("net_ema.", "net.")] = torch.clone(state_dict[k]) - func = functools.partial( - set_model_state_dict, - model_state_dict=state_dict, - options=StateDictOptions(strict=False), - ) - list(map(func, self.model)) - - -class OptimizerWrapper(Stateful): - def __init__( - self, - model: Union[nn.Module, List[nn.Module]], - optim: Union[torch.optim.Optimizer, List[torch.optim.Optimizer]], - ) -> None: - self.model = [model] if isinstance(model, nn.Module) else model - self.optim = [optim] if isinstance(optim, torch.optim.Optimizer) else optim - - def state_dict(self) -> Dict[str, Any]: - func = functools.partial( - get_optimizer_state_dict, - options=StateDictOptions(flatten_optimizer_state_dict=True), - ) - return {k: v for sd in map(func, self.model, self.optim) for k, v in sd.items()} - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - func = functools.partial( - set_optimizer_state_dict, - optim_state_dict=state_dict, - options=StateDictOptions(flatten_optimizer_state_dict=True), - ) - list(map(func, self.model, self.optim)) - - -class AsyncMode(str, enum.Enum): - DISABLED = "disabled" - ASYNC_WITH_PINNED_MEM = "async_with_pinned_mem" - - -class Terminate: - pass - - -class SaveDone: - def __init__(self, iteration: int, elapsed_time: float, succeeded: bool): - self.iteration = iteration - self.elapsed_time = elapsed_time - self.succeeded = succeeded - - def __str__(self): - return f"SaveDone(iteration={self.iteration}, elapsed_time={self.elapsed_time}, succeeded={self.succeeded})" - - -def save_checkpoint_in_background( - receiver_queue: multiprocessing.Queue, - sender_queue: multiprocessing.Queue, - checkpoint_config: CheckpointConfig, - job_config: JobConfig, -) -> None: - """ - Handles model checkpoint saving in a separate background process using PyTorch's distributed functionality. - This function runs in a dedicated process to avoid blocking the main training loop. - - Args: - receiver_queue: Queue to receive state dictionaries and commands from the main process - sender_queue: Queue to send completion signals back to the main process - checkpoint_config: Configuration settings for checkpoint saving behavior - job_config: Configuration settings for the training job - - Flow: - 1. Initializes distributed processing environment - 2. Continuously waits for state dictionaries to save - 3. Saves checkpoints asynchronously - 4. Signals completion back to main process - 5. Terminates when receiving a Terminate signal - - Raises: - AssertionError: If received object is neither Terminate signal nor valid state dict tuple - - Note: - - Uses a different port than the main process to avoid conflicts - - Disables TorchElastic agent store for checkpoint operations - - Automatically cleans up distributed process group on exit - """ - # Configure distributed environment - os.environ["MASTER_PORT"] = str(int(os.environ["MASTER_PORT"]) + 2) - os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "False" - - # Set up GPU device and distributed processing - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - distributed.init() - - # Initialize checkpointing mechanism - checkpoint_handler = DistributedCheckpointer(checkpoint_config, job_config, None, disable_async=True) - - try: - while True: - log.debug("Checkpoint background process is ready for next task, waiting for new state_dict") - received_data = receiver_queue.get() - log.debug("Received new state_dict") - - if isinstance(received_data, Terminate): - log.info("Received termination signal in checkpoint background process, closing sender queue") - sender_queue.put(Terminate()) - sender_queue.close() - return - - assert isinstance(received_data, tuple), "Received data must be a tuple of (state_dict, checkpoint_path)" - state_dict, checkpoint_path = received_data - - # Save checkpoint and measure time taken - start_time = time.monotonic() - iteration = state_dict["trainer"][0]["iteration"] - elapsed_time = 0 - succeeded = False - try: - checkpoint_handler.save_state_dict_worker(state_dict, checkpoint_path) - elapsed_time = time.monotonic() - start_time - log.info( - f"Checkpoint saved successfully in background process. Time taken: {elapsed_time:.2f} seconds, iteration: {iteration}" - ) - succeeded = True - except Exception as e: - log.error(f"Error saving checkpoint to {checkpoint_path}: {e}") - # continue because if the thread exits, the main thread keeps on adding to the queue - finally: - if elapsed_time == 0: - elapsed_time = time.monotonic() - start_time - sender_queue.put(SaveDone(iteration, elapsed_time, succeeded)) - - finally: - log.info("Cleaning up: destroying distributed process group") - torch.distributed.destroy_process_group() - - -class DistributedCheckpointer(AbstractCheckpointer): - KEYS_TO_SAVE = ["model", "optim", "scheduler", "trainer"] - - def __init__( - self, - config_checkpoint: CheckpointConfig, - config_job: JobConfig, - callbacks: Optional[callback.CallBackGroup] = None, - disable_async: bool = False, - ): - super().__init__(config_checkpoint, config_job, callbacks) - self.config_checkpoint = config_checkpoint - if config_checkpoint.dcp_async_mode_enabled: - self.async_mode = AsyncMode.ASYNC_WITH_PINNED_MEM - else: - self.async_mode = AsyncMode.DISABLED - - if disable_async: - self.async_mode = AsyncMode.DISABLED - - if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: - ctx = get_context("spawn") - self.mp_queue_send = ctx.Queue() - self.mp_queue_recv = ctx.Queue() - self.mp = ctx.Process( - target=save_checkpoint_in_background, - args=( - self.mp_queue_send, - self.mp_queue_recv, - config_checkpoint, - config_job, - ), - daemon=True, - ) - self.mp.start() - self.cpu_offload_state_dict = None - self.staging = False - self.staging_ckpt_file = None - self.staging_stream = torch.cuda.Stream() - - def keys_to_resume_during_load(self) -> Tuple[Set, Union[str, None]]: - latest_checkpoint_file = self._read_latest_checkpoint_file() - - resume_keys = [] - - if latest_checkpoint_file is not None: - # 1. Resume training from latest_checkpoint.txt under the same name. - checkpoint_path = os.path.join(self.load_dirname, latest_checkpoint_file) - resume_keys.extend(self.KEYS_TO_SAVE) - else: - if self.load_path and not self.load_path.endswith(".pth"): - # 2. Load the module weights specified by config_checkpoint.path. - checkpoint_path = self.load_path - if self.load_s3_backend_key: - checkpoint_path = f"s3://{self.config_checkpoint.load_from_object_store.bucket}/{checkpoint_path}" - if not re.search(r"/checkpoints/iter_\d{9}/?$", checkpoint_path): - old_ckpt_path = checkpoint_path - # If path doesn't end with specific checkpoint, read latest checkpoint file - latest_ckpt_path = os.path.join(checkpoint_path, "checkpoints/latest_checkpoint.txt") - if easy_io.exists(latest_ckpt_path, backend_key=self.load_s3_backend_key): - checkpoint_file = easy_io.load( - latest_ckpt_path, backend_key=self.load_s3_backend_key - ).strip() - checkpoint_path = f"{checkpoint_path}/checkpoints/{checkpoint_file}" - else: - log.warning( - f"Latest checkpoint file {latest_ckpt_path} not found, load from {old_ckpt_path}" - ) - checkpoint_path = old_ckpt_path - if self.load_training_state: - resume_keys.extend(self.KEYS_TO_SAVE) - else: - resume_keys.append("model") - if self.only_load_scheduler_state: - resume_keys.append("scheduler") - elif self.load_path and self.load_path.endswith(".pth"): - checkpoint_path = self.load_path - else: - checkpoint_path = None - if len(self.keys_not_to_resume) > 0: - for key in self.keys_not_to_resume: - assert key in self.KEYS_TO_SAVE, f"Invalid key to resume: {key} not in {self.KEYS_TO_SAVE}" - resume_keys = [key for key in resume_keys if key not in self.keys_not_to_resume] - return set(resume_keys), checkpoint_path - - @misc.timer("checkpoint loading") - def load( - self, - model: ImaginaireModel, - optimizer: torch.optim.Optimizer | None = None, - scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, - grad_scaler: torch.amp.GradScaler | None = None, - ) -> int: - if self.callbacks is not None: - self.callbacks.on_load_checkpoint_start(model) - - resume_keys, checkpoint_path = self.keys_to_resume_during_load() - resume_keys = sorted(resume_keys) - log.info(f"Resuming ckpt {checkpoint_path} with keys: {resume_keys}") - - iteration = 0 - - if checkpoint_path is not None and not checkpoint_path.endswith(".pth"): - self._check_checkpoint_exists(checkpoint_path) - for key in resume_keys: - load_planner = DefaultLoadPlanner(allow_partial_load=True) - if hasattr(load_planner, "set_partial_channel_weight"): - log.info(f"set_partial_channel_weight: {self.config_checkpoint.dcp_allow_mismatched_size}") - load_planner.set_partial_channel_weight(self.config_checkpoint.dcp_allow_mismatched_size) - cur_key_ckpt_full_path = os.path.join(checkpoint_path, key) - log.info(f"Start loading checkpoint from {checkpoint_path}") - storage_reader = self.get_storage_reader(cur_key_ckpt_full_path) - torch.distributed.barrier() - log.info(f"starting {cur_key_ckpt_full_path}", rank0_only=False) - if key == "model": - log.info("- Loading the model...") - _model_wrapper = ModelWrapper(model) - _state_dict = _model_wrapper.state_dict() - load(_state_dict, storage_reader=storage_reader, planner=load_planner) - _model_wrapper.load_state_dict(_state_dict) - elif key == "optim": - log.info("- Loading the optimizer...") - _optim_wrapper = OptimizerWrapper(model, optimizer) - _state_dict = _optim_wrapper.state_dict() - dcp.load( - _state_dict, - storage_reader=storage_reader, - planner=load_planner, - ) - _optim_wrapper.load_state_dict(_state_dict) - elif key == "scheduler": - log.info("- Loading the scheduler...") - _state_dict = scheduler.state_dict() - dcp.load( - _state_dict, - storage_reader=storage_reader, - planner=load_planner, - ) - scheduler.load_state_dict(_state_dict) - elif key == "trainer": - log.info("- Loading the trainer...") - _state_dict = { - "grad_scaler": grad_scaler.state_dict(), - "iteration": iteration, - } - dcp.load( - _state_dict, - storage_reader=storage_reader, - planner=load_planner, - ) - grad_scaler.load_state_dict(_state_dict["grad_scaler"]) - iteration = _state_dict["iteration"] - else: - raise ValueError(f"Invalid key: {key}. not support to resume.") - if self.callbacks is not None: - self.callbacks.on_load_checkpoint(model, state_dict=_state_dict) - log.info(f"Loaded checkpoint from {checkpoint_path} in iteration {iteration}") - elif checkpoint_path is not None and checkpoint_path.endswith(".pth"): - state = easy_io.load(checkpoint_path) - model_state = model.net.state_dict() - - for k, v in list(state.items()): - tgt = model_state.get(k, None) - if tgt is None: - continue - # If target param/buffer is a DTensor and checkpoint value is not, distribute it - if isinstance(tgt, DTensor) and not isinstance(v, DTensor): - # Match device, dtype, and placements from the target DTensor - v = v.to(tgt.device, dtype=tgt.dtype, copy=False) - v = distribute_tensor(v, tgt.device_mesh, tgt.placements) - state[k] = v - # If target is a plain Tensor but checkpoint is a DTensor, bring it local - if not isinstance(tgt, DTensor) and isinstance(v, DTensor): - state[k] = v.to_local().to(tgt.device, dtype=tgt.dtype, copy=False) - - model.load_state_dict(state, strict=False, pretrain_copy=True) - # Clear unused reserved memory from fp32 - torch.cuda.empty_cache() - log.critical(f"Loaded checkpoint from {checkpoint_path}. **** This only happen at iteration 0 **** ") - else: - log.info("Training from scratch.") - torch.cuda.empty_cache() - - if self.callbacks is not None: - self.callbacks.on_load_checkpoint_end(model, iteration=iteration, checkpoint_path=checkpoint_path) - return iteration - - def _async_with_pinned_memory(self, checkpoint_file: str, state_dict: Dict[str, Tuple[Any, str]]) -> None: - try: - from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict - except ImportError as e: - raise ImportError( - "Please install the latest PyTorch nightly to use async checkpointing with pinned memory." - ) from e - if self.cpu_offload_state_dict is None: - log.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f") - self.cpu_offload_state_dict = _create_cpu_state_dict(state_dict, pin_memory=True, share_memory=True) - - log.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f") - with torch.cuda.stream(self.staging_stream): - self.cpu_offload_state_dict = _copy_state_dict( - state_dict, - self.cpu_offload_state_dict, - non_blocking=True, - ) - self.staging = True - self.staging_ckpt_file = checkpoint_file - - self.maybe_wait_for_staging() - - def maybe_wait_for_staging(self) -> None: - if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM and self.staging: - if not self.staging_stream.query(): - self.staging_stream.synchronize() - - def sync_func(): - self.mp_queue_send.put_nowait((self.cpu_offload_state_dict, self.staging_ckpt_file)) - - sync_func() - self.staging = False - - def get_previous_checkpoint_results(self, wait_for: int = 0) -> None: - """Get the results of previously submitted checkpoints and pass them to callbacks if checkpoint succeeded""" - if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: - try: - start_time = time.monotonic() - while not self.mp_queue_recv.empty() or wait_for > 0: - try: - ret = self.mp_queue_recv.get(timeout=1) - if isinstance(ret, Terminate): - log.info("Received termination event from checkpoint background process") - break - save_done: SaveDone = ret - log.logger.info(f"Received checkpoint save result: {save_done}") - if self.callbacks is not None and save_done.succeeded: - self.callbacks.on_save_checkpoint_success( - iteration=save_done.iteration, elapsed_time=save_done.elapsed_time - ) - except queue.Empty: - elapsed_time = time.monotonic() - start_time - if elapsed_time > wait_for: - break - except (EOFError, BrokenPipeError): - log.info("Queue was closed by checkpoint background process") - - def get_storage_writer(self, checkpoint_path: str) -> Union[S3StorageWriter, FileSystemWriter]: - if self.save_to_object_store: - return S3StorageWriter( - credential_path=self.config_checkpoint.save_to_object_store.credentials, - path=checkpoint_path, - ) - return FileSystemWriter(path=checkpoint_path) - - def get_storage_reader(self, checkpoint_path: str) -> Union[S3StorageReader, FileSystemReader]: - if self.load_from_object_store: - return S3StorageReader( - credential_path=self.config_checkpoint.load_from_object_store.credentials, - path=checkpoint_path, - ) - return FileSystemReader(checkpoint_path) - - def save_state_dict_worker(self, to_save_dict: Dict[str, Tuple[Any, str]], checkpoint_file: str) -> None: - for k, (v, full_checkpoint_path) in to_save_dict.items(): - storage_writer = self.get_storage_writer(full_checkpoint_path) - dcp.save( - v, - storage_writer=storage_writer, - planner=DefaultSavePlanner(dedup_save_to_lowest_rank=True), - ) - - if distributed.is_rank0(): - print(f"Saving last checkpoint file {checkpoint_file}") - self._write_latest_checkpoint_file(checkpoint_file) - - log.critical(f"Saved checkpoint to {os.path.join(self.save_dirname, checkpoint_file)}", rank0_only=True) - - def save( - self, - model: ImaginaireModel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - iteration: int, - ) -> None: - """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. - - Args: - model (ImaginaireModel): The PyTorch model. - optimizer (torch.optim.Optimizer): The model optimizer. - scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. - grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). - iteration (int): Current iteration number. - """ - if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: - self.get_previous_checkpoint_results(wait_for=0) - - if self.callbacks is not None: - self.callbacks.on_save_checkpoint_start(model, iteration) - - checkpoint_file = f"iter_{iteration:09}" - to_save_dict = { - "model": ModelWrapper(model).state_dict(), - "optim": OptimizerWrapper(model, optimizer).state_dict(), - "scheduler": scheduler.state_dict(), - "trainer": { - "grad_scaler": grad_scaler.state_dict(), - "iteration": iteration, - }, - } - for k in to_save_dict.keys(): - output_dirname = os.path.join(self.save_dirname, f"iter_{iteration:09}/{k}") - to_save_dict[k] = (to_save_dict[k], output_dirname) - - if self.callbacks is not None: - self.callbacks.on_save_checkpoint(model, state_dict=to_save_dict) - - if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: - self._async_with_pinned_memory(checkpoint_file, to_save_dict) - else: - start_time = time.monotonic() - try: - self.save_state_dict_worker(to_save_dict, checkpoint_file) - finally: - if self.callbacks is not None: - self.callbacks.on_save_checkpoint_success( - iteration=iteration, elapsed_time=time.monotonic() - start_time - ) - - # This measures exposed (synchronous) checkpoint time, on_save_checkpoint_success() - # is instead called to measure the entire duration for asynchronous checkpoint for the async case too. - if self.callbacks is not None: - self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) - - def finalize(self) -> None: - super().finalize() - if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: - if self.mp and self.mp.is_alive(): - self.mp_queue_send.put(Terminate()) - self.get_previous_checkpoint_results(wait_for=60) - self.mp.join() diff --git a/lyra_2/_ext/imaginaire/checkpointer/dummy.py b/lyra_2/_ext/imaginaire/checkpointer/dummy.py deleted file mode 100644 index 9eb445a4a7ff5cc2dd315e3e04c508787783933e..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/checkpointer/dummy.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import torch -import torch.distributed - -from lyra_2._ext.imaginaire.checkpointer.base import AbstractCheckpointer -from lyra_2._ext.imaginaire.model import ImaginaireModel - - -class Checkpointer(AbstractCheckpointer): - """ - A dummy checkpointer that does not save or load anything. This is useful for debugging jobs or share workload with collobrators. - """ - - def save( - self, - model: ImaginaireModel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - iteration: int, - ) -> None: - pass - - def load( - self, - model: ImaginaireModel, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, - grad_scaler: Optional[torch.amp.GradScaler] = None, - ) -> int: - return 0 diff --git a/lyra_2/_ext/imaginaire/checkpointer/s3_filesystem.py b/lyra_2/_ext/imaginaire/checkpointer/s3_filesystem.py deleted file mode 100644 index 438882f6694df89b5e5c56a9dc9480c72522045e..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/checkpointer/s3_filesystem.py +++ /dev/null @@ -1,323 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import io -import json -import os -import time -from contextlib import contextmanager -from typing import Generator, Union -from urllib.parse import urlparse - -import boto3 -from botocore.config import Config -from botocore.exceptions import ClientError -from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter -from torch.distributed.checkpoint.filesystem import FileSystemBase - -from lyra_2._ext.imaginaire.utils import log - - -class S3Stream(io.BytesIO): - """ - Workaround for PyTorch manually closing the stream before we can upload it to S3. We override the close() as noop - and instead call our own _true_close() method to close the stream after we are done using it. - The commit at fault is https://github.com/pytorch/pytorch/commit/9c909bf3bb122db2cce95e2eb7459bbe50dfa15a - """ - - def close(self): - self.flush() - # No close - - def _true_close(self): - super().close() - - -class S3FileSystem(FileSystemBase): - """Implementation of FileSystemBase for AWS S3 storage.""" - - def __init__( - self, - credential_path: str, - max_attempts: int = 20, - initial_backoff: float = 1.0, - max_backoff: float = 30.0, - backoff_factor: float = 2.0, - ) -> None: - """ - Initialize S3FileSystem with retry configuration. - - Args: - credential_path: Path to AWS credentials JSON file - max_attempts: Maximum number of retry attempts - initial_backoff: Initial backoff time in seconds - max_backoff: Maximum backoff time in seconds - backoff_factor: Multiplicative factor for backoff time - """ - with open(credential_path, "r") as f: - conf = json.load(f) - - # Configure boto3 with retry settings - config = Config( - retries=dict(max_attempts=max_attempts, mode="adaptive"), # Adaptive mode automatically handles throttling - connect_timeout=60, - read_timeout=60, - request_checksum_calculation="when_required", # Data integrity check for uploads and downloads - response_checksum_validation="when_required", # Data integrity check for uploads and downloads - ) - - self.s3_client = boto3.client("s3", config=config, **conf) - self.max_attempts = max_attempts - self.initial_backoff = initial_backoff - self.max_backoff = max_backoff - self.backoff_factor = backoff_factor - - def _retry_with_backoff(self, operation_func, *args, **kwargs): - """ - Execute an operation with exponential backoff retry logic. - - Args: - operation_func: Function to execute - *args: Positional arguments for the function - **kwargs: Keyword arguments for the function - - Returns: - Result of the operation function - - Raises: - Exception: If all retry attempts fail - """ - last_exception = None - backoff = self.initial_backoff - - for attempt in range(self.max_attempts): - try: - return operation_func(*args, **kwargs) - except ClientError as e: - error_code = e.response.get("Error", {}).get("Code", "") - log.info(f"S3 Filesystem: Received ClientError: {error_code}", rank0_only=False) - - # Handle specific error cases - if error_code in ["SlowDown", "ThrottlingException", "RequestLimitExceeded", "InternalError"]: - last_exception = e - if attempt < self.max_attempts - 1: # Don't sleep on last attempt - current_backoff = min(backoff, self.max_backoff) - log.info(f"S3 Filesystem: Retrying in {current_backoff} seconds", rank0_only=False) - time.sleep(current_backoff) - backoff *= self.backoff_factor - continue - # For other client errors, raise immediately - raise - except Exception as e: - log.info(f"S3 Filesystem: Received Exception: {str(e)}", rank0_only=False) - last_exception = e - if attempt < self.max_attempts - 1: - current_backoff = min(backoff, self.max_backoff) - log.info(f"S3 Filesystem: Retrying in {current_backoff} seconds", rank0_only=False) - time.sleep(current_backoff) - backoff *= self.backoff_factor - continue - - raise last_exception - - @contextmanager - def create_stream(self, path: Union[str, os.PathLike], mode: str) -> Generator[io.IOBase, None, None]: - """Create a stream for reading from or writing to S3 with retry logic.""" - path_str = str(path) - bucket, key = self._parse_s3_uri(path_str) - log.info(f"S3 Filesystem: Creating stream for {key} in bucket {bucket}", rank0_only=False) - - if mode == "rb": - stream = io.BytesIO() - try: - - def download_operation(): - self.s3_client.download_fileobj(bucket, key, stream) - stream.seek(0) - - log.info(f"S3 Filesystem: Downloading {key} from bucket {bucket}", rank0_only=False) - self._retry_with_backoff(download_operation) - log.info("S3 Filesystem: Download complete", rank0_only=False) - yield stream - finally: - stream.close() - elif mode == "wb": - stream = S3Stream() - try: - yield stream - - def upload_operation(): - stream.seek(0) - self.s3_client.upload_fileobj(stream, bucket, key) - - log.info(f"S3 Filesystem: Uploading {key} to bucket {bucket}", rank0_only=False) - self._retry_with_backoff(upload_operation) - log.info("S3 Filesystem: Upload complete", rank0_only=False) - finally: - stream._true_close() - else: - raise ValueError(f"Unsupported mode: {mode}") - - def concat_path(self, path: Union[str, os.PathLike], suffix: str) -> Union[str, os.PathLike]: - """Concatenate S3 path with suffix.""" - path_str = str(path) - if path_str.endswith("/"): - return f"{path_str}{suffix}" - return f"{path_str}/{suffix}" - - def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: - """Initialize and validate S3 path.""" - path_str = str(path) - if not path_str.startswith("s3://"): - raise ValueError(f"Invalid S3 URI: {path_str}. Must start with 's3://'") - return path_str - - def rename(self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]) -> None: - """Rename (move) an object in S3 with retry logic.""" - src_bucket, src_key = self._parse_s3_uri(str(path)) - dst_bucket, dst_key = self._parse_s3_uri(str(new_path)) - - def copy_operation(): - copy_source = {"Bucket": src_bucket, "Key": src_key} - self.s3_client.copy(copy_source, dst_bucket, dst_key) - - self._retry_with_backoff(copy_operation) - - def delete_operation(): - self.s3_client.delete_object(Bucket=src_bucket, Key=src_key) - - self._retry_with_backoff(delete_operation) - - def mkdir(self, path: Union[str, os.PathLike]) -> None: - """ - Create a "directory" in S3. - - Note: S3 doesn't have real directories, but we can create an empty object - with a trailing slash to simulate a directory. - """ - # Creating same buckets from different ranks can cause rate limit issues in GCP. - # In object store, we don't need to create a directory. - pass - - @classmethod - def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: - """Validate if the checkpoint_id is a valid S3 URI.""" - checkpoint_id_str = str(checkpoint_id) - try: - if not checkpoint_id_str.startswith("s3://"): - return False - parsed = urlparse(checkpoint_id_str) - return bool(parsed.netloc and parsed.path) # Must have bucket and key - except Exception: - return False - - def exists(self, path: Union[str, os.PathLike]) -> bool: - """Check if an object exists in S3 with retry logic.""" - bucket, key = self._parse_s3_uri(str(path)) - try: - - def head_operation(): - self.s3_client.head_object(Bucket=bucket, Key=key) - - self._retry_with_backoff(head_operation) - return True - except ClientError as e: - if e.response.get("Error", {}).get("Code", "") == "404": - return False - raise - - def rm_file(self, path: Union[str, os.PathLike]) -> None: - """Remove a file from S3 with retry logic.""" - bucket, key = self._parse_s3_uri(str(path)) - - def delete_operation(): - self.s3_client.delete_object(Bucket=bucket, Key=key) - - self._retry_with_backoff(delete_operation) - - def _parse_s3_uri(self, uri: str) -> tuple[str, str]: - """ - Parse an S3 URI into bucket and key. - - Args: - uri: S3 URI in the format s3://bucket-name/key - - Returns: - Tuple of (bucket_name, key) - - Raises: - ValueError: If the URI is invalid - """ - uri = uri if isinstance(uri, str) else str(uri) - if not uri.startswith("s3://"): - raise ValueError(f"Invalid S3 URI: {uri}. Must start with 's3://'") - - parsed = urlparse(uri) - bucket = parsed.netloc - - # Remove leading slash from key - key = parsed.path.lstrip("/") - - if not bucket: - raise ValueError(f"Invalid S3 URI: {uri}. No bucket specified") - - return bucket, key - - -class S3StorageWriter(FileSystemWriter): - def __init__( - self, - credential_path: str, - path: str, - **kwargs, - ) -> None: - """ - Initialize an S3 writer for distributed checkpointing. - - Args: - region (str): The AWS region for S3. - path (str): The S3 URI to write checkpoints to. - kwargs (dict): Keyword arguments to pass to the parent :class:`FileSystemWriter`. - """ - super().__init__( - path=path, - sync_files=False, - **kwargs, - ) - self.fs = S3FileSystem(credential_path) # type: ignore - self.path = self.fs.init_path(path) - - @classmethod - def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: - return S3FileSystem.validate_checkpoint_id(checkpoint_id) - - -class S3StorageReader(FileSystemReader): - def __init__(self, credential_path: str, path: Union[str, os.PathLike]) -> None: - """ - Initialize an S3 reader for distributed checkpointing. - - Args: - region (str): The AWS region for S3. - path (Union[str, os.PathLike]): The S3 path to read checkpoints from. - """ - super().__init__(path) - self.fs = S3FileSystem(credential_path) # type: ignore - self.path = self.fs.init_path(path) - self.sync_files = False - - @classmethod - def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: - return S3FileSystem.validate_checkpoint_id(checkpoint_id) diff --git a/lyra_2/_ext/imaginaire/config.py b/lyra_2/_ext/imaginaire/config.py deleted file mode 100644 index 9fb6c4d1d2eb043ab2190ad2ade75097a0d8b6fc..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/config.py +++ /dev/null @@ -1,445 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Training config system for Imaginare4""" - -from __future__ import annotations - -import os -from typing import Any, Dict, Optional, Type, TypeVar, Union - -import attrs -import torch -import torch.utils.data -import torch.utils.data.distributed - -try: - from megatron.core import ModelParallelConfig - - USE_MEGATRON = True -except ImportError: - USE_MEGATRON = False - print("Megatron-core is not installed.") - -from lyra_2._ext.imaginaire.lazy_config import LazyCall as L -from lyra_2._ext.imaginaire.lazy_config import LazyDict -from lyra_2._ext.imaginaire.utils import callback, distributed -from lyra_2._ext.imaginaire.utils.misc import Color - -T = TypeVar("T") - - -def _is_attrs_instance(obj: object) -> bool: - """ - Helper function to check if an object is an instance of an attrs-defined class. - - Args: - obj: The object to check. - - Returns: - bool: True if the object is an instance of an attrs-defined class, False otherwise. - """ - return hasattr(obj, "__attrs_attrs__") - - -def make_freezable(cls: T) -> T: - """ - A decorator that adds the capability to freeze instances of an attrs-defined class. - - NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need - to hack on a "_is_frozen" attribute. - - This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime. - Once an instance is frozen, its attributes cannot be changed. It also recursively freezes - any attrs-defined objects that are attributes of the class. - - Usage: - @make_freezable - @attrs.define(slots=False) - class MyClass: - attribute1: int - attribute2: str - - obj = MyClass(1, 'a') - obj.freeze() # Freeze the instance - obj.attribute1 = 2 # Raises AttributeError - - Args: - cls: The class to be decorated. - - Returns: - The decorated class with added freezing capability. - """ - - if not hasattr(cls, "__dict__"): - raise TypeError( - "make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped " - "class was defined with `@attrs.define(slots=False)`" - ) - - original_setattr = cls.__setattr__ - - def setattr_override(self, key, value) -> None: # noqa: ANN001 - """ - Override __setattr__ to allow modifications during initialization - and prevent modifications once the instance is frozen. - """ - if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen": - raise AttributeError("Cannot modify frozen instance") - original_setattr(self, key, value) # type: ignore - - cls.__setattr__ = setattr_override # type: ignore - - def freeze(self: object) -> None: - """ - Freeze the instance and all its attrs-defined attributes. - """ - for _, value in attrs.asdict(self, recurse=False).items(): - if _is_attrs_instance(value) and hasattr(value, "freeze"): - value.freeze() - self._is_frozen = True # type: ignore - - cls.freeze = freeze # type: ignore - - return cls - - -def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str: - """ - Recursively pretty prints attrs objects with color. - """ - - assert attrs.has(obj.__class__) - - lines: list[str] = [] - for attribute in attrs.fields(obj.__class__): - value = getattr(obj, attribute.name) - if attrs.has(value.__class__): - if use_color: - lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":") - else: - lines.append(" " * indent + "* " + attribute.name + ":") - lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color)) - else: - if use_color: - lines.append( - " " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value) - ) - else: - lines.append(" " * indent + "* " + attribute.name + ": " + str(value)) - return "\n".join(lines) - - -def pretty_print_overrides(overrides: Optional[list[str]] = None, use_color: bool = False) -> str: - """ - Pretty prints overrides. - """ - - lines: list[str] = [] - lines.append(Color.cyan("* ") + Color.green("overrides") + ": ") - for override in overrides: - if override == "--": - continue - if override.startswith("~"): - attribute_name = override[1:] - attribute_value = None - else: - attribute_name, attribute_value = override.split("=") - if use_color: - lines.append(" " + Color.cyan("* ") + Color.green(attribute_name) + ": " + Color.yellow(attribute_value)) - else: - lines.append(" " + "* " + attribute_name + ": " + str(attribute_value)) - - return "\n".join(lines) - - -@make_freezable -@attrs.define(slots=False) # slots=False is required for make_freezable. See the make_freezable notes for more info. -class ObjectStoreConfig: - # Whether the file I/O is from object store instead of local disk. - enabled: bool = False - # Path to the object store credentials file. - credentials: str = "" - # Object store bucket to read from / write to the objects. - bucket: str = "" - - -@make_freezable -@attrs.define(slots=False) -class JobConfig: - # Project name. - project: str = "" - # Experiment name. - group: str = "" - # Run/job name. - name: str = "" - @property - def path(self) -> str: - return f"{self.project}/{self.group}/{self.name}" - - @property - def path_local(self) -> str: - local_root = os.environ.get("IMAGINAIRE_OUTPUT_ROOT", "/tmp/imaginaire4-output") - return f"{local_root}/{self.path}" - - -@make_freezable -@attrs.define(slots=False) -class EMAConfig: - # Enable tracking a set of exponential moving average (EMA) weights. - enabled: bool = False - # EMA decay rate. - beta: float = 0.9999 - # Enable removing "_orig_mod-" from buffer names that is added by torch.compile - torch_compile_buffer_renaming: bool = False - - -@make_freezable -@attrs.define(slots=False) -class PowerEMAConfig: - # Enable tracking a set of exponential moving average (EMA) weights. - enabled: bool = False - # EDM2 paper EMA decay rate. - s: float = 0.1 - # Enable removing "_orig_mod-" from buffer names that is added by torch.compile - torch_compile_buffer_renaming: bool = False - - -@make_freezable -@attrs.define(slots=False) -class DDPConfig: - # Traverse the computation graph to find parameters that don't receive gradients. - find_unused_parameters: bool = False - # Set to True if the computation graph does not change during the whole training loop. - static_graph: bool = True - # Set to True if we want to synchronize buffers. Set to False if the sync is going to be handled elsewhere. - broadcast_buffers: bool = True - - -@make_freezable -@attrs.define(slots=False) -class CuDNNConfig: - # Set to True for better reproducibility of the results (only using deterministic cudnn functions). - deterministic: bool = False - # If set to True, cudnn will benchmark several algorithms and pick the fastest one. - benchmark: bool = True - - -@make_freezable -@attrs.define(slots=False) -class JITConfig: - # Enable exporting a JIT compiled model. - enabled: bool = False - # Input tensor shape, for example input. - input_shape: Union[list[int], None] = None - # Device to compile onto. - device: str = "cuda" - # # Data type to compile onto. - dtype: str = "bfloat16" - # Strict mode for PyTorch JIT. - strict: bool = True - - -@make_freezable -@attrs.define(slots=False) -class CheckpointConfig: - # possible checkpoint class - type: Optional[Dict] = None - # for dcp, whether to use async mode - dcp_async_mode_enabled: bool = False - # Configs for saving the checkpoints to object store. - save_to_object_store: ObjectStoreConfig = attrs.field(factory=ObjectStoreConfig) - # Save the checkpoint every N iterations. - save_iter: int = 999999999 - # Configs for loading the checkpoints from object store. - load_from_object_store: ObjectStoreConfig = attrs.field(factory=ObjectStoreConfig) - # Path of model weights to resume the checkpoint from. - load_path: str = "" - # Whether to load the training states (optimizer/scheduler/grad-scaler) from the checkpoint path. - load_training_state: bool = False - # Whether to load the scheduler state only from the checkpoint path. If load_training_state is True, this will be ignored. - only_load_scheduler_state: bool = False - # Load state_dict to the models in strict mode. - strict_resume: bool = True - # Configs for JIT compiling EMA model. - jit: JITConfig = attrs.field(factory=JITConfig) - # Print detailed information during checkpoint saving/loading. - verbose: bool = True - # keys not to resume from the checkpoint, choices: ["model", "optim", "scheduler", "trainer"] - keys_not_to_resume: list[str] = [] - # Whether to use the local filesystem for broadcasting checkpoint data (used for Tensor Parallel Checkpointer). - broadcast_via_filesystem: bool = False - load_ema_to_reg: bool = False # used in inference, load EMA weights to regular model - # In dcp planner, skip the weight shape check, load weights into the model even weight shape is different - dcp_allow_mismatched_size: bool = False - - -@make_freezable -@attrs.define(slots=False) -class NVTXConfig: - """Config for NVTX ranges used in the main training loop. - - See tutorials/nanogpt for more details on how to integrate profiling into your model.""" - - # Enable the NVTX ranges. - enabled: bool = False - # Synchronize everything in each NVTX range. - cuda_synchronize: bool = False - - -@make_freezable -@attrs.define(slots=False) -class StragglerDetectionConfig: - """Config for Straggler detection tool.""" - - # Enable the Straggler Detection. - enabled: bool = False - # How frequently should the Straggler reports be generated. - report_freq: int = 100 - # How frequently iterations should be profiled - profile_freq: int = 1 - # What is the maximum relative difference between GPUs after they are considered stragglers - max_diff: float = 2.0 - # Should the error be raised when straggler is detected - raise_error: bool = True - # Analyze kernels in the forward pass. - analyze_forward: bool = True - # Analyze kernels in the backward pass. - analyze_backward: bool = True - # Analyze kernels in the optimizer. - analyze_optimizer: bool = True - # Analyze dataloading time. - analyze_dataloading: bool = True - - -@make_freezable -@attrs.define(slots=False) -class Profiling: - enable_profiling: bool = False - enable_memory_snapshot: bool = False - save_s3: bool = False - profile_freq: int = 1 - # Target ranks for profiling, each entry must be >=0 and < world_size. - target_ranks: list[int] = list(range(8)) - # Set `record_shape` and `profile_memory` to False to reduce profile size. - record_shape: bool = False - profile_memory: bool = False - with_stack: bool = True - with_modules: bool = True - - -@make_freezable -@attrs.define(slots=False) -class TrainerConfig: - from lyra_2._ext.imaginaire.trainer import ImaginaireTrainer - - type: Type[ImaginaireTrainer] = ImaginaireTrainer - # Set the callback class. - # Defaults to the callbacks below. - callbacks: LazyDict = LazyDict( - dict( - ema=L(callback.EMAModelCallback)(), - progress_bar=L(callback.ProgressBarCallback)(), - ) - ) - # distributed parallelism strategy - distributed_parallelism: str = "ddp" - # Distributed data parallel configs. - ddp: DDPConfig = attrs.field(factory=DDPConfig) - # cuDNN configs. - cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig) - # Set the random seed. - seed: int = 0 - # Set the random seed based on current timestamp - timestamp_seed: bool = False - # Gradient scaler arguments (for torch.amp.GradScaler). - grad_scaler_args: dict = attrs.field(factory=lambda: dict(enabled=False)) - # Maximum number of iterations to train the model. - max_iter: int = 999999999 - # Maximum number of iterations to validate the model. If None, validate on the entire dataset. - max_val_iter: int | None = None - # How often we log the training stats. - logging_iter: int = 100 - # Whether we want to run the validation routines. - run_validation: bool = True - # How often we evaluate on the validation set. - validation_iter: int = 999999999 - # Kill the process after N seconds since the last iteration (usually means dead job). - timeout_period: int = 999999999 - # Tensor memory organization format. - memory_format: torch.memory_format = torch.preserve_format - # Gradient accumulation (update step every N iteration). - grad_accum_iter: int = 1 - # Straggler Detection config - straggler_detection: StragglerDetectionConfig = attrs.field(factory=StragglerDetectionConfig) - # Profiling config - profiling: Profiling = attrs.field(factory=Profiling) - - -@make_freezable -@attrs.define(slots=False) -class Config: - """Config for an imaginaire4 job. - - See /README.md/Configuration System for more info. - """ - - # Model configs. - model: LazyDict - # Optimizer configs. - optimizer: LazyDict - # Scheduler configs. - scheduler: LazyDict - # Training data configs. - dataloader_train: LazyDict - # Validation data configs. - dataloader_val: LazyDict - - # Training job configs. - job: JobConfig = attrs.field(factory=JobConfig) - - # Trainer configs. - trainer: TrainerConfig = attrs.field(factory=TrainerConfig) - - if USE_MEGATRON: - # Megatron-Core configs - model_parallel: ModelParallelConfig = attrs.field(factory=ModelParallelConfig) - else: - model_parallel: None = None - - # Checkpointer configs. - checkpoint: CheckpointConfig = attrs.field(factory=CheckpointConfig) - - # enable upload reproducible setup to s3 - upload_reproducible_setup: bool = False - - def pretty_print(self, use_color: bool = False) -> str: - return _pretty_print_attrs_instance(self, 0, use_color) - - def to_dict(self) -> dict[str, Any]: - return attrs.asdict(self) - - def validate(self) -> None: - """Validate that the config has all required fields.""" - - # broadcast job.name across all ranks to make sure it is consistent - # otherwise, unaligned job names leads unaligned path to save checkpoints - job_name_tensor = torch.ByteTensor(bytearray(self.job.name, "utf-8")).cuda() - distributed.broadcast(job_name_tensor, 0) - self.job.name = job_name_tensor.cpu().numpy().tobytes().decode("utf-8") - - assert self.job.project != "" - assert self.job.group != "" - assert self.job.name != "" diff --git a/lyra_2/_ext/imaginaire/flags.py b/lyra_2/_ext/imaginaire/flags.py deleted file mode 100644 index ec6b188fe183e3ff22ad8aed3211871c4a0a00dd..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/flags.py +++ /dev/null @@ -1,52 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Feature flags.""" - -import os -from dataclasses import dataclass - - -def _parse_bool(value: str) -> bool: - """Parse string to a boolean.""" - return value.lower() in ["true", "1", "yes", "y"] - - -INTERNAL = _parse_bool(os.environ.get("COSMOS_INTERNAL", "0")) -"""Whether to enable internal features.""" - -SMOKE = _parse_bool(os.environ.get("COSMOS_SMOKE", "0")) -"""Whether to enable smoke test. - -Disables expensive operations such as checkpoint loading. -""" - -VERBOSE = _parse_bool(os.environ.get("COSMOS_VERBOSE", "0")) -"""Whether to enable verbose output.""" - -EXPERIMENTAL_CHECKPOINTS = _parse_bool(os.environ.get("COSMOS_EXPERIMENTAL_CHECKPOINTS", "0")) -"""Whether to enable experimental checkpoints.""" - - -@dataclass -class Flags: - internal: bool = INTERNAL - smoke: bool = SMOKE - verbose: bool = VERBOSE - experimental_checkpoints: bool = EXPERIMENTAL_CHECKPOINTS - - -FLAGS = Flags() -"""Convenience object for accessing flags.""" diff --git a/lyra_2/_ext/imaginaire/functional/batch_ops.py b/lyra_2/_ext/imaginaire/functional/batch_ops.py deleted file mode 100644 index 549e81119e341c785ff5799ffd0446f805d66dd7..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/functional/batch_ops.py +++ /dev/null @@ -1,61 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Functions for performing operations with broadcasting to the right axis -# -# Example -# input1: tensor of size (N1, N2) -# input2: tensor of size (N1, N2, N3, N4) -# batch_mul(input1, input2) = input1[:, :, None, None] * input2 -# -# If the common dimensions don't match, we raise an assertion error. - -from torch import Tensor - - -def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: - ndims1 = x.ndim - ndims2 = y.ndim - - common_ndims = min(ndims1, ndims2) - for axis in range(common_ndims): - assert x.shape[axis] == y.shape[axis], "Dimensions not equal at axis {}".format(axis) - - if ndims1 < ndims2: - x = x.reshape(x.shape + (1,) * (ndims2 - ndims1)) - elif ndims2 < ndims1: - y = y.reshape(y.shape + (1,) * (ndims1 - ndims2)) - - return x, y - - -def batch_add(x: Tensor, y: Tensor) -> Tensor: - x, y = common_broadcast(x, y) - return x + y - - -def batch_mul(x: Tensor, y: Tensor) -> Tensor: - x, y = common_broadcast(x, y) - return x * y - - -def batch_sub(x: Tensor, y: Tensor) -> Tensor: - x, y = common_broadcast(x, y) - return x - y - - -def batch_div(x: Tensor, y: Tensor) -> Tensor: - x, y = common_broadcast(x, y) - return x / y diff --git a/lyra_2/_ext/imaginaire/functional/lr_scheduler.py b/lyra_2/_ext/imaginaire/functional/lr_scheduler.py deleted file mode 100644 index a84e7b2028ffb76a229e8adb17721ff97ce1c20e..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/functional/lr_scheduler.py +++ /dev/null @@ -1,178 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import numpy as np - -from lyra_2._ext.imaginaire.utils import distributed, log - - -class TeroPolyScheduler: - def __init__( - self, - total_Mimg: int, - batch_size: int, - ref_Mimg: Optional[int] = None, - ref_batches: float = 70e3 / 1024, - max_lr_ratio: Optional[float] = 1.0, - min_lr_ratio: Optional[float] = None, - rampup_Mimg: float = 0, - rampdown_Mimg: int = 0, - verbosity_interval: int = 0, - formula: str = "poly", - poly_exp: float = 0.5, - ): - self.total_Mimg = total_Mimg - self.batch_size = batch_size * distributed.get_world_size() - self.ref_Mimg = ref_Mimg or ref_batches * batch_size / 1e6 - self.ref_batches = ref_batches - self.max_lr_ratio = max_lr_ratio - self.min_lr_ratio = min_lr_ratio - self.rampup_Mimg = rampup_Mimg - self.rampdown_Mimg = rampdown_Mimg - self.verbosity_interval = verbosity_interval - self.formula = formula - self.poly_exp = poly_exp - - self._model = None - - @property - def model(self): - return self._model - - @model.setter - def model(self, model): - self._model = model - - def schedule(self, n, **kwargs): - cur_Mimg = getattr(self.model, "sample_counter", 0) / 1e6 - - if self.formula == "constant": - lr = 1.0 - elif self.formula == "poly": - lr = max(cur_Mimg / self.ref_Mimg, 1e-8) ** -self.poly_exp - else: - raise ValueError(f'Invalid learning rate formula "{self.formula}"') - - if self.max_lr_ratio is not None: - lr = min(lr, self.max_lr_ratio) - if self.min_lr_ratio is not None: - lr = max(lr, self.min_lr_ratio) - - if self.rampup_Mimg > 0 and cur_Mimg < self.rampup_Mimg: - lr *= cur_Mimg / self.rampup_Mimg - if self.rampdown_Mimg > 0 and cur_Mimg > self.total_Mimg - self.rampdown_Mimg: - lr *= (self.total_Mimg - cur_Mimg) / self.rampdown_Mimg - - return lr - - def __call__(self, n, **kwargs): - return self.schedule(n, **kwargs) - - -class LambdaWarmUpCosineScheduler: - """ - A learning rate scheduler that combines warm-up with a cosine decay schedule for multiple cycles. - It supports different configurations for each cycle, including the number of warm-up steps, minimum - and maximum scaling factors for the learning rate. - - The scheduler is intended to be used with a base learning rate of 1.0, where the actual learning - rate at any step is the base learning rate multiplied by the scaling factor computed by the scheduler. - - Parameters: - warm_up_steps (list[int]): List of integers where each element represents the number of warm-up - steps for the corresponding cycle. - f_min (list[float]): List of the minimum scaling factors for each cycle after warm-up. - f_max (list[float]): List of the maximum scaling factors at the start and end of each cosine cycle. - f_start (list[float]): List of starting scaling factors for each warm-up phase. - cycle_lengths (list[int]): List of the total lengths of each cycle, including warm-up steps. - verbosity_interval (int, optional): Interval of training steps at which to print current step and - scaling factor information. Set to 0 by default to disable verbosity. - - Examples: - >>> scheduler = LambdaWarmUpCosineScheduler2( - warm_up_steps=[10, 10], - f_min=[0.1, 0.1], - f_max=[1.0, 1.0], - f_start=[0.01, 0.01], - cycle_lengths=[50, 50], - verbosity_interval=10) - >>> for step in range(100): - >>> lr_multiplier = scheduler(step) - >>> print(f"Step {step}: LR Multiplier = {lr_multiplier}") - """ - - def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): - assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) - self.lr_warm_up_steps = warm_up_steps - self.f_start = f_start - self.f_min = f_min - self.f_max = f_max - self.cycle_lengths = cycle_lengths - self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) - self.last_f = 0.0 - self.verbosity_interval = verbosity_interval - - def find_in_interval(self, n): - interval = 0 - for cl in self.cum_cycles[1:]: - if n <= cl: - return interval - interval += 1 - - def schedule(self, n, **kwargs): - cycle = self.find_in_interval(n) - n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - log.info(f"current step: {n}, recent lr-multiplier: {self.last_f}, current cycle {cycle}") - if n < self.lr_warm_up_steps[cycle]: - f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] - self.last_f = f - return f - else: - t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) - t = min(t, 1.0) - f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) - self.last_f = f - return f - - def __call__(self, n, **kwargs): - return self.schedule(n, **kwargs) - - -class LambdaLinearScheduler(LambdaWarmUpCosineScheduler): - """ - Linear instead of cosine decay for the main part of the cycle. - """ - - def schedule(self, n, **kwargs): - cycle = self.find_in_interval(n) - n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - log.info(f"current step: {n}, recent lr-multiplier: {self.last_f}, current cycle {cycle}") - - if n < self.lr_warm_up_steps[cycle]: - f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] - self.last_f = f - return f - else: - f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / ( - self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] - ) - self.last_f = f - return f diff --git a/lyra_2/_ext/imaginaire/lazy_config/__init__.py b/lyra_2/_ext/imaginaire/lazy_config/__init__.py deleted file mode 100644 index 8f222a2eacdbf7fe1713e4872690da4af3c9c000..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/lazy_config/__init__.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -import os - -from omegaconf import DictConfig, OmegaConf - -from lyra_2._ext.imaginaire.lazy_config.instantiate import instantiate -from lyra_2._ext.imaginaire.lazy_config.lazy import LazyCall, LazyConfig -from lyra_2._ext.imaginaire.lazy_config.omegaconf_patch import to_object - -OmegaConf.to_object = to_object - -PLACEHOLDER = None -LazyDict = DictConfig - -__all__ = ["instantiate", "LazyCall", "LazyConfig", "PLACEHOLDER", "LazyDict"] - - -DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py - - -def fixup_module_metadata(module_name, namespace, keys=None): - """ - Fix the __qualname__ of module members to be their exported api name, so - when they are referenced in docs, sphinx can find them. Reference: - https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241 - """ - if not DOC_BUILDING: - return - seen_ids = set() - - def fix_one(qualname, name, obj): - # avoid infinite recursion (relevant when using - # typing.Generic, for example) - if id(obj) in seen_ids: - return - seen_ids.add(id(obj)) - - mod = getattr(obj, "__module__", None) - if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")): - obj.__module__ = module_name - # Modules, unlike everything else in Python, put fully-qualitied - # names into their __name__ attribute. We check for "." to avoid - # rewriting these. - if hasattr(obj, "__name__") and "." not in obj.__name__: - obj.__name__ = name - obj.__qualname__ = qualname - if isinstance(obj, type): - for attr_name, attr_value in obj.__dict__.items(): - fix_one(objname + "." + attr_name, attr_name, attr_value) - - if keys is None: - keys = namespace.keys() - for objname in keys: - if not objname.startswith("_"): - obj = namespace[objname] - fix_one(objname, objname, obj) - - -fixup_module_metadata(__name__, globals(), __all__) -del fixup_module_metadata diff --git a/lyra_2/_ext/imaginaire/lazy_config/file_io.py b/lyra_2/_ext/imaginaire/lazy_config/file_io.py deleted file mode 100644 index d63d90ced6c75bb05087a0a9ec2feb432b7603dd..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/lazy_config/file_io.py +++ /dev/null @@ -1,24 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler -from iopath.common.file_io import PathManager as PathManagerBase - -__all__ = ["PathManager", "PathHandler"] - - -PathManager = PathManagerBase() -PathManager.register_handler(HTTPURLHandler()) -PathManager.register_handler(OneDrivePathHandler()) diff --git a/lyra_2/_ext/imaginaire/lazy_config/instantiate.py b/lyra_2/_ext/imaginaire/lazy_config/instantiate.py deleted file mode 100644 index 094100d9b482614d3eacf5fc7bef51172d3e3a17..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/lazy_config/instantiate.py +++ /dev/null @@ -1,120 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections.abc as abc -import dataclasses -import logging -from typing import Any - -import attrs - -from lyra_2._ext.imaginaire.lazy_config.registry import _convert_target_to_string, locate - -__all__ = ["dump_dataclass", "instantiate"] - - -def is_dataclass_or_attrs(target): - return dataclasses.is_dataclass(target) or attrs.has(target) - - -def dump_dataclass(obj: Any): - """ - Dump a dataclass recursively into a dict that can be later instantiated. - - Args: - obj: a dataclass object - - Returns: - dict - """ - assert dataclasses.is_dataclass(obj) and not isinstance(obj, type), ( - "dump_dataclass() requires an instance of a dataclass." - ) - ret = {"_target_": _convert_target_to_string(type(obj))} - for f in dataclasses.fields(obj): - v = getattr(obj, f.name) - if dataclasses.is_dataclass(v): - v = dump_dataclass(v) - if isinstance(v, (list, tuple)): - v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v] - ret[f.name] = v - return ret - - -def instantiate(cfg, *args, **kwargs): - """ - Recursively instantiate objects defined in dictionaries by - "_target_" and arguments. - - Args: - cfg: a dict-like object with "_target_" that defines the caller, and - other keys that define the arguments - args: Optional positional parameters pass-through. - kwargs: Optional named parameters pass-through. - - Returns: - object instantiated by cfg - """ - from omegaconf import DictConfig, ListConfig, OmegaConf - - if isinstance(cfg, ListConfig): - lst = [instantiate(x) for x in cfg] - return ListConfig(lst, flags={"allow_objects": True}) - if isinstance(cfg, list): - # Specialize for list, because many classes take - # list[objects] as arguments, such as ResNet, DatasetMapper - return [instantiate(x) for x in cfg] - - # If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config), - # instantiate it to the actual dataclass. - if isinstance(cfg, DictConfig) and is_dataclass_or_attrs(cfg._metadata.object_type): - return OmegaConf.to_object(cfg) - - if isinstance(cfg, abc.Mapping) and "_target_" in cfg: - # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all, - # but faster: https://github.com/facebookresearch/hydra/issues/1200 - is_recursive = getattr(cfg, "_recursive_", True) - if is_recursive: - cfg = {k: instantiate(v) for k, v in cfg.items()} - else: - cfg = {k: v for k, v in cfg.items()} - # pop the _recursive_ key to avoid passing it as a parameter - if "_recursive_" in cfg: - cfg.pop("_recursive_") - cls = cfg.pop("_target_") - cls = instantiate(cls) - - if isinstance(cls, str): - cls_name = cls - cls = locate(cls_name) - assert cls is not None, cls_name - else: - try: - cls_name = cls.__module__ + "." + cls.__qualname__ - except Exception: - # target could be anything, so the above could fail - cls_name = str(cls) - assert callable(cls), f"_target_ {cls} does not define a callable object" - try: - # override config with kwargs - instantiate_kwargs = {} - instantiate_kwargs.update(cfg) - instantiate_kwargs.update(kwargs) - return cls(*args, **instantiate_kwargs) - except TypeError: - logger = logging.getLogger(__name__) - logger.error(f"Error when instantiating {cls_name}!") - raise - return cfg # return as-is if don't know what to do diff --git a/lyra_2/_ext/imaginaire/lazy_config/lazy.py b/lyra_2/_ext/imaginaire/lazy_config/lazy.py deleted file mode 100644 index 9cd810ca3058bceec7d03298a15c960e14edd69a..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/lazy_config/lazy.py +++ /dev/null @@ -1,430 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import ast -import builtins -import collections.abc as abc -import importlib -import inspect -import logging -import os -import pickle -import uuid -from collections import OrderedDict -from contextlib import contextmanager -from copy import deepcopy -from dataclasses import is_dataclass -from typing import Any, Dict, List, Tuple, Union - -import attrs -import yaml -from omegaconf import DictConfig, ListConfig, OmegaConf - -try: - import dill as dill_pickle -except ImportError: - dill_pickle = None - -try: - import cloudpickle -except ImportError: - cloudpickle = None - -from lyra_2._ext.imaginaire.lazy_config.file_io import PathManager -from lyra_2._ext.imaginaire.lazy_config.registry import _convert_target_to_string - -__all__ = ["LazyCall", "LazyConfig"] - - -def sort_dict(d: Dict[str, Any]) -> OrderedDict[str, Any]: - return OrderedDict(sorted(d.items(), key=lambda x: x[0])) - - -def dict_representer(dumper: yaml.Dumper, data: OrderedDict[str, Any]) -> yaml.nodes.MappingNode: - return dumper.represent_mapping("tag:yaml.org,2002:map", data.items()) - - -def sort_recursive(obj: Union[Dict[str, Any], List[Any], Any]) -> Union[OrderedDict[str, Any], List[Any], Any]: - if isinstance(obj, dict): - return sort_dict({k: sort_recursive(v) for k, v in obj.items()}) - elif isinstance(obj, list): - return [sort_recursive(item) for item in obj] - return obj - - -yaml.add_representer(OrderedDict, dict_representer) - -OmegaConf.register_new_resolver("add", lambda *vals: sum(vals)) -OmegaConf.register_new_resolver("subtract", lambda *vals: vals[0] - sum(vals[1:])) - - -def get_default_params(cls_or_func): - if callable(cls_or_func): - # inspect signature for function - signature = inspect.signature(cls_or_func) - else: - # inspect signature for class - signature = inspect.signature(cls_or_func.__init__) - params = signature.parameters - default_params = { - name: param.default for name, param in params.items() if param.default is not inspect.Parameter.empty - } - return default_params - - -class LazyCall: - """ - Wrap a callable so that when it's called, the call will not be executed, - but returns a dict that describes the call. - - LazyCall object has to be called with only keyword arguments. Positional - arguments are not yet supported. - - Examples: - :: - from detectron2.config import instantiate, LazyCall - - layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32) - layer_cfg.out_channels = 64 # can edit it afterwards - layer = instantiate(layer_cfg) - """ - - def __init__(self, target): - if not (callable(target) or isinstance(target, (str, abc.Mapping))): - raise TypeError(f"target of LazyCall must be a callable or defines a callable! Got {target}") - self._target = target - - def __call__(self, **kwargs): - if is_dataclass(self._target) or attrs.has(self._target): - # omegaconf object cannot hold dataclass type - # https://github.com/omry/omegaconf/issues/784 - target = _convert_target_to_string(self._target) - else: - target = self._target - kwargs["_target_"] = target - - _final_params = get_default_params(self._target) - _final_params.update(kwargs) - - return DictConfig(content=_final_params, flags={"allow_objects": True}) - - -def _visit_dict_config(cfg, func): - """ - Apply func recursively to all DictConfig in cfg. - """ - if isinstance(cfg, DictConfig): - func(cfg) - for v in cfg.values(): - _visit_dict_config(v, func) - elif isinstance(cfg, ListConfig): - for v in cfg: - _visit_dict_config(v, func) - - -def _validate_py_syntax(filename): - # see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py - with PathManager.open(filename, "r") as f: - content = f.read() - try: - ast.parse(content) - except SyntaxError as e: - raise SyntaxError(f"Config file {filename} has syntax error!") from e - - -def _cast_to_config(obj): - # if given a dict, return DictConfig instead - if isinstance(obj, dict): - return DictConfig(obj, flags={"allow_objects": True}) - return obj - - -_CFG_PACKAGE_NAME = "detectron2._cfg_loader" -""" -A namespace to put all imported config into. -""" - - -def _random_package_name(filename): - # generate a random package name when loading config files - return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename) - - -@contextmanager -def _patch_import(): - """ - Enhance relative import statements in config files, so that they: - 1. locate files purely based on relative location, regardless of packages. - e.g. you can import file without having __init__ - 2. do not cache modules globally; modifications of module states has no side effect - 3. support other storage system through PathManager, so config files can be in the cloud - 4. imported dict are turned into omegaconf.DictConfig automatically - """ - old_import = builtins.__import__ - - def find_relative_file(original_file, relative_import_path, level): - # NOTE: "from . import x" is not handled. Because then it's unclear - # if such import should produce `x` as a python module or DictConfig. - # This can be discussed further if needed. - relative_import_err = """ -Relative import of directories is not allowed within config files. -Within a config file, relative import can only import other config files. -""".replace("\n", " ") - if not len(relative_import_path): - raise ImportError(relative_import_err) - - cur_file = os.path.dirname(original_file) - for _ in range(level - 1): - cur_file = os.path.dirname(cur_file) - cur_name = relative_import_path.lstrip(".") - for part in cur_name.split("."): - cur_file = os.path.join(cur_file, part) - if not cur_file.endswith(".py"): - cur_file += ".py" - if not PathManager.isfile(cur_file): - cur_file_no_suffix = cur_file[: -len(".py")] - if PathManager.isdir(cur_file_no_suffix): - raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err) - else: - raise ImportError( - f"Cannot import name {relative_import_path} from {original_file}: {cur_file} does not exist." - ) - return cur_file - - def new_import(name, globals=None, locals=None, fromlist=(), level=0): - if ( - # Only deal with relative imports inside config files - level != 0 and globals is not None and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME) - ): - cur_file = find_relative_file(globals["__file__"], name, level) - _validate_py_syntax(cur_file) - spec = importlib.machinery.ModuleSpec(_random_package_name(cur_file), None, origin=cur_file) - module = importlib.util.module_from_spec(spec) - module.__file__ = cur_file - with PathManager.open(cur_file) as f: - content = f.read() - exec(compile(content, cur_file, "exec"), module.__dict__) - for name in fromlist: # turn imported dict into DictConfig automatically - val = _cast_to_config(module.__dict__[name]) - module.__dict__[name] = val - return module - return old_import(name, globals, locals, fromlist=fromlist, level=level) - - builtins.__import__ = new_import - yield new_import - builtins.__import__ = old_import - - -class LazyConfig: - """ - Provide methods to save, load, and overrides an omegaconf config object - which may contain definition of lazily-constructed objects. - """ - - @staticmethod - def load_rel(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): - """ - Similar to :meth:`load()`, but load path relative to the caller's - source file. - - This has the same functionality as a relative import, except that this method - accepts filename as a string, so more characters are allowed in the filename. - """ - caller_frame = inspect.stack()[1] - caller_fname = caller_frame[0].f_code.co_filename - assert caller_fname != "", "load_rel Unable to find caller" - caller_dir = os.path.dirname(caller_fname) - filename = os.path.join(caller_dir, filename) - return LazyConfig.load(filename, keys) - - @staticmethod - def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): - """ - Load a config file. - - Args: - filename: absolute path or relative path w.r.t. the current working directory - keys: keys to load and return. If not given, return all keys - (whose values are config objects) in a dict. - """ - has_keys = keys is not None - filename = filename.replace("/./", "/") # redundant - if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]: - raise ValueError(f"Config file {filename} has to be a python or yaml file.") - if filename.endswith(".py"): - _validate_py_syntax(filename) - - with _patch_import(): - # Record the filename - module_namespace = { - "__file__": filename, - "__package__": _random_package_name(filename), - } - with PathManager.open(filename) as f: - content = f.read() - # Compile first with filename to: - # 1. make filename appears in stacktrace - # 2. make load_rel able to find its parent's (possibly remote) location - exec(compile(content, filename, "exec"), module_namespace) - - ret = module_namespace - else: - with PathManager.open(filename) as f: - obj = yaml.unsafe_load(f) - ret = OmegaConf.create(obj, flags={"allow_objects": True}) - - if has_keys: - if isinstance(keys, str): - return _cast_to_config(ret[keys]) - else: - return tuple(_cast_to_config(ret[a]) for a in keys) - else: - if filename.endswith(".py"): - # when not specified, only load those that are config objects - ret = DictConfig( - { - name: _cast_to_config(value) - for name, value in ret.items() - if isinstance(value, (DictConfig, ListConfig, dict)) and not name.startswith("_") - }, - flags={"allow_objects": True}, - ) - return ret - - @staticmethod - def save_pkl(cfg, filename: str) -> str: - """ - Saves a Config object to a file using pickle serialization. This method is typically used - when the configuration object contains complex objects, such as lambdas, that are not supported by - simpler serialization methods like YAML. The function attempts to create a deep copy of the configuration - object before serialization to ensure that the original object remains unmodified. - - Args: - cfg: A Config object to be serialized and saved. - filename: The path and name of the file where the configuration should be saved. The function - assumes the file extension indicates a pickle format (e.g., .pkl). - - Returns: - str: The filename to which the configuration was saved. This can be used to verify the file location - or log the outcome. - - Notes: - - The function logs a warning if the configuration is successfully saved using pickle. - - If saving fails, an error is logged with the exception details. - """ - logger = logging.getLogger(__name__) - try: - cfg = deepcopy(cfg) - except Exception: - pass - - try: - with PathManager.open(filename, "wb") as f: - pickle.dump(cfg, f) - logger.warning(f"Config is saved using pickle at {filename}.") - except Exception as e: - logger.error(f"Failed to save config to {filename}: {e}. Trying dill or cloudpickle instead") - if dill_pickle: - try: - with PathManager.open(filename, "wb") as f: - pickle.dump(dill_pickle.dumps(cfg, recurse=True), f) - logger.warning(f"Config is saved using dill at {filename}.") - except Exception as e: - logger.error(f"Failed to save config to {filename}: {e}.") - if cloudpickle: - try: - with PathManager.open(filename, "wb") as f: - pickle.dump(cloudpickle.dumps(cfg), f) - logger.warning(f"Config is saved using cloudpickle at {filename}.") - except Exception as e: - logger.error(f"Failed to save config to {filename}: {e}.") - else: - logger.error("cloudpickle is not available. Cannot save the config.") - raise e - - return filename - - @staticmethod - def save_yaml(cfg, filename: str) -> str: - """ - Saves a Config object to a file using YAML serialization. This method is beneficial when the configuration object's content needs to be human-readable and easily editable. YAML is suitable for configurations that do not contain complex types like lambdas, which must be handled differently. The function converts unserializable items to strings before saving to ensure compatibility with YAML serialization. - - Args: - cfg: A Config object to be serialized and saved. It handles both DictConfig and ListConfig types. - filename: The path and name of the file where the configuration should be saved. The function does not require a specific file extension but typically uses '.yaml'. - - Returns: - str: The filename to which the configuration was saved. This can be used to verify the file location or log the outcome. - - Notes: - - The function logs a warning if the configuration is successfully saved using YAML. - - If saving fails, an error is logged with the exception details. - """ - logger = logging.getLogger(__name__) - try: - cfg = deepcopy(cfg) - except Exception: - pass - - # Define a function to check if an item is serializable to YAML - def is_serializable(item): - try: - OmegaConf.to_yaml(item) - return True - except Exception as e: - return False - - # Function to convert unserializable items to strings - def serialize_config(config): - if isinstance(config, DictConfig): - for key, value in config.items(): - if isinstance(value, (DictConfig, ListConfig)): - try: - if "_target_" in value: - default_params = get_default_params(value["_target_"]) - for default_key, default_v in default_params.items(): - if default_key not in value: - value[default_key] = default_v - except Exception as e: - logger.error(f"Failed to add default argument values: {e}") - - serialize_config(value) - else: - if not is_serializable(value) and value is not None: - config[key] = str(value) - elif isinstance(config, ListConfig): - for i, item in enumerate(config): - if isinstance(item, (DictConfig, ListConfig)): - serialize_config(item) - else: - if not is_serializable(item) and item is not None: - config[i] = str(item) - else: - raise NotImplementedError("Input config must be a DictConfig or ListConfig.") - return config - - # Convert Config object to a DictConfig object. - config_dict = attrs.asdict(cfg) - config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True}) - - # Serialize the DictConfig object by converting non-serializable objects to strings. - config_omegaconf = serialize_config(config_omegaconf) - - config_dict: Dict[str, Any] = OmegaConf.to_container(config_omegaconf, resolve=True) - sorted_config: OrderedDict[str, Any] = sort_recursive(config_dict) - with open(filename, "w") as f: - yaml.dump(sorted_config, f, default_flow_style=False) - logger.warning(f"Config is saved using omegaconf at {filename}.") - return filename diff --git a/lyra_2/_ext/imaginaire/lazy_config/omegaconf_patch.py b/lyra_2/_ext/imaginaire/lazy_config/omegaconf_patch.py deleted file mode 100644 index fd1ba013974070bde6189f1f8e5d5f0270f8fde3..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/lazy_config/omegaconf_patch.py +++ /dev/null @@ -1,65 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, List, Union - -from omegaconf import OmegaConf -from omegaconf.base import DictKeyType, SCMode -from omegaconf.dictconfig import DictConfig # pragma: no cover - - -def to_object(cfg: Any) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: - """ - Converts an OmegaConf configuration object to a native Python container (dict or list), unless - the configuration is specifically created by LazyCall, in which case the original configuration - is returned directly. - - This function serves as a modification of the original `to_object` method from OmegaConf, - preventing DictConfig objects created by LazyCall from being automatically converted to Python - dictionaries. This ensures that configurations meant to be lazily evaluated retain their intended - structure and behavior. - - Differences from OmegaConf's original `to_object`: - - Adds a check at the beginning to return the configuration unchanged if it is created by LazyCall. - - Reference: - - Original OmegaConf `to_object` method: https://github.com/omry/omegaconf/blob/master/omegaconf/omegaconf.py#L595 - - Args: - cfg (Any): The OmegaConf configuration object to convert. - - Returns: - Union[Dict[DictKeyType, Any], List[Any], None, str, Any]: The converted Python container if - `cfg` is not a LazyCall created configuration, otherwise the unchanged `cfg`. - - Examples: - >>> cfg = DictConfig({"key": "value", "_target_": "Model"}) - >>> to_object(cfg) - DictConfig({"key": "value", "_target_": "Model"}) - - >>> cfg = DictConfig({"list": [1, 2, 3]}) - >>> to_object(cfg) - {'list': [1, 2, 3]} - """ - if isinstance(cfg, DictConfig) and "_target_" in cfg.keys(): - return cfg - - return OmegaConf.to_container( - cfg=cfg, - resolve=True, - throw_on_missing=True, - enum_to_str=False, - structured_config_mode=SCMode.INSTANTIATE, - ) diff --git a/lyra_2/_ext/imaginaire/lazy_config/registry.py b/lyra_2/_ext/imaginaire/lazy_config/registry.py deleted file mode 100644 index e435ca49df8275e9a0fae9cc6e45df25e1d867ed..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/lazy_config/registry.py +++ /dev/null @@ -1,74 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pydoc -from typing import Any - -from fvcore.common.registry import Registry # for backward compatibility. - -""" -``Registry`` and `locate` provide ways to map a string (typically found -in config files) to callable objects. -""" - -__all__ = ["Registry", "locate"] - - -def _convert_target_to_string(t: Any) -> str: - """ - Inverse of ``locate()``. - - Args: - t: any object with ``__module__`` and ``__qualname__`` - """ - module, qualname = t.__module__, t.__qualname__ - - # Compress the path to this object, e.g. ``module.submodule._impl.class`` - # may become ``module.submodule.class``, if the later also resolves to the same - # object. This simplifies the string, and also is less affected by moving the - # class implementation. - module_parts = module.split(".") - for k in range(1, len(module_parts)): - prefix = ".".join(module_parts[:k]) - candidate = f"{prefix}.{qualname}" - try: - if locate(candidate) is t: - return candidate - except ImportError: - pass - return f"{module}.{qualname}" - - -def locate(name: str) -> Any: - """ - Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``, - such as "module.submodule.class_name". - - Raise Exception if it cannot be found. - """ - obj = pydoc.locate(name) - - # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly - # by pydoc.locate. Try a private function from hydra. - if obj is None: - try: - # from hydra.utils import get_method - will print many errors - from hydra.utils import _locate - except ImportError as e: - raise ImportError(f"Cannot dynamically locate object {name}!") from e - else: - obj = _locate(name) # it raises if fails - - return obj diff --git a/lyra_2/_ext/imaginaire/model.py b/lyra_2/_ext/imaginaire/model.py deleted file mode 100644 index 61d4d8d472dae944c5e48dc47bd2ffa96f712c4a..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/model.py +++ /dev/null @@ -1,129 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -import torch - -from lyra_2._ext.imaginaire.lazy_config import LazyDict, instantiate - - -class ImaginaireModel(torch.nn.Module): - """The base model class of Imaginaire. It is inherited from torch.nn.Module. - - All models in Imaginaire should inherit ImaginaireModel. It should include the implementions for all the - computation graphs. All inheriting child classes should implement the following methods: - - training_step(): The training step of the model, including the loss computation. - - validation_step(): The validation step of the model, including the loss computation. - - forward(): The computation graph for model inference. - The following methods have default implementations in ImaginaireModel: - - init_optimizer_scheduler(): Creates the optimizer and scheduler for the model. - """ - - def __init__(self) -> None: - super().__init__() - - def init_optimizer_scheduler( - self, optimizer_config: LazyDict, scheduler_config: LazyDict - ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: - """Creates the optimizer and scheduler for the model. - - Args: - config_model (ModelConfig): The config object for the model. - - Returns: - optimizer (torch.optim.Optimizer): The model optimizer. - scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. - """ - optimizer_config.params = self.parameters() - optimizer = instantiate(optimizer_config) - scheduler_config.optimizer = optimizer - scheduler = instantiate(scheduler_config) - return optimizer, scheduler - - def training_step( - self, data_batch: dict[str, torch.Tensor], iteration: int - ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: - """The training step of the model, including the loss computation. - - Args: - data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). - iteration (int): Current iteration number. - - Returns: - output_batch (dict[str, torch.Tensor]): Auxiliary model output from the training batch. - loss (torch.Tensor): The total loss for backprop (weighted sum of various losses). - """ - raise NotImplementedError - - @torch.no_grad() - def validation_step( - self, data_batch: dict[str, torch.Tensor], iteration: int - ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: - """The validation step of the model, including the loss computation. - - Args: - data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). - iteration (int): Current iteration number. - - Returns: - output_batch (dict[str, torch.Tensor]): Auxiliary model output from the validation batch. - loss (torch.Tensor): The total loss (weighted sum of various losses). - """ - raise NotImplementedError - - @torch.inference_mode() - def forward(self, *args: Any, **kwargs: Any) -> Any: - """The computation graph for model inference. - - Args: - *args: Whatever you decide to pass into the forward method. - **kwargs: Keyword arguments are also possible. - - Return: - Your model's output. - """ - raise NotImplementedError - - def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: - """The model preparation before the training is launched - - Args: - memory_format (torch.memory_format): Memory format of the model. - """ - pass - - def on_before_zero_grad( - self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int - ) -> None: - """Hook before zero_grad() is called. - - Args: - optimizer (torch.optim.Optimizer): The model optimizer. - scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. - iteration (int): Current iteration number. - """ - pass - - def on_after_backward(self, iteration: int = 0) -> None: - """Hook after loss.backward() is called. - - This method is called immediately after the backward pass, allowing for custom operations - or modifications to be performed on the gradients before the optimizer step. - - Args: - iteration (int): Current iteration number. - """ - pass diff --git a/lyra_2/_ext/imaginaire/trainer.py b/lyra_2/_ext/imaginaire/trainer.py deleted file mode 100644 index 9b0cc001f01f3a4994cc6bede61058dc4a79e970..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/trainer.py +++ /dev/null @@ -1,379 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import inspect -import os -import signal - -import torch -import torch.distributed as dist -import torch.utils.data - -from lyra_2._ext.imaginaire.utils.context_managers import distributed_init -from lyra_2._ext.imaginaire.utils.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling - -try: - from megatron.core import parallel_state - - USE_MEGATRON = True -except ImportError: - USE_MEGATRON = False - print("Megatron-core is not installed.") - - -from lyra_2._ext.imaginaire.lazy_config import LazyConfig, instantiate -from lyra_2._ext.imaginaire.model import ImaginaireModel -from lyra_2._ext.imaginaire.utils import callback, distributed, ema, log, misc -from lyra_2._ext.imaginaire.utils.checkpointer import Checkpointer -from lyra_2._ext.imaginaire.utils.misc import StragglerDetectorV2 - -try: - from lyra_2._src.callbacks.smart_stop import TimeoutException -except ImportError: - # Define a dummy exception if smart_stop is not available - class TimeoutException(Exception): - pass - - -class ImaginaireTrainer: - """The base trainer class of Imaginaire. - - All trainers in Imaginaire should inherit ImaginaireTrainer. It contains the basic functionality for model training - (particularly suited for large-scale training), including data parallel (DDP/FSDP), model weight average (EMA), - mixed-precision training (fp16/bf16). - - Attributes: - checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states. - training_timer (misc.Timer): Timer object to time code blocks and functions. - """ - - def __init__(self, config): - """Constructor of the trainer. - - Args: - config (Config): The config object for the Imaginaire codebase. - """ - super().__init__() - self.config = config - # Set up the distributed computing environment. - with distributed_init(): - distributed.init() - # Set up parallel states. - if hasattr(config.model, "context_parallel_size"): - if config.model_parallel.context_parallel_size > 1: - raise ValueError( - "Both config.model.context_parallel_size and config.model_parallel.context_parallel_size are set. " - "config.model.context_parallel_size is deprecated. Please only set config.model_parallel.context_parallel_size." - ) - else: - log.critical( - "Using deprecated config.model.context_parallel_size. Please use config.model_parallel.context_parallel_size instead." - ) - config.model_parallel.context_parallel_size = config.model.context_parallel_size - if USE_MEGATRON: - if ( - "create_gloo_process_groups" - in inspect.signature(parallel_state.initialize_model_parallel).parameters - ): - parallel_state.initialize_model_parallel( - pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size, - tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size, - context_parallel_size=config.model_parallel.context_parallel_size, - create_gloo_process_groups=False, - ) - else: - parallel_state.initialize_model_parallel( - pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size, - tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size, - context_parallel_size=config.model_parallel.context_parallel_size, - ) - # `config.model_parallel.sequence_parallel` is a bool that indicates whether to use sequence parallelism. - # It is not part of the original `parallel_state` API, so we need to set it manually. - parallel_state.sequence_parallel = config.model_parallel.sequence_parallel - if parallel_state.sequence_parallel: - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - - # Create the local job directory, save the config file, and pipe to a local log. - if distributed.is_rank0(): - os.makedirs(config.job.path_local, exist_ok=True) - # Save the config as .pkl for reproducibility. - LazyConfig.save_pkl(config, f"{config.job.path_local}/config.pkl") - # Save the config as .yaml for reading or parsing experiment hyperparameters. - LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") - dist.barrier() - log.init_loguru_file(f"{config.job.path_local}/stdout.log") - if distributed.is_rank0(): - # Print important environment variables and the effective config. - log.info("Config:\n" + config.pretty_print(use_color=True)) - misc.print_environ_variables(["HF_HOME", "IMAGINAIRE_OUTPUT_ROOT"]) - # Set the random seed. If multi-GPU, different ranks are set with different seeds. - misc.set_random_seed(seed=config.trainer.seed, by_rank=True) - # Initialize cuDNN. - torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic - torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark - # Floating-point precision settings. - torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True - # Initialize the callback functions. - self.callbacks = callback.CallBackGroup(config=config, trainer=self) - # Initialize the model checkpointer. - if config.checkpoint.type is None: - self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks) - else: - self.checkpointer: Checkpointer = instantiate( - config.checkpoint.type, config.checkpoint, config.job, callbacks=self.callbacks - ) - # Initialize the timer for speed benchmarking. - self.training_timer = misc.TrainingTimer() - # Initialize Straggler Detection - self.straggler_detector = StragglerDetectorV2( - enabled=self.config.trainer.straggler_detection.enabled, - report_freq=self.config.trainer.straggler_detection.report_freq, - profile_freq=self.config.trainer.straggler_detection.profile_freq, - max_diff=self.config.trainer.straggler_detection.max_diff, - raise_error=self.config.trainer.straggler_detection.raise_error, - ) - self.straggler_detector.initialize() - # Send a TimeoutError if a training step takes over timeout_period seconds. - signal.signal(signal.SIGALRM, functools.partial(misc.timeout_handler, config.trainer.timeout_period)) # type: ignore - - def train( - self, - model: ImaginaireModel, - dataloader_train: torch.utils.data.DataLoader, - dataloader_val: torch.utils.data.DataLoader, - ) -> None: - """The training function. - - Args: - model (ImaginaireModel): The PyTorch model. - dataloader_train (torch.utils.data.DataLoader): The training data loader. - dataloader_val (torch.utils.data.DataLoader): The validation data loader. - """ - # Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models. - model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore - model.on_train_start(self.config.trainer.memory_format) - - # Initialize the optimizer, scheduler, and grad_scaler. - self.callbacks.on_optimizer_init_start() - optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler) - grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args) - self.callbacks.on_optimizer_init_end() - # Load the model checkpoint and get the starting iteration number. - iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler) - grad_accum_iter = 0 - log.info(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}") - if self.config.trainer.distributed_parallelism == "ddp": - # Create a DDP model wrapper. - model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model) - elif self.config.trainer.distributed_parallelism == "fsdp": - model_ddp = model - else: - raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}") - - log.info("Starting training...") - self.callbacks.on_train_start(model, iteration=iteration) - # Initial validation. - if self.config.trainer.run_validation and iteration == 0: - self.validate(model, dataloader_val, iteration=iteration) - _end_training = False - _smart_stop_triggered = False - _oom_triggered = False - with ( - maybe_enable_profiling(self.config, global_step=iteration) as torch_profiler, - maybe_enable_memory_snapshot(self.config, global_step=iteration) as memory_profiler, - ): - while True: - dataloader_train_iter = iter(dataloader_train) - while True: - self.callbacks.on_before_dataloading(iteration) - try: - with ( - self.training_timer("dataloader_train"), - self.straggler_detector.profile_section( - "dataloading", - self.config.trainer.straggler_detection.analyze_dataloading, - profile_cuda=False, - ), - ): - data_batch = next(dataloader_train_iter) - except StopIteration: - break - finally: - self.callbacks.on_after_dataloading(iteration) - # If max_iter is reached, exit the training loop. - if iteration >= self.config.trainer.max_iter: - _end_training = True - break - # Move all tensors in the data batch to GPU device. - data_batch = misc.to(data_batch, device="cuda") - # The actual training step. - self.callbacks.on_training_step_start(model, data_batch, iteration=iteration) - self.callbacks.on_training_step_batch_start(model, data_batch, iteration=iteration) - if not model.training: - model_ddp.train() - assert model_ddp.training, "model_ddp is not in training mode." - assert model.training, "model is not in training mode." - try: - output_batch, loss, grad_accum_iter = self.training_step( - model_ddp, - optimizer, - scheduler, - grad_scaler, - data_batch, - iteration=iteration, - grad_accum_iter=grad_accum_iter, - ) - except torch.OutOfMemoryError as e: - # CUDA OOM error - save checkpoint and exit gracefully - _oom_triggered = True - log.error(f"CUDA Out of Memory error caught: {e}") - log.info("Saving checkpoint due to CUDA OOM error...") - self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) - log.success("Checkpoint saved successfully. Exiting training gracefully due to OOM.") - _end_training = True - break - self.callbacks.on_training_step_batch_end( - model, data_batch, output_batch, loss, iteration=iteration - ) - # If the gradients are still being accumulated, continue to load the next training batch. - if grad_accum_iter != 0: - continue - # Do the following when an actual optimizer (update) step has been made. - iteration += 1 - # Save checkpoint. - if iteration % self.config.checkpoint.save_iter == 0: - self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) - # Call on_training_step_end with SmartStop exception handling - try: - self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration) - except TimeoutException as e: - # Smart stop triggered - save checkpoint and exit gracefully - _smart_stop_triggered = True - log.warning(f"SmartStop exception caught: {e}") - log.info("Saving checkpoint due to time limit...") - self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) - log.success("Checkpoint saved successfully. Exiting training gracefully.") - _end_training = True - break - # Validation. - if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0: - self.validate(model, dataloader_val, iteration=iteration) - # This iteration is successful; reset the timeout signal. - signal.alarm(self.config.trainer.timeout_period) - self.straggler_detector.generate_report(iteration) - if torch_profiler: - torch_profiler.step() - if memory_profiler: - memory_profiler.step() - if _end_training: - break - log.success("Done with training.") - if iteration % self.config.checkpoint.save_iter != 0 and not _smart_stop_triggered and not _oom_triggered: - self.checkpointer.save(model, optimizer, scheduler, grad_scaler, iteration=iteration) - self.callbacks.on_train_end(model, iteration=iteration) - self.checkpointer.finalize() - distributed.barrier() - self.callbacks.on_app_end() - - def training_step( - self, - model_ddp: torch.nn.Module | distributed.DistributedDataParallel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - data: dict[str, torch.Tensor], - iteration: int = 0, - grad_accum_iter: int = 0, - ) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]: - """The training step. - - Args: - model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare - module, depending on whether distributed training is enabled or not. - optimizer (torch.optim.Optimizer): The model optimizer. - scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. - grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). - data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). - iteration (int): Current iteration number. - grad_accum_iter (int): Number of gradient accumulation iterations. - - Returns: - output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors). - loss (torch.Tensor): The total loss of the training data batch. - """ - # Only let DDP sync gradient at the last iteration of the gradient accumulation window - with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1): - self.callbacks.on_before_forward(iteration=iteration) - with self.training_timer("forward"): - with self.straggler_detector.profile_section( - "fwd", self.config.trainer.straggler_detection.analyze_forward - ): - output_batch, loss = model_ddp.training_step(data, iteration) - self.callbacks.on_after_forward(iteration=iteration) - self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration) - with self.training_timer("backward"): - with self.straggler_detector.profile_section( - "bwd", self.config.trainer.straggler_detection.analyze_backward - ): - loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter) - loss_scaled.backward() - if self.config.trainer.distributed_parallelism == "ddp": - model_ddp.module.on_after_backward() - else: - model_ddp.on_after_backward() - self.callbacks.on_after_backward(model_ddp, iteration=iteration) - grad_accum_iter += 1 - if grad_accum_iter == self.config.trainer.grad_accum_iter: - with self.training_timer("optimizer_step"): - with self.straggler_detector.profile_section( - "opt", self.config.trainer.straggler_detection.analyze_optimizer - ): - self.callbacks.on_before_optimizer_step( - model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration - ) - grad_scaler.step(optimizer) - grad_scaler.update() - scheduler.step() - self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration) - if self.config.trainer.distributed_parallelism == "ddp": - model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration) - else: - model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration) - optimizer.zero_grad(set_to_none=True) - grad_accum_iter = 0 - return output_batch, loss, grad_accum_iter - - @torch.no_grad() - def validate(self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: - """Validate on the full validation dataset. - - Args: - model (ImaginaireModel): The PyTorch model. - dataloader_val (torch.utils.data.DataLoader): The validation data loader. - iteration (int): Current iteration number. - """ - self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) - model.eval() - # Evaluate on the full validation set. - with ema.ema_scope(model, enabled=model.config.ema.enabled): - for val_iter, data_batch in enumerate(dataloader_val): - if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: - break - data_batch = misc.to(data_batch, device="cuda") - self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) - output_batch, loss = model.validation_step(data_batch, iteration) - self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) - self.callbacks.on_validation_end(model, iteration=iteration) diff --git a/lyra_2/_ext/imaginaire/types/denoise_prediction.py b/lyra_2/_ext/imaginaire/types/denoise_prediction.py deleted file mode 100644 index 3588651babfe35ba176a093d75b519c3f146eb92..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/types/denoise_prediction.py +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Optional - -import torch - - -@dataclass -class DenoisePrediction: - x0: torch.Tensor # clean data prediction - eps: Optional[torch.Tensor] = None # noise prediction - logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty diff --git a/lyra_2/_ext/imaginaire/utils/__init__.py b/lyra_2/_ext/imaginaire/utils/__init__.py deleted file mode 100644 index dac9a4d7496eb38831f1f3c820a90d50e25e2a7e..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/lyra_2/_ext/imaginaire/utils/callback.py b/lyra_2/_ext/imaginaire/utils/callback.py deleted file mode 100644 index de470969177d8483a4f2555ea430e5cca5c586d9..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/callback.py +++ /dev/null @@ -1,524 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import sys -import time -import warnings -from typing import TYPE_CHECKING, Any, Callable, Optional - -import omegaconf -import torch -import torch.distributed as dist -import torch.utils.data -import tqdm -from lyra_2._ext.imaginaire.lazy_config import instantiate -from lyra_2._ext.imaginaire.utils import distributed, log, misc -from lyra_2._ext.imaginaire.utils.misc import get_local_tensor_if_DTensor - -try: - from megatron.core import parallel_state -except ImportError: - parallel_state = None - print("Megatron-core is not installed.") - - -if TYPE_CHECKING: - from lyra_2._ext.imaginaire.config import Config - from lyra_2._ext.imaginaire.model import ImaginaireModel - from lyra_2._ext.imaginaire.trainer import ImaginaireTrainer - - -class CallBackGroup: - """A class for hosting a collection of callback objects. - - It is used to execute callback functions of multiple callback objects with the same method name. - When callbackgroup.func(args) is executed, internally it loops through the objects in self._callbacks and runs - self._callbacks[0].func(args), self._callbacks[1].func(args), etc. The method name and arguments should match. - - Attributes: - _callbacks (list[Callback]): List of callback objects. - """ - - def __init__(self, config: Config, trainer: ImaginaireTrainer) -> None: - """Initializes the list of callback objects. - - Args: - config (Config): The config object for the Imaginaire codebase. - trainer (ImaginaireTrainer): The main trainer. - """ - self._callbacks = [] - callback_configs = config.trainer.callbacks - if callback_configs: - if isinstance(callback_configs, list) or isinstance(callback_configs, omegaconf.listconfig.ListConfig): - warnings.warn( - "The 'config.trainer.callbacks' parameter should be a dict instead of a list. " - "Please update your code", - DeprecationWarning, - stacklevel=2, - ) - callback_configs = {f"callback_{i}": v for i, v in enumerate(callback_configs)} - for callback_name, current_callback_cfg in callback_configs.items(): - if "_target_" not in current_callback_cfg: - log.critical( - f"Callback {callback_name} is missing the '_target_' field. \n SKip {current_callback_cfg}" - ) - continue - log.info(f"Instantiating callback {callback_name}: {current_callback_cfg}") - _callback = instantiate(current_callback_cfg) - assert isinstance(_callback, Callback), f"{current_callback_cfg} is not a valid callback." - _callback.config = config - _callback.trainer = trainer - self._callbacks.append(_callback) - - def __getattr__(self, method_name: str) -> Callable: - """Loops through the callback objects to call the corresponding callback function. - - Args: - method_name (str): Callback method name. - """ - - def multi_callback_wrapper(*args, **kwargs) -> None: - for callback in self._callbacks: - assert hasattr(callback, method_name) - method = getattr(callback, method_name) - assert callable(method) - _ = method(*args, **kwargs) - - return multi_callback_wrapper - - -class Callback: - """The base class for all callbacks. - - All callbacks should inherit from this class and adhere to the established method names and signatures. - """ - - def __init__(self, config: Optional["Config"] = None, trainer: Optional["ImaginaireTrainer"] = None): - """Initializes a Callback object. - - Args: - config (Optional[Config]): The configuration object for the Imaginaire codebase, if available. - trainer (Optional[ImaginaireTrainer]): The main trainer handling the training loop, if available. - - Notes: - The config and trainer parameters are optional to maintain backward compatibility. - In future releases, these parameters will be removed. Upon using these parameters, a deprecation - warning will be issued. - - """ - if config is not None or trainer is not None: - warnings.warn( - "The 'config' and 'trainer' parameters are deprecated and will be removed in a future release. " - "Please update your code to create Callback instances without these parameters.", - DeprecationWarning, - stacklevel=2, - ) - del config, trainer - - def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None: - pass - - def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None: - """ - Called before the training step, for each batch. This is paired with on_training_step_end() but note that - when using gradient accumulation, while on_training_step_end() is only called when the optimizer is updated, - this function is called for every batch. - Use on_training_step_batch_start and on_training_step_batch_end if you need callbacks that are called - for every batch, albeit with the same iteration number. - FIXME - should this either be deprecated, or called only when a new training step is started after having updated - the optimizer? - """ - pass - - def on_training_step_batch_start( - self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0 - ) -> None: - """ - Called before the training step, for each batch, similarly to on_training_step_start(). This function is paired with - on_training_step_batch_end(), and both functions are called for every batch even when using gradient accumulation. - Note that the iteration is only updated when the optimizer is updated, and therefore it may be the same for multiple invocations. - """ - pass - - def on_before_forward(self, iteration: int = 0) -> None: - pass - - def on_after_forward(self, iteration: int = 0) -> None: - pass - - def on_before_backward( - self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0 - ) -> None: - pass - - def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None: - pass - - def on_before_dataloading(self, iteration: int = 0) -> None: - pass - - def on_after_dataloading(self, iteration: int = 0) -> None: - pass - - def on_optimizer_init_start(self) -> None: - pass - - def on_optimizer_init_end(self) -> None: - pass - - def on_before_optimizer_step( - self, - model_ddp: distributed.DistributedDataParallel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - iteration: int = 0, - ) -> None: - pass - - def on_before_zero_grad( - self, - model_ddp: distributed.DistributedDataParallel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - iteration: int = 0, - ) -> None: - pass - - def on_training_step_batch_end( - self, - model: ImaginaireModel, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - """ - Called at the end of a training step for every batch even when using gradient accumulation. - This is paired with on_training_step_batch_start(). Note that the iteration is only updated when the optimizer is updated, - and therefore it may be the same for multiple batches. - """ - pass - - def on_training_step_end( - self, - model: ImaginaireModel, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - """ - Called at the end of a training step, but note that when using gradient accumulation, this is only called - when the optimizer is updated, and the iteration incremented, whereas on_training_step_start is called every time. - Use on_training_step_batch_start and on_training_step_batch_end if you need callbacks that are called - for every batch. - """ - pass - - def on_validation_start( - self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0 - ) -> None: - pass - - def on_validation_step_start( - self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0 - ) -> None: - pass - - def on_validation_step_end( - self, - model: ImaginaireModel, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - pass - - def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None: - pass - - def on_load_checkpoint_start(self, model: ImaginaireModel) -> None: - pass - - def on_load_checkpoint_end( - self, model: ImaginaireModel, iteration: int = 0, checkpoint_path: Optional[str] = None - ) -> None: - pass - - def on_load_checkpoint(self, model: ImaginaireModel, state_dict: dict[Any]) -> None: - """ - Called when checkpoint loading is about to start, but after on_save_checkpoint_start(). - FIXME - why do we need this callback, can't we just use on_save_checkpoint_start()? - """ - pass - - def on_save_checkpoint_start(self, model: ImaginaireModel, iteration: int = 0) -> None: - """ - Called when checkpoint saving is about to start. - """ - pass - - def on_save_checkpoint_end(self, model: ImaginaireModel, iteration: int = 0) -> None: - """ - Called when the synchronous part of checkpointing is finished, this function can be used - along with on_save_checkpoint_start() to measure the exposed (synchronous) checkpoint time. - Note that for asynchronous checkpoint, the checkpoint may still be ongoing, so this function - does not mean the checkpoint is finished for the asynchronous case, use on_save_checkpoint_success() - for that. - """ - pass - - def on_save_checkpoint_success(self, iteration: int = 0, elapsed_time: float = 0) -> None: - """ - Called when checkpoint saving is fully finished, and succeeded. Not called if checkpoint failed. - For synchronous checkpoint, it is called at the same time as on_save_checkpoint_end(), but for asynchronous - checkpoint, it is called after the asynchronous part has also finished. For checkpointers with out-of-process - checkpointing, this function is called as soon as the notification is received from the checkpointer process, - which may not be immediately after the checkpoint has completed but later on. Therefore, if you need to measure - the full checkpoint duration for the asynchronous part, use the elapsed_time parameter, do not measure it directly - as this would be a significant overestimate. - """ - pass - - def on_save_checkpoint(self, model: ImaginaireModel, state_dict: dict[Any]) -> None: - pass - - def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None: - pass - - def on_app_end(self) -> None: - pass - - -class EMAModelCallback(Callback): - """The callback class for tracking EMA model weights.""" - - def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None: - # Set up the EMA model weight tracker. - if model.config.ema.enabled: - assert hasattr(model, "ema"), "EMA should be initialized from ImaginaireModel" - # EMA model must be kept in FP32 precision. - model.ema = model.ema.to(dtype=torch.float32) - else: - assert not hasattr(model, "ema"), "There should be no EMA initialized." - - def on_training_step_end( - self, - model: ImaginaireModel, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - # Update the EMA model with the new regular weights. - if model.config.ema.enabled: - model.ema.update_average(model, iteration) - - -class ProgressBarCallback(Callback): - """The callback class for visualizing the training/validation progress bar in the console.""" - - @distributed.rank0_only - def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None: - self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training") - - @distributed.rank0_only - def on_training_step_end( - self, - model: ImaginaireModel, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - self.train_pbar.update() - - @distributed.rank0_only - def on_validation_start( - self, model: ImaginaireModel, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0 - ) -> None: - if self.config.trainer.max_val_iter is not None: - num_iter = self.config.trainer.max_val_iter - else: - num_iter = len(dataloader_val) - assert num_iter is not None and num_iter > 0, f"Invalid number of validation iterations: {num_iter}" - self.val_pbar = tqdm.trange(num_iter, desc="Validating", position=1, leave=False) - - @distributed.rank0_only - def on_validation_step_end( - self, - model: ImaginaireModel, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - self.val_pbar.update() - - @distributed.rank0_only - def on_validation_end(self, model: ImaginaireModel, iteration: int = 0) -> None: - self.val_pbar.close() - - @distributed.rank0_only - def on_train_end(self, model: ImaginaireModel, iteration: int = 0) -> None: - self.trainer.checkpointer.finalize() - self.train_pbar.close() - - -class IterationLoggerCallback(Callback): - """The callback class for visualizing the training/validation progress bar in the console.""" - - @distributed.rank0_only - def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None: - # self.train_pbar = tqdm.trange(self.config.trainer.max_iter, initial=iteration, desc="Training") - self.start_iteration_time = time.time() - self.elapsed_iteration_time = 0 - - @distributed.rank0_only - def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None: - self.start_iteration_time = time.time() - - @distributed.rank0_only - def on_training_step_end( - self, - model: ImaginaireModel, - data_batch: dict[str, torch.Tensor], - output_batch: dict[str, torch.Tensor], - loss: torch.Tensor, - iteration: int = 0, - ) -> None: - # but this is only called when the optimizer is updated, so it's only the time for the last batch. - self.elapsed_iteration_time += time.time() - self.start_iteration_time - - if iteration % self.config.trainer.logging_iter == 0: - avg_time = self.elapsed_iteration_time / self.config.trainer.logging_iter - log.info(f"Iteration: {iteration}, average iter time: {avg_time:2f}, total loss {loss.item():4f}") - - self.elapsed_iteration_time = 0 - - -class LowPrecisionCallback(Callback): - """The callback class handling low precision training""" - - def __init__(self, config: Config, trainer: ImaginaireTrainer, update_iter: int): - self.update_iter = update_iter - - def on_train_start(self, model: ImaginaireModel, iteration: int = 0) -> None: - if model.precision == torch.float32: - log.critical("Using fp32. We should disable master weights update.") - self.update_iter = sys.maxsize - else: - assert model.precision in [ - torch.bfloat16, - torch.float16, - torch.half, - ], "LowPrecisionCallback must use a low precision dtype." - self.precision_type = model.precision - - def on_training_step_start(self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0) -> None: - for k, v in data.items(): - if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]): - data[k] = v.to(dtype=self.precision_type) - - def on_validation_step_start( - self, model: ImaginaireModel, data: dict[str, torch.Tensor], iteration: int = 0 - ) -> None: - for k, v in data.items(): - if isinstance(v, torch.Tensor) and torch.is_floating_point(data[k]): - data[k] = v.to(dtype=self.precision_type) - - def on_before_zero_grad( - self, - model_ddp: distributed.DistributedDataParallel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - iteration: int = 0, - ) -> None: - if iteration % self.update_iter == 0: - if getattr(optimizer, "master_weights", False): - params, master_params = [], [] - for group, group_master in zip(optimizer.param_groups, optimizer.param_groups_master): - for p, p_master in zip(group["params"], group_master["params"]): - params.append(get_local_tensor_if_DTensor(p.data)) - master_params.append(p_master.data) - torch._foreach_copy_(params, master_params) - - -class NVTXCallback(Callback): - """The callback for creating NVTX ranges""" - - def __init__( - self, - synchronize: bool = False, - config: Optional["Config"] = None, - trainer: Optional["ImaginaireTrainer"] = None, - ): - super().__init__(config, trainer) - self.synchronize = synchronize - - def on_before_forward(self, iteration: int = 0) -> None: - if self.synchronize: - torch.cuda.synchronize() - torch.cuda.nvtx.range_push("forward") - - def on_after_forward(self, iteration: int = 0) -> None: - if self.synchronize: - torch.cuda.synchronize() - torch.cuda.nvtx.range_pop() - - def on_before_backward( - self, model_ddp: distributed.DistributedDataParallel, loss: torch.Tensor, iteration: int = 0 - ) -> None: - if self.synchronize: - torch.cuda.synchronize() - torch.cuda.nvtx.range_push("backward") - - def on_after_backward(self, model_ddp: distributed.DistributedDataParallel, iteration: int = 0) -> None: - if self.synchronize: - torch.cuda.synchronize() - torch.cuda.nvtx.range_pop() - - def on_before_optimizer_step( - self, - model_ddp: distributed.DistributedDataParallel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - iteration: int = 0, - ) -> None: - if self.synchronize: - torch.cuda.synchronize() - torch.cuda.nvtx.range_push("optimizer_step") - - def on_before_zero_grad( - self, - model_ddp: distributed.DistributedDataParallel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - iteration: int = 0, - ) -> None: - if self.synchronize: - torch.cuda.synchronize() - torch.cuda.nvtx.range_pop() - - def on_before_dataloading(self, iteration: int = 0) -> None: - torch.cuda.nvtx.range_push("dataloading") - - def on_after_dataloading(self, iteration: int = 0) -> None: - torch.cuda.nvtx.range_pop() diff --git a/lyra_2/_ext/imaginaire/utils/checkpointer.py b/lyra_2/_ext/imaginaire/utils/checkpointer.py deleted file mode 100644 index 1d0fb47eeccddd5b3ad41b605a2d7a37861222e5..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/checkpointer.py +++ /dev/null @@ -1,372 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import os -import threading -from typing import TYPE_CHECKING, List, NamedTuple, Tuple - -import torch - -from lyra_2._ext.imaginaire.model import ImaginaireModel -from lyra_2._ext.imaginaire.utils import callback, distributed, log, misc, object_store - -if TYPE_CHECKING: - from lyra_2._ext.imaginaire.config import CheckpointConfig, JobConfig - -TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) -if TORCH_VERSION >= (1, 11): - from torch.ao import quantization - from torch.ao.quantization import FakeQuantizeBase, ObserverBase -elif ( - TORCH_VERSION >= (1, 8) - and hasattr(torch.quantization, "FakeQuantizeBase") - and hasattr(torch.quantization, "ObserverBase") -): - from torch import quantization - from torch.quantization import FakeQuantizeBase, ObserverBase - - -class _IncompatibleKeys( - NamedTuple( - "IncompatibleKeys", - [ - ("missing_keys", List[str]), - ("unexpected_keys", List[str]), - ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]), - ], - ) -): - pass - - -class Checkpointer: - """The checkpointer class. Supports checkpoint saving/loading to both local disk or object store.""" - - def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): - """Constructor of the checkpointer. - - Args: - config_checkpoint (CheckpointConfig): The config object for the checkpointer. - """ - # Set the callback functions. - self.callbacks = callbacks - - self.checkpoint_dir_local = f"{config_job.path_local}/checkpoints" - self.checkpoint_dir_object_store = f"{config_job.path}/checkpoints" - self.save_to_object_store = config_checkpoint.save_to_object_store.enabled - self.load_from_object_store = config_checkpoint.load_from_object_store.enabled - self.strict_resume = config_checkpoint.strict_resume - self.load_path = config_checkpoint.load_path or None - self.load_training_state = config_checkpoint.load_training_state - self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state - self.save_thread = None - # Create the object store client interface. - if self.save_to_object_store: - self.object_store_saver = object_store.ObjectStore(config_checkpoint.save_to_object_store) - if self.load_from_object_store: - self.object_store_loader = object_store.ObjectStore(config_checkpoint.load_from_object_store) - - def save( - self, - model: ImaginaireModel, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler.LRScheduler, - grad_scaler: torch.amp.GradScaler, - iteration: int, - ) -> None: - """Save network weights, optimizer parameters, scheduler parameters to a checkpoint. - - Args: - model (ImaginaireModel): The PyTorch model. - optimizer (torch.optim.Optimizer): The model optimizer. - scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. - grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). - iteration (int): Current iteration number. - """ - self.callbacks.on_save_checkpoint_start(model, iteration) - - checkpoint_file = f"iter_{iteration:09}.pt" - - if distributed.get_rank() == 0: - state_dict = dict( - model=model.state_dict(), - optimizer=optimizer.state_dict(), - scheduler=scheduler.state_dict(), - grad_scaler=grad_scaler.state_dict(), - iteration=iteration, - ) - state_dict = misc.to(state_dict, device="cpu") - self.callbacks.on_save_checkpoint(model, state_dict=state_dict) - # Wait for previous saver thread to end. - if self.save_thread: - self.save_thread.join() - # Run the checkpoint saver in a separate thread. - self.save_thread = threading.Thread( - target=self._save_worker_object_store if self.save_to_object_store else self._save_worker_local, - daemon=False, - args=(state_dict, checkpoint_file, distributed.get_rank()), - ) - self.save_thread.start() - - # Note: Checkpoints are saved on a separate thread and this callback is not accurate. - # Please check logs from on_save_checkpoint_success() for better accuracy - self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) - - @misc.timer("checkpoint saving (local)") - def _save_worker_local(self, state_dict: dict[str, torch.Tensor], checkpoint_file: str, rank: int = 0) -> None: - """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training). - - Args: - state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler. - checkpoint_file (str): The file name of the model checkpoint. - rank (int): GPU device (default: 0). - """ - checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file) - os.makedirs(self.checkpoint_dir_local, exist_ok=True) - try: - torch.save(state_dict, checkpoint_path) - if rank == 0: - self._write_latest_checkpoint_file(checkpoint_file) - log.success(f"Saved checkpoint (local): {checkpoint_path}") - iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) - self.callbacks.on_save_checkpoint_success(iteration=iteration) - except Exception as e: # noqa: BLE001 - log.exception(f"Checkpoint failed to save (local): {e}") - - @misc.timer("checkpoint saving (object store)") - def _save_worker_object_store( - self, state_dict: dict[str, torch.Tensor], checkpoint_file: str, rank: int = 0 - ) -> None: - """Worker to upload checkpoint to object store, spawned with a child thread (in parallel with the training). - - Args: - state_dict (dict[str, torch.Tensor]): The state dict of the model/optimizer/scheduler. - checkpoint_file (str): The file name of the model checkpoint. - rank (int): GPU device (default: 0). - """ - checkpoint_path = os.path.join(self.checkpoint_dir_object_store, checkpoint_file) - try: - self.object_store_saver.save_object(state_dict, key=checkpoint_path, type="torch") - if rank == 0: - self._write_latest_checkpoint_file(checkpoint_file) - log.success(f"Saved checkpoint (object store): {checkpoint_path}") - iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) - self.callbacks.on_save_checkpoint_success(iteration=iteration) - except Exception as e: # noqa: BLE001 - log.exception(f"Checkpoint failed to upload (object store): {e}") - - @misc.timer("checkpoint loading") - def load( - self, - model: ImaginaireModel, - optimizer: torch.optim.Optimizer | None = None, - scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, - grad_scaler: torch.amp.GradScaler | None = None, - ) -> int: - """Load network weights and optimizer states from a checkpoint in a single process. - - The priority of the checkpoint loading logic is: - 1. Attempt to resume training if possible by looking for latest_checkpoint.txt under the same name. - 2. If no latest checkpoint were found, it loads the model weights specified by config_checkpoint.path. - - This is typically used for inference mode. - - If config_checkpoint.load_optimizer_state is True, then also load the optimizer and scheduler states. - 3. If none of the above, randomly initialize the model parameters and train from scratch. - - Args: - model (ImaginaireModel): The PyTorch model. - optimizer (torch.optim.Optimizer | None): The model optimizer (default: None). - scheduler (torch.optim.lr_scheduler.LRScheduler | None): The optimization scheduler (default: None). - grad_scaler (torch.amp.GradScaler | None): The gradient scaler (for mixed precision training). - - Returns: - iteration (int): the iteration number to start/resume from. - """ - self.callbacks.on_load_checkpoint_start(model) - - latest_checkpoint_file = self._read_latest_checkpoint_file() - if latest_checkpoint_file is not None: - # 1. Resume training from latest_checkpoint.txt under the same name. - checkpoint_dir = ( - self.checkpoint_dir_object_store if self.load_from_object_store else self.checkpoint_dir_local - ) - checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint_file) - resume = True - only_resume_scheduler = True - else: - if self.load_path: - # 2. Load the module weights specified by config_checkpoint.path. - checkpoint_path = self.load_path - resume = self.load_training_state - only_resume_scheduler = self.only_load_scheduler_state - else: - # 3. Randomly initialize the model parameters and train from scratch. - checkpoint_path = None - resume = False - only_resume_scheduler = False - # Load checkpoint. - if checkpoint_path is not None: - self._check_checkpoint_exists(checkpoint_path) - if self.load_from_object_store: - log.info(f"Loading checkpoint (object store): {checkpoint_path}") - state_dict = self.object_store_loader.load_object(key=checkpoint_path, type="torch", max_attempts=20) - log.success(f"Complete loading checkpoint (object store): {checkpoint_path}") - else: - log.info(f"Loading checkpoint (local): {checkpoint_path}") - state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) - log.success(f"Complete loading checkpoint (local): {checkpoint_path}") - self.callbacks.on_load_checkpoint(model, state_dict=state_dict) - # Load the state dicts. - log.info("- Loading the model...") - model.load_state_dict(state_dict["model"], strict=self.strict_resume) - if resume or only_resume_scheduler: - iteration = state_dict["iteration"] - assert scheduler - log.info("- Loading the scheduler...") - scheduler.load_state_dict(state_dict["scheduler"]) - scheduler.last_epoch = iteration - else: - iteration = 0 - if resume: - assert optimizer - log.info("- Loading the optimizer...") - optimizer.load_state_dict(state_dict["optimizer"]) - log.info("- Loading the gradient scaler...") - grad_scaler.load_state_dict(state_dict["grad_scaler"]) - log.success(f"Done with loading the checkpoint (iteration {iteration}).") - else: - log.success("Done with loading the checkpoint.") - else: - # Checkpoint not found and not specified. We will train everything from scratch. - iteration = 0 - log.info("Training from scratch.") - torch.cuda.empty_cache() - - self.callbacks.on_load_checkpoint_end(model, iteration=iteration, checkpoint_path=checkpoint_path) - - return iteration - - def _read_latest_checkpoint_file(self) -> str | None: - """Get the file name of the latest saved checkpoint. If it doesn't exist, return None. - - Returns: - checkpoint_file (str | None): file name of the latest saved checkpoint. - """ - checkpoint_file = None - if self.load_from_object_store: - latest_path = os.path.join(self.checkpoint_dir_object_store, "latest_checkpoint.txt") - if self.object_store_loader.object_exists(key=latest_path): - checkpoint_file = self.object_store_loader.load_object(key=latest_path, type="text").strip() - else: - latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") - if os.path.isfile(latest_path): - checkpoint_file = open(latest_path).read().strip() - return checkpoint_file - - def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None: - """Track the file name of the latest saved checkpoint. - - Args: - checkpoint_file (str): file name of the latest saved checkpoint. - """ - content = f"{checkpoint_file}\n" - if self.save_to_object_store: - latest_path = os.path.join(self.checkpoint_dir_object_store, "latest_checkpoint.txt") - self.object_store_saver.save_object(content, key=latest_path, type="text") - else: - latest_path = os.path.join(self.checkpoint_dir_local, "latest_checkpoint.txt") - with open(latest_path, "w") as file: - file.write(content) - - def _check_checkpoint_exists(self, checkpoint_path: str) -> None: - """If the file checkpoint_path does not exist, raise an error. - - Args: - checkpoint_path (str): full path to the checkpoint. - """ - if self.load_from_object_store: - if not self.object_store_loader.object_exists(key=checkpoint_path): - raise FileNotFoundError(f"File not found (object store): {checkpoint_path}") - else: - if not os.path.exists(checkpoint_path): - raise FileNotFoundError(f"File not found (local): {checkpoint_path}") - - def finalize(self) -> None: - """Finalize the checkpointer.""" - if self.save_thread: - self.save_thread.join() - - -# https://github.com/facebookresearch/fvcore/blob/9d683aae73fb899dd35d6cf6720e5ef567761c57/fvcore/common/checkpoint.py -def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys: - # workaround https://github.com/pytorch/pytorch/issues/24139 - model_state_dict = model.state_dict() - incorrect_shapes = [] - for k in list(checkpoint_state_dict.keys()): - if k in model_state_dict: - if "_extra_state" in k: # Key introduced by TransformerEngine for FP8 - log.warning(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.") - continue - model_param = model_state_dict[k] - # Allow mismatch for uninitialized parameters - if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter): - continue - if not isinstance(model_param, torch.Tensor): - raise ValueError( - f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not." - ) - - shape_model = tuple(model_param.shape) - shape_checkpoint = tuple(checkpoint_state_dict[k].shape) - if shape_model != shape_checkpoint: - has_observer_base_classes = ( - TORCH_VERSION >= (1, 8) - and hasattr(quantization, "ObserverBase") - and hasattr(quantization, "FakeQuantizeBase") - ) - if has_observer_base_classes: - # Handle the special case of quantization per channel observers, - # where buffer shape mismatches are expected. - def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module: - # foo.bar.param_or_buffer_name -> [foo, bar] - key_parts = key.split(".")[:-1] - cur_module = model - for key_part in key_parts: - cur_module = getattr(cur_module, key_part) - return cur_module - - cls_to_skip = ( - ObserverBase, - FakeQuantizeBase, - ) - target_module = _get_module_for_key(model, k) - if isinstance(target_module, cls_to_skip): - # Do not remove modules with expected shape mismatches - # them from the state_dict loading. They have special logic - # in _load_from_state_dict to handle the mismatches. - continue - - incorrect_shapes.append((k, shape_checkpoint, shape_model)) - checkpoint_state_dict.pop(k) - incompatible = model.load_state_dict(checkpoint_state_dict, strict=False) - # Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling - missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k] - unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k] - return _IncompatibleKeys( - missing_keys=missing_keys, - unexpected_keys=unexpected_keys, - incorrect_shapes=incorrect_shapes, - ) diff --git a/lyra_2/_ext/imaginaire/utils/config_helper.py b/lyra_2/_ext/imaginaire/utils/config_helper.py deleted file mode 100644 index 5d26027fc087eaafcc5c8c2ff6b5711009a5bd11..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/config_helper.py +++ /dev/null @@ -1,214 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import importlib -import os -import pkgutil -import sys -from dataclasses import fields as dataclass_fields -from dataclasses import is_dataclass -from typing import Any, Dict, Optional - -import attr -import attrs -from hydra import compose, initialize -from hydra.core.config_store import ConfigStore -from hydra.core.global_hydra import GlobalHydra -from omegaconf import DictConfig, OmegaConf - -from lyra_2._ext.imaginaire.config import Config -from lyra_2._ext.imaginaire.utils import log - - -def is_attrs_or_dataclass(obj) -> bool: - """ - Check if the object is an instance of an attrs class or a dataclass. - - Args: - obj: The object to check. - - Returns: - bool: True if the object is an instance of an attrs class or a dataclass, False otherwise. - """ - return is_dataclass(obj) or attr.has(type(obj)) - - -def get_fields(obj): - """ - Get the fields of an attrs class or a dataclass. - - Args: - obj: The object to get fields from. Must be an instance of an attrs class or a dataclass. - - Returns: - list: A list of field names. - - Raises: - ValueError: If the object is neither an attrs class nor a dataclass. - """ - if is_dataclass(obj): - return [field.name for field in dataclass_fields(obj)] - elif attr.has(type(obj)): - return [field.name for field in attr.fields(type(obj))] - else: - raise ValueError("The object is neither an attrs class nor a dataclass.") - - -def override(config: Config, overrides: Optional[list[str]] = None) -> Config: - """ - :param config: the instance of class `Config` (usually from `make_config`) - :param overrides: list of overrides for config - :return: the composed instance of class `Config` - """ - # Store the class of the config for reconstruction after overriding. - # config_class = type(config) - - # Convert Config object to a DictConfig object - config_dict = attrs.asdict(config) - config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True}) - # Enforce "--" separator between the script arguments and overriding configs. - if overrides: - if overrides[0] != "--": - raise ValueError( - f'Hydra config overrides must be separated with a "--" token. but got overrides={overrides}, and overrides[0]={overrides[0]}' - ) - overrides = overrides[1:] - # Use Hydra to handle overrides - cs = ConfigStore.instance() - cs.store(name="config", node=config_omegaconf) - if not GlobalHydra().is_initialized(): - with initialize(version_base=None): - config_omegaconf = compose(config_name="config", overrides=overrides) - OmegaConf.resolve(config_omegaconf) - else: - config_omegaconf = compose(config_name="config", overrides=overrides) - OmegaConf.resolve(config_omegaconf) - - def config_from_dict(ref_instance: Any, kwargs: Any) -> Any: - """ - Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data - - Args: - ref_instance: The reference instance to determine the type and fields when needed - kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data - - Returns: - Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data - - Raises: - AssertionError: If the fields do not match or if extra keys are found. - Exception: If there is an error constructing the new instance. - """ - is_type = is_attrs_or_dataclass(ref_instance) - if not is_type: - return kwargs - else: - ref_fields = set(get_fields(ref_instance)) - assert isinstance(kwargs, dict) or isinstance(kwargs, DictConfig), ( - "kwargs must be a dictionary or a DictConfig" - ) - keys = set(kwargs.keys()) - - # ref_fields must equal to or include all keys - extra_keys = keys - ref_fields - assert ref_fields == keys or keys.issubset(ref_fields), ( - f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}" - ) - - resolved_kwargs: Dict[str, Any] = {} - for f in keys: - resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f]) - try: - new_instance = type(ref_instance)(**resolved_kwargs) - except Exception as e: - log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}") - log.error(e) - raise e - return new_instance - - config = config_from_dict(config, config_omegaconf) - - return config - - -def get_config_module(config_file: str) -> str: - if not config_file.endswith(".py"): - log.error("Config file cannot be specified as module.") - log.error("Please provide the path to the Python config file (relative to the Imaginaire4 root).") - assert os.path.isfile(config_file), f"Imaginaire4 config file ({config_file}) not found." - # Convert to importable module format. - config_module = config_file.replace("/", ".").replace(".py", "") - return config_module - - -def import_module(full_module_name: str, reload: bool = False): - """ - Import a module by name. - - Args: - full_module_name: The fully qualified name of the module to import. - reload: If True, reload the module if it's already imported. - """ - if full_module_name in sys.modules and reload: - importlib.reload(sys.modules[full_module_name]) - else: - importlib.import_module(full_module_name) - - -def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None: - """ - Import all modules from the specified package path recursively. - - This function is typically used in conjunction with Hydra to ensure that all modules - within a specified package are imported, which is necessary for registering configurations. - - Example usage: - ```python - import_all_modules_from_package("projects.cosmos.diffusion.v1.config.experiment", reload=True, skip_underscore=False) - ``` - - Args: - package_path (str): The dotted path to the package from which to import all modules. - reload (bool): Flag to determine whether to reload modules if they're already imported. - skip_underscore (bool): If True, skips importing modules that start with an underscore. - """ - log.info(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}") - package = importlib.import_module(package_path) - package_directory = package.__path__ - - def import_modules_recursively(directory: str, prefix: str) -> None: - """ - Recursively imports or reloads all modules in the given directory. - - Args: - directory (str): The file system path to the current package directory. - prefix (str): The module prefix (e.g., 'projects.cosmos.diffusion.v1.config'). - """ - for _, module_name, is_pkg in pkgutil.iter_modules([directory]): - if skip_underscore and module_name.startswith("_"): - log.debug(f"Skipping module {module_name} as it starts with an underscore") - continue - - full_module_name = f"{prefix}.{module_name}" - log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}") - - import_module(full_module_name, reload=reload) - - if is_pkg: - sub_package_directory = os.path.join(directory, module_name) - import_modules_recursively(sub_package_directory, full_module_name) - - for directory in package_directory: - import_modules_recursively(directory, package_path) diff --git a/lyra_2/_ext/imaginaire/utils/context_managers.py b/lyra_2/_ext/imaginaire/utils/context_managers.py deleted file mode 100644 index c282a2543f14213a3254ad9e65bc97115a647f3f..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/context_managers.py +++ /dev/null @@ -1,55 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from contextlib import ExitStack, contextmanager -from typing import Generator - -from lyra_2._ext.imaginaire.utils.misc import timer - - -@contextmanager -def data_loader_init() -> Generator[None, None, None]: - """ - Wrap the data loader initialization with multiple context managers used for telemetry and one logger. - """ - contexts = [ - timer("init_data_loader"), - ] - with ExitStack() as stack: - yield [stack.enter_context(cm) for cm in contexts] - - -@contextmanager -def model_init(set_barrier: bool = False) -> Generator[None, None, None]: - """ - Wrap the instantiation of the model with multiple context managers used for telemetry and one logger. - """ - contexts = [ - timer("init_model"), - ] - with ExitStack() as stack: - yield [stack.enter_context(cm) for cm in contexts] - - -@contextmanager -def distributed_init() -> Generator[None, None, None]: - """ - Wrap the distributed initialization, used for telemetry and timers - """ - contexts = [ - timer("init_distributed"), - ] - with ExitStack() as stack: - yield [stack.enter_context(cm) for cm in contexts] diff --git a/lyra_2/_ext/imaginaire/utils/count_params.py b/lyra_2/_ext/imaginaire/utils/count_params.py deleted file mode 100644 index 7308f58e3ad4d9bcf378fc841935cb55c467a2a4..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/count_params.py +++ /dev/null @@ -1,29 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from torch import nn - - -def disabled_train(self, mode: bool = True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -def count_params(model: nn.Module, verbose=False) -> int: - total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - if verbose: - print(f"{model.__class__.__name__} has {total_params * 1.0e-6:.2f} M params.") - return total_params diff --git a/lyra_2/_ext/imaginaire/utils/device.py b/lyra_2/_ext/imaginaire/utils/device.py deleted file mode 100644 index 9c1ac58b5e537daea655c497a4201b501b4504b5..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/device.py +++ /dev/null @@ -1,114 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gc -import math -import os - -import pynvml -from loguru import logger as logging - - -def get_gpu_architecture(): - """ - Retrieves the GPU architecture of the available GPUs. - - Returns: - str: The GPU architecture, which can be "H100", "A100", or "Other". - """ - try: - pynvml.nvmlInit() - device_count = pynvml.nvmlDeviceGetCount() - for i in range(device_count): - handle = pynvml.nvmlDeviceGetHandleByIndex(i) - model_name = pynvml.nvmlDeviceGetName(handle) - if isinstance(model_name, bytes): - model_name = model_name.decode("utf-8") - print(f"GPU {i}: Model: {model_name}") - - # Check for specific models like H100 or A100 - if "H100" in model_name or "H200" in model_name: - return "H100" - elif "A100" in model_name: - return "A100" - elif "L40S" in model_name: - return "L40S" - elif "B200" in model_name: - return "B200" - except pynvml.NVMLError as error: - print(f"Failed to get GPU info: {error}") - finally: - pynvml.nvmlShutdown() - - # return "Other" incase of non hopper/ampere or error - return "Other" - - -class GPUArchitectureNotSupported(Exception): - """ - Custom exception raised when the expected GPU architecture is not supported. - """ - - pass - - -def print_gpu_mem(str=None): - try: - pynvml.nvmlInit() - meminfo = pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(0)) - logging.info( - f"{str}: {meminfo.used / 1024 / 1024}/{meminfo.total / 1024 / 1024}MiB used ({meminfo.free / 1024 / 1024}MiB free)" - ) - except pynvml.NVMLError as error: - print(f"Failed to get GPU memory info: {error}") - - -def force_gc(): - print_gpu_mem() - print("gc()") - gc.collect() - print_gpu_mem() - print("empty cuda cache") - # print(torch.cuda.memory_summary()) - print_gpu_mem() - - -def gpu0_has_80gb_or_less(): - try: - pynvml.nvmlInit() - meminfo = pynvml.nvmlDeviceGetMemoryInfo(pynvml.nvmlDeviceGetHandleByIndex(0)) - return meminfo.total / 1024 / 1024 / 1024 <= 80 - except pynvml.NVMLError as error: - print(f"Failed to get GPU memory info: {error}") - - -class Device: - _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) # type: ignore - - def __init__(self, device_idx: int): - super().__init__() - self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) - - def get_name(self) -> str: - return pynvml.nvmlDeviceGetName(self.handle) - - def get_cpu_affinity(self) -> list[int]: - affinity_string = "" - for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements): - # assume nvml returns list of 64 bit ints - affinity_string = "{:064b}".format(j) + affinity_string - affinity_list = [int(x) for x in affinity_string] - affinity_list.reverse() # so core 0 is in 0th element of list - return [i for i, e in enumerate(affinity_list) if e != 0] diff --git a/lyra_2/_ext/imaginaire/utils/distributed.py b/lyra_2/_ext/imaginaire/utils/distributed.py deleted file mode 100644 index ffbdd3fe8ada49c95df7f43baece887ccee6a775..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/distributed.py +++ /dev/null @@ -1,443 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import collections -import collections.abc -import ctypes -import functools -import os -from contextlib import contextmanager -from datetime import timedelta -from typing import TYPE_CHECKING, Any, Callable, Container, Optional - -import pynvml -import torch -import torch.distributed as dist -from torch.distributed import get_process_group_ranks - -from lyra_2._ext.imaginaire.utils.device import Device - -if dist.is_available(): - from torch.distributed.distributed_c10d import _get_default_group - from torch.distributed.utils import _sync_module_states, _verify_param_shape_across_processes - -from lyra_2._ext.imaginaire.utils import log - -if TYPE_CHECKING: - from lyra_2._ext.imaginaire.config import DDPConfig - -try: - from megatron.core import parallel_state -except ImportError: - print("Megatron-core is not installed.") - - -def init() -> int | None: - """Initialize distributed training.""" - if dist.is_initialized(): - return torch.cuda.current_device() - - # Set GPU affinity. - pynvml.nvmlInit() - local_rank = int(os.getenv("LOCAL_RANK", 0)) - try: - device = Device(local_rank) - os.sched_setaffinity(0, device.get_cpu_affinity()) - except Exception as e: - log.warning(f"Failed to set device affinity: {e}") - # Set up NCCL communication. - os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0" - os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" - if dist.is_available(): - torch.cuda.set_device(local_rank) - # Get the timeout value from environment variable - timeout_seconds = os.getenv("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", 1800) - # Convert the timeout to an integer (if it isn't already) and then to a timedelta - timeout_timedelta = timedelta(seconds=int(timeout_seconds)) - dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout_timedelta) - log.info( - f"Initialized distributed training with local rank {local_rank} with timeout {timeout_seconds}", - rank0_only=False, - ) - # Increase the L2 fetch granularity for faster speed. - _libcudart = ctypes.CDLL("libcudart.so") - # Set device limit on the current device. - p_value = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) - _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) - _libcudart.cudaDeviceGetLimit(p_value, ctypes.c_int(0x05)) - log.info(f"Training with {get_world_size()} GPUs.") - - -def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: - """Get the rank (GPU device) of the worker. - - Returns: - rank (int): The rank of the worker. - """ - rank = 0 - if dist.is_available() and dist.is_initialized(): - rank = dist.get_rank(group) - return rank - - -def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int: - """Get world size. How many GPUs are available in this job. - - Returns: - world_size (int): The total number of GPUs available in this job. - """ - world_size = 1 - if dist.is_available() and dist.is_initialized(): - world_size = dist.get_world_size(group) - return world_size - - -def is_rank0() -> bool: - """Check if current process is the master GPU. - - Returns: - (bool): True if this function is called from the master GPU, else False. - """ - return get_rank() == 0 - - -def is_local_rank0() -> bool: - """Check if current process is the local master GPU in the current node. - - Returns: - (bool): True if this function is called from the local master GPU, else False. - """ - return torch.cuda.current_device() == 0 - - -def rank0_only(func: Callable) -> Callable: - """Apply this function only to the master GPU. - - Example usage: - @rank0_only - def func(x): - return x + 3 - - Args: - func (Callable): a function. - - Returns: - (Callable): A function wrapper executing the function only on the master GPU. - """ - - @functools.wraps(func) - def wrapper(*args, **kwargs): # noqa: ANN202 - if is_rank0(): - return func(*args, **kwargs) - else: - return None - - return wrapper - - -def barrier() -> None: - """Barrier for all GPUs.""" - if dist.is_available() and dist.is_initialized(): - dist.barrier() - - -def rank0_first(func: Callable) -> Callable: - """run the function on rank 0 first, then on other ranks.""" - - @functools.wraps(func) - def wrapper(*args, **kwargs): # noqa: ANN202 - if is_rank0(): - result = func(*args, **kwargs) - barrier() - if not is_rank0(): - result = func(*args, **kwargs) - return result - - return wrapper - - -def parallel_model_wrapper(config_ddp: DDPConfig, model: torch.nn.Module) -> torch.nn.Module | DistributedDataParallel: - """Wraps the model to enable data parallalism for training across multiple GPU devices. - - Args: - config_ddp (DDPConfig): The data parallel config. - model (torch.nn.Module): The PyTorch module. - - Returns: - model (torch.nn.Module | DistributedDataParallel): The data parallel model wrapper - if distributed environment is available, otherwise return the original model. - """ - if dist.is_available() and dist.is_initialized(): - local_rank = int(os.getenv("LOCAL_RANK", 0)) - try: - ddp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) - except Exception as e: - log.info(e) - log.info("parallel_state not initialized, treating all GPUs equally for DDP") - ddp_group = None - - model = DistributedDataParallel( - model, - device_ids=[local_rank], - output_device=local_rank, - find_unused_parameters=config_ddp.find_unused_parameters, - static_graph=config_ddp.static_graph, - broadcast_buffers=config_ddp.broadcast_buffers, - process_group=ddp_group, - ) - return model - - -class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel): - """This extends torch.nn.parallel.DistributedDataParallel with .training_step(). - - This borrows the concept of `forward-redirection` from Pytorch lightning. It wraps an ImaginaireModel such that - model.training_step() would be executed when calling self.training_step(), while preserving the behavior of calling - model() for Pytorch modules. Internally, this is a double rerouting mechanism (training_step -> forward -> - training_step), allowing us to preserve the function names and signatures. - """ - - def __init__(self, model: torch.nn.Module, *args, **kwargs): - super().__init__(model, *args, **kwargs) - self.show_sync_grad_static_graph_warning = True - - def training_step(self, *args, **kwargs) -> Any: - # Cache the original model.forward() method. - original_forward = self.module.forward - - def wrapped_training_step(*_args, **_kwargs): # noqa: ANN202 - # Unpatch immediately before calling training_step() because itself may want to call the real forward. - self.module.forward = original_forward - # The actual .training_step(). - return self.module.training_step(*_args, **_kwargs) - - # Patch the original_module's forward so we can redirect the arguments back to the real method. - self.module.forward = wrapped_training_step - # Call self, which implicitly calls self.forward() --> model.forward(), which is now model.training_step(). - # Without calling self.forward() or model.forward() explciitly, implicit hooks are also executed. - return self(*args, **kwargs) - - -@contextmanager -def ddp_sync_grad(model, enabled): - r""" - Context manager to enable/disable gradient synchronizations across DDP processes for DDP model. - Modified from: - https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync - Note that this is incompatible with static_graph=True and will be an no-op if static_graph=True. - - Within this context, gradients will be accumulated on module - variables, which will later be synchronized in the first - forward-backward pass exiting the context. - - .. warning:: - The forward pass should be included inside the context manager, or - else gradients will still be synchronized. - """ - assert isinstance(model, torch.nn.Module) - if isinstance(model, DistributedDataParallel): - old_require_backward_grad_sync = model.require_backward_grad_sync - if model.static_graph and model.require_backward_grad_sync != enabled: - if model.show_sync_grad_static_graph_warning: - log.warning("DDP static_graph=True is incompatible with sync_grad(). Performance will be reduced.") - model.show_sync_grad_static_graph_warning = False - else: - model.require_backward_grad_sync = enabled - try: - yield - finally: - if isinstance(model, DistributedDataParallel): - model.require_backward_grad_sync = old_require_backward_grad_sync - - -def collate_batches(data_batches: list[dict[str, torch.Tensor]]) -> torch.Tensor | dict[str, torch.Tensor]: - """Aggregate the list of data batches from all devices and process the results. - - This is used for gathering validation data batches with lyra_2._ext.imaginaire.utils.dataloader.DistributedEvalSampler. - It will return the data/output of the entire validation set in its original index order. The sizes of data_batches - in different ranks may differ by 1 (if dataset size is not evenly divisible), in which case a dummy sample will be - created before calling dis.all_gather(). - - Args: - data_batches (list[dict[str, torch.Tensor]]): List of tensors or (hierarchical) dictionary where - leaf entries are tensors. - - Returns: - data_gather (torch.Tensor | dict[str, torch.Tensor]): tensors or (hierarchical) dictionary where - leaf entries are concatenated tensors. - """ - if isinstance(data_batches[0], torch.Tensor): - # Concatenate the local data batches. - data_concat = torch.cat(data_batches, dim=0) # type: ignore - # Get the largest number of local samples from all ranks to determine whether to dummy-pad on this rank. - max_num_local_samples = torch.tensor(len(data_concat), device="cuda") - dist.all_reduce(max_num_local_samples, op=dist.ReduceOp.MAX) - if len(data_concat) < max_num_local_samples: - assert len(data_concat) + 1 == max_num_local_samples - dummy = torch.empty_like(data_concat[:1]) - data_concat = torch.cat([data_concat, dummy], dim=0) - dummy_count = torch.tensor(1, device="cuda") - else: - dummy_count = torch.tensor(0, device="cuda") - # Get all concatenated batches from all ranks and concatenate again. - dist.all_reduce(dummy_count, op=dist.ReduceOp.SUM) - data_concat = all_gather_tensor(data_concat.contiguous()) - data_collate = torch.stack(data_concat, dim=1).flatten(start_dim=0, end_dim=1) - # Remove the dummy samples. - if dummy_count > 0: - data_collate = data_collate[:-dummy_count] - elif isinstance(data_batches[0], collections.abc.Mapping): - data_collate = dict() - for key in data_batches[0].keys(): - data_collate[key] = collate_batches([data[key] for data in data_batches]) # type: ignore - else: - raise TypeError - return data_collate - - -@torch.no_grad() -def all_gather_tensor(tensor: torch.Tensor) -> list[torch.Tensor]: - """Gather the corresponding tensor from all GPU devices to a list. - - Args: - tensor (torch.Tensor): Pytorch tensor. - - Returns: - tensor_list (list[torch.Tensor]): A list of Pytorch tensors gathered from all GPU devices. - """ - tensor_list = [torch.zeros_like(tensor) for _ in range(get_world_size())] - dist.all_gather(tensor_list, tensor) - return tensor_list - - -def broadcast(tensor, src, group=None, async_op=False): - world_size = get_world_size() - if world_size < 2: - return tensor - dist.broadcast(tensor, src=src, group=group, async_op=async_op) - - -def dist_reduce_tensor(tensor, rank=0, reduce="mean"): - r"""Reduce to rank 0""" - world_size = get_world_size() - if world_size < 2: - return tensor - with torch.no_grad(): - dist.reduce(tensor, dst=rank) - if get_rank() == rank: - if reduce == "mean": - tensor /= world_size - elif reduce == "sum": - pass - else: - raise NotImplementedError - return tensor - - -def sync_model_states( - model: torch.nn.Module, - process_group: Optional[dist.ProcessGroup] = None, - src: int = 0, - params_and_buffers_to_ignore: Optional[Container[str]] = None, - broadcast_buffers: bool = True, -): - """ - Modify based on DDP source code - Synchronizes the parameters and buffers of a model across different processes in a distributed setting. - - This function ensures that all processes in the specified process group have the same initial parameters and - buffers from the source rank, typically rank 0. It is useful when different processes start with different model - states and a synchronization is required to ensure consistency across all ranks. - - Args: - model (nn.Module): The model whose parameters and buffers are to be synchronized. - process_group (dist.ProcessGroup, optional): The process group for communication. If None, - the default group is used. Defaults to None. - src (int, optional): The source rank from which parameters and buffers will be broadcasted. - Defaults to 0. - params_and_buffers_to_ignore (Optional[Container[str]], optional): A container of parameter and buffer - names to exclude from synchronization. Defaults to None, which means all parameters and buffers are - included. - broadcast_buffers (bool, optional): Whether to broadcast buffers or not. Defaults to True. - - Side Effects: - This function modifies the state of the model in-place to synchronize it with the source rank's model state. - - Raises: - RuntimeError: If the shapes of parameters across processes do not match, a runtime error will be raised. - - Examples: - >>> # downloading duplicated model weights from s3 in each rank and save network bandwidth - >>> # useful and save our time when model weights are huge - >>> if dist.get_rank == 0: - >>> model.load_state_dict(network_bound_weights_download_fn(s3_weights_path)) - >>> dist.barrir() - >>> sync_model_states(model) # sync rank0 weights to other ranks - """ - if not dist.is_available() or not dist.is_initialized(): - return - if process_group is None: - process_group = _get_default_group() - if not params_and_buffers_to_ignore: - params_and_buffers_to_ignore = set() - - log.info( - f"Synchronizing model states from rank {src} to all ranks in process group {get_process_group_ranks(process_group)}." - ) - - # Build tuple of (module, parameter) for all parameters that require grads. - modules_and_parameters = [ - (module, parameter) - for module_name, module in model.named_modules() - for parameter in [ - param - # Note that we access module.named_parameters instead of - # parameters(module). parameters(module) is only needed in the - # single-process multi device case, where it accesses replicated - # parameters through _former_parameters. - for param_name, param in module.named_parameters(recurse=False) - if f"{module_name}.{param_name}" not in params_and_buffers_to_ignore - # if param.requires_grad - # and f"{module_name}.{param_name}" not in params_and_buffers_to_ignore - ] - ] - - # Deduplicate any parameters that might be shared across child modules. - memo = set() - modules_and_parameters = [ - # "p not in memo" is the deduplication check. - # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed. - (m, p) - for m, p in modules_and_parameters - if p not in memo and not memo.add(p) # type: ignore[func-returns-value] - ] - - # Build list of parameters. - parameters = [parameter for _, parameter in modules_and_parameters] - if len(parameters) == 0: - return - - _verify_param_shape_across_processes(process_group, parameters) - - _sync_module_states( - module=model, - process_group=process_group, - broadcast_bucket_size=int(250 * 1024 * 1024), - src=src, - params_and_buffers_to_ignore=params_and_buffers_to_ignore, - broadcast_buffers=broadcast_buffers, - ) diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/__init__.py b/lyra_2/_ext/imaginaire/utils/easy_io/__init__.py deleted file mode 100644 index dac9a4d7496eb38831f1f3c820a90d50e25e2a7e..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/backends/__init__.py b/lyra_2/_ext/imaginaire/utils/easy_io/backends/__init__.py deleted file mode 100644 index 8fc95233facfd2cfe6afeab6948555117649c8ad..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/backends/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from lyra_2._ext.imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend -from lyra_2._ext.imaginaire.utils.easy_io.backends.boto3_backend import Boto3Backend -from lyra_2._ext.imaginaire.utils.easy_io.backends.http_backend import HTTPBackend -from lyra_2._ext.imaginaire.utils.easy_io.backends.local_backend import LocalBackend -from lyra_2._ext.imaginaire.utils.easy_io.backends.registry_utils import ( - backends, - prefix_to_backends, - register_backend, -) - -__all__ = [ - "BaseStorageBackend", - "LocalBackend", - "HTTPBackend", - "Boto3Backend", - "register_backend", - "backends", - "prefix_to_backends", -] diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/backends/auto_auth.py b/lyra_2/_ext/imaginaire/utils/easy_io/backends/auto_auth.py deleted file mode 100644 index a09443640780993a519c1168c77bcf6077d75f9d..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/backends/auto_auth.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import json -from typing import Any, Optional - -from lyra_2._ext.imaginaire.utils import log -from lyra_2._ext.imaginaire.utils.env_parsers.cred_env_parser import CRED_ENVS, CRED_ENVS_DICT - -DEPLOYMENT_ENVS = ["prod", "dev", "stg"] - - -# context manger to open a file or read from env variable -@contextlib.contextmanager -def open_auth(s3_credential_path: Optional[Any], mode: str): - if not s3_credential_path: - log.info(f"No credential file provided {s3_credential_path}.") - yield None - return - - name = s3_credential_path.split("/")[-1].split(".")[0] - if not name: - raise ValueError(f"Could not parse into env var: {s3_credential_path}") - cred_env_name = f"PROD_{name.upper()}" - - if CRED_ENVS.APP_ENV in DEPLOYMENT_ENVS and cred_env_name in CRED_ENVS_DICT: - object_storage_config = get_creds_from_env(cred_env_name) - log.info(f"using ENV vars for {cred_env_name}") - yield object_storage_config - else: - log.info(f"using credential file: {s3_credential_path}") - with open(s3_credential_path, mode) as f: - yield f - - -def get_creds_from_env(cred_env_name: str) -> dict[str, str]: - try: - object_storage_config = CRED_ENVS_DICT[cred_env_name] - except KeyError: - raise ValueError(f"Could not find {cred_env_name} in CRED_ENVS") - empty_args = {key.upper() for key in object_storage_config if object_storage_config[key] == ""} - if empty_args: - raise ValueError(f"Some required environment variable(s) were not provided for {cred_env_name}", empty_args) - return object_storage_config - - -def json_load_auth(f): - if CRED_ENVS.APP_ENV in DEPLOYMENT_ENVS: - return f if f else {} - else: - return json.load(f) diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/backends/base_backend.py b/lyra_2/_ext/imaginaire/utils/easy_io/backends/base_backend.py deleted file mode 100644 index 3cdc6849fe8dda331ccafe9045f3bcca56e0b11f..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/backends/base_backend.py +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import os.path as osp -from abc import ABCMeta, abstractmethod - - -def mkdir_or_exist(dir_name, mode=0o777): - if dir_name == "": - return - dir_name = osp.expanduser(dir_name) - os.makedirs(dir_name, mode=mode, exist_ok=True) - - -def has_method(obj, method): - return hasattr(obj, method) and callable(getattr(obj, method)) - - -class BaseStorageBackend(metaclass=ABCMeta): - """Abstract class of storage backends. - - All backends need to implement two apis: :meth:`get()` and - :meth:`get_text()`. - - - :meth:`get()` reads the file as a byte stream. - - :meth:`get_text()` reads the file as texts. - """ - - # a flag to indicate whether the backend can create a symlink for a file - # This attribute will be deprecated in future. - _allow_symlink = False - - @property - def allow_symlink(self): - return self._allow_symlink - - @property - def name(self): - return self.__class__.__name__ - - @abstractmethod - def get(self, filepath): - pass - - @abstractmethod - def get_text(self, filepath): - pass diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/backends/boto3_backend.py b/lyra_2/_ext/imaginaire/utils/easy_io/backends/boto3_backend.py deleted file mode 100644 index bc226c2fa6ee37a275780804efcb4da9ddf4e4d1..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/backends/boto3_backend.py +++ /dev/null @@ -1,841 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import io -import os -import re -import tempfile -from contextlib import contextmanager -from pathlib import Path -from shutil import SameFileError -from typing import Generator, Iterator, Optional, Tuple, Union - -from lyra_2._ext.imaginaire.utils import log -from lyra_2._ext.imaginaire.utils.easy_io.backends.base_backend import ( - BaseStorageBackend, - has_method, - mkdir_or_exist, -) -from lyra_2._ext.imaginaire.utils.easy_io.backends.boto3_client import Boto3Client - - -class Boto3Backend(BaseStorageBackend): - """boto3 storage backend (for internal usage). - - Boto3Backend supports reading and writing data to multiple clusters. - If the file path contains the cluster name, Boto3Backend will read data - from specified cluster or write data to it. Otherwise, Boto3Backend will - access the default cluster. - - Args: - path_mapping (dict, optional): Path mapping dict from local path to - Boto3 path. When ``path_mapping={'src': 'dst'}``, ``src`` in - ``filepath`` will be replaced by ``dst``. Defaults to None. - s3_credential_path (str, optional): Config path of Boto3 client. Default: None. - `New in version 0.3.3`. - - Examples: - >>> backend = Boto3Backend() - >>> filepath1 = 's3://path/of/file' - >>> filepath2 = 'cluster-name:s3://path/of/file' - >>> backend.get(filepath1) # get data from default cluster - >>> client.get(filepath2) # get data from 'cluster-name' cluster - """ - - def __init__( - self, - s3_credential_path: str = "", - path_mapping: Optional[dict] = None, - ): - self._client = Boto3Client(s3_credential_path=s3_credential_path) - assert isinstance(path_mapping, dict) or path_mapping is None - self.path_mapping = path_mapping - if path_mapping: - for k, v in path_mapping.items(): - log.critical(f"Path mapping: {k} -> {v}", rank0_only=False) - - def _map_path(self, filepath: Union[str, Path]) -> str: - """Map ``filepath`` to a string path whose prefix will be replaced by - :attr:`self.path_mapping`. - - Args: - filepath (str or Path): Path to be mapped. - """ - filepath = str(filepath) - if self.path_mapping is not None: - for k, v in self.path_mapping.items(): - filepath = filepath.replace(k, v, 1) - return filepath - - def _format_path(self, filepath: str) -> str: - """Convert a ``filepath`` to standard format of s3 oss. - - If the ``filepath`` is concatenated by ``os.path.join``, in a Windows - environment, the ``filepath`` will be the format of - 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the - above ``filepath`` will be converted to 's3://bucket_name/image.jpg'. - - Args: - filepath (str): Path to be formatted. - """ - return re.sub(r"\\+", "/", filepath) - - def _replace_prefix(self, filepath: Union[str, Path]) -> str: - filepath = str(filepath) - return filepath - # return filepath.replace('s3://', 's3://') - - def get(self, filepath: Union[str, Path]) -> bytes: - """Read bytes from a given ``filepath`` with 'rb' mode. - - Args: - filepath (str or Path): Path to read data. - - Returns: - bytes: Return bytes read from filepath. - - Examples: - >>> backend = Boto3Backend() - >>> filepath = 's3://path/of/file' - >>> backend.get(filepath) - b'hello world' - """ - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - value = self._client.get(filepath) - return value - - def get_text( - self, - filepath: Union[str, Path], - encoding: str = "utf-8", - ) -> str: - """Read text from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - - Examples: - >>> backend = Boto3Backend() - >>> filepath = 's3://path/of/file' - >>> backend.get_text(filepath) - 'hello world' - """ - return str(self.get(filepath), encoding=encoding) - - def put(self, obj: Union[bytes, io.BytesIO], filepath: Union[str, Path]) -> None: - """Write bytes to a given ``filepath``. - - Args: - obj (bytes): Data to be saved. - filepath (str or Path): Path to write data. - - Examples: - >>> backend = Boto3Backend() - >>> filepath = 's3://path/of/file' - >>> backend.put(b'hello world', filepath) - """ - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - self._client.put(obj, filepath) - - def fast_put(self, obj: Union[bytes, io.BytesIO], filepath: Union[str, Path], num_processes: int = 32) -> None: - """Write bytes to a given ``filepath`` with multiple processes and async""" - assert num_processes > 1 - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - self._client.fast_put(obj, filepath, num_processes=num_processes) - - def put_text( - self, - obj: str, - filepath: Union[str, Path], - encoding: str = "utf-8", - ) -> None: - """Write text to a given ``filepath``. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str): The encoding format used to encode the ``obj``. - Defaults to 'utf-8'. - - Examples: - >>> backend = Boto3Backend() - >>> filepath = 's3://path/of/file' - >>> backend.put_text('hello world', filepath) - """ - self.put(bytes(obj, encoding=encoding), filepath) - - def exists(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - - Examples: - >>> backend = Boto3Backend() - >>> filepath = 's3://path/of/file' - >>> backend.exists(filepath) - True - """ - if not (has_method(self._client, "contains") and has_method(self._client, "isdir")): - raise NotImplementedError( - "Current version of Boto3 Python SDK has not supported " - "the `contains` and `isdir` methods, please use a higher" - "version or dev branch instead." - ) - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - return self._client.contains(filepath) or self._client.isdir(filepath) - - def isdir(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - - Examples: - >>> backend = Boto3Backend() - >>> filepath = 's3://path/of/dir' - >>> backend.isdir(filepath) - True - """ - if not has_method(self._client, "isdir"): - raise NotImplementedError( - "Current version of Boto3 Python SDK has not supported " - "the `isdir` method, please use a higher version or dev" - " branch instead." - ) - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - return self._client.isdir(filepath) - - def isfile(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - - Examples: - >>> backend = Boto3Backend() - >>> filepath = 's3://path/of/file' - >>> backend.isfile(filepath) - True - """ - if not has_method(self._client, "contains"): - raise NotImplementedError( - "Current version of Boto3 Python SDK has not supported " - "the `contains` method, please use a higher version or " - "dev branch instead." - ) - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - return self._client.contains(filepath) - - def join_path( - self, - filepath: Union[str, Path], - *filepaths: Union[str, Path], - ) -> str: - r"""Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of \*filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - - Returns: - str: The result after concatenation. - - Examples: - >>> backend = Boto3Backend() - >>> filepath = 's3://path/of/file' - >>> backend.join_path(filepath, 'another/path') - 's3://path/of/file/another/path' - >>> backend.join_path(filepath, '/another/path') - 's3://path/of/file/another/path' - """ - filepath = self._format_path(self._map_path(filepath)) - if filepath.endswith("/"): - filepath = filepath[:-1] - formatted_paths = [filepath] - for path in filepaths: - formatted_path = self._format_path(self._map_path(path)) - formatted_paths.append(formatted_path.lstrip("/")) - - return "/".join(formatted_paths) - - @contextmanager - def get_local_path( - self, - filepath: Union[str, Path], - ) -> Generator[Union[str, Path], None, None]: - """Download a file from ``filepath`` to a local temporary directory, - and return the temporary path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Args: - filepath (str or Path): Download a file from ``filepath``. - - Yields: - Iterable[str]: Only yield one temporary path. - - Examples: - >>> backend = Boto3Backend() - >>> # After existing from the ``with`` clause, - >>> # the path will be removed - >>> filepath = 's3://path/of/file' - >>> with backend.get_local_path(filepath) as path: - ... # do something here - """ - assert self.isfile(filepath) - try: - f = tempfile.NamedTemporaryFile(delete=False) - f.write(self.get(filepath)) - f.close() - yield f.name - finally: - os.remove(f.name) - - def copyfile( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Copy a file src to dst and return the destination file. - - src and dst should have the same prefix. If dst specifies a directory, - the file will be copied into dst using the base filename from src. If - dst specifies a file that already exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to dst. - - Returns: - str: The destination file. - - Raises: - SameFileError: If src and dst are the same file, a SameFileError - will be raised. - - Examples: - >>> backend = Boto3Backend() - >>> # dst is a file - >>> src = 's3://path/of/file' - >>> dst = 's3://path/of/file1' - >>> backend.copyfile(src, dst) - 's3://path/of/file1' - - >>> # dst is a directory - >>> dst = 's3://path/of/dir' - >>> backend.copyfile(src, dst) - 's3://path/of/dir/file' - """ - src = self._format_path(self._map_path(src)) - dst = self._format_path(self._map_path(dst)) - if self.isdir(dst): - dst = self.join_path(dst, src.split("/")[-1]) - - if src == dst: - raise SameFileError("src and dst should not be same") - - self.put(self.get(src), dst) - return dst - - def copytree( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Recursively copy an entire directory tree rooted at src to a - directory named dst and return the destination directory. - - src and dst should have the same prefix. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to dst. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Raises: - FileExistsError: If dst had already existed, a FileExistsError will - be raised. - - Examples: - >>> backend = Boto3Backend() - >>> src = 's3://path/of/dir' - >>> dst = 's3://path/of/dir1' - >>> backend.copytree(src, dst) - 's3://path/of/dir1' - """ - src = self._format_path(self._map_path(src)) - dst = self._format_path(self._map_path(dst)) - - if self.exists(dst): - raise FileExistsError("dst should not exist") - - for path in self.list_dir_or_file(src, list_dir=False, recursive=True): - src_path = self.join_path(src, path) - dst_path = self.join_path(dst, path) - self.put(self.get(src_path), dst_path) - - return dst - - def copyfile_from_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Upload a local file src to dst and return the destination file. - - Args: - src (str or Path): A local file to be copied. - dst (str or Path): Copy file to dst. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> backend = Boto3Backend() - >>> # dst is a file - >>> src = 'path/of/your/file' - >>> dst = 's3://path/of/file1' - >>> backend.copyfile_from_local(src, dst) - 's3://path/of/file1' - - >>> # dst is a directory - >>> dst = 's3://path/of/dir' - >>> backend.copyfile_from_local(src, dst) - 's3://path/of/dir/file' - """ - dst = self._format_path(self._map_path(dst)) - if self.isdir(dst): - dst = self.join_path(dst, os.path.basename(src)) - - with open(src, "rb") as f: - self.put(f.read(), dst) - - return dst - - def copytree_from_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Recursively copy an entire directory tree rooted at src to a - directory named dst and return the destination directory. - - Args: - src (str or Path): A local directory to be copied. - dst (str or Path): Copy directory to dst. - - Returns: - str: The destination directory. - - Raises: - FileExistsError: If dst had already existed, a FileExistsError will - be raised. - - Examples: - >>> backend = Boto3Backend() - >>> src = 'path/of/your/dir' - >>> dst = 's3://path/of/dir1' - >>> backend.copytree_from_local(src, dst) - 's3://path/of/dir1' - """ - dst = self._format_path(self._map_path(dst)) - if self.exists(dst): - raise FileExistsError("dst should not exist") - - src = str(src) - - for cur_dir, _, files in os.walk(src): - for f in files: - src_path = os.path.join(cur_dir, f) - dst_path = self.join_path(dst, src_path.replace(src, "")) - self.copyfile_from_local(src_path, dst_path) - - return dst - - def copyfile_to_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - dst_type: str, # Choose from ["file", "dir"] - ) -> Union[str, Path]: - """Copy the file src to local dst and return the destination file. - - If dst specifies a directory, the file will be copied into dst using - the base filename from src. If dst specifies a file that already - exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to to local dst. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> backend = Boto3Backend() - >>> # dst is a file - >>> src = 's3://path/of/file' - >>> dst = 'path/of/your/file' - >>> backend.copyfile_to_local(src, dst) - 'path/of/your/file' - - >>> # dst is a directory - >>> dst = 'path/of/your/dir' - >>> backend.copyfile_to_local(src, dst) - 'path/of/your/dir/file' - """ - assert dst_type in ["file", "dir"] - # There is no good way to detect whether dst is a directory or a file, so we make dst_type required - if dst_type == "dir": - basename = os.path.basename(src) - if isinstance(dst, str): - dst = os.path.join(dst, basename) - else: - assert isinstance(dst, Path) - dst = dst / basename - - # Create parent directory if it doesn't exist - parent_dir = os.path.dirname(dst) - os.makedirs(parent_dir, exist_ok=True) - - try: - with open(dst, "wb") as f: - data = self.get(src) - f.write(data) - except Exception as e: - log.error(f"Failed to write file: {e}") - raise - - return dst - - def copytree_to_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> Union[str, Path]: - """Recursively copy an entire directory tree rooted at src to a local - directory named dst and return the destination directory. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to local dst. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Examples: - >>> backend = Boto3Backend() - >>> src = 's3://path/of/dir' - >>> dst = 'path/of/your/dir' - >>> backend.copytree_to_local(src, dst) - 'path/of/your/dir' - """ - for path in self.list_dir_or_file(src, list_dir=False, recursive=True): - dst_path = os.path.join(dst, path) - mkdir_or_exist(os.path.dirname(dst_path)) - with open(dst_path, "wb") as f: - f.write(self.get(self.join_path(src, path))) - - return dst - - def remove(self, filepath: Union[str, Path]) -> None: - """Remove a file. - - Args: - filepath (str or Path): Path to be removed. - - Raises: - FileNotFoundError: If filepath does not exist, an FileNotFoundError - will be raised. - IsADirectoryError: If filepath is a directory, an IsADirectoryError - will be raised. - - Examples: - >>> backend = Boto3Backend() - >>> filepath = 's3://path/of/file' - >>> backend.remove(filepath) - """ - if not has_method(self._client, "delete"): - raise NotImplementedError( - "Current version of Boto3 Python SDK has not supported " - "the `delete` method, please use a higher version or dev " - "branch instead." - ) - - if not self.exists(filepath): - raise FileNotFoundError(f"filepath {filepath} does not exist") - - if self.isdir(filepath): - raise IsADirectoryError("filepath should be a file") - - filepath = self._map_path(filepath) - filepath = self._format_path(filepath) - filepath = self._replace_prefix(filepath) - self._client.delete(filepath) - - def rmtree(self, dir_path: Union[str, Path]) -> None: - """Recursively delete a directory tree. - - Args: - dir_path (str or Path): A directory to be removed. - - Examples: - >>> backend = Boto3Backend() - >>> dir_path = 's3://path/of/dir' - >>> backend.rmtree(dir_path) - """ - for path in self.list_dir_or_file(dir_path, list_dir=False, recursive=True): - filepath = self.join_path(dir_path, path) - self.remove(filepath) - - def copy_if_symlink_fails( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> bool: - """Create a symbolic link pointing to src named dst. - - Directly copy src to dst because PetrelBacekend does not support create - a symbolic link. - - Args: - src (str or Path): A file or directory to be copied. - dst (str or Path): Copy a file or directory to dst. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - - Returns: - bool: Return False because Boto3Backend does not support create - a symbolic link. - - Examples: - >>> backend = Boto3Backend() - >>> src = 's3://path/of/file' - >>> dst = 's3://path/of/your/file' - >>> backend.copy_if_symlink_fails(src, dst) - False - >>> src = 's3://path/of/dir' - >>> dst = 's3://path/of/your/dir' - >>> backend.copy_if_symlink_fails(src, dst) - False - """ - if self.isfile(src): - self.copyfile(src, dst) - else: - self.copytree(src, dst) - return False - - def list_dir(self, dir_path: Union[str, Path]): - """List all folders in an S3 bucket with a given prefix. - - Args: - dir_path (str | Path): Path of the directory. - - Examples: - >>> backend = Boto3Backend() - >>> dir_path = 's3://path/of/dir' - >>> backend.list_dir(dir_path) - """ - dir_path = self._map_path(dir_path) - dir_path = self._format_path(dir_path) - dir_path = self._replace_prefix(dir_path) - return self._client.ls_dir(dir_path) - - def list_dir_or_file( # pylint: disable=too-many-arguments - self, - dir_path: Union[str, Path], - list_dir: bool = True, - list_file: bool = True, - suffix: Optional[Union[str, Tuple[str]]] = None, - recursive: bool = False, - ) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - Boto3 has no concept of directories but it simulates the directory - hierarchy in the filesystem through public prefixes. In addition, - if the returned path ends with '/', it means the path is a public - prefix which is a logical directory. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - In addition, the returned path of directory will not contains the - suffix '/' which is consistent with other backends. - - Args: - dir_path (str | Path): Path of the directory. - list_dir (bool): List the directories. Defaults to True. - list_file (bool): List the path of files. Defaults to True. - suffix (str or tuple[str], optional): File suffix - that we are interested in. Defaults to None. - recursive (bool): If set to True, recursively scan the - directory. Defaults to False. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - - Examples: - >>> backend = Boto3Backend() - >>> dir_path = 's3://path/of/dir' - >>> # list those files and directories in current directory - >>> for file_path in backend.list_dir_or_file(dir_path): - ... print(file_path) - >>> # only list files - >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False): - ... print(file_path) - >>> # only list directories - >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False): - ... print(file_path) - >>> # only list files ending with specified suffixes - >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'): - ... print(file_path) - >>> # list all files and directory recursively - >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True): - ... print(file_path) - """ # noqa: E501 - if not has_method(self._client, "list"): - raise NotImplementedError( - "Current version of Boto3 Python SDK has not supported " - "the `list` method, please use a higher version or dev" - " branch instead." - ) - - dir_path = self._map_path(dir_path) - dir_path = self._format_path(dir_path) - dir_path = self._replace_prefix(dir_path) - if list_dir and suffix is not None: - raise TypeError("`list_dir` should be False when `suffix` is not None") - - if list_dir and not list_file and not recursive: - raise TypeError( - "Please use `list_dir` instead of `list_dir_or_file` when you only want to list the first level directories." - ) - - if (suffix is not None) and not isinstance(suffix, (str, tuple)): - raise TypeError("`suffix` must be a string or tuple of strings") - - # Boto3's simulated directory hierarchy assumes that directory paths - # should end with `/` - if not dir_path.endswith("/"): - dir_path += "/" - - root = dir_path - - def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive): - # Keep track of directories we've already yielded to avoid duplicates - yielded_dirs = set() if list_dir else None - - for path in self._client.list(dir_path): - # All paths returned by S3 list are file paths, never directory paths - absolute_path = self.join_path(dir_path, path) - rel_path = absolute_path[len(root) :] - - # If we want directories, extract directory prefixes from file paths - # boto3 client actually never return dir, it only return file paths - if list_dir and "/" in rel_path: - if not recursive: - # Non-recursive: only yield immediate child directory (first level) - first_slash_pos = rel_path.find("/") - immediate_child_dir = rel_path[:first_slash_pos] - - if immediate_child_dir not in yielded_dirs: - yielded_dirs.add(immediate_child_dir) - yield immediate_child_dir - else: - # Recursive: yield all directory levels - path_parts = rel_path.split("/")[:-1] # Exclude filename - current_dir = "" - for part in path_parts: - if current_dir: - current_dir += "/" + part - else: - current_dir = part - - if current_dir not in yielded_dirs: - yielded_dirs.add(current_dir) - yield current_dir - - # Handle file listing - if (suffix is None or rel_path.endswith(suffix)) and list_file: - yield rel_path - - return _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) - - def generate_presigned_url(self, url: str, client_method: str = "get_object", expires_in: int = 3600) -> str: - """Generate the presigned url of video stream which can be passed to - mmcv.VideoReader. Now only work on Boto3 backend. - - Note: - Now only work on Boto3 backend. - - Args: - url (str): Url of video stream. - client_method (str): Method of client, 'get_object' or - 'put_object'. Default: 'get_object'. - expires_in (int): expires, in seconds. Default: 3600. - - Returns: - str: Generated presigned url. - """ - raise NotImplementedError("generate_presigned_url is not supported in Boto3Backend") - return self._client.generate_presigned_url(url, client_method, expires_in) diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/backends/boto3_client.py b/lyra_2/_ext/imaginaire/utils/easy_io/backends/boto3_client.py deleted file mode 100644 index 829b1261ffb3854667ff673323a4b3ba4607010e..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/backends/boto3_client.py +++ /dev/null @@ -1,565 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import concurrent.futures -import io -import os -import time -from math import ceil -from multiprocessing import shared_memory -from typing import Any, Dict, Generator, List, Tuple - -import boto3 -import numpy as np -from botocore.config import Config as S3Config -from botocore.exceptions import ClientError - -import lyra_2._ext.imaginaire.utils.easy_io.backends.auto_auth as auto -from lyra_2._ext.imaginaire.utils import log -from lyra_2._ext.imaginaire.utils.env_parsers.cred_env_parser import CRED_ENVS - -try: - import aioboto3 - import aioboto3.session - from aiobotocore.config import AioConfig - from aiobotocore.session import AioSession -except ImportError: - aioboto3 = None - AioSession = None - -MAX_RETRIES = 5 -RETRY_DELAY = 1 # seconds - - -async def upload_single_part_async( - s3: AioSession, bucket: str, key: str, part_number: int, data: bytes, upload_id: str -) -> Dict[str, Any]: - """ - Uploads a single part of a file asynchronously to S3. - - Args: - s3 (S3): The S3 client. - bucket (str): The S3 bucket name. - key (str): The S3 key (file path). - part_number (int): The part number of the upload. - data (bytes): The data to upload. - upload_id (str): The upload ID for the multipart upload. - - Returns: - Dict[str, Any]: A dictionary containing the part number and ETag. - """ - for attempt in range(MAX_RETRIES): - try: - response = await s3.upload_part( - Bucket=bucket, Key=key, PartNumber=part_number, UploadId=upload_id, Body=data - ) - return {"PartNumber": part_number, "ETag": response["ETag"]} - except (ClientError, asyncio.TimeoutError, Exception) as e: - log.warning(f"Attempt {attempt + 1} failed for part {part_number}: {str(e)}", rank0_only=False) - if attempt < MAX_RETRIES - 1: - await asyncio.sleep(RETRY_DELAY * (2**attempt)) # Exponential backoff - else: - log.error(f"Failed to upload part {part_number} after {MAX_RETRIES} attempts", rank0_only=False) - raise - - -async def upload_parts_async( - part_size: int, - part_numbers: range, - upload_id: str, - data: bytes, - bucket: str, - key: str, - client_config: Dict[str, Any], -) -> List[Dict[str, Any]]: - """ - Uploads multiple parts of a file asynchronously to S3. - - Args: - part_size (int): The size of each part in bytes. - part_numbers (range): The range of part numbers to upload. - upload_id (str): The upload ID for the multipart upload. - data (bytes): The data to upload. - bucket (str): The S3 bucket name. - key (str): The S3 key (file path). - client_config (Dict[str, Any]): The S3 client configuration. - - Returns: - List[Dict[str, Any]]: A list of dictionaries containing part numbers and ETags. - """ - session = aioboto3.Session() - config = AioConfig(retries={"max_attempts": 3, "mode": "adaptive"}, connect_timeout=5, read_timeout=10) - start_idx = part_numbers[0] - async with session.client("s3", config=config, **client_config) as s3: - tasks = [] - for part_number in part_numbers: - start = (part_number - start_idx) * part_size - end = min(start + part_size, len(data)) - part_data = data[start:end] - tasks.append(upload_single_part_async(s3, bucket, key, part_number + 1, part_data, upload_id)) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - successful_parts = [] - failed_parts = [] - for part_number, result in enumerate(results, start=start_idx + 1): - if isinstance(result, Exception): - failed_parts.append(part_number) - else: - successful_parts.append(result) - - if failed_parts: - log.error(f"Failed to upload parts: {failed_parts}", rank0_only=False) - raise Exception(f"Failed to upload {len(failed_parts)} parts") - - successful_parts.sort(key=lambda part: part["PartNumber"]) - return successful_parts - - -def upload_parts_to_s3(args: Tuple[range, str, int, bytes, str, str, Dict[str, Any]]) -> List[Dict[str, Any]]: - """ - Uploads parts of a file to S3 using a new event loop. - - Args: - args (Tuple[range, str, int, bytes, str, str, Dict[str, Any]]): The arguments for uploading parts, including: - part_numbers (range): The range of part numbers to upload. - upload_id (str): The upload ID for the multipart upload. - part_size (int): The size of each part in bytes. - data (bytes): The data to upload. - bucket (str): The S3 bucket name. - key (str): The S3 key (file path). - client_config (Dict[str, Any]): The S3 client configuration. - - Returns: - List[Dict[str, Any]]: A list of dictionaries containing part numbers and ETags. - """ - part_numbers, upload_id, part_size, data, bucket, key, client_config = args - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - parts = loop.run_until_complete( - upload_parts_async(part_size, part_numbers, upload_id, data, bucket, key, client_config) - ) - loop.close() - return parts - - -async def download_single_part_async( - s3, bucket: str, key: str, part_number: int, start: int, end: int, shm_name: str, part_size: int -) -> None: - """ - Downloads a single part of a file asynchronously and writes it to shared memory. - - Args: - s3 (S3): The S3 client. - bucket (str): The S3 bucket name. - key (str): The S3 key (file path). - part_number (int): The part number. - start (int): The start byte of the part. - end (int): The end byte of the part. - shm_name (str): The name of the shared memory block. - part_size (int): The size of each part in bytes. - """ - for attempt in range(MAX_RETRIES): - try: - range_header = f"bytes={start}-{end}" - response = await s3.get_object(Bucket=bucket, Key=key, Range=range_header) - data = await response["Body"].read() - - shm = shared_memory.SharedMemory(name=shm_name) - offset = part_number * part_size - shm.buf[offset : offset + len(data)] = data - shm.close() - return - except (ClientError, asyncio.TimeoutError, Exception) as e: - log.warning(f"Attempt {attempt + 1} failed for part {part_number}: {str(e)}", rank0_only=False) - if attempt < MAX_RETRIES - 1: - await asyncio.sleep(RETRY_DELAY * (2**attempt)) # Exponential backoff - else: - log.error(f"Failed to download part {part_number} after {MAX_RETRIES} attempts", rank0_only=False) - raise - - -async def download_parts_async( - part_size: int, part_numbers: range, bucket: str, key: str, client_config: Dict[str, Any], shm_name: str -) -> None: - """ - Downloads multiple parts of a file asynchronously and writes them to shared memory. - - Args: - part_size (int): The size of each part in bytes. - part_numbers (range): The range of part numbers to download. - bucket (str): The S3 bucket name. - key (str): The S3 key (file path). - client_config (Dict[str, Any]): The S3 client configuration. - shm_name (str): The name of the shared memory block. - """ - session = aioboto3.Session() - config = AioConfig(retries={"max_attempts": 5, "mode": "adaptive"}, connect_timeout=10, read_timeout=30) - async with session.client("s3", config=config, **client_config) as s3: - tasks = [ - download_single_part_async( - s3, - bucket, - key, - part_number, - part_number * part_size, - (part_number + 1) * part_size - 1, - shm_name, - part_size, - ) - for part_number in part_numbers - ] - results = await asyncio.gather(*tasks, return_exceptions=True) - failed_parts = [part for part, result in zip(part_numbers, results) if isinstance(result, Exception)] - - if failed_parts: - log.error(f"Failed to download parts: {failed_parts}", rank0_only=False) - raise Exception(f"Failed to download {len(failed_parts)} parts") - - -def download_parts_to_s3(args: Tuple[range, int, str, str, Dict[str, Any], str]) -> bytes: - """ - Downloads parts of a file using a new event loop. - - Args: - args (Tuple[range, int, str, str, Dict[str, Any]]): The arguments for downloading parts, including: - part_numbers (range): The range of part numbers to download. - part_size (int): The size of each part in bytes. - bucket (str): The S3 bucket name. - key (str): The S3 key (file path). - client_config (Dict[str, Any]): The S3 client configuration. - - Returns: - bytes: The combined file data from all downloaded parts. - """ - part_numbers, part_size, bucket, key, client_config, shm_name = args - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(download_parts_async(part_size, part_numbers, bucket, key, client_config, shm_name)) - loop.close() - - -class Boto3Client: - def __init__( - self, - s3_credential_path: str, - max_attempt: int = 3, - ): - self.max_attempt = max_attempt - assert s3_credential_path, "s3_credential_path is required" - assert os.path.exists(s3_credential_path) or CRED_ENVS.APP_ENV in [ - "prod", - "dev", - "stg", - ], f"Credential file not found: {s3_credential_path}" - with auto.open_auth(s3_credential_path, "r") as f: - conf = auto.json_load_auth(f) - - s3_config = S3Config( - signature_version="s3v4", - s3={"addressing_style": "virtual"}, - response_checksum_validation="when_required", - request_checksum_calculation="when_required", - ) - - self._client = boto3.client("s3", **conf, config=s3_config) - self._s3_cred_info = conf - self._mc_kv_store = None - - def get(self, filepath): - filepath = self._check_path(filepath) - - if self._mc_kv_store and self._mc_kv_store.available: - if self._mc_kv_store.has(filepath): - return self._mc_kv_store.get(filepath) - - attempt = 0 - while attempt < self.max_attempt: - try: - buffer = io.BytesIO() - self._client.download_fileobj( - Bucket=filepath.split("/")[0], - Key="/".join(filepath.split("/")[1:]), - Fileobj=buffer, - ) - buffer.seek(0) - if self._mc_kv_store and self._mc_kv_store.available: - self._mc_kv_store.put(filepath, buffer.read()) - - return buffer.read() - except Exception as e: - attempt += 1 - log.error(f"Got an exception: attempt={attempt} - {e} - {filepath}", rank0_only=False) - - raise ConnectionError("Unable to read {} from. {} attempts tried.".format(filepath, attempt)) - - def _get_file_size(self, bucket, key, max_retries=10): - retries = 0 - while retries < max_retries: - try: - # Try to get the file size - file_size = self._client.head_object(Bucket=bucket, Key=key)["ContentLength"] - return file_size # Return file size if successful - except ClientError as e: - retries += 1 - log.error(f"Attempt {retries} failed for s3://{bucket}/{key}: {e}", rank0_only=False) - if retries >= max_retries: - raise # Re-raise the exception after max retries - time.sleep(2) # Wait for 2 seconds before retrying - except Exception as e: - retries += 1 - log.error( - f"Attempt {retries} failed for s3://{bucket}/{key}: due to an unexpected error: {e}", - rank0_only=False, - ) - if retries >= max_retries: - raise # Re-raise the exception after max retries - time.sleep(2) # Wait for 2 seconds before retrying - - def put(self, obj, filepath): - filepath = self._check_path(filepath) - bucket_name = filepath.split("/")[0] - key = "/".join(filepath.split("/")[1:]) - attempt = 0 - while attempt < self.max_attempt: - try: - # If obj is a string path to a local file, use upload_file instead - if isinstance(obj, str) and os.path.isfile(obj): - self._client.upload_file(Filename=obj, Bucket=bucket_name, Key=key) - return - if isinstance(obj, io.BytesIO): - obj.seek(0) - self._client.upload_fileobj(obj, Bucket=bucket_name, Key=key) - return - if isinstance(obj, bytes): - self._client.put_object(Body=obj, Bucket=bucket_name, Key=key) - return - else: - raise ValueError("Unsupported object type for upload") - except ClientError as e: - attempt += 1 - log.error(f"Got an exception: attempt={attempt} - {e} - {filepath}", rank0_only=False) - - raise ConnectionError("Unable to write {} to. {} attempts tried.".format(filepath, attempt)) - - def fast_put(self, obj, filepath, num_processes: int = 32): - assert aioboto3 is not None, "aioboto3 is required for fast_put" - original_filepath = filepath - filepath = self._check_path(filepath) - bucket = filepath.split("/")[0] - key = "/".join(filepath.split("/")[1:]) - part_size = 16 * 1024 * 1024 # 16 MB part size - - if isinstance(obj, bytes): - data = obj - elif isinstance(obj, str) and os.path.isfile(obj): - with open(obj, "rb") as f: - data = f.read() - elif isinstance(obj, io.BytesIO): - obj.seek(0) - data = obj.read() - else: - raise ValueError("Unsupported object type for upload") - - file_size = len(data) - if file_size <= part_size * num_processes: - return self.put(data, original_filepath) - num_parts = ceil(file_size / part_size) - upload_id = self._client.create_multipart_upload(Bucket=bucket, Key=key)["UploadId"] - - part_numbers = np.array_split(np.arange(num_parts), num_processes) - - with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor: - args = [] - for i in range(num_processes): - cur_parts = part_numbers[i].tolist() - cur_data = data[cur_parts[0] * part_size : min(cur_parts[-1] * part_size + part_size, file_size)] - args.append((cur_parts, upload_id, part_size, cur_data, bucket, key, self._s3_cred_info)) - results = executor.map(upload_parts_to_s3, args) - parts = [] - for result in results: - parts.extend(result) - - parts = sorted(parts, key=lambda part: part["PartNumber"]) - self._client.complete_multipart_upload( - Bucket=bucket, Key=key, UploadId=upload_id, MultipartUpload={"Parts": parts} - ) - - def contains(self, filepath: str, max_retries=10) -> bool: - """ - Checks if the specified object exists in the S3 bucket with retry logic for errors. - - Args: - filepath (str): The s3 path of the file to check, must start with "s3://". - - Returns: - bool: True if the object exists in the S3 bucket, False otherwise. - - Raises: - ClientError: If an error response other than "404 Not Found" is returned from the S3 service. - """ - filepath = self._check_path(filepath) - bucket = filepath.split("/")[0] - key = "/".join(filepath.split("/")[1:]) - - retries = 0 - while retries < max_retries: - try: - # Try to check if the object exists - self._client.head_object(Bucket=bucket, Key=key) - return True # Object exists - except ClientError as e: - if e.response["Error"]["Code"] == "404": - return False # Object does not exist - else: - retries += 1 - print(f"Attempt {retries} failed with error: {e}") - if retries >= max_retries: - raise # Re-raise the exception if max retries are reached - time.sleep(2) # Wait for 2 seconds before retrying - except Exception as e: - retries += 1 - print(f"Attempt {retries} failed due to an unexpected error: {e}") - if retries >= max_retries: - raise # Re-raise the exception if max retries are reached - time.sleep(2) # Wait for 2 seconds before retrying - - def isdir(self, filepath: str, max_retries=10) -> bool: - """ - Determines if the specified path corresponds to a directory in S3 with retry logic. - - A directory in S3 is implied if there are any objects stored with the given prefix, - which means this function checks for the existence of any objects at or under the specified path. - - Args: - filepath (str): The s3 path to check, must start with "s3://". - - Returns: - bool: True if the specified path corresponds to a directory in S3, False otherwise. - Directories in S3 are not physical entities but are implied by object keys. - - Raises: - ClientError: An error from the S3 API that isn't related to the absence of the directory - (logged but not raised further). - """ - filepath = self._check_path(filepath) - if not filepath.endswith("/"): - filepath += "/" - - bucket = filepath.split("/")[0] - prefix = "/".join(filepath.split("/")[1:]) - - retries = 0 - while retries < max_retries: - try: - # Try to check if any objects exist with the given prefix (i.e., directory in S3) - resp = self._client.list_objects_v2(Bucket=bucket, Prefix=prefix, Delimiter="/", MaxKeys=1) - # Check if any content or prefixes exist under the given path - return "CommonPrefixes" in resp or "Contents" in resp - except ClientError as e: - retries += 1 - log.error(f"Attempt {retries} failed: {e}", rank0_only=False) - if retries >= max_retries: - return False # Return False if maximum retries are reached - time.sleep(2) # Wait for 2 seconds before retrying - except Exception as e: - retries += 1 - log.error(f"Attempt {retries} failed due to an unexpected error: {e}", rank0_only=False) - if retries >= max_retries: - return False # Return False if maximum retries are reached - time.sleep(2) # Wait for 2 seconds before retrying - - def delete(self, filepath): - filepath = self._check_path(filepath) - self._client.delete_object(Bucket=filepath.split("/")[0], Key="/".join(filepath.split("/")[1:])) - - def ls_dir(self, filepath: str) -> Generator[str, None, None]: - """ - List all folders in an S3 bucket with a given prefix. - - Args: - filepath (str): The S3 path of the folder to list. - - Yields: - str: The keys of the folders in the S3 bucket. - """ - filepath = self._check_path(filepath) - bucket = filepath.split("/")[0] - prefix = "/".join(filepath.split("/")[1:]) - continuation_token = None - if prefix and not prefix.endswith("/"): - prefix += "/" - - while True: - if continuation_token: - resp = self._client.list_objects_v2( - Bucket=bucket, Prefix=prefix, Delimiter="/", ContinuationToken=continuation_token - ) - else: - resp = self._client.list_objects_v2(Bucket=bucket, Prefix=prefix, Delimiter="/") - - if "CommonPrefixes" in resp: - for item in resp["CommonPrefixes"]: - yield item["Prefix"][len(prefix) :] - - # Check if there are more keys to retrieve - if resp.get("IsTruncated"): # If IsTruncated is True, there are more keys - continuation_token = resp.get("NextContinuationToken") - else: - break - - def list(self, filepath: str, exclude_prefix: str = None) -> Generator[str, None, None]: - """ - List all keys in an S3 bucket with a given prefix, excluding files that start with - specified prefix. - - Args: - filepath (str): The S3 path of the file to list. - exclude_prefix (str): Files starting with this prefix will be excluded from results. - Defaults to "real". - - Yields: - str: The keys of the files in the S3 bucket that don't start with exclude_prefix. - """ - filepath = self._check_path(filepath) - bucket = filepath.split("/")[0] - prefix = "/".join(filepath.split("/")[1:]) - - continuation_token = None - - while True: - if continuation_token: - resp = self._client.list_objects_v2(Bucket=bucket, Prefix=prefix, ContinuationToken=continuation_token) - else: - resp = self._client.list_objects_v2(Bucket=bucket, Prefix=prefix) - - if "Contents" in resp: - for item in resp["Contents"]: - key = item["Key"][len(prefix) :] - # Skip files that start with the excluded prefix - if exclude_prefix is None or not key.startswith(exclude_prefix): - yield key - - # Check if there are more keys to retrieve - if resp.get("IsTruncated"): # If IsTruncated is True, there are more keys - continuation_token = resp.get("NextContinuationToken") - else: - break - - def _check_path(self, filepath: str): - assert filepath.startswith("s3://") - filepath = filepath[5:] - return filepath diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/backends/http_backend.py b/lyra_2/_ext/imaginaire/utils/easy_io/backends/http_backend.py deleted file mode 100644 index 5b094d88250f497434bd1c629922ea0d48a4f6b6..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/backends/http_backend.py +++ /dev/null @@ -1,91 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -from contextlib import contextmanager -from pathlib import Path -from typing import Generator, Union -from urllib.request import urlopen - -from lyra_2._ext.imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend - - -class HTTPBackend(BaseStorageBackend): - """HTTP and HTTPS storage bachend.""" - - def get(self, filepath: str) -> bytes: - """Read bytes from a given ``filepath``. - - Args: - filepath (str): Path to read data. - - Returns: - bytes: Expected bytes object. - - Examples: - >>> backend = HTTPBackend() - >>> backend.get('http://path/of/file') - b'hello world' - """ - return urlopen(filepath).read() - - def get_text(self, filepath, encoding="utf-8") -> str: - """Read text from a given ``filepath``. - - Args: - filepath (str): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - - Examples: - >>> backend = HTTPBackend() - >>> backend.get_text('http://path/of/file') - 'hello world' - """ - return urlopen(filepath).read().decode(encoding) - - @contextmanager - def get_local_path(self, filepath: str) -> Generator[Union[str, Path], None, None]: - """Download a file from ``filepath`` to a local temporary directory, - and return the temporary path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Args: - filepath (str): Download a file from ``filepath``. - - Yields: - Iterable[str]: Only yield one temporary path. - - Examples: - >>> backend = HTTPBackend() - >>> # After existing from the ``with`` clause, - >>> # the path will be removed - >>> with backend.get_local_path('http://path/of/file') as path: - ... # do something here - """ - try: - f = tempfile.NamedTemporaryFile(delete=False) - f.write(self.get(filepath)) - f.close() - yield f.name - finally: - os.remove(f.name) diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/backends/local_backend.py b/lyra_2/_ext/imaginaire/utils/easy_io/backends/local_backend.py deleted file mode 100644 index 52d00c0dd4af0708e539a46845d5a54aa4b38234..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/backends/local_backend.py +++ /dev/null @@ -1,550 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import io -import os -import os.path as osp -import shutil -from contextlib import contextmanager -from pathlib import Path -from typing import Generator, Iterator, Optional, Tuple, Union - -from lyra_2._ext.imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend, mkdir_or_exist - - -class LocalBackend(BaseStorageBackend): - """Raw local storage backend.""" - - _allow_symlink = True - - def get(self, filepath: Union[str, Path]) -> bytes: - """Read bytes from a given ``filepath`` with 'rb' mode. - - Args: - filepath (str or Path): Path to read data. - - Returns: - bytes: Expected bytes object. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.get(filepath) - b'hello world' - """ - with open(filepath, "rb") as f: - value = f.read() - return value - - def get_text(self, filepath: Union[str, Path], encoding: str = "utf-8") -> str: - """Read text from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.get_text(filepath) - 'hello world' - """ - with open(filepath, encoding=encoding) as f: - text = f.read() - return text - - def put(self, obj: Union[bytes, io.BytesIO], filepath: Union[str, Path]) -> None: - """Write bytes to a given ``filepath`` with 'wb' mode. - - Note: - ``put`` will create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (bytes): Data to be written. - filepath (str or Path): Path to write data. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.put(b'hello world', filepath) - """ - mkdir_or_exist(osp.dirname(filepath)) - if isinstance(obj, io.BytesIO): - obj.seek(0) - obj = obj.getvalue() - with open(filepath, "wb") as f: - f.write(obj) - - def put_text(self, obj: str, filepath: Union[str, Path], encoding: str = "utf-8") -> None: - """Write text to a given ``filepath`` with 'w' mode. - - Note: - ``put_text`` will create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.put_text('hello world', filepath) - """ - mkdir_or_exist(osp.dirname(filepath)) - with open(filepath, "w", encoding=encoding) as f: - f.write(obj) - - def exists(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.exists(filepath) - True - """ - return osp.exists(filepath) - - def isdir(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/dir' - >>> backend.isdir(filepath) - True - """ - return osp.isdir(filepath) - - def isfile(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.isfile(filepath) - True - """ - return osp.isfile(filepath) - - def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: - r"""Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of \*filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - - Returns: - str: The result of concatenation. - - Examples: - >>> backend = LocalBackend() - >>> filepath1 = '/path/of/dir1' - >>> filepath2 = 'dir2' - >>> filepath3 = 'path/of/file' - >>> backend.join_path(filepath1, filepath2, filepath3) - '/path/of/dir/dir2/path/of/file' - """ - return osp.join(filepath, *filepaths) - - @contextmanager - def get_local_path( - self, - filepath: Union[str, Path], - ) -> Generator[Union[str, Path], None, None]: - """Only for unified API and do nothing. - - Args: - filepath (str or Path): Path to be read data. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Examples: - >>> backend = LocalBackend() - >>> with backend.get_local_path('s3://bucket/abc.jpg') as path: - ... # do something here - """ - yield filepath - - def copyfile( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Copy a file src to dst and return the destination file. - - src and dst should have the same prefix. If dst specifies a directory, - the file will be copied into dst using the base filename from src. If - dst specifies a file that already exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to dst. - - Returns: - str: The destination file. - - Raises: - SameFileError: If src and dst are the same file, a SameFileError - will be raised. - - Examples: - >>> backend = LocalBackend() - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> # src will be copied to '/path1/of/file1' - >>> backend.copyfile(src, dst) - '/path1/of/file1' - - >>> # dst is a directory - >>> dst = '/path1/of/dir' - >>> # src will be copied to '/path1/of/dir/file' - >>> backend.copyfile(src, dst) - '/path1/of/dir/file' - """ - return shutil.copy(src, dst) - - def copytree( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Recursively copy an entire directory tree rooted at src to a - directory named dst and return the destination directory. - - src and dst should have the same prefix and dst must not already exist. - - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to dst. - - Returns: - str: The destination directory. - - Raises: - FileExistsError: If dst had already existed, a FileExistsError will - be raised. - - Examples: - >>> backend = LocalBackend() - >>> src = '/path/of/dir1' - >>> dst = '/path/of/dir2' - >>> backend.copytree(src, dst) - '/path/of/dir2' - """ - return shutil.copytree(src, dst) - - def copyfile_from_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Copy a local file src to dst and return the destination file. Same - as :meth:`copyfile`. - - Args: - src (str or Path): A local file to be copied. - dst (str or Path): Copy file to dst. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Raises: - SameFileError: If src and dst are the same file, a SameFileError - will be raised. - - Examples: - >>> backend = LocalBackend() - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> # src will be copied to '/path1/of/file1' - >>> backend.copyfile_from_local(src, dst) - '/path1/of/file1' - - >>> # dst is a directory - >>> dst = '/path1/of/dir' - >>> # src will be copied to - >>> backend.copyfile_from_local(src, dst) - '/path1/of/dir/file' - """ - return self.copyfile(src, dst) - - def copytree_from_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Recursively copy an entire directory tree rooted at src to a - directory named dst and return the destination directory. Same as - :meth:`copytree`. - - Args: - src (str or Path): A local directory to be copied. - dst (str or Path): Copy directory to dst. - - Returns: - str: The destination directory. - - Examples: - >>> backend = LocalBackend() - >>> src = '/path/of/dir1' - >>> dst = '/path/of/dir2' - >>> backend.copytree_from_local(src, dst) - '/path/of/dir2' - """ - return self.copytree(src, dst) - - def copyfile_to_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - dst_type: Optional[str] = None, - ) -> str: - """Copy the file src to local dst and return the destination file. Same - as :meth:`copyfile`. - - If dst specifies a directory, the file will be copied into dst using - the base filename from src. If dst specifies a file that already - exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to to local dst. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> backend = LocalBackend() - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> # src will be copied to '/path1/of/file1' - >>> backend.copyfile_to_local(src, dst) - '/path1/of/file1' - - >>> # dst is a directory - >>> dst = '/path1/of/dir' - >>> # src will be copied to - >>> backend.copyfile_to_local(src, dst) - '/path1/of/dir/file' - """ - return self.copyfile(src, dst) - - def copytree_to_local( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> str: - """Recursively copy an entire directory tree rooted at src to a local - directory named dst and return the destination directory. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to local dst. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Examples: - >>> backend = LocalBackend() - >>> src = '/path/of/dir1' - >>> dst = '/path/of/dir2' - >>> backend.copytree_from_local(src, dst) - '/path/of/dir2' - """ - return self.copytree(src, dst) - - def remove(self, filepath: Union[str, Path]) -> None: - """Remove a file. - - Args: - filepath (str or Path): Path to be removed. - - Raises: - IsADirectoryError: If filepath is a directory, an IsADirectoryError - will be raised. - FileNotFoundError: If filepath does not exist, an FileNotFoundError - will be raised. - - Examples: - >>> backend = LocalBackend() - >>> filepath = '/path/of/file' - >>> backend.remove(filepath) - """ - if not self.exists(filepath): - raise FileNotFoundError(f"filepath {filepath} does not exist") - - if self.isdir(filepath): - raise IsADirectoryError("filepath should be a file") - - os.remove(filepath) - - def rmtree(self, dir_path: Union[str, Path]) -> None: - """Recursively delete a directory tree. - - Args: - dir_path (str or Path): A directory to be removed. - - Examples: - >>> dir_path = '/path/of/dir' - >>> backend.rmtree(dir_path) - """ - shutil.rmtree(dir_path) - - def copy_if_symlink_fails( - self, - src: Union[str, Path], - dst: Union[str, Path], - ) -> bool: - """Create a symbolic link pointing to src named dst. - - If failed to create a symbolic link pointing to src, directly copy src - to dst instead. - - Args: - src (str or Path): Create a symbolic link pointing to src. - dst (str or Path): Create a symbolic link named dst. - - Returns: - bool: Return True if successfully create a symbolic link pointing - to src. Otherwise, return False. - - Examples: - >>> backend = LocalBackend() - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> backend.copy_if_symlink_fails(src, dst) - True - >>> src = '/path/of/dir' - >>> dst = '/path1/of/dir1' - >>> backend.copy_if_symlink_fails(src, dst) - True - """ - try: - os.symlink(src, dst) - return True - except Exception: - if self.isfile(src): - self.copyfile(src, dst) - else: - self.copytree(src, dst) - return False - - def list_dir_or_file( - self, - dir_path: Union[str, Path], - list_dir: bool = True, - list_file: bool = True, - suffix: Optional[Union[str, Tuple[str]]] = None, - recursive: bool = False, - ) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - - Args: - dir_path (str or Path): Path of the directory. - list_dir (bool): List the directories. Defaults to True. - list_file (bool): List the path of files. Defaults to True. - suffix (str or tuple[str], optional): File suffix that we are - interested in. Defaults to None. - recursive (bool): If set to True, recursively scan the directory. - Defaults to False. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - - Examples: - >>> backend = LocalBackend() - >>> dir_path = '/path/of/dir' - >>> # list those files and directories in current directory - >>> for file_path in backend.list_dir_or_file(dir_path): - ... print(file_path) - >>> # only list files - >>> for file_path in backend.list_dir_or_file(dir_path, list_dir=False): - ... print(file_path) - >>> # only list directories - >>> for file_path in backend.list_dir_or_file(dir_path, list_file=False): - ... print(file_path) - >>> # only list files ending with specified suffixes - >>> for file_path in backend.list_dir_or_file(dir_path, suffix='.txt'): - ... print(file_path) - >>> # list all files and directory recursively - >>> for file_path in backend.list_dir_or_file(dir_path, recursive=True): - ... print(file_path) - """ # noqa: E501 - if list_dir and suffix is not None: - raise TypeError("`suffix` should be None when `list_dir` is True") - - if (suffix is not None) and not isinstance(suffix, (str, tuple)): - raise TypeError("`suffix` must be a string or tuple of strings") - - root = dir_path - - def _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive): - for entry in os.scandir(dir_path): - if not entry.name.startswith(".") and entry.is_file(): - rel_path = osp.relpath(entry.path, root) - if (suffix is None or rel_path.endswith(suffix)) and list_file: - yield rel_path - elif osp.isdir(entry.path): - if list_dir: - rel_dir = osp.relpath(entry.path, root) - yield rel_dir - if recursive: - yield from _list_dir_or_file(entry.path, list_dir, list_file, suffix, recursive) - - return _list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/backends/registry_utils.py b/lyra_2/_ext/imaginaire/utils/easy_io/backends/registry_utils.py deleted file mode 100644 index a071a85491202425775009310286d2b1ab175bae..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/backends/registry_utils.py +++ /dev/null @@ -1,130 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Optional, Type, Union - -from lyra_2._ext.imaginaire.utils.easy_io.backends.base_backend import BaseStorageBackend -from lyra_2._ext.imaginaire.utils.easy_io.backends.boto3_backend import Boto3Backend -from lyra_2._ext.imaginaire.utils.easy_io.backends.http_backend import HTTPBackend -from lyra_2._ext.imaginaire.utils.easy_io.backends.local_backend import LocalBackend - -backends: dict = {} -prefix_to_backends: dict = {} - - -def _register_backend( - name: str, - backend: Type[BaseStorageBackend], - force: bool = False, - prefixes: Union[str, list, tuple, None] = None, -): - """Register a backend. - - Args: - name (str): The name of the registered backend. - backend (BaseStorageBackend): The backend class to be registered, - which must be a subclass of :class:`BaseStorageBackend`. - force (bool): Whether to override the backend if the name has already - been registered. Defaults to False. - prefixes (str or list[str] or tuple[str], optional): The prefix - of the registered storage backend. Defaults to None. - """ - global backends, prefix_to_backends - - if not isinstance(name, str): - raise TypeError(f"the backend name should be a string, but got {type(name)}") - - if not inspect.isclass(backend): - raise TypeError(f"backend should be a class, but got {type(backend)}") - if not issubclass(backend, BaseStorageBackend): - raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend") - - if name in backends and not force: - raise ValueError( - f'{name} is already registered as a storage backend, add "force=True" if you want to override it' - ) - backends[name] = backend - - if prefixes is not None: - if isinstance(prefixes, str): - prefixes = [prefixes] - else: - assert isinstance(prefixes, (list, tuple)) - - for prefix in prefixes: - if prefix in prefix_to_backends and not force: - raise ValueError( - f'{prefix} is already registered as a storage backend, add "force=True" if you want to override it' - ) - - prefix_to_backends[prefix] = backend - - -def register_backend( - name: str, - backend: Optional[Type[BaseStorageBackend]] = None, - force: bool = False, - prefixes: Union[str, list, tuple, None] = None, -): - """Register a backend. - - Args: - name (str): The name of the registered backend. - backend (class, optional): The backend class to be registered, - which must be a subclass of :class:`BaseStorageBackend`. - When this method is used as a decorator, backend is None. - Defaults to None. - force (bool): Whether to override the backend if the name has already - been registered. Defaults to False. - prefixes (str or list[str] or tuple[str], optional): The prefix - of the registered storage backend. Defaults to None. - - This method can be used as a normal method or a decorator. - - Examples: - - >>> class NewBackend(BaseStorageBackend): - ... def get(self, filepath): - ... return filepath - ... - ... def get_text(self, filepath): - ... return filepath - >>> register_backend('new', NewBackend) - - >>> @register_backend('new') - ... class NewBackend(BaseStorageBackend): - ... def get(self, filepath): - ... return filepath - ... - ... def get_text(self, filepath): - ... return filepath - """ - if backend is not None: - _register_backend(name, backend, force=force, prefixes=prefixes) - return - - def _register(backend_cls): - _register_backend(name, backend_cls, force=force, prefixes=prefixes) - return backend_cls - - return _register - - -register_backend("local", LocalBackend, prefixes="") -# To avoid breaking backward Compatibility, 's3' is also used as a -# prefix for Boto3Backend -register_backend("s3", Boto3Backend, prefixes=["s3"]) -register_backend("http", HTTPBackend, prefixes=["http", "https"]) diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/easy_io.py b/lyra_2/_ext/imaginaire/utils/easy_io/easy_io.py deleted file mode 100644 index 673bbf29b28b78a3281bbbbfe1523de683e97cc4..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/easy_io.py +++ /dev/null @@ -1,1085 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import warnings -from contextlib import contextmanager -from io import BytesIO, StringIO -from pathlib import Path -from typing import IO, Any, Generator, Iterator, Optional, Tuple, Union - -from lyra_2._ext.imaginaire.utils.easy_io.backends import backends, prefix_to_backends -from lyra_2._ext.imaginaire.utils.easy_io.file_client import FileClient -from lyra_2._ext.imaginaire.utils.easy_io.handlers import file_handlers - -backend_instances: dict = {} - - -def is_filepath(filepath): - return isinstance(filepath, (str, Path)) - - -def _parse_uri_prefix(uri: Union[str, Path]) -> str: - """Parse the prefix of uri. - - Args: - uri (str or Path): Uri to be parsed that contains the file prefix. - - Examples: - >>> _parse_uri_prefix('/home/path/of/your/file') - '' - >>> _parse_uri_prefix('s3://path/of/your/file') - 's3' - >>> _parse_uri_prefix('clusterName:s3://path/of/your/file') - 's3' - - Returns: - str: Return the prefix of uri if the uri contains '://'. Otherwise, - return ''. - """ - assert is_filepath(uri) - uri = str(uri) - # if uri does not contains '://', the uri will be handled by - # LocalBackend by default - if "://" not in uri: - return "" - else: - prefix, _ = uri.split("://") - # In the case of Boto3Backend, the prefix may contain the cluster - # name like clusterName:s3://path/of/your/file - if ":" in prefix: - _, prefix = prefix.split(":") - return prefix - - -def _get_file_backend(prefix: str, backend_args: dict): - """Return a file backend based on the prefix or backend_args. - - Args: - prefix (str): Prefix of uri. - backend_args (dict): Arguments to instantiate the corresponding - backend. - """ - # backend name has a higher priority - if "backend" in backend_args: - # backend_args should not be modified - backend_args_bak = backend_args.copy() - backend_name = backend_args_bak.pop("backend") - backend = backends[backend_name](**backend_args_bak) - else: - backend = prefix_to_backends[prefix](**backend_args) - return backend - - -def set_s3_backend( - key: str = "s3:{}", - backend_args: Optional[dict] = None, -): - """register s3 backend. - - Args: - key str: The key to register the s3 backend. Defaults to s3. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - """ - global backend_instances - if backend_args is None: - backend_args = {} - backend = _get_file_backend(key, backend_args) - backend_instances[key] = backend - return backend - - -def get_file_backend( - uri: Union[str, Path, None] = None, - *, - backend_args: Optional[dict] = None, - enable_singleton: bool = False, - backend_key: Optional[str] = None, -): - """Return a file backend based on the prefix of uri or backend_args. - - Args: - uri (str or Path): Uri to be parsed that contains the file prefix. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - enable_singleton (bool): Whether to enable the singleton pattern. - If it is True, the backend created will be reused if the - signature is same with the previous one. Defaults to False. - backend_key: str: The key to register the backend. Defaults to None. - - Returns: - BaseStorageBackend: Instantiated Backend object. - - Examples: - >>> # get file backend based on the prefix of uri - >>> uri = 's3://path/of/your/file' - >>> backend = get_file_backend(uri) - >>> # get file backend based on the backend_args - >>> backend = get_file_backend(backend_args={'backend': 's3'}) - >>> # backend name has a higher priority if 'backend' in backend_args - >>> backend = get_file_backend(uri, backend_args={'backend': 's3'}) - """ - global backend_instances - if backend_key is not None: - if backend_key in backend_instances: - return backend_instances[backend_key] - - if backend_args is None: - backend_args = {} - - if uri is None and "backend" not in backend_args and backend_key is None: - raise ValueError('uri should not be None when "backend" does not exist in backend_args and backend_key is None') - - if uri is not None: - prefix = _parse_uri_prefix(uri) - else: - prefix = "" - - if enable_singleton: - unique_key = f"{prefix}:{json.dumps(backend_args)}" - if unique_key in backend_instances: - return backend_instances[unique_key] - - backend = _get_file_backend(prefix, backend_args) - backend_instances[unique_key] = backend - if backend_key is not None: - backend_instances[backend_key] = backend - return backend - else: - backend = _get_file_backend(prefix, backend_args) - return backend - - -def get( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> bytes: - """Read bytes from a given ``filepath`` with 'rb' mode. - - Args: - filepath (str or Path): Path to read data. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Returns: - bytes: Expected bytes object. - - Examples: - >>> filepath = '/path/of/file' - >>> get(filepath) - b'hello world' - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - return backend.get(filepath) - - -def get_text( - filepath: Union[str, Path], - encoding="utf-8", - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> str: - """Read text from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Returns: - str: Expected text reading from ``filepath``. - - Examples: - >>> filepath = '/path/of/file' - >>> get_text(filepath) - 'hello world' - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - return backend.get_text(filepath, encoding) - - -def put( - obj: bytes, - filepath: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> None: - """Write bytes to a given ``filepath`` with 'wb' mode. - - Note: - ``put`` should create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (bytes): Data to be written. - filepath (str or Path): Path to write data. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Examples: - >>> filepath = '/path/of/file' - >>> put(b'hello world', filepath) - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - backend.put(obj, filepath) - - -def put_text( - obj: str, - filepath: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> None: - """Write text to a given ``filepath`` with 'w' mode. - - Note: - ``put_text`` should create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str, optional): The encoding format used to open the - ``filepath``. Defaults to 'utf-8'. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Examples: - >>> filepath = '/path/of/file' - >>> put_text('hello world', filepath) - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - backend.put_text(obj, filepath) - - -def exists( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - - Examples: - >>> filepath = '/path/of/file' - >>> exists(filepath) - True - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - return backend.exists(filepath) - - -def isdir( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - - Examples: - >>> filepath = '/path/of/dir' - >>> isdir(filepath) - True - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - return backend.isdir(filepath) - - -def isfile( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - - Examples: - >>> filepath = '/path/of/file' - >>> isfile(filepath) - True - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - return backend.isfile(filepath) - - -def join_path( - filepath: Union[str, Path], - *filepaths: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Union[str, Path]: - r"""Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of \*filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - *filepaths (str or Path): Other paths to be concatenated. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Returns: - str: The result of concatenation. - - Examples: - >>> filepath1 = '/path/of/dir1' - >>> filepath2 = 'dir2' - >>> filepath3 = 'path/of/file' - >>> join_path(filepath1, filepath2, filepath3) - '/path/of/dir/dir2/path/of/file' - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - return backend.join_path(filepath, *filepaths) - - -@contextmanager -def get_local_path( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Generator[Union[str, Path], None, None]: - """Download data from ``filepath`` and write the data to local path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Note: - If the ``filepath`` is a local path, just return itself and it will - not be released (removed). - - Args: - filepath (str or Path): Path to be read data. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Yields: - Iterable[str]: Only yield one path. - - Examples: - >>> with get_local_path('s3://bucket/abc.jpg') as path: - ... # do something here - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - with backend.get_local_path(str(filepath)) as local_path: - yield local_path - - -def copyfile( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Union[str, Path]: - """Copy a file src to dst and return the destination file. - - src and dst should have the same prefix. If dst specifies a directory, - the file will be copied into dst using the base filename from src. If - dst specifies a file that already exists, it will be replaced. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: The destination file. - - Raises: - SameFileError: If src and dst are the same file, a SameFileError will - be raised. - - Examples: - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> # src will be copied to '/path1/of/file1' - >>> copyfile(src, dst) - '/path1/of/file1' - - >>> # dst is a directory - >>> dst = '/path1/of/dir' - >>> # src will be copied to '/path1/of/dir/file' - >>> copyfile(src, dst) - '/path1/of/dir/file' - """ - backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) - return backend.copyfile(src, dst) - - -def copytree( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Union[str, Path]: - """Recursively copy an entire directory tree rooted at src to a directory - named dst and return the destination directory. - - src and dst should have the same prefix and dst must not already exist. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - backend_key (str, optional): The key to get the backend from register. - - Returns: - str: The destination directory. - - Raises: - FileExistsError: If dst had already existed, a FileExistsError will be - raised. - - Examples: - >>> src = '/path/of/dir1' - >>> dst = '/path/of/dir2' - >>> copytree(src, dst) - '/path/of/dir2' - """ - backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) - return backend.copytree(src, dst) - - -def copyfile_from_local( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Union[str, Path]: - """Copy a local file src to dst and return the destination file. - - Note: - If the backend is the instance of LocalBackend, it does the same - thing with :func:`copyfile`. - - Args: - src (str or Path): A local file to be copied. - dst (str or Path): Copy file to dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> # dst is a file - >>> src = '/path/of/file' - >>> dst = 's3://openmmlab/mmengine/file1' - >>> # src will be copied to 's3://openmmlab/mmengine/file1' - >>> copyfile_from_local(src, dst) - s3://openmmlab/mmengine/file1 - - >>> # dst is a directory - >>> dst = 's3://openmmlab/mmengine' - >>> # src will be copied to 's3://openmmlab/mmengine/file'' - >>> copyfile_from_local(src, dst) - 's3://openmmlab/mmengine/file' - """ - backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) - return backend.copyfile_from_local(src, dst) - - -def copytree_from_local( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Union[str, Path]: - """Recursively copy an entire directory tree rooted at src to a directory - named dst and return the destination directory. - - Note: - If the backend is the instance of LocalBackend, it does the same - thing with :func:`copytree`. - - Args: - src (str or Path): A local directory to be copied. - dst (str or Path): Copy directory to dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Examples: - >>> src = '/path/of/dir' - >>> dst = 's3://openmmlab/mmengine/dir' - >>> copyfile_from_local(src, dst) - 's3://openmmlab/mmengine/dir' - """ - backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) - return backend.copytree_from_local(src, dst) - - -def copyfile_to_local( - src: Union[str, Path], - dst: Union[str, Path], - dst_type: str, # Choose from ["file", "dir"] - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Union[str, Path]: - """Copy the file src to local dst and return the destination file. - - If dst specifies a directory, the file will be copied into dst using - the base filename from src. If dst specifies a file that already - exists, it will be replaced. - - Note: - If the backend is the instance of LocalBackend, it does the same - thing with :func:`copyfile`. - - Args: - src (str or Path): A file to be copied. - dst (str or Path): Copy file to to local dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: If dst specifies a directory, the file will be copied into dst - using the base filename from src. - - Examples: - >>> # dst is a file - >>> src = 's3://openmmlab/mmengine/file' - >>> dst = '/path/of/file' - >>> # src will be copied to '/path/of/file' - >>> copyfile_to_local(src, dst) - '/path/of/file' - - >>> # dst is a directory - >>> dst = '/path/of/dir' - >>> # src will be copied to '/path/of/dir/file' - >>> copyfile_to_local(src, dst) - '/path/of/dir/file' - """ - assert dst_type in ["file", "dir"] - Path(dst).parent.mkdir(parents=True, exist_ok=True) - backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) - return backend.copyfile_to_local(src, dst, dst_type=dst_type) - - -def copytree_to_local( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Union[str, Path]: - """Recursively copy an entire directory tree rooted at src to a local - directory named dst and return the destination directory. - - Note: - If the backend is the instance of LocalBackend, it does the same - thing with :func:`copytree`. - - Args: - src (str or Path): A directory to be copied. - dst (str or Path): Copy directory to local dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: The destination directory. - - Examples: - >>> src = 's3://openmmlab/mmengine/dir' - >>> dst = '/path/of/dir' - >>> copytree_to_local(src, dst) - '/path/of/dir' - """ - Path(dst).parent.mkdir(parents=True, exist_ok=True) - backend = get_file_backend(dst, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) - return backend.copytree_to_local(src, dst) - - -def remove( - filepath: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> None: - """Remove a file. - - Args: - filepath (str, Path): Path to be removed. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Raises: - FileNotFoundError: If filepath does not exist, an FileNotFoundError - will be raised. - IsADirectoryError: If filepath is a directory, an IsADirectoryError - will be raised. - - Examples: - >>> filepath = '/path/of/file' - >>> remove(filepath) - """ - backend = get_file_backend( - filepath, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - backend.remove(filepath) - - -def rmtree( - dir_path: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> None: - """Recursively delete a directory tree. - - Args: - dir_path (str or Path): A directory to be removed. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Examples: - >>> dir_path = '/path/of/dir' - >>> rmtree(dir_path) - """ - backend = get_file_backend( - dir_path, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - backend.rmtree(dir_path) - - -def copy_if_symlink_fails( - src: Union[str, Path], - dst: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> bool: - """Create a symbolic link pointing to src named dst. - - If failed to create a symbolic link pointing to src, directory copy src to - dst instead. - - Args: - src (str or Path): Create a symbolic link pointing to src. - dst (str or Path): Create a symbolic link named dst. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - bool: Return True if successfully create a symbolic link pointing to - src. Otherwise, return False. - - Examples: - >>> src = '/path/of/file' - >>> dst = '/path1/of/file1' - >>> copy_if_symlink_fails(src, dst) - True - >>> src = '/path/of/dir' - >>> dst = '/path1/of/dir1' - >>> copy_if_symlink_fails(src, dst) - True - """ - backend = get_file_backend(src, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) - return backend.copy_if_symlink_fails(src, dst) - - -def list_dir( - dir_path: Union[str, Path], - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -): - """List all folders in an S3 bucket with a given prefix. - - Args: - dir_path (str | Path): Path of the directory. - - Examples: - >>> dir_path = '/path/of/dir' - >>> for file_path in list_dir(dir_path): - ... print(file_path) - """ - if not dir_path.endswith("/"): - dir_path += "/" - backend = get_file_backend( - dir_path, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - - return backend.list_dir(dir_path) - - -def list_dir_or_file( - dir_path: Union[str, Path], - list_dir: bool = True, - list_file: bool = True, - suffix: Optional[Union[str, Tuple[str]]] = None, - recursive: bool = False, - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - - Args: - dir_path (str or Path): Path of the directory. - list_dir (bool): List the directories. Defaults to True. - list_file (bool): List the path of files. Defaults to True. - suffix (str or tuple[str], optional): File suffix that we are - interested in. Defaults to None. - recursive (bool): If set to True, recursively scan the directory. - Defaults to False. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - - Examples: - >>> dir_path = '/path/of/dir' - >>> for file_path in list_dir_or_file(dir_path): - ... print(file_path) - >>> # list those files and directories in current directory - >>> for file_path in list_dir_or_file(dir_path): - ... print(file_path) - >>> # only list files - >>> for file_path in list_dir_or_file(dir_path, list_dir=False): - ... print(file_path) - >>> # only list directories - >>> for file_path in list_dir_or_file(dir_path, list_file=False): - ... print(file_path) - >>> # only list files ending with specified suffixes - >>> for file_path in list_dir_or_file(dir_path, suffix='.txt'): - ... print(file_path) - >>> # list all files and directory recursively - >>> for file_path in list_dir_or_file(dir_path, recursive=True): - ... print(file_path) - """ - backend = get_file_backend( - dir_path, - backend_args=backend_args, - enable_singleton=True, - backend_key=backend_key, - ) - yield from backend.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) - - -def generate_presigned_url( - url: str, - client_method: str = "get_object", - expires_in: int = 3600, - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> str: - """Generate the presigned url of video stream which can be passed to - mmcv.VideoReader. Now only work on s3 backend. - - Note: - Now only work on s3 backend. - - Args: - url (str): Url of video stream. - client_method (str): Method of client, 'get_object' or - 'put_object'. Defaults to 'get_object'. - expires_in (int): expires, in seconds. Defaults to 3600. - backend_args (dict, optional): Arguments to instantiate the - corresponding backend. Defaults to None. - - Returns: - str: Generated presigned url. - """ - backend = get_file_backend(url, backend_args=backend_args, enable_singleton=True, backend_key=backend_key) - return backend.generate_presigned_url(url, client_method, expires_in) - - -def load( - file: Union[str, Path, IO[Any]], - file_format: Optional[str] = None, - file_client_args: Optional[dict] = None, - fast_backend: bool = False, - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, - **kwargs, -): - """Load data from json/yaml/pickle files. - - This method provides a unified api for loading data from serialized files. - - ``load`` supports loading data from serialized files those can be storaged - in different backends. - - Args: - file (str or :obj:`Path` or file-like object): Filename or a file-like - object. - file_format (str, optional): If not specified, the file format will be - inferred from the file extension, otherwise use the specified one. - Currently supported formats include "json", "yaml/yml" and - "pickle/pkl". - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - ``backend_args`` instead. - fast_backend: bool: Whether to use multiprocess. Defaults to False. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - - Examples: - >>> load('/path/of/your/file') # file is storaged in disk - >>> load('https://path/of/your/file') # file is storaged in Internet - >>> load('s3://path/of/your/file') # file is storaged in s3 - - Returns: - The content from the file. - """ - if isinstance(file, Path): - file = str(file) - if file_format is None and isinstance(file, str): - file_format = file.split(".")[-1] - # convert file_format to lower case - file_format = file_format.lower() - if file_format not in file_handlers: - raise TypeError(f"Unsupported format: {file_format}") - - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. Please use "backend_args" instead', - DeprecationWarning, - ) - if backend_args is not None: - raise ValueError('"file_client_args and "backend_args" cannot be set at the same time.') - - handler = file_handlers[file_format] - if isinstance(file, str): - if file_client_args is not None: - file_client = FileClient.infer_client(file_client_args, file) - file_backend = file_client - else: - file_backend = get_file_backend( - file, - backend_args=backend_args, - backend_key=backend_key, - enable_singleton=True, - ) - - if handler.str_like: - with StringIO(file_backend.get_text(file)) as f: - obj = handler.load_from_fileobj(f, **kwargs) - else: - if fast_backend: - if hasattr(file_backend, "fast_get"): - with BytesIO(file_backend.fast_get(file)) as f: - obj = handler.load_from_fileobj(f, **kwargs) - else: - warnings.warn( - f"fast_backend is not supported by the backend, type {type(file_backend)} fallback to normal get" - ) - with BytesIO(file_backend.get(file)) as f: - obj = handler.load_from_fileobj(f, **kwargs) - else: - with BytesIO(file_backend.get(file)) as f: - obj = handler.load_from_fileobj(f, **kwargs) - elif hasattr(file, "read"): - obj = handler.load_from_fileobj(file, **kwargs) - else: - raise TypeError('"file" must be a filepath str or a file-object') - return obj - - -def dump( - obj: Any, - file: Union[str, Path, IO[Any], None] = None, - file_format: Optional[str] = None, - file_client_args: Optional[dict] = None, - fast_backend: bool = False, - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, - **kwargs, -): - """Dump data to json/yaml/pickle strings or files. - - This method provides a unified api for dumping data as strings or to files, - and also supports custom arguments for each file format. - - ``dump`` supports dumping data as strings or to files which is saved to - different backends. - - Args: - obj (any): The python object to be dumped. - file (str or :obj:`Path` or file-like object, optional): If not - specified, then the object is dumped to a str, otherwise to a file - specified by the filename or file-like object. - file_format (str, optional): Same as :func:`load`. - file_client_args (dict, optional): Arguments to instantiate a - FileClient. See :class:`mmengine.fileio.FileClient` for details. - Defaults to None. It will be deprecated in future. Please use - ``backend_args`` instead. - fast_backend: bool: Whether to use multiprocess. Defaults to False. - backend_args (dict, optional): Arguments to instantiate the - prefix of uri corresponding backend. Defaults to None. - New in v0.2.0. - backend_key: str: The key to register the backend. Defaults to None. - - Examples: - >>> dump('hello world', '/path/of/your/file') # disk - >>> dump('hello world', 's3://path/of/your/file') # ceph or s3 - - Returns: - bool: True for success, False otherwise. - """ - if isinstance(file, Path): - file = str(file) - if file_format is None: - if isinstance(file, str): - file_format = file.split(".")[-1] - elif file is None: - raise ValueError("file_format must be specified since file is None") - # convert file_format to lower case - file_format = file_format.lower() - if file_format not in file_handlers: - raise TypeError(f"Unsupported format: {file_format}") - - if file_client_args is not None: - warnings.warn( - '"file_client_args" will be deprecated in future. Please use "backend_args" instead', - DeprecationWarning, - ) - if backend_args is not None: - raise ValueError('"file_client_args" and "backend_args" cannot be set at the same time.') - - handler = file_handlers[file_format] - if file is None: - return handler.dump_to_str(obj, **kwargs) - elif isinstance(file, str): - if file_client_args is not None: - file_client = FileClient.infer_client(file_client_args, file) - file_backend = file_client - else: - file_backend = get_file_backend( - file, - backend_args=backend_args, - backend_key=backend_key, - enable_singleton=True, - ) - - if handler.str_like: - with StringIO() as f: - handler.dump_to_fileobj(obj, f, **kwargs) - file_backend.put_text(f.getvalue(), file) - else: - with BytesIO() as f: - handler.dump_to_fileobj(obj, f, **kwargs) - if fast_backend: - if hasattr(file_backend, "fast_put"): - file_backend.fast_put(f, file) - else: - warnings.warn("fast_backend is not supported by the backend, fallback to normal put") - file_backend.put(f, file) - else: - file_backend.put(f, file) - elif hasattr(file, "write"): - handler.dump_to_fileobj(obj, file, **kwargs) - else: - raise TypeError('"file" must be a filename str or a file-object') diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/file_client.py b/lyra_2/_ext/imaginaire/utils/easy_io/file_client.py deleted file mode 100644 index 8ea90af82c27789574f9481f4a5f198b978e9547..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/file_client.py +++ /dev/null @@ -1,458 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Generator, Iterator, Optional, Tuple, Union - -from lyra_2._ext.imaginaire.utils.easy_io.backends import ( - BaseStorageBackend, - Boto3Backend, - HTTPBackend, - LocalBackend, -) - - -def is_filepath(filepath): - return isinstance(filepath, (str, Path)) - - -class HardDiskBackend(LocalBackend): - """Raw hard disks storage backend.""" - - @property - def name(self): - return self.__class__.__name__ - - -class FileClient: - """A general file client to access files in different backends. - - The client loads a file or text in a specified backend from its path - and returns it as a binary or text file. There are two ways to choose a - backend, the name of backend and the prefix of path. Although both of them - can be used to choose a storage backend, ``backend`` has a higher priority - that is if they are all set, the storage backend will be chosen by the - backend argument. If they are all `None`, the disk backend will be chosen. - Note that It can also register other backend accessor with a given name, - prefixes, and backend class. In addition, We use the singleton pattern to - avoid repeated object creation. If the arguments are the same, the same - object will be returned. - - Warning: - `FileClient` will be deprecated in future. Please use io functions - in https://mmengine.readthedocs.io/en/latest/api/fileio.html#file-io - - Args: - backend (str, optional): The storage backend type. Options are "disk", - "memcached", "lmdb", "http" and "s3". Defaults to None. - prefix (str, optional): The prefix of the registered storage backend. - Options are "s3", "http", "https". Defaults to None. - - Examples: - >>> # only set backend - >>> file_client = FileClient(backend='s3') - >>> # only set prefix - >>> file_client = FileClient(prefix='s3') - >>> # set both backend and prefix but use backend to choose client - >>> file_client = FileClient(backend='s3', prefix='s3') - >>> # if the arguments are the same, the same object is returned - >>> file_client1 = FileClient(backend='s3') - >>> file_client1 is file_client - True - - Attributes: - client (:obj:`BaseStorageBackend`): The backend object. - """ - - _backends = { - "disk": HardDiskBackend, - "s3": Boto3Backend, - "http": HTTPBackend, - } - - _prefix_to_backends: dict = { - "s3": Boto3Backend, - "http": HTTPBackend, - "https": HTTPBackend, - } - - _instances: dict = {} - - client: Any - - def __new__(cls, backend=None, prefix=None, **kwargs): - if backend is None and prefix is None: - backend = "disk" - if backend is not None and backend not in cls._backends: - raise ValueError( - f"Backend {backend} is not supported. Currently supported ones are {list(cls._backends.keys())}" - ) - if prefix is not None and prefix not in cls._prefix_to_backends: - raise ValueError( - f"prefix {prefix} is not supported. Currently supported ones are {list(cls._prefix_to_backends.keys())}" - ) - - # concatenate the arguments to a unique key for determining whether - # objects with the same arguments were created - arg_key = f"{backend}:{prefix}" - for key, value in kwargs.items(): - arg_key += f":{key}:{value}" - - # if a backend was overridden, it will create a new object - if arg_key in cls._instances: - _instance = cls._instances[arg_key] - else: - # create a new object and put it to _instance - _instance = super().__new__(cls) - if backend is not None: - _instance.client = cls._backends[backend](**kwargs) - else: - _instance.client = cls._prefix_to_backends[prefix](**kwargs) - - cls._instances[arg_key] = _instance - - return _instance - - @property - def name(self): - return self.client.name - - @property - def allow_symlink(self): - return self.client.allow_symlink - - @staticmethod - def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]: - """Parse the prefix of a uri. - - Args: - uri (str | Path): Uri to be parsed that contains the file prefix. - - Examples: - >>> FileClient.parse_uri_prefix('s3://path/of/your/file') - 's3' - - Returns: - str | None: Return the prefix of uri if the uri contains '://' else - ``None``. - """ - assert is_filepath(uri) - uri = str(uri) - if "://" not in uri: - return None - else: - prefix, _ = uri.split("://") - # In the case of Boto3Backend, the prefix may contains the cluster - # name like clusterName:s3 - if ":" in prefix: - _, prefix = prefix.split(":") - return prefix - - @classmethod - def infer_client( - cls, - file_client_args: Optional[dict] = None, - uri: Optional[Union[str, Path]] = None, - ) -> "FileClient": - """Infer a suitable file client based on the URI and arguments. - - Args: - file_client_args (dict, optional): Arguments to instantiate a - FileClient. Defaults to None. - uri (str | Path, optional): Uri to be parsed that contains the file - prefix. Defaults to None. - - Examples: - >>> uri = 's3://path/of/your/file' - >>> file_client = FileClient.infer_client(uri=uri) - >>> file_client_args = {'backend': 's3'} - >>> file_client = FileClient.infer_client(file_client_args) - - Returns: - FileClient: Instantiated FileClient object. - """ - assert file_client_args is not None or uri is not None - if file_client_args is None: - file_prefix = cls.parse_uri_prefix(uri) # type: ignore - return cls(prefix=file_prefix) - else: - return cls(**file_client_args) - - @classmethod - def _register_backend(cls, name, backend, force=False, prefixes=None): - if not isinstance(name, str): - raise TypeError(f"the backend name should be a string, but got {type(name)}") - if not inspect.isclass(backend): - raise TypeError(f"backend should be a class but got {type(backend)}") - if not issubclass(backend, BaseStorageBackend): - raise TypeError(f"backend {backend} is not a subclass of BaseStorageBackend") - if not force and name in cls._backends: - raise KeyError( - f'{name} is already registered as a storage backend, add "force=True" if you want to override it' - ) - - if name in cls._backends and force: - for arg_key, instance in list(cls._instances.items()): - if isinstance(instance.client, cls._backends[name]): - cls._instances.pop(arg_key) - cls._backends[name] = backend - - if prefixes is not None: - if isinstance(prefixes, str): - prefixes = [prefixes] - else: - assert isinstance(prefixes, (list, tuple)) - for prefix in prefixes: - if prefix not in cls._prefix_to_backends: - cls._prefix_to_backends[prefix] = backend - elif (prefix in cls._prefix_to_backends) and force: - overridden_backend = cls._prefix_to_backends[prefix] - for arg_key, instance in list(cls._instances.items()): - if isinstance(instance.client, overridden_backend): - cls._instances.pop(arg_key) - else: - raise KeyError( - f"{prefix} is already registered as a storage backend," - ' add "force=True" if you want to override it' - ) - - @classmethod - def register_backend(cls, name, backend=None, force=False, prefixes=None): - """Register a backend to FileClient. - - This method can be used as a normal class method or a decorator. - - .. code-block:: python - - class NewBackend(BaseStorageBackend): - - def get(self, filepath): - return filepath - - def get_text(self, filepath): - return filepath - - FileClient.register_backend('new', NewBackend) - - or - - .. code-block:: python - - @FileClient.register_backend('new') - class NewBackend(BaseStorageBackend): - - def get(self, filepath): - return filepath - - def get_text(self, filepath): - return filepath - - Args: - name (str): The name of the registered backend. - backend (class, optional): The backend class to be registered, - which must be a subclass of :class:`BaseStorageBackend`. - When this method is used as a decorator, backend is None. - Defaults to None. - force (bool, optional): Whether to override the backend if the name - has already been registered. Defaults to False. - prefixes (str or list[str] or tuple[str], optional): The prefixes - of the registered storage backend. Defaults to None. - `New in version 1.3.15.` - """ - if backend is not None: - cls._register_backend(name, backend, force=force, prefixes=prefixes) - return - - def _register(backend_cls): - cls._register_backend(name, backend_cls, force=force, prefixes=prefixes) - return backend_cls - - return _register - - def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]: - """Read data from a given ``filepath`` with 'rb' mode. - - Note: - There are two types of return values for ``get``, one is ``bytes`` - and the other is ``memoryview``. The advantage of using memoryview - is that you can avoid copying, and if you want to convert it to - ``bytes``, you can use ``.tobytes()``. - - Args: - filepath (str or Path): Path to read data. - - Returns: - bytes | memoryview: Expected bytes object or a memory view of the - bytes object. - """ - return self.client.get(filepath) - - def get_text(self, filepath: Union[str, Path], encoding="utf-8") -> str: - """Read data from a given ``filepath`` with 'r' mode. - - Args: - filepath (str or Path): Path to read data. - encoding (str): The encoding format used to open the ``filepath``. - Defaults to 'utf-8'. - - Returns: - str: Expected text reading from ``filepath``. - """ - return self.client.get_text(filepath, encoding) - - def put(self, obj: bytes, filepath: Union[str, Path]) -> None: - """Write data to a given ``filepath`` with 'wb' mode. - - Note: - ``put`` should create a directory if the directory of ``filepath`` - does not exist. - - Args: - obj (bytes): Data to be written. - filepath (str or Path): Path to write data. - """ - self.client.put(obj, filepath) - - def put_text(self, obj: str, filepath: Union[str, Path]) -> None: - """Write data to a given ``filepath`` with 'w' mode. - - Note: - ``put_text`` should create a directory if the directory of - ``filepath`` does not exist. - - Args: - obj (str): Data to be written. - filepath (str or Path): Path to write data. - encoding (str, optional): The encoding format used to open the - `filepath`. Defaults to 'utf-8'. - """ - self.client.put_text(obj, filepath) - - def remove(self, filepath: Union[str, Path]) -> None: - """Remove a file. - - Args: - filepath (str, Path): Path to be removed. - """ - self.client.remove(filepath) - - def exists(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path exists. - - Args: - filepath (str or Path): Path to be checked whether exists. - - Returns: - bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise. - """ - return self.client.exists(filepath) - - def isdir(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a directory. - - Args: - filepath (str or Path): Path to be checked whether it is a - directory. - - Returns: - bool: Return ``True`` if ``filepath`` points to a directory, - ``False`` otherwise. - """ - return self.client.isdir(filepath) - - def isfile(self, filepath: Union[str, Path]) -> bool: - """Check whether a file path is a file. - - Args: - filepath (str or Path): Path to be checked whether it is a file. - - Returns: - bool: Return ``True`` if ``filepath`` points to a file, ``False`` - otherwise. - """ - return self.client.isfile(filepath) - - def join_path(self, filepath: Union[str, Path], *filepaths: Union[str, Path]) -> str: - r"""Concatenate all file paths. - - Join one or more filepath components intelligently. The return value - is the concatenation of filepath and any members of \*filepaths. - - Args: - filepath (str or Path): Path to be concatenated. - - Returns: - str: The result of concatenation. - """ - return self.client.join_path(filepath, *filepaths) - - @contextmanager - def get_local_path(self, filepath: Union[str, Path]) -> Generator[Union[str, Path], None, None]: - """Download data from ``filepath`` and write the data to local path. - - ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It - can be called with ``with`` statement, and when exists from the - ``with`` statement, the temporary path will be released. - - Note: - If the ``filepath`` is a local path, just return itself. - - .. warning:: - ``get_local_path`` is an experimental interface that may change in - the future. - - Args: - filepath (str or Path): Path to be read data. - - Examples: - >>> file_client = FileClient(prefix='s3') - >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path: - ... # do something here - - Yields: - Iterable[str]: Only yield one path. - """ - with self.client.get_local_path(str(filepath)) as local_path: - yield local_path - - def list_dir_or_file( # pylint: disable=too-many-arguments - self, - dir_path: Union[str, Path], - list_dir: bool = True, - list_file: bool = True, - suffix: Optional[Union[str, Tuple[str]]] = None, - recursive: bool = False, - ) -> Iterator[str]: - """Scan a directory to find the interested directories or files in - arbitrary order. - - Note: - :meth:`list_dir_or_file` returns the path relative to ``dir_path``. - - Args: - dir_path (str | Path): Path of the directory. - list_dir (bool): List the directories. Defaults to True. - list_file (bool): List the path of files. Defaults to True. - suffix (str or tuple[str], optional): File suffix - that we are interested in. Defaults to None. - recursive (bool): If set to True, recursively scan the - directory. Defaults to False. - - Yields: - Iterable[str]: A relative path to ``dir_path``. - """ - yield from self.client.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/__init__.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/__init__.py deleted file mode 100644 index c12df057b379da77c451024de0a3f3e6d077d84d..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.json_handler import JsonHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.registry_utils import file_handlers, register_handler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.yaml_handler import YamlHandler - -__all__ = [ - "BaseFileHandler", - "JsonHandler", - "PickleHandler", - "YamlHandler", - "register_handler", - "file_handlers", -] diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/base.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/base.py deleted file mode 100644 index 6822157143d5c7001bacce3e4d1945f38d4ec649..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/base.py +++ /dev/null @@ -1,44 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABCMeta, abstractmethod - - -class BaseFileHandler(metaclass=ABCMeta): - # `str_like` is a flag to indicate whether the type of file object is - # str-like object or bytes-like object. Pickle only processes bytes-like - # objects but json only processes str-like object. If it is str-like - # object, `StringIO` will be used to process the buffer. - str_like = True - - @abstractmethod - def load_from_fileobj(self, file, **kwargs): - pass - - @abstractmethod - def dump_to_fileobj(self, obj, file, **kwargs): - pass - - @abstractmethod - def dump_to_str(self, obj, **kwargs): - pass - - def load_from_path(self, filepath, mode="r", **kwargs): - with open(filepath, mode) as f: - return self.load_from_fileobj(f, **kwargs) - - def dump_to_path(self, obj, filepath, mode="w", **kwargs): - with open(filepath, mode) as f: - self.dump_to_fileobj(obj, f, **kwargs) diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/byte_handler.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/byte_handler.py deleted file mode 100644 index bae967a27971d80e3dd1d32337db24ffa650444c..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/byte_handler.py +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import IO - -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler - - -class ByteHandler(BaseFileHandler): - str_like = False - - def load_from_fileobj(self, file: IO[bytes], **kwargs): - file.seek(0) - # extra all bytes and return - return file.read() - - def dump_to_fileobj( - self, - obj: bytes, - file: IO[bytes], - **kwargs, - ): - # write all bytes to file - file.write(obj) - - def dump_to_str(self, obj, **kwargs): - raise NotImplementedError diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/csv_handler.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/csv_handler.py deleted file mode 100644 index 23c490b672c01b48d790e8ea33b14bce9f71f3ad..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/csv_handler.py +++ /dev/null @@ -1,42 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import csv -from io import StringIO - -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler - - -class CsvHandler(BaseFileHandler): - def load_from_fileobj(self, file, **kwargs): - del kwargs - reader = csv.reader(file) - return list(reader) - - def dump_to_fileobj(self, obj, file, **kwargs): - del kwargs - writer = csv.writer(file) - if not all(isinstance(row, list) for row in obj): - raise ValueError("Each row must be a list") - writer.writerows(obj) - - def dump_to_str(self, obj, **kwargs): - del kwargs - output = StringIO() - writer = csv.writer(output) - if not all(isinstance(row, list) for row in obj): - raise ValueError("Each row must be a list") - writer.writerows(obj) - return output.getvalue() diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/gzip_handler.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/gzip_handler.py deleted file mode 100644 index 0aad7442c87caedf1b4007ea76e20da23be10b96..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/gzip_handler.py +++ /dev/null @@ -1,33 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gzip -import pickle -from io import BytesIO -from typing import Any - -from lyra_2._ext.imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler - - -class GzipHandler(PickleHandler): - str_like = False - - def load_from_fileobj(self, file: BytesIO, **kwargs): - with gzip.GzipFile(fileobj=file, mode="rb") as f: - return pickle.load(f) - - def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs): - with gzip.GzipFile(fileobj=file, mode="wb") as f: - pickle.dump(obj, f) diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/imageio_video_handler.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/imageio_video_handler.py deleted file mode 100644 index 0c75827ae430015b8d956b56860951781040ce02..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/imageio_video_handler.py +++ /dev/null @@ -1,168 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import IO, Any, Dict, Tuple - -import imageio -import imageio.v3 as iio_v3 -import numpy as np -import torch - -from lyra_2._ext.imaginaire.utils import log -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler - - -class ImageioVideoHandler(BaseFileHandler): - str_like = False - - def load_from_fileobj( - self, file: IO[bytes], format: str = "mp4", mode: str = "rgb", **kwargs - ) -> Tuple[np.ndarray, Dict[str, Any]]: - """ - Load video from a file-like object using imageio.v3 with specified format and color mode. - - Parameters: - file (IO[bytes]): A file-like object containing video data. - format (str): Format of the video file (default 'mp4'). - mode (str): Color mode of the video, 'rgb' or 'gray' (default 'rgb'). - - Returns: - tuple: A tuple containing an array of video frames and metadata about the video. - """ - file.seek(0) - - # The plugin argument in v3 replaces the format argument in v2 - plugin = kwargs.pop("plugin", "pyav") - - # Load all frames at once using v3 API - video_frames = iio_v3.imread(file, plugin=plugin, **kwargs) - - # Handle grayscale conversion if needed - if mode == "gray": - import cv2 - - if len(video_frames.shape) == 4: # (frames, height, width, channels) - gray_frames = [] - for frame in video_frames: - gray_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) - gray_frame = np.expand_dims(gray_frame, axis=2) # Keep dimensions consistent - gray_frames.append(gray_frame) - video_frames = np.array(gray_frames) - - # Extract metadata - # Note: iio_v3.imread doesn't return metadata directly like v2 did - # We need to extract it separately - file.seek(0) - metadata = self._extract_metadata(file, plugin=plugin) - - return video_frames, metadata - - def _extract_metadata(self, file: IO[bytes], plugin: str = "pyav") -> Dict[str, Any]: - """ - Extract metadata from a video file. - - Parameters: - file (IO[bytes]): File-like object containing video data. - plugin (str): Plugin to use for reading. - - Returns: - dict: Video metadata. - """ - try: - # Create a generator to read frames and metadata - metadata = iio_v3.immeta(file, plugin=plugin) - - # Add some standard fields similar to v2 metadata format - if "fps" not in metadata and "duration" in metadata: - # Read the first frame to get shape information - file.seek(0) - first_frame = iio_v3.imread(file, plugin=plugin, index=0) - metadata["size"] = first_frame.shape[1::-1] # (width, height) - metadata["source_size"] = metadata["size"] - - # Create a consistent metadata structure with v2 - metadata["plugin"] = plugin - if "codec" not in metadata: - metadata["codec"] = "unknown" - if "pix_fmt" not in metadata: - metadata["pix_fmt"] = "unknown" - - # Calculate nframes if possible - if "fps" in metadata and "duration" in metadata: - metadata["nframes"] = int(metadata["fps"] * metadata["duration"]) - else: - metadata["nframes"] = float("inf") - - return metadata - - except Exception as e: - # Fallback to basic metadata - return { - "plugin": plugin, - "nframes": float("inf"), - "codec": "unknown", - "fps": 30.0, # Default values - "duration": 0, - "size": (0, 0), - } - - def dump_to_fileobj( - self, - obj: np.ndarray | torch.Tensor, - file: IO[bytes], - format: str = "mp4", # pylint: disable=redefined-builtin - fps: int = 17, - quality: int = 5, - ffmpeg_params=None, - **kwargs, - ): - """ - Save an array of video frames to a file-like object using imageio. - - Parameters: - obj (Union[np.ndarray, torch.Tensor]): An array of frames to be saved as video. - file (IO[bytes]): A file-like object to which the video data will be written. - format (str): Format of the video file (default 'mp4'). - fps (int): Frames per second of the output video (default 17). - quality (int): Quality of the video (0-10, default 5). - ffmpeg_params (list): Additional parameters to pass to ffmpeg. - - """ - if isinstance(obj, torch.Tensor): - assert obj.dtype == torch.uint8, "Tensor must be of type uint8" - obj = obj.cpu().numpy() - h, w = obj.shape[1:-1] - - # Default ffmpeg params that ensure width and height are set - default_ffmpeg_params = ["-s", f"{w}x{h}"] - - # Use provided ffmpeg_params if any, otherwise use defaults - final_ffmpeg_params = ffmpeg_params if ffmpeg_params is not None else default_ffmpeg_params - - mimsave_kwargs = { - "fps": fps, - "quality": quality, - "macro_block_size": 1, - "ffmpeg_params": final_ffmpeg_params, - "output_params": ["-f", "mp4"], - } - # Update with any other kwargs - mimsave_kwargs.update(kwargs) - log.debug(f"mimsave_kwargs: {mimsave_kwargs}") - - imageio.mimsave(file, obj, format, **mimsave_kwargs) - - def dump_to_str(self, obj, **kwargs): - raise NotImplementedError diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/json_handler.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/json_handler.py deleted file mode 100644 index 967b2b09ef2a926e947513e92d9ab7e115365aa8..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/json_handler.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -import numpy as np - -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler - - -def set_default(obj): - """Set default json values for non-serializable values. - - It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. - It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, - etc.) into plain numbers of plain python built-in types. - """ - if isinstance(obj, (set, range)): - return list(obj) - elif isinstance(obj, np.ndarray): - return obj.tolist() - elif isinstance(obj, np.generic): - return obj.item() - raise TypeError(f"{type(obj)} is unsupported for json dump") - - -class JsonHandler(BaseFileHandler): - def load_from_fileobj(self, file): - return json.load(file) - - def dump_to_fileobj(self, obj, file, **kwargs): - kwargs.setdefault("default", set_default) - json.dump(obj, file, **kwargs) - - def dump_to_str(self, obj, **kwargs): - kwargs.setdefault("default", set_default) - return json.dumps(obj, **kwargs) diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/jsonl_handler.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/jsonl_handler.py deleted file mode 100644 index 0fb18a1f78be05b7c1ca758f5c25a5673fa6e463..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/jsonl_handler.py +++ /dev/null @@ -1,80 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from typing import IO - -import numpy as np - -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler - - -def set_default(obj): - """Set default json values for non-serializable values. - - It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list. - It also converts ``np.generic`` (including ``np.int32``, ``np.float32``, - etc.) into plain numbers of plain python built-in types. - """ - if isinstance(obj, (set, range)): - return list(obj) - elif isinstance(obj, np.ndarray): - return obj.tolist() - elif isinstance(obj, np.generic): - return obj.item() - raise TypeError(f"{type(obj)} is unsupported for json dump") - - -class JsonlHandler(BaseFileHandler): - """Handler for JSON lines (JSONL) files.""" - - def load_from_fileobj(self, file: IO[bytes]): - """Load JSON objects from a newline-delimited JSON (JSONL) file object. - - Returns: - A list of Python objects loaded from each JSON line. - """ - data = [] - for line in file: - line = line.strip() - if not line: - continue # skip empty lines if any - data.append(json.loads(line)) - return data - - def dump_to_fileobj(self, obj: IO[bytes], file, **kwargs): - """Dump a list of objects to a newline-delimited JSON (JSONL) file object. - - Args: - obj: A list (or iterable) of objects to dump line by line. - """ - kwargs.setdefault("default", set_default) - for item in obj: - file.write(json.dumps(item, **kwargs) + "\n") - - def dump_to_str(self, obj, **kwargs): - """Dump a list of objects to a newline-delimited JSON (JSONL) string.""" - kwargs.setdefault("default", set_default) - lines = [json.dumps(item, **kwargs) for item in obj] - return "\n".join(lines) - - -if __name__ == "__main__": - from lyra_2._ext.imaginaire.utils.easy_io import easy_io - - easy_io.dump([1, 2, 3], "test.jsonl", file_format="jsonl") - print(easy_io.load("test.jsonl")) - easy_io.dump([{"key1": 1, "key2": 2}, {"key1": 3, "key2": 4}], "test.jsonl", file_format="jsonl") - print(easy_io.load("test.jsonl")) diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/np_handler.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/np_handler.py deleted file mode 100644 index 3e8966992992c9963139c9b63968be34e4f82329..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/np_handler.py +++ /dev/null @@ -1,89 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from io import BytesIO -from typing import IO, Any - -import numpy as np - -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler - - -class NumpyHandler(BaseFileHandler): - str_like = False - - def load_from_fileobj(self, file: IO[bytes], **kwargs) -> Any: - """ - Load a NumPy array from a file-like object. - - Parameters: - file (IO[bytes]): The file-like object containing the NumPy array data. - **kwargs: Additional keyword arguments passed to `np.load`. - - Returns: - numpy.ndarray: The loaded NumPy array. - """ - return np.load(file, **kwargs) - - def load_from_path(self, filepath: str, **kwargs) -> Any: - """ - Load a NumPy array from a file path. - - Parameters: - filepath (str): The path to the file to load. - **kwargs: Additional keyword arguments passed to `np.load`. - - Returns: - numpy.ndarray: The loaded NumPy array. - """ - return super().load_from_path(filepath, mode="rb", **kwargs) - - def dump_to_str(self, obj: np.ndarray, **kwargs) -> str: - """ - Serialize a NumPy array to a string in binary format. - - Parameters: - obj (np.ndarray): The NumPy array to serialize. - **kwargs: Additional keyword arguments passed to `np.save`. - - Returns: - str: The serialized NumPy array as a string. - """ - with BytesIO() as f: - np.save(f, obj, **kwargs) - return f.getvalue() - - def dump_to_fileobj(self, obj: np.ndarray, file: IO[bytes], **kwargs): - """ - Dump a NumPy array to a file-like object. - - Parameters: - obj (np.ndarray): The NumPy array to dump. - file (IO[bytes]): The file-like object to which the array is dumped. - **kwargs: Additional keyword arguments passed to `np.save`. - """ - np.save(file, obj, **kwargs) - - def dump_to_path(self, obj: np.ndarray, filepath: str, **kwargs): - """ - Dump a NumPy array to a file path. - - Parameters: - obj (np.ndarray): The NumPy array to dump. - filepath (str): The file path where the array should be saved. - **kwargs: Additional keyword arguments passed to `np.save`. - """ - with open(filepath, "wb") as f: - np.save(f, obj, **kwargs) diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/pandas_handler.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/pandas_handler.py deleted file mode 100644 index 04077f3f0a2c456cf06499c7d2371f2adda42974..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/pandas_handler.py +++ /dev/null @@ -1,31 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pandas as pd - -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler # isort:skip - - -class PandasHandler(BaseFileHandler): - str_like = False - - def load_from_fileobj(self, file, **kwargs): - return pd.read_csv(file, **kwargs) - - def dump_to_fileobj(self, obj, file, **kwargs): - obj.to_csv(file, **kwargs) - - def dump_to_str(self, obj, **kwargs): - raise NotImplementedError("PandasHandler does not support dumping to str") diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/pickle_handler.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/pickle_handler.py deleted file mode 100644 index ff38b135efcd6bd81c772b6e0b53ec2aed6f821c..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/pickle_handler.py +++ /dev/null @@ -1,42 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pickle -from io import BytesIO -from typing import Any - -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler - - -class PickleHandler(BaseFileHandler): - str_like = False - - def load_from_fileobj(self, file: BytesIO, **kwargs): - return pickle.load(file, **kwargs) - - def load_from_path(self, filepath, **kwargs): - return super().load_from_path(filepath, mode="rb", **kwargs) - - def dump_to_str(self, obj, **kwargs): - kwargs.setdefault("protocol", 2) - return pickle.dumps(obj, **kwargs) - - def dump_to_fileobj(self, obj: Any, file: BytesIO, **kwargs): - kwargs.setdefault("protocol", 2) - pickle.dump(obj, file, **kwargs) - - def dump_to_path(self, obj, filepath, **kwargs): - with open(filepath, "wb") as f: - pickle.dump(obj, f, **kwargs) diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/pil_handler.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/pil_handler.py deleted file mode 100644 index 35c9dc79425634ece0be98c2478333ac673ac101..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/pil_handler.py +++ /dev/null @@ -1,96 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import IO, Optional, Tuple, Union - -import numpy as np - -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler - -try: - from PIL import Image -except ImportError: - Image = None - - -class PILHandler(BaseFileHandler): - format: str - str_like = False - - def load_from_fileobj( - self, - file: IO[bytes], - fmt: str = "pil", - size: Optional[Union[int, Tuple[int, int]]] = None, - **kwargs, - ): - """ - Load an image from a file-like object and return it in a specified format. - - Args: - file (IO[bytes]): A file-like object containing the image data. - fmt (str): The format to convert the image into. Options are \ - 'numpy', 'np', 'npy', 'type' (all return numpy arrays), \ - 'pil' (returns PIL Image), 'th', 'torch' (returns a torch tensor). - size (Optional[Union[int, Tuple[int, int]]]): The new size of the image as a single integer \ - or a tuple of (width, height). If specified, the image is resized accordingly. - **kwargs: Additional keyword arguments that can be passed to conversion functions. - - Returns: - Image data in the format specified by `fmt`. - - Raises: - IOError: If the image cannot be loaded or processed. - ValueError: If the specified format is unsupported. - """ - try: - img = Image.open(file) - img.load() # Explicitly load the image data - if size is not None: - if isinstance(size, int): - size = ( - size, - size, - ) # create a tuple if only one integer is provided - img = img.resize(size, Image.ANTIALIAS) - - # Return the image in the requested format - if fmt in ["numpy", "np", "npy"]: - return np.array(img, **kwargs) - if fmt == "pil": - return img - if fmt in ["th", "torch"]: - import torch - - # Convert to tensor - img_tensor = torch.from_numpy(np.array(img, **kwargs)) - # Convert image from HxWxC to CxHxW - if img_tensor.ndim == 3: - img_tensor = img_tensor.permute(2, 0, 1) - return img_tensor - raise ValueError( - "Unsupported format. Supported formats are 'numpy', 'np', 'npy', 'pil', 'th', and 'torch'." - ) - except Exception as e: - raise IOError(f"Unable to load image: {e}") from e - - def dump_to_fileobj(self, obj, file: IO[bytes], **kwargs): - if "format" not in kwargs: - kwargs["format"] = self.format - kwargs["format"] = "JPEG" if self.format.lower() == "jpg" else self.format.upper() - obj.save(file, **kwargs) - - def dump_to_str(self, obj, **kwargs): - raise NotImplementedError diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/registry_utils.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/registry_utils.py deleted file mode 100644 index a5e7d6329a43814383f3883bcabad3cc796a5aa6..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/registry_utils.py +++ /dev/null @@ -1,89 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.byte_handler import ByteHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.csv_handler import CsvHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.gzip_handler import GzipHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.imageio_video_handler import ImageioVideoHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.json_handler import JsonHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.jsonl_handler import JsonlHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.np_handler import NumpyHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.pandas_handler import PandasHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.pickle_handler import PickleHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.pil_handler import PILHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.tarfile_handler import TarHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.torch_handler import TorchHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.torchjit_handler import TorchJitHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.txt_handler import TxtHandler -from lyra_2._ext.imaginaire.utils.easy_io.handlers.yaml_handler import YamlHandler - -file_handlers = { - "json": JsonHandler(), - "yaml": YamlHandler(), - "yml": YamlHandler(), - "pickle": PickleHandler(), - "pkl": PickleHandler(), - "tar": TarHandler(), - "jit": TorchJitHandler(), - "npy": NumpyHandler(), - "txt": TxtHandler(), - "csv": CsvHandler(), - "pandas": PandasHandler(), - "gz": GzipHandler(), - "jsonl": JsonlHandler(), - "byte": ByteHandler(), -} - -for torch_type in ["pt", "pth", "ckpt"]: - file_handlers[torch_type] = TorchHandler() -for img_type in ["jpg", "jpeg", "png", "bmp", "gif"]: - file_handlers[img_type] = PILHandler() - file_handlers[img_type].format = img_type -try: - from lyra_2._ext.imaginaire.utils.easy_io.handlers.trimesh_handler import TrimeshHandler - - for mesh_type in ["ply", "stl", "obj", "glb"]: - file_handlers[mesh_type] = TrimeshHandler() - file_handlers[mesh_type].format = mesh_type -except ImportError: - pass -for video_type in ["mp4", "avi", "mov", "webm", "flv", "wmv"]: - file_handlers[video_type] = ImageioVideoHandler() - - -def _register_handler(handler, file_formats): - """Register a handler for some file extensions. - - Args: - handler (:obj:`BaseFileHandler`): Handler to be registered. - file_formats (str or list[str]): File formats to be handled by this - handler. - """ - if not isinstance(handler, BaseFileHandler): - raise TypeError(f"handler must be a child of BaseFileHandler, not {type(handler)}") - if isinstance(file_formats, str): - file_formats = [file_formats] - if not all([isinstance(item, str) for item in file_formats]): - raise TypeError("file_formats must be a str or a list of str") - for ext in file_formats: - file_handlers[ext] = handler - - -def register_handler(file_formats, **kwargs): - def wrap(cls): - _register_handler(cls(**kwargs), file_formats) - return cls - - return wrap diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/tarfile_handler.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/tarfile_handler.py deleted file mode 100644 index 44fd6fa84d040d67e2b99c5c444c69e2ee6bcfc0..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/tarfile_handler.py +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tarfile - -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler - - -class TarHandler(BaseFileHandler): - str_like = False - - def load_from_fileobj(self, file, mode="r|*", **kwargs): - return tarfile.open(fileobj=file, mode=mode, **kwargs) - - def load_from_path(self, filepath, mode="r|*", **kwargs): - return tarfile.open(filepath, mode=mode, **kwargs) - - def dump_to_fileobj(self, obj, file, mode="w", **kwargs): - with tarfile.open(fileobj=file, mode=mode) as tar: - tar.add(obj, **kwargs) - - def dump_to_path(self, obj, filepath, mode="w", **kwargs): - with tarfile.open(filepath, mode=mode) as tar: - tar.add(obj, **kwargs) - - def dump_to_str(self, obj, **kwargs): - raise NotImplementedError diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/torch_handler.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/torch_handler.py deleted file mode 100644 index 62a78f29ce37152fcdcf6a5f1948bab9f14ded72..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/torch_handler.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -try: - import torch -except ImportError: - torch = None - -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler - - -class TorchHandler(BaseFileHandler): - str_like = False - - def load_from_fileobj(self, file, **kwargs): - return torch.load(file, **kwargs) - - def dump_to_fileobj(self, obj, file, **kwargs): - torch.save(obj, file, **kwargs) - - def dump_to_str(self, obj, **kwargs): - raise NotImplementedError diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/torchjit_handler.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/torchjit_handler.py deleted file mode 100644 index e2acee6b3db8f56a120a43b3f6e0293ef7d23379..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/torchjit_handler.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -try: - import torch -except ImportError: - torch = None - -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler - - -class TorchJitHandler(BaseFileHandler): - str_like = False - - def load_from_fileobj(self, file, **kwargs): - return torch.jit.load(file, **kwargs) - - def dump_to_fileobj(self, obj, file, **kwargs): - torch.jit.save(obj, file, **kwargs) - - def dump_to_str(self, obj, **kwargs): - raise NotImplementedError diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/trimesh_handler.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/trimesh_handler.py deleted file mode 100644 index e84fc4eb77667f9356b6bad42198f101bbe53f20..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/trimesh_handler.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import IO - -import trimesh - -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler - - -class TrimeshHandler(BaseFileHandler): - format: str - str_like = False - - def load_from_fileobj(self, file: IO[bytes], **kwargs) -> trimesh.Trimesh: - file = trimesh.load(file_obj=file, file_type=self.format) - return file - - def dump_to_fileobj(self, obj, file: IO[bytes], **kwargs): - obj.export(file_obj=file, file_type=self.format) - return file - - def dump_to_str(self, obj, **kwargs): - raise NotImplementedError diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/txt_handler.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/txt_handler.py deleted file mode 100644 index abed95b97784889578ca80411e4746d43b4330f0..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/txt_handler.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler - - -class TxtHandler(BaseFileHandler): - def load_from_fileobj(self, file, **kwargs): - del kwargs - return file.read() - - def dump_to_fileobj(self, obj, file, **kwargs): - del kwargs - if not isinstance(obj, str): - obj = str(obj) - file.write(obj) - - def dump_to_str(self, obj, **kwargs): - del kwargs - if not isinstance(obj, str): - obj = str(obj) - return obj diff --git a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/yaml_handler.py b/lyra_2/_ext/imaginaire/utils/easy_io/handlers/yaml_handler.py deleted file mode 100644 index 4b384dac71757fe2783b96251e643b889bf58140..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/easy_io/handlers/yaml_handler.py +++ /dev/null @@ -1,38 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import yaml - -try: - from yaml import CDumper as Dumper # type: ignore - from yaml import CLoader as Loader # type: ignore -except ImportError: - from yaml import Dumper, Loader # type: ignore - -from lyra_2._ext.imaginaire.utils.easy_io.handlers.base import BaseFileHandler # isort:skip - - -class YamlHandler(BaseFileHandler): - def load_from_fileobj(self, file, **kwargs): - kwargs.setdefault("Loader", Loader) - return yaml.load(file, **kwargs) - - def dump_to_fileobj(self, obj, file, **kwargs): - kwargs.setdefault("Dumper", Dumper) - yaml.dump(obj, file, **kwargs) - - def dump_to_str(self, obj, **kwargs): - kwargs.setdefault("Dumper", Dumper) - return yaml.dump(obj, **kwargs) diff --git a/lyra_2/_ext/imaginaire/utils/ema.py b/lyra_2/_ext/imaginaire/utils/ema.py deleted file mode 100644 index b5e82411fea0194485df02b57e733c9d7c3e6ca7..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/ema.py +++ /dev/null @@ -1,333 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union - -import numpy as np -import torch - -try: - from megatron.core import parallel_state - - USE_MEGATRON = True -except ImportError: - USE_MEGATRON = False - -from lyra_2._ext.imaginaire.utils import distributed, log - -if TYPE_CHECKING: - from lyra_2._ext.imaginaire.model import ImaginaireModel - - -class FastEmaModelUpdater: - """ - This class is used to update target model~(EMA) given source model~(regular model) and beta. - The method interaface mimic :class:`EMAModelTracker` and :class:`PowerEMATracker`. - Different from two classes, this class does not maintain the EMA model weights as buffers. It expects the user to have two module with same architecture and weights shape. - The class is proposed to work with FSDP model where above two classes are not working as expected. Besides, it is strange to claim model weights as buffers and do unnecessary name changing in :class:`EMAModelTracker` and :class:`PowerEMATracker`. Moeving forward, we should use this class instead of above two classes. - """ - - def __init__(self): - # Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite - self.is_cached = False - - def update_average(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module, beta: float = 0.9999) -> None: - target_list = [] - source_list = [] - for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): - assert tgt_params.dtype == torch.float32, ( - f"EMA model only works in FP32 dtype, got {tgt_params.dtype} instead." - ) - target_list.append(tgt_params) - source_list.append(src_params.data) - torch._foreach_mul_(target_list, beta) - torch._foreach_add_(target_list, source_list, alpha=1.0 - beta) - - def copy_to(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module) -> None: - for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): - tgt_params.data.copy_(src_params.data) - - def cache(self, parameters: Any, is_cpu: bool = False) -> None: - """Save the current parameters for restoring later. - - Args: - parameters (iterable): Iterable of torch.nn.Parameter to be temporarily stored. - """ - assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?" - device = "cpu" if is_cpu else "cuda" - self.collected_params = [param.clone().to(device) for param in parameters] - self.is_cached = True - - def restore(self, parameters: Any) -> None: - """Restore the parameters in self.collected_params. - - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before copy_to(). - After validation (or model saving), use this to restore the former parameters. - - Args: - parameters (iterable): Iterable of torch.nn.Parameter to be updated with the stored parameters. - """ - assert self.is_cached, "EMA cache is not taken yet." - for c_param, param in zip(self.collected_params, parameters, strict=False): - param.data.copy_(c_param.data.type_as(param.data)) - self.collected_params = [] - # Release the cache after we call restore - self.is_cached = False - - -def get_buffer_name(param_name: str, torch_compile_buffer_renaming: bool = False) -> str: - """ - This function creates buffer name used by EMA from parameter's name - - Args: - param_name (str): Model's parameter name - Returns: - buffer_name (str): buffer name to be used for given parameter name - """ - - buffer_name = param_name.replace(".", "-") - - if torch_compile_buffer_renaming: - # torch.compile() adds _orig_mod to state dict names, this way we get original name - buffer_name = buffer_name.replace("_orig_mod-", "") - - return buffer_name - - -class EMAModelTracker(torch.nn.Module): - """This is a class to track the EMA model weights. - - The EMA weights are registered as buffers, which are extractable as state dicts. The names follow those of the - regular weights, except all "." are replaced with "-" (limitation of register_buffer()). This is similar to SDXL's - implementation of EMA. There are no optimizable parameters. - - Attributes: - collected_params (list): temporarily stores the regular weights while in EMA mode. - beta (float): EMA decay rate. (default: 0.9999). - torch_compile_buffer_renaming (bool): whether to remove '_orig_mod-' from buffer names when torch.compile is used - """ - - def __init__(self, model: ImaginaireModel, beta: float = 0.9999, torch_compile_buffer_renaming: bool = False): - """Constructor of the EMA model weight tracker. - - Args: - model (ImaginaireModel): The PyTorch model. - beta (float): EMA decay rate. (default: 0.9999). - """ - super().__init__() - self.torch_compile_buffer_renaming: bool = torch_compile_buffer_renaming - if not 0.0 <= beta <= 1.0: - raise ValueError("Decay must be between 0 and 1") - self.beta = beta - for name, param in model.named_parameters(): - if param.requires_grad: - buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming) - self.register_buffer(buffer_name, param.clone().detach().data) - self.collected_params = [] - # Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite - self.is_cached = False - - @torch.no_grad() - def update_average(self, model: ImaginaireModel, iteration: Optional[int] = None) -> None: - del iteration - target_list = [] - source_list = [] - ema_buffers = self.state_dict() - for name, param in model.named_parameters(): - if param.requires_grad: - buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming) - buffer = ema_buffers[buffer_name] - assert buffer.dtype == torch.float32, f"EMA model only works in FP32 dtype, got {buffer.dtype} instead." - target_list.append(buffer) - source_list.append(param.data) - torch._foreach_mul_(target_list, self.beta) - torch._foreach_add_(target_list, source_list, alpha=1.0 - self.beta) - - def copy_to(self, model: ImaginaireModel) -> None: - ema_buffers = self.state_dict() - for name, param in model.named_parameters(): - if param.requires_grad: - buffer_name = get_buffer_name(name, self.torch_compile_buffer_renaming) - buffer = ema_buffers[buffer_name] - param.data.copy_(buffer.data) - - def cache(self, parameters: Any, is_cpu: bool = False) -> None: - """Save the current parameters for restoring later. - - Args: - parameters (iterable): Iterable of torch.nn.Parameter to be temporarily stored. - """ - assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?" - device = "cpu" if is_cpu else "cuda" - self.collected_params = [param.clone().to(device) for param in parameters] - self.is_cached = True - - def restore(self, parameters: Any) -> None: - """Restore the parameters in self.collected_params. - - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before copy_to(). - After validation (or model saving), use this to restore the former parameters. - - Args: - parameters (iterable): Iterable of torch.nn.Parameter to be updated with the stored parameters. - """ - assert self.is_cached, "EMA cache is not taken yet." - for c_param, param in zip(self.collected_params, parameters, strict=False): - param.data.copy_(c_param.data.type_as(param.data)) - self.collected_params = [] - # Release the cache after we call restore - self.is_cached = False - - @classmethod - def initialize_multi_rank_ema( - cls, model: torch.nn.Module, rate: Union[float, List[float]], num: int = 1, enabled: bool = True - ) -> Optional[EMAModelTracker]: - """ - Class method to initialize per rank EMA Model Tracker with different rate. - Each rank will have a different rate based on the given configuration, resulting in different EMA weights. - - Args: - model (torch.nn.Module): The neural network model to be tracked. - rate (Union[float, List[float]]): The decay rate(s) for the EMA. If a list is provided, - it corresponds to rates for different ranks. - num (int, optional): The number of leading ranks to consider for different rates. - Defaults to 1. - enabled (bool, optional): Flag to enable or disable the creation of the tracker. - If False, returns None. Defaults to True. - - Returns: - Optional[EMAModelTracker]: An instance of EMAModelTracker if enabled, otherwise None. - - Example: - >>> model = torch.nn.Linear(10, 2) - >>> tracker = EMAModelTracker.initialize_ema_from_settings(model, rate=[0.1, 0.2], num=2) - >>> print(tracker) - - Notes: - If `rate` is a list and the current rank is less than `num`, the rate for the current rank - is used. If the current rank exceeds `num`, the first rate in the list is used by default. - """ - if not enabled: - return None - if USE_MEGATRON and parallel_state.is_initialized(): - cur_dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) - log.critical(f"using MCore parallel_state for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) - log.warning("It should not used together with FSDP!") - else: - cur_dp_rank = distributed.get_rank() - log.critical(f"using torch.distributed for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) - rate = rate if isinstance(rate, list) else [rate] - num = min(num, len(rate)) - rate = rate[cur_dp_rank] if cur_dp_rank < num else rate[0] - if cur_dp_rank < num: - print(f"EMAModelTracker: rank {cur_dp_rank}, rate {rate}") - return cls(model, rate) - - -class PowerEMATracker(EMAModelTracker): - def __init__(self, model: ImaginaireModel, s: float = 0.1, torch_compile_buffer_renaming: bool = False): - """Constructor of the EMA model weight tracker. - - Args: - model (ImaginaireModel): The PyTorch model. - s (float): EMA decay rate. See EDM2 paper - torch_compile_buffer_renaming (bool): whether to remove '_orig_mod-' from buffer names when torch.compile is used - """ - super().__init__(model=model, beta=0.0, torch_compile_buffer_renaming=torch_compile_buffer_renaming) - self.exp = np.roots([1, 7, 16 - s**-2, 12 - s**-2]).real.max() - - @torch.no_grad() - def update_average(self, model: ImaginaireModel, iteration: Optional[int] = None) -> None: - if iteration == 0: - beta = 0.0 - else: - i = iteration + 1 - beta = (1 - 1 / i) ** (self.exp + 1) - self.beta = beta - - super().update_average(model, iteration) - - @classmethod - def initialize_multi_rank_ema( - cls, model: torch.nn.Module, rate: float, num: int, enabled: bool = True - ) -> Optional[PowerEMATracker]: - """ - Class method to initialize per rank EMA Model Tracker with different rate. - Each rank will have a different rate based on the given configuration, resulting in different EMA weights. - - Args: - model (torch.nn.Module): The neural network model for which the EMA tracker is being set up. - num (int): The number of ranks for which the rate adjustment is applied. Beyond this, the rate remains unchanged. - rate (float): The base decay rate for the EMA calculation. - enabled (bool, optional): Flag to enable or disable the initialization of the tracker. If False, returns None. - Defaults to True. - - Returns: - Optional[PowerEMATracker]: An instance of PowerEMATracker with adjusted rate if enabled, otherwise None. - - Raises: - None - - Example: - >>> model = torch.nn.Linear(10, 2) - >>> tracker = PowerEMATracker.initialize_multi_rank_ema(model, num=3, rate=0.99) - >>> print(tracker) - - Notes: - The decay rate is modified by dividing it by 2 raised to the power of the rank for each rank less than `num`. - If the rank is greater than or equal to `num`, the base rate is used without modification. This approach - allows higher ranked processes to have a less aggressive decay, potentially reflecting their delayed synchronization - in a distributed training scenario. - """ - if not enabled: - return None - if USE_MEGATRON and parallel_state.is_initialized(): - cur_dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) - log.critical(f"using MCore parallel_state for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) - log.warning("It should not used together with FSDP!") - else: - cur_dp_rank = distributed.get_rank() - log.critical(f"using torch.distributed for EMA initialization. DP RANK: {cur_dp_rank}", rank0_only=False) - - divider = 2**cur_dp_rank if cur_dp_rank < num else 1 - if cur_dp_rank < num: - print(f"PowerEMATracker: rank {cur_dp_rank}, rate {rate / divider}") - return cls(model, rate / divider) - - -@contextmanager -def ema_scope(model: ImaginaireModel, enabled: bool = False) -> Generator[None, None, None]: - """Context manager for switching between regular and EMA model weights. - - Args: - model (ImaginaireModel): The PyTorch model. - enabled (bool): Whether switching to EMA weights is enabled (default: False). - """ - if enabled: - assert hasattr(model, "ema") and isinstance(model.ema, (FastEmaModelUpdater, EMAModelTracker, PowerEMATracker)) - model.ema.cache(model.parameters()) - model.ema.copy_to(model) - log.info("EMA: switched to EMA weights.") - try: - yield None - finally: - if enabled: - model.ema.restore(model.parameters()) - log.info("EMA: restored regular weights.") diff --git a/lyra_2/_ext/imaginaire/utils/env_parsers/cred_env_parser.py b/lyra_2/_ext/imaginaire/utils/env_parsers/cred_env_parser.py deleted file mode 100644 index 6d7694227e03ebd6d9b32275af7bfc15e6dcc6de..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/env_parsers/cred_env_parser.py +++ /dev/null @@ -1,58 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from lyra_2._ext.imaginaire.utils.env_parsers.env_parser import EnvParser -from lyra_2._ext.imaginaire.utils.validator import String - - -class CredentialEnvParser(EnvParser): - APP_ENV = String(default="") - PROD_FT_AWS_CREDS_ACCESS_KEY_ID = String(default="") - PROD_FT_AWS_CREDS_SECRET_ACCESS_KEY = String(default="") - PROD_FT_AWS_CREDS_ENDPOINT_URL = String(default="https://s3.us-west-2.amazonaws.com") - PROD_FT_AWS_CREDS_REGION_NAME = String(default="us-west-2") - - PROD_S3_CHECKPOINT_ACCESS_KEY_ID = String(default="") - PROD_S3_CHECKPOINT_SECRET_ACCESS_KEY = String(default="") - PROD_S3_CHECKPOINT_ENDPOINT_URL = String(default="") - PROD_S3_CHECKPOINT_REGION_NAME = String(default="") - - PROD_TEAM_DIR_ACCESS_KEY_ID = String(default="") - PROD_TEAM_DIR_SECRET_ACCESS_KEY = String(default="") - PROD_TEAM_DIR_ENDPOINT_URL = String(default="") - PROD_TEAM_DIR_REGION_NAME = String(default="") - - -CRED_ENVS = CredentialEnvParser() -CRED_ENVS_DICT = { - "PROD_FT_AWS_CREDS": { - "aws_access_key_id": CRED_ENVS.PROD_FT_AWS_CREDS_ACCESS_KEY_ID, - "aws_secret_access_key": CRED_ENVS.PROD_FT_AWS_CREDS_SECRET_ACCESS_KEY, - "endpoint_url": CRED_ENVS.PROD_FT_AWS_CREDS_ENDPOINT_URL, - "region_name": CRED_ENVS.PROD_FT_AWS_CREDS_REGION_NAME, - }, - "PROD_S3_CHECKPOINT": { - "aws_access_key_id": CRED_ENVS.PROD_S3_CHECKPOINT_ACCESS_KEY_ID, - "aws_secret_access_key": CRED_ENVS.PROD_S3_CHECKPOINT_SECRET_ACCESS_KEY, - "endpoint_url": CRED_ENVS.PROD_S3_CHECKPOINT_ENDPOINT_URL, - "region_name": CRED_ENVS.PROD_S3_CHECKPOINT_REGION_NAME, - }, - "PROD_TEAM_DIR": { - "aws_access_key_id": CRED_ENVS.PROD_TEAM_DIR_ACCESS_KEY_ID, - "aws_secret_access_key": CRED_ENVS.PROD_TEAM_DIR_SECRET_ACCESS_KEY, - "endpoint_url": CRED_ENVS.PROD_TEAM_DIR_ENDPOINT_URL, - "region_name": CRED_ENVS.PROD_TEAM_DIR_REGION_NAME, - }, -} diff --git a/lyra_2/_ext/imaginaire/utils/env_parsers/env_parser.py b/lyra_2/_ext/imaginaire/utils/env_parsers/env_parser.py deleted file mode 100644 index 662a27c276be113243fa7876abaf58a9d4041aac..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/env_parsers/env_parser.py +++ /dev/null @@ -1,127 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import base64 -import json -import os - -from lyra_2._ext.imaginaire.utils import log -from lyra_2._ext.imaginaire.utils.validator import JsonDict, Validator - -""" -Base class for parsing environment variables using validators. -Class will go through its list of validators and retrieve values from same named environment variables. -Validators provide: -- default value -- typed parsing -- enforments of mandatory values - -Additionally the environment variables can be passed as single base64 encoded string. - -we cannot enforce that a component isn't directly using the environment variables. -so evaluation of params should throw error to make sure actual env var is correct. -""" - - -class EnvParser: - def __init__(self, b64_str=None): - if b64_str: - log.critical(f"b64_str recieved: {b64_str}") - self.from_b64(b64_str) - else: - self.from_env() - - def from_env(self): - validators = self.get_val_dict() - for key in validators.keys(): - val = os.getenv(key.upper()) - # log.debug(f"getting env var {key.upper()}: {val}") - if val: - setattr(self, key, val) - self.check_mandatory_values() - - def from_json(self, file_name): - with open(file_name, "r") as f: - log.info(f"Reading env params from {file_name}") - dict = json.load(f) - for key, value in dict.items(): - setattr(self, key, value) - self.check_mandatory_values() - - def to_b64(self): - json_str = self.to_json() - # create bytes-like object for b64 encoder - json_str_bytes = json_str.encode() - b64_str = base64.b64encode(json_str_bytes).decode() - - print(b64_str) - return b64_str - - def from_b64(self, b64_str): - json_str = base64.b64decode(b64_str).decode() - dict = json.loads(json_str) - for key, value in dict.items(): - setattr(self, key, value) - self.check_mandatory_values() - - def check_mandatory_values(self): - for key, validator in self.get_val_dict().items(): - if getattr(self, key) is None and validator.default is None: - raise ValueError(f"Missing mandatory env var: {key}") - - @classmethod - def get_val_dict(cls): - log.debug(f"getting val dict of {cls.__name__}") - val_dict = {} - val_dict.update({key: value for key, value in cls.__dict__.items() if isinstance(value, Validator)}) - - return val_dict - - def dump_validators(self): - validators = self.get_val_dict() - for key, value in validators.items(): - log.debug(f"{key}: {value.__get__(self)}") - - def to_json(self, file_name=None): - dict = { - key.upper(): value.__get__(self) - for key, value in EnvParser.__dict__.items() - if isinstance(value, Validator) - } - json_str = json.dumps(dict, indent=4) - print(json_str) - - if file_name: - with open(file_name, "w") as f: - log.info(f"Writing env params to {file_name}") - f.write(json_str) - - return json_str - - def to_string_dict(self): - result = {} - for key, validator in self.get_val_dict().items(): - value = getattr(self, key) - if value is None: - value = validator.default - if isinstance(validator, JsonDict): - value = json.dumps(value) - else: - value = str(value) - result[key] = value - return result - - def __str__(self): - return ", ".join(f"{key}={value}" for key, value in self.__dict__.items()) diff --git a/lyra_2/_ext/imaginaire/utils/fsdp_helper.py b/lyra_2/_ext/imaginaire/utils/fsdp_helper.py deleted file mode 100644 index 0a6b744fafeb655fb052290e8eacc3d55325a364..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/fsdp_helper.py +++ /dev/null @@ -1,159 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from contextlib import contextmanager -from functools import partial - -import torch -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointImpl, - apply_activation_checkpointing, - checkpoint_wrapper, -) -from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp._runtime_utils import ( - _post_forward, - _post_forward_reshard, - _pre_forward, - _pre_forward_unshard, - _root_pre_forward, -) -from torch.distributed.utils import _p_assert - -from lyra_2._ext.imaginaire.utils import distributed, log - - -def apply_fsdp_checkpointing(model, list_block_cls): - """apply activation checkpointing to model - returns None as model is updated directly - """ - log.critical("--> applying fdsp activation checkpointing...") - non_reentrant_wrapper = partial( - checkpoint_wrapper, - # offload_to_cpu=False, - checkpoint_impl=CheckpointImpl.NO_REENTRANT, - ) - - def check_fn(submodule): - result = False - for block_cls in list_block_cls: - if isinstance(submodule, block_cls): - result = True - break - return result - - apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn) - - -@contextmanager -def possible_fsdp_scope( - model: torch.nn.Module, -): - enabled = isinstance(model, FSDP) - if enabled: - assert not torch.is_grad_enabled(), "FSDP context should be entered with grad disabled" - handle = model._handle - args, kwargs = [0], dict(dummy=0) - with torch.autograd.profiler.record_function("FullyShardedDataParallel.possible_fsdp_scope"): - args, kwargs = _root_pre_forward(model, model, args, kwargs) - unused = None - args, kwargs = _pre_forward( - model, - handle, - _pre_forward_unshard, - model._fsdp_wrapped_module, - args, - kwargs, - ) - if handle: - _p_assert( - handle.flat_param.device == model.compute_device, - "Expected `FlatParameter` to be on the compute device " - f"{model.compute_device} but got {handle.flat_param.device}", - ) - try: - yield None - finally: - if enabled: - output = {"output": 1} - _post_forward(model, handle, _post_forward_reshard, model, unused, output) - - -def hsdp_device_mesh(replica_group_size=None, sharding_group_size=None, device=None): - """ - Initializes a device mesh for use with Hybrid Sharding strategy in FSDP (HSDP) training. - - This function requires explicit sizes for replica and sharding groups to accommodate models - whose GPU fit is unknown, providing flexibility in distributed training setups. - - Args: - replica_group_size (int): The size of each replica group. Must be provided to ensure - the model fits within the available resources. - sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to - ensure the correct distribution of model parameters. - device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda" - with the local rank as the device index. - - Returns: - A device mesh object compatible with FSDP. - - Raises: - ValueError: If replica_group_size or sharding_group_size are not provided, or if the - world size is not evenly divisible by the sharding group size. - RuntimeError: If a valid device mesh cannot be created. - - Usage: - If your model fits on 4 GPUS, and you have 3 nodes of 8 GPUs, then: - Sharding_Group_Size = 4 - Replica_Groups_Size = (24 total gpus, 4 per sharding group) = 6 Replica Groups - >>> device_mesh = initialize_device_mesh(replica_group_size, sharding_group_size) - >>> sharded_model = FSDP(model, device_mesh=device_mesh, ...) - """ - - # world_size = int(os.getenv("WORLD_SIZE", "1")) - world_size = distributed.get_world_size() - if sharding_group_size is None: - sharding_group_size = min(world_size, 8) - sharding_group_size = min(sharding_group_size, world_size) - if replica_group_size is None: - replica_group_size = world_size // sharding_group_size - - device = device or "cuda" - - if world_size % sharding_group_size != 0: - raise ValueError( - f"World size {world_size} is not evenly divisible by sharding group size {sharding_group_size}." - ) - - if (world_size // sharding_group_size) % replica_group_size != 0: - raise ValueError( - f"The calculated number of replica groups is not evenly divisible by " - f"replica_group_size {replica_group_size}." - ) - - device_mesh = init_device_mesh( - device, (replica_group_size, sharding_group_size), mesh_dim_names=("replicate", "shard") - ) - if device_mesh is None: - raise RuntimeError("Failed to create a valid device mesh.") - - log.info( - f"Device mesh initialized with replica group size {replica_group_size} and sharding group size {sharding_group_size}" - ) - - return device_mesh diff --git a/lyra_2/_ext/imaginaire/utils/log.py b/lyra_2/_ext/imaginaire/utils/log.py deleted file mode 100644 index cb99def71d2bce8af264011518da1321b6d4eead..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/log.py +++ /dev/null @@ -1,162 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import atexit -import os -import sys -from typing import Any, Optional - -import torch.distributed as dist -from loguru._logger import Core, Logger - -RANK0_ONLY = True -LEVEL = os.environ.get("LOGURU_LEVEL", "INFO") - - -def make_new_logger(depth: int = 1) -> Logger: - return Logger( - core=Core(), - exception=None, - depth=depth, - record=False, - lazy=False, - colors=False, - raw=False, - capture=True, - patchers=[], - extra={}, - ) - - -logger = make_new_logger(depth=1) -atexit.register(logger.remove) - - -def _add_relative_path(record: dict[str, Any]) -> None: - start = os.getcwd() - record["extra"]["relative_path"] = os.path.relpath(record["file"].path, start) - - -*options, _, extra = logger._options # type: ignore -logger._options = tuple([*options, [_add_relative_path], extra]) # type: ignore - - -def init_loguru_stdout() -> None: - logger.remove() - datetime_format = get_datetime_format() - machine_format = get_machine_format() - message_format = get_message_format() - logger.add( - sys.stdout, - level=LEVEL, - format=f"{datetime_format}{machine_format}{message_format}", - filter=_rank0_only_filter, - ) - - -def init_loguru_file(path: str) -> None: - datetime_format = get_datetime_format() - machine_format = get_machine_format() - message_format = get_message_format() - logger.add( - path, - encoding="utf8", - level=LEVEL, - format=f"{datetime_format}{machine_format}{message_format}", - rotation="100 MB", - filter=lambda result: _rank0_only_filter(result) or not RANK0_ONLY, - enqueue=True, - ) - - -def get_datetime_format() -> str: - return "[{time:MM-DD HH:mm:ss}|" - - -def get_machine_format() -> str: - node_id = os.environ.get("NGC_ARRAY_INDEX", "0") - num_nodes = int(os.environ.get("NGC_ARRAY_SIZE", "1")) - machine_format = "" - rank = 0 - if dist.is_available(): - if not RANK0_ONLY and dist.is_initialized(): - rank = dist.get_rank() - world_size = dist.get_world_size() - machine_format = ( - f"[Node{node_id:<3}/{num_nodes:<3}][RANK{rank:<5}/{world_size:<5}]" + "[{process.name:<8}]| " - ) - return machine_format - - -def get_message_format() -> str: - message_format = "{level}|{extra[relative_path]}:{line}:{function}] {message}" - return message_format - - -def _rank0_only_filter(record: Any) -> bool: - is_rank0 = record["extra"].get("rank0_only", True) - if _get_rank() == 0 and is_rank0: - return True - if not is_rank0: - record["message"] = f"[RANK {_get_rank()}]" + record["message"] - return not is_rank0 - - -def trace(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).trace(message) - - -def debug(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).debug(message) - - -def info(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).info(message) - - -def success(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).success(message) - - -def warning(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).warning(message) - - -def error(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).error(message) - - -def critical(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).critical(message) - - -def exception(message: str, rank0_only: bool = True) -> None: - logger.opt(depth=1).bind(rank0_only=rank0_only).exception(message) - - -def _get_rank(group: Optional[dist.ProcessGroup] = None) -> int: - """Get the rank (GPU device) of the worker. - - Returns: - rank (int): The rank of the worker. - """ - rank = 0 - if dist.is_available() and dist.is_initialized(): - rank = dist.get_rank(group) - return rank - - -# Execute at import time. -init_loguru_stdout() diff --git a/lyra_2/_ext/imaginaire/utils/misc.py b/lyra_2/_ext/imaginaire/utils/misc.py deleted file mode 100644 index c5fd3242d669a115c9f793b9cb55a87a42b622d4..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/misc.py +++ /dev/null @@ -1,620 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import collections -import collections.abc -import functools -import json -import os -import random -import time -from contextlib import ContextDecorator, nullcontext -from dataclasses import fields -from typing import Any, Callable, List, Tuple, TypeVar, Union - -import numpy as np -from loguru import logger as logging - -try: - import straggler -except ImportError: - straggler = None -import termcolor -import torch -from torch.distributed._functional_collectives import AsyncCollectiveTensor -from torch.distributed._tensor.api import DTensor - -from lyra_2._ext.imaginaire.utils import distributed, log -from lyra_2._ext.imaginaire.utils.distributed import all_gather_tensor -from lyra_2._ext.imaginaire.utils.easy_io import easy_io - - -def requires_grad(model: torch.nn.Module, value: bool = True) -> None: - """Set a model to require gradients or not. - - Args: - model (torch.nn.Module): Neural network model. - value (bool): Whether the network requires gradients or not. - """ - for p in model.parameters(): - p.requires_grad = value - - -def to( - data: Any, - device: str | torch.device | None = None, - dtype: torch.dtype | None = None, - memory_format: torch.memory_format = torch.preserve_format, -) -> Any: - """Recursively cast data into the specified device, dtype, and/or memory_format. - - The input data can be a tensor, a list of tensors, a dict of tensors. - See the documentation for torch.Tensor.to() for details. - - Args: - data (Any): Input data. - device (str | torch.device): GPU device (default: None). - dtype (torch.dtype): data type (default: None). - memory_format (torch.memory_format): memory organization format (default: torch.preserve_format). - - Returns: - data (Any): Data cast to the specified device, dtype, and/or memory_format. - """ - assert device is not None or dtype is not None or memory_format is not None, ( - "at least one of device, dtype, memory_format should be specified" - ) - - if isinstance(data, torch.Tensor): - if ( - memory_format == torch.channels_last - and data.dim() != 4 - or memory_format == torch.channels_last_3d - and data.dim() != 5 - ): - memory_format = torch.preserve_format # do not change the memory format - is_cpu = (isinstance(device, str) and device == "cpu") or ( - isinstance(device, torch.device) and device.type == "cpu" - ) - data = data.to( - device=device, - dtype=dtype, - memory_format=memory_format, - non_blocking=(not is_cpu), - ) - return data - elif isinstance(data, collections.abc.Mapping): - return type(data)({key: to(data[key], device=device, dtype=dtype, memory_format=memory_format) for key in data}) - elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): - return type(data)([to(elem, device=device, dtype=dtype, memory_format=memory_format) for elem in data]) - else: - return data - - -def serialize(data: Any) -> Any: - """Serialize data by hierarchically traversing through iterables. - - Args: - data (Any): Input data. - - Returns: - data (Any): Serialized data. - """ - if isinstance(data, collections.abc.Mapping): - return type(data)({key: serialize(data[key]) for key in data}) - elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)): - return type(data)([serialize(elem) for elem in data]) - else: - try: - json.dumps(data) - except TypeError: - data = str(data) - return data - - -def print_environ_variables(env_vars: list[str]) -> None: - """Print a specific list of environment variables. - - Args: - env_vars (list[str]): List of specified environment variables. - """ - for env_var in env_vars: - if env_var in os.environ: - log.info(f"Environment variable {Color.green(env_var)}: {Color.yellow(os.environ[env_var])}") - else: - log.warning(f"Environment variable {Color.green(env_var)} not set!") - - -def set_random_seed(seed: int, by_rank: bool = False) -> None: - """Set random seed. This includes random, numpy, Pytorch. - - Args: - seed (int): Random seed. - by_rank (bool): if true, each GPU will use a different random seed. - """ - if by_rank: - seed += distributed.get_rank() - log.info(f"Using random seed {seed}.") - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) # sets seed on the current CPU & all GPUs - - -def arch_invariant_rand( - shape: List[int] | Tuple[int], dtype: torch.dtype, device: str | torch.device, seed: int | None = None -): - """Produce a GPU-architecture-invariant randomized Torch tensor. - - Args: - shape (list or tuple of ints): Output tensor shape. - dtype (torch.dtype): Output tensor type. - device (torch.device): Device holding the output. - seed (int): Optional randomization seed. - - Returns: - tensor (torch.tensor): Randomly-generated tensor. - """ - # Create a random number generator, optionally seeded - rng = np.random.RandomState(seed) - - # Generate random numbers using the generator - random_array = rng.standard_normal(shape).astype(np.float32) # Use standard_normal for normal distribution - - # Convert to torch tensor and return - return torch.from_numpy(random_array).to(dtype=dtype, device=device) - - -def get_data_batch_size(data: dict[str, torch.Tensor] | torch.Tensor) -> int: - """Get the batch size from a data batch, a (possibly hierarchical) dictionary of tensors. - - Args: - data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). - - Returns: - batch_size (int): Data batch size. - """ - - def _get_batch_size(input_data: Any) -> Union[int, None]: - """ - Helper function that recursively finds a tensor in the input data - (could be a nested dictionary) and returns its batch size. - """ - if isinstance(input_data, torch.Tensor): - return len(input_data) - elif isinstance(input_data, collections.abc.Mapping): - for key, value in input_data.items(): - batch_size = _get_batch_size(value) - if batch_size is not None: - return batch_size - return None - - batch_size = _get_batch_size(data) - if not isinstance(batch_size, int): - raise ValueError(f"Batch size ({batch_size}) obtained from invalid data: {data}") - return batch_size - - -def parameters_to_buffer(module: torch.nn.Module, persistent: bool = True): - """Convert parameters in a module to buffers. - Buffers do not have its own gradients and thus not updated by backpropagation. - - Args: - module (torch.nn.Module): a module to convert parameters - persistent (bool): If True, buffers are included in state_dict. - """ - named_params = dict() - - for name, param in module.named_parameters(): - named_params[name] = param - - for name, param in named_params.items(): - module_hierarchy = name.split(".") - submodule_name = ".".join(module_hierarchy[:-1]) - submodule = module.get_submodule(submodule_name) - subname = module_hierarchy[-1] - delattr(submodule, subname) - submodule.register_buffer(subname, param, persistent=persistent) - - return - - -T = TypeVar("T", bound=Callable[..., Any]) - - -class timer(ContextDecorator): # noqa: N801 - """Simple timer for timing the execution of code. - - It can be used as either a context manager or a function decorator. The timing result will be logged upon exit. - - Example: - def func_a(): - time.sleep(1) - with timer("func_a"): - func_a() - - @timer("func_b) - def func_b(): - time.sleep(1) - func_b() - """ - - def __init__(self, context: str, debug: bool = False): - self.context = context - self.debug = debug - - def __enter__(self) -> None: - self.tic = time.time() - - def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 - time_spent = time.time() - self.tic - if self.debug: - log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") - else: - log.info(f"Time spent on {self.context}: {time_spent:.4f} seconds") - - def __call__(self, func: T) -> T: - @functools.wraps(func) - def wrapper(*args, **kwargs): # noqa: ANN202 - tic = time.time() - result = func(*args, **kwargs) - time_spent = time.time() - tic - if self.debug: - log.debug(f"Time spent on {self.context}: {time_spent:.4f} seconds") - else: - log.info(f"Time spent on {self.context}: {time_spent:.4f} seconds") - return result - - return wrapper # type: ignore - - -class memory_checker(ContextDecorator): # noqa: N801 - """Simple memory checker for a given block of code. - - It can be used as either a context manager or a function decorator. The memory usage will be logged upon exit. - Example: - def func_a(): - torch.rand([int(1024**2)]).float().cuda() - with memory_checker("func_a"): - func_a() - >>> 0.004GB memory used - - @memory_checker("func_b") - def func_b(): - random_var = torch.rand([int(1024**2)]).cuda() - func_b() - """ - - def __init__(self, context: str, debug: bool = False): - self.context = context - self.debug = debug - - def __enter__(self) -> None: - torch.cuda.synchronize() - torch.cuda.reset_peak_memory_stats() - self.initial_memory = torch.cuda.max_memory_allocated() - - def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 - torch.cuda.synchronize() - final_memory = torch.cuda.max_memory_allocated() - message = f"Memory used within {self.context}: {(final_memory - self.initial_memory) / 1024**3:.4f} GB" - if self.debug: - log.debug(message) - else: - log.info(message) - - def __call__(self, func: T) -> T: - @functools.wraps(func) - def wrapper(*args, **kwargs): # noqa: ANN202 - torch.cuda.synchronize() - torch.cuda.reset_peak_memory_stats() - initial_memory = torch.cuda.max_memory_allocated() - result = func(*args, **kwargs) - torch.cuda.synchronize() - final_memory = torch.cuda.max_memory_allocated() - message = f"Memory used within {self.context}: {(final_memory - initial_memory) / 1024**3:.4f} GB" - if self.debug: - log.debug(message) - else: - log.info(message) - return result - - return wrapper # type: ignore - - -class TrainingTimer: - """Timer for timing the execution of code, aggregating over multiple training iterations. - - It is used as a context manager to measure the execution time of code and store the timing results - for each function. The context managers can be nested. - - Attributes: - results (dict): A dictionary to store timing results for various code. - - Example: - timer = Timer() - for i in range(100): - with timer("func_a"): - func_a() - avg_time = sum(timer.results["func_a"]) / len(timer.results["func_a"]) - print(f"func_a() took {avg_time} seconds.") - """ - - def __init__(self) -> None: - self.results = dict() - self.average_results = dict() - self.start_time = [] - self.func_stack = [] - self.reset() - - def reset(self) -> None: - self.results = {key: [] for key in self.results} - - def __enter__(self) -> TrainingTimer: - self.start_time.append(time.time()) - return self - - def __exit__(self, exc_type, exc_value, traceback) -> None: # noqa: ANN001 - end_time = time.time() - result = end_time - self.start_time.pop() - key = self.func_stack.pop() - self.results.setdefault(key, []) - self.results[key].append(result) - - def __call__(self, func_name: str) -> TrainingTimer: - self.func_stack.append(func_name) - return self - - def __getattr__(self, func_name: str) -> TrainingTimer: - return self.__call__(func_name) - - def nested(self, func_name: str) -> TrainingTimer: - return self.__call__(func_name) - - def compute_average_results(self) -> dict[str, float]: - results = dict() - for key, value_list in self.results.items(): - results[key] = sum(value_list) / len(value_list) - return results - - -def timeout_handler(timeout_period: float, signum: int, frame: int) -> None: - # What to do when the process gets stuck. For now, we simply end the process. - error_message = f"Timeout error: more than {timeout_period} seconds passed since the last iteration." - raise TimeoutError(error_message) - - -class Color: - """A convenience class to colorize strings in the console. - - Example: - import - print("This is {Color.red('important')}.") - """ - - @staticmethod - def red(x: str) -> str: - return termcolor.colored(str(x), color="red") - - @staticmethod - def green(x: str) -> str: - return termcolor.colored(str(x), color="green") - - @staticmethod - def blue(x: str) -> str: - return termcolor.colored(str(x), color="blue") - - @staticmethod - def cyan(x: str) -> str: - return termcolor.colored(str(x), color="cyan") - - @staticmethod - def yellow(x: str) -> str: - return termcolor.colored(str(x), color="yellow") - - @staticmethod - def magenta(x: str) -> str: - return termcolor.colored(str(x), color="magenta") - - @staticmethod - def grey(x: str) -> str: - return termcolor.colored(str(x), color="grey") - - -class BufferCnt: - """ - Buffer counter which keeps track of the condition when called and returns True when the condition in met "thres" - amount of times, otherwise returns False. - - Example usage: - buf = BufferCnt(thres=3) - for _ in range(5): - if buf(random.random() > 0.5): - print("We got lucky 3 times out of 5.") - - Args: - thres (int): The amount of times the expression needs to be True before returning True. - reset_over_thres (bool): Whether to reset the buffer after returning True. - """ - - def __init__(self, thres=10, reset_over_thres=False): - self._cnt = 0 - self.thres = thres - self.reset_over_thres = reset_over_thres - - def __call__(self, expre, thres=None): - if expre is True: - self._cnt += 1 - else: - self._cnt = 0 - - if thres is None: - thres = self.thres - - if self._cnt >= thres: - if self.reset_over_thres: - self.reset() - return True - - return False - - @property - def cnt(self): - return self._cnt - - def reset(self): - self._cnt = 0 - - -def dataclass_instance_to_dict(dataclass: Any) -> dict: - """Convert a dataclass to a dictionary. - - Args: - dataclass (Any): Dataclass object. - - Returns: - dict: Dictionary representation of the dataclass. - """ - return {f.name: getattr(dataclass, f.name) for f in fields(dataclass)} - - -def get_local_tensor_if_DTensor(tensor: torch.Tensor | DTensor) -> torch.tensor: - if isinstance(tensor, DTensor): - local = tensor.to_local() - # As per PyTorch documentation, if the communication is not finished yet, we need to wait for it to finish - # https://pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.DTensor.to_local - if isinstance(local, AsyncCollectiveTensor): - return local.wait() - else: - return local - return tensor - - -class NVTXRangeContext: - """ - Context manager which inserts NVTX range around the current context and optionally calls torch.cuda.synchronize - at the start and the end of the context. - - Args: - name (str): Name of the NVTX range. - enabled (bool): Whether the context manager is enabled. When disabled, it does nothing. Default: True. - synchronize (bool): Whether to call torch.cuda.synchronize() at the start and the end of the context. Default: True. - """ - - def __init__(self, name: str, enabled: bool = True, synchronize: bool = True): - self.name = name - self.enabled = enabled - self.synchronize = synchronize - - def __enter__(self): - if not self.enabled: - return - if self.synchronize: - torch.cuda.synchronize() - torch.cuda.nvtx.range_push(self.name) - - def __exit__(self, exc_type, exc_val, exc_tb): - if not self.enabled: - return - if self.synchronize: - torch.cuda.synchronize() - torch.cuda.nvtx.range_pop() - - -class StragglerDetectorV2: - """StragglerDetectorV2 detects stragglers using low-level CUPTI tool, which can gather kernel execution time - with very low overhead. The execution times are compared across different ranks, as well as to the execution - time of the exact same kernels in the past. This tool can be easily integrated, as it's resilient to any - synchronizations, since it captures kernels execution time. It means that we can wrap the entire forward or - backward passes and the stragglers will be identified regardless of synchronizations happening during the iteration. - - Args: - enabled (bool): Whether the straggler detection is enabled. When disabled, it does nothing. Default: True. - report_freq (int): Generate a report each report_freq iterations that analyzes the GPUs performance. Defaults to 100. - profile_freq (int): Enable the CUPTI profiling each profile_freq iterations. Since the overhead is very low, - the default value is 1. - max_diff (float): Defines the maximum relative difference between the fastest and the slowest rank to determine the slowdown. Defaults to 2.0 - raise_error (bool): Whether to raise error when stragglers are detected enough times. Defaults to True.""" - - def __init__( - self, - enabled: bool = True, - report_freq: int = 100, - profile_freq: int = 1, - max_diff: float = 2.0, - raise_error: bool = True, - ): - self.enabled = enabled - self.report_freq = report_freq - self.profile_freq = profile_freq - self.name = self.__class__.__name__ - self.slowdown_count = BufferCnt(thres=10, reset_over_thres=True) - self.max_diff = max_diff - self.raise_error = raise_error - - def initialize(self): - if self.enabled: - if not straggler: - raise RuntimeError( - "Please install the straggler package before using StragglerDetectionV2." - ) - - straggler.Detector.initialize( - scores_to_compute=["relative_perf_scores", "individual_perf_scores"], - gather_on_rank0=False, # all ranks results will be available on rank 0 - profiling_interval=self.profile_freq, - ) - - def profile_section(self, name: str, section_enabled: bool, profile_cuda: bool = True): - if section_enabled and self.enabled: - return straggler.Detector.detection_section(name, profile_cuda=profile_cuda) - else: - return nullcontext() - - def _aggregate_section_results(self, local_section_summaries): - data = [] - for key in local_section_summaries: - # straggler reports time in ms - data.append(local_section_summaries[key][straggler.Statistic.MAX] / 1000) - return distributed.all_gather_tensor(torch.tensor(data).cuda()) - - def generate_report(self, iteration): - if self.enabled and iteration % self.report_freq == 0: - report = straggler.Detector.generate_report() - gpu_relative_perf_score = report.gpu_relative_perf_scores[distributed.get_rank()] - gpu_relative_perf_score_gather_list = distributed.all_gather_tensor( - torch.tensor([gpu_relative_perf_score]).cuda() - ) - local_section_data = self._aggregate_section_results(report.local_section_summaries) - if distributed.get_rank() == 0: - stragglers = report.identify_stragglers(gpu_rel_threshold=1 / self.max_diff) - data_tensor = torch.tensor(gpu_relative_perf_score_gather_list) - slowest_rank_id = torch.argmin(data_tensor) - # Which GPUs are slower than other GPUs, based on the execution time of kernels - relative_stragglers = stragglers["straggler_gpus_relative"] - # Which GPUs are slower than itself in the past, based on the past execution time of kernels. - individual_stragglers = stragglers["straggler_gpus_individual"] - is_slowdown = relative_stragglers or individual_stragglers - if is_slowdown: - hostname = torch.ByteTensor(bytearray(os.uname().nodename, "utf-8")).cuda() - whole_hostname = all_gather_tensor(hostname) - slowest_hostname = whole_hostname[slowest_rank_id].cpu().numpy().tobytes().decode("utf-8") - logging.critical(f"Slowest rank hostname: {slowest_hostname}") - - if self.slowdown_count(is_slowdown) and self.raise_error: - raise RuntimeError( - f"Detected GPU {slowest_rank_id} to be too slow compared to other GPUs." - f" The relative performance of {slowest_rank_id} rank was {report.gpu_relative_perf_scores[slowest_rank_id]}. Terminating the training." - ) diff --git a/lyra_2/_ext/imaginaire/utils/object_store.py b/lyra_2/_ext/imaginaire/utils/object_store.py deleted file mode 100644 index 09485cea1ecf7f37c5a1d71524954020cb04352e..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/object_store.py +++ /dev/null @@ -1,462 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import io -import json -import os -import pickle -import random -import time -from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional -from urllib.parse import urlparse - -import boto3 -import botocore -import numpy as np -import torch -import yaml -from botocore.config import Config -from botocore.exceptions import ClientError -from PIL import Image - -import lyra_2._ext.imaginaire.utils.easy_io.backends.auto_auth as auto -from lyra_2._ext.imaginaire.utils import distributed, log -from lyra_2._ext.imaginaire.utils.easy_io import easy_io - -GLOBAL_S3_CONFIG = Config(retries={"max_attempts": 20, "mode": "adaptive"}, connect_timeout=10, read_timeout=60) -Image.MAX_IMAGE_PIXELS = None - -if TYPE_CHECKING: - from lyra_2._ext.imaginaire.config import ObjectStoreConfig - - -class ObjectStore: - """This is the interface class for object store, used for interacting with S3-compatible storage. - - Attributes: - client (botocore.client.S3): Object store client object. - bucket (str): Object store bucket name. - """ - - def __init__(self, config_object_storage: ObjectStoreConfig): - with auto.open_auth(config_object_storage.credentials, "r") as file: - object_storage_config = auto.json_load_auth(file) - self.client = Boto3Wrapper( - "s3", - **object_storage_config, - ) - self.bucket = config_object_storage.bucket - - def load_object( - self, - key: str, - type: str | None = None, - load_func: Callable | None = None, - encoding: str = "UTF-8", - max_attempts: int = 10, - ) -> Any: - """Helper function for loading object from S3. - - Args: - key (str): The key of the object. - type (str): Specified for some common data types. If not provided, `load_func` should be specified. - The predefined types currently supported are: - - "torch": PyTorch model checkpoints, opened with torch.load(). - - "torch.jit": A JIT-compiled TorchScript model, loaded with torch.jit.load(). - - "image": Image objects, opened with PIL.Image.open(). - - "json": JSON files, opened with json.load(). - - "pickle": Picklable objects, opened with pickle.load(). - - "yaml": YAML files, opened with yaml.safe_load(). - - "text": Pure text files. - - "numpy": Numpy arrays, opened with np.load(). - - "bytes": Raw bytes. - load_func (Callable): a custom function for reading the buffer if `type` were not provided. - encoding (str): Text encoding standard (default: "UTF-8"). - max_attempts (int): Max number of attempts to load the object if there is a failure. - - Returns: - object (Any): The downloaded object. - """ - - for attempt in range(max_attempts): - try: - return self._load_object( - key, - type=type, - load_func=load_func, - encoding=encoding, - ) - except botocore.exceptions.ClientError as e: - retry_interval = min(0.1 * 2**attempt + random.uniform(0, 1), 30) - log.exception( - f"Failed to load ({self.bucket}) {key}, attempt {attempt}. {e}. Retrying in {retry_interval}s." - ) - if attempt < max_attempts - 1: - time.sleep(retry_interval) - raise ConnectionError(f"Unable to read ({self.bucket}) {key} after {max_attempts} attempts.") - - def _load_object( - self, key: str, type: str | None = None, load_func: Callable | None = None, encoding: str = "UTF-8" - ) -> Any: - """Helper function for loading object from S3. - - Args: - key (str): The key of the object. - type (str): Specified for some common data types. If not provided, `load_func` should be specified. - load_func (Callable): a custom function for reading the buffer if `type` were not provided. - encoding (str): Text encoding standard (default: "UTF-8"). - - Returns: - object (Any): The downloaded object. - """ - assert type is not None or load_func is not None, "Either type or load_func should be specified." - with io.BytesIO() as buffer: - self.client.download_fileobj(Bucket=self.bucket, Key=key, Fileobj=buffer) - buffer.seek(0) - # Read from buffer for common data types. - if type == "torch": - object = torch.load(buffer, map_location=lambda storage, loc: storage, weights_only=False) - elif type == "torch.jit": - object = torch.jit.load(buffer) - elif type == "image": - object = Image.open(buffer) - object.load() - elif type == "json": - object = json.load(buffer) - elif type == "pickle": - object = pickle.load(buffer) - elif type == "yaml": - object = yaml.safe_load(buffer) - elif type == "text": - object = buffer.read().decode(encoding) - elif type == "numpy": - object = np.load(buffer, allow_pickle=True) - # Read from buffer as raw bytes. - elif type == "bytes": - object = buffer.read() - # Customized load_func should be provided. - else: - object = load_func(buffer) - return object - - def save_object( - self, object: Any, key: str, type: str | None = None, save_func: Callable | None = None, encoding: str = "UTF-8" - ) -> None: - """Helper function for saving object to S3. - - Args: - object (Any): The object to upload. - key (str): The key of the object. - type (str): Specified for some common data types. If not provided, `save_func` should be specified. - The predefined types currently supported are: - - "torch": PyTorch model checkpoints, saved with torch.save(). - - "torch.jit": A JIT-compiled TorchScript model, exported with torch.jit.save(). - - "image": Image objects, saved with PIL.Image.save(). - - "json": JSON files, saved with json.dumps(). - - "pickle": Picklable objects, saved with pickle.dump(). - - "yaml": YAML files, saved with yaml.safe_dump(). - - "text": Pure text files. - - "numpy": Numpy arrays, saved with np.save(). - - "bytes": Raw bytes. - save_func (Callable): a custom function for writing the buffer if `type` were not provided. - encoding (str): Text encoding standard (default: "UTF-8"). - """ - assert type is not None or save_func is not None - with io.BytesIO() as buffer: - # Write to buffer for common data types. - if type == "torch": - torch.save(object, buffer) - elif type == "torch.jit": - torch.jit.save(object, buffer) - elif type == "image": - type = os.path.basename(key).split(".")[-1] - object.save(buffer, format=type) - elif type == "json": - buffer.write(json.dumps(object).encode(encoding)) - elif type == "pickle": - pickle.dump(object, buffer) - elif type == "yaml": - buffer.write(yaml.safe_dump(object).encode(encoding)) - elif type == "text": - buffer.write(object.encode(encoding)) - elif type == "numpy": - np.save(buffer, object) - # Write to buffer as raw bytes. - elif type == "bytes": - buffer.write(bytes(object)) - # Customized save_func should be provided. - else: - save_func(object, buffer) - buffer.seek(0) - self.client.upload_fileobj(Bucket=self.bucket, Key=key, Fileobj=buffer) - - def object_exists(self, key: str, max_retries: int = 10, retry_delay: float = 2.0) -> bool: - """ - Check whether an object exists in the storage, with retry logic for transient errors. - - Args: - key (str): The key of the object. - max_retries (int): The maximum number of retry attempts in case of errors. - retry_delay (float): The delay (in seconds) between retry attempts. - - Returns: - bool: True if the object exists, False if not or if an error persists after retries. - """ - for attempt in range(max_retries): - try: - # Attempt to check if the object exists - self.client.head_object(Bucket=self.bucket, Key=key) - return True - except ClientError as e: - if e.response["Error"]["Code"] == "404": - return False # Object does not exist - # Log or print the error for troubleshooting - log.error(f"Attempt {attempt + 1} failed: {e}", rank0_only=False) - - # If this is the last attempt, return False - if attempt == max_retries - 1: - return False - - # Wait for the specified delay before retrying - time.sleep(retry_delay) - except Exception as e: - # Handle other unexpected exceptions - log.error(f"Unexpected error on attempt {attempt + 1}: {e}", rank0_only=False) - - # If this is the last attempt, return False - if attempt == max_retries - 1: - return False - - # Wait for the specified delay before retrying - time.sleep(retry_delay) - - # If all retries fail, return False - return False - - -class Boto3Wrapper: - """ - This class serves as a wrapper around boto3.client in order to make boto3.client serializable. It's required to use - spawn method of creating DataLoader workers, which is in turn required to avoid segfaults when using Triton, e.g. - for torch.compile or custom kernels. - """ - - def __init__(self, *args, **kwargs): - self._args = args - self._kwargs = kwargs - self.client = None - - def __setstate__(self, state): - self.__dict__ = state - - def __getattr__(self, item): - is_worker = torch.utils.data.get_worker_info() is not None - client = ( - boto3.client(*self._args, **self._kwargs, config=GLOBAL_S3_CONFIG) if self.client is None else self.client - ) - if is_worker: - self.client = client - return getattr(client, item) - - -def sync_s3_dir_to_local( - s3_dir: str, - s3_credential_path: str, - cache_dir: Optional[str] = None, - rank_sync: bool = True, - local_rank_sync: bool = False, -) -> str: - """ - Download an entire directory from S3 to the local cache directory. - - Args: - s3_dir (str): The AWS S3 directory to download. - s3_credential_path (str): The path to the AWS S3 credentials file. - rank_sync (bool, optional): Whether to synchronize download across - ALL distributed workers using `distributed.barrier()`. Defaults to True. - cache_dir (str, optional): The cache folder to sync the S3 directory to. - If None, the environment variable `IMAGINAIRE_CACHE_DIR` (defaulting - to "~/.cache/imaginaire") will be used. - local_rank_sync (bool, optional): Whether to synchronize download across - workers within the same node using a node-level barrier. This is useful - when the cache directory is not shared across nodes. Defaults to False. - Note: rank_sync and local_rank_sync cannot both be True. - - Returns: - local_dir (str): The path to the local directory. - """ - if local_rank_sync and rank_sync: - raise ValueError("rank_sync and local_rank_sync cannot be True at the same time.") - - if not s3_dir.startswith("s3://"): - # If the directory exists locally, return the local path - assert os.path.exists(s3_dir), f"{s3_dir} is not a S3 path or a local path." - return s3_dir - - # Get local rank for node-level synchronization - local_rank = int(os.getenv("LOCAL_RANK", 0)) if local_rank_sync else None - - # Load AWS credentials from the file - with open(s3_credential_path, "r") as f: - credentials = json.load(f) - - # Create an S3 client - s3 = boto3.client( - "s3", - **credentials, - ) - - # Parse the S3 URL - parsed_url = urlparse(s3_dir) - source_bucket = parsed_url.netloc - source_prefix = parsed_url.path.lstrip("/") - - # If the local directory is not specified, use the default cache directory - cache_dir = ( - os.environ.get("IMAGINAIRE_CACHE_DIR", os.path.expanduser("~/.cache/imaginaire")) - if cache_dir is None - else cache_dir - ) - cache_dir = os.path.expanduser(cache_dir) - Path(cache_dir).mkdir(parents=True, exist_ok=True) - - # List objects in the bucket with the given prefix - response = s3.list_objects_v2(Bucket=source_bucket, Prefix=source_prefix) - # Download each matching object - for obj in response.get("Contents", []): - if obj["Key"].startswith(source_prefix): - # Create the full path for the destination file, preserving the directory structure - rel_path = os.path.relpath(obj["Key"], source_prefix) - dest_path = os.path.join(cache_dir, source_prefix, rel_path) - - # Ensure the directory exists - os.makedirs(os.path.dirname(dest_path), exist_ok=True) - - # Check if the file already exists - if os.path.exists(dest_path): - continue - else: - log.info(f"Downloading {obj['Key']} to {dest_path}") - # Download the file - if rank_sync: - # Only rank 0 downloads when using global rank sync - if distributed.get_rank() == 0: - s3.download_file(source_bucket, obj["Key"], dest_path) - elif local_rank_sync: - # Only local rank 0 (first rank on each node) downloads when using local rank sync - if local_rank == 0: - s3.download_file(source_bucket, obj["Key"], dest_path) - else: - # No synchronization - every rank downloads - s3.download_file(source_bucket, obj["Key"], dest_path) - # Synchronize after downloads complete - if rank_sync or local_rank_sync: - distributed.barrier() - - local_dir = os.path.join(cache_dir, source_prefix) - return local_dir - - -def download_from_s3_with_cache( - s3_path: str, - s3_credential_path: str, - cache_fp: Optional[str] = None, - cache_dir: Optional[str] = None, - rank_sync: bool = True, - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> str: - """download data from S3 with optional caching. - - This function first attempts to load the data from a local cache file. If - the cache file doesn't exist, it downloads the data from S3 to the cache - location. Caching is performed in a rank-aware manner - using `distributed.barrier()` to ensure only one download occurs across - distributed workers (if `rank_sync` is True). - - Args: - s3_path (str): The S3 path of the data to load. - cache_fp (str, optional): The path to the local cache file. If None, - a filename will be generated based on `s3_path` within `cache_dir`. - cache_dir (str, optional): The directory to store the cache file. If - None, the environment variable `IMAGINAIRE_CACHE_DIR` (defaulting - to "/tmp") will be used. - rank_sync (bool, optional): Whether to synchronize download across - distributed workers using `distributed.barrier()`. Defaults to True. - backend_args (dict, optional): The backend arguments passed to easy_io to construct the backend. - backend_key (str, optional): The backend key passed to easy_io to registry the backend or retrieve the backend if it is already registered. - - Returns: - cache_fp (str): The path to the local cache file. - - Raises: - FileNotFoundError: If the data cannot be found in S3 or the cache. - """ - if not s3_path.startswith("s3://"): - # If the file exists locally, return the local path - assert os.path.exists(s3_path), f"{s3_path} is not a S3 path nor a local path." - return s3_path - - easy_io.set_s3_backend( - backend_args={ - "backend": "s3", - "path_mapping": None, - "s3_credential_path": s3_credential_path, - } - ) - cache_dir = ( - os.environ.get("IMAGINAIRE_CACHE_DIR", os.path.expanduser("~/.cache/imaginaire")) - if cache_dir is None - else cache_dir - ) - cache_dir = os.path.expanduser(cache_dir) - if cache_fp is None: - cache_fp = os.path.join(cache_dir, s3_path.replace("s3://", "")) - if not cache_fp.startswith("/"): - cache_fp = os.path.join(cache_dir, cache_fp) - - if rank_sync: - if distributed.get_rank() == 0: - if os.path.exists(cache_fp): - # check the size of cache_fp - if os.path.getsize(cache_fp) < 1: - os.remove(cache_fp) - log.warning(f"Removed empty cache file {cache_fp}.") - - if not os.path.exists(cache_fp): - easy_io.copyfile_to_local( - s3_path, cache_fp, dst_type="file", backend_args=backend_args, backend_key=backend_key - ) - log.info(f"Downloaded {s3_path} to {cache_fp}.") - else: - log.info(f"The cache file {cache_fp} already exists.") - distributed.barrier() - else: - if os.path.exists(cache_fp): - # check the size of cache_fp - if os.path.getsize(cache_fp) < 1: - os.remove(cache_fp) - log.warning(f"Removed empty cache file {cache_fp}.") - if not os.path.exists(cache_fp): - easy_io.copyfile_to_local( - s3_path, cache_fp, dst_type="file", backend_args=backend_args, backend_key=backend_key - ) - log.info(f"Downloaded {s3_path} to {cache_fp}.") - else: - log.info(f"The cache file {cache_fp} already exists") - return cache_fp diff --git a/lyra_2/_ext/imaginaire/utils/optim_instantiate.py b/lyra_2/_ext/imaginaire/utils/optim_instantiate.py deleted file mode 100644 index 95f6cc9f01c8be89b41fe58d685fdc377ad87a3b..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/optim_instantiate.py +++ /dev/null @@ -1,82 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import hydra -import torch -from torch import nn - -from lyra_2._ext.imaginaire.utils import log - - -def get_regular_param_group(net: nn.Module): - """ - seperate the parameters of the network into two groups: decay and no_decay. - based on nano_gpt codebase. - """ - param_dict = {pn: p for pn, p in net.named_parameters()} - param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} - - decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] - nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] - return decay_params, nodecay_params - - -def get_base_optimizer( - model: nn.Module, - lr: float, - weight_decay: float, - optim_type: str = "adamw", - sharding: bool = False, - **kwargs, -) -> torch.optim.Optimizer: - net_decay_param, net_nodecay_param = get_regular_param_group(model) - - num_decay_params = sum(p.numel() for p in net_decay_param) - num_nodecay_params = sum(p.numel() for p in net_nodecay_param) - net_param_total = num_decay_params + num_nodecay_params - log.critical(f"total num parameters : {net_param_total:,}") - - param_group = [ - { - "params": net_decay_param + net_nodecay_param, - "lr": lr, - "weight_decay": weight_decay, - }, - ] - - if optim_type == "adamw": - opt_cls = torch.optim.AdamW - else: - raise ValueError(f"Unknown optimizer type: {optim_type}") - - return opt_cls(param_group, **kwargs) - - -def get_base_scheduler( - optimizer: torch.optim.Optimizer, - model: nn.Module, - scheduler_config: dict, -): - net_scheduler = hydra.utils.instantiate(scheduler_config) - net_scheduler.model = model - num_param_groups = len(optimizer.param_groups) - - return torch.optim.lr_scheduler.LambdaLR( - optimizer, - lr_lambda=[ - net_scheduler.schedule, - ] - * num_param_groups, - ) diff --git a/lyra_2/_ext/imaginaire/utils/parallel_state_helper.py b/lyra_2/_ext/imaginaire/utils/parallel_state_helper.py deleted file mode 100644 index 0c26066b24021f95286b14f90df8ca18d60eab12..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/parallel_state_helper.py +++ /dev/null @@ -1,35 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This module contains various helper functions designed to extend the functionality of parallel states within the MCore library. - -MCore is a third-party library that is infrequently updated and may introduce backward compatibility issues in our codebase, such as changes in function signatures or missing / new functions in new versions. - -To mitigate these issues, this module provides stable functions that ensure the lyra_2._ext.imaginaire codebase remains compatible with different versions of MCore. -""" - -try: - from megatron.core import parallel_state -except ImportError: - print("Megatron is not installed, is_tp_cp_pp_rank0 functions will not work.") - - -def is_tp_cp_pp_rank0(): - return ( - parallel_state.get_tensor_model_parallel_rank() == 0 - and parallel_state.get_pipeline_model_parallel_rank() == 0 - and parallel_state.get_context_parallel_rank() == 0 - ) diff --git a/lyra_2/_ext/imaginaire/utils/profiling.py b/lyra_2/_ext/imaginaire/utils/profiling.py deleted file mode 100644 index 2ae4cf7ac5fadd0cd347d8e26cd90264c1d818a4..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/profiling.py +++ /dev/null @@ -1,139 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import os -import time - -import torch - -from lyra_2._ext.imaginaire.utils import distributed, log -from lyra_2._ext.imaginaire.utils.easy_io import easy_io - -# (qsh 2024-11-23) credits -# https://github.com/pytorch/torchtitan/blob/main/torchtitan/profiling.py - -# the number of warmup steps before the active step in each profiling cycle -TORCH_TRACE_WARMUP = 3 - -# how much memory allocation/free ops to record in memory snapshots -MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 - - -@contextlib.contextmanager -def maybe_enable_profiling(config, *, global_step: int = 0): - # get user defined profiler settings - enable_profiling = config.trainer.profiling.enable_profiling - profile_freq = config.trainer.profiling.profile_freq - - if enable_profiling: - trace_dir = os.path.join(config.job.path_local, "torch_trace") - if distributed.get_rank() == 0: - os.makedirs(trace_dir, exist_ok=True) - - rank = distributed.get_rank() - - def trace_handler(prof): - curr_trace_dir_name = "iteration_" + str(prof.step_num) - curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name) - if not os.path.exists(curr_trace_dir): - os.makedirs(curr_trace_dir, exist_ok=True) - - log.info(f"Dumping traces at step {prof.step_num}") - begin = time.monotonic() - if rank in config.trainer.profiling.target_ranks: - prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json.gz") - log.info(f"Finished dumping traces in {time.monotonic() - begin:.2f} seconds") - - log.info(f"Profiling active. Traces will be saved at {trace_dir}") - - if not os.path.exists(trace_dir): - os.makedirs(trace_dir, exist_ok=True) - - warmup, active = TORCH_TRACE_WARMUP, 1 - wait = profile_freq - (active + warmup) - assert wait >= 0, "profile_freq must be greater than or equal to warmup + active" - - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), - on_trace_ready=trace_handler, - record_shapes=config.trainer.profiling.record_shape, - profile_memory=config.trainer.profiling.profile_memory, - with_stack=config.trainer.profiling.with_stack, - with_modules=config.trainer.profiling.with_modules, - ) as torch_profiler: - torch_profiler.step_num = global_step - yield torch_profiler - else: - torch_profiler = contextlib.nullcontext() - yield None - - -@contextlib.contextmanager -def maybe_enable_memory_snapshot(config, *, global_step: int = 0): - enable_snapshot = config.trainer.profiling.enable_memory_snapshot - if enable_snapshot: - if config.trainer.profiling.save_s3: - snapshot_dir = "s3://rundir" - else: - snapshot_dir = os.path.join(config.job.path_local, "memory_snapshot") - if distributed.get_rank() == 0: - os.makedirs(snapshot_dir, exist_ok=True) - - rank = torch.distributed.get_rank() - - class MemoryProfiler: - def __init__(self, step_num: int, freq: int): - torch.cuda.memory._record_memory_history(max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES) - # when resume training, we start from the last step - self.step_num = step_num - self.freq = freq - - def step(self, exit_ctx: bool = False): - self.step_num += 1 - if not exit_ctx and self.step_num % self.freq != 0: - return - if not exit_ctx: - curr_step = self.step_num - dir_name = f"iteration_{curr_step}" - else: - # dump as iteration_0_exit if OOM at iter 1 - curr_step = self.step_num - 1 - dir_name = f"iteration_{curr_step}_exit" - curr_snapshot_dir = os.path.join(snapshot_dir, dir_name) - if not config.trainer.profiling.save_s3 and not os.path.exists(curr_snapshot_dir): - os.makedirs(curr_snapshot_dir, exist_ok=True) - log.info(f"Dumping memory snapshot at step {curr_step}") - begin = time.monotonic() - - if rank in config.trainer.profiling.target_ranks: - easy_io.dump( - torch.cuda.memory._snapshot(), - f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", - ) - log.info(f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds") - - log.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}") - profiler = MemoryProfiler(global_step, config.trainer.profiling.profile_freq) - try: - yield profiler - except torch.cuda.OutOfMemoryError as e: - profiler.step(exit_ctx=True) - else: - yield None diff --git a/lyra_2/_ext/imaginaire/utils/s3_utils.py b/lyra_2/_ext/imaginaire/utils/s3_utils.py deleted file mode 100644 index eae454fa56a995eb665ccbfe841e389cf61ed3f6..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/s3_utils.py +++ /dev/null @@ -1,139 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from typing import Any, Optional - -from lyra_2._ext.imaginaire.utils import distributed, log -from lyra_2._ext.imaginaire.utils.easy_io import easy_io - - -def download_from_s3_with_cache( - s3_path: str, - cache_fp: Optional[str] = None, - cache_dir: Optional[str] = None, - rank_sync: bool = True, - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, -) -> str: - """download data from S3 with optional caching. - - This function first attempts to load the data from a local cache file. If - the cache file doesn't exist, it downloads the data from S3 to the cache - location. Caching is performed in a rank-aware manner - using `distributed.barrier()` to ensure only one download occurs across - distributed workers (if `rank_sync` is True). - - Args: - s3_path (str): The S3 path of the data to load. - cache_fp (str, optional): The path to the local cache file. If None, - a filename will be generated based on `s3_path` within `cache_dir`. - cache_dir (str, optional): The directory to store the cache file. If - None, the environment variable `IMAGINAIRE_CACHE_DIR` (defaulting - to "/tmp") will be used. - rank_sync (bool, optional): Whether to synchronize download across - distributed workers using `distributed.barrier()`. Defaults to True. - backend_args (dict, optional): The backend arguments passed to easy_io to construct the backend. - backend_key (str, optional): The backend key passed to easy_io to registry the backend or retrieve the backend if it is already registered. - - Returns: - cache_fp (str): The path to the local cache file. - - Raises: - FileNotFoundError: If the data cannot be found in S3 or the cache. - """ - cache_dir = os.environ.get("TORCH_HOME") if cache_dir is None else cache_dir - cache_dir = ( - os.environ.get("IMAGINAIRE_CACHE_DIR", os.path.expanduser("~/.cache/imaginaire")) - if cache_dir is None - else cache_dir - ) - cache_dir = os.path.expanduser(cache_dir) - if cache_fp is None: - cache_fp = os.path.join(cache_dir, s3_path.replace("s3://", "")) - if not cache_fp.startswith("/"): - cache_fp = os.path.join(cache_dir, cache_fp) - - if distributed.get_rank() == 0: - if os.path.exists(cache_fp): - # check the size of cache_fp - if os.path.getsize(cache_fp) < 1: - os.remove(cache_fp) - log.warning(f"Removed empty cache file {cache_fp}.") - - if rank_sync: - if not os.path.exists(cache_fp): - log.critical(f"Local cache {cache_fp} Not exist! Downloading {s3_path} to {cache_fp}.") - log.info(f"backend_args: {backend_args}") - log.info(f"backend_key: {backend_key}") - - easy_io.copyfile_to_local( - s3_path, cache_fp, dst_type="file", backend_args=backend_args, backend_key=backend_key - ) - log.info(f"Downloaded {s3_path} to {cache_fp}.") - else: - log.info(f"Local cache {cache_fp} already exist! {s3_path} -> {cache_fp}.") - - distributed.barrier() - else: - if not os.path.exists(cache_fp): - easy_io.copyfile_to_local( - s3_path, cache_fp, dst_type="file", backend_args=backend_args, backend_key=backend_key - ) - - log.info(f"Downloaded {s3_path} to {cache_fp}.") - return cache_fp - - -def load_from_s3_with_cache( - s3_path: str, - cache_fp: Optional[str] = None, - cache_dir: Optional[str] = None, - rank_sync: bool = True, - backend_args: Optional[dict] = None, - backend_key: Optional[str] = None, - easy_io_kwargs: Optional[dict] = None, -) -> Any: - """Loads data from S3 with optional caching. - - This function first attempts to load the data from a local cache file. If - the cache file doesn't exist, it downloads the data from S3 to the cache - location and then loads it. Caching is performed in a rank-aware manner - using `distributed.barrier()` to ensure only one download occurs across - distributed workers (if `rank_sync` is True). - - Args: - s3_path (str): The S3 path of the data to load. - cache_fp (str, optional): The path to the local cache file. If None, - a filename will be generated based on `s3_path` within `cache_dir`. - cache_dir (str, optional): The directory to store the cache file. If - None, the environment variable `IMAGINAIRE_CACHE_DIR` (defaulting - to "/tmp") will be used. - rank_sync (bool, optional): Whether to synchronize download across - distributed workers using `distributed.barrier()`. Defaults to True. - backend_args (dict, optional): The backend arguments passed to easy_io to construct the backend. - backend_key (str, optional): The backend key passed to easy_io to registry the backend or retrieve the backend if it is already registered. - - Returns: - Any: The loaded data from the S3 path or cache file. - - Raises: - FileNotFoundError: If the data cannot be found in S3 or the cache. - """ - cache_fp = download_from_s3_with_cache(s3_path, cache_fp, cache_dir, rank_sync, backend_args, backend_key) - - if easy_io_kwargs is None: - easy_io_kwargs = {} - return easy_io.load(cache_fp, **easy_io_kwargs) diff --git a/lyra_2/_ext/imaginaire/utils/validator.py b/lyra_2/_ext/imaginaire/utils/validator.py deleted file mode 100644 index f10386271a7374783fe4c0b10790f33878e8576d..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/utils/validator.py +++ /dev/null @@ -1,514 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import base64 -import binascii -import itertools -import json -import os -from abc import ABC, abstractmethod -from io import BytesIO -from typing import Any - -# Sentinel value to indicate that no default was explicitly set by the user -# we want to mimic usage of function parameters: if no default is provided, the parameter is mandatory -_UNSET = object() - - -# from https://docs.python.org/3/howto/descriptor.html#validator-class -# For usage of hidden flag see the ModelParams class in apis/utils/model_params.py - - -# validators can be customized to very specific needs, e.g. see HumanAttributes below -class Validator(ABC): - def __init__(self, default=_UNSET, hidden=False): - self.default = default - self.hidden = hidden - - # set name is called when the validator is created as class variable - # name is the name of the variable in the owner class, so here we create the name for the backing variable - def __set_name__(self, owner, name): - self.private_name = "_" + name - - def __get__(self, obj, objtype=None): - value = getattr(obj, self.private_name, self.default) - if value is _UNSET: - # If we reach here, it means a mandatory parameter was accessed without being set - attr_name = getattr(self, "private_name", "unknown").lstrip("_") - raise ValueError( - f"Parameter '{attr_name}' is mandatory but has not been set. " - f"No default value was provided and no value was assigned." - ) - return value - - def __set__(self, obj, value): - value = self.validate(value) - setattr(obj, self.private_name, value) - - @abstractmethod - def validate(self, value): - pass - - def json(self): - pass - - -class Bool(Validator): - def __init__(self, default=_UNSET, hidden=False, tooltip=None): - super().__init__(default, hidden) - self.default = default - self.hidden = hidden - self.tooltip = tooltip - - def validate(self, value): - if isinstance(value, int): - value = value != 0 - elif isinstance(value, str): - value = value.lower() - if value in ["true", "1"]: - value = True - elif value in ["false", "0"]: - value = False - else: - raise ValueError(f"Expected {value!r} to be one of ['True', 'False', '1', '0']") - elif not isinstance(value, bool): - raise TypeError(f"Expected {value!r} to be an bool") - - return value - - def get_range_iterator(self): - return [True, False] - - def __repr__(self) -> str: - return f"Bool({self.private_name=} {self.default=} {self.hidden=})" - - def json(self): - return { - "type": bool.__name__, - "default": self.default, - "tooltip": self.tooltip, - } - - -class Int(Validator): - def __init__(self, default=_UNSET, min=None, max=None, step=1, hidden=False, tooltip=None): - self.min = min - self.max = max - self.default = default - self.step = step - self.hidden = hidden - self.tooltip = tooltip - - def validate(self, value): - if isinstance(value, str): - value = int(value) - elif not isinstance(value, int): - raise TypeError(f"Expected {value!r} to be an int") - - if self.min is not None and value < self.min: - raise ValueError(f"Expected {value!r} to be at least {self.min!r}") - if self.max is not None and value > self.max: - raise ValueError(f"Expected {value!r} to be no more than {self.max!r}") - return value - - def get_range_iterator(self): - if self.default is _UNSET: - default_val = 0 - else: - default_val = int(self.default) if isinstance(self.default, (int, float, str)) else 0 - iter_min = self.min if self.min is not None else default_val - iter_max = self.max if self.max is not None else (default_val + 100) - return itertools.takewhile(lambda x: x <= iter_max, itertools.count(iter_min, self.step)) - - def __repr__(self) -> str: - return f"Int({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" - - def json(self): - return { - "type": int.__name__, - "default": self.default, - "min": self.min, - "max": self.max, - "step": self.step, - "tooltip": self.tooltip, - } - - -class Float(Validator): - def __init__(self, default=_UNSET, min=None, max=None, step=0.5, hidden=False, tooltip=None): - self.min = min - self.max = max - self.default = default - self.step = step - self.hidden = hidden - self.tooltip = tooltip - - def validate(self, value): - if isinstance(value, str) or isinstance(value, int): - value = float(value) - elif not isinstance(value, float): - raise TypeError(f"Expected {value!r} to be float") - - if self.min is not None and value < self.min: - raise ValueError(f"Expected {value!r} to be at least {self.min!r}") - if self.max is not None and value > self.max: - raise ValueError(f"Expected {value!r} to be no more than {self.max!r}") - return value - - def get_range_iterator(self): - if self.default is _UNSET: - default_val = 0.0 - else: - default_val = float(self.default) if isinstance(self.default, (int, float, str)) else 0.0 - iter_min = self.min if self.min is not None else default_val - iter_max = self.max if self.max is not None else (default_val + 100.0) - return itertools.takewhile(lambda x: x <= iter_max, itertools.count(iter_min, self.step)) - - def __repr__(self) -> str: - return f"Float({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" - - def json(self): - return { - "type": float.__name__, - "default": self.default, - "min": self.min, - "max": self.max, - "step": self.step, - "tooltip": self.tooltip, - } - - -class String(Validator): - def __init__(self, default=_UNSET, min=None, max=None, predicate=None, hidden=False, tooltip=None): - self.min = min - self.max = max - self.predicate = predicate - self.default = default - self.hidden = hidden - self.tooltip = tooltip - - def validate(self, value): - if value is None: - return value # Allow None as a valid value to be compatible with existing code - # this breaks strict typing, so do this only for strings - if not isinstance(value, str): - raise TypeError(f"Expected {value!r} to be an str or None") - if self.min is not None and len(value) < self.min: - raise ValueError(f"Expected {value!r} to be no smaller than {self.min!r}") - if self.max is not None and len(value) > self.max: - raise ValueError(f"Expected {value!r} to be no bigger than {self.max!r}") - if self.predicate is not None and not self.predicate(value): - raise ValueError(f"Expected {self.predicate} to be true for {value!r}") - return value - - def get_range_iterator(self): - return iter([self.default]) - - def __repr__(self) -> str: - return f"String({self.private_name=} {self.default=}, {self.min=}, {self.max=} {self.hidden=})" - - def json(self): - return { - "type": str.__name__, - "default": self.default, - "tooltip": self.tooltip, - } - - -class Path(Validator): - def __init__(self, default=_UNSET, hidden=False, tooltip=None): - self.default = default - self.hidden = hidden - self.tooltip = tooltip - - def validate(self, value): - if value is None: - return value - if not isinstance(value, str): - raise TypeError(f"{self.private_name} validator: Expected {value!r} to be an str") - if not os.path.exists(value): - raise ValueError(f"{self.private_name} validator: Expected {value!r} to be a valid path") - - return value - - def get_range_iterator(self): - return iter([self.default]) - - def __repr__(self) -> str: - return f"String({self.private_name=} {self.default=}, {self.hidden=})" - - -class InputImage(Validator): - def __init__( - self, default=_UNSET, hidden=False, tooltip=None, supported_formats=["jpeg", "jpg", "png", "bmp", "gif"] - ): - self.default = default - self.hidden = hidden - self.tooltip = tooltip - self.supported_formats = supported_formats - - def validate(self, value): - ext = os.path.splitext(value)[1].lower() - - if ext not in self.supported_formats: - raise ValueError(f"Unsupported image format: {ext}") - - if not isinstance(value, str): - raise TypeError(f"Expected {value!r} to be an str") - if not os.path.exists(value): - raise ValueError(f"Expected {value!r} to be a valid path") - return value - - def get_range_iterator(self): - return iter([self.default]) - - def __repr__(self) -> str: - return f"String({self.private_name=} {self.default=} {self.hidden=})" - - def json(self): - return { - "type": InputImage.__name__, - "default": self.default, - "values": self.supported_formats, - "tooltip": self.tooltip, - } - - -class JsonDict(Validator): - """ - JSON stringified version of a python dict. - Example: '{"ema_customization_iter.pt": "ema_customization_iter.pt"}' - """ - - def __init__(self, default=_UNSET, hidden=False): - self.default = default - self.hidden = hidden - - def validate(self, value): - if not value: - return {} - try: - dict = json.loads(value) - return dict - except json.JSONDecodeError as e: - raise ValueError(f"Expected {value!r} to be json stringified dict. Error: {str(e)}") - - def __repr__(self) -> str: - return f"Dict({self.default=} {self.hidden=})" - - -class Dict(Validator): - """ - Python dict. - Example: {'key': 'value'} - - This allows a single level of parameter nesting, but not a full nested dict. - For now we validate the individual keys here and store the dict as is. - Alternatively, we could have a validator that gets/sets another ValidatorParams class. - """ - - def __init__(self, default=_UNSET, hidden=False): - self.default = default - self.hidden = hidden - - def validate(self, value): - if not isinstance(value, dict): - raise TypeError(f"Expected {value!r} to be an dict") - return value - - def __repr__(self) -> str: - value = getattr(self, self.private_name, self.default) - - return f"Dict({self.private_name=} {self.default=} {self.hidden=} value={json.dumps(value, indent=4)})" - - -class OneOf(Validator): - def __init__(self, default=_UNSET, options=None, type_cast=None, hidden=False, tooltip=None): - self.options = set(options) if options is not None else set() - self.default = default - self.type_cast = type_cast # Cast the value to this type before checking if it's in options - self.tooltip = tooltip - self.hidden = hidden - - def validate(self, value): - if self.type_cast: - try: - value = self.type_cast(value) - except ValueError: - raise ValueError(f"Expected {value!r} to be castable to {self.type_cast!r}") - - if value not in self.options: - raise ValueError(f"Expected {value!r} to be one of {self.options!r}") - - return value - - def get_range_iterator(self): - return self.options - - def __repr__(self) -> str: - return f"OneOf({self.private_name=} {self.options=} {self.hidden=})" - - def json(self): - return { - "type": OneOf.__name__, - "default": self.default, - "values": list(self.options), - "tooltip": self.tooltip, - } - - -class MultipleOf(Validator): - def __init__(self, default=_UNSET, multiple_of: int = 1, type_cast=None, hidden=False, tooltip=None): - if type(multiple_of) is not int: - raise ValueError(f"Expected {multiple_of!r} to be an int") - self.multiple_of = multiple_of - self.default = default - self.type_cast = type_cast - - # For usage of hidden flag see the ModelParams class in apis/utils/model_params.py - # if a parameter is hidden then probe() can't expose the param - # and the param can't be set anymore - self.hidden = hidden - self.tooltip = tooltip - - def validate(self, value): - if self.type_cast: - try: - value = self.type_cast(value) - except ValueError: - raise ValueError(f"Expected {value!r} to be castable to {self.type_cast!r}") - - if value % self.multiple_of != 0: - raise ValueError(f"Expected {value!r} to be a multiple of {self.multiple_of!r}") - - return value - - def get_range_iterator(self): - return itertools.count(0, self.multiple_of) - - def __repr__(self) -> str: - return f"MultipleOf({self.private_name=} {self.multiple_of=} {self.hidden=})" - - def json(self): - return { - "type": MultipleOf.__name__, - "default": self.default, - "multiple_of": self.multiple_of, - "tooltip": self.tooltip, - } - - -class HumanAttributes(Validator): - def __init__(self, default=_UNSET, hidden=False, tooltip=None): - self.default = default - self.hidden = hidden - self.tooltip = tooltip - - # hard code the options for now - # we extend this to init parameter as needed - valid_attributes = { - "emotion": ["angry", "contemptful", "disgusted", "fearful", "happy", "neutral", "sad", "surprised"], - "race": ["asian", "indian", "black", "white", "middle eastern", "latino hispanic"], - "gender": ["male", "female"], - "age group": [ - "young", - "teen", - "adult early twenties", - "adult late twenties", - "adult early thirties", - "adult late thirties", - "adult middle aged", - "older adult", - ], - } - - def get_range_iterator(self): - # create a list of all possible combinations - l1 = self.valid_attributes["emotion"] - l2 = self.valid_attributes["race"] - l3 = self.valid_attributes["gender"] - l4 = self.valid_attributes["age group"] - all_combinations = list(itertools.product(l1, l2, l3, l4)) - return iter(all_combinations) - - def validate(self, value): - human_attributes = value.lower() - if human_attributes not in ["none", "random"]: - # In this case, we need for custom attribute string - - attr_string = human_attributes - for attr_key in ["emotion", "race", "gender", "age group"]: - attr_detected = False - for attr_label in self.valid_attributes[attr_key]: - if attr_string.startswith(attr_label): - attr_string = attr_string[len(attr_label) + 1 :] # noqa: E203 - attr_detected = True - break - - if attr_detected is False: - raise ValueError(f"Expected {value!r} to be one of {self.valid_attributes!r}") - - return value - - def __repr__(self) -> str: - return f"HumanAttributes({self.private_name=} {self.hidden=})" - - def json(self): - return { - "type": HumanAttributes.__name__, - "default": self.default, - "values": self.valid_attributes, - "tooltip": self.tooltip, - } - - -class BytesIOType(Validator): - """ - Validator class for BytesIO. Valid inputs are either: - - bytes - - objects of class BytesIO - - str which can be successfully decoded into BytesIO - """ - - def __init__(self, default=_UNSET, hidden=False, tooltip=None): - self.default = default - self.hidden = hidden - self.tooltip = tooltip - - def validate(self, value: Any) -> BytesIO: - if isinstance(value, str): - try: - # Decode the Base64 string - decoded_bytes = base64.b64decode(value) - # Create a BytesIO stream from the decoded bytes - return BytesIO(decoded_bytes) - except (binascii.Error, ValueError) as e: - raise ValueError(f"Invalid Base64 encoded string: {e}") - elif isinstance(value, bytes): - return BytesIO(value) - elif isinstance(value, BytesIO): - return value - else: - raise TypeError(f"Expected {value!r} to be a Base64 encoded string, bytes, or BytesIO") - - def __repr__(self) -> str: - return f"BytesIOValidator({self.default=}, {self.hidden=})" - - def json(self): - return { - "type": BytesIO.__name__, - "default": self.default, - "tooltip": self.tooltip, - } diff --git a/lyra_2/_ext/imaginaire/visualize/__init__.py b/lyra_2/_ext/imaginaire/visualize/__init__.py deleted file mode 100644 index b3b051edb4aee99b5ae0ab0a7d420112b5524f55..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/visualize/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/lyra_2/_ext/imaginaire/visualize/video.py b/lyra_2/_ext/imaginaire/visualize/video.py deleted file mode 100644 index cb2cad516209b961e404076e5005cff73adfd154..0000000000000000000000000000000000000000 --- a/lyra_2/_ext/imaginaire/visualize/video.py +++ /dev/null @@ -1,97 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import IO, Any, Union - -import cv2 -import numpy as np -import torch -from einops import rearrange -from PIL import Image as PILImage -from torch import Tensor - -from lyra_2._ext.imaginaire.utils import log -from lyra_2._ext.imaginaire.utils.easy_io import easy_io - -try: - import ffmpegcv -except Exception as e: # ImportError cannot catch all problems - log.info(e) - ffmpegcv = None - - -def save_video(grid, video_name, fps=30): - grid = (grid * 255).astype(np.uint8) - grid = np.transpose(grid, (1, 2, 3, 0)) - with ffmpegcv.VideoWriter(video_name, "h264", fps) as writer: - for frame in grid: - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) - - writer.write(frame) - - -def save_img_or_video( - sample_C_T_H_W_in01: Tensor, save_fp_wo_ext: Union[str, IO[Any]], fps: int = 24, quality=None, ffmpeg_params=None -) -> None: - """ - Save a tensor as an image or video file based on shape - - Args: - sample_C_T_H_W_in01 (Tensor): Input tensor with shape (C, T, H, W) in [0, 1] range. - save_fp_wo_ext (Union[str, IO[Any]]): File path without extension or file-like object. - fps (int): Frames per second for video. Default is 24. - """ - assert sample_C_T_H_W_in01.ndim == 4, "Only support 4D tensor" - assert isinstance(save_fp_wo_ext, str) or hasattr(save_fp_wo_ext, "write"), ( - "save_fp_wo_ext must be a string or file-like object" - ) - - if torch.is_floating_point(sample_C_T_H_W_in01): - sample_C_T_H_W_in01 = sample_C_T_H_W_in01.clamp(0, 1) - else: - assert sample_C_T_H_W_in01.dtype == torch.uint8, "Only support uint8 tensor" - sample_C_T_H_W_in01 = sample_C_T_H_W_in01.float().div(255) - - kwargs = {} - if quality is not None: - kwargs["quality"] = quality - if ffmpeg_params is not None: - kwargs["ffmpeg_params"] = ffmpeg_params - - if sample_C_T_H_W_in01.shape[1] == 1: - save_obj = PILImage.fromarray( - rearrange((sample_C_T_H_W_in01.cpu().float().numpy() * 255), "c 1 h w -> h w c").astype(np.uint8), - mode="RGB", - ) - ext = ".jpg" if isinstance(save_fp_wo_ext, str) else "" - easy_io.dump( - save_obj, - f"{save_fp_wo_ext}{ext}" if isinstance(save_fp_wo_ext, str) else save_fp_wo_ext, - file_format="jpg", - format="JPEG", - quality=85, - **kwargs, - ) - else: - save_obj = rearrange((sample_C_T_H_W_in01.cpu().float().numpy() * 255), "c t h w -> t h w c").astype(np.uint8) - ext = ".mp4" if isinstance(save_fp_wo_ext, str) else "" - easy_io.dump( - save_obj, - f"{save_fp_wo_ext}{ext}" if isinstance(save_fp_wo_ext, str) else save_fp_wo_ext, - file_format="mp4", - format="mp4", - fps=fps, - **kwargs, - ) diff --git a/lyra_2/_src/__init__.py b/lyra_2/_src/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/lyra_2/_src/callbacks/__init__.py b/lyra_2/_src/callbacks/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/lyra_2/_src/callbacks/model_weights_stats.py b/lyra_2/_src/callbacks/model_weights_stats.py deleted file mode 100644 index 896f387715736b672af0b9dd3e8594b5175eb147..0000000000000000000000000000000000000000 --- a/lyra_2/_src/callbacks/model_weights_stats.py +++ /dev/null @@ -1,65 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any - -import torch -from torch import nn - - -@dataclass -class TrainingStats: - """Data class to hold training statistics.""" - - video_samples: int = 0 - image_samples: int = 0 - iterations: int = 0 - training_hours: float = 0.0 - - -class WeightTrainingStat(nn.Module, ABC): - """Abstract base class for tracking training statistics.""" - - def __init__(self) -> None: - super().__init__() - self._initialize_tracking_buffers() - - def _initialize_tracking_buffers(self) -> None: - """Initialize tracking buffers with default values.""" - tracking_buffers = { - "accum_video_sample_counter": torch.tensor(0, dtype=torch.int64), - "accum_image_sample_counter": torch.tensor(0, dtype=torch.int64), - "accum_iteration": torch.tensor(0, dtype=torch.int64), - "accum_train_in_hours": torch.tensor(0.0, dtype=torch.float32), - } - - for name, tensor in tracking_buffers.items(): - self.register_buffer(name, tensor) - - def get_training_stats(self) -> TrainingStats: - """Return current training statistics.""" - return TrainingStats( - video_samples=self.accum_video_sample_counter.item(), - image_samples=self.accum_image_sample_counter.item(), - iterations=self.accum_iteration.item(), - training_hours=self.accum_train_in_hours.item(), - ) - - @abstractmethod - def forward(self, *args, **kwargs) -> Any: - pass diff --git a/lyra_2/_src/configs/__init__.py b/lyra_2/_src/configs/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/lyra_2/_src/configs/config.py b/lyra_2/_src/configs/config.py deleted file mode 100644 index b47858f1d297182c0592400b280b761a6a8b89b0..0000000000000000000000000000000000000000 --- a/lyra_2/_src/configs/config.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Any, List - -import attrs - -from lyra_2._ext.imaginaire import config -from lyra_2._ext.imaginaire.trainer import ImaginaireTrainer as Trainer -from lyra_2._ext.imaginaire.utils.config_helper import import_all_modules_from_package -from lyra_2._src.configs.defaults.common.checkpoint import register_checkpoint -from lyra_2._src.configs.defaults.common.ckpt_type import register_ckpt_type -from lyra_2._src.configs.defaults.common.ema import register_ema -from lyra_2._src.configs.defaults.common.optimizer import register_optimizer -from lyra_2._src.configs.defaults.common.scheduler import register_scheduler -from lyra_2._src.configs.defaults.common.tokenizer import register_tokenizer -from lyra_2._src.configs.defaults.conditioner import lyra_register_conditioner -from lyra_2._src.configs.defaults.dataloader import lyra_register_dataloaders -from lyra_2._src.configs.defaults.model import lyra_register_model -from lyra_2._src.configs.defaults.net import lyra_register_net - - -@attrs.define(slots=False) -class Config(config.Config): - defaults: List[Any] = attrs.field( - factory=lambda: [ - "_self_", - {"data_train": None}, - {"data_val": None}, - {"optimizer": "adamw"}, - {"scheduler": "lambdalinear"}, - {"model": "fsdp_wan2pt1_lyra2_spatial"}, - {"net": "wan2pt1_14B_i2v_lyra2"}, - {"conditioner": "lyra2_conditioner"}, - {"ema": "power"}, - {"tokenizer": "wan2pt1_tokenizer"}, - {"checkpoint": "local"}, - {"ckpt_type": "dummy"}, - {"experiment": None}, - ] - ) - - -def make_config() -> Config: - c = Config( - model=None, - optimizer=None, - scheduler=None, - dataloader_train=None, - dataloader_val=None, - ) - - c.job.project = "lyra_2" - c.job.group = "debug" - c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}" - - c.trainer.type = Trainer - c.trainer.straggler_detection.enabled = False - c.trainer.max_iter = 400_000 - c.trainer.logging_iter = 10 - c.trainer.validation_iter = 100 - c.trainer.run_validation = False - c.trainer.callbacks = None - - # Register common defaults - register_optimizer() - register_scheduler() - register_ema() - register_tokenizer() - register_checkpoint() - register_ckpt_type() - - # Register lyra_2-specific configs - lyra_register_model() - lyra_register_net() - lyra_register_conditioner() - lyra_register_dataloaders() - - # Register experiment configs - import_all_modules_from_package("lyra_2._src.configs", reload=True) - - return c diff --git a/lyra_2/_src/configs/defaults/__init__.py b/lyra_2/_src/configs/defaults/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/lyra_2/_src/configs/defaults/common/__init__.py b/lyra_2/_src/configs/defaults/common/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/lyra_2/_src/configs/defaults/common/checkpoint.py b/lyra_2/_src/configs/defaults/common/checkpoint.py deleted file mode 100644 index 4ee855dd9c913610cfa41216a134b3492a69cb64..0000000000000000000000000000000000000000 --- a/lyra_2/_src/configs/defaults/common/checkpoint.py +++ /dev/null @@ -1,57 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from hydra.core.config_store import ConfigStore - -from lyra_2._ext.imaginaire.config import CheckpointConfig, ObjectStoreConfig - -object_store = ObjectStoreConfig( - enabled=False, - credentials="", - bucket="", -) - -object_store_s3_vsr = ObjectStoreConfig( - enabled=True, - credentials="credentials/s3_checkpoint.secret", - bucket="a3-upsampler", -) - - -CHECKPOINT_LOCAL = CheckpointConfig( - save_to_object_store=object_store, - save_iter=1000, - load_from_object_store=object_store, - load_path="", - load_training_state=False, - strict_resume=True, -) - - -CHECKPOINT_S3_VSR = CheckpointConfig( - save_to_object_store=object_store_s3_vsr, - save_iter=1000, - load_from_object_store=object_store_s3_vsr, - load_path="", - load_training_state=False, - strict_resume=True, -) - - -def register_checkpoint(): - cs = ConfigStore.instance() - cs.store(group="checkpoint", package="checkpoint", name="local", node=CHECKPOINT_LOCAL) - cs.store(group="checkpoint", package="checkpoint", name="s3_vsr", node=CHECKPOINT_S3_VSR) diff --git a/lyra_2/_src/configs/defaults/common/ckpt_type.py b/lyra_2/_src/configs/defaults/common/ckpt_type.py deleted file mode 100644 index 08df0f8917e5e810adeb0d276b9ed944cee3c236..0000000000000000000000000000000000000000 --- a/lyra_2/_src/configs/defaults/common/ckpt_type.py +++ /dev/null @@ -1,32 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Dict - -from hydra.core.config_store import ConfigStore - -from lyra_2._ext.imaginaire.checkpointer.dcp import DistributedCheckpointer -from lyra_2._ext.imaginaire.checkpointer.dummy import Checkpointer as DummyCheckpointer -from lyra_2._ext.imaginaire.lazy_config import LazyCall as L - -DUMMY_CHECKPOINTER: Dict[str, str] = L(DummyCheckpointer)() -DISTRIBUTED_CHECKPOINTER: Dict[str, str] = L(DistributedCheckpointer)() - - -def register_ckpt_type(): - cs = ConfigStore.instance() - cs.store(group="ckpt_type", package="checkpoint.type", name="dummy", node=DUMMY_CHECKPOINTER) - cs.store(group="ckpt_type", package="checkpoint.type", name="dcp", node=DISTRIBUTED_CHECKPOINTER) diff --git a/lyra_2/_src/configs/defaults/common/ema.py b/lyra_2/_src/configs/defaults/common/ema.py deleted file mode 100644 index dfc627bd298930f30ee774f15b5518353133c550..0000000000000000000000000000000000000000 --- a/lyra_2/_src/configs/defaults/common/ema.py +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from hydra.core.config_store import ConfigStore - -from lyra_2._src.models.wan_t2v_model import EMAConfig - -PowerEMAConfig: EMAConfig = EMAConfig( - enabled=True, - rate=0.10, - iteration_shift=0, -) - - -def register_ema(): - cs = ConfigStore.instance() - cs.store(group="ema", package="model.config.ema", name="power", node=PowerEMAConfig) diff --git a/lyra_2/_src/configs/defaults/common/optimizer.py b/lyra_2/_src/configs/defaults/common/optimizer.py deleted file mode 100644 index cff2901ffbb5934c03ffc4fd9e0303039f610a6f..0000000000000000000000000000000000000000 --- a/lyra_2/_src/configs/defaults/common/optimizer.py +++ /dev/null @@ -1,47 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from hydra.core.config_store import ConfigStore - -from lyra_2._ext.imaginaire.lazy_config import PLACEHOLDER -from lyra_2._ext.imaginaire.lazy_config import LazyCall as L -from lyra_2._src.utils.optim_instantiate import get_base_optimizer, get_multiple_optimizer - -AdamWConfig = L(get_base_optimizer)( - model=PLACEHOLDER, - lr=1e-4, - weight_decay=1e-3, - betas=[0.9, 0.999], - optim_type="adamw", - eps=1e-8, - fused=True, -) - -AdamWConfigHighLR = L(get_multiple_optimizer)( - model=PLACEHOLDER, - lr=1e-4, - weight_decay=1e-3, - optim_type="adamw", - eps=1e-8, - fused=True, - lr_overrides=[], -) - - -def register_optimizer(): - cs = ConfigStore.instance() - cs.store(group="optimizer", package="optimizer", name="adamw", node=AdamWConfig) - cs.store(group="optimizer", package="optimizer", name="adamw_multiple_lr", node=AdamWConfigHighLR) diff --git a/lyra_2/_src/configs/defaults/common/scheduler.py b/lyra_2/_src/configs/defaults/common/scheduler.py deleted file mode 100644 index 612c9fd186ca33b71e995ef8a7bedb3dbd16ce56..0000000000000000000000000000000000000000 --- a/lyra_2/_src/configs/defaults/common/scheduler.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from hydra.core.config_store import ConfigStore - -from lyra_2._ext.imaginaire.functional.lr_scheduler import LambdaLinearScheduler -from lyra_2._ext.imaginaire.lazy_config import LazyCall as L -from lyra_2._ext.imaginaire.lazy_config import LazyDict - -LambdaLinearSchedulerConfig: LazyDict = L(LambdaLinearScheduler)( - warm_up_steps=[1000], - cycle_lengths=[10000000000000], - f_start=[1.0e-6], - f_max=[1.0], - f_min=[1.0], -) - - -def register_scheduler(): - cs = ConfigStore.instance() - cs.store(group="scheduler", package="scheduler", name="lambdalinear", node=LambdaLinearSchedulerConfig) diff --git a/lyra_2/_src/configs/defaults/common/tokenizer.py b/lyra_2/_src/configs/defaults/common/tokenizer.py deleted file mode 100644 index 73c8e2a83f54928c7a7230ba8b60a26ac79a1d7f..0000000000000000000000000000000000000000 --- a/lyra_2/_src/configs/defaults/common/tokenizer.py +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from hydra.core.config_store import ConfigStore - -from lyra_2._src.tokenizers.wan2pt1 import ( - Wan2pt1VAEConfig, -) - - -def register_tokenizer(): - cs = ConfigStore.instance() - cs.store(group="tokenizer", package="model.config.tokenizer", name="wan2pt1_tokenizer", node=Wan2pt1VAEConfig) diff --git a/lyra_2/_src/configs/defaults/conditioner.py b/lyra_2/_src/configs/defaults/conditioner.py deleted file mode 100644 index 5eff1bc543ae78f193a27c595734dd45b4545d81..0000000000000000000000000000000000000000 --- a/lyra_2/_src/configs/defaults/conditioner.py +++ /dev/null @@ -1,83 +0,0 @@ -from dataclasses import dataclass -from typing import Dict, Optional - -import torch -from hydra.core.config_store import ConfigStore - -from lyra_2._ext.imaginaire.lazy_config import LazyCall as L -from lyra_2._ext.imaginaire.lazy_config import LazyDict -from lyra_2._src.models.lyra2_model import WAN2PT1_I2V_COND_LATENT_KEY -from lyra_2._src.networks.clip_lyra2 import Wan2pt1CLIPEmbLyra2 -from lyra_2._src.modules.conditioner import ( - BaseCondition, - GeneralConditioner, - ReMapkey, - T2VCondition, - TextAttrEmptyStringDrop, - broadcast_condition, -) -from lyra_2._src.utils.context_parallel import broadcast - - -@dataclass(frozen=True) -class Img2VidWan2pt1ConditionLyra2(T2VCondition): - frame_cond_crossattn_emb_B_L_D: Optional[torch.Tensor] = None - y_B_C_T_H_W: Optional[torch.Tensor] = None - y_buffer_B_C_T_H_W: Optional[torch.Tensor] = None - - def broadcast(self, process_group: torch.distributed.ProcessGroup) -> BaseCondition: - if self.is_broadcasted: - return self - kwargs = self.to_dict(skip_underscore=False) - y = kwargs.pop("y_B_C_T_H_W") - y_buffer = kwargs.pop("y_buffer_B_C_T_H_W") - new_cond = T2VCondition.broadcast(type(self)(**kwargs), process_group) - kwargs = new_cond.to_dict(skip_underscore=False) - kwargs["y_B_C_T_H_W"] = broadcast(y, process_group) - kwargs["y_buffer_B_C_T_H_W"] = broadcast(y_buffer, process_group) if y_buffer is not None else None - return type(self)(**kwargs) - - -class Img2VidWan2pt1ConditionerLyra2(GeneralConditioner): - def forward( - self, - batch: Dict, - override_dropout_rate: Optional[Dict[str, float]] = None, - ) -> Img2VidWan2pt1ConditionLyra2: - output = super()._forward(batch, override_dropout_rate) - return Img2VidWan2pt1ConditionLyra2(**output) - - -Lyra2ConditionerConfig: LazyDict = L(Img2VidWan2pt1ConditionerLyra2)( - text=L(TextAttrEmptyStringDrop)( - input_key=["t5_text_embeddings"], - dropout_rate=0.2, - ), - fps=L(ReMapkey)( - input_key="fps", - output_key="fps", - dropout_rate=0.0, - dtype=None, - ), - padding_mask=L(ReMapkey)( - input_key="padding_mask", - output_key="padding_mask", - dropout_rate=0.0, - dtype=None, - ), - wanclip=L(Wan2pt1CLIPEmbLyra2)( - input_key=["last_hist_frame", "video", WAN2PT1_I2V_COND_LATENT_KEY, "cond_latent_mask", "cond_latent_buffer"], - dropout_rate=0.0, - dtype="bfloat16", - ), -) - - -def lyra_register_conditioner(): - cs = ConfigStore.instance() - cs.store( - group="conditioner", - package="model.config.conditioner", - name="lyra2_conditioner", - node=Lyra2ConditionerConfig, - ) diff --git a/lyra_2/_src/configs/defaults/dataloader.py b/lyra_2/_src/configs/defaults/dataloader.py deleted file mode 100644 index 1611330a9d57eb16e85bb6c385270dbe95b9b3ce..0000000000000000000000000000000000000000 --- a/lyra_2/_src/configs/defaults/dataloader.py +++ /dev/null @@ -1,16 +0,0 @@ -from hydra.core.config_store import ConfigStore -from lyra_2._src.datasets.depth_warp_dataloader import get_gen3c_multiple_video_dataloader - - -def lyra_register_dataloaders(): - """Register lyra_2 dataloaders.""" - cs = ConfigStore.instance() - - lyra2_dl3dv_long_480p_dav3_hsg = get_gen3c_multiple_video_dataloader( - dataset_list=["dl3dv_long_moge_chunk_81_480p_dav3_hsg"], - dataset_weight_list=[1], - num_workers=2, - prefetch_factor=2, - ) - cs.store(group="data_train", package="dataloader_train", name="lyra2_dl3dv_long_moge_chunk_81_480p_dav3_hsg", node=lyra2_dl3dv_long_480p_dav3_hsg) - cs.store(group="data_val", package="dataloader_val", name="lyra2_dl3dv_long_moge_chunk_81_480p_dav3_hsg", node=lyra2_dl3dv_long_480p_dav3_hsg) diff --git a/lyra_2/_src/configs/defaults/model.py b/lyra_2/_src/configs/defaults/model.py deleted file mode 100644 index a639948211ebdb31acbf83573b6138925354ce2f..0000000000000000000000000000000000000000 --- a/lyra_2/_src/configs/defaults/model.py +++ /dev/null @@ -1,28 +0,0 @@ -from hydra.core.config_store import ConfigStore -from lyra_2._ext.imaginaire.lazy_config import LazyCall as L -from lyra_2._src.models.lyra2_model import ( - Lyra2Model, - Lyra2T2VConfig, -) - - -fsdp_wan2pt1_lyra2_spatial_config = dict( - trainer=dict(distributed_parallelism="fsdp"), - model=L(Lyra2Model)( - config=Lyra2T2VConfig(fsdp_shard_size=8, state_t=20), - _recursive_=False, - ), -) - -ddp_wan2pt1_lyra2_spatial_config = dict( - trainer=dict(distributed_parallelism="ddp"), - model=L(Lyra2Model)( - config=Lyra2T2VConfig(state_t=20), - _recursive_=False, - ), -) - -def lyra_register_model(): - cs = ConfigStore.instance() - cs.store(group="model", package="_global_", name="fsdp_wan2pt1_lyra2_spatial", node=fsdp_wan2pt1_lyra2_spatial_config) - cs.store(group="model", package="_global_", name="ddp_wan2pt1_lyra2_spatial", node=ddp_wan2pt1_lyra2_spatial_config) diff --git a/lyra_2/_src/configs/defaults/net.py b/lyra_2/_src/configs/defaults/net.py deleted file mode 100644 index ed340cd2955fd71ae3b123ff460d7f30ea77bbb5..0000000000000000000000000000000000000000 --- a/lyra_2/_src/configs/defaults/net.py +++ /dev/null @@ -1,44 +0,0 @@ -from hydra.core.config_store import ConfigStore -from lyra_2._ext.imaginaire.lazy_config import LazyCall as L -from lyra_2._ext.imaginaire.lazy_config import LazyDict -from lyra_2._src.networks.wan2pt1_lyra2 import Lyra2WanModel -from lyra_2._src.modules.selective_activation_checkpoint import SACConfig - - -WAN2PT1_1PT3B_I2V_LYRA2: LazyDict = L(Lyra2WanModel)( - dim=1536, - eps=1e-06, - ffn_dim=8960, - freq_dim=256, - in_dim=36, - model_type="i2v", - num_heads=12, - num_layers=30, - out_dim=16, - text_len=512, - cp_comm_type="p2p", - sac_config=L(SACConfig)(mode="block_wise"), - postpone_checkpoint=False, -) - -WAN2PT1_14B_I2V_LYRA2: LazyDict = L(Lyra2WanModel)( - dim=5120, - eps=1e-06, - ffn_dim=13824, - freq_dim=256, - in_dim=36, - model_type="i2v", - num_heads=40, - num_layers=40, - out_dim=16, - text_len=512, - cp_comm_type="p2p", - sac_config=L(SACConfig)(mode="block_wise"), - postpone_checkpoint=False, -) - - -def lyra_register_net(): - cs = ConfigStore.instance() - cs.store(group="net", package="model.config.net", name="wan2pt1_1pt3B_i2v_lyra2", node=WAN2PT1_1PT3B_I2V_LYRA2) - cs.store(group="net", package="model.config.net", name="wan2pt1_14B_i2v_lyra2", node=WAN2PT1_14B_I2V_LYRA2) diff --git a/lyra_2/_src/configs/experiment.py b/lyra_2/_src/configs/experiment.py deleted file mode 100644 index 0d8e302c981c9e8ed5c8eca52124373c076a085d..0000000000000000000000000000000000000000 --- a/lyra_2/_src/configs/experiment.py +++ /dev/null @@ -1,83 +0,0 @@ -from hydra.core.config_store import ConfigStore - -cs = ConfigStore.instance() - - -def register_lyra2(): - """Fully-flattened lyra_2 spatial training experiment. - - Effective config equivalent to: - two_buffers_dl3dv_image_tokens_correspondence_finetune_kq_only_multibuffer_add_depth_hsg - in the source repo. - """ - experiment_config = dict( - defaults=[ - {"override /model": "fsdp_wan2pt1_lyra2_spatial"}, - {"override /net": "wan2pt1_14B_i2v_lyra2"}, - {"override /conditioner": "lyra2_conditioner"}, - {"override /data_train": "lyra2_dl3dv_long_moge_chunk_81_480p_dav3_hsg"}, - {"override /data_val": "lyra2_dl3dv_long_moge_chunk_81_480p_dav3_hsg"}, - "_self_", - ], - job=dict( - project="lyra_2", - group="lyra2", - name="lyra2", - ), - model=dict( - config=dict( - ema=dict(enabled=False), - framepack_type="f1k1f4s2f1s1f16k4f2k2f1k1_g20", - max_segments=13, - apply_corruption_to_spatial_region="noise_with_sigma", - augment_sigma_sample_p_mean=-3.0, - augment_sigma_sample_p_std=2.0, - augment_sigma_sample_multiplier=1.0, - self_aug_enabled=True, - self_aug_steps=1, - self_aug_guidance=1.0, - self_aug_scheduler_shift=1.0, - self_aug_every_k=2, - self_aug_prob=1.0, - self_aug_max_T=500, - self_aug_copy_chunk=True, - self_aug_encode_gt_with_clean_history=True, - starting_frame_ratio=0.0, - use_mp_policy_fsdp=True, - keep_original_net_dtype=True, - spatial_memory_use_image=True, - spatial_memory_stride=8, - spatial_memory_skip_recent=16, - warp_chunk_size=16, - framepack_trainable_modules="cam_encoder,buffer_encoder,self_attn,clean_patch_embeddings,patch_embedding", - ), - ), - model_parallel=dict( - context_parallel_size=1, - ), - optimizer=dict( - lr=3e-5, - ), - checkpoint=dict( - save_iter=100, - save_to_object_store=dict(enabled=False), - load_from_object_store=dict(enabled=False), - load_path="", - load_training_state=False, - strict_resume=False, - ), - trainer=dict( - max_iter=1000000, - callbacks=None, - ), - ) - - cs.store( - group="experiment", - package="_global_", - name="lyra2", - node=experiment_config, - ) - - -register_lyra2() diff --git a/lyra_2/_src/datasets/__init__.py b/lyra_2/_src/datasets/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/lyra_2/_src/datasets/base.py b/lyra_2/_src/datasets/base.py deleted file mode 100644 index 218eed0068108fb5d3269243fca6fb95a60a5877..0000000000000000000000000000000000000000 --- a/lyra_2/_src/datasets/base.py +++ /dev/null @@ -1,147 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, List, Optional - -from lyra_2._src.datasets.data_field import DataField - - -class BaseDataset(ABC): - """ - Base class for all datasets. - Note that this is not directly wrap-able by a dataloader. It is meant to be - subclassed / included in another dataset class. - """ - - def __init__(self): - pass - - @abstractmethod - def available_data_fields(self) -> List[DataField]: - """ - Return a list of available data fields in the dataset. - """ - pass - - @abstractmethod - def num_videos(self) -> int: - """ - Returns: - Number of videos in the dataset. - """ - pass - - @abstractmethod - def num_views(self, video_idx: int) -> int: - """ - Args: - video_idx: Index of the video. - - Returns: - Number of views in the video. - """ - pass - - @abstractmethod - def num_frames(self, video_idx: int, view_idx: int = 0) -> int: - """ - Args: - video_idx: Index of the video. - view_idx: Index of the view. - - Returns: - Number of frames in the given view. - """ - pass - - def read_video_metadata(self, video_idx: int) -> dict[str, Any]: - """ - Read metadata of the video. - - Args: - video_idx: Index of the video. - - Returns: - A dictionary containing metadata of the video. - """ - return {} - - def read_view_metadata(self, video_idx: int, view_idx: int) -> dict[str, Any]: - """ - Read metadata of the view. - - Args: - video_idx: Index of the video. - view_idx: Index of the view. - - Returns: - A dictionary containing metadata of the view. - """ - return {} - - def read( - self, - video_idx: int, - frame_idxs: Optional[List[int]] = None, - view_idxs: Optional[List[int]] = None, - data_fields: Optional[List[DataField]] = None, - ) -> dict[DataField, Any]: - """ - Read data from the dataset. - Args: - video_idx: Index of the video. - frame_idxs: List of frame indices. - view_idxs: List of view indices. - data_fields: List of data fields to read. If None, read all data fields. - - Example: - if frame_idxs is None, view_idxs is None, read all frames from all views. - if frame_idxs is not None, view_idxs is None, read frames from the first view. - if frame_idxs is None, view_idxs is not None, read all frames from the specified views. - if frame_idxs is not None, view_idxs is not None, read frames from the specified views. - - Returns: - A dictionary mapping data fields to their values. - """ - - if data_fields is None: - data_fields = self.available_data_fields() - - if frame_idxs is None: - # Frame not provided, default read all frames. - if view_idxs is None: - view_iterator = range(self.num_views(video_idx)) - else: - view_iterator = view_idxs - - new_frame_idxs, new_view_idxs = [], [] - for view_idx in view_iterator: - num_frames = self.num_frames(video_idx, view_idx) - new_frame_idxs.extend(list(range(num_frames))) - new_view_idxs.extend([view_idx] * num_frames) - frame_idxs, view_idxs = new_frame_idxs, new_view_idxs - - elif view_idxs is None: - # View not provided, but frame is provided, we only read the first view. - view_idxs = [0] * len(frame_idxs) - - else: - # Both frame_idxs and view_idxs provided, do sanity check. - assert len(frame_idxs) == len(view_idxs), ( - "Frame and view indices must match." - ) - - return self._read_data( - video_idx=video_idx, - frame_idxs=frame_idxs, - view_idxs=view_idxs, - data_fields=data_fields, - ) - - @abstractmethod - def _read_data( - self, - video_idx: int, - frame_idxs: List[int], - view_idxs: List[int], - data_fields: List[DataField], - ) -> dict[DataField, Any]: - pass diff --git a/lyra_2/_src/datasets/config_dataverse.py b/lyra_2/_src/datasets/config_dataverse.py deleted file mode 100644 index 00245690d6783f5726c02b717ed973e16ffb4026..0000000000000000000000000000000000000000 --- a/lyra_2/_src/datasets/config_dataverse.py +++ /dev/null @@ -1,33 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -DATAVERSE_CONFIG = { - "dl3dv_long_moge_chunk_81_480p_dav3_hsg": { - "dataset_cfg": { - "target": "lyra_2._src.datasets.radym.Radym", - "params": { - "root_path": "", - "filter_list_path": "", - }, - }, - "data_name": "dl3dv_long_moge_chunk_81", - "sample_n_frames": 1000, - "video_mirror": True, - "video_mirror_when_short_only": True, - "sample_size": [480, 854], - "crop_size": [480, 832], - "t5_embedding_path": "", - } -} diff --git a/lyra_2/_src/datasets/data_sources/__init__.py b/lyra_2/_src/datasets/data_sources/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/lyra_2/_src/datasets/data_sources/item_datasets_for_validation.py b/lyra_2/_src/datasets/data_sources/item_datasets_for_validation.py deleted file mode 100644 index 0e33eac6332937b3703ed0c9e72edb7d53112c7a..0000000000000000000000000000000000000000 --- a/lyra_2/_src/datasets/data_sources/item_datasets_for_validation.py +++ /dev/null @@ -1,37 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import dataclasses -import os - - -@dataclasses.dataclass -class ItemDatasetConfig: - path: str - length: int - - -def get_itemdataset_option_local(name: str) -> ItemDatasetConfig: - return ITEMDATASET_OPTIONS_LOCAL[name] - - -_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) - -ITEMDATASET_OPTIONS_LOCAL = { - "empty_string_umt5": ItemDatasetConfig( - path=os.path.join(_REPO_ROOT, "checkpoints", "empty_string_umt5.pt"), - length=1, - ), -} diff --git a/lyra_2/_src/datasets/depth_warp_dataloader.py b/lyra_2/_src/datasets/depth_warp_dataloader.py deleted file mode 100644 index f64403fb3fe3ded28e0c2e66848b120c2d10c79e..0000000000000000000000000000000000000000 --- a/lyra_2/_src/datasets/depth_warp_dataloader.py +++ /dev/null @@ -1,333 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import importlib -import os -import pickle -import random - -import numpy as np -import torch -import torchvision.transforms as transforms -from torch.utils.data import DataLoader - -try: - from lyra_2._src.datasets.base import DataField -except ImportError: - pass -from lyra_2._ext.imaginaire.utils import log -from lyra_2._ext.imaginaire.lazy_config import LazyCall as L, instantiate -from lyra_2._src.datasets.config_dataverse import DATAVERSE_CONFIG -from lyra_2._src.datasets.utils import VIDEO_RES_SIZE_INFO - -try: - from megatron.core import parallel_state -except ImportError: - pass - -import omegaconf -from omegaconf import OmegaConf - - -# --------------------------------------------------------------------------- -# Utilities -# --------------------------------------------------------------------------- - -def _get_obj_from_str(string): - module, cls = string.rsplit(".", 1) - return getattr(importlib.import_module(module, package=None), cls) - - -def _instantiate_from_config(config, **additional_kwargs): - if "target" not in config: - raise KeyError("Expected key `target` to instantiate.") - additional_kwargs.update(config.get("params", dict())) - return _get_obj_from_str(config["target"])(**additional_kwargs) - - -def _resize_intrinsics(intrinsics, old_size, new_size, crop_size=None): - """intrinsics: (N, 3, 3), sizes: (h, w)""" - intrinsics_copy = intrinsics.clone() if isinstance(intrinsics, torch.Tensor) else np.copy(intrinsics) - intrinsics_copy[:, 0, :] *= new_size[1] / old_size[1] - intrinsics_copy[:, 1, :] *= new_size[0] / old_size[0] - if crop_size is not None: - intrinsics_copy[:, 0, -1] -= (new_size[1] - crop_size[1]) / 2 - intrinsics_copy[:, 1, -1] -= (new_size[0] - crop_size[0]) / 2 - return intrinsics_copy - - -def _intrinsics_from_fxfycxcy_batch(intrinsics): - m = torch.zeros((intrinsics.shape[0], 3, 3), device=intrinsics.device) - m[:, 0, 0] = intrinsics[:, 0] - m[:, 1, 1] = intrinsics[:, 1] - m[:, 0, 2] = intrinsics[:, 2] - m[:, 1, 2] = intrinsics[:, 3] - m[:, 2, 2] = 1 - return m - - -def _dict_collation_fn(samples): - """Collate a list of sample dicts into a batched dict.""" - batched = {key: [s[key] for s in samples] for key in samples[0]} - result = {} - for key, vals in batched.items(): - if isinstance(vals[0], bool): - result[key] = vals[0] - elif isinstance(vals[0], (int, float)): - result[key] = torch.from_numpy(np.array(vals)) - elif isinstance(vals[0], torch.Tensor): - result[key] = torch.stack(vals) - else: - result[key] = vals - return result - - -def _sample_frame_indices(total_frames, N, video_mirror=False): - """Sample N frame indices starting from 0 with stride 1. - If video_mirror is True, extends short clips by mirroring before sampling. - """ - if video_mirror: - mapping = list(range(total_frames)) - n_repeat = max((N - total_frames) // (total_frames - 1), 0) + 1 - mapping_repeat = mapping.copy() - for i in range(n_repeat): - mapping_repeat += mapping[-2::-1] if i % 2 == 0 else mapping[1:] - return [mapping_repeat[i] for i in range(N)] - else: - if total_frames < N: - idx = list(range(total_frames)) - idx += [total_frames - 1] * (N - total_frames) - return idx - return list(range(N)) - - -# --------------------------------------------------------------------------- -# Dataloader -# --------------------------------------------------------------------------- - -class IterativeGEN3CDataLoader: - """Wraps multiple dataloaders with ratio-based sampling.""" - - def __init__(self, dataloaders): - self.dataloader_list, self.dataset_name_list, self.data_ratios = [], [], [] - for dataset_name, dataloader_data in dataloaders.items(): - if dataset_name in ("image_data", "video_data"): - continue - self.dataset_name_list.append(dataset_name) - self.dataloader_list.append(instantiate(dataloader_data["dataloader"])) - self.data_ratios.append(dataloader_data["ratio"]) - self.ratio_sum = sum(self.data_ratios) - self.data_len = sum(len(d) for d in self.dataloader_list) - self.dataloaders = [iter(dl) for dl in self.dataloader_list] - - def __len__(self) -> int: - return self.data_len - - def __iter__(self): - while True: - data_id = random.randint(0, self.ratio_sum - 1) - cumsum = 0 - for i, r in enumerate(self.data_ratios): - cumsum += r - if data_id < cumsum: - break - output = next(self.dataloaders[i]) - output["dataset_name"] = self.dataset_name_list[i] - yield output - - -def get_gen3c_multiple_video_dataloader( - dataset_list: list[str], - dataset_weight_list: list[float], - shuffle=True, - num_workers=4, - prefetch_factor=4, - mode="random", -) -> omegaconf.dictconfig.DictConfig: - dataloader_dict = { - name: { - "dataloader": L(MyDataLoader)( - dataset=L(get_depth_warp_dataset)(dataset_name=name), - batch_size=1, - num_workers=num_workers, - shuffle=shuffle, - prefetch_factor=prefetch_factor, - ), - "ratio": weight, - } - for name, weight in zip(dataset_list, dataset_weight_list) - } - return L(IterativeGEN3CDataLoader)(dataloaders=dataloader_dict) - - -class MyDataLoader(DataLoader): - def __init__(self, dataset, batch_size: int = 1, *args, **kw): - kw.pop("dataloaders", None) - super().__init__(dataset.build_dataset(), batch_size, collate_fn=_dict_collation_fn, *args, **kw) - - -def get_depth_warp_dataset(dataset_name="dl3dv_long_moge_chunk_81_480p_dav3_hsg", resolution="720", chunk_size=256, **kwargs): - return DepthWarpDataset(dataset_name, resolution, chunk_size, **kwargs) - - -class DepthWarpDataset: - def __init__(self, dataset_name, resolution, chunk_size, **kwargs): - self.video_size = VIDEO_RES_SIZE_INFO[resolution] - self.dataset_config = OmegaConf.merge( - OmegaConf.create(DATAVERSE_CONFIG[dataset_name]), - OmegaConf.create(kwargs), - ) - - def build_dataset(self): - return InfiniteCommonDataset(**OmegaConf.to_container(self.dataset_config, resolve=True)) - - -class InfiniteCommonDataset: - def __init__( - self, - dataset_cfg, - data_name="", - batch_size=1, - sample_n_frames=8, - sample_size=[320, 512], - crop_size=None, - video_mirror=False, - video_mirror_when_short_only=False, - t5_embedding_path=None, - ): - self.dataset = _instantiate_from_config(dataset_cfg) - self.data_name = data_name - self.n_data = self.dataset.num_videos() - self.t5_embedding_path = t5_embedding_path - - if parallel_state.is_initialized(): - dp_group_id = parallel_state.get_data_parallel_rank() - dp_world_size = parallel_state.get_data_parallel_world_size() - log.critical( - f"Using parallelism size CP :{parallel_state.get_context_parallel_world_size()}, " - + f"TP :{parallel_state.get_tensor_model_parallel_world_size()} for video dataset, " - + f"DP: {dp_group_id}, DP World size: {dp_world_size}" - ) - else: - dp_world_size = 1 - dp_group_id = 0 - self.n_data_per_node = self.n_data // dp_world_size - self.data_start_idx = dp_group_id * self.n_data_per_node - - self.multiplier = (2000000 * batch_size) // self.n_data_per_node - self.sample_n_frames = sample_n_frames - self.sample_size = sample_size - self.crop_size = crop_size if crop_size is not None else sample_size - self.video_mirror = video_mirror - self.video_mirror_when_short_only = video_mirror_when_short_only - - self.img_transform = transforms.Compose([ - transforms.Resize(sample_size, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True), - transforms.CenterCrop(self.crop_size), - ]) - self.norm_image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) - - def __len__(self): - return self.multiplier * self.n_data_per_node - - def _get_frame_indices(self, n_total_frames): - if self.video_mirror_when_short_only: - use_mirror = self.video_mirror and (n_total_frames < self.sample_n_frames) - else: - use_mirror = self.video_mirror - return _sample_frame_indices(n_total_frames, self.sample_n_frames, use_mirror) - - def _transform(self, images, w2c, depths, intrinsics): - _, _, H, W = images.shape - return { - "video": self.norm_image(self.img_transform(images)).permute(1, 0, 2, 3).contiguous(), - "camera_w2c": w2c, - "depth": self.img_transform(depths), - "intrinsics": _resize_intrinsics(intrinsics, [H, W], self.sample_size, self.crop_size), - "is_preprocessed": True, - } - - def _to_w2c_and_intrinsics(self, data): - data[DataField.METRIC_DEPTH] = data[DataField.METRIC_DEPTH].unsqueeze(1) - data[DataField.CAMERA_C2W_TRANSFORM] = torch.from_numpy( - np.stack([np.linalg.inv(m.numpy()) for m in data[DataField.CAMERA_C2W_TRANSFORM]]) - ) - data[DataField.CAMERA_INTRINSICS] = _intrinsics_from_fxfycxcy_batch(data[DataField.CAMERA_INTRINSICS]) - return data - - def __getitem__(self, idx): - data_idx = (idx % self.n_data_per_node) + self.data_start_idx - assert data_idx < self.n_data - - frame_indices = self._get_frame_indices(self.dataset.num_frames(data_idx)) - - data = self.dataset._read_data( - video_idx=data_idx, - data_fields=[DataField.IMAGE_RGB, DataField.CAMERA_C2W_TRANSFORM, DataField.CAMERA_INTRINSICS], - frame_idxs=frame_indices, - view_idxs=[0], - ) - depth_data = self.dataset._read_data( - video_idx=data_idx, - data_fields=[DataField.METRIC_DEPTH], - frame_idxs=frame_indices, - view_idxs=[0], - ) - data.update(depth_data) - data = self._to_w2c_and_intrinsics(data) - - N = self.sample_n_frames - sample = self._transform( - data[DataField.IMAGE_RGB][-N:].clone(), - data[DataField.CAMERA_C2W_TRANSFORM][-N:].clone(), - data[DataField.METRIC_DEPTH][-N:].clone(), - data[DataField.CAMERA_INTRINSICS][-N:].clone(), - ) - - # T5 chunk embeddings - t5_path = os.path.join(self.t5_embedding_path, data["__key__"] + ".pkl") - if not os.path.exists(t5_path): - print(f"t5 embedding path {t5_path} does not exist") - return self.__getitem__(np.random.randint(0, len(self))) - t5_pickle = pickle.load(open(t5_path, "rb")) - keys = [int(k) for k in t5_pickle] - embeddings = [torch.as_tensor(t5_pickle[k]["embedding"]) for k in t5_pickle] - order = np.argsort(keys) - sorted_keys = np.asarray(keys, dtype=np.int64)[order] - sorted_embs = [embeddings[i] for i in order] - cutoff = int(np.searchsorted(sorted_keys, int(frame_indices[-1]), side="right")) - sorted_keys = sorted_keys[:cutoff] - sorted_embs = sorted_embs[:cutoff] - assert len(sorted_embs) > 0 - chunk_emb = torch.zeros(len(sorted_embs), 512, 4096) - chunk_mask = torch.zeros(len(sorted_embs), 512) - for i, e in enumerate(sorted_embs): - s, d = min(e.shape[0], 512), min(e.shape[1], 4096) - chunk_emb[i, :s, :d] = e[:s, :d] - chunk_mask[i, :s] = 1.0 - del t5_pickle - sample["t5_chunk_keys"] = torch.from_numpy(sorted_keys) - sample["t5_chunk_embeddings"] = chunk_emb - sample["t5_chunk_mask"] = chunk_mask - - sample["sample_frame_indices"] = torch.as_tensor(frame_indices, dtype=torch.long) - sample["num_frames"] = N - sample["image_size"] = torch.as_tensor(self.crop_size) - sample["fps"] = 24 - sample["__key__"] = data["__key__"] - sample["clip_name"] = f"{self.data_name}-{data['__key__']}-{data_idx:d}-000-001" - sample["padding_mask"] = torch.zeros(1, self.crop_size[0], self.crop_size[1]) - del data - return sample diff --git a/lyra_2/_src/datasets/forward_warp_utils_pytorch.py b/lyra_2/_src/datasets/forward_warp_utils_pytorch.py deleted file mode 100644 index 2c110c4a88ac8c06c2ed9721dcd42e3dc677c3e8..0000000000000000000000000000000000000000 --- a/lyra_2/_src/datasets/forward_warp_utils_pytorch.py +++ /dev/null @@ -1,428 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional, Tuple, Union - -import torch -from einops import rearrange - -def get_max_exponent_for_dtype(dtype): - # Set the maximum exponent based on dtype - if dtype == torch.bfloat16: - return 80.0 # Safe maximum exponent for bfloat16 - elif dtype == torch.float16: - return 10.0 # Safe maximum exponent for float16 - elif dtype == torch.float32: - return 80.0 # Safe maximum exponent for float32 - elif dtype == torch.float64: - return 700.0 # Safe maximum exponent for float64 - else: - return 80.0 # Default safe value - -def inverse_with_conversion(mtx): - return torch.linalg.inv(mtx.to(torch.float32)).to(mtx.dtype) - - -def reliable_depth_mask_range_batch(depth, window_size=5, ratio_thresh=0.05, eps=1e-6): - assert window_size % 2 == 1, "Window size must be odd." - if depth.dim() == 3: # Input shape: (B, H, W) - depth_unsq = depth.unsqueeze(1) - elif depth.dim() == 4: # Already has shape (B, 1, H, W) - depth_unsq = depth - else: - raise ValueError("depth tensor must be of shape (B, H, W) or (B, 1, H, W)") - - local_max = torch.nn.functional.max_pool2d(depth_unsq, kernel_size=window_size, stride=1, padding=window_size // 2) - local_min = -torch.nn.functional.max_pool2d(-depth_unsq, kernel_size=window_size, stride=1, padding=window_size // 2) - local_mean = torch.nn.functional.avg_pool2d(depth_unsq, kernel_size=window_size, stride=1, padding=window_size // 2) - ratio = (local_max - local_min) / (local_mean + eps) - reliable_mask = (ratio < ratio_thresh) & (depth_unsq > 0) - reliable_mask = reliable_mask - - return reliable_mask - - -def forward_warp_multiframes( - frame1: torch.Tensor, - mask1: Optional[torch.Tensor], - depth1: Optional[torch.Tensor], - transformation1: Optional[torch.Tensor], - transformation2: torch.Tensor, - intrinsic1: Optional[torch.Tensor], - intrinsic2: Optional[torch.Tensor], - is_image=True, - render_depth=False, - world_points1=None, - clean_points: bool = False, - clean_points_continuity: bool = False, # clean points based on depth continuity -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - :param frame1: (b, v, 3, h, w). If frame1 is not in the range [-1, 1], either set is_image=False when calling - bilinear_splatting on frame within this function, or modify clipping in bilinear_splatting() - method accordingly. - :param mask1: (b, v, 1, h, w) - 1 for known, 0 for unknown. Optional - :param depth1: (b, v, 1, h, w). Optional if world_points1 is provided. - :param transformation1: (b, v, 4, 4) source view w2c. Required if depth1 is provided and world_points1 is None. - :param transformation2: (b, 4, 4) extrinsic transformation matrix of target view: [R, t; 0, 1] - :param intrinsic1: (b, v, 3, 3) source intrinsics. Required if depth1 is provided and world_points1 is None. - :param intrinsic2: (b, 3, 3) camera intrinsic matrix for target view. Optional - :param world_points1: (b, v, h, w, 3) optional precomputed world points. - :param clean_points: bool, enable point cleaning. - :param clean_points_continuity: bool, use depth continuity for cleaning. - """ - device = frame1.device - b, v, c, h, w = frame1.shape - if mask1 is None: - mask1 = torch.ones(size=(b, v, 1, h, w), device=device, dtype=frame1.dtype) - - # If world_points1 isn't provided, build it from RGBD + per-view camera parameters. - if world_points1 is None: - assert depth1 is not None, "depth1 must be provided when world_points1 is None" - assert transformation1 is not None, "transformation1 (w2c) must be provided when world_points1 is None" - assert intrinsic1 is not None, "intrinsic1 must be provided when world_points1 is None" - assert depth1.shape[:2] == (b, v) - assert transformation1.shape[:2] == (b, v) - assert intrinsic1.shape[:2] == (b, v) - - depth1 = torch.nan_to_num(depth1, nan=1e4) - depth1 = torch.clamp(depth1, min=0, max=1e4) - - # Valid mask: depth>0 plus optional continuity cleaning. - mask_valid = (depth1 > 0).to(dtype=mask1.dtype, device=device) - if clean_points and clean_points_continuity: - depth_flat = rearrange(depth1, "b v c h w -> (b v) c h w") - cont_mask_flat = reliable_depth_mask_range_batch(depth_flat).to(dtype=mask1.dtype, device=device) - cont_mask = rearrange(cont_mask_flat, "(b v) c h w -> b v c h w", b=b, v=v) - mask_valid = mask_valid * cont_mask - mask1 = mask1 * mask_valid - - depth_flat = rearrange(depth1, "b v c h w -> (b v) c h w") - w2c_flat = rearrange(transformation1, "b v c d -> (b v) c d") - K_flat = rearrange(intrinsic1, "b v c d -> (b v) c d") - mask_flat = rearrange(mask1 > 0.5, "b v c h w -> (b v) c h w") - world_pts_flat = unproject_points( - depth=depth_flat, - w2c=w2c_flat, - intrinsic=K_flat, - is_depth=True, - is_ftheta=False, - mask=mask_flat, - return_sparse=False, - ) # [(b*v), h, w, 3] - world_points1 = rearrange(world_pts_flat, "(b v) h w c -> b v h w c", b=b, v=v) - - assert world_points1 is not None and world_points1.shape == (b, v, h, w, 3) - - frame1 = frame1.reshape(b * v, c, h, w) - transformation2 = transformation2.unsqueeze(1).repeat(1, v, 1, 1).view(-1, 4, 4) - intrinsic2 = intrinsic2.unsqueeze(1).repeat(1, v, 1, 1).view(-1, 3, 3) - world_points1 = rearrange(world_points1, "b v h w c-> (b v) h w c") - mask1 = rearrange(mask1, "b v c h w-> (b v) c h w") - # Avoid in-place ops on potentially expanded/broadcasted mask tensors - mask1 = mask1.clone() - - trans_points1 = project_points(world_points1, transformation2, intrinsic2) - mask1 = mask1 * (trans_points1[:, :, :, 2, 0].unsqueeze(1) > 0) - trans_coordinates = trans_points1[:, :, :, :2, 0] / (trans_points1[:, :, :, 2:3, 0] + 1e-7) - trans_coordinates = trans_coordinates.permute(0, 3, 1, 2) # b, 2, h, w - trans_depth1 = trans_points1[:, :, :, 2, 0].unsqueeze(1) - - grid = create_grid(b * v, h, w, device=device) # .to(trans_coordinates) - flow12 = trans_coordinates - grid - warped_frame2, mask2 = bilinear_splatting(frame1, mask1, trans_depth1, flow12, None, is_image=is_image, n_views=v) - if render_depth: - warped_depth2 = bilinear_splatting(trans_depth1, mask1, trans_depth1, flow12, None, is_image=False, n_views=v)[ - 0 - ][:, 0] - return warped_frame2, mask2, warped_depth2, flow12 - return warped_frame2, mask2, None, flow12 - - -def unproject_points(depth: torch.Tensor, - w2c: torch.Tensor, - intrinsic: torch.Tensor, - is_depth: bool = True, - is_ftheta: bool = False, - mask: Optional[torch.Tensor] = None, - return_sparse: bool = False) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Unprojects depth values into 3D world points. - - If is_ftheta is False the pinhole model is used; if True then the intrinsic - is interpreted as [cx, cy, width, height, poly_coeffs..., is_bw_poly] and the - fθ model is used. - - If return_sparse is True, returns a list of B tensors, where each tensor is of - shape (N, 3) and contains the unprojected points for that batch. - """ - b, _, h, w = depth.shape - device = depth.device - dtype = depth.dtype - if mask is None: - mask = depth > 0 - if mask.dim() == depth.dim() and mask.shape[1] == 1: - mask = mask[:, 0] - - idx = torch.nonzero(mask) - if idx.numel() == 0: - # No valid points: return an empty list (sparse) or a zero tensor (dense) - if return_sparse: - return [torch.empty((0, 3), device=device, dtype=dtype) for _ in range(b)] - else: - return torch.zeros((b, h, w, 3), device=device, dtype=dtype) - - b_idx, y_idx, x_idx = idx[:, 0], idx[:, 1], idx[:, 2] - - if not is_ftheta: - # ---- Pinhole model (sparse computation) ---- - intrinsic_inv = inverse_with_conversion(intrinsic) # (b, 3, 3) - - x_valid = x_idx.to(dtype) - y_valid = y_idx.to(dtype) - ones = torch.ones_like(x_valid) - pos = torch.stack([x_valid, y_valid, ones], dim=1).unsqueeze(-1) # (N, 3, 1) - - intrinsic_inv_valid = intrinsic_inv[b_idx] # (N, 3, 3) - unnormalized_pos = torch.matmul(intrinsic_inv_valid, pos) # (N, 3, 1) - - depth_valid = depth[b_idx, 0, y_idx, x_idx].view(-1, 1, 1) - if is_depth: - world_points_cam = depth_valid * unnormalized_pos - else: - norm_val = torch.norm(unnormalized_pos, dim=1, keepdim=True) - direction = unnormalized_pos / (norm_val + 1e-8) - world_points_cam = depth_valid * direction - - ones_h = torch.ones((world_points_cam.shape[0], 1, 1), - device=device, dtype=dtype) - world_points_homo = torch.cat([world_points_cam, ones_h], dim=1) # (N, 4, 1) - - trans = inverse_with_conversion(w2c) # (b, 4, 4) - trans_valid = trans[b_idx] # (N, 4, 4) - world_points_transformed = torch.matmul(trans_valid, world_points_homo) # (N, 4, 1) - sparse_points = world_points_transformed[:, :3, 0] # (N, 3) - else: - # ---- fθ model (sparse computation) ---- - x_valid = x_idx.to(dtype) - y_valid = y_idx.to(dtype) - - cx_valid = intrinsic[b_idx, 0].view(-1) - cy_valid = intrinsic[b_idx, 1].view(-1) - xd = x_valid - cx_valid - yd = y_valid - cy_valid - norm_xy = torch.sqrt(xd**2 + yd**2 + 1e-8) - - poly_coeffs_valid = intrinsic[b_idx, 4:-1] # (N, d) - d_coeff = poly_coeffs_valid.shape[1] - - powers = torch.arange(d_coeff, device=device, dtype=dtype).view(1, d_coeff) - norm_powers = norm_xy.view(-1, 1).pow(powers) # (N, d) - alpha = (poly_coeffs_valid * norm_powers).sum(dim=1) - sin_alpha = torch.sin(alpha) - cos_alpha = torch.cos(alpha) - scale = sin_alpha / (norm_xy + 1e-8) - ray_x = scale * xd - ray_y = scale * yd - ray_z = cos_alpha - - near_zero = norm_xy < 1e-6 - ray_x[near_zero] = 0.0 - ray_y[near_zero] = 0.0 - ray_z[near_zero] = 1.0 - rays = torch.stack([ray_x, ray_y, ray_z], dim=1) # (N, 3) - rays = rays / ray_z.unsqueeze(-1) - - depth_valid = depth[b_idx, 0, y_idx, x_idx].view(-1, 1) - if is_depth: - world_points_cam = depth_valid * rays - else: - ray_norm = torch.norm(rays, dim=1, keepdim=True) - world_points_cam = depth_valid * (rays / (ray_norm + 1e-8)) - - ones_h = torch.ones((world_points_cam.shape[0], 1), - device=device, dtype=dtype) - world_points_homo = torch.cat([world_points_cam, ones_h], dim=1) # (N, 4) - world_points_homo = world_points_homo.unsqueeze(-1) # (N, 4, 1) - trans = inverse_with_conversion(w2c) # (b, 4, 4) - trans_valid = trans[b_idx] # (N, 4, 4) - world_points_transformed = torch.matmul(trans_valid, world_points_homo) # (N, 4, 1) - sparse_points = world_points_transformed[:, :3, 0] # (N, 3) - - if return_sparse: - counts = torch.bincount(b_idx, minlength=b).tolist() - sparse_list = [] - offset = 0 - for count in counts: - if count > 0: - sparse_list.append(sparse_points[offset:offset+count]) - else: - sparse_list.append(torch.empty((0, 3), device=device, dtype=dtype)) - offset += count - return sparse_list - else: - out_points = torch.zeros((b, h, w, 3), device=device, dtype=dtype) - out_points[b_idx, y_idx, x_idx, :] = sparse_points - return out_points - - -def project_points(world_points: torch.Tensor, w2c: torch.Tensor, intrinsic: torch.Tensor): - """ - Projects 3D world points back into 2D pixel space. - """ - world_points = world_points.unsqueeze(-1) # (b, h, w, 3) -> # (b, h, w, 3, 1) - b, h, w, _, _ = world_points.shape - - ones_4d = torch.ones((b, h, w, 1, 1), device=world_points.device, dtype=world_points.dtype) # (b, h, w, 1, 1) - world_points_homo = torch.cat([world_points, ones_4d], dim=3) # (b, h, w, 4, 1) - - # Apply transformation2 to convert world points to camera space - trans_4d = w2c[:, None, None] # (b, 1, 1, 4, 4) - camera_points_homo = torch.matmul(trans_4d, world_points_homo) # (b, h, w, 4, 1) - - # Remove homogeneous coordinate and project to image plane - camera_points = camera_points_homo[:, :, :, :3] # (b, h, w, 3, 1) - intrinsic_4d = intrinsic[:, None, None] # (b, 1, 1, 3, 3) - projected_points = torch.matmul(intrinsic_4d, camera_points) # (b, h, w, 3, 1) - - return projected_points - - -def bilinear_splatting( - frame1: torch.Tensor, - mask1: Optional[torch.Tensor], - depth1: torch.Tensor, - flow12: torch.Tensor, - flow12_mask: Optional[torch.Tensor], - is_image: bool = False, - n_views=1, - depth_weight_scale=50, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Bilinear splatting - :param frame1: (b,c,h,w) - :param mask1: (b,1,h,w): 1 for known, 0 for unknown. Optional - :param depth1: (b,1,h,w) - :param flow12: (b,2,h,w) - :param flow12_mask: (b,1,h,w): 1 for valid flow, 0 for invalid flow. Optional - :param is_image: if true, output will be clipped to (-1,1) range - :return: warped_frame2: (b,c,h,w) - mask2: (b,1,h,w): 1 for known and 0 for unknown - """ - b, c, h, w = frame1.shape - device = frame1.device - dtype = frame1.dtype - if mask1 is None: - mask1 = torch.ones(size=(b, 1, h, w), device=device, dtype=dtype) # .to(frame1) - if flow12_mask is None: - flow12_mask = torch.ones(size=(b, 1, h, w), device=device, dtype=dtype) # .to(flow12) - grid = create_grid(b, h, w, device=device, dtype=dtype).to(dtype) # .to(frame1) - trans_pos = flow12 + grid - - trans_pos_offset = trans_pos + 1 - trans_pos_floor = torch.floor(trans_pos_offset).long() - trans_pos_ceil = torch.ceil(trans_pos_offset).long() - trans_pos_offset = torch.stack( - [torch.clamp(trans_pos_offset[:, 0], min=0, max=w + 1), torch.clamp(trans_pos_offset[:, 1], min=0, max=h + 1)], - dim=1, - ) - trans_pos_floor = torch.stack( - [torch.clamp(trans_pos_floor[:, 0], min=0, max=w + 1), torch.clamp(trans_pos_floor[:, 1], min=0, max=h + 1)], - dim=1, - ) - trans_pos_ceil = torch.stack( - [torch.clamp(trans_pos_ceil[:, 0], min=0, max=w + 1), torch.clamp(trans_pos_ceil[:, 1], min=0, max=h + 1)], - dim=1, - ) - - prox_weight_nw = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * ( - 1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]) - ) - prox_weight_sw = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * ( - 1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]) - ) - prox_weight_ne = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * ( - 1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]) - ) - prox_weight_se = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * ( - 1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]) - ) - - # Calculate depth weights, preventing overflow and removing saturation - # Clamp depth to be non-negative before log1p - clamped_depth1 = torch.clamp(depth1, min=0) - log_depth1 = torch.log1p(clamped_depth1) # Use log1p for better precision near 0 - # Normalize and scale log depth - exponent = log_depth1 / (log_depth1.max() + 1e-7) * depth_weight_scale - # Clamp exponent before exp to prevent overflow - max_exponent = get_max_exponent_for_dtype(depth1.dtype) - clamped_exponent = torch.clamp(exponent, max=max_exponent) - # Compute depth weights with added epsilon for stability when dividing later - depth_weights = torch.exp(clamped_exponent) + 1e-7 - - - weight_nw = torch.moveaxis(prox_weight_nw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2]) - weight_sw = torch.moveaxis(prox_weight_sw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2]) - weight_ne = torch.moveaxis(prox_weight_ne * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2]) - weight_se = torch.moveaxis(prox_weight_se * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2]) - - warped_frame = torch.zeros(size=(b, h + 2, w + 2, c), dtype=dtype, device=device) # .to(frame1) - warped_weights = torch.zeros(size=(b, h + 2, w + 2, 1), dtype=dtype, device=device) # .to(frame1) - - frame1_cl = torch.moveaxis(frame1, [0, 1, 2, 3], [0, 3, 1, 2]) - batch_indices = torch.arange(b, device=device, dtype=torch.long)[:, None, None] # .to(frame1.device) - warped_frame.index_put_( - (batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]), frame1_cl * weight_nw, accumulate=True - ) - warped_frame.index_put_( - (batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]), frame1_cl * weight_sw, accumulate=True - ) - warped_frame.index_put_( - (batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]), frame1_cl * weight_ne, accumulate=True - ) - warped_frame.index_put_( - (batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]), frame1_cl * weight_se, accumulate=True - ) - - warped_weights.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]), weight_nw, accumulate=True) - warped_weights.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]), weight_sw, accumulate=True) - warped_weights.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]), weight_ne, accumulate=True) - warped_weights.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]), weight_se, accumulate=True) - if n_views > 1: - warped_frame = warped_frame.reshape(b // n_views, n_views, h + 2, w + 2, c).sum(1) - warped_weights = warped_weights.reshape(b // n_views, n_views, h + 2, w + 2, 1).sum(1) - - warped_frame_cf = torch.moveaxis(warped_frame, [0, 1, 2, 3], [0, 2, 3, 1]) - warped_weights_cf = torch.moveaxis(warped_weights, [0, 1, 2, 3], [0, 2, 3, 1]) - cropped_warped_frame = warped_frame_cf[:, :, 1:-1, 1:-1] - cropped_weights = warped_weights_cf[:, :, 1:-1, 1:-1] - cropped_weights = torch.nan_to_num(cropped_weights, nan=1000.0) - - mask = cropped_weights > 0 - zero_value = -1 if is_image else 0 - zero_tensor = torch.tensor(zero_value, dtype=frame1.dtype, device=frame1.device) - warped_frame2 = torch.where(mask, cropped_warped_frame / cropped_weights, zero_tensor) - mask2 = mask.to(frame1) - if is_image: - warped_frame2 = torch.clamp(warped_frame2, min=-1, max=1) - return warped_frame2, mask2 - - -def create_grid(b: int, h: int, w: int, device="cpu", dtype=torch.float) -> torch.Tensor: - """ - Create a dense grid of (x,y) coordinates of shape (b, 2, h, w). - """ - x = torch.arange(0, w, device=device, dtype=dtype).view(1, 1, 1, w).expand(b, 1, h, w) - y = torch.arange(0, h, device=device, dtype=dtype).view(1, 1, h, 1).expand(b, 1, h, w) - return torch.cat([x, y], dim=1) diff --git a/lyra_2/_src/datasets/plucker_embed_corrupter.py b/lyra_2/_src/datasets/plucker_embed_corrupter.py deleted file mode 100644 index 7cc89ebe5977726945f844c1e5be637fea930dfa..0000000000000000000000000000000000000000 --- a/lyra_2/_src/datasets/plucker_embed_corrupter.py +++ /dev/null @@ -1,74 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from packaging import version as pver - - -def custom_meshgrid(*args): - # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid - if pver.parse(torch.__version__) < pver.parse("1.10"): - return torch.meshgrid(*args) - else: - return torch.meshgrid(*args, indexing="ij") - - -def ray_condition(K, c2w, H, W, device, flip_flag=None, use_ray_o=False): - # c2w: B, V, 4, 4 - # K: B, V, 4 - # If K is None, use constant forward z-direction in camera space - B, V = c2w.shape[:2] - - j, i = custom_meshgrid( - torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), - torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), - ) - i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW] - j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW] - - n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0 - if n_flip > 0: - j_flip, i_flip = custom_meshgrid( - torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), - torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype), - ) - i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5 - j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5 - i[:, flip_flag, ...] = i_flip - j[:, flip_flag, ...] = j_flip - - if K is None: - directions = torch.zeros(B, V, H * W, 3, device=device, dtype=c2w.dtype) - directions[..., 2] = 1.0 - else: - fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 - zs = torch.ones_like(i) # [B, V, HxW] - xs = (i - cx) / fx * zs - ys = (j - cy) / fy * zs - zs = zs.expand_as(ys) - directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 - directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 - - rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3 - rays_o = c2w[..., :3, 3] # B, V, 3 - rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3 - # c2w @ dirctions - if use_ray_o: - plucker = torch.cat([rays_o, rays_d], dim=-1) - else: - rays_dxo = torch.cross(rays_o, rays_d) # B, V, HW, 3 - plucker = torch.cat([rays_dxo, rays_d], dim=-1) - plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 - return plucker diff --git a/lyra_2/_src/datasets/radym.py b/lyra_2/_src/datasets/radym.py deleted file mode 100644 index f6c635f2ff37362b0ad3018ef462568cd04d89ad..0000000000000000000000000000000000000000 --- a/lyra_2/_src/datasets/radym.py +++ /dev/null @@ -1,303 +0,0 @@ -import json -import os -import zipfile -from pathlib import Path -from typing import Any, List, Optional - -import cv2 -import numpy as np -try: - import OpenEXR -except ImportError: - OpenEXR = None -import torch -import torch.utils.data -from decord import VideoReader -from lru import LRU - -from lyra_2._src.datasets.base import BaseDataset, DataField - - -class Radym(BaseDataset): - MAX_ZIP_DESCRIPTORS = 10 - MAX_MP4_READERS = 2 - - def __init__( - self, root_path, filter_list_path: Optional[str] = None, num_views: int = -1, depth_folder: str = "depth", - custom_folders: Optional[List[str]] = None, custom_fields: Optional[List[str]] = None - ): - # For multi-view datasets, root_path is the path to camera idx 0. - self.root_path = root_path - - # filter_list_path is a text file containing the list of mp4 files to load. - # Each line in the file should contain the name of the mp4 file with or without the extension. - if filter_list_path is None: - self.filter_set = None - else: - self.filter_list_path = filter_list_path if os.path.isabs(filter_list_path) else os.path.join(root_path, filter_list_path) - with open(self.filter_list_path, "r") as f: - self.filter_set = [line.strip() for line in f.readlines()] - self.filter_set = set([x.split(".")[0] for x in self.filter_set]) - self.n_views = num_views - - # Recursively grab all mp4 files in subfolders with name 'rgb'. - self.mp4_file_paths = [] - for rgb_root in Path(root_path).rglob("rgb"): - if not rgb_root.is_dir(): - continue - print(rgb_root) - for mp4_file in rgb_root.glob("*.mp4"): - if self.filter_set is None or mp4_file.stem in self.filter_set: - self.mp4_file_paths.append(mp4_file) - - # Process-dependent LRU cache for file handles of the tar files. - self.worker_id = None - self.zip_descriptors = LRU( - self.MAX_ZIP_DESCRIPTORS, callback=self._evict_zip_handle - ) - # self.mp4_readers = LRU(self.MAX_MP4_READERS, callback=self._evict_mp4_reader) - self.depth_folder = depth_folder - self.custom_folders = custom_folders - self.custom_fields = custom_fields - - @staticmethod - def _evict_zip_handle(_, zip_handle): - zip_handle.close() - - @staticmethod - def _evict_mp4_reader(_, mp4_reader: VideoReader): - # This is no-op, just a placeholder. - del mp4_reader - - def _check_worker_id(self): - # Protect handle boundary: - worker_info = torch.utils.data.get_worker_info() - if worker_info is not None: - if self.worker_id is not None: - assert self.worker_id == worker_info.id, "Worker id mismatch" - else: - self.worker_id = worker_info.id - - def _get_zip_handle(self, idx, attr, view_idx): - self._check_worker_id() - - if self.n_views != -1: - dict_key = f"{idx}_{view_idx}_{attr}" - else: - dict_key = f"{idx}_{attr}" - if dict_key in self.zip_descriptors: - return self.zip_descriptors[dict_key] - - rgb_path = self.mp4_file_paths[idx] - root_path, zip_name = rgb_path.parent.parent, rgb_path.stem + ".zip" - if self.n_views != -1: - root_path = root_path.parent / str(view_idx) - - zip_handle = zipfile.ZipFile(root_path / attr / zip_name, "r") - self.zip_descriptors[dict_key] = zip_handle - return zip_handle - - def _get_mp4_reader(self, idx, attr, view_idx): - rgb_path = self.mp4_file_paths[idx] - if self.n_views != -1: - root_path, mp4_name = rgb_path.parent.parent.parent, rgb_path.name - else: - root_path, mp4_name = rgb_path.parent.parent, rgb_path.name - if self.n_views != -1: - root_path = root_path / str(view_idx) - - mp4_reader = VideoReader(str(root_path / attr / mp4_name), num_threads=4) - # self.mp4_readers[dict_key] = mp4_reader - return mp4_reader - - def available_data_fields(self) -> list[DataField]: - return [ - DataField.IMAGE_RGB, - DataField.CAMERA_C2W_TRANSFORM, - DataField.CAMERA_INTRINSICS, - DataField.METRIC_DEPTH, - DataField.DYNAMIC_INSTANCE_MASK, - DataField.BACKWARD_FLOW, - DataField.OBJECT_BBOX, - DataField.CAPTION, - ] - - def num_videos(self) -> int: - return len(self.mp4_file_paths) - - def num_views(self, video_idx: int) -> int: - return 1 if self.n_views == -1 else self.n_views - - def num_frames(self, video_idx: int, view_idx: int = 0) -> int: - return len(self._get_mp4_reader(video_idx, "rgb", view_idx)) - - def _read_data( - self, - video_idx: int, - frame_idxs: List[int], - view_idxs: List[int], - data_fields: List[DataField], - ): - frame_indices = np.asarray(frame_idxs).astype(np.int64) - rgb_path = self.mp4_file_paths[video_idx] - data_base_path, data_key = rgb_path.parent.parent, rgb_path.stem - if self.n_views != -1: - # Currently support only load at most one camera. - assert len(set(view_idxs)) == 1, "Currently support only one view" - view_idx = view_idxs[0] - data_base_path = data_base_path.parent / str(view_idx) - else: - view_idx = 0 - - output_dict: dict[str | DataField, Any] = {"__key__": data_key} - - for data_field in data_fields: - if data_field == DataField.IMAGE_RGB: - rgb_reader = self._get_mp4_reader(video_idx, "rgb", view_idx) - rgb_read = rgb_reader.get_batch(frame_indices) - try: - rgb_np = rgb_read.asnumpy() - except AttributeError: - rgb_np = rgb_read.numpy() - rgb_np = rgb_np.astype(np.float32) / 255.0 - rgb_torch = torch.from_numpy(rgb_np).moveaxis(-1, 1).contiguous() - output_dict[data_field] = rgb_torch - - rgb_reader.seek(0) # set video reader point back to 0 to clean up cache - del rgb_reader - - elif data_field == DataField.CAMERA_C2W_TRANSFORM: - c2w_data = np.load(data_base_path / "pose" / f"{data_key}.npz") - f_idx = np.searchsorted(c2w_data["inds"], frame_indices) - assert np.all( - c2w_data["inds"][f_idx] == frame_indices - ), "Pose not found" - c2w_np = c2w_data["data"][f_idx].astype(np.float32) - c2w_torch = torch.from_numpy(c2w_np).contiguous() - output_dict[data_field] = c2w_torch - - elif data_field == DataField.CAMERA_INTRINSICS: - intrinsics_data = np.load( - data_base_path / "intrinsics" / f"{data_key}.npz" - ) - f_idx = np.searchsorted(intrinsics_data["inds"], frame_indices) - assert np.all( - intrinsics_data["inds"][f_idx] == frame_indices - ), "Intrinsics not found" - intrinsics_np = intrinsics_data["data"][f_idx].astype(np.float32) - intrinsics_torch = torch.from_numpy(intrinsics_np).contiguous() - output_dict[data_field] = intrinsics_torch - - elif data_field == DataField.METRIC_DEPTH: - depth_zip_handle = self._get_zip_handle(video_idx, self.depth_folder, view_idx) - depth_np = [] - for frame_idx in frame_indices: - frame_name = f"{frame_idx:05d}.exr" - with depth_zip_handle.open(frame_name, "r") as f: - exr_file = OpenEXR.InputFile(f) - exr_dw = exr_file.header()["dataWindow"] - depth_np.append( - np.frombuffer(exr_file.channel("Z"), np.float16).reshape( - exr_dw.max.y - exr_dw.min.y + 1, - exr_dw.max.x - exr_dw.min.x + 1, - ) - ) - depth_np = np.stack(depth_np, axis=0).astype(np.float32) - depth_torch = torch.from_numpy(depth_np).contiguous() - output_dict[data_field] = depth_torch - - elif data_field == DataField.OBJECT_BBOX: - bbox_zip_handle = self._get_zip_handle( - video_idx, "object_info", view_idx - ) - bbox_list = [] - for frame_idx in frame_indices: - frame_name = f"{frame_idx:05d}.json" - with bbox_zip_handle.open(frame_name, "r") as f: - bbox_data = json.load(f) - bbox_list.append(bbox_data) - output_dict[data_field] = bbox_list - - elif data_field == DataField.DYNAMIC_INSTANCE_MASK: - mask_zip_handle = self._get_zip_handle(video_idx, "mask", view_idx) - mask_np = [] - for frame_idx in frame_indices: - frame_name = f"{frame_idx:05d}.png" - with mask_zip_handle.open(frame_name, "r") as f: - mask_np.append( - cv2.imdecode( - np.frombuffer(f.read(), np.uint8), cv2.IMREAD_UNCHANGED - ) - ) - mask_np = np.stack(mask_np, axis=0).astype(np.uint8) - mask_torch = torch.from_numpy(mask_np).contiguous() - output_dict[data_field] = mask_torch - - elif data_field == DataField.BACKWARD_FLOW: - flow_zip_handle = self._get_zip_handle(video_idx, "flow", view_idx) - flow_np = [] - for frame_idx in frame_indices: - frame_name = f"{frame_idx:05d}.exr" - with flow_zip_handle.open(frame_name, "r") as f: - exr_file = OpenEXR.InputFile(f) - exr_dw = exr_file.header()["dataWindow"] - height, width = ( - exr_dw.max.y - exr_dw.min.y + 1, - exr_dw.max.x - exr_dw.min.x + 1, - ) - flow_np.append( - np.stack( - [ - np.frombuffer( - exr_file.channel(f"{channel}"), np.float16 - ) - for channel in ["U", "V"] - ], - axis=-1, - ).reshape(height, width, 2) - ) - flow_np = np.stack(flow_np, axis=0).astype(np.float32) - flow_torch = torch.from_numpy(flow_np).contiguous() - output_dict[data_field] = flow_torch - - elif data_field == DataField.CAPTION: - caption_path = data_base_path / "caption" / f"{data_key}.txt" - with open(caption_path, "r") as f: - caption = f.read() - output_dict[data_field] = caption - elif data_field == "custom": - if self.custom_folders is not None: - output_dict[data_field] = {} - assert len(self.custom_folders) == len(self.custom_fields), "Custom folders and types must have the same length" - for custom_folder, custom_fields in zip(self.custom_folders, self.custom_fields): - if custom_fields == "ftheta_intrinsic": - intrinsics_data = np.load( - data_base_path / custom_folder / f"{data_key}.npz" - ) - f_idx = np.searchsorted(intrinsics_data["inds"], frame_indices) - assert np.all( - intrinsics_data["inds"][f_idx] == frame_indices - ), "Intrinsics not found" - intrinsics_np = intrinsics_data["data"][f_idx].astype(np.float32) - intrinsics_torch = torch.from_numpy(intrinsics_np).contiguous() - output_dict[data_field][custom_fields] = intrinsics_torch - elif custom_fields in ["hdmap"]: - mp4_reader = self._get_mp4_reader(video_idx, custom_folder, view_idx) - mp4_read = mp4_reader.get_batch(frame_indices) - try: - mp4_np = mp4_read.asnumpy() - except AttributeError: - mp4_np = mp4_read.numpy() - mp4_np = mp4_np.astype(np.float32) / 255.0 - mp4_torch = torch.from_numpy(mp4_np).moveaxis(-1, 1).contiguous() - output_dict[data_field][custom_fields] = mp4_torch - - mp4_reader.seek(0) # set video reader point back to 0 to clean up cache - del mp4_reader - - else: - raise NotImplementedError(f"Can't handle custom data field {data_field}") - else: - raise NotImplementedError(f"Can't handle data field {data_field}") - - return output_dict diff --git a/lyra_2/_src/datasets/utils.py b/lyra_2/_src/datasets/utils.py deleted file mode 100644 index 2f2139797aad32f26694e830f3a5f508539ba03b..0000000000000000000000000000000000000000 --- a/lyra_2/_src/datasets/utils.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -IMAGE_RES_SIZE_INFO: dict[str, tuple[int, int]] = { - "1080": { - "1,1": (1024, 1024), - "4,3": (1440, 1056), - "3,4": (1056, 1440), - "16,9": (1920, 1056), - "9,16": (1056, 1920), - }, - "1024": {"1,1": (1024, 1024), "4,3": (1168, 880), "3,4": (880, 1168), "16,9": (1360, 768), "9,16": (768, 1360)}, - "720": {"1,1": (960, 960), "4,3": (960, 704), "3,4": (704, 960), "16,9": (1280, 704), "9,16": (704, 1280)}, - "512": {"1,1": (512, 512), "4,3": (640, 512), "3,4": (512, 640), "16,9": (640, 384), "9,16": (384, 640)}, - "480": {"1,1": (480, 480), "4,3": (640, 480), "3,4": (480, 640), "16,9": (768, 432), "9,16": (432, 768)}, - "480p": {"1,1": (640, 640), "4,3": (640, 480), "3,4": (480, 640), "16,9": (832, 480), "9,16": (480, 832)}, - "720robocasa": {"1,1": (720, 720), "4,3": (960, 720), "3,4": (720, 960), "16,9": (1280, 720), "9,16": (720, 1280)}, - "256": { - "1,1": (256, 256), - "4,3": (320, 256), - "3,4": (256, 320), - "16,9": (320, 192), - "9,16": (192, 320), - }, -} - - -VIDEO_RES_SIZE_INFO: dict[str, tuple[int, int]] = { - "1080": { - "1,1": (1024, 1024), - "4,3": (1440, 1072), - "3,4": (1072, 1440), - "16,9": (1920, 1072), - "9,16": (1072, 1920), - }, - "1024": {"1,1": (1024, 1024), "4,3": (1280, 1024), "3,4": (1024, 1280), "16,9": (1280, 768), "9,16": (768, 1280)}, - "720": {"1,1": (960, 960), "4,3": (960, 704), "3,4": (704, 960), "16,9": (1280, 704), "9,16": (704, 1280)}, - "512": {"1,1": (512, 512), "4,3": (640, 512), "3,4": (512, 640), "16,9": (640, 384), "9,16": (384, 640)}, - "480": {"1,1": (480, 480), "4,3": (640, 480), "3,4": (480, 640), "16,9": (768, 432), "9,16": (432, 768)}, - "480p": {"1,1": (640, 640), "4,3": (640, 480), "3,4": (480, 640), "16,9": (832, 480), "9,16": (480, 832)}, - "720p": {"1,1": (960, 960), "4,3": (960, 720), "3,4": (720, 960), "16,9": (1280, 720), "9,16": (720, 1280)}, - "720robocasa": {"1,1": (720, 720), "4,3": (960, 720), "3,4": (720, 960), "16,9": (1280, 720), "9,16": (720, 1280)}, - "256": { - "1,1": (256, 256), - "4,3": (320, 256), - "3,4": (256, 320), - "16,9": (320, 192), - "9,16": (192, 320), - }, -} diff --git a/lyra_2/_src/inference/__init__.py b/lyra_2/_src/inference/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/lyra_2/_src/inference/camera_traj_utils.py b/lyra_2/_src/inference/camera_traj_utils.py deleted file mode 100644 index b8479135d15dee5c54e2d253129725c4a3c33765..0000000000000000000000000000000000000000 --- a/lyra_2/_src/inference/camera_traj_utils.py +++ /dev/null @@ -1,453 +0,0 @@ -# ----------------------------------------------------------------------------- -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# ----------------------------------------------------------------------------- - -import argparse -from typing import Tuple, Optional, Sequence - -import torch - -from lyra_2._src.inference.camera_utils import ( # type: ignore - create_spiral_trajectory, - create_horizontal_trajectory, - create_horizontal_with_noise_trajectory, - create_horizontal_zoom_with_bend_trajectory, - create_horizontal_zoom_with_noise_and_bend_trajectory, - create_back_trajectory, - create_dolly_zoom_trajectory, - create_spiral_horizontal_trajectory, - create_orbit_trajectory, - create_rotate_then_zoom_trajectory, - create_rotate_spot_trajectory, - create_rotate_spot_with_noise_trajectory, -) - - -# Canonical list of supported camera trajectories shared across inference scripts. -CAMERA_TRAJECTORY_CHOICES: Tuple[str, ...] = ( - "original", - "spiral", - "spiral_center", - "spiral_outwards", - "horizontal", - "horizontal_noise", - "horizontal_lift", - "horizontal_lift_noise", - "horizontal_zoom", - "horizontal_zoom_noise", - "horizontal_zoom_bend", - "horizontal_zoom_noise_bend", - "horizontal_zoom_still", - "horizontal_still", - "horizontal_simple", - "vertical_simple", - "horizontal_outward", - "back", - "back_simple", - "dolly_zoom", - "horizontal_spiral", - "orbit_horizontal", - "orbit_vertical", - "rotate_zoom_in", - "rotate_zoom_out", - "rotate_spot", - "rotate_spot_noise", -) - - -def add_camera_traj_args( - parser: argparse.ArgumentParser, - *, - with_video_len: bool = True, - video_len_flag: str = "video_len", - video_len_default: int = 93, - video_len_help: Optional[str] = None, - with_fps: bool = True, - fps_default: int = 16, - trajectory_default: str = "original", - strength_default: float = 0.2, -) -> None: - """Attach shared camera trajectory CLI arguments to an argparse parser. - - """ - if with_video_len: - help_text = ( - video_len_help - if video_len_help is not None - else "Video length (number of frames) for camera trajectory." - ) - parser.add_argument( - f"--{video_len_flag}", - type=int, - default=video_len_default, - help=help_text, - ) - if with_fps: - parser.add_argument( - "--fps", - type=int, - default=fps_default, - help="Output video frame rate.", - ) - parser.add_argument( - "--trajectory", - type=str, - default=trajectory_default, - choices=list(CAMERA_TRAJECTORY_CHOICES), - ) - parser.add_argument( - "--direction", - type=str, - default="right", - choices=["left", "right", "up", "down"], - ) - parser.add_argument( - "--strength", - type=float, - default=strength_default, - ) - - -def build_camera_trajectory( - initial_w2c_44: torch.Tensor, - K_33: torch.Tensor, - center_depth: float, - video_len: int, - trajectory: str, - direction: str, - strength: float, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Shared camera-trajectory builder. - - Returns: - w2cs: [T,4,4] world-to-camera matrices - Ks: [T,3,3] intrinsics per frame - """ - device = initial_w2c_44.device - if trajectory == "original": - w2cs = initial_w2c_44.unsqueeze(0).repeat(video_len, 1, 1) - Ks = K_33.unsqueeze(0).repeat(video_len, 1, 1) - return w2cs, Ks - - if trajectory == "spiral": - radius_x = 0.15 * strength - radius_y = 0.10 * strength - new_w2cs = create_spiral_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - radius_x=radius_x, - radius_y=radius_y, - right=(direction == "right"), - num_circles=2, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "spiral_center": - new_w2cs = create_spiral_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - radius_x=0.03, - radius_y=0.02, - right=(direction == "right"), - start_from_zero=False, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "spiral_outwards": - radius_x = 0.3 * strength - radius_y = 0.2 * strength - new_w2cs = create_spiral_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - inwards=False, - radius_x=radius_x, - radius_y=radius_y, - right=(direction == "right"), - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "horizontal": - new_w2cs = create_horizontal_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - right=(direction == "right"), - axis="x", - distance=strength, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "horizontal_noise": - new_w2cs = create_horizontal_with_noise_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - right=(direction == "right"), - axis="x", - distance=strength, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "horizontal_lift": - new_w2cs = create_horizontal_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - right=(direction == "right"), - axis="y", - distance=strength, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "horizontal_lift_noise": - new_w2cs = create_horizontal_with_noise_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - right=(direction == "right"), - axis="y", - distance=strength, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "horizontal_zoom": - new_w2cs = create_horizontal_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - right=(direction == "right"), - axis="z", - distance=strength, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "horizontal_zoom_noise": - new_w2cs = create_horizontal_with_noise_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - right=(direction == "right"), - axis="z", - distance=strength, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "horizontal_zoom_bend": - new_w2cs = create_horizontal_zoom_with_bend_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - right=(direction == "right"), - axis="z", - distance=strength, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "horizontal_zoom_noise_bend": - new_w2cs = create_horizontal_zoom_with_noise_and_bend_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - right=(direction == "right"), - axis="z", - distance=strength, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "horizontal_spiral": - radius_x = 0.15 * strength * 0.25 * 0.5 - radius_y = 0.10 * strength * 0.25 * 0.5 - num_circles = 2 - new_w2cs = create_spiral_horizontal_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - right=(direction == "right"), - axis="z", - distance=strength, - radius_x=radius_x, - radius_y=radius_y, - num_circles=num_circles, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "horizontal_zoom_still": - half = video_len // 2 - seq1 = create_horizontal_trajectory( - initial_w2c_44, - center_depth, - n_steps=half, - right=(direction == "right"), - axis="z", - distance=strength, - ) - seq2 = seq1[-1:].repeat(video_len - seq1.shape[0], 1, 1) - new_w2cs = torch.cat([seq1, seq2], dim=0) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "horizontal_still": - half = video_len // 2 - seq1 = create_horizontal_trajectory( - initial_w2c_44, - center_depth, - n_steps=half, - right=(direction == "right"), - axis="x", - distance=strength, - ) - seq2 = seq1[-1:].repeat(video_len - seq1.shape[0], 1, 1) - new_w2cs = torch.cat([seq1, seq2], dim=0) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "horizontal_simple": - new_w2cs = initial_w2c_44.unsqueeze(0).repeat(video_len, 1, 1) - shift = torch.linspace(0.0, strength, video_len, device=device) * center_depth - if direction == "right": - new_w2cs[:, 0, 3] -= shift - else: - new_w2cs[:, 0, 3] += shift - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "vertical_simple": - new_w2cs = initial_w2c_44.unsqueeze(0).repeat(video_len, 1, 1) - shift = torch.linspace(0.0, strength, video_len, device=device) * center_depth - if direction == "up": - new_w2cs[:, 1, 3] += shift - else: - new_w2cs[:, 1, 3] -= shift - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "horizontal_outward": - new_w2cs = create_horizontal_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - right=(direction == "right"), - axis="x", - distance=strength, - outwards=True, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "back": - seq = create_back_trajectory( - initial_w2c_44.unsqueeze(0).repeat(video_len, 1, 1), - center_depth, - right=(direction == "right"), - ) - Ks = K_33.unsqueeze(0).repeat(seq.shape[0], 1, 1) - return seq, Ks - - if trajectory == "back_simple": - seq = create_back_trajectory( - initial_w2c_44.unsqueeze(0).repeat(video_len, 1, 1), - center_depth, - right=(direction == "right"), - invert_pos=True, - radius_x=0.15, - radius_y=0.1, - ) - Ks = K_33.unsqueeze(0).repeat(seq.shape[0], 1, 1) - return seq, Ks - - if trajectory == "dolly_zoom": - seq, Ks = create_dolly_zoom_trajectory(initial_w2c_44, K_33, center_depth, n_steps=video_len) - return seq, Ks - - if trajectory == "orbit_horizontal": - new_w2cs = create_orbit_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - angle=strength, - axis="y", - direction=direction, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "orbit_vertical": - new_w2cs = create_orbit_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - angle=strength, - axis="x", - direction=direction, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "rotate_zoom_in": - new_w2cs = create_rotate_then_zoom_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - rotate_direction=direction, - zoom_direction="right", - rotation_angle=20.0, - zoom_distance=strength, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "rotate_zoom_out": - new_w2cs = create_rotate_then_zoom_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - rotate_direction=direction, - zoom_direction="left", - rotation_angle=20.0, - zoom_distance=strength, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "rotate_spot": - new_w2cs = create_rotate_spot_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - rotate_direction=direction, - rotation_angle=1.0 * strength, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - if trajectory == "rotate_spot_noise": - new_w2cs = create_rotate_spot_with_noise_trajectory( - initial_w2c_44, - center_depth, - n_steps=video_len, - rotate_direction=direction, - rotation_angle=1.0 * strength, - ) - Ks = K_33.unsqueeze(0).repeat(new_w2cs.shape[0], 1, 1) - return new_w2cs, Ks - - raise NotImplementedError(f"Unsupported trajectory: {trajectory}") - - - diff --git a/lyra_2/_src/inference/camera_utils.py b/lyra_2/_src/inference/camera_utils.py deleted file mode 100644 index 4e0abc4702b1555dc433ecbe3a135931fc55e3db..0000000000000000000000000000000000000000 --- a/lyra_2/_src/inference/camera_utils.py +++ /dev/null @@ -1,1027 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import torch - - -def apply_transformation(Bx4x4, another_matrix): - B = Bx4x4.shape[0] - if another_matrix.dim() == 2: - another_matrix = another_matrix.unsqueeze(0).expand(B, -1, -1) # Make another_matrix compatible with batch size - transformed_matrix = torch.bmm(Bx4x4, another_matrix) # Shape: (B, 4, 4) - - return transformed_matrix - - -def look_at_matrix(camera_pos, target, invert_pos=True): - """Creates a 4x4 look-at matrix, keeping the camera pointing towards a target.""" - forward = (target - camera_pos).float() - forward = forward / torch.norm(forward) - - up = torch.tensor([0.0, 1.0, 0.0], device=camera_pos.device) # assuming Y-up coordinate system - right = torch.cross(up, forward, dim=0) - right = right / torch.norm(right) - up = torch.cross(forward, right, dim=0) - - look_at = torch.eye(4, device=camera_pos.device) - look_at[0, :3] = right - look_at[1, :3] = up - look_at[2, :3] = forward - if invert_pos: - # Proper world-to-camera translation: t = -R @ C - look_at[:3, 3] = -look_at[:3, :3] @ camera_pos - else: - look_at[:3, 3] = camera_pos - - return look_at - - -def slerp(quat1, quat2, t): - """Spherical linear interpolation (SLERP) between two quaternions.""" - dot_product = torch.dot(quat1, quat2) - - if dot_product < 0.0: - quat2 = -quat2 - dot_product = -dot_product - - dot_product = torch.clamp(dot_product, -1.0, 1.0) - theta_0 = torch.acos(dot_product) - sin_theta_0 = torch.sin(theta_0) - - if sin_theta_0 > 1e-6: - theta = theta_0 * t - sin_theta = torch.sin(theta) - s1 = torch.sin(theta_0 - theta) / sin_theta_0 - s2 = sin_theta / sin_theta_0 - return s1 * quat1 + s2 * quat2 - else: - return (1.0 - t) * quat1 + t * quat2 - - -def matrix_to_quaternion(matrix, device): - """Converts a 3x3 rotation matrix to a quaternion.""" - m = matrix[:3, :3] - trace = torch.trace(m) - - if trace > 0.0: - s = torch.sqrt(trace + 1.0) * 2.0 - w = 0.25 * s - x = (m[2, 1] - m[1, 2]) / s - y = (m[0, 2] - m[2, 0]) / s - z = (m[1, 0] - m[0, 1]) / s - elif (m[0, 0] > m[1, 1]) and (m[0, 0] > m[2, 2]): - s = torch.sqrt(1.0 + m[0, 0] - m[1, 1] - m[2, 2]) * 2.0 - w = (m[2, 1] - m[1, 2]) / s - x = 0.25 * s - y = (m[0, 1] + m[1, 0]) / s - z = (m[0, 2] + m[2, 0]) / s - elif m[1, 1] > m[2, 2]: - s = torch.sqrt(1.0 + m[1, 1] - m[0, 0] - m[2, 2]) * 2.0 - w = (m[0, 2] - m[2, 0]) / s - x = (m[0, 1] + m[1, 0]) / s - y = 0.25 * s - z = (m[1, 2] + m[2, 1]) / s - else: - s = torch.sqrt(1.0 + m[2, 2] - m[0, 0] - m[1, 1]) * 2.0 - w = (m[1, 0] - m[0, 1]) / s - x = (m[0, 2] + m[2, 0]) / s - y = (m[1, 2] + m[2, 1]) / s - z = 0.25 * s - - return torch.tensor([w, x, y, z], device=device) - - -def quaternion_to_matrix(quat, device): - """Converts a quaternion to a 4x4 rotation matrix.""" - w, x, y, z = quat - rotation = torch.eye(4, device=device) - - rotation[0, 0] = 1 - 2 * (y * y + z * z) - rotation[0, 1] = 2 * (x * y - z * w) - rotation[0, 2] = 2 * (x * z + y * w) - - rotation[1, 0] = 2 * (x * y + z * w) - rotation[1, 1] = 1 - 2 * (x * x + z * z) - rotation[1, 2] = 2 * (y * z - x * w) - - rotation[2, 0] = 2 * (x * z - y * w) - rotation[2, 1] = 2 * (y * z + x * w) - rotation[2, 2] = 1 - 2 * (x * x + y * y) - - return rotation - - -def get_translation_matrix(dx, dy, dz, device): - """Creates a 4x4 translation matrix.""" - translation = torch.eye(4, device=device) - translation[0, 3] = dx - translation[1, 3] = dy - translation[2, 3] = dz - return translation - - -def interpolate(trajectory, n_steps_per_segment, device): - interpolated_trajectory = [trajectory[0]] - - for i in range(len(trajectory) - 1): - start_matrix = trajectory[i] - end_matrix = trajectory[i + 1] - - start_pos = start_matrix[:3, 3] - end_pos = end_matrix[:3, 3] - - start_rot = matrix_to_quaternion(start_matrix, device) - end_rot = matrix_to_quaternion(end_matrix, device) - - for t in torch.linspace(0, 1, n_steps_per_segment + 1, device=device)[1:]: - interp_pos = (1 - t) * start_pos + t * end_pos - interp_rot = slerp(start_rot, end_rot, t) - - interp_matrix = torch.eye(4, device=device) - interp_matrix[:3, :3] = quaternion_to_matrix(interp_rot, device)[:3, :3] - interp_matrix[:3, 3] = interp_pos - - interpolated_trajectory.append(interp_matrix) - interpolated_trajectory = torch.stack(interpolated_trajectory) - - -def create_spiral_trajectory( - world_to_camera_matrix, - center_depth, - radius_x=0.03, - radius_y=0.02, - radius_z=0.0, - right=True, - inwards=True, - n_steps=13, - device="cuda", - start_from_zero=True, - num_circles=1, -): - """ - Create a spiral camera trajectory that follows a given motion, keeps the camera looking at a point, - and interpolates between trajectory points to create a smooth movement between camera positions. - - Parameters: - - world_to_camera_matrix (torch.Tensor): 4x4 camera-to-world matrix. - - num_points (int): Number of key points in the spiral motion. - - radius (float): Spiral radius for the camera motion. - - look_at (torch.Tensor): 3D point the camera should look at. - - n_steps_per_segment (int): Number of steps to interpolate between each key point. - - device (str): The device on which to perform the calculations (e.g., 'cpu' or 'cuda'). - - Returns: - - interpolated_trajectory (list): List of 4x4 matrices representing the interpolated camera positions. - """ - # Move all inputs to the specified device - # world_to_camera_matrix = world_to_camera_matrix.to(device) - look_at = torch.tensor([0.0, 0.0, center_depth]).to(device) - - # Spiral motion key points - trajectory = [] - spiral_positions = [] - initial_camera_pos = torch.tensor([0, 0, 0], device=device) # world_to_camera_matrix[:3, 3].clone() - - example_scale = 1.0 - - theta_max = 2 * math.pi * num_circles - - for i in range(n_steps): - # theta = 2 * math.pi * i / (n_steps-1) # angle for each point - theta = theta_max * i / (n_steps - 1) # angle for each point - if start_from_zero: - x = radius_x * (math.cos(theta) - 1) * (1 if right else -1) * (center_depth / example_scale) - else: - x = radius_x * (math.cos(theta)) * (center_depth / example_scale) - - y = radius_y * math.sin(theta) * (center_depth / example_scale) - z = radius_z * math.sin(theta) * (center_depth / example_scale) - spiral_positions.append(torch.tensor([x, y, z], device=device)) - - for pos in spiral_positions: - if inwards: - view_matrix = look_at_matrix(initial_camera_pos + pos, look_at) - else: - view_matrix = look_at_matrix(initial_camera_pos, look_at + pos) - trajectory.append(view_matrix) - trajectory = torch.stack(trajectory) - return apply_transformation(trajectory, world_to_camera_matrix) - - -def create_orbit_trajectory( - world_to_camera_matrix, - center_depth, - n_steps=13, - angle=math.pi / 4, - axis="y", - direction="right", - device="cuda", -): - """ - Create a constant-radius orbit around the fixed look-at point [0, 0, center_depth]. - The camera always looks at the point while moving along a circular arc. - - Args: - world_to_camera_matrix (torch.Tensor): Base 4x4 world-to-camera transform to post-multiply. - center_depth (float or tensor): Z of the look-at point; also equals initial radius. - n_steps (int): Number of frames in the orbit. - angle (float): Angular sweep (radians) away from the initial pose (theta0 = pi). - axis (str): 'y' for horizontal (left/right) orbit, 'x' for vertical (up/down) orbit. - direction (str): 'right'/'left' for horizontal, 'up'/'down' for vertical; sets sweep sign. - device (str): Device where tensors are created. - - Returns: - torch.Tensor: [n_steps, 4, 4] sequence of world-to-camera matrices. - """ - # Resolve scalar radius - try: - r = float(center_depth) - except Exception: - r = center_depth.item() if hasattr(center_depth, "item") else float(center_depth) - - look_at = torch.tensor([0.0, 0.0, r], device=device) - - # Determine sweep direction - if axis == "y": - sweep_sign = -1.0 if direction == "right" else 1.0 - elif axis == "x": - sweep_sign = 1.0 if direction == "up" else -1.0 - else: - raise ValueError("axis must be 'x' or 'y'") - - trajectory = [] - for i in range(n_steps): - frac = 0.0 if n_steps <= 1 else (i / (n_steps - 1)) - # Sweep from 0 -> angle relative to the initial heading (theta0=pi) - theta = math.pi + sweep_sign * (angle * frac) - - if axis == "y": - # Horizontal orbit (rotate around Y axis) - cx = r * math.sin(theta) - cy = 0.0 - cz = r + r * math.cos(theta) - else: - # Vertical orbit (rotate around X axis) - cx = 0.0 - cy = r * math.sin(theta) - cz = r + r * math.cos(theta) - - camera_pos = torch.tensor([cx, cy, cz], device=device) - view_matrix = look_at_matrix(camera_pos, look_at) - trajectory.append(view_matrix) - - trajectory = torch.stack(trajectory) - return apply_transformation(trajectory, world_to_camera_matrix) - - -def create_horizontal_trajectory( - world_to_camera_matrix, center_depth, right=True, n_steps=13, distance=0.1, device="cuda", axis="x", outwards=False -): - if axis == "z": - look_at = torch.tensor([0.0, 0.0, center_depth * (distance+1.0)]).to(device) - else: - look_at = torch.tensor([0.0, 0.0, center_depth]).to(device) - # Spiral motion key points - trajectory = [] - translation_positions = [] - initial_camera_pos = torch.tensor([0, 0, 0], device=device) - - for i in range(n_steps): - if axis == "x": - x = i * distance * center_depth / n_steps * (1 if right else -1) - y = 0 - z = 0 - elif axis == "y": - x = 0 - y = i * distance * center_depth / n_steps * (1 if right else -1) - z = 0 - elif axis == "z": - x = 0 - y = 0 - z = i * distance * center_depth / n_steps * (1 if right else -1) - else: - raise ValueError("Axis should be x, y or z") - - translation_positions.append(torch.tensor([x, y, z], device=device)) - - for pos in translation_positions: - camera_pos = initial_camera_pos + pos - if outwards: - _look_at = look_at + pos * 2 - else: - _look_at = look_at - view_matrix = look_at_matrix(camera_pos, _look_at) - trajectory.append(view_matrix) - trajectory = torch.stack(trajectory) - return apply_transformation(trajectory, world_to_camera_matrix) - - -def create_horizontal_with_noise_trajectory( - world_to_camera_matrix, center_depth, right=True, n_steps=13, distance=0.1, device="cuda", axis="x", outwards=False, noise_percentage=0.0001 -): - """Create a horizontal trajectory with the same primary axis movement as create_horizontal_trajectory, - but with random noise added to the two perpendicular dimensions. - - Args: - world_to_camera_matrix: Transformation matrix from world to camera space - center_depth: Depth at the center - right: Direction of movement along primary axis - n_steps: Number of steps in the trajectory - distance: Base distance (used for movement and noise scaling) - device: Device to create tensors on - axis: Primary axis ("x", "y", or "z") - movement along this axis - outwards: Whether to move look_at outwards - noise_percentage: Percentage of distance * center_depth to use as noise range (default 0.0001) - - Returns: - Trajectory tensor with primary axis movement plus noise in perpendicular dimensions - """ - if axis == "z": - look_at = torch.tensor([0.0, 0.0, center_depth * (distance+1.0)]).to(device) - else: - look_at = torch.tensor([0.0, 0.0, center_depth]).to(device) - - trajectory = [] - translation_positions = [] - initial_camera_pos = torch.tensor([0, 0, 0], device=device) - - # Calculate noise magnitude as percentage of distance * center_depth - noise_magnitude = noise_percentage * distance * center_depth - - for i in range(n_steps): - # Primary axis movement (same as create_horizontal_trajectory) - primary_movement = i * distance * center_depth / n_steps * (1 if right else -1) - - # Sample random noise for the two perpendicular dimensions independently per timestep - # Using Gaussian distribution centered around zero with std = noise_magnitude - noise1 = torch.randn(1, device=device).item() * noise_magnitude - noise2 = torch.randn(1, device=device).item() * noise_magnitude - - if axis == "x": - x = primary_movement - y = noise1 - z = noise2 - elif axis == "y": - x = noise1 - y = primary_movement - z = noise2 - elif axis == "z": - x = noise1 - y = noise2 - z = primary_movement - else: - raise ValueError("Axis should be x, y or z") - - translation_positions.append(torch.tensor([x, y, z], device=device)) - - for pos in translation_positions: - camera_pos = initial_camera_pos + pos - if outwards: - _look_at = look_at + pos * 2 - else: - _look_at = look_at - view_matrix = look_at_matrix(camera_pos, _look_at) - trajectory.append(view_matrix) - - trajectory = torch.stack(trajectory) - return apply_transformation(trajectory, world_to_camera_matrix) - - -def create_horizontal_zoom_with_bend_trajectory( - world_to_camera_matrix, center_depth, right=True, n_steps=13, distance=0.1, device="cuda", axis="z", outwards=False, bend_percentage_in=0.04, bend_percentage_out=0.12, axis_bend="y" -): - """Create a horizontal trajectory with bend applied to a perpendicular axis. - - Args: - world_to_camera_matrix: Transformation matrix from world to camera space - center_depth: Depth at the center - right: Direction of movement along primary axis (True = zoom in, False = zoom out) - n_steps: Number of steps in the trajectory - distance: Base distance for movement - device: Device to create tensors on - axis: Primary axis ("x", "y", or "z") - movement along this axis - outwards: Whether to move look_at outwards - bend_percentage_in: Percentage of distance * center_depth to use as bend for zoom in (default 0.04) - bend_percentage_out: Percentage of distance * center_depth to use as bend for zoom out (default 0.12) - axis_bend: Axis to apply bend to (default "y" for vertical bend when axis="z") - When right=True, bend goes positive; when right=False, bend goes negative - - Returns: - Trajectory tensor with primary axis movement plus bend on perpendicular axis - """ - if axis == "z": - look_at = torch.tensor([0.0, 0.0, center_depth * (distance+1.0)]).to(device) - else: - look_at = torch.tensor([0.0, 0.0, center_depth]).to(device) - - trajectory = [] - translation_positions = [] - initial_camera_pos = torch.tensor([0, 0, 0], device=device) - - # Select bend percentage based on direction - bend_percentage = bend_percentage_in if right else bend_percentage_out - # Calculate bend magnitude as percentage of distance * center_depth - bend_magnitude = bend_percentage * distance * center_depth - - for i in range(n_steps): - # Primary axis movement (same as create_horizontal_trajectory) - primary_movement = i * distance * center_depth / n_steps * (1 if right else -1) - - # Bend: positive when right=True, negative when right=False - bend_base = bend_magnitude * i / n_steps - bend_value = bend_base if right else -bend_base - - if axis == "x": - x = primary_movement - if axis_bend == "y": - y = bend_value - z = 0 - elif axis_bend == "z": - y = 0 - z = bend_value - else: - raise ValueError(f"axis_bend must be perpendicular to axis. For axis='x', use 'y' or 'z'") - elif axis == "y": - y = primary_movement - if axis_bend == "x": - x = bend_value - z = 0 - elif axis_bend == "z": - x = 0 - z = bend_value - else: - raise ValueError(f"axis_bend must be perpendicular to axis. For axis='y', use 'x' or 'z'") - elif axis == "z": - z = primary_movement - if axis_bend == "x": - x = bend_value - y = 0 - elif axis_bend == "y": - x = 0 - y = bend_value - else: - raise ValueError(f"axis_bend must be perpendicular to axis. For axis='z', use 'x' or 'y'") - else: - raise ValueError("Axis should be x, y or z") - - translation_positions.append(torch.tensor([x, y, z], device=device)) - - for pos in translation_positions: - camera_pos = initial_camera_pos + pos - if outwards: - _look_at = look_at + pos * 2 - else: - _look_at = look_at - view_matrix = look_at_matrix(camera_pos, _look_at) - trajectory.append(view_matrix) - - trajectory = torch.stack(trajectory) - return apply_transformation(trajectory, world_to_camera_matrix) - - -def create_horizontal_zoom_with_noise_and_bend_trajectory( - world_to_camera_matrix, center_depth, right=True, n_steps=13, distance=0.1, device="cuda", axis="z", outwards=False, noise_percentage=0.0001, bend_percentage_in=0.04, bend_percentage_out=0.12, axis_bend="y" -): - """Create a horizontal trajectory with Gaussian noise and bend applied to a perpendicular axis. - - Args: - world_to_camera_matrix: Transformation matrix from world to camera space - center_depth: Depth at the center - right: Direction of movement along primary axis (True = zoom in, False = zoom out) - n_steps: Number of steps in the trajectory - distance: Base distance for movement - device: Device to create tensors on - axis: Primary axis ("x", "y", or "z") - movement along this axis - outwards: Whether to move look_at outwards - noise_percentage: Percentage of distance * center_depth to use as noise range (default 0.0001) - bend_percentage_in: Percentage of distance * center_depth to use as bend for zoom in (default 0.04) - bend_percentage_out: Percentage of distance * center_depth to use as bend for zoom out (default 0.12) - axis_bend: Axis to apply bend to (default "y" for vertical bend when axis="z") - When right=True, bend goes positive; when right=False, bend goes negative - - Returns: - Trajectory tensor with primary axis movement plus noise and bend on perpendicular axis - """ - if axis == "z": - look_at = torch.tensor([0.0, 0.0, center_depth * (distance+1.0)]).to(device) - else: - look_at = torch.tensor([0.0, 0.0, center_depth]).to(device) - - trajectory = [] - translation_positions = [] - initial_camera_pos = torch.tensor([0, 0, 0], device=device) - - # Calculate noise magnitude - noise_magnitude = noise_percentage * distance * center_depth - # Select bend percentage based on direction - bend_percentage = bend_percentage_in if right else bend_percentage_out - # Calculate bend magnitude as percentage of distance * center_depth - bend_magnitude = bend_percentage * distance * center_depth - - for i in range(n_steps): - # Primary axis movement (same as create_horizontal_trajectory) - primary_movement = i * distance * center_depth / n_steps * (1 if right else -1) - - # Bend: positive when right=True, negative when right=False - bend_base = bend_magnitude * i / n_steps - bend_value = bend_base if right else -bend_base - - # Sample random noise for the two perpendicular dimensions independently per timestep - # Using Gaussian distribution centered around zero - noise1 = torch.randn(1, device=device).item() * noise_magnitude - noise2 = torch.randn(1, device=device).item() * noise_magnitude - - if axis == "x": - x = primary_movement - if axis_bend == "y": - y = bend_value + noise1 - z = noise2 - elif axis_bend == "z": - y = noise1 - z = bend_value + noise2 - else: - raise ValueError(f"axis_bend must be perpendicular to axis. For axis='x', use 'y' or 'z'") - elif axis == "y": - y = primary_movement - if axis_bend == "x": - x = bend_value + noise1 - z = noise2 - elif axis_bend == "z": - x = noise1 - z = bend_value + noise2 - else: - raise ValueError(f"axis_bend must be perpendicular to axis. For axis='y', use 'x' or 'z'") - elif axis == "z": - z = primary_movement - if axis_bend == "x": - x = bend_value + noise1 - y = noise2 - elif axis_bend == "y": - x = noise1 - y = bend_value + noise2 - else: - raise ValueError(f"axis_bend must be perpendicular to axis. For axis='z', use 'x' or 'y'") - else: - raise ValueError("Axis should be x, y or z") - - translation_positions.append(torch.tensor([x, y, z], device=device)) - - for pos in translation_positions: - camera_pos = initial_camera_pos + pos - if outwards: - _look_at = look_at + pos * 2 - else: - _look_at = look_at - view_matrix = look_at_matrix(camera_pos, _look_at) - trajectory.append(view_matrix) - - trajectory = torch.stack(trajectory) - return apply_transformation(trajectory, world_to_camera_matrix) - - -def create_back_trajectory( - forward_trajectory, - center_depth, - radius_x=0.3, - radius_y=0.2, - radius_z=0.0, - inwards=True, - right=True, - device="cuda", - invert_pos=False, -): - look_at = torch.tensor([0.0, 0.0, center_depth]).to(device) - if not inwards: - look_at *= -1 - - # Spiral motion key points - trajectory = [] - spiral_positions = [] - initial_camera_pos = torch.tensor([0, 0, 0], device=device) # world_to_camera_matrix[:3, 3].clone() - n_steps = forward_trajectory.shape[0] - 1 - for i in range(n_steps): - theta = 2 * math.pi * i / (n_steps - 1) # angle for each point - x = radius_x * (math.cos(theta) - 1) * (1 if right else -1) - y = radius_y * math.sin(theta) - z = radius_z * math.sin(theta) - spiral_positions.append(torch.tensor([x, y, z], device=device)) - - for pos in spiral_positions: - camera_pos = initial_camera_pos + pos - view_matrix = look_at_matrix(camera_pos, look_at, invert_pos=invert_pos) - trajectory.append(view_matrix) - trajectory = torch.stack(trajectory) - backward_trajectory = apply_transformation(trajectory, forward_trajectory[:n_steps].flip(0)) - return torch.cat([forward_trajectory, backward_trajectory]) - - -def create_dolly_zoom_trajectory( - world_to_camera_matrix, intrinsic, center_depth, n_steps=13, shift_z=-3, device="cuda" -): - center_depth = center_depth * 0.185 - - look_at = torch.tensor([0.0, 0.0, center_depth], device=device) - - trajectory = [] - intrinsics_list = [] - translation_positions = [] - initial_camera_pos = torch.tensor([0.0, 0.0, 0.0], device=device) - - f0 = intrinsic[0, 0] - f1 = intrinsic[1, 1] - - Z_subject = center_depth # The Z position of the subject. now use 1.0 as default - z0 = initial_camera_pos[2].item() # also use default - - D0 = Z_subject - z0 # Initial distance to the subject - - for i in range(n_steps): - x = 0.0 - y = 0.0 - z = shift_z * i / (n_steps - 1) - translation_positions.append(torch.tensor([x, y, z], device=device)) - - for pos in translation_positions: - camera_pos = initial_camera_pos - pos - zi = -camera_pos[2].item() - Di = Z_subject - zi - - # Create new intrinsic matrix - new_intrinsic = intrinsic.clone() - new_intrinsic[0, 0] = f0 * (D0 / Di) - new_intrinsic[1, 1] = f1 * (D0 / Di) - intrinsics_list.append(new_intrinsic) - - # Create the view matrix - view_matrix = look_at_matrix(camera_pos, look_at) - trajectory.append(view_matrix) - - trajectory = torch.stack(trajectory) - intrinsics_list = torch.stack(intrinsics_list) - - # Apply transformation to trajectory - transformed_trajectory = apply_transformation(trajectory, world_to_camera_matrix) - return transformed_trajectory, intrinsics_list - - -def create_spiral_horizontal_trajectory( - world_to_camera_matrix, - center_depth, - radius_x=0.03, - radius_y=0.02, - distance=0.1, # zoom or horizontal move distance - right=True, - inwards=True, - n_steps=20, - num_circles=1, - device="cuda", - axis="x", # "x", "y", or "z" (z = zoom) -): - """ - Combine spiral motion with horizontal/zoom translation. - - axis="x" or "y" → spiral + horizontal pan - - axis="z" → spiral + zoom in/out - """ - # Match your other functions' convention - if axis == "z": - look_at = torch.tensor([0.0, 0.0, center_depth * (distance + 1.0)], device=device) - else: - look_at = torch.tensor([0.0, 0.0, center_depth], device=device) - - trajectory = [] - initial_camera_pos = torch.tensor([0.0, 0.0, 0.0], device=device) - theta_max = 2 * math.pi * num_circles - - for i in range(n_steps): - t = i / (n_steps - 1) - theta = theta_max * t - - # --- Spiral (on-the-spot orbit) --- - if axis == "z": - # Spiral in x/y plane while zooming along z - x_spiral = radius_x * (math.cos(theta) - 1) * (1 if right else -1) * center_depth - y_spiral = radius_y * math.sin(theta) * center_depth - z_spiral = 0.0 - else: - x_spiral = radius_x * math.cos(theta) * center_depth - y_spiral = radius_y * math.sin(theta) * center_depth - z_spiral = 0.0 - - # --- Linear motion / zoom --- - offset = t * distance * center_depth * (1 if right else -1) - if axis == "x": - x = x_spiral + offset - y = y_spiral - z = z_spiral - elif axis == "y": - x = x_spiral - y = y_spiral + offset - z = z_spiral - elif axis == "z": - x = x_spiral - y = y_spiral - z = z_spiral + offset # zoom motion - else: - raise ValueError("Axis should be 'x', 'y', or 'z'") - - camera_pos = initial_camera_pos + torch.tensor([x, y, z], device=device) - - # Keep same look_at logic as your original spiral/horizontal versions - if inwards: - view_matrix = look_at_matrix(camera_pos, look_at) - else: - view_matrix = look_at_matrix(initial_camera_pos, look_at + camera_pos) - - trajectory.append(view_matrix) - - trajectory = torch.stack(trajectory) - return apply_transformation(trajectory, world_to_camera_matrix) - - -def create_rotate_then_zoom_trajectory( - world_to_camera_matrix, - center_depth, - n_steps=13, - rotate_direction="right", - zoom_direction="right", - rotation_angle=20.0, - zoom_distance=0.1, - device="cuda", -): - """ - Create a trajectory that first rotates on the spot (no translation), then zooms in/out. - - Args: - world_to_camera_matrix (torch.Tensor): Base 4x4 world-to-camera transform. - center_depth (float): Z of the look-at point. - n_steps (int): Number of frames in the trajectory. - rotate_direction (str): 'right' or 'left' - rotation direction (default: 'right'). - zoom_direction (str): 'right' or 'left' - zoom direction (default: 'right'). - 'right' = zoom in (forward), 'left' = zoom out (backward). - rotation_angle (float): Total rotation angle in degrees (default: 20.0). - zoom_distance (float): Zoom distance as fraction of center_depth. - device (str): Device where tensors are created. - - Returns: - torch.Tensor: [n_steps, 4, 4] sequence of world-to-camera matrices. - """ - # Convert rotation angle from degrees to radians - rotation_angle_rad = math.radians(rotation_angle) - - look_at = torch.tensor([0.0, 0.0, center_depth], device=device) - initial_camera_pos = torch.tensor([0.0, 0.0, 0.0], device=device) - - trajectory = [] - half_steps = n_steps // 2 - - # Determine rotation direction - rotation_sign = 1.0 if rotate_direction == "right" else -1.0 - - # Determine zoom direction: right = zoom in (forward), left = zoom out (backward) - zoom_sign = 1.0 if zoom_direction == "right" else -1.0 - - # Pre-compute the view matrix at the transition point (end of rotation, start of zoom) - # This ensures smooth continuity between the two phases - final_angle = rotation_sign * rotation_angle_rad - cos_final = math.cos(final_angle) - sin_final = math.sin(final_angle) - final_rotation_matrix = torch.tensor( - [ - [cos_final, 0, sin_final, 0], - [0, 1, 0, 0], - [-sin_final, 0, cos_final, 0], - [0, 0, 0, 1], - ], - device=device, - ) - transition_view_matrix = look_at_matrix(initial_camera_pos, look_at) - transition_view_matrix = final_rotation_matrix @ transition_view_matrix - - # Extract the forward direction from the transition view matrix - # In look_at_matrix, the third row (index 2) is the forward direction in world space - rotation_part = transition_view_matrix[:3, :3] - rotated_forward = rotation_part[2, :3].clone() # Forward direction in world space - rotated_forward = rotated_forward / torch.norm(rotated_forward) - - # Compute the effective look_at point that the camera is looking at after rotation - # The camera at origin with rotated orientation is looking along rotated_forward - # So the look_at point is: origin + rotated_forward * distance_to_look_at - # The distance is the same as center_depth (distance from origin to original look_at) - zoom_look_at = initial_camera_pos + rotated_forward * center_depth - - for i in range(n_steps): - if i < half_steps: - # First half: rotate on the spot (no translation) - # Rotate the camera's view direction around Y-axis while keeping position fixed - # The look_at point stays fixed, but we rotate the camera's orientation - t = i / max(half_steps - 1, 1) # Normalize to [0, 1] - angle = rotation_sign * rotation_angle_rad * t - - # Create rotation matrix around Y-axis - cos_a = math.cos(angle) - sin_a = math.sin(angle) - rotation_matrix = torch.tensor( - [ - [cos_a, 0, sin_a, 0], - [0, 1, 0, 0], - [-sin_a, 0, cos_a, 0], - [0, 0, 0, 1], - ], - device=device, - ) - - # Camera stays at origin, but rotates its orientation - # Start with initial view matrix - view_matrix = look_at_matrix(initial_camera_pos, look_at) - # Apply rotation to the view matrix (this rotates the camera's orientation) - view_matrix = rotation_matrix @ view_matrix - - else: - # Second half: zoom along forward direction - # Keep the rotation from the last rotation state, only change translation - # The look_at point (zoom_look_at) is computed from rotation state for reference - - # Zoom progress in second half (starts from 0 at transition) - t_zoom = (i - half_steps) / max(n_steps - half_steps - 1, 1) # Normalize to [0, 1] - zoom_offset = zoom_sign * zoom_distance * center_depth * t_zoom - - # Move camera along the forward direction from the transition point - camera_pos = initial_camera_pos + rotated_forward * zoom_offset - - # Keep the exact same rotation as at transition, only update translation - # Translation in world-to-camera: -R^T @ camera_pos - # This ensures smooth continuity: at zoom_offset=0, view_matrix = transition_view_matrix - view_matrix = transition_view_matrix.clone() - view_matrix[:3, 3] = -rotation_part.T @ camera_pos - - trajectory.append(view_matrix) - - trajectory = torch.stack(trajectory) - return apply_transformation(trajectory, world_to_camera_matrix) - - -def create_rotate_spot_trajectory( - world_to_camera_matrix, - center_depth, - n_steps=13, - rotate_direction="right", - rotation_angle=65.0, - device="cuda", -): - """ - Create a trajectory that rotates on the spot (no translation). - - Args: - world_to_camera_matrix (torch.Tensor): Base 4x4 world-to-camera transform. - center_depth (float): Z of the look-at point. - n_steps (int): Number of frames in the trajectory. - rotate_direction (str): 'right' or 'left' - rotation direction (default: 'right'). - rotation_angle (float): Total rotation angle in degrees (default: 65.0). - device (str): Device where tensors are created. - - Returns: - torch.Tensor: [n_steps, 4, 4] sequence of world-to-camera matrices. - """ - # Convert rotation angle from degrees to radians - rotation_angle_rad = math.radians(rotation_angle) - - look_at = torch.tensor([0.0, 0.0, center_depth], device=device) - initial_camera_pos = torch.tensor([0.0, 0.0, 0.0], device=device) - - trajectory = [] - - # Determine rotation direction - rotation_sign = 1.0 if rotate_direction == "right" else -1.0 - - for i in range(n_steps): - # Rotate on the spot (no translation) - # Rotate the camera's view direction around Y-axis while keeping position fixed - # The look_at point stays fixed, but we rotate the camera's orientation - t = i / max(n_steps - 1, 1) # Normalize to [0, 1] - angle = rotation_sign * rotation_angle_rad * t - - # Create rotation matrix around Y-axis - cos_a = math.cos(angle) - sin_a = math.sin(angle) - rotation_matrix = torch.tensor( - [ - [cos_a, 0, sin_a, 0], - [0, 1, 0, 0], - [-sin_a, 0, cos_a, 0], - [0, 0, 0, 1], - ], - device=device, - ) - - # Camera stays at origin, but rotates its orientation - # Start with initial view matrix - view_matrix = look_at_matrix(initial_camera_pos, look_at) - # Apply rotation to the view matrix (this rotates the camera's orientation) - view_matrix = rotation_matrix @ view_matrix - - trajectory.append(view_matrix) - - trajectory = torch.stack(trajectory) - return apply_transformation(trajectory, world_to_camera_matrix) - - -def create_rotate_spot_with_noise_trajectory( - world_to_camera_matrix, - center_depth, - n_steps=13, - rotate_direction="right", - rotation_angle=65.0, - device="cuda", - noise_percentage=0.0001, - distance=0.1, -): - """ - Create a trajectory that mimics rotate_spot but adds tiny position offsets to avoid - degenerate covariance errors in DA3 pose alignment during autoregressive generation. - - This is identical to rotate_spot except it adds minimal smooth position variation - that's imperceptible visually but provides enough variation for pose alignment algorithms. - - Args: - world_to_camera_matrix (torch.Tensor): Base 4x4 world-to-camera transform. - center_depth (float): Z of the look-at point. - n_steps (int): Number of frames in the trajectory. - rotate_direction (str): 'right' or 'left' - rotation direction (default: 'right'). - rotation_angle (float): Total rotation angle in degrees (default: 65.0). - device (str): Device where tensors are created. - noise_percentage: Percentage of distance * center_depth for position offset (default 0.0001). - distance: Base distance (used for offset scaling, default 0.1) - - Returns: - torch.Tensor: [n_steps, 4, 4] sequence of world-to-camera matrices. - """ - # Convert rotation angle from degrees to radians - rotation_angle_rad = math.radians(rotation_angle) - - look_at = torch.tensor([0.0, 0.0, center_depth], device=device) - initial_camera_pos = torch.tensor([0.0, 0.0, 0.0], device=device) - - trajectory = [] - - # Determine rotation direction - rotation_sign = 1.0 if rotate_direction == "right" else -1.0 - - # Use a tiny smooth orbit instead of noise - this provides smooth, deterministic variation - # The orbit radius is very small to avoid visible movement, but large enough for pose alignment - # Use noise_percentage to scale the orbit radius - orbit_radius = noise_percentage * distance * center_depth # Small orbit radius - - for i in range(n_steps): - # Rotate on the spot (with tiny orbit for pose alignment) - # Rotate the camera's view direction around Y-axis while keeping position nearly fixed - t = i / max(n_steps - 1, 1) # Normalize to [0, 1] - angle = rotation_sign * rotation_angle_rad * t - - # Create rotation matrix around Y-axis - cos_a = math.cos(angle) - sin_a = math.sin(angle) - rotation_matrix = torch.tensor( - [ - [cos_a, 0, sin_a, 0], - [0, 1, 0, 0], - [-sin_a, 0, cos_a, 0], - [0, 0, 0, 1], - ], - device=device, - ) - - # Add tiny smooth orbit in X-Y plane (circular motion) - # This provides smooth, deterministic variation for pose alignment - # Use absolute frame index with very slow frequency to ensure smoothness across all frames - # The orbit should be imperceptibly slow to be visually identical to pure rotation - orbit_angle = 2 * math.pi * i * 0.001 # Very slow orbit based on absolute frame index - orbit_x = orbit_radius * math.cos(orbit_angle) - orbit_y = orbit_radius * math.sin(orbit_angle) - orbit_z = 0.0 - - # Camera position with tiny orbit (visually identical to pure rotation) - camera_pos = initial_camera_pos + torch.tensor([orbit_x, orbit_y, orbit_z], device=device) - - # Start with view matrix at orbiting position - view_matrix = look_at_matrix(camera_pos, look_at) - # Apply rotation to the view matrix (this rotates the camera's orientation) - view_matrix = rotation_matrix @ view_matrix - - trajectory.append(view_matrix) - - trajectory = torch.stack(trajectory) - return apply_transformation(trajectory, world_to_camera_matrix) diff --git a/lyra_2/_src/inference/depth_utils.py b/lyra_2/_src/inference/depth_utils.py deleted file mode 100644 index 8c74e9c26886ac7c8af9401b96b2915239da3733..0000000000000000000000000000000000000000 --- a/lyra_2/_src/inference/depth_utils.py +++ /dev/null @@ -1,217 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Depth model utilities: MoGe monocular depth and Depth Anything 3 loaders.""" - -import os -import re -import sys -from typing import Tuple - -import cv2 -import torch - -from lyra_2._ext.imaginaire.utils import log - - -# --------------------------------------------------------------------------- -# MoGe -# --------------------------------------------------------------------------- - -def load_moge_model(device: torch.device): - # Disable xformers for MoGe's DINOv2 backbone so it doesn't dispatch to - # flash_attn CUDA kernels that segfault on aarch64 / Grace Hopper. - os.environ["XFORMERS_DISABLED"] = "1" - try: - from moge.model.v1 import MoGeModel # type: ignore - except Exception as e: - raise ImportError("MoGe is required for --input_image_path flow. Please install `moge`. Error: " + str(e)) - - for _name, _mod in sys.modules.items(): - if "moge" in _name and hasattr(_mod, "XFORMERS_ENABLED"): - _mod.XFORMERS_ENABLED = False - - model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device) - model.eval() - return model - - -def moge_infer_depth_intrinsics( - moge_model, - img_rgb_uint8: torch.Tensor, - depth_pred_hw: Tuple[int, int] = (720, 1280), - target_hw: Tuple[int, int] = (704, 1280), -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Args: - img_rgb_uint8: [H,W,3] uint8 on CPU - Returns: - image_chw_norm: [1,3,Ht,Wt] float in [0,1] - depth_hw: [Ht,Wt] float (large for invalid) - intrinsics_33: [3,3] pixel units - mask_hw: [Ht,Wt] bool mask - """ - device = next(moge_model.parameters()).device - Ht, Wt = target_hw - - img_resized = cv2.resize(img_rgb_uint8.numpy(), (Wt, Ht), interpolation=cv2.INTER_LINEAR) - img_chw_0_1 = torch.tensor(img_resized / 255.0, dtype=torch.float32, device=device).permute(2, 0, 1) - - with torch.no_grad(): - out = moge_model.infer(img_chw_0_1) - - depth_hw = out["depth"].to(device) - mask_hw = out["mask"].to(device) - K_norm = out["intrinsics"].to(device) # 3x3 normalized - - depth_hw = torch.nan_to_num(depth_hw, nan=1e4).clamp(min=0, max=1e4) - depth_hw = torch.where(mask_hw == 0, torch.tensor(1000.0, device=device, dtype=depth_hw.dtype), depth_hw) - - # Scale intrinsics to pixel units for (Wt, Ht) - K = K_norm.clone() - K[0, 0] *= Wt - K[1, 1] *= Ht - K[0, 2] *= Wt - K[1, 2] *= Ht - - return img_chw_0_1.unsqueeze(0), depth_hw, K, mask_hw - - -# --------------------------------------------------------------------------- -# Depth Anything 3 -# --------------------------------------------------------------------------- - -def _import_da3_api(): - """Lazy-import DepthAnything3 from the vendored depth_anything_3 submodule.""" - da3_src_root = os.path.join( - os.path.dirname(__file__), - "depth_anything_3", - "src", - ) - if da3_src_root not in sys.path: - sys.path.insert(0, da3_src_root) - stale_keys = [k for k in sys.modules if k == "depth_anything_3" or k.startswith("depth_anything_3.")] - for k in stale_keys: - del sys.modules[k] - try: - from depth_anything_3.api import DepthAnything3 # type: ignore - except Exception as e: - raise ImportError( - "Failed to import DepthAnything3 from vendored depth_anything_3. " - "Make sure the depth_anything_3 submodule is present and its " - "dependencies are available." - ) from e - return DepthAnything3 - - -def _resolve_da3_local_model_name(model_name: str) -> str: - """Map a Hub-style model id to a local DA3 config name when possible.""" - da3_src_root = os.path.join( - os.path.dirname(__file__), - "depth_anything_3", - "src", - ) - if da3_src_root not in sys.path: - sys.path.insert(0, da3_src_root) - from depth_anything_3.registry import MODEL_REGISTRY # type: ignore - - candidates = [str(model_name).strip()] - tail = candidates[0].split("/")[-1] - candidates.extend( - [ - tail, - tail.lower(), - tail.lower().replace("_", "-"), - re.sub(r"[-_]\d+(?:\.\d+)*$", "", tail.lower().replace("_", "-")), - ] - ) - for candidate in candidates: - if candidate in MODEL_REGISTRY: - return candidate - raise KeyError( - "Unable to map DA3 model name to a local config: " - f"{model_name}. Available local configs: {', '.join(MODEL_REGISTRY.keys())}" - ) - - -def load_da3_from_custom_checkpoint( - ckpt_path: str, - pretrained_path: str = "depth-anything/DA3NESTED-GIANT-LARGE-1.1", - device: str = "cuda", - strict: bool = True, -): - """Load DepthAnything3 from a custom finetuned checkpoint.""" - DepthAnything3 = _import_da3_api() - local_model_name = _resolve_da3_local_model_name(pretrained_path) - log.info( - f"Initializing DA3 architecture from local config: {local_model_name} " - f"(requested={pretrained_path})" - ) - model = DepthAnything3(model_name=local_model_name) - - log.info(f"Loading custom checkpoint from {ckpt_path}") - checkpoint = torch.load(ckpt_path, map_location="cpu") - - if "module" in checkpoint: - state_dict = checkpoint["module"] - elif "model" in checkpoint: - state_dict = checkpoint["model"] - else: - state_dict = checkpoint - - # Strip "model." prefix - prefix = "model." - converted = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} - - missing, unexpected = model.model.load_state_dict(converted, strict=strict) - if missing and not strict: - log.info( - f"Custom checkpoint is missing {len(missing)} tensors; " - f"falling back to pretrained base weights from {pretrained_path}." - ) - model = DepthAnything3.from_pretrained(pretrained_path) - missing, unexpected = model.model.load_state_dict(converted, strict=strict) - if missing: - log.info(f"Missing keys when loading custom checkpoint: {len(missing)}") - if unexpected: - log.info(f"Unexpected keys when loading custom checkpoint: {len(unexpected)}") - log.info(f"Loaded {len(converted)} parameters from custom checkpoint") - - model = model.to(device=device) - model.eval() - return model - - -def load_da3_model( - da3_model_name: str, - da3_model_path_custom: str = None, - device: str = "cuda", -): - """Load a DepthAnything3 model, optionally from a custom checkpoint.""" - if da3_model_path_custom is not None: - log.info(f"Loading DA3 model with custom checkpoint: {da3_model_path_custom}") - model = load_da3_from_custom_checkpoint( - ckpt_path=da3_model_path_custom, - pretrained_path=da3_model_name, - device=device, - strict=False, - ) - else: - log.info(f"Loading DA3 model from pretrained: {da3_model_name}") - DepthAnything3 = _import_da3_api() - model = DepthAnything3.from_pretrained(da3_model_name).to(device) - - model.eval() - return model diff --git a/lyra_2/_src/inference/get_t5_emb.py b/lyra_2/_src/inference/get_t5_emb.py deleted file mode 100644 index 525d261c9ef4da5a9a6dea91627ac1f1b441fa55..0000000000000000000000000000000000000000 --- a/lyra_2/_src/inference/get_t5_emb.py +++ /dev/null @@ -1,577 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import html -import math -import string -from typing import List, Optional, Union - -import ftfy -import regex as re -import torch -import torch.distributed.checkpoint as dcp -import torch.nn as nn -import torch.nn.functional as F -from torch.distributed.checkpoint import FileSystemReader -from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner -from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict -from transformers import AutoTokenizer - -from lyra_2._ext.imaginaire.utils import distributed, log, misc -from lyra_2._ext.imaginaire.utils.easy_io import easy_io - - -def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text - - -def canonicalize(text, keep_punctuation_exact_string=None): - text = text.replace("_", " ") - if keep_punctuation_exact_string: - text = keep_punctuation_exact_string.join( - part.translate(str.maketrans("", "", string.punctuation)) - for part in text.split(keep_punctuation_exact_string) - ) - else: - text = text.translate(str.maketrans("", "", string.punctuation)) - text = text.lower() - text = re.sub(r"\s+", " ", text) - return text.strip() - - -class HuggingfaceTokenizer: - def __init__(self, name, seq_len=None, clean=None, **kwargs): - assert clean in (None, "whitespace", "lower", "canonicalize") - self.name = name - self.seq_len = seq_len - self.clean = clean - - # init tokenizer - self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) - self.vocab_size = self.tokenizer.vocab_size - - def __call__(self, sequence, **kwargs): - return_mask = kwargs.pop("return_mask", False) - - # arguments - _kwargs = {"return_tensors": "pt"} - if self.seq_len is not None: - _kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len}) - _kwargs.update(**kwargs) - - # tokenization - if isinstance(sequence, str): - sequence = [sequence] - if self.clean: - sequence = [self._clean(u) for u in sequence] - ids = self.tokenizer(sequence, **_kwargs) - - # output - if return_mask: - return ids.input_ids, ids.attention_mask - else: - return ids.input_ids - - def _clean(self, text): - if self.clean == "whitespace": - text = whitespace_clean(basic_clean(text)) - elif self.clean == "lower": - text = whitespace_clean(basic_clean(text)).lower() - elif self.clean == "canonicalize": - text = canonicalize(basic_clean(text)) - return text - - -def fp16_clamp(x): - if x.dtype == torch.float16 and torch.isinf(x).any(): - clamp = torch.finfo(x.dtype).max - 1000 - x = torch.clamp(x, min=-clamp, max=clamp) - return x - - -def init_weights(m): - if isinstance(m, T5LayerNorm): - nn.init.ones_(m.weight) - elif isinstance(m, T5Model): - nn.init.normal_(m.token_embedding.weight, std=1.0) - elif isinstance(m, T5FeedForward): - nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) - nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) - nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) - elif isinstance(m, T5Attention): - nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5) - nn.init.normal_(m.k.weight, std=m.dim**-0.5) - nn.init.normal_(m.v.weight, std=m.dim**-0.5) - nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5) - elif isinstance(m, T5RelativeEmbedding): - nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5) - - -class GELU(nn.Module): - def forward(self, x): - return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) - - -class T5LayerNorm(nn.Module): - def __init__(self, dim, eps=1e-6): - super(T5LayerNorm, self).__init__() - self.dim = dim - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def forward(self, x): - x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps) - if self.weight.dtype in [torch.float16, torch.bfloat16]: - x = x.type_as(self.weight) - return self.weight * x - - -class T5Attention(nn.Module): - def __init__(self, dim, dim_attn, num_heads, dropout=0.1): - assert dim_attn % num_heads == 0 - super(T5Attention, self).__init__() - self.dim = dim - self.dim_attn = dim_attn - self.num_heads = num_heads - self.head_dim = dim_attn // num_heads - - # layers - self.q = nn.Linear(dim, dim_attn, bias=False) - self.k = nn.Linear(dim, dim_attn, bias=False) - self.v = nn.Linear(dim, dim_attn, bias=False) - self.o = nn.Linear(dim_attn, dim, bias=False) - self.dropout = nn.Dropout(dropout) - - def forward(self, x, context=None, mask=None, pos_bias=None): - context = x if context is None else context - b, n, c = x.size(0), self.num_heads, self.head_dim - - q = self.q(x).view(b, -1, n, c) - k = self.k(context).view(b, -1, n, c) - v = self.v(context).view(b, -1, n, c) - - attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) - if pos_bias is not None: - attn_bias += pos_bias - if mask is not None: - assert mask.ndim in [2, 3] - mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1) - attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) - - attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias - attn = F.softmax(attn.float(), dim=-1).type_as(attn) - x = torch.einsum("bnij,bjnc->binc", attn, v) - - x = x.reshape(b, -1, n * c) - x = self.o(x) - x = self.dropout(x) - return x - - -class T5FeedForward(nn.Module): - def __init__(self, dim, dim_ffn, dropout=0.1): - super(T5FeedForward, self).__init__() - self.dim = dim - self.dim_ffn = dim_ffn - - self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) - self.fc1 = nn.Linear(dim, dim_ffn, bias=False) - self.fc2 = nn.Linear(dim_ffn, dim, bias=False) - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - x = self.fc1(x) * self.gate(x) - x = self.dropout(x) - x = self.fc2(x) - x = self.dropout(x) - return x - - -class T5SelfAttention(nn.Module): - def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): - super(T5SelfAttention, self).__init__() - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.num_buckets = num_buckets - self.shared_pos = shared_pos - - self.norm1 = T5LayerNorm(dim) - self.attn = T5Attention(dim, dim_attn, num_heads, dropout) - self.norm2 = T5LayerNorm(dim) - self.ffn = T5FeedForward(dim, dim_ffn, dropout) - self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) - - def forward(self, x, mask=None, pos_bias=None): - e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) - x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) - x = fp16_clamp(x + self.ffn(self.norm2(x))) - return x - - -class T5CrossAttention(nn.Module): - def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): - super(T5CrossAttention, self).__init__() - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.num_buckets = num_buckets - self.shared_pos = shared_pos - - self.norm1 = T5LayerNorm(dim) - self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) - self.norm2 = T5LayerNorm(dim) - self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) - self.norm3 = T5LayerNorm(dim) - self.ffn = T5FeedForward(dim, dim_ffn, dropout) - self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) - - def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None): - e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) - x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) - x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)) - x = fp16_clamp(x + self.ffn(self.norm3(x))) - return x - - -class T5RelativeEmbedding(nn.Module): - def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): - super(T5RelativeEmbedding, self).__init__() - self.num_buckets = num_buckets - self.num_heads = num_heads - self.bidirectional = bidirectional - self.max_dist = max_dist - - self.embedding = nn.Embedding(num_buckets, num_heads) - - def forward(self, lq, lk): - device = self.embedding.weight.device - rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1) - rel_pos = self._relative_position_bucket(rel_pos) - rel_pos_embeds = self.embedding(rel_pos) - rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk] - return rel_pos_embeds.contiguous() - - def _relative_position_bucket(self, rel_pos): - if self.bidirectional: - num_buckets = self.num_buckets // 2 - rel_buckets = (rel_pos > 0).long() * num_buckets - rel_pos = torch.abs(rel_pos) - else: - num_buckets = self.num_buckets - rel_buckets = 0 - rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) - - max_exact = num_buckets // 2 - rel_pos_large = ( - max_exact - + ( - torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact) - ).long() - ) - rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) - rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) - return rel_buckets - - -class T5Encoder(nn.Module): - def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1): - super(T5Encoder, self).__init__() - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.num_layers = num_layers - self.num_buckets = num_buckets - self.shared_pos = shared_pos - - self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) - self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None - self.dropout = nn.Dropout(dropout) - self.blocks = nn.ModuleList( - [ - T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) - for _ in range(num_layers) - ] - ) - self.norm = T5LayerNorm(dim) - - self.apply(init_weights) - - def forward(self, ids, mask=None): - x = self.token_embedding(ids) - x = self.dropout(x) - e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None - for block in self.blocks: - x = block(x, mask, pos_bias=e) - x = self.norm(x) - x = self.dropout(x) - return x - - -class T5Decoder(nn.Module): - def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1): - super(T5Decoder, self).__init__() - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.num_layers = num_layers - self.num_buckets = num_buckets - self.shared_pos = shared_pos - - self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) - self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None - self.dropout = nn.Dropout(dropout) - self.blocks = nn.ModuleList( - [ - T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) - for _ in range(num_layers) - ] - ) - self.norm = T5LayerNorm(dim) - - self.apply(init_weights) - - def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): - b, s = ids.size() - - if mask is None: - mask = torch.tril(torch.ones(1, s, s).to(ids.device)) - elif mask.ndim == 2: - mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) - - x = self.token_embedding(ids) - x = self.dropout(x) - e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None - for block in self.blocks: - x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) - x = self.norm(x) - x = self.dropout(x) - return x - - -class T5Model(nn.Module): - def __init__( - self, - vocab_size, - dim, - dim_attn, - dim_ffn, - num_heads, - encoder_layers, - decoder_layers, - num_buckets, - shared_pos=True, - dropout=0.1, - ): - super(T5Model, self).__init__() - self.vocab_size = vocab_size - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.encoder_layers = encoder_layers - self.decoder_layers = decoder_layers - self.num_buckets = num_buckets - - self.token_embedding = nn.Embedding(vocab_size, dim) - self.encoder = T5Encoder( - self.token_embedding, dim, dim_attn, dim_ffn, num_heads, encoder_layers, num_buckets, shared_pos, dropout - ) - self.decoder = T5Decoder( - self.token_embedding, dim, dim_attn, dim_ffn, num_heads, decoder_layers, num_buckets, shared_pos, dropout - ) - self.head = nn.Linear(dim, vocab_size, bias=False) - - self.apply(init_weights) - - def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): - x = self.encoder(encoder_ids, encoder_mask) - x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) - x = self.head(x) - return x - - -def _t5(name, encoder_only=False, decoder_only=False, dtype=torch.float32, device="cpu", **kwargs): - assert not (encoder_only and decoder_only) - - if encoder_only: - model_cls = T5Encoder - kwargs["vocab"] = kwargs.pop("vocab_size") - kwargs["num_layers"] = kwargs.pop("encoder_layers") - _ = kwargs.pop("decoder_layers") - elif decoder_only: - model_cls = T5Decoder - kwargs["vocab"] = kwargs.pop("vocab_size") - kwargs["num_layers"] = kwargs.pop("decoder_layers") - _ = kwargs.pop("encoder_layers") - else: - model_cls = T5Model - - with torch.device(device): - model = model_cls(**kwargs) - - model = model.to(dtype=dtype, device=device) - return model - - -def umt5_xxl(**kwargs): - cfg = dict( - vocab_size=256384, - dim=4096, - dim_attn=4096, - dim_ffn=10240, - num_heads=64, - encoder_layers=24, - decoder_layers=24, - num_buckets=32, - shared_pos=False, - dropout=0.1, - ) - cfg.update(**kwargs) - return _t5("umt5-xxl", **cfg) - - -def load_model_dcp(model, ckpt_path): - storage_reader = FileSystemReader(ckpt_path) - _state_dict = get_model_state_dict(model) - dcp.load(_state_dict, storage_reader=storage_reader, planner=DefaultLoadPlanner(allow_partial_load=True)) - log.info(set_model_state_dict(model, _state_dict, options=StateDictOptions(strict=False))) - return model - - -def load_model_torch(model, ckpt_path): - if distributed.is_rank0(): - ckpt = easy_io.load( - ckpt_path, - map_location="cuda", - fast_backend=True, - ) - model.load_state_dict(ckpt) - - distributed.sync_model_states(model, src=0) - return model - - -def load_model_torch_cpu(model, ckpt_path): - """CPU-only state_dict load to avoid any CUDA allocation during init.""" - if distributed.is_rank0(): - ckpt = easy_io.load( - ckpt_path, - map_location="cpu", - fast_backend=True, - ) - model.load_state_dict(ckpt) - - return model - - -class UMT5EncoderModel: - def __init__( - self, - text_len=512, - dtype=torch.bfloat16, - device=torch.cuda.current_device(), - checkpoint_path="./checkpoints/text_encoder/encoder.pth", - tokenizer_path="google/umt5-xxl", - enable_fsdp_shard: bool = False, - load_on_cpu: bool = False, - ): - assert not enable_fsdp_shard, "FSDP is not supported for UMT5" - self.text_len = text_len - self.dtype = dtype - self.device = device - - model = umt5_xxl(encoder_only=True, dtype=dtype, device=device).eval().requires_grad_(False) - log.info(f"loading {checkpoint_path}") - if checkpoint_path.endswith(".dcp"): - model = load_model_dcp(model, checkpoint_path) - else: - assert checkpoint_path.endswith(".pth"), "only .pth or .dcp are supported" - if load_on_cpu: - model = load_model_torch_cpu(model, checkpoint_path) - else: - model = load_model_torch(model, checkpoint_path) - self.model = model - self.model.to(self.device) - self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace") - - def __call__(self, texts, device: Optional[torch.device] = None): - if device is None: - device = self.device - ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) - ids = ids.to(device) - mask = mask.to(device) - seq_lens = mask.gt(0).sum(dim=1).long() - context = self.model(ids, mask) - stack_emb = [] - for u, length in zip(context, seq_lens): - if length > self.text_len: - stack_emb.append(u[: self.text_len]) - else: - zeros = torch.zeros(self.text_len - length, u.shape[1]).to(u) - stack_emb.append(torch.cat([u[:length], zeros], dim=0)) - return torch.stack(stack_emb) - - -t5_encoder: Optional[UMT5EncoderModel] = None -_t5_offloaded: Optional[UMT5EncoderModel] = None - - -def get_umt5_embedding( - prompts: Union[str, List[str]], - device: str = "cuda", - max_length: int = 512, -) -> torch.Tensor: - global t5_encoder - if t5_encoder is None: - t5_encoder = UMT5EncoderModel(device=device) - return t5_encoder(prompts, device=device) - - -@torch.no_grad() -def get_umt5_embedding_offloaded( - prompts: Union[str, List[str]], - device: str = "cuda", - max_length: int = 512, -) -> torch.Tensor: - """ - Load UMT5 encoder on CPU only, move to CUDA just for the call, then back to CPU and clear CUDA cache. - """ - global _t5_offloaded - if _t5_offloaded is None: - _t5_offloaded = UMT5EncoderModel(device="cpu", load_on_cpu=True) - _t5_offloaded.model.to(device) - _t5_offloaded.device = device - emb = _t5_offloaded(prompts, device=device) - _t5_offloaded.model.to("cpu") - _t5_offloaded.device = "cpu" - if torch.cuda.is_available(): - try: - torch.cuda.empty_cache() - except Exception: - pass - return emb diff --git a/lyra_2/_src/inference/lyra2_ar_inference.py b/lyra_2/_src/inference/lyra2_ar_inference.py deleted file mode 100644 index c8f457c642f93bd1769899f354ef7f9e92dc733b..0000000000000000000000000000000000000000 --- a/lyra_2/_src/inference/lyra2_ar_inference.py +++ /dev/null @@ -1,1325 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from lyra_2._ext.imaginaire.visualize.video import save_img_or_video -from lyra_2._ext.imaginaire.utils import log, misc -import os -import re -import sys -from pathlib import Path -from typing import List, Optional, Tuple - -import numpy as np -from megatron.core import parallel_state -from einops import rearrange, repeat -import torch -import tqdm -from lyra_2._src.models.lyra2_model import Sparse3DCache -import gc -from lyra_2._src.datasets.forward_warp_utils_pytorch import ( - reliable_depth_mask_range_batch, -) - -torch.enable_grad(False) - -def _get_vae_handles(model): - vae_iface = model.tokenizer - vae_wrap = vae_iface.model # WanVAE wrapper - vae_core = vae_wrap.model # WanVAE_ core - return vae_iface, vae_wrap, vae_core - -def _prime_encoder_cache_with_history(init_video, vae_wrap, vae_core, model=None, enable_offload=False): - """Advance encoder cache through the history pixels and return history_latents plus the live cache.""" - # Offload diffusion model to CPU before VAE operations - _offload_diffusion_to_cpu(model, enable_offload) - - # Fresh encode using model helpers, then clone caches and normalize - vae_core.clear_cache() - with vae_wrap.context: - video_cast = init_video.to(vae_wrap.dtype) if not vae_wrap.is_amp else init_video - feats = model._vae_encode_range_stream(video_cast, 0, video_cast.shape[2], skip_first_frame=False) - enc_feat_cache = model._clone_vae_cache(vae_core._enc_feat_map) - history_latents = model._encoder_feats_to_normalized_latents(feats).contiguous().to(init_video.dtype) - - # Restore diffusion model to GPU after VAE operations - _restore_diffusion_to_gpu(model, enable_offload) - - return history_latents, enc_feat_cache - - -def _decode_new_latent_chunk(vae_wrap, vae_core, dec_feat_cache, latent_chunk, latent_offset, model=None, enable_offload=False): - """Stream-decode new latent chunk given current decoder cache; return pixel frames for this chunk.""" - # Offload diffusion model to CPU before VAE operations - _offload_diffusion_to_cpu(model, enable_offload) - - # Unnormalize per-frame to mu using offset, then apply channel unscale, conv2, and stream through decoder - B, C, T_new, H, W = latent_chunk.shape - # Build per-frame stats slice for offset positions - if T_new == 1 and latent_offset == 0: - mu = latent_chunk * vae_wrap.img_std.type_as(latent_chunk) + vae_wrap.img_mean.type_as(latent_chunk) - else: - mu = latent_chunk * vae_wrap.video_std[:, :, :1].type_as(latent_chunk) \ - + vae_wrap.video_mean[:, :, :1].type_as(latent_chunk) - # Channel unscale - mean_c, inv_std_c = vae_wrap.scale[0], vae_wrap.scale[1] - if torch.is_tensor(mean_c): - z = mu / inv_std_c.view(1, vae_core.z_dim, 1, 1, 1).type_as(mu) + mean_c.view(1, vae_core.z_dim, 1, 1, 1).type_as(mu) - else: - z = mu / inv_std_c + mean_c - with vae_wrap.context: - if not vae_wrap.is_amp: - z = z.to(vae_wrap.dtype) - x = vae_core.conv2(z) - # Decode one temporal slice at a time to mirror encoder streaming and keep memory low - outs = [] - for t in range(T_new): - feat_idx = [0] - out_t = vae_core.decoder(x[:, :, t : t + 1, :, :], feat_cache=dec_feat_cache, feat_idx=feat_idx) - outs.append(out_t) - video_chunk = torch.cat(outs, dim=2) - - # Restore diffusion model to GPU after VAE operations - _restore_diffusion_to_gpu(model, enable_offload) - - return video_chunk - - -def _add_tiny_offsets_to_extrinsics_for_pose_alignment( - extrinsics_np: np.ndarray, - variance_threshold: float = 1e-10, - offset_scale: float = 1e-6, - offset_freq_x: float = 0.00001, - offset_freq_y: float = 0.000013, - offset_freq_z: float = 0.000007, -) -> np.ndarray: - """ - Add tiny offsets to extrinsics to avoid degenerate covariance in pose alignment. - - This function detects when camera positions are nearly identical (which causes - degenerate covariance errors in Umeyama alignment) and adds minimal offsets - only when needed. - - Args: - extrinsics_np: [N, 4, 4] numpy array of extrinsics (world-to-camera matrices) - variance_threshold: Threshold for position variance to detect degenerate case (default 1e-10) - offset_scale: Scale of offsets to add (default 1e-6) - offset_freq_x: Frequency for X-axis offset pattern (default 0.00001) - offset_freq_y: Frequency for Y-axis offset pattern (default 0.000013) - offset_freq_z: Frequency for Z-axis offset pattern (default 0.000007) - - Returns: - Modified extrinsics_np with tiny offsets added if needed - """ - # Extract camera positions from extrinsics (world-to-camera: translation is -R^T @ camera_pos) - camera_positions = [] - for ext in extrinsics_np: - R = ext[:3, :3] - t = ext[:3, 3] - # Recover camera position: camera_pos = -R^T @ t - camera_pos = -R.T @ t - camera_positions.append(camera_pos) - camera_positions = np.stack(camera_positions, axis=0) - - # Check if positions are nearly identical (small variance) - pos_variance = np.var(camera_positions, axis=0).sum() - if pos_variance < variance_threshold: - # Add tiny offsets to provide minimal variation for pose alignment - # Using slow sinusoidal patterns with different frequencies per axis: - # - Different frequencies ensure independent variation on each axis - # - Mixing sin/cos ensures smooth, continuous variation - # - Very slow frequencies keep offsets imperceptible - # - Based on absolute frame index 'i' for consistency across autoregressive chunks - extrinsics_modified = extrinsics_np.copy() - for i, ext in enumerate(extrinsics_modified): - offset = np.array([ - offset_scale * np.sin(2 * np.pi * i * offset_freq_x), # X: sin with configurable frequency - offset_scale * np.cos(2 * np.pi * i * offset_freq_y), # Y: cos with configurable frequency - offset_scale * np.sin(2 * np.pi * i * offset_freq_z), # Z: sin with configurable frequency - ]) - R = ext[:3, :3] - # Modify translation: new_t = old_t - R @ offset - extrinsics_modified[i, :3, 3] = ext[:3, 3] - R @ offset - return extrinsics_modified - - return extrinsics_np - - -def _predict_da3_depth_window( - *, - da3_model, - history_frames: torch.Tensor, - start_index: int, - abs_last_idx: int, - cam_w2c: torch.Tensor, - intrinsics: torch.Tensor, - frame_interval: int, - max_history_frames: int, - process_res: int = 504, - process_res_method: str = "upper_bound_resize", - add_pose_alignment_offsets: bool = False, - include_ar_chunk_last_frames: bool = False, - ar_chunk_size_frames: int | None = None, - return_raw_predicted_pose: bool = False, -): - """Depth Anything 3 inference for a temporal window ending at abs_last_idx. - - - Select frames: last, last-N, last-2N, ... going backwards, up to max_history_frames and >= 0. - - Optionally also include the last frame of each AR chunk: 0, N, 2N, ... (de-duplicated). - When enabled, these extra frames are added on top of the base selection (no total-count budget enforcement). - - Uses images from history_frames (in [-1,1]) converted to uint8, and corresponding cameras. - - Returns numpy prediction plus the selected absolute frame indices. - """ - assert history_frames.dim() == 5, "history_frames must be [B,C,T,H,W]" - B, C, T_total, H, W = history_frames.shape - if B != 1: - raise ValueError("DA3 backend currently supports batch size B=1 only.") - - if abs_last_idx < 0: - return {"frame_indices": [], "prediction": None} - - # Build absolute frame indices along the global video timeline - selected_frames: List[int] = [] - step = max(int(frame_interval), 1) - for k in range(int(max_history_frames)): - f = abs_last_idx - k * step - if f < 0: - break - selected_frames.append(f) - if include_ar_chunk_last_frames: - if ar_chunk_size_frames is None: - raise ValueError("ar_chunk_size_frames must be provided when include_ar_chunk_last_frames=True") - chunk = max(int(ar_chunk_size_frames), 1) - - selected_frames.extend(list(range(0, int(abs_last_idx) + 1, chunk))) - - selected_frames = sorted(set(int(x) for x in selected_frames)) - if not selected_frames: - return {"frame_indices": [], "prediction": None} - - images: List[np.ndarray] = [] - exts: List[np.ndarray] = [] - ixts: List[np.ndarray] = [] - - hist_cpu = history_frames[0].detach().cpu() # [C,T,H,W] - cam_cpu = cam_w2c.detach().cpu() - intr_cpu = intrinsics.detach().cpu() - - for f in selected_frames: - t = start_index + f - if t < 0 or t >= T_total: - continue - frame_chw = hist_cpu[:, t] # [C,H,W] in [-1,1] - frame_0_1 = (frame_chw * 0.5 + 0.5).clamp(0.0, 1.0) - frame_hwc = frame_0_1.permute(1, 2, 0).float().numpy() - images.append(np.clip(frame_hwc * 255.0 + 0.5, 0, 255).astype(np.uint8)) - exts.append(cam_cpu[0, f].numpy().astype(np.float32)) - ixts.append(intr_cpu[0, f].numpy().astype(np.float32)) - - if not images: - return {"frame_indices": [], "prediction": None} - - extrinsics_np = np.stack(exts, axis=0).astype(np.float32) - intrinsics_np = np.stack(ixts, axis=0).astype(np.float32) - - # Use the input image long side as DA3 process_res; with upper_bound_resize this avoids resizing. - process_res = int(max(images[0].shape[0], images[0].shape[1])) - process_res_method = "upper_bound_resize" - - # Optionally add tiny offsets to extrinsics to avoid degenerate covariance in pose alignment - # This only affects DA3 inference, not the actual camera trajectory - if add_pose_alignment_offsets: - extrinsics_np = _add_tiny_offsets_to_extrinsics_for_pose_alignment( - extrinsics_np, - ) - - prediction = da3_model.inference( - image=images, - extrinsics=extrinsics_np, - intrinsics=intrinsics_np, - align_to_input_extrinsics=not return_raw_predicted_pose, - align_to_input_ext_scale=not return_raw_predicted_pose, - infer_gs=False, - process_res=process_res, - process_res_method=process_res_method, - reorder_cam_token_by_reference=True, - export_dir=None, - export_format="mini_npz", - ) - - return { - "frame_indices": selected_frames, - "prediction": prediction, - } - - - -def _camera_centers_from_w2c(w2c: torch.Tensor) -> torch.Tensor: - """Compute camera centers from world-to-camera matrices.""" - R = w2c[:, :3, :3] - t = w2c[:, :3, 3] - return -(R.transpose(1, 2) @ t.unsqueeze(-1)).squeeze(-1) - - -def _intrinsics_vec_to_k33(intrinsics_vec: torch.Tensor) -> torch.Tensor: - """Convert VIPE intrinsics (fx,fy,cx,cy,...) to 3x3 matrices.""" - if intrinsics_vec.ndim != 2 or intrinsics_vec.shape[1] < 4: - raise ValueError(f"Expected intrinsics shape (T,>=4), got {tuple(intrinsics_vec.shape)}") - fx, fy, cx, cy = (intrinsics_vec[:, 0], intrinsics_vec[:, 1], intrinsics_vec[:, 2], intrinsics_vec[:, 3]) - T = intrinsics_vec.shape[0] - K = torch.zeros((T, 3, 3), dtype=intrinsics_vec.dtype, device=intrinsics_vec.device) - K[:, 0, 0] = fx - K[:, 1, 1] = fy - K[:, 0, 2] = cx - K[:, 1, 2] = cy - K[:, 2, 2] = 1.0 - return K - - -def _offload_diffusion_to_cpu(model, enable_offload: bool): - """Move diffusion model to CPU if offload is enabled.""" - if enable_offload and hasattr(model, 'net'): - model.net.cpu() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -def _restore_diffusion_to_gpu(model, enable_offload: bool): - """Move diffusion model back to GPU if offload is enabled.""" - if enable_offload and hasattr(model, 'net'): - model.net.to(model.tensor_kwargs.get("device", "cuda")) - - -def safe_to(obj, device=None, dtype=None, skip_keys: set | None = None): - """Recursively move tensors to device/dtype while skipping dtype conversion for specific keys. - - - skip_keys: keys in dict for which we only move to device (keep original dtype) - """ - if skip_keys is None: - skip_keys = set() - - def _move_tensor(t: torch.Tensor, force_dtype: bool) -> torch.Tensor: - if device is None and (dtype is None or not force_dtype): - return t - if device is not None and dtype is not None and force_dtype: - return t.to(device=device, dtype=dtype, non_blocking=True) - if device is not None: - return t.to(device=device, non_blocking=True) - if dtype is not None and force_dtype: - return t.to(dtype=dtype) - return t - - if torch.is_tensor(obj): - return _move_tensor(obj, force_dtype=True) - if isinstance(obj, dict): - out = {} - for k, v in obj.items(): - if torch.is_tensor(v): - out[k] = _move_tensor(v, force_dtype=(k not in skip_keys)) - else: - out[k] = safe_to(v, device=device, dtype=dtype, skip_keys=skip_keys) - return out - if isinstance(obj, (list, tuple)): - seq = [safe_to(v, device=device, dtype=dtype, skip_keys=skip_keys) for v in obj] - return type(obj)(seq) - return obj - - -def save_output(to_show, vid_save_path): - legancy_to_show = (1.0 + torch.stack(to_show, dim=0).clamp(-1, 1)) / 2.0 # [n, b, c, t, h, w] - - video_array = (rearrange(legancy_to_show, "n b c t h w -> t (n h) (b w) c") * 255).to(torch.uint8).cpu().numpy() - log.info( - f"video_array.shape: {video_array.shape} value: {video_array.max()}, {video_array.min()}, save to {vid_save_path}" - ) - base_stem, _ = os.path.splitext(vid_save_path) - save_img_or_video( - rearrange(legancy_to_show, "n b c t h w -> c t (n h) (b w)"), - base_stem, - fps=16, - ) - # Also save a subsampled preview that keeps every 8th frame. - stride = 8 - subsampled = legancy_to_show[:, :, :, ::stride] - subsampled_stem = f"{base_stem}_stride{stride}" - save_img_or_video( - rearrange(subsampled, "n b c t h w -> c t (n h) (b w)"), - subsampled_stem, - fps=16, - ) - log.info(f"save video to {vid_save_path}", rank0_only=True) - - -class Lyra2InferencePipeline: - """Stateful pipeline for Lyra2 autoregressive inference.""" - - def __init__( - self, - *, - model, - args, - first_frame, - first_depth, - first_cam_w2c, - first_intrinsics, - da3_model=None, - cp_group=None, - base_t5_text_embeddings=None, - base_neg_t5_text_embeddings=None, - padding_mask=None, - fps=None, - vipe_input_dump_dir: Optional[str] = None, - vipe_input_dump_prefix: Optional[str] = None, - multiview_data: Optional[dict] = None, - ): - self.model = model - self.args = args - self.cp_group = cp_group - self.frames_per_latent = model.framepack_num_frames_per_latent - self.tokens_per_step = model.framepack_num_new_latent_frames - self.T_hist = model.framepack_total_max_num_latent_frames - self.tokens_per_step - self.start_index = (self.T_hist - 1) * self.frames_per_latent - self.repeat_pixels = (self.T_hist - 1) * self.frames_per_latent + 1 - - init_video = first_frame.repeat(1, 1, self.repeat_pixels, 1, 1) - self.history_frames = init_video - _, self.vae_wrap, self.vae_core = _get_vae_handles(model) - self.history_latents, self.enc_feat_cache = _prime_encoder_cache_with_history( - init_video, self.vae_wrap, self.vae_core, model, args.offload - ) - # Align latent dtype/device to model tensor kwargs for downstream layers. - self.history_latents = misc.to(self.history_latents, **self.model.tensor_kwargs) - self.vae_core.clear_cache() - self.dec_feat_cache = [None] * self.vae_core._conv_num - _ = _decode_new_latent_chunk( - self.vae_wrap, - self.vae_core, - self.dec_feat_cache, - self.history_latents, - latent_offset=0, - model=model, - enable_offload=args.offload, - ) - self.first_latent = self.history_latents[:, :, :1] - if args.offload: - self.history_latents = self.history_latents.cpu() - self.history_frames = self.history_frames.cpu() - - self.last_hist_frame = first_frame[:, :, 0] - - cfg = model.config - # Lyra2 is collapsed to the pose-conditioned target branch. - self.use_pose = True - self.use_plucker = True - self.use_plucker_relative = False - self.use_plucker_no_intrinsics = False - self.use_image_spatial = bool(getattr(cfg, "spatial_memory_use_image", False)) - self.merge_history_buffers = False - self.num_retrieval_views: int = int(getattr(args, "num_retrieval_views", 1)) - self.warp_video_collect: List[torch.Tensor] = [] - self.vipe_input_dump_dir = vipe_input_dump_dir - self.vipe_input_dump_prefix = vipe_input_dump_prefix - - cam_w2c_first = first_cam_w2c - if cam_w2c_first.dim() == 3: - cam_w2c_first = cam_w2c_first.unsqueeze(1) - intrinsics_first = first_intrinsics - if intrinsics_first.dim() == 3: - intrinsics_first = intrinsics_first.unsqueeze(1) - self.cam_w2c = cam_w2c_first.to(torch.float32) - self.intrinsics = intrinsics_first.to(torch.float32) - - # Pose-conditioning state: - # - We keep a single Sparse3DCache (downsample=4, store_values=True) used for: - # 1) spatial overlap retrieval - # 2) depth lookup by frame_id for accumulated-PD warping - self.retrieval_cache: Optional[Sparse3DCache] = None - # Most recent "buffer" depth for warping (B,1,H,W). We store it explicitly to avoid requiring - # the latest history frame to be present in the cache. - self.buffer_depth_latest: Optional[torch.Tensor] = None - # Optional per-pixel validity mask (B,1,H,W) for UI/visualization (e.g. non-sky). - self.buffer_mask_latest: Optional[torch.Tensor] = None - self.buffer_depth_latest_frame_idx: Optional[int] = None - - self.depth_backend = getattr(args, "depth_backend", "da3") - self.local_da3_model = da3_model if self.depth_backend == "da3" else None - self.vipe = None - # MoGe backend is intentionally not supported in this inference script. - - if self.use_pose: - assert first_depth is not None, "first_depth is required when pose conditioning is enabled" - B_cam, T_cam = self.cam_w2c.shape[0], self.cam_w2c.shape[1] - assert T_cam >= 1, "need at least first camera for pose mode" - first_img_bchw = first_frame[:, :, 0] - first_depth_b1hw = first_depth - if first_depth_b1hw.dim() == 3: - first_depth_b1hw = first_depth_b1hw.unsqueeze(1) - first_w2c = self.cam_w2c[:, 0] - first_K = self.intrinsics[:, 0] - if self.depth_backend == "da3": - if self.local_da3_model is None: - from lyra_2._src.inference.depth_utils import load_da3_model - da3_device = model.tensor_kwargs.get( - "device", "cuda" if torch.cuda.is_available() else "cpu" - ) - self.local_da3_model = load_da3_model( - da3_model_name=args.da3_model_name, - da3_model_path_custom=args.da3_model_path_custom, - device=da3_device, - ) - self.local_da3_model.eval() - else: - raise ValueError(f"Unsupported depth_backend='{self.depth_backend}' for this inference script.") - - store_device = "cpu" if args.offload else str(first_img_bchw.device.type) - # For inference we optionally store original depth values in the same cache used for retrieval. - # This is required when use_accumulated_pcd=True (warping needs per-frame depth lookup by frame_id). - self.retrieval_cache = Sparse3DCache( - downsample=4, - store_device=store_device, - store_values=True, - ) - - mv_ids = getattr(args, "multiview_ids", None) - if mv_ids and multiview_data is not None: - # Multiview input: seed cache with specified frames using negative IDs. - mv_video = multiview_data["video"] # [B, C, T, H, W] - mv_depth = multiview_data["depth"] # [B, T, ...] or [B, T, 1, H, W] - mv_w2c = multiview_data["camera_w2c"] # [B, T, 4, 4] - mv_K = multiview_data["intrinsics"] # [B, T, 3, 3] - for i, src_idx in enumerate(mv_ids): - neg_id = -(i + 1) - d = mv_depth[:, src_idx].to(torch.float32) - if d.dim() == 3: - d = d.unsqueeze(1) - w = mv_w2c[:, src_idx].to(torch.float32) - k = mv_K[:, src_idx].to(torch.float32) - self.retrieval_cache.add(d, w, k, latent_index=neg_id, frame_id=neg_id) - rgb = mv_video[:, :, src_idx].to(torch.float32) # [B, C, H, W] - self.retrieval_cache.store_rgb(neg_id, rgb) - log.info(f"Multiview cache: added frame src_idx={src_idx} as cache id={neg_id}", rank0_only=True) - # Use the first multiview frame's depth as the initial buffer depth. - first_mv_depth = mv_depth[:, mv_ids[0]].to(torch.float32) - if first_mv_depth.dim() == 3: - first_mv_depth = first_mv_depth.unsqueeze(1) - self.buffer_depth_latest = first_mv_depth - self.buffer_depth_latest_frame_idx = 0 - else: - # Default: seed cache with the first frame (frame_id=0). - self.retrieval_cache.add( - first_depth_b1hw.to(torch.float32), - first_w2c.to(torch.float32), - first_K.to(torch.float32), - latent_index=0, - frame_id=0, - ) - self.buffer_depth_latest = first_depth_b1hw.to(torch.float32) - self.buffer_depth_latest_frame_idx = 0 - # Seed mask (best-effort): valid where depth > 0. DA3 sky mask is applied later during updates. - self.buffer_mask_latest = (self.buffer_depth_latest > 0).to(torch.float32) - if args.offload: - self.buffer_depth_latest = self.buffer_depth_latest.cpu() - self.buffer_mask_latest = self.buffer_mask_latest.cpu() - else: - self.local_da3_model = None - - self.tokens_generated = 0 - self.ar_idx = 0 - - # Predicted-pose update state (populated by _update_depth_cache when da3_use_predicted_pose=True). - self._predicted_pose_last_w2c: Optional[torch.Tensor] = None - self._predicted_pose_updated_seed_depth: Optional[torch.Tensor] = None - self._predicted_pose_updated_seed_mask: Optional[torch.Tensor] = None - self._predicted_pose_is_first_segment: bool = False - - self.base_t5_text_embeddings = base_t5_text_embeddings - self.base_neg_t5_text_embeddings = base_neg_t5_text_embeddings - self.padding_mask = padding_mask - self.fps = fps - - # Snapshot for one-step undo (populated by save_snapshot). - self._snapshot: dict | None = None - - # ------------------------------------------------------------------ # - # Snapshot / revert helpers (one-level undo) - # ------------------------------------------------------------------ # - - def save_snapshot(self) -> None: - """Save the current pipeline state so that the next generation can be reverted.""" - - def _clone_cache_list(cache_list): - if cache_list is None: - return None - return [x.clone() if isinstance(x, torch.Tensor) else x for x in cache_list] - - snap: dict = {} - - # Tensors that get appended – store their current temporal size so we can crop. - snap["history_frames_T"] = int(self.history_frames.shape[2]) - snap["history_latents_T"] = int(self.history_latents.shape[2]) - snap["cam_w2c_T"] = int(self.cam_w2c.shape[1]) - snap["intrinsics_T"] = int(self.intrinsics.shape[1]) - - # VAE caches – must deep-clone (list of tensors or Nones). - snap["enc_feat_cache"] = _clone_cache_list(self.enc_feat_cache) - snap["dec_feat_cache"] = _clone_cache_list(self.dec_feat_cache) - - # Scalar / small-tensor state. - snap["ar_idx"] = self.ar_idx - snap["tokens_generated"] = self.tokens_generated - snap["last_hist_frame"] = self.last_hist_frame.clone() - - # Depth state. - snap["buffer_depth_latest"] = self.buffer_depth_latest.clone() if self.buffer_depth_latest is not None else None - snap["buffer_mask_latest"] = self.buffer_mask_latest.clone() if self.buffer_mask_latest is not None else None - snap["buffer_depth_latest_frame_idx"] = self.buffer_depth_latest_frame_idx - - # Sparse3DCache – record current list lengths for trimming. - if self.retrieval_cache is not None: - snap["cache_len"] = len(self.retrieval_cache._world_points) - snap["cache_rgb_keys"] = set(self.retrieval_cache._rgbs.keys()) - else: - snap["cache_len"] = 0 - snap["cache_rgb_keys"] = set() - - # Warp video collect. - snap["warp_video_collect_len"] = len(self.warp_video_collect) - - # Predicted-pose state. - snap["_predicted_pose_last_w2c"] = ( - self._predicted_pose_last_w2c.clone() if isinstance(self._predicted_pose_last_w2c, torch.Tensor) else None - ) - snap["_predicted_pose_updated_seed_depth"] = ( - self._predicted_pose_updated_seed_depth.clone() - if isinstance(self._predicted_pose_updated_seed_depth, torch.Tensor) - else None - ) - snap["_predicted_pose_updated_seed_mask"] = ( - self._predicted_pose_updated_seed_mask.clone() - if isinstance(self._predicted_pose_updated_seed_mask, torch.Tensor) - else None - ) - snap["_predicted_pose_is_first_segment"] = self._predicted_pose_is_first_segment - - self._snapshot = snap - - def revert_to_snapshot(self) -> bool: - """Revert pipeline state to the last saved snapshot. Returns True on success.""" - snap = self._snapshot - if snap is None: - return False - - # Crop appendable tensors. - self.history_frames = self.history_frames[:, :, : snap["history_frames_T"]].contiguous() - self.history_latents = self.history_latents[:, :, : snap["history_latents_T"]].contiguous() - self.cam_w2c = self.cam_w2c[:, : snap["cam_w2c_T"]].contiguous() - self.intrinsics = self.intrinsics[:, : snap["intrinsics_T"]].contiguous() - - # Restore VAE caches. - self.enc_feat_cache = snap["enc_feat_cache"] - self.dec_feat_cache = snap["dec_feat_cache"] - - # Scalars. - self.ar_idx = snap["ar_idx"] - self.tokens_generated = snap["tokens_generated"] - self.last_hist_frame = snap["last_hist_frame"] - - # Depth state. - self.buffer_depth_latest = snap["buffer_depth_latest"] - self.buffer_mask_latest = snap["buffer_mask_latest"] - self.buffer_depth_latest_frame_idx = snap["buffer_depth_latest_frame_idx"] - - # Trim Sparse3DCache back to its snapshot length. - if self.retrieval_cache is not None: - old_len = snap["cache_len"] - self.retrieval_cache._world_points = self.retrieval_cache._world_points[:old_len] - self.retrieval_cache._latent_indices = self.retrieval_cache._latent_indices[:old_len] - self.retrieval_cache._frame_ids = self.retrieval_cache._frame_ids[:old_len] - if self.retrieval_cache._store_values: - self.retrieval_cache._depths = self.retrieval_cache._depths[:old_len] - self.retrieval_cache._w2cs = self.retrieval_cache._w2cs[:old_len] - self.retrieval_cache._Ks = self.retrieval_cache._Ks[:old_len] - kept_rgb_keys = snap["cache_rgb_keys"] - self.retrieval_cache._rgbs = { - k: v for k, v in self.retrieval_cache._rgbs.items() if k in kept_rgb_keys - } - - # Warp video collect. - self.warp_video_collect = self.warp_video_collect[: snap["warp_video_collect_len"]] - - # Predicted-pose state. - self._predicted_pose_last_w2c = snap["_predicted_pose_last_w2c"] - self._predicted_pose_updated_seed_depth = snap["_predicted_pose_updated_seed_depth"] - self._predicted_pose_updated_seed_mask = snap["_predicted_pose_updated_seed_mask"] - self._predicted_pose_is_first_segment = snap["_predicted_pose_is_first_segment"] - - # Invalidate snapshot after revert (single-level undo). - self._snapshot = None - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return True - - def _append_cameras(self, cam_w2c_chunk: torch.Tensor, intrinsics_chunk: torch.Tensor): - cam_chunk = cam_w2c_chunk - intr_chunk = intrinsics_chunk - if cam_chunk.dim() == 3: - cam_chunk = cam_chunk.unsqueeze(1) - if intr_chunk.dim() == 3: - intr_chunk = intr_chunk.unsqueeze(1) - cam_chunk = cam_chunk.to(torch.float32) - intr_chunk = intr_chunk.to(torch.float32) - self.cam_w2c = torch.cat([self.cam_w2c, cam_chunk], dim=1) - self.intrinsics = torch.cat([self.intrinsics, intr_chunk], dim=1) - - def _prepare_text_embeddings(self, t5_text_embeddings, neg_t5_text_embeddings): - pos = t5_text_embeddings if t5_text_embeddings is not None else self.base_t5_text_embeddings - neg = neg_t5_text_embeddings if neg_t5_text_embeddings is not None else self.base_neg_t5_text_embeddings - return pos, neg - - def autoregressive_step( - self, - *, - cam_w2c_chunk, - intrinsics_chunk, - t5_text_embeddings=None, - neg_t5_text_embeddings=None, - is_last_step=False, - ): - self._append_cameras(cam_w2c_chunk, intrinsics_chunk) - start_px_idx = 1 + self.ar_idx * self.model.framepack_num_new_video_frames - end_px_idx = start_px_idx + self.model.framepack_num_new_video_frames - - total_latents_now = int(self.history_latents.shape[2]) - # Use cached counts computed in model._init_lyra2_metadata. - num_temporal_hist = int(self.model.framepack_num_temporal_hist) - num_spatial_hist = int(self.model.framepack_num_spatial_hist) - - temporal_selected: List[int] = self.model._select_temporal_history_indices(total_latents_now, num_temporal_hist) - - cfg = self.model.config - use_image_spatial = bool(cfg.spatial_memory_use_image) - - # Move history latents to the model device/dtype for selection/inference. - history_full = misc.to(self.history_latents, **self.model.tensor_kwargs) - # Unified input preparation: reuse Lyra2Model._prepare_lyra2_inputs. - # For now we only support pose-conditioned inference here. - assert self.retrieval_cache is not None, "retrieval_cache must be initialized for pose mode." - assert self.buffer_depth_latest is not None, "buffer_depth_latest must be initialized for pose mode." - - device = history_full.device - video_hist_abs = self.history_frames[:, :, self.start_index : ] - video_all = misc.to(video_hist_abs, **self.model.tensor_kwargs) - # Build a virtual video_indices timeline: [0 repeated prefix] + [1..end_px_idx-1] - video_indices_t = torch.tensor( - [0] * int(self.repeat_pixels) + list(range(1, int(end_px_idx))), - device=device, - dtype=torch.long, - ) - - # Dummy generation tail; `_prepare_lyra2_inputs()` overwrites it with pose conditioning. - B, C_lat, _T_hist, H_lat, W_lat = history_full.shape - T_new_lat = int(self.model.framepack_num_new_latent_frames) - gen_cond_dummy = torch.zeros((B, C_lat, T_new_lat, H_lat, W_lat), device=device, dtype=history_full.dtype) - - # Buffer depth (most recent history pixel frame) for warping. - buffer_depth = self.buffer_depth_latest.to(device=device, dtype=torch.float32) - if buffer_depth.dim() == 3: - buffer_depth = buffer_depth.unsqueeze(1) - - # Keep original skip behavior from the previous implementation. - spatial_cache_skip_last_n = 0 - - # Collect warped pixels for visualization if pose conditioning is enabled. - prev_collect = bool(getattr(self.model, "_collect_return_condition_state", False)) - try: - self.model._collect_return_condition_state = True - latents_full, cond_latent, _mask, buffer_cond_latents = self.model._prepare_lyra2_inputs( - history_full=history_full, - gen_cond=gen_cond_dummy, - spatial_cache=self.retrieval_cache, - video=video_all, - buffer_depth_B_1_H_W=buffer_depth, - camera_w2c=self.cam_w2c, - intrinsics=self.intrinsics, - video_indices=video_indices_t, - is_training=False, - spatial_cache_skip_last_n=int(spatial_cache_skip_last_n), - num_retrieval_views=self.num_retrieval_views, - ) - finally: - self.model._collect_return_condition_state = prev_collect - gc.collect() - torch.cuda.empty_cache() - warp_pixels = getattr(self.model, "_latest_condition_state_pixels", None) - if self.use_pose and warp_pixels is not None: - if isinstance(warp_pixels, torch.Tensor) and warp_pixels.dim() == 5: - if int(warp_pixels.shape[1]) > 3: - warp_pixels = warp_pixels[:, :3] - self.warp_video_collect.append(warp_pixels.detach().float().cpu()) - history_window = latents_full[:, :, : -T_new_lat] - - self._restore_model_to_gpu() - pos_text, neg_text = self._prepare_text_embeddings(t5_text_embeddings, neg_t5_text_embeddings) - last_hist_frame_cast = misc.to(self.last_hist_frame, **self.model.tensor_kwargs) - padding_mask_cast = misc.to(self.padding_mask, **self.model.tensor_kwargs) if self.padding_mask is not None else None - if not self.args.use_dmd_scheduler: - gen_chunk = self.model.inference( - history_latents=history_window, - cond_latent=cond_latent, - cond_latent_mask=_mask, - cond_latent_buffer=buffer_cond_latents, - guidance=self.args.guidance, - seed=int(self.args.seed + self.ar_idx), - num_steps=self.args.num_sampling_step, - shift=self.args.shift, - t5_text_embeddings=pos_text, - neg_t5_text_embeddings=neg_text, - last_hist_frame=last_hist_frame_cast, - fps=self.fps, - padding_mask=padding_mask_cast, - ) - else: - gen_chunk = self.model.inference_dmd( - history_latents=history_window, - cond_latent=cond_latent, - cond_latent_mask=_mask, - cond_latent_buffer=buffer_cond_latents, - guidance=self.args.guidance, - seed=int(self.args.seed + self.ar_idx), - num_steps=self.args.num_sampling_step, - shift=self.args.shift, - t5_text_embeddings=pos_text, - neg_t5_text_embeddings=neg_text, - last_hist_frame=last_hist_frame_cast, - fps=self.fps, - padding_mask=padding_mask_cast, - ) - gen_chunk = gen_chunk[:, :, :self.model.framepack_num_new_latent_frames] - new_generated_frames = _decode_new_latent_chunk( - self.vae_wrap, - self.vae_core, - self.dec_feat_cache, - gen_chunk, - latent_offset=self.history_latents.shape[2], - model=self.model, - enable_offload=self.args.offload, - ) - if self.args.offload: - self.history_frames = torch.cat( - [self.history_frames, new_generated_frames.to(self.history_frames.dtype).cpu()], - dim=2, - ) - else: - self.history_frames = torch.cat( - [self.history_frames, new_generated_frames.to(self.history_frames.dtype)], - dim=2, - ) - - with self.vae_wrap.context: - px_cast = new_generated_frames.to(self.vae_wrap.dtype) if not self.vae_wrap.is_amp else new_generated_frames - feats_gen, enc_cache_out = self.model.vae_encode_with_cache( - enc_cache=self.enc_feat_cache, - video=px_cast, - start_t=0, - end_t=px_cast.shape[2], - return_cache=True, - ) - self.enc_feat_cache[:] = enc_cache_out - gen_chunk_reencoded = self.model._encoder_feats_to_normalized_latents(feats_gen).to(self.history_latents.dtype) - if self.args.offload: - self.history_latents = torch.cat( - [self.history_latents, gen_chunk_reencoded.to(self.history_latents.dtype).cpu()], - dim=2, - ) - else: - self.history_latents = torch.cat( - [self.history_latents, gen_chunk_reencoded.to(self.history_latents.dtype)], - dim=2, - ) - self.last_hist_frame = new_generated_frames[:, :, -1] - self.tokens_generated += self.model.framepack_num_new_latent_frames - - if self.use_pose and not is_last_step: - self._update_depth_cache(end_px_idx) - del gen_chunk, new_generated_frames, gen_chunk_reencoded - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - self.ar_idx += 1 - return {"abort": False} - - def _restore_model_to_gpu(self): - _restore_diffusion_to_gpu(self.model, self.args.offload) - - def _update_depth_cache(self, end_px_idx): - if self.depth_backend == "da3": - assert self.local_da3_model is not None, "DA3 model must be initialized for pose mode." - assert self.retrieval_cache is not None, "retrieval_cache must exist for pose mode." - assert self.buffer_depth_latest is not None, "buffer_depth_latest must exist for pose mode." - - offload_da3 = bool(getattr(self.args, "offload_da3_diffusion", False)) - if offload_da3: - _offload_diffusion_to_cpu(self.model, True) - try: - da3_out = _predict_da3_depth_window( - da3_model=self.local_da3_model, - history_frames=self.history_frames, - start_index=self.start_index, - abs_last_idx=end_px_idx - 1, - cam_w2c=self.cam_w2c, - intrinsics=self.intrinsics, - frame_interval=int(self.args.da3_frame_interval), - max_history_frames=int(self.args.da3_max_history_frames), - add_pose_alignment_offsets=True, - include_ar_chunk_last_frames=bool(getattr(self.args, "da3_include_ar_chunk_last_frames", False)), - ar_chunk_size_frames=int(self.model.framepack_num_new_video_frames), - return_raw_predicted_pose=( - bool(getattr(self.args, "da3_use_predicted_pose", False)) - and (self.ar_idx == 0 or bool(getattr(self.args, "da3_predicted_pose_continuation", False))) - ), - ) - finally: - if offload_da3: - _restore_diffusion_to_gpu(self.model, True) - da3_frames: List[int] = da3_out["frame_indices"] - da3_pred = da3_out["prediction"] - assert da3_pred is not None and len(da3_frames) > 0, "DA3 depth window prediction required for cache update" - - depths_np = da3_pred.depth - sky_np = getattr(da3_pred, "sky", None) - H0 = int(self.history_frames.shape[-2]) - W0 = int(self.history_frames.shape[-1]) - - # --- Predicted-pose alignment (if enabled) --- - use_predicted_pose = bool(getattr(self.args, "da3_use_predicted_pose", False)) - aligned_w2c_per_local: Optional[List[torch.Tensor]] = None - depth_scale_factor: float = 1.0 - is_first_segment = False - - if use_predicted_pose: - pred_ext = getattr(da3_pred, "extrinsics", None) - if pred_ext is not None: - pred_ext_t = torch.as_tensor(np.asarray(pred_ext), dtype=torch.float32).to(self.cam_w2c.device) - if pred_ext_t.dim() == 3: - pred_ext_t = pred_ext_t.unsqueeze(0) - pred_w2c_all = pred_ext_t[0] # [N, 3or4, 4] - if pred_w2c_all.shape[-2] == 3: - pad = torch.zeros((pred_w2c_all.shape[0], 4, 4), dtype=pred_w2c_all.dtype, device=pred_w2c_all.device) - pad[:, :3, :4] = pred_w2c_all - pad[:, 3, 3] = 1.0 - pred_w2c_all = pad - - step_size = int(self.model.framepack_num_new_video_frames) - new_frame_start = end_px_idx - step_size - hist_local = [i for i, f in enumerate(da3_frames) if f < new_frame_start] - new_local = [i for i, f in enumerate(da3_frames) if f >= new_frame_start] - - is_first_segment = (len(hist_local) == 1 and da3_frames[hist_local[0]] == 0) - - if is_first_segment: - # Case a): history = frame 0 only. - # Normalise by inv(pred[0]) so frame 0 becomes identity, - # then scale the w2c translations by a trajectory-length - # ratio (so that relative camera distances match the - # pipeline), and finally transform into pipeline world space. - all_local = list(range(len(da3_frames))) - - inv_p0 = torch.linalg.inv(pred_w2c_all[hist_local[0]]) - aligned_all = pred_w2c_all @ inv_p0.unsqueeze(0) # frame 0 → identity - - # Compute depth/pose scale via trajectory-length ratio. - assert len(all_local) >= 2, "Expected at least two frames for trajectory-length scale calculation." - ref_c2w_pos = torch.stack([ - torch.linalg.inv(self.cam_w2c[0, da3_frames[i]].to(torch.float32))[:3, 3] - for i in all_local - ]) - pred_c2w_pos = torch.stack([ - torch.linalg.inv(pred_w2c_all[i])[:3, 3] - for i in all_local - ]) - ref_traj_len = (ref_c2w_pos[1:] - ref_c2w_pos[:-1]).norm(dim=-1).sum().item() - pred_traj_len = (pred_c2w_pos[1:] - pred_c2w_pos[:-1]).norm(dim=-1).sum().item() - depth_scale_factor = ref_traj_len / pred_traj_len if pred_traj_len > 1e-8 else 1.0 - - - # Scale w2c translations so relative distances match pipeline. - # (frame 0 has t=0 after normalisation, so it stays exactly at origin.) - aligned_all[:, :3, 3] = aligned_all[:, :3, 3] * depth_scale_factor - - # Transform into pipeline world space via pipeline's frame-0 w2c. - pipeline_w2c_0 = self.cam_w2c[0, da3_frames[hist_local[0]]].to(torch.float32) - aligned_all = aligned_all @ pipeline_w2c_0.unsqueeze(0) - - aligned_w2c_per_local = [aligned_all[i].unsqueeze(0) for i in range(len(da3_frames))] - log.info( - f"[da3_use_predicted_pose] Case A (first segment): " - f"frame0-normalise + traj-length scale={depth_scale_factor:.6f} " - f"on {len(all_local)} frames", - rank0_only=True, - ) - for li in range(len(da3_frames)): - f_abs_i = da3_frames[li] - pipeline_pos = torch.linalg.inv(self.cam_w2c[0, f_abs_i].to(torch.float32))[:3, 3] - aligned_pos = torch.linalg.inv(aligned_all[li])[:3, 3] - residual = (pipeline_pos - aligned_pos).norm().item() - is_hist = "hist" if li in hist_local else "new " - log.info( - f" [{is_hist}] frame {f_abs_i}: " - f"pipeline_pos={pipeline_pos.cpu().numpy()}, " - f"aligned_pos={aligned_pos.cpu().numpy()}, " - f"residual={residual:.6f}", - rank0_only=True, - ) - elif bool(getattr(self.args, "da3_predicted_pose_continuation", False)): - assert len(hist_local) >= 2, "Expected at least one history frame for continuation segment." - # Case b): continuation segment. Traj-length scale, - # normalise at history-last, then anchor to pipeline pose. - all_local = list(range(len(da3_frames))) - hist_last_local = hist_local[-1] - - # Trajectory-length scale on history frames only (already - # DA3-aligned in prior chunks, so reliable for scale). - - ref_c2w_pos = torch.stack([ - torch.linalg.inv(self.cam_w2c[0, da3_frames[i]].to(torch.float32))[:3, 3] - for i in hist_local - ]) - pred_c2w_pos = torch.stack([ - torch.linalg.inv(pred_w2c_all[i])[:3, 3] - for i in hist_local - ]) - ref_traj_len = (ref_c2w_pos[1:] - ref_c2w_pos[:-1]).norm(dim=-1).sum().item() - pred_traj_len = (pred_c2w_pos[1:] - pred_c2w_pos[:-1]).norm(dim=-1).sum().item() - depth_scale_factor = ref_traj_len / pred_traj_len if pred_traj_len > 1e-8 else 1.0 - - - # Scale predicted w2c translations to match pipeline magnitude. - scaled_pred = pred_w2c_all.clone() - scaled_pred[:, :3, 3] = scaled_pred[:, :3, 3] * depth_scale_factor - - # Align: map scaled pred[hist_last] → pipeline[hist_last]. - pipeline_w2c_anchor = self.cam_w2c[0, da3_frames[hist_last_local]].to(torch.float32) - T_align = torch.linalg.inv(scaled_pred[hist_last_local]) @ pipeline_w2c_anchor - aligned_all = scaled_pred @ T_align.unsqueeze(0) - - aligned_w2c_per_local = [aligned_all[i].unsqueeze(0) for i in range(len(da3_frames))] - log.info( - f"[da3_use_predicted_pose] Case B (continuation): " - f"traj-length scale={depth_scale_factor:.6f}, " - f"anchor=hist_last (local={hist_last_local}, abs={da3_frames[hist_last_local]}), " - f"aligned {len(da3_frames)} frames using {len(hist_local)} history frames. " - f"hist_frame_ids={[da3_frames[i] for i in hist_local]}, " - f"new_frame_ids={[da3_frames[i] for i in new_local]}", - rank0_only=True, - ) - for li in range(len(da3_frames)): - f_abs_i = da3_frames[li] - pipeline_pos = torch.linalg.inv(self.cam_w2c[0, f_abs_i].to(torch.float32))[:3, 3] - aligned_pos = torch.linalg.inv(aligned_all[li])[:3, 3] - residual = (pipeline_pos - aligned_pos).norm().item() - is_hist = "hist" if li in hist_local else "new " - log.info( - f" [{is_hist}] frame {f_abs_i}: " - f"pipeline_pos={pipeline_pos.cpu().numpy()}, " - f"aligned_pos={aligned_pos.cpu().numpy()}, " - f"residual={residual:.6f}", - rank0_only=True, - ) - - # Update self.cam_w2c for new frames with aligned predicted poses. - if aligned_w2c_per_local is not None: - updated_abs_ids = [] - for li in new_local: - f_abs_li = da3_frames[li] - self.cam_w2c[:, f_abs_li] = aligned_w2c_per_local[li].to(self.cam_w2c.dtype) - updated_abs_ids.append(f_abs_li) - log.info( - f"[da3_use_predicted_pose] Updated cam_w2c for frames: {updated_abs_ids}", - rank0_only=True, - ) - else: - log.warning( - "[da3_use_predicted_pose] DA3 prediction has no extrinsics; " - "falling back to pipeline poses.", - rank0_only=True, - ) - - existing_ids = set(getattr(self.retrieval_cache, "_latent_indices", [])) - stride = max(int(self.model.config.spatial_memory_stride), 1) - - for local_i, f_abs in enumerate(da3_frames): - depth_np = depths_np[local_i] - depth_t = torch.from_numpy(depth_np).to(self.cam_w2c.device, dtype=torch.float32) - if depth_t.dim() == 2: - depth_t = depth_t.unsqueeze(0).unsqueeze(0) - elif depth_t.dim() == 3: - depth_t = depth_t.unsqueeze(0) - depth_t = torch.nn.functional.interpolate( - depth_t, - size=(H0, W0), - mode="bilinear", - align_corners=False, - ) - - if use_predicted_pose and depth_scale_factor != 1.0: - depth_t = depth_t * depth_scale_factor - - # Optional sky mask for UI/visualization only; do not modify depth_t. - valid_mask_t = None - if sky_np is not None: - sky_arr = np.asarray(sky_np)[local_i].astype(np.uint8) # [H,W], 1 for sky - sky_t = torch.from_numpy(sky_arr).to(self.cam_w2c.device, dtype=torch.float32).unsqueeze(0).unsqueeze(0) - if int(sky_t.shape[-2]) != H0 or int(sky_t.shape[-1]) != W0: - sky_t = torch.nn.functional.interpolate(sky_t, size=(H0, W0), mode="nearest") - sky_hw = sky_t > 0.5 - valid_mask_t = (~sky_hw).to(torch.float32) - - # Further filter with depth reliability mask (same as Gen3C forward-warp cleanup). - depth_rel = reliable_depth_mask_range_batch(depth_t, ratio_thresh=0.15) - # Some variants may return a tuple/list; first element is the mask. - if isinstance(depth_rel, (tuple, list)): - depth_rel = depth_rel[0] - depth_rel = depth_rel.to(dtype=torch.float32, device=depth_t.device) - valid_mask_t = depth_rel if valid_mask_t is None else (valid_mask_t * depth_rel) - - if aligned_w2c_per_local is not None: - w2c_t = aligned_w2c_per_local[local_i].to(torch.float32) - else: - w2c_t = self.cam_w2c[:, f_abs].to(torch.float32) - K_t = self.intrinsics[:, f_abs].to(torch.float32) - # Track latest buffer depth (the most recent history pixel frame). - if int(f_abs) == int(end_px_idx - 1): - self.buffer_depth_latest = depth_t.detach() - self.buffer_depth_latest_frame_idx = int(f_abs) - if valid_mask_t is not None: - self.buffer_mask_latest = valid_mask_t.detach() - else: - self.buffer_mask_latest = (depth_t > 0).to(torch.float32).detach() - - # Add to overlap-selection cache depending on mode. - if not getattr(self.args, "disable_cache_update", False): - cache_id = int(f_abs) if self.use_image_spatial else int((int(f_abs) + int(self.start_index)) // int(self.frames_per_latent)) - - # Case a): also update frame 0's depth/pose in the cache. - if use_predicted_pose and is_first_segment and int(f_abs) == 0 and cache_id in existing_ids: - self.retrieval_cache.update_by_frame_id( - frame_id=0, - depth_B_1_H_W=depth_t, - w2c_B_4_4=w2c_t, - K_B_3_3=K_t, - ) - continue - - if cache_id in existing_ids: - continue - if self.use_image_spatial and int(f_abs) != 0 and (int(f_abs) % int(stride) != 0): - continue - existing_ids.add(cache_id) - self.retrieval_cache.add( - depth_t, - w2c_t, - K_t, - latent_index=int(cache_id), - frame_id=int(f_abs), - ) - - # Store predicted-pose update info for the persistent wrapper to return to client. - if use_predicted_pose and aligned_w2c_per_local is not None: - last_f_abs = int(end_px_idx - 1) - last_local = None - for li, f in enumerate(da3_frames): - if int(f) == last_f_abs: - last_local = li - break - if last_local is not None: - self._predicted_pose_last_w2c = aligned_w2c_per_local[last_local].detach().cpu() - if is_first_segment: - frame0_local = None - for li, f in enumerate(da3_frames): - if int(f) == 0: - frame0_local = li - break - if frame0_local is not None: - d0_np = depths_np[frame0_local] - d0_t = torch.from_numpy(d0_np).to(torch.float32) - if d0_t.dim() == 2: - d0_t = d0_t.unsqueeze(0).unsqueeze(0) - elif d0_t.dim() == 3: - d0_t = d0_t.unsqueeze(0) - d0_t = torch.nn.functional.interpolate(d0_t, size=(H0, W0), mode="bilinear", align_corners=False) - if depth_scale_factor != 1.0: - d0_t = d0_t * depth_scale_factor - self._predicted_pose_updated_seed_depth = d0_t.detach().cpu() - # Mask for seed depth - seed_mask = (d0_t > 0).to(torch.float32) - if sky_np is not None: - sky0 = np.asarray(sky_np)[frame0_local].astype(np.uint8) - sky0_t = torch.from_numpy(sky0).unsqueeze(0).unsqueeze(0).to(torch.float32) - if sky0_t.shape[-2:] != (H0, W0): - sky0_t = torch.nn.functional.interpolate(sky0_t, size=(H0, W0), mode="nearest") - seed_mask = seed_mask * (~(sky0_t > 0.5)).to(torch.float32) - self._predicted_pose_updated_seed_mask = seed_mask.detach().cpu() - self._predicted_pose_is_first_segment = is_first_segment - if self.args.offload: - self.buffer_depth_latest = self.buffer_depth_latest.cpu() - if self.buffer_mask_latest is not None: - self.buffer_mask_latest = self.buffer_mask_latest.cpu() - return - - raise ValueError(f"Only depth_backend='da3' is supported in this script (VIPE backend removed).") - - def build_outputs(self, da3_gs_export_stem, log_prefix): - video = self.history_frames[:, :, self.start_index:] - warp_video = None - if self.use_pose and len(self.warp_video_collect) > 0: - warp_video = torch.cat(self.warp_video_collect, dim=2) - first_frame = video[:, :, :1] - warp_video = torch.cat([first_frame.cpu(), warp_video], dim=2) - - video_out = video.float().cpu() - warp_out = warp_video.float().cpu() if warp_video is not None else None - warp_out_merged = None - - del self.history_frames, self.history_latents, self.enc_feat_cache, self.dec_feat_cache - del self.warp_video_collect - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return { - "video": video_out, - "warp_video": warp_out, - "warp_video_merged": warp_out_merged, - "use_pose": self.use_pose, - "use_plucker": self.use_plucker, - } - -def run_lyra2_sample( - model, - data_batch, - args, - *, - process_group=None, - da3_model=None, - show_progress=False, - log_prefix="Start AR spatial inference", - da3_gs_export_stem=None, - vipe_input_dump_dir=None, -): - """Shared Lyra2 autoregressive generation logic for a single prepared sample.""" - model._normalize_video_databatch_inplace(data_batch) - - cp_group = None - if args.context_parallel_size > 1: - cp_group = process_group if process_group is not None else parallel_state.get_context_parallel_group() - - init_frame = data_batch["video"][:, :, :1] - first_depth = data_batch["depth"][:, 0] - first_cam_w2c = data_batch["camera_w2c"][:, 0] - first_intrinsics = data_batch["intrinsics"][:, 0] - - # Multiview input: pass full data tensors so the pipeline can seed the cache. - multiview_data = None - if getattr(args, "multiview_ids", None): - multiview_data = { - "video": data_batch["video"], - "depth": data_batch["depth"], - "camera_w2c": data_batch["camera_w2c"], - "intrinsics": data_batch["intrinsics"], - } - - pipeline = Lyra2InferencePipeline( - model=model, - args=args, - first_frame=init_frame, - first_depth=first_depth, - first_cam_w2c=first_cam_w2c, - first_intrinsics=first_intrinsics, - da3_model=da3_model, - cp_group=cp_group, - base_t5_text_embeddings=data_batch.get("t5_text_embeddings", None), - base_neg_t5_text_embeddings=data_batch.get("neg_t5_text_embeddings", None), - padding_mask=data_batch.get("padding_mask", None), - fps=data_batch.get("fps", None), - vipe_input_dump_dir=vipe_input_dump_dir, - vipe_input_dump_prefix=log_prefix, - multiview_data=multiview_data, - ) - - num_frames = int(args.num_frames) - assert (num_frames - 1) % (pipeline.tokens_per_step * pipeline.frames_per_latent) == 0, ( - f"N-1 must be divisible by tokens_per_step*frames_per_latent, but got {num_frames-1} " - f"and {pipeline.tokens_per_step * pipeline.frames_per_latent}" - ) - - tokens_needed = (num_frames - 1 + pipeline.frames_per_latent - 1) // pipeline.frames_per_latent - num_iters = (tokens_needed + pipeline.tokens_per_step - 1) // pipeline.tokens_per_step - - with torch.no_grad(): - log.info(log_prefix, rank0_only=True) - for ar_idx in tqdm.tqdm(range(num_iters)): - start_px_idx = 1 + ar_idx * model.framepack_num_new_video_frames - end_px_idx = start_px_idx + model.framepack_num_new_video_frames - cam_chunk = data_batch["camera_w2c"][:, start_px_idx:end_px_idx] - intr_chunk = data_batch["intrinsics"][:, start_px_idx:end_px_idx] - - if "t5_chunk_keys" in data_batch: - t5_chunk_embeddings = data_batch["t5_chunk_embeddings"] - t5_chunk_mask = data_batch["t5_chunk_mask"] - B = int(data_batch["t5_chunk_keys"].shape[0]) - if args.ablate_same_t5: - pos_t5 = t5_chunk_embeddings[:, 0] - data_batch["t5_text_embeddings"] = pos_t5 - data_batch["t5_text_mask"] = t5_chunk_mask[:, 0] - else: - last_hist_px_abs = ar_idx * model.framepack_num_new_video_frames + 1 - sample_frame_indices = data_batch["sample_frame_indices"] - t5_chunk_keys = data_batch["t5_chunk_keys"] - F_total = int(sample_frame_indices.shape[1]) - idx_clamped = min(max(0, last_hist_px_abs), F_total - 1) - first_abs_idx_B = sample_frame_indices[:, idx_clamped].to(dtype=torch.long) - selected_emb_list = [] - selected_mask_list = [] - for b in range(B): - keys_b = t5_chunk_keys[b] - Kb = int(keys_b.numel()) - val = int(first_abs_idx_B[b].item()) - pos = torch.searchsorted( - keys_b, torch.tensor([val], device=keys_b.device, dtype=keys_b.dtype), right=True - ).item() - sel_idx = max(0, min(int(pos) - 1, Kb - 1)) - emb_b = t5_chunk_embeddings[b, sel_idx] - msk_b = t5_chunk_mask[b, sel_idx] - selected_emb_list.append(emb_b) - selected_mask_list.append(msk_b) - pos_t5 = torch.stack(selected_emb_list, dim=0) - data_batch["t5_text_embeddings"] = pos_t5 - data_batch["t5_text_mask"] = torch.stack(selected_mask_list, dim=0) - else: - pos_t5 = data_batch["t5_text_embeddings"] - neg_t5 = data_batch["neg_t5_text_embeddings"] - - step_out = pipeline.autoregressive_step( - cam_w2c_chunk=cam_chunk, - intrinsics_chunk=intr_chunk, - t5_text_embeddings=pos_t5, - neg_t5_text_embeddings=neg_t5, - is_last_step=ar_idx == num_iters - 1, - ) - - if step_out["abort"]: - break - - return pipeline.build_outputs(da3_gs_export_stem, log_prefix) diff --git a/lyra_2/_src/inference/lyra2_custom_traj_inference.py b/lyra_2/_src/inference/lyra2_custom_traj_inference.py deleted file mode 100644 index f566fd775f44b438addd0f4d3ceac8769a6990e6..0000000000000000000000000000000000000000 --- a/lyra_2/_src/inference/lyra2_custom_traj_inference.py +++ /dev/null @@ -1,551 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Single-image → video generation along a user-supplied camera trajectory. - -Pipeline: -1. Read a single input image. -2. Load per-chunk text captions from a captions.json file (or single --prompt). -3. Load camera trajectory (w2c + intrinsics) from an .npz file, - take the first ``num_frames`` poses. -4. Produce a video using FramePack AR spatial generation with per-chunk T5 embeddings. -5. Save the output video. -""" - -from __future__ import annotations - -import argparse -import gc -import json -import os - -import cv2 -import numpy as np -import torch -import torch.nn.functional as F - -from lyra_2._ext.imaginaire.utils import log, misc -from lyra_2._ext.imaginaire.visualize.video import save_img_or_video -from lyra_2._src.inference.lyra2_ar_inference import ( - save_output, - safe_to, - run_lyra2_sample, -) -from lyra_2._src.inference.lyra2_zoomgs_inference import ( - _da3_infer_depth_intrinsics_single, - _build_image_list, -) -from lyra_2._src.utils.model_loader import load_model_from_checkpoint - -torch.enable_grad(False) -torch.backends.cudnn.enabled = False - - -# --------------------------------------------------------------------------- -# Trajectory loading -# --------------------------------------------------------------------------- - -def load_trajectory( - path: str, - num_frames: int, - target_hw: tuple[int, int] | None = None, - pose_scale: float = 1.0, -): - """Load camera trajectory from an .npz file. - - Expected keys: - w2c – (N, 4, 4) world-to-camera matrices (float32/64) - intrinsics – (N, 3, 3) camera intrinsic matrices (float32/64) - image_height, image_width – original resolution the intrinsics refer to - - If *target_hw* is provided and differs from the stored resolution, - intrinsics are rescaled accordingly. - - Returns the first *num_frames* entries as torch tensors. - """ - data = np.load(path) - w2c = torch.from_numpy(data["w2c"][:num_frames].astype(np.float32)) - intrinsics = torch.from_numpy(data["intrinsics"][:num_frames].astype(np.float32)) - - if pose_scale != 1.0: - w2c[:, :3, 3] *= pose_scale - - if target_hw is not None and "image_height" in data and "image_width" in data: - orig_h, orig_w = int(data["image_height"]), int(data["image_width"]) - tgt_h, tgt_w = target_hw - if (orig_h, orig_w) != (tgt_h, tgt_w): - sx = tgt_w / orig_w - sy = tgt_h / orig_h - intrinsics[:, 0, 0] *= sx - intrinsics[:, 0, 2] *= sx - intrinsics[:, 1, 1] *= sy - intrinsics[:, 1, 2] *= sy - - return w2c, intrinsics - - -# --------------------------------------------------------------------------- -# Argument parser -# --------------------------------------------------------------------------- - -def parse_arguments() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Single-image video generation with a custom camera trajectory" - ) - # Input - parser.add_argument("--input_image_path", type=str, required=True, - help="Path to a single image or a folder of images") - parser.add_argument("--trajectory_path", type=str, required=True, - help="Path to .npz trajectory file (or a folder of per-image .npz files). " - "Expected keys: w2c (N,4,4), intrinsics (N,3,3), " - "image_height, image_width.") - parser.add_argument("--num_samples", type=int, default=10) - parser.add_argument("--sample_start_idx", type=int, default=0) - parser.add_argument("--prompt", type=str, default="", - help="Optional explicit prompt applied to ALL images (single caption).") - parser.add_argument("--prompt_dir", type=str, default=None, - help="Directory containing per-image .txt caption files (single caption).") - parser.add_argument("--captions_path", type=str, default=None, - help="Path to captions.json (or dir with per-image .json files). " - "JSON maps frame-index strings to caption text. " - "Each AR chunk uses the caption whose key is <= current frame.") - parser.add_argument("--prompt_suffix", type=str, default="", - help="Text appended to every prompt.") - - # Model and generation - parser.add_argument("--experiment", type=str, default="lyra_framepack_spatial") - parser.add_argument("--checkpoint_dir", type=str, default="checkpoints/model") - parser.add_argument("--output_path", type=str, default="inference/lyra2_custom_traj") - parser.add_argument("--guidance", type=float, default=5.0) - parser.add_argument("--shift", type=float, default=5.0) - parser.add_argument("--num_sampling_step", type=int, default=35) - parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--fps", type=int, default=16) - parser.add_argument("--num_frames", type=int, default=161, - help="Number of frames to generate (taken from the start of the trajectory).") - parser.add_argument("--pose_scale", type=float, default=1.1, - help="Scale factor applied to w2c translation vectors.") - parser.add_argument("--resolution", type=str, default="480,832", help="H,W") - parser.add_argument("--context_parallel_size", type=int, default=1) - parser.add_argument("--lora_paths", type=str, default=None, nargs="+") - parser.add_argument("--lora_weights", type=float, default=None, nargs="+") - parser.add_argument("--offload", action="store_true") - parser.add_argument("--offload_when_prompt", action="store_true") - parser.add_argument("--debug", action="store_true") - - # Depth backend - parser.add_argument("--use_moge_scale", action=argparse.BooleanOptionalAction, default=True, - help="Align DA3 depth to MoGe scale (default: True).") - parser.add_argument("--depth_backend", type=str, default="da3", choices=["da3"]) - parser.add_argument("--da3_model_name", type=str, default="depth-anything/DA3NESTED-GIANT-LARGE-1.1") - parser.add_argument("--da3_model_path_custom", type=str, default=None) - parser.add_argument("--da3_frame_interval", type=int, default=8) - parser.add_argument("--da3_max_history_frames", type=int, default=10) - parser.add_argument("--da3_include_ar_chunk_last_frames", action="store_true") - parser.add_argument("--da3_use_predicted_pose", action="store_true") - parser.add_argument("--da3_predicted_pose_continuation", action="store_true") - - # DMD distillation (4-step fast inference) - parser.add_argument("--use_dmd", action="store_true", - help="Enable DMD fast inference: loads DMD distillation LoRA, " - "activates DMD scheduler, and reduces sampling steps.") - - # Misc flags needed by run_lyra2_sample internals - parser.add_argument("--ablate_same_t5", action="store_true") - parser.add_argument("--use_dmd_scheduler", action="store_true") - parser.add_argument("--warp_chunk_size", type=int, default=None) - parser.add_argument("--num_retrieval_views", type=int, default=1) - parser.add_argument("--disable_cache_update", action="store_true") - parser.add_argument("--multiview_ids", type=int, nargs="+", default=None) - parser.add_argument("--offload_da3_diffusion", action="store_true") - - return parser.parse_args() - - -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - -DMD_LORA_PATH = "checkpoints/lora/dmd_distillation.safetensors" -DMD_LORA_WEIGHT = 1.0 - - -def _apply_dmd_defaults(args): - """When --use_dmd is set, inject the DMD LoRA and switch to the DMD scheduler. - - Note: the DMD scheduler uses a fixed 4-step denoising list internally, so - ``--num_sampling_step`` is ignored in this code path. - """ - if not args.use_dmd: - return - args.use_dmd_scheduler = True - if args.lora_paths is None: - args.lora_paths = [] - if args.lora_weights is None: - args.lora_weights = [] - args.lora_paths.append(DMD_LORA_PATH) - args.lora_weights.append(DMD_LORA_WEIGHT) - log.info( - f"[DMD] Enabled: lora={DMD_LORA_PATH}, scheduler=dmd (4 fixed steps)", - rank0_only=True, - ) - - -if __name__ == "__main__": - args = parse_arguments() - _apply_dmd_defaults(args) - - if args.debug: - import debugpy - debugpy.listen(5678) - log.info("Waiting for debugger to attach...") - debugpy.wait_for_client() - - process_group = None - if args.context_parallel_size > 1: - import imaginaire - from megatron.core import parallel_state - imaginaire.utils.distributed.init() - parallel_state.initialize_model_parallel(context_parallel_size=args.context_parallel_size) - process_group = parallel_state.get_context_parallel_group() - - os.makedirs(args.output_path, exist_ok=True) - misc.set_random_seed(seed=args.seed, by_rank=True) - - # Negative prompt embeddings - negative_prompt_data = torch.load( - "checkpoints/text_encoder/negative_prompt.pt", map_location="cpu", weights_only=False - ) - - # ---- Load FramePack model ---- - experiment_opts = [ - "model.config.use_mp_policy_fsdp=False", - "model.config.keep_original_net_dtype=False", - ] - if args.lora_paths: - experiment_opts += ["model.config.net.postpone_checkpoint=True"] - model, config = load_model_from_checkpoint( - config_file="lyra_2/_src/configs/config.py", - experiment_name=args.experiment, - checkpoint_path=args.checkpoint_dir, - enable_fsdp=False, - instantiate_ema=False, - load_ema_to_reg=False, - experiment_opts=experiment_opts, - ) - if args.lora_paths: - lora_names = [] - for lora_path in args.lora_paths: - lora_name = model.load_lora_weights(lora_path) - lora_names.append(lora_name) - model.set_weights_and_activate_adapters(lora_names, args.lora_weights) - if hasattr(model, "net") and hasattr(model.net, "enable_selective_checkpoint"): - model.net.enable_selective_checkpoint(model.net.sac_config, model.net.blocks) - - desired_dtype = model.tensor_kwargs.get("dtype", None) - desired_device = model.tensor_kwargs.get("device", None) - if desired_dtype is not None: - model.net = model.net.to(device=desired_device, dtype=desired_dtype) - log.info(f"Casted model.net to dtype={desired_dtype}", rank0_only=True) - - assert getattr(model.config, "important_start", True) is True - assert getattr(model.config, "encode_video_from_start", True) is True - assert not getattr(model.config, "use_hd_map_cond", False) - - model.eval() - if args.context_parallel_size > 1: - model.net.enable_context_parallel(process_group) - - if args.warp_chunk_size is not None: - model.config.warp_chunk_size = args.warp_chunk_size - model.warp_chunk_size = args.warp_chunk_size - - # Resolution - target_h, target_w = [int(x) for x in args.resolution.split(",")] - - # ---- Load DA3 model ---- - from lyra_2._src.inference.depth_utils import load_da3_model - da3_device = model.tensor_kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") - da3_model = load_da3_model( - da3_model_name=args.da3_model_name, - da3_model_path_custom=args.da3_model_path_custom, - device=da3_device, - ) - da3_model.eval() - - # ---- Optionally load MoGe model for depth scale alignment ---- - moge_model = None - if args.use_moge_scale: - from lyra_2._src.inference.depth_utils import load_moge_model - moge_device = model.tensor_kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") - moge_model = load_moge_model(moge_device) - moge_model.eval() - log.info("MoGe model loaded for depth scale alignment.", rank0_only=True) - - # ---- Resolve image(s) ---- - image_paths = _build_image_list(args.input_image_path)[ - args.sample_start_idx : args.sample_start_idx + args.num_samples - ] - - # Resolve trajectory file(s): single file shared across images, or per-image files in a folder. - traj_is_dir = os.path.isdir(args.trajectory_path) - - # Resolve captions source: per-chunk JSON or single caption - captions_is_dir = args.captions_path is not None and os.path.isdir(args.captions_path) - - N = int(args.num_frames) - - for img_idx, img_path in enumerate(image_paths): - base_name = os.path.splitext(os.path.basename(img_path))[0] - - video_path = os.path.join(args.output_path, f"{base_name}.mp4") - if os.path.exists(video_path): - log.info(f"Skipping {img_path} (video already exists at {video_path})", rank0_only=True) - continue - - log.info(f"Processing [{img_idx}]: {img_path}", rank0_only=True) - misc.set_random_seed(seed=args.seed, by_rank=True) - - # ---- Load trajectory ---- - if traj_is_dir: - traj_file = os.path.join(args.trajectory_path, f"{base_name}.npz") - else: - traj_file = args.trajectory_path - if not os.path.isfile(traj_file): - log.error(f"Trajectory file not found: {traj_file}") - continue - - w2cs_T_44, Ks_T_33 = load_trajectory(traj_file, N, target_hw=(target_h, target_w), pose_scale=args.pose_scale) - log.info(f"Loaded trajectory: {w2cs_T_44.shape[0]} frames from {traj_file}", rank0_only=True) - - # ---- Read image ---- - bgr = cv2.imread(img_path) - if bgr is None: - log.error(f"Cannot read: {img_path}") - continue - rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) - rgb_t = torch.from_numpy(rgb) - - # ---- Depth & intrinsics for the first frame (via DA3) ---- - log.info("Running DA3 single-image depth...", rank0_only=True) - image_chw01, depth_hw, _K_33_da3, mask_hw = _da3_infer_depth_intrinsics_single( - da3_model=da3_model, - img_rgb_uint8=rgb_t, - target_hw=(target_h, target_w), - ) - H, W = image_chw01.shape[-2:] - - # ---- Optionally align DA3 depth to MoGe scale ---- - if args.use_moge_scale and moge_model is not None: - log.info("Aligning DA3 depth to MoGe scale...", rank0_only=True) - from lyra_2._src.inference.depth_utils import moge_infer_depth_intrinsics - - moge_model.to(desired_device) - with torch.nn.attention.sdpa_kernel( - [torch.nn.attention.SDPBackend.MATH] - ): - _, moge_depth_hw, _, moge_mask_hw = moge_infer_depth_intrinsics( - moge_model, - rgb_t, - depth_pred_hw=(target_h, target_w), - target_hw=(target_h, target_w), - ) - - da3_d = depth_hw.to(moge_depth_hw.device) - da3_m = mask_hw.to(moge_mask_hw.device) - - valid_mask = (da3_m > 0.5) & (moge_mask_hw > 0.5) - if valid_mask.sum() > 10: - d_da3_vals = da3_d[valid_mask] - d_moge_vals = moge_depth_hw[valid_mask] - - inv_da3 = 1.0 / (d_da3_vals + 1e-6) - inv_moge = 1.0 / (d_moge_vals + 1e-6) - - numerator = (inv_da3 * inv_moge).sum() - denominator = (inv_da3 * inv_da3).sum() - - if denominator > 1e-8: - scale = numerator / denominator - log.info(f"Global inverse-depth scale factor: {scale.item()}", rank0_only=True) - if scale > 1e-6: - depth_hw = depth_hw / scale.to(depth_hw.device) - else: - log.warning(f"Scale too small ({scale.item()}), skipping alignment.", rank0_only=True) - else: - log.warning("Denominator too small for LS scale alignment.", rank0_only=True) - else: - log.warning("Not enough overlapping valid pixels for scale alignment.", rank0_only=True) - - moge_model.cpu() - del moge_depth_hw, moge_mask_hw, da3_d, da3_m - torch.cuda.empty_cache() - gc.collect() - - img_bchw = image_chw01.to(device=desired_device) * 2.0 - 1.0 - - # ---- Load captions ---- - from lyra_2._src.inference.get_t5_emb import get_umt5_embedding, get_umt5_embedding_offloaded - neg_t5 = misc.to(negative_prompt_data["t5_text_embeddings"], **model.tensor_kwargs) - - captions_file = None - if args.captions_path is not None: - if captions_is_dir: - captions_file = os.path.join(args.captions_path, f"{base_name}.json") - else: - captions_file = args.captions_path - if not os.path.isfile(captions_file): - log.warning(f"Captions file not found: {captions_file}, falling back to single caption") - captions_file = None - - use_chunk_captions = False - if captions_file is not None: - with open(captions_file, "r") as f: - captions_dict = json.load(f) - chunk_keys_int = sorted(int(k) for k in captions_dict) - chunk_keys_int = [k for k in chunk_keys_int if k < N] - if len(chunk_keys_int) > 1: - use_chunk_captions = True - log.info(f"Loaded {len(chunk_keys_int)} chunk captions from {captions_file}", rank0_only=True) - - chunk_keys = torch.tensor(chunk_keys_int, dtype=torch.long, device=desired_device) - chunk_embs = [] - chunk_masks = [] - for ck in chunk_keys_int: - cap = captions_dict[str(ck)] - if args.prompt_suffix: - cap = cap.rstrip() + " " + args.prompt_suffix - if args.offload_when_prompt: - emb = get_umt5_embedding_offloaded(cap, device=desired_device).to(dtype=desired_dtype) - else: - emb = get_umt5_embedding(cap, device=desired_device).to(dtype=desired_dtype) - if emb.dim() == 3: - emb = emb[0] - S, D = emb.shape - S = min(S, 512) - D = min(D, 4096) - padded_emb = torch.zeros(512, 4096, dtype=desired_dtype, device=desired_device) - padded_emb[:S, :D] = emb[:S, :D] - padded_mask = torch.zeros(512, dtype=desired_dtype, device=desired_device) - padded_mask[:S] = 1.0 - chunk_embs.append(padded_emb) - chunk_masks.append(padded_mask) - - t5_chunk_embeddings = torch.stack(chunk_embs).unsqueeze(0) - t5_chunk_mask = torch.stack(chunk_masks).unsqueeze(0) - t5_chunk_keys = chunk_keys.unsqueeze(0) - sample_frame_indices = torch.arange(N, dtype=torch.long, device=desired_device).unsqueeze(0) - t5 = t5_chunk_embeddings[:, 0, :, :] - else: - single_caption = captions_dict.get(str(chunk_keys_int[0]), "") if chunk_keys_int else "" - if args.prompt_suffix: - single_caption = single_caption.rstrip() + " " + args.prompt_suffix - - if not use_chunk_captions: - if args.prompt: - caption = args.prompt - elif captions_file is not None: - caption = single_caption - elif args.prompt_dir: - txt_path = os.path.join(args.prompt_dir, f"{base_name}.txt") - if not os.path.isfile(txt_path): - log.error(f"Caption file not found: {txt_path}") - continue - with open(txt_path, "r") as f: - caption = f.read().strip() - log.info(f"Loaded caption from {txt_path}", rank0_only=True) - else: - raise RuntimeError( - "No caption source specified. Use --captions_path, --prompt, or --prompt_dir." - ) - if args.prompt_suffix: - caption = caption.rstrip() + " " + args.prompt_suffix - if args.offload_when_prompt: - t5 = get_umt5_embedding_offloaded(caption, device=desired_device).to(dtype=desired_dtype) - else: - t5 = get_umt5_embedding(caption, device=desired_device).to(dtype=desired_dtype) - if t5.dim() == 2: - t5 = t5.unsqueeze(0) - elif t5.dim() == 3 and t5.shape[0] != 1: - t5 = t5[:1] - - # ---- Assemble data batch ---- - w2cs_b_t_44 = w2cs_T_44.unsqueeze(0).to(dtype=torch.float32, device=desired_device) - Ks_b_t_33 = Ks_T_33.unsqueeze(0).to(dtype=torch.float32, device=desired_device) - depth_b_thw = depth_hw.unsqueeze(0).unsqueeze(0).repeat(1, N, 1, 1).to(device=desired_device) - - data_batch = { - "video": img_bchw.unsqueeze(2), - "t5_text_embeddings": t5, - "neg_t5_text_embeddings": neg_t5, - "fps": torch.tensor([args.fps], dtype=torch.int32, device=desired_device), - "padding_mask": torch.zeros((1, 1, H, W), dtype=model.tensor_kwargs["dtype"], device=desired_device), - "is_preprocessed": torch.tensor([True], dtype=torch.bool, device=desired_device), - "camera_w2c": w2cs_b_t_44, - "intrinsics": Ks_b_t_33, - "depth": depth_b_thw, - } - - if use_chunk_captions: - data_batch["t5_chunk_keys"] = t5_chunk_keys - data_batch["t5_chunk_embeddings"] = t5_chunk_embeddings - data_batch["t5_chunk_mask"] = t5_chunk_mask - data_batch["sample_frame_indices"] = sample_frame_indices - - skip_keys = {"camera_w2c", "intrinsics", "depth", "t5_chunk_keys", "sample_frame_indices"} - data_batch = safe_to( - data_batch, - device=model.tensor_kwargs.get("device", None), - dtype=model.tensor_kwargs.get("dtype", None), - skip_keys=skip_keys, - ) - - # ---- Run AR inference ---- - log.info(f"=== Generating video ({N} frames) ===", rank0_only=True) - result = run_lyra2_sample( - model, - data_batch, - args, - process_group=process_group, - da3_model=da3_model, - show_progress=True, - log_prefix=f"{base_name}_custom_traj", - ) - - if result is None: - log.warning(f"Generation failed for {img_path}", rank0_only=True) - continue - - # ---- Save output video ---- - video_01 = (result["video"][0].clamp(-1, 1) * 0.5 + 0.5).float().cpu() - save_img_or_video(video_01, video_path.replace(".mp4", ""), fps=args.fps) - log.info(f"Saved video: {video_path}", rank0_only=True) - - del result, data_batch - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - # Clean up distributed - if args.context_parallel_size > 1: - from megatron.core import parallel_state - parallel_state.destroy_model_parallel() - try: - import torch.distributed as dist - dist.destroy_process_group() - except Exception: - pass - - log.info("Done.", rank0_only=True) diff --git a/lyra_2/_src/inference/lyra2_zoomgs_inference.py b/lyra_2/_src/inference/lyra2_zoomgs_inference.py deleted file mode 100644 index 31eea036af6b9d4a16141ff6405acd593fade1a7..0000000000000000000000000000000000000000 --- a/lyra_2/_src/inference/lyra2_zoomgs_inference.py +++ /dev/null @@ -1,811 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Single-image → zoom-in + zoom-out video generation. - -Pipeline: -1. Read a single input image. -2. Load its text caption from a pre-generated .txt file (see scripts/gemini_caption.py). -3. Produce a zoom-in and a zoom-out video using Lyra2 AR spatial generation. -4. Save individual + combined videos. - -GS reconstruction is handled separately by vipe_da3_gs_recon.py. -""" - -from __future__ import annotations - -import argparse -import gc -import os -from typing import List, Tuple - -import cv2 -import numpy as np -import torch -import torch.nn.functional as F - -from lyra_2._ext.imaginaire.utils import log, misc -from lyra_2._ext.imaginaire.visualize.video import save_img_or_video -from lyra_2._src.inference.lyra2_ar_inference import ( - save_output, - safe_to, - run_lyra2_sample, -) -from lyra_2._src.inference.camera_traj_utils import ( - build_camera_trajectory, - CAMERA_TRAJECTORY_CHOICES, -) -from lyra_2._src.utils.model_loader import load_model_from_checkpoint - -torch.enable_grad(False) -torch.backends.cudnn.enabled = False - -# --------------------------------------------------------------------------- -# DA3 single-image depth (reused from lyra2_ar_inference_from_image) -# --------------------------------------------------------------------------- - -def _da3_infer_depth_intrinsics_single( - da3_model, - img_rgb_uint8: torch.Tensor, - target_hw: Tuple[int, int], -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """DA3 single-image depth: RGB uint8 HWC -> (image_chw01, depth_hw, K_33, mask_hw).""" - Ht, Wt = target_hw - img_np = img_rgb_uint8.detach().cpu().numpy() - img_resized = cv2.resize(img_np, (Wt, Ht), interpolation=cv2.INTER_LINEAR) - - image_chw01 = torch.from_numpy(img_resized.astype(np.float32) / 255.0) - image_chw01 = image_chw01.permute(2, 0, 1).unsqueeze(0).contiguous() - - images = [img_resized.astype(np.uint8)] - prediction = da3_model.inference( - image=images, - extrinsics=None, - intrinsics=None, - align_to_input_ext_scale=True, - infer_gs=False, - process_res=int(max(Ht, Wt)), - process_res_method="upper_bound_resize", - export_dir=None, - export_format="mini_npz", - ) - - depths_np = getattr(prediction, "depth", None) - if depths_np is None: - raise RuntimeError("DA3 prediction has no 'depth' field.") - if isinstance(depths_np, torch.Tensor): - depth_np = depths_np[0].detach().cpu().numpy() - else: - depth_np = np.asarray(depths_np)[0] - Hd, Wd = depth_np.shape[-2:] - - depth_t = torch.from_numpy(depth_np.astype(np.float32)).unsqueeze(0).unsqueeze(0) - if (Hd, Wd) != (Ht, Wt): - depth_t = F.interpolate(depth_t, size=(Ht, Wt), mode="bilinear", align_corners=False) - depth_hw = depth_t[0, 0] - depth_hw = torch.nan_to_num(depth_hw, nan=1e4).clamp(min=0, max=1e4) - mask_hw = (depth_hw < 999.9).to(dtype=torch.float32) - - try: - ixts_np = getattr(prediction, "intrinsics", None) - if ixts_np is None: - raise AttributeError - if isinstance(ixts_np, torch.Tensor): - K_np = ixts_np[0].detach().cpu().numpy() - else: - K_np = np.asarray(ixts_np)[0] - K_33 = torch.from_numpy(K_np.astype(np.float32)) - scale_x = float(Wt) / float(Wd) - scale_y = float(Ht) / float(Hd) - K_33 = K_33.clone() - K_33[0, 0] *= scale_x - K_33[1, 1] *= scale_y - K_33[0, 2] *= scale_x - K_33[1, 2] *= scale_y - except Exception: - fx = fy = max(Ht, Wt) * 1.5 - cx, cy = Wt / 2.0, Ht / 2.0 - K_33 = torch.tensor([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], dtype=torch.float32) - - return image_chw01, depth_hw, K_33, mask_hw - - -def _camera_centers_from_w2c(w2c: torch.Tensor) -> torch.Tensor: - R = w2c[:, :3, :3] - t = w2c[:, :3, 3] - return -(R.transpose(1, 2) @ t.unsqueeze(-1)).squeeze(-1) - - -# --------------------------------------------------------------------------- -# Argument parser -# --------------------------------------------------------------------------- - -def parse_arguments() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Single-image zoom-in/zoom-out video generation" - ) - # Input - parser.add_argument("--input_image_path", type=str, required=True, - help="Path to a single image or a folder of images") - parser.add_argument("--num_samples", type=int, default=10) - parser.add_argument("--sample_start_idx", type=int, default=0) - parser.add_argument("--sample_id", type=int, default=None, - help="Run only the sample at this index (0-based). " - "Overrides --num_samples and --sample_start_idx.") - parser.add_argument("--prompt", type=str, default="", - help="Optional explicit prompt applied to ALL images.") - parser.add_argument("--prompt_dir", type=str, default=None, - help="Directory containing per-image .txt caption files. " - "Each file should be named .txt. " - "When set, Gemini captioning is skipped entirely.") - parser.add_argument("--prompt_suffix", type=str, default="", - help="Text appended to every prompt.") - - # Model and generation - parser.add_argument("--experiment", type=str, default="lyra_framepack_spatial") - parser.add_argument("--checkpoint_dir", type=str, default="checkpoints/model") - parser.add_argument("--output_path", type=str, default="inference/lyra2_zoomgs") - parser.add_argument("--guidance", type=float, default=5.0) - parser.add_argument("--shift", type=float, default=5.0) - parser.add_argument("--num_sampling_step", type=int, default=50) - parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--fps", type=int, default=16) - parser.add_argument("--num_frames", type=int, default=161, - help="Default frames per direction. Overridden by --num_frames_zoom_in / --num_frames_zoom_out.") - parser.add_argument("--num_frames_zoom_in", type=int, default=81, - help="Frames for zoom-in. Falls back to --num_frames if not set.") - parser.add_argument("--num_frames_zoom_out", type=int, default=241, - help="Frames for zoom-out. Falls back to --num_frames if not set.") - parser.add_argument("--resolution", type=str, default="480,832", help="H,W") - parser.add_argument("--context_parallel_size", type=int, default=1) - parser.add_argument("--lora_paths", type=str, nargs="+", - default=["checkpoints/lora/realism_boost.safetensors", - "checkpoints/lora/detail_enhancer.safetensors"]) - parser.add_argument("--lora_weights", type=float, nargs="+", default=[0.4, 0.4]) - parser.add_argument("--offload", action="store_true") - parser.add_argument("--offload_when_prompt", action="store_true") - - # Camera trajectory for zoom - parser.add_argument("--zoom_in_trajectory", type=str, default="horizontal_zoom", - choices=list(CAMERA_TRAJECTORY_CHOICES), - help="Camera trajectory for zoom-in video.") - parser.add_argument("--zoom_out_trajectory", type=str, default="horizontal_zoom", - choices=list(CAMERA_TRAJECTORY_CHOICES), - help="Camera trajectory for zoom-out video.") - parser.add_argument("--zoom_in_direction", type=str, default="right", - choices=["left", "right", "up", "down"], - help="Direction for zoom-in (right = forward along z).") - parser.add_argument("--zoom_out_direction", type=str, default="left", - choices=["left", "right", "up", "down"], - help="Direction for zoom-out (left = backward along z).") - parser.add_argument("--zoom_in_strength", type=float, default=0.5) - parser.add_argument("--zoom_out_strength", type=float, default=1.5) - - # Depth backend - parser.add_argument("--use_moge_scale", action=argparse.BooleanOptionalAction, default=True, - help="Align DA3 depth to MoGe scale during seeding (default: True).") - parser.add_argument("--ground_plane_align", action="store_true", - help="Fit a ground plane from depth and move camera parallel to it.") - parser.add_argument("--ground_plane_bottom_frac", type=float, default=0.4, - help="Fraction of the image (from bottom) to use for ground plane fitting.") - parser.add_argument("--zoom_out_upward_shift", type=float, default=0.05, - help="Extra linear upward shift (along ground normal) applied to zoom-out " - "trajectory. 0 = disabled. Units are in camera-space translation.") - parser.add_argument("--zoom_out_upward_ratio", type=float, default=0.15, - help="Ratio of upward component added to the zoom-out backward trajectory. " - "0 = pure z-axis retreat, 0.15 = slight diagonal upward tilt. " - "Applied independently of --zoom_out_upward_shift.") - parser.add_argument("--depth_backend", type=str, default="da3", choices=["da3"]) - parser.add_argument("--da3_model_name", type=str, default="depth-anything/DA3NESTED-GIANT-LARGE-1.1") - parser.add_argument("--da3_model_path_custom", type=str, default="checkpoints/recon/model.pt") - parser.add_argument("--da3_frame_interval", type=int, default=8) - parser.add_argument("--da3_max_history_frames", type=int, default=10) - parser.add_argument("--da3_include_ar_chunk_last_frames", action="store_true") - parser.add_argument("--da3_use_predicted_pose", action="store_true", - help="Use DA3-predicted camera poses (aligned to pipeline coords) for cache updates.") - parser.add_argument("--da3_predicted_pose_continuation", action="store_true", - help="Apply DA3-predicted pose alignment for continuation segments.") - - # DMD distillation (4-step fast inference) - parser.add_argument("--use_dmd", action="store_true", - help="Enable DMD fast inference: loads DMD distillation LoRA, " - "activates DMD scheduler, and reduces sampling steps.") - - # Misc flags needed by run_lyra2_sample internals - parser.add_argument("--ablate_same_t5", action="store_true") - parser.add_argument("--use_dmd_scheduler", action="store_true") - parser.add_argument("--warp_chunk_size", type=int, default=None) - parser.add_argument("--num_retrieval_views", type=int, default=1) - parser.add_argument("--disable_cache_update", action="store_true") - parser.add_argument("--multiview_ids", type=int, nargs="+", default=None) - parser.add_argument("--offload_da3_diffusion", action="store_true") - - return parser.parse_args() - - - - -def _build_image_list(path: str) -> List[str]: - if os.path.isdir(path): - exts = {".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG"} - files = [os.path.join(path, f) for f in sorted(os.listdir(path)) if os.path.splitext(f)[1] in exts] - if not files: - raise FileNotFoundError(f"No images found in folder: {path}") - return files - else: - if not os.path.isfile(path): - raise FileNotFoundError(f"Input image not found: {path}") - return [path] - - -def _fit_ground_normal_from_depth( - depth_hw: torch.Tensor, - K_33: torch.Tensor, - mask_hw: torch.Tensor, - bottom_frac: float = 0.4, - ransac_iters: int = 200, - ransac_thresh: float = 0.05, -) -> torch.Tensor | None: - """Fit a ground plane from the bottom portion of the depth map. - - Returns the plane normal in camera space (pointing 'up' away from ground), - or None if fitting fails. - """ - H, W = depth_hw.shape - y_start = int(H * (1.0 - bottom_frac)) - - valid = (mask_hw[y_start:] > 0.5) & (depth_hw[y_start:] > 0.01) & (depth_hw[y_start:] < 500.0) - if valid.sum() < 50: - return None - - ys, xs = torch.where(valid) - ys = ys + y_start - depths = depth_hw[ys, xs] - - fx, fy = K_33[0, 0], K_33[1, 1] - cx, cy = K_33[0, 2], K_33[1, 2] - X = (xs.float() - cx) / fx * depths - Y = (ys.float() - cy) / fy * depths - Z = depths - pts = torch.stack([X, Y, Z], dim=-1) # (N, 3) - - N_pts = pts.shape[0] - best_normal = None - best_inliers = 0 - - for _ in range(ransac_iters): - idx = torch.randint(0, N_pts, (3,)) - p0, p1, p2 = pts[idx[0]], pts[idx[1]], pts[idx[2]] - v1 = p1 - p0 - v2 = p2 - p0 - n = torch.cross(v1, v2, dim=0) - norm = n.norm() - if norm < 1e-8: - continue - n = n / norm - d = -torch.dot(n, p0) - dists = (pts @ n + d).abs() - inlier_count = (dists < ransac_thresh * Z.abs().clamp(min=0.1)).sum().item() - if inlier_count > best_inliers: - best_inliers = inlier_count - best_normal = n - - if best_normal is None: - return None - - # Ensure normal points "up" in camera space (negative y direction = up in image coords) - if best_normal[1] > 0: - best_normal = -best_normal - - log.info( - f"[ground_plane] Fitted normal: [{best_normal[0]:.4f}, {best_normal[1]:.4f}, {best_normal[2]:.4f}], " - f"inliers: {best_inliers}/{N_pts}", - rank0_only=True, - ) - return best_normal - - -def _correct_trajectory_ground_parallel( - w2cs_T_44: torch.Tensor, - ground_normal_cam: torch.Tensor, -) -> torch.Tensor: - """Re-project w2c translations so camera moves parallel to the ground plane. - - The original trajectory's translation direction (typically camera z-axis) is - projected onto the ground plane, preserving the total displacement magnitude. - Camera orientation (rotation) is kept unchanged. - """ - T = w2cs_T_44.shape[0] - n = ground_normal_cam.to(w2cs_T_44.device, dtype=w2cs_T_44.dtype) - - t0 = w2cs_T_44[0, :3, 3] - displacements = w2cs_T_44[:, :3, 3] - t0.unsqueeze(0) # (T, 3) - - # Project each displacement onto the ground plane - n_dot_d = (displacements * n.unsqueeze(0)).sum(dim=-1, keepdim=True) # (T, 1) - projected = displacements - n_dot_d * n.unsqueeze(0) # (T, 3) - - # Preserve original displacement magnitudes - orig_norms = displacements.norm(dim=-1, keepdim=True).clamp(min=1e-8) - proj_norms = projected.norm(dim=-1, keepdim=True).clamp(min=1e-8) - projected = projected * (orig_norms / proj_norms) - # First frame stays at origin (no displacement) - projected[0] = 0.0 - - corrected = w2cs_T_44.clone() - corrected[:, :3, 3] = t0.unsqueeze(0) + projected - return corrected - - -def _generate_one_direction( - *, - model, - args: argparse.Namespace, - img_bchw: torch.Tensor, - depth_hw: torch.Tensor, - mask_hw: torch.Tensor, - K_33: torch.Tensor, - t5_embeddings: torch.Tensor, - neg_t5_embeddings: torch.Tensor, - trajectory: str, - direction: str, - strength: float, - N: int, - da3_model=None, - process_group=None, - log_prefix: str = "", - ground_normal_cam: torch.Tensor | None = None, - upward_shift: float = 0.0, - zoom_out_upward_ratio: float = 0.0, -) -> dict | None: - """Run AR spatial inference for a single camera trajectory direction.""" - device = model.tensor_kwargs.get("device", None) - H, W = img_bchw.shape[-2:] - - initial_w2c = torch.eye(4, dtype=torch.float32, device=device) - center_depth = torch.quantile(depth_hw[mask_hw > 0.5], 0.25) - - w2cs_T_44, Ks_T_33 = build_camera_trajectory( - initial_w2c, - K_33.to(initial_w2c), - center_depth, - N, - trajectory, - direction, - strength, - ) - - if zoom_out_upward_ratio > 0.0: - cam_centers = _camera_centers_from_w2c(w2cs_T_44) - z_disp = cam_centers[:, 2] - cam_centers[0, 2] - backward_amount = (-z_disp).clamp(min=0) - upward_amount = backward_amount * zoom_out_upward_ratio - cam_centers_shifted = cam_centers.clone() - cam_centers_shifted[:, 1] -= upward_amount - R = w2cs_T_44[:, :3, :3] - new_t = -(R @ cam_centers_shifted.unsqueeze(-1)).squeeze(-1) - w2cs_T_44 = w2cs_T_44.clone() - w2cs_T_44[:, :3, 3] = new_t - log.info( - f"{log_prefix} [upward_tilt] Added upward ratio={zoom_out_upward_ratio:.3f}, " - f"max_upward={upward_amount.max().item():.4f}", - rank0_only=True, - ) - - if ground_normal_cam is not None: - w2cs_T_44 = _correct_trajectory_ground_parallel(w2cs_T_44, ground_normal_cam) - - if upward_shift > 0.0: - n = ground_normal_cam.to(w2cs_T_44.device, dtype=w2cs_T_44.dtype) - T = w2cs_T_44.shape[0] - ramp = torch.linspace(0, upward_shift, T, device=w2cs_T_44.device, dtype=w2cs_T_44.dtype) - w2cs_T_44 = w2cs_T_44.clone() - w2cs_T_44[:, :3, 3] -= ramp.unsqueeze(-1) * n.unsqueeze(0) - - w2cs_b_t_44 = w2cs_T_44.unsqueeze(0).to(dtype=torch.float32) - Ks_b_t_33 = Ks_T_33.unsqueeze(0).to(dtype=torch.float32) - - depth_b_thw = depth_hw.unsqueeze(0).unsqueeze(0).repeat(1, N, 1, 1).to(device=device) - - data_batch = { - "video": img_bchw.unsqueeze(2), - "t5_text_embeddings": t5_embeddings, - "neg_t5_text_embeddings": neg_t5_embeddings, - "fps": torch.tensor([args.fps], dtype=torch.int32, device=device), - "padding_mask": torch.zeros((1, 1, H, W), dtype=model.tensor_kwargs["dtype"], device=device), - "is_preprocessed": torch.tensor([True], dtype=torch.bool, device=device), - "camera_w2c": w2cs_b_t_44, - "intrinsics": Ks_b_t_33, - "depth": depth_b_thw, - } - - skip_keys = {"camera_w2c", "intrinsics", "depth"} - data_batch = safe_to( - data_batch, - device=model.tensor_kwargs.get("device", None), - dtype=model.tensor_kwargs.get("dtype", None), - skip_keys=skip_keys, - ) - - saved_num_frames = args.num_frames - args.num_frames = N - try: - result = run_lyra2_sample( - model, - data_batch, - args, - process_group=process_group, - da3_model=da3_model, - show_progress=True, - log_prefix=log_prefix, - ) - finally: - args.num_frames = saved_num_frames - - return result - - -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - -DMD_LORA_PATH = "checkpoints/lora/dmd_distillation.safetensors" -DMD_LORA_WEIGHT = 1.0 - - -def _apply_dmd_defaults(args): - """When --use_dmd is set, inject the DMD LoRA and switch to the DMD scheduler. - - Note: the DMD scheduler uses a fixed 4-step denoising list internally, so - ``--num_sampling_step`` is ignored in this code path. - """ - if not args.use_dmd: - return - args.use_dmd_scheduler = True - if args.lora_paths is None: - args.lora_paths = [] - if args.lora_weights is None: - args.lora_weights = [] - args.lora_paths.append(DMD_LORA_PATH) - args.lora_weights.append(DMD_LORA_WEIGHT) - log.info( - f"[DMD] Enabled: lora={DMD_LORA_PATH}, scheduler=dmd (4 fixed steps)", - rank0_only=True, - ) - - -if __name__ == "__main__": - args = parse_arguments() - _apply_dmd_defaults(args) - - process_group = None - if args.context_parallel_size > 1: - import imaginaire - from megatron.core import parallel_state - imaginaire.utils.distributed.init() - parallel_state.initialize_model_parallel(context_parallel_size=args.context_parallel_size) - process_group = parallel_state.get_context_parallel_group() - - os.makedirs(args.output_path, exist_ok=True) - misc.set_random_seed(seed=args.seed, by_rank=True) - - # Negative prompt embeddings - negative_prompt_data = torch.load( - "checkpoints/text_encoder/negative_prompt.pt", map_location="cpu", weights_only=False - ) - - # Load Lyra2 model - experiment_opts = [ - "model.config.use_mp_policy_fsdp=False", - "model.config.keep_original_net_dtype=False", - ] - if args.lora_paths: - experiment_opts += ["model.config.net.postpone_checkpoint=True"] - model, config = load_model_from_checkpoint( - config_file="lyra_2/_src/configs/config.py", - experiment_name=args.experiment, - checkpoint_path=args.checkpoint_dir, - enable_fsdp=False, - instantiate_ema=False, - load_ema_to_reg=False, - experiment_opts=experiment_opts, - ) - if args.lora_paths: - lora_names = [] - for lora_path in args.lora_paths: - lora_name = model.load_lora_weights(lora_path) - lora_names.append(lora_name) - model.set_weights_and_activate_adapters(lora_names, args.lora_weights) - if hasattr(model, "net") and hasattr(model.net, "enable_selective_checkpoint"): - model.net.enable_selective_checkpoint(model.net.sac_config, model.net.blocks) - - desired_dtype = model.tensor_kwargs.get("dtype", None) - desired_device = model.tensor_kwargs.get("device", None) - if desired_dtype is not None: - model.net = model.net.to(device=desired_device, dtype=desired_dtype) - log.info(f"Casted model.net to dtype={desired_dtype}", rank0_only=True) - - assert getattr(model.config, "important_start", True) is True - assert getattr(model.config, "encode_video_from_start", True) is True - assert not getattr(model.config, "use_hd_map_cond", False) - - model.eval() - if args.context_parallel_size > 1: - model.net.enable_context_parallel(process_group) - - if args.warp_chunk_size is not None: - model.config.warp_chunk_size = args.warp_chunk_size - model.warp_chunk_size = args.warp_chunk_size - - # Resolution - target_h, target_w = [int(x) for x in args.resolution.split(",")] - - # Load DA3 model - from lyra_2._src.inference.depth_utils import load_da3_model - da3_device = model.tensor_kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") - da3_model = load_da3_model( - da3_model_name=args.da3_model_name, - da3_model_path_custom=args.da3_model_path_custom, - device=da3_device, - ) - da3_model.eval() - - # Optionally load MoGe model for depth scale alignment - moge_model = None - if args.use_moge_scale: - from lyra_2._src.inference.depth_utils import load_moge_model - moge_device = model.tensor_kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") - moge_model = load_moge_model(moge_device) - moge_model.eval() - log.info("MoGe model loaded for depth scale alignment.", rank0_only=True) - - # Resolve image(s) - all_image_paths = _build_image_list(args.input_image_path) - if args.sample_id is not None: - if args.sample_id < 0 or args.sample_id >= len(all_image_paths): - raise IndexError( - f"--sample_id {args.sample_id} out of range [0, {len(all_image_paths) - 1}]" - ) - image_paths = [all_image_paths[args.sample_id]] - else: - image_paths = all_image_paths[args.sample_start_idx:args.sample_start_idx + args.num_samples] - - videos_dir = os.path.join(args.output_path, "videos") - os.makedirs(videos_dir, exist_ok=True) - - for img_idx, img_path in enumerate(image_paths): - base_name = os.path.splitext(os.path.basename(img_path))[0] - per_image_dir = os.path.join(args.output_path, base_name) - os.makedirs(per_image_dir, exist_ok=True) - - combined_video_path = os.path.join(videos_dir, f"{base_name}.mp4") - if os.path.exists(combined_video_path): - log.info(f"Skipping {img_path} (combined video already exists at {combined_video_path})", rank0_only=True) - continue - - log.info(f"Processing [{img_idx}]: {img_path}", rank0_only=True) - misc.set_random_seed(seed=args.seed, by_rank=True) - - # Read image - bgr = cv2.imread(img_path) - if bgr is None: - log.error(f"Cannot read: {img_path}") - continue - rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) - rgb_t = torch.from_numpy(rgb) # H,W,3 uint8 - - # Step 1: Depth & intrinsics - log.info("Running DA3 single-image depth...", rank0_only=True) - image_chw01, depth_hw, K_33, mask_hw = _da3_infer_depth_intrinsics_single( - da3_model=da3_model, - img_rgb_uint8=rgb_t, - target_hw=(target_h, target_w), - ) - H, W = image_chw01.shape[-2:] - - # Step 1b: Optionally align DA3 depth to MoGe scale - if args.use_moge_scale and moge_model is not None: - log.info("Aligning DA3 depth to MoGe scale...", rank0_only=True) - from lyra_2._src.inference.depth_utils import moge_infer_depth_intrinsics - - moge_model.to(desired_device) - with torch.nn.attention.sdpa_kernel( - [torch.nn.attention.SDPBackend.MATH] - ): - _, moge_depth_hw, _, moge_mask_hw = moge_infer_depth_intrinsics( - moge_model, - rgb_t, - depth_pred_hw=(target_h, target_w), - target_hw=(target_h, target_w), - ) - - da3_d = depth_hw.to(moge_depth_hw.device) - da3_m = mask_hw.to(moge_mask_hw.device) - - valid_mask = (da3_m > 0.5) & (moge_mask_hw > 0.5) - if valid_mask.sum() > 10: - d_da3_vals = da3_d[valid_mask] - d_moge_vals = moge_depth_hw[valid_mask] - - inv_da3 = 1.0 / (d_da3_vals + 1e-6) - inv_moge = 1.0 / (d_moge_vals + 1e-6) - - numerator = (inv_da3 * inv_moge).sum() - denominator = (inv_da3 * inv_da3).sum() - - if denominator > 1e-8: - scale = numerator / denominator - log.info(f"Global inverse-depth scale factor: {scale.item()}", rank0_only=True) - if scale > 1e-6: - depth_hw = depth_hw / scale.to(depth_hw.device) - else: - log.warning(f"Scale too small ({scale.item()}), skipping alignment.", rank0_only=True) - else: - log.warning("Denominator too small for LS scale alignment.", rank0_only=True) - else: - log.warning("Not enough overlapping valid pixels for scale alignment.", rank0_only=True) - - # Free MoGe GPU memory before video generation - moge_model.cpu() - del moge_depth_hw, moge_mask_hw, da3_d, da3_m - torch.cuda.empty_cache() - gc.collect() - - img_bchw = image_chw01.to(device=desired_device) * 2.0 - 1.0 # [-1,1] - - # Step 2: Load caption from .txt file or use explicit prompt - if args.prompt: - caption = args.prompt - log.info(f"Using provided prompt: {caption}", rank0_only=True) - elif args.prompt_dir: - txt_path = os.path.join(args.prompt_dir, f"{base_name}.txt") - if not os.path.isfile(txt_path): - log.error(f"Caption file not found: {txt_path} (expected for image {base_name})") - continue - with open(txt_path, "r") as f: - caption = f.read().strip() - log.info(f"Loaded caption from {txt_path}: {caption}", rank0_only=True) - else: - raise RuntimeError( - "No caption source specified. Use --prompt for a global prompt, " - "or --prompt_dir pointing to a folder of .txt files. " - "Run scripts/gemini_caption.py first to generate captions." - ) - - if args.prompt_suffix: - caption = caption.rstrip() + " " + args.prompt_suffix - - # Step 2b: T5 embeddings - from lyra_2._src.inference.get_t5_emb import get_umt5_embedding, get_umt5_embedding_offloaded - if args.offload_when_prompt: - t5 = get_umt5_embedding_offloaded(caption, device=desired_device).to(dtype=desired_dtype) - else: - t5 = get_umt5_embedding(caption, device=desired_device).to(dtype=desired_dtype) - if t5.dim() == 2: - t5 = t5.unsqueeze(0) - elif t5.dim() == 3 and t5.shape[0] != 1: - t5 = t5[:1] - neg_t5 = misc.to(negative_prompt_data["t5_text_embeddings"], **model.tensor_kwargs) - - N_in = int(args.num_frames_zoom_in or args.num_frames) - N_out = int(args.num_frames_zoom_out or args.num_frames) - - # Step 2c: Optionally fit ground plane for trajectory alignment - ground_normal = None - if args.ground_plane_align: - ground_normal = _fit_ground_normal_from_depth( - depth_hw, K_33, mask_hw, - bottom_frac=args.ground_plane_bottom_frac, - ) - if ground_normal is None: - log.warning("Ground plane fitting failed, using original trajectory.", rank0_only=True) - - # Step 3: Generate zoom-in video - log.info(f"=== Generating ZOOM-IN video ({args.zoom_in_trajectory} {args.zoom_in_direction} str={args.zoom_in_strength}, N={N_in}) ===", rank0_only=True) - result_in = _generate_one_direction( - model=model, - args=args, - img_bchw=img_bchw, - depth_hw=depth_hw, - mask_hw=mask_hw, - K_33=K_33, - t5_embeddings=t5, - neg_t5_embeddings=neg_t5, - trajectory=args.zoom_in_trajectory, - direction=args.zoom_in_direction, - strength=args.zoom_in_strength, - N=N_in, - da3_model=da3_model, - process_group=process_group, - log_prefix=f"{base_name}_zoom_in", - ground_normal_cam=ground_normal, - ) - - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # Step 3b: Generate zoom-out video - log.info(f"=== Generating ZOOM-OUT video ({args.zoom_out_trajectory} {args.zoom_out_direction} str={args.zoom_out_strength}, N={N_out}) ===", rank0_only=True) - result_out = _generate_one_direction( - model=model, - args=args, - img_bchw=img_bchw, - depth_hw=depth_hw, - mask_hw=mask_hw, - K_33=K_33, - t5_embeddings=t5, - neg_t5_embeddings=neg_t5, - trajectory=args.zoom_out_trajectory, - direction=args.zoom_out_direction, - strength=args.zoom_out_strength, - N=N_out, - da3_model=da3_model, - process_group=process_group, - log_prefix=f"{base_name}_zoom_out", - upward_shift=args.zoom_out_upward_shift, - ground_normal_cam=ground_normal, - zoom_out_upward_ratio=args.zoom_out_upward_ratio, - ) - - if result_in is None and result_out is None: - log.warning(f"Both zoom-in and zoom-out failed for {img_path}", rank0_only=True) - continue - - # Save individual direction videos - for tag, res in [("zoom_in", result_in), ("zoom_out", result_out)]: - if res is None: - continue - vid_stem = os.path.join(per_image_dir, tag) - to_show = [] - if res.get("warp_video") is not None: - to_show.append(res["warp_video"]) - to_show.append(res["video"]) - save_output(to_show, vid_stem + ".mp4") - log.info(f"Saved {tag} video: {vid_stem}.mp4", rank0_only=True) - - # Combine zoom-out (reversed) + zoom-in into a single video - videos_to_combine = [] - if result_out is not None: - videos_to_combine.append(result_out["video"].flip(dims=[2])) - if result_in is not None: - videos_to_combine.append(result_in["video"]) - - combined_video = torch.cat(videos_to_combine, dim=2) # [B, C, T_total, H, W] - log.info(f"Combined video: {combined_video.shape[2]} frames from both directions", rank0_only=True) - - combined_01 = (combined_video[0].clamp(-1, 1) * 0.5 + 0.5).float().cpu() - save_img_or_video(combined_01, combined_video_path.replace(".mp4", ""), fps=args.fps) - log.info(f"Saved combined video: {combined_video_path}", rank0_only=True) - - per_image_combined = os.path.join(per_image_dir, "combined") - save_img_or_video(combined_01, per_image_combined, fps=args.fps) - - del combined_video, combined_01 - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - # Clean up distributed - if args.context_parallel_size > 1: - from megatron.core import parallel_state - parallel_state.destroy_model_parallel() - try: - import torch.distributed as dist - dist.destroy_process_group() - except Exception: - pass - - log.info("Done.", rank0_only=True) diff --git a/lyra_2/_src/inference/vipe_da3_gs_recon.py b/lyra_2/_src/inference/vipe_da3_gs_recon.py deleted file mode 100644 index e31cec2dfd6938ad7f93df09f6952dbb72f9dc71..0000000000000000000000000000000000000000 --- a/lyra_2/_src/inference/vipe_da3_gs_recon.py +++ /dev/null @@ -1,915 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import argparse -import os -import sys -import tempfile -from dataclasses import dataclass -from pathlib import Path -from typing import List, Optional, Sequence, Tuple - -import cv2 -import numpy as np -import torch -from plyfile import PlyData - -from lyra_2._src.inference.depth_utils import load_da3_model - - -REPO_ROOT = Path(__file__).resolve().parents[3] -DEFAULT_RECON_DA3_MODEL_PATH = REPO_ROOT / "checkpoints" / "recon" / "model.pt" - - -@dataclass(slots=True) -class VIPEOutputs: - intrinsics: torch.Tensor - extrinsics_c2w: torch.Tensor - depth: torch.Tensor - frame_ids: torch.Tensor - - -def _ensure_da3_on_syspath() -> Path: - da3_src_root = Path(__file__).resolve().parent / "depth_anything_3" / "src" - if str(da3_src_root) not in sys.path: - sys.path.insert(0, str(da3_src_root)) - return da3_src_root - - -def _to_rgb_tensor_0_1(frames: torch.Tensor) -> torch.Tensor: - if not isinstance(frames, torch.Tensor): - raise TypeError(f"frames must be a torch.Tensor, got {type(frames)}") - if frames.ndim != 4: - raise ValueError(f"frames must be 4D, got shape {tuple(frames.shape)}") - - if frames.shape[-1] == 3: - rgb = frames - elif frames.shape[1] == 3: - rgb = frames.permute(0, 2, 3, 1) - else: - raise ValueError(f"frames must be (T,H,W,3) or (T,3,H,W), got {tuple(frames.shape)}") - - if rgb.dtype == torch.uint8: - rgb = rgb.float() / 255.0 - else: - rgb = rgb.float() - if rgb.max() > 1.5: - rgb = rgb / 255.0 - - return rgb.clamp_(0.0, 1.0) - - -def _compose_vipe_config( - *, - overrides: Sequence[str], - config_dir: str | Path, - config_name: str = "default", -): - from hydra import compose, initialize_config_dir - from hydra.core.global_hydra import GlobalHydra - - cfg_dir = Path(config_dir).resolve() - if GlobalHydra.instance().is_initialized(): - GlobalHydra.instance().clear() - with initialize_config_dir(version_base=None, config_dir=str(cfg_dir)): - cfg = compose(config_name=config_name, overrides=list(overrides)) - return cfg - - -def _import_vipe_class(): - vipe_root = Path(__file__).resolve().parent / "vipe" - if not (vipe_root / "vipe" / "__init__.py").is_file() or not (vipe_root / "configs").is_dir(): - raise ImportError(f"VIPE submodule not found at {vipe_root}") - - import_root = vipe_root - config_dir = vipe_root / "configs" - if str(import_root) not in sys.path: - sys.path.insert(0, str(import_root)) - - try: - from vipe.pipeline import make_pipeline # type: ignore - from vipe.streams.base import VideoFrame, VideoStream # type: ignore - except Exception as e: - raise ImportError(f"Failed to import VIPE from {import_root}") from e - - class InMemoryVideoStream(VideoStream): # type: ignore[misc] - def __init__( - self, - frames_rgb_0_1_thwc: torch.Tensor, - fps_value: float = 16.0, - stream_name: str = "inmem", - device: Optional[torch.device] = None, - ) -> None: - super().__init__() - self.frames_rgb_0_1_thwc = _to_rgb_tensor_0_1(frames_rgb_0_1_thwc) - self.fps_value = float(fps_value) - self.stream_name = stream_name - self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") - - def frame_size(self) -> tuple[int, int]: - _, h, w, _ = self.frames_rgb_0_1_thwc.shape - return (h, w) - - def name(self) -> str: - return self.stream_name - - def fps(self) -> float: - return self.fps_value - - def __len__(self) -> int: - return int(self.frames_rgb_0_1_thwc.shape[0]) - - def __getitem__(self, idx: int) -> VideoFrame: - if idx < 0: - idx = len(self) + idx - if idx < 0 or idx >= len(self): - raise IndexError(idx) - rgb = self.frames_rgb_0_1_thwc[idx].to(self.device, non_blocking=True) - return VideoFrame(raw_frame_idx=idx, rgb=rgb) - - def __iter__(self): - for i in range(len(self)): - yield self[i] - - class VIPEWrapper: - def __init__( - self, - overrides: Sequence[str], - *, - config_dir: str | Path | None = None, - config_name: str = "default", - device: str | torch.device | None = None, - fast_mode: bool = True, - ) -> None: - if config_dir is None: - config_dir = _config_dir - self.cfg = _compose_vipe_config( - overrides=overrides, - config_dir=config_dir, - config_name=config_name, - ) - if fast_mode: - if self.cfg.pipeline.init.get("instance") is not None: - self.cfg.pipeline.init.instance = None - if self.cfg.pipeline.post.get("compute_backward_flow") is not None: - self.cfg.pipeline.post.compute_backward_flow = False - if self.cfg.pipeline.output.get("save_viz") is not None: - self.cfg.pipeline.output.save_viz = False - if self.cfg.pipeline.output.get("save_artifacts") is not None: - self.cfg.pipeline.output.save_artifacts = False - if self.cfg.pipeline.output.get("save_metrics") is not None: - self.cfg.pipeline.output.save_metrics = False - if self.cfg.pipeline.output.get("save_slam_map") is not None: - self.cfg.pipeline.output.save_slam_map = False - self.pipeline = make_pipeline(self.cfg.pipeline) - self.pipeline.return_output_streams = True - self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) - - def infer_frames( - self, - frames: torch.Tensor, - *, - fps: float = 30.0, - name: str = "inmem", - ) -> VIPEOutputs: - stream = InMemoryVideoStream( - frames_rgb_0_1_thwc=frames, - fps_value=float(fps), - stream_name=name, - device=self.device, - ) - out = self.pipeline.run(stream) - assert out.output_streams is not None and len(out.output_streams) == 1 - output_stream = out.output_streams[0] - - intr_list = [] - c2w_list = [] - depth_list = [] - frame_ids = [] - for frame in output_stream: - f = frame.cpu() - frame_ids.append(f.raw_frame_idx) - assert f.intrinsics is not None and f.pose is not None - intr_list.append(f.intrinsics.float()) - c2w_list.append(f.pose.matrix().float()) - if f.metric_depth is not None: - depth_metric = f.metric_depth.float() - else: - height, width = stream.frame_size() - depth_metric = torch.zeros((height, width), dtype=torch.float32) - depth_list.append(depth_metric) - - return VIPEOutputs( - intrinsics=torch.stack(intr_list, dim=0), - extrinsics_c2w=torch.stack(c2w_list, dim=0), - depth=torch.stack(depth_list, dim=0), - frame_ids=torch.tensor(frame_ids, dtype=torch.int64), - ) - - _config_dir = config_dir - return VIPEWrapper - - -def _vipe_default_overrides(output_path: Path) -> List[str]: - return [ - "pipeline=default", - "pipeline.slam.optimize_intrinsics=false", - "pipeline.post.depth_align_model=null", - "pipeline.output.save_artifacts=false", - "pipeline.output.save_viz=false", - f"pipeline.output.path={output_path}", - ] - - -def _intrinsics_vec_to_k33(intrinsics_vec: torch.Tensor) -> torch.Tensor: - if intrinsics_vec.ndim != 2 or intrinsics_vec.shape[1] < 4: - raise ValueError(f"Expected intrinsics shape (T,>=4), got {tuple(intrinsics_vec.shape)}") - fx, fy, cx, cy = intrinsics_vec[:, 0], intrinsics_vec[:, 1], intrinsics_vec[:, 2], intrinsics_vec[:, 3] - t = int(intrinsics_vec.shape[0]) - k = torch.zeros((t, 3, 3), dtype=intrinsics_vec.dtype, device=intrinsics_vec.device) - k[:, 0, 0] = fx - k[:, 1, 1] = fy - k[:, 0, 2] = cx - k[:, 1, 2] = cy - k[:, 2, 2] = 1.0 - return k - - -def _probe_video(video_path: str) -> Tuple[int, float]: - cap = cv2.VideoCapture(video_path) - if not cap.isOpened(): - raise FileNotFoundError(f"Failed to open video: {video_path}") - frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) - fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0) - cap.release() - if frame_count <= 0: - frame_count = 0 - if fps <= 1e-6: - fps = 30.0 - return frame_count, fps - - -def _sample_indices(num_frames: int, stride: int, max_views: int = 0) -> List[int]: - stride = max(int(stride), 1) - indices = list(range(0, int(num_frames), stride)) - if max_views > 0: - indices = indices[: int(max_views)] - return indices - - -def _uniform_subsample_indices(num_frames: int, max_frames: int) -> List[int]: - num_frames = int(num_frames) - max_frames = int(max_frames) - if num_frames <= 0: - return [] - if max_frames <= 0 or num_frames <= max_frames: - return list(range(num_frames)) - return np.floor(np.linspace(0, num_frames - 1, num=max_frames)).astype(np.int64).tolist() - - -def _read_video_frames_rgb(video_path: str, indices: List[int]) -> List[np.ndarray]: - if not indices: - return [] - - wanted = set(int(i) for i in indices) - last_needed = int(max(wanted)) - frames: List[np.ndarray] = [] - read_ids: List[int] = [] - - cap = cv2.VideoCapture(video_path) - if not cap.isOpened(): - raise FileNotFoundError(f"Failed to open video: {video_path}") - - frame_idx = 0 - try: - while True: - ok, frame_bgr = cap.read() - if not ok or frame_bgr is None: - break - if frame_idx in wanted: - frames.append(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)) - read_ids.append(frame_idx) - if len(frames) == len(wanted): - break - if frame_idx >= last_needed: - break - frame_idx += 1 - finally: - cap.release() - - if len(frames) != len(wanted): - missing = sorted(list(wanted - set(read_ids))) - print( - f"[vipe_da3_gs] Warning: requested {len(wanted)} frames, got {len(frames)}. " - f"Missing={missing[:10]}" - ) - - return frames - - -def _compute_aligned_pred_w2c(pred_extr_np: np.ndarray, input_w2c_np: np.ndarray) -> np.ndarray: - _ensure_da3_on_syspath() - from depth_anything_3.utils.geometry import affine_inverse_np # type: ignore - from depth_anything_3.utils.pose_align import align_poses_umeyama # type: ignore - - pred_44 = pred_extr_np.copy() - if pred_44.shape[-2] == 3: - pad = np.zeros((*pred_44.shape[:-2], 4, 4), dtype=pred_44.dtype) - pad[..., :3, :4] = pred_44 - pad[..., 3, 3] = 1.0 - pred_44 = pad - - inp_44 = input_w2c_np.copy() - if inp_44.shape[-2] == 3: - pad = np.zeros((*inp_44.shape[:-2], 4, 4), dtype=inp_44.dtype) - pad[..., :3, :4] = inp_44 - pad[..., 3, 3] = 1.0 - inp_44 = pad - - r, t, s = align_poses_umeyama(pred_44, inp_44) - r_inv = r.T - pred_c2w = affine_inverse_np(pred_44) - - aligned_c2w = np.zeros_like(pred_c2w) - aligned_c2w[:, :3, :3] = np.einsum("ij,njk->nik", r_inv, pred_c2w[:, :3, :3]) - trans_shifted = pred_c2w[:, :3, 3] - t[None, :] - aligned_c2w[:, :3, 3] = np.einsum("ij,nj->ni", r_inv, trans_shifted) / s - aligned_c2w[:, 3, 3] = 1.0 - - return affine_inverse_np(aligned_c2w).astype(np.float32) - - -def _pad_to_44(mat: np.ndarray) -> np.ndarray: - if mat.shape[-2:] == (4, 4): - return mat - padded = np.zeros((*mat.shape[:-2], 4, 4), dtype=mat.dtype) - padded[..., :3, :4] = mat[..., :3, :4] - padded[..., 3, 3] = 1.0 - return padded - - -def _interpolate_w2c( - w2c_keyframes: np.ndarray, - key_indices: List[int], - n_total: int, -) -> np.ndarray: - w2c_keyframes = _pad_to_44(w2c_keyframes) - - if len(key_indices) == 1: - return np.repeat(w2c_keyframes[:1], n_total, axis=0).astype(np.float32) - - from scipy.spatial.transform import Rotation, Slerp - - c2w = np.linalg.inv(w2c_keyframes) - times_key = np.array(key_indices, dtype=np.float64) - rotations = Rotation.from_matrix(c2w[:, :3, :3]) - translations = c2w[:, :3, 3].astype(np.float64) - - slerp = Slerp(times_key, rotations) - times_all = np.arange(n_total, dtype=np.float64) - times_clamped = np.clip(times_all, times_key[0], times_key[-1]) - - rotations_interp = slerp(times_clamped) - translations_interp = np.column_stack( - [np.interp(times_clamped, times_key, translations[:, dim]) for dim in range(3)] - ) - - c2w_dense = np.zeros((n_total, 4, 4), dtype=np.float64) - c2w_dense[:, :3, :3] = rotations_interp.as_matrix() - c2w_dense[:, :3, 3] = translations_interp - c2w_dense[:, 3, 3] = 1.0 - return np.linalg.inv(c2w_dense).astype(np.float32) - - -def _load_gaussian_ply_to_gaussians(ply_path: str, device: torch.device): - _ensure_da3_on_syspath() - from depth_anything_3.specs import Gaussians # type: ignore - - ply = PlyData.read(ply_path) - if "vertex" not in ply: - raise ValueError(f"No 'vertex' element in PLY: {ply_path}") - vertices = ply["vertex"].data - names = list(vertices.dtype.names or []) - - def _stack_props(prefix: str, count: int) -> np.ndarray: - props = [] - for idx in range(count): - key = f"{prefix}{idx}" - if key not in names: - raise ValueError(f"Missing '{key}' in PLY: {ply_path}") - props.append(vertices[key].astype(np.float32, copy=False)) - return np.stack(props, axis=1) - - means = np.stack( - [ - vertices["x"].astype(np.float32, copy=False), - vertices["y"].astype(np.float32, copy=False), - vertices["z"].astype(np.float32, copy=False), - ], - axis=1, - ) - scales = np.exp(_stack_props("scale_", 3)) - rotations = _stack_props("rot_", 4) - opacities = 1.0 / (1.0 + np.exp(-vertices["opacity"].astype(np.float32, copy=False))) - - f_dc = _stack_props("f_dc_", 3) - f_rest_keys = sorted( - [key for key in names if key.startswith("f_rest_")], - key=lambda key: int(key.split("_")[-1]), - ) - if f_rest_keys: - f_rest = np.stack([vertices[key].astype(np.float32, copy=False) for key in f_rest_keys], axis=1) - if f_rest.shape[1] % 3 != 0: - raise ValueError(f"Unexpected f_rest size {f_rest.shape[1]} in PLY: {ply_path}") - d_sh = f_rest.shape[1] // 3 + 1 - f_rest = f_rest.reshape(f_rest.shape[0], 3, d_sh - 1) - harmonics = np.concatenate([f_dc[:, :, None], f_rest], axis=2) - else: - harmonics = f_dc[:, :, None] - - return Gaussians( - means=torch.from_numpy(means).to(device=device).unsqueeze(0), - scales=torch.from_numpy(scales).to(device=device).unsqueeze(0), - rotations=torch.from_numpy(rotations).to(device=device).unsqueeze(0), - harmonics=torch.from_numpy(harmonics).to(device=device).unsqueeze(0), - opacities=torch.from_numpy(opacities).to(device=device).unsqueeze(0), - ) - - -def _save_video_mp4(video_path: str, frames_thwc: np.ndarray, fps: float) -> None: - if frames_thwc.ndim != 4 or frames_thwc.shape[-1] != 3: - raise ValueError(f"Expected frames shape (T,H,W,3), got {tuple(frames_thwc.shape)}") - - t = int(frames_thwc.shape[0]) - if t == 0: - raise ValueError("No frames to save.") - - os.makedirs(os.path.dirname(video_path), exist_ok=True) - from moviepy.video.io.ImageSequenceClip import ImageSequenceClip - - frames_uint8 = frames_thwc - if frames_uint8.dtype != np.uint8: - frames_uint8 = np.clip(frames_uint8, 0, 255).astype(np.uint8) - frames_list = [frame for frame in frames_uint8] - clip = ImageSequenceClip(frames_list, fps=float(max(fps, 1.0))) - try: - clip.write_videofile( - video_path, - codec="libx264", - audio=False, - fps=float(max(fps, 1.0)), - ffmpeg_params=["-crf", "18", "-preset", "slow", "-pix_fmt", "yuv420p"], - ) - finally: - clip.close() - - -def _collect_vipe_images( - video_path: str, - vipe_stride: int, - max_frames: int, - max_views: int, -) -> tuple[List[np.ndarray], List[int], float]: - frame_count, fps = _probe_video(video_path) - - if frame_count > 0: - total = frame_count - if max_frames > 0: - total = min(total, max_frames) - indices_vipe = _sample_indices(total, vipe_stride, max_views) - images_vipe = _read_video_frames_rgb(video_path, indices_vipe) - return images_vipe, indices_vipe, fps - - cap = cv2.VideoCapture(video_path) - if not cap.isOpened(): - raise FileNotFoundError(f"Failed to open video: {video_path}") - - frames_tmp: List[np.ndarray] = [] - try: - while True: - ok, frame_bgr = cap.read() - if not ok or frame_bgr is None: - break - frames_tmp.append(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)) - if max_frames > 0 and len(frames_tmp) >= max_frames: - break - finally: - cap.release() - - indices_vipe = _sample_indices(len(frames_tmp), vipe_stride, max_views) - images_vipe = [frames_tmp[idx] for idx in indices_vipe] - return images_vipe, indices_vipe, fps - - -def _build_output_dir(input_video_path: str, output_dir: str | None) -> Path: - if output_dir: - return Path(output_dir).expanduser().resolve() - - input_video = Path(input_video_path).expanduser().resolve() - return input_video.with_name(f"{input_video.stem}_gs_ours") - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Single-video VIPE pose estimation + DA3 Gaussian reconstruction." - ) - parser.add_argument("--input_video_path", type=str, required=True) - parser.add_argument("--output_dir", type=str, default=None) - parser.add_argument("--force", action="store_true") - - parser.add_argument("--device", type=str, default=None) - parser.add_argument( - "--no_vipe", - action="store_true", - help="Skip VIPE pose estimation; DA3 reconstructs without input poses and its predicted poses are used for rendering.", - ) - parser.add_argument("--vipe_overrides", type=str, nargs="+", default=None) - parser.add_argument("--vipe_full_mode", action="store_true") - - parser.add_argument("--max_frames", type=int, default=0) - parser.add_argument( - "--da3_max_frames", - type=int, - default=128, - help="Uniformly subsample VIPE frames to at most this many views for DA3.", - ) - - parser.add_argument( - "--da3_model_name", - type=str, - default="depth-anything/DA3NESTED-GIANT-LARGE-1.1", - ) - parser.add_argument( - "--da3_model_path_custom", - type=str, - default=str(DEFAULT_RECON_DA3_MODEL_PATH), - ) - parser.add_argument("--da3_process_res", type=int, default=None) - parser.add_argument( - "--da3_process_method", - type=str, - default="upper_bound_resize", - choices=["upper_bound_resize", "lower_bound_resize"], - ) - parser.add_argument( - "--max_resolution", - type=int, - default=0, - help="If > 0, use as DA3 short-side cap via lower_bound_resize.", - ) - - parser.add_argument("--gs_down_ratio", type=int, default=2) - parser.add_argument("--gs_scale_extra_multiplier", type=float, default=1.0) - parser.add_argument("--gs_ply_prune_opacity_percentile", type=float, default=None) - parser.add_argument( - "--no_gs_ds_feature_mode", - dest="gs_ds_feature_mode", - action="store_false", - help="Disable the default release-friendly GS feature downsampling mode.", - ) - parser.set_defaults(gs_ds_feature_mode=True) - - parser.add_argument( - "--use_da3_render_pose", - dest="use_da3_render_pose", - action="store_true", - help="Render with DA3-aligned predicted poses interpolated to VIPE cadence.", - ) - parser.add_argument( - "--no_da3_render_pose", - dest="use_da3_render_pose", - action="store_false", - help="Render with raw VIPE poses instead.", - ) - parser.set_defaults(use_da3_render_pose=True) - - parser.add_argument("--render_fps", type=float, default=None) - parser.add_argument("--render_chunk_size", type=int, default=1) - - return parser.parse_args() - - -def main() -> None: - args = parse_args() - - input_video = Path(args.input_video_path).expanduser().resolve() - if not input_video.is_file(): - raise FileNotFoundError(f"Input video not found: {input_video}") - - output_dir = _build_output_dir(str(input_video), args.output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - done_marker = output_dir / ".done" - if done_marker.is_file() and not args.force: - print(f"[vipe_da3_gs] Skipping {input_video.name}: {done_marker} exists. Use --force to re-run.") - return - - device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu")) - print(f"[vipe_da3_gs] Input video: {input_video}") - print(f"[vipe_da3_gs] Output dir: {output_dir}") - print(f"[vipe_da3_gs] Device: {device}") - - da3_model_path_custom = None - if args.da3_model_path_custom: - da3_model_path_custom = str(Path(args.da3_model_path_custom).expanduser().resolve()) - if not Path(da3_model_path_custom).is_file(): - raise FileNotFoundError(f"DA3 checkpoint not found: {da3_model_path_custom}") - print(f"[vipe_da3_gs] DA3 ckpt: {da3_model_path_custom}") - - print("[vipe_da3_gs] Loading DA3 model...") - da3_model = load_da3_model( - da3_model_name=args.da3_model_name, - da3_model_path_custom=da3_model_path_custom, - device=str(device), - ) - da3_model.eval() - - skip_vipe = bool(args.no_vipe) - - if not skip_vipe: - VIPE = _import_vipe_class() - - print("[vipe_da3_gs] Reading video frames...") - images_all, indices_all, fps = _collect_vipe_images( - str(input_video), - vipe_stride=1, - max_frames=args.max_frames, - max_views=0, - ) - if not images_all: - raise RuntimeError("No frames read from video.") - - indices_da3_rel = _uniform_subsample_indices(len(images_all), args.da3_max_frames) - if not indices_da3_rel: - raise RuntimeError("No frames selected for DA3.") - - images_da3 = [images_all[idx] for idx in indices_da3_rel] - indices_da3 = [indices_all[idx] for idx in indices_da3_rel] - eff_fps = float(fps) - - print( - f"[vipe_da3_gs] fps={fps:.4g}, no_vipe={skip_vipe}, da3_max_frames={args.da3_max_frames}, " - f"frames_all={len(images_all)}, frames_da3={len(images_da3)}" - ) - - if skip_vipe: - w2c_np_vipe_full = None - k_np_vipe_full = None - w2c_np_da3 = None - k_np_da3 = None - else: - frames_np = np.stack(images_all, axis=0).astype(np.float32) / 255.0 - frames_thwc = torch.from_numpy(frames_np).contiguous() - - with tempfile.TemporaryDirectory(prefix="vipe_da3_gs_") as tmpdir: - if not skip_vipe: - vipe_output_path = Path(tmpdir) / "vipe_out" - vipe_output_path.mkdir(parents=True, exist_ok=True) - vipe_overrides = args.vipe_overrides or _vipe_default_overrides( - vipe_output_path, - ) - - print("[vipe_da3_gs] Loading VIPE...") - vipe_kwargs = {"fast_mode": not bool(args.vipe_full_mode)} - vipe = VIPE(vipe_overrides, **vipe_kwargs) - - print("[vipe_da3_gs] Running VIPE...") - vipe_out = vipe.infer_frames(frames_thwc, fps=eff_fps, name=input_video.stem) - - c2w = vipe_out.extrinsics_c2w.to(dtype=torch.float32) - w2c = torch.linalg.inv(c2w) - intrinsics_vipe = _intrinsics_vec_to_k33(vipe_out.intrinsics.to(dtype=torch.float32)) - - w2c_np_vipe_full = w2c.cpu().numpy().astype(np.float32) - k_np_vipe_full = intrinsics_vipe.cpu().numpy().astype(np.float32) - w2c_np_da3 = w2c_np_vipe_full[indices_da3_rel] - k_np_da3 = k_np_vipe_full[indices_da3_rel] - - np.savez( - output_dir / "vipe_predictions.npz", - frame_ids=vipe_out.frame_ids.cpu().numpy().astype(np.int64), - w2c_vipe=w2c_np_vipe_full, - intrinsics_vipe=k_np_vipe_full, - w2c_da3=w2c_np_da3, - intrinsics_da3=k_np_da3, - indices_vipe=np.asarray(indices_all, dtype=np.int64), - indices_da3=np.asarray(indices_da3, dtype=np.int64), - fps=np.asarray([eff_fps], dtype=np.float32), - input_video_path=np.asarray([str(input_video)]), - ) - - if args.da3_process_res is not None: - da3_process_res = int(args.da3_process_res) - da3_process_method = str(args.da3_process_method) - elif int(args.max_resolution) > 0: - da3_process_res = int(args.max_resolution) - da3_process_method = "lower_bound_resize" - else: - h0, w0 = images_da3[0].shape[:2] - da3_process_res = int(max(h0, w0)) - da3_process_method = "upper_bound_resize" - - _ensure_da3_on_syspath() - from depth_anything_3.utils.gsply_helpers import save_gaussian_ply # type: ignore - - if skip_vipe: - # Pass 1: DA3 without input poses to get predicted extrinsics - print( - f"[vipe_da3_gs] DA3 pass 1 (pose estimation): views={len(images_da3)} " - f"process_res={da3_process_res} infer_gs=False" - ) - pred_pose = da3_model.inference( - image=images_da3, - extrinsics=None, - intrinsics=None, - align_to_input_extrinsics=False, - align_to_input_ext_scale=False, - infer_gs=False, - process_res=da3_process_res, - process_res_method=da3_process_method, - export_dir=None, - export_format="mini_npz", - ) - if pred_pose.extrinsics is None or pred_pose.intrinsics is None: - raise RuntimeError("DA3 pass 1 did not return predicted poses.") - w2c_np_da3 = _pad_to_44(np.asarray(pred_pose.extrinsics, dtype=np.float32)) - k_np_da3 = np.asarray(pred_pose.intrinsics, dtype=np.float32) - da3_pred_w2c = w2c_np_da3.copy() - da3_pred_k = k_np_da3.copy() - print( - f"[vipe_da3_gs] DA3 pass 1 done. " - f"extrinsics {w2c_np_da3.shape}, intrinsics {k_np_da3.shape}" - ) - - del pred_pose - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # Pass 2: DA3 with pass-1 poses as condition for GS reconstruction - print( - f"[vipe_da3_gs] DA3 pass 2 (GS recon): views={len(images_da3)} " - f"process_res={da3_process_res} infer_gs=True" - ) - - else: - print( - f"[vipe_da3_gs] DA3 single pass: views={len(images_da3)} " - f"process_res={da3_process_res}" - ) - - pred = da3_model.inference( - image=images_da3, - extrinsics=w2c_np_da3, - intrinsics=k_np_da3, - align_to_input_extrinsics=False, - align_to_input_ext_scale=False, - infer_gs=True, - process_res=da3_process_res, - process_res_method=da3_process_method, - export_dir=None, - export_format="mini_npz", - use_aligned_pred_cam=True, - gs_down_ratio=args.gs_down_ratio, - gs_scale_extra_multiplier=args.gs_scale_extra_multiplier, - gs_ds_feature_mode=args.gs_ds_feature_mode, - ) - - if not skip_vipe: - aligned_w2c_da3 = None - if args.use_da3_render_pose and pred.extrinsics is not None: - aligned_w2c_da3 = _compute_aligned_pred_w2c( - np.asarray(pred.extrinsics, dtype=np.float32), - w2c_np_da3, - ) - - final_ply_path = output_dir / "reconstructed_scene.ply" - depth_t = torch.from_numpy(np.asarray(pred.depth, dtype=np.float32)).float() - save_gaussian_ply( - pred.gaussians, - str(final_ply_path), - ctx_depth=depth_t.unsqueeze(-1), - prune_by_opacity_percentile=args.gs_ply_prune_opacity_percentile, - prune_border_gs=False - if ( - args.gs_ply_prune_opacity_percentile is not None - and args.gs_ply_prune_opacity_percentile > 0 - ) - else True, - ) - print(f"[vipe_da3_gs] Saved PLY to {final_ply_path}") - - if skip_vipe: - w2c_render = _interpolate_w2c(da3_pred_w2c, indices_da3_rel, len(images_all)) - if da3_pred_k is not None: - k_da3_first = da3_pred_k[0:1] - k_render = np.repeat(k_da3_first, len(images_all), axis=0).astype(np.float32) - else: - raise RuntimeError("DA3 did not return predicted intrinsics; cannot render without VIPE.") - print("[vipe_da3_gs] Rendering with DA3 predicted poses (no VIPE).") - elif args.use_da3_render_pose and aligned_w2c_da3 is not None: - w2c_render = _interpolate_w2c(aligned_w2c_da3, indices_da3_rel, len(images_all)) - k_render = k_np_vipe_full - print("[vipe_da3_gs] Rendering with DA3-aligned poses.") - else: - w2c_render = w2c_np_vipe_full - k_render = k_np_vipe_full - if args.use_da3_render_pose: - print("[vipe_da3_gs] Warning: DA3 poses unavailable, falling back to VIPE poses.") - else: - print("[vipe_da3_gs] Rendering with VIPE poses.") - - cameras_data = { - "w2c_render": w2c_render, - "indices_da3": np.asarray(indices_da3, dtype=np.int64), - "fps": np.asarray([eff_fps], dtype=np.float32), - "no_vipe": np.asarray([int(skip_vipe)], dtype=np.int32), - } - if not skip_vipe: - cameras_data.update({ - "w2c_vipe": w2c_np_vipe_full, - "intrinsics_vipe": k_np_vipe_full, - "w2c_da3": w2c_np_da3, - "intrinsics_da3": k_np_da3, - "indices_vipe": np.asarray(indices_all, dtype=np.int64), - "use_da3_render_pose": np.asarray([int(args.use_da3_render_pose)], dtype=np.int32), - }) - else: - cameras_data.update({ - "w2c_da3_pred": da3_pred_w2c, - "intrinsics_da3_pred": da3_pred_k, - "intrinsics_render": k_render, - }) - np.savez(output_dir / "cameras.npz", **cameras_data) - - del pred - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - from depth_anything_3.model.utils.gs_renderer import run_renderer_in_chunk_w_trj_mode # type: ignore - - gs_device = device - if hasattr(da3_model, "model"): - try: - gs_device = next(da3_model.model.parameters()).device - except StopIteration: - gs_device = device - - gaussians = _load_gaussian_ply_to_gaussians(str(final_ply_path), device=gs_device) - render_extr = torch.from_numpy(w2c_render).to(device=gs_device, dtype=gaussians.means.dtype)[None] - render_intr = torch.from_numpy(k_render).to(device=gs_device, dtype=gaussians.means.dtype)[None] - if render_extr.shape[-2:] == (3, 4): - pad = torch.tensor([0, 0, 0, 1], device=gs_device, dtype=gaussians.means.dtype).view(1, 1, 1, 4) - render_extr = torch.cat( - [render_extr, pad.expand(render_extr.shape[0], render_extr.shape[1], -1, -1)], - dim=-2, - ) - - render_h, render_w = images_all[0].shape[:2] - render_fps = float(args.render_fps) if args.render_fps is not None else float(max(1, round(eff_fps))) - print( - f"[vipe_da3_gs] Rendering {render_extr.shape[1]} frames at {render_h}x{render_w} " - f"(fps={render_fps:.2f}, chunk_size={args.render_chunk_size})..." - ) - color, depth = run_renderer_in_chunk_w_trj_mode( - gaussians=gaussians, - extrinsics=render_extr, - intrinsics=render_intr, - image_shape=(render_h, render_w), - chunk_size=int(args.render_chunk_size), - trj_mode="original", - use_sh=True, - color_mode="RGB+ED", - enable_tqdm=True, - ) - - frames_render = ( - color[0].clamp(0.0, 1.0).mul(255.0).byte().permute(0, 2, 3, 1).cpu().numpy() - ) - video_path = output_dir / "gs_trajectory.mp4" - _save_video_mp4(str(video_path), frames_render, fps=render_fps) - print(f"[vipe_da3_gs] Saved GS render video to {video_path}") - - del gaussians, render_extr, render_intr, color, depth, frames_render - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - done_marker.write_text("done\n") - print("[vipe_da3_gs] Done.") - - -if __name__ == "__main__": - main() diff --git a/lyra_2/_src/models/__init__.py b/lyra_2/_src/models/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/lyra_2/_src/models/fm_solvers_unipc.py b/lyra_2/_src/models/fm_solvers_unipc.py deleted file mode 100644 index 0d1f5f12bf9ccb6dba097dfa1bba28769b8d7878..0000000000000000000000000000000000000000 --- a/lyra_2/_src/models/fm_solvers_unipc.py +++ /dev/null @@ -1,772 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py -# Convert unipc for flow matching -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput -from diffusers.utils import deprecate, is_scipy_available - -if is_scipy_available(): - pass - - -class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): - """ - `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. - - This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic - methods the library implements for all schedulers such as loading and saving. - - Args: - num_train_timesteps (`int`, defaults to 1000): - The number of diffusion steps to train the model. - solver_order (`int`, default `2`): - The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` - due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for - unconditional sampling. - prediction_type (`str`, defaults to "flow_prediction"): - Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts - the flow of the diffusion process. - thresholding (`bool`, defaults to `False`): - Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such - as Stable Diffusion. - dynamic_thresholding_ratio (`float`, defaults to 0.995): - The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. - sample_max_value (`float`, defaults to 1.0): - The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. - predict_x0 (`bool`, defaults to `True`): - Whether to use the updating algorithm on the predicted x0. - solver_type (`str`, default `bh2`): - Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` - otherwise. - lower_order_final (`bool`, default `True`): - Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can - stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. - disable_corrector (`list`, default `[]`): - Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` - and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is - usually disabled during the first few steps. - solver_p (`SchedulerMixin`, default `None`): - Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. - use_karras_sigmas (`bool`, *optional*, defaults to `False`): - Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, - the sigmas are determined according to a sequence of noise levels {σi}. - use_exponential_sigmas (`bool`, *optional*, defaults to `False`): - Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. - timestep_spacing (`str`, defaults to `"linspace"`): - The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - steps_offset (`int`, defaults to 0): - An offset added to the inference steps, as required by some model families. - final_sigmas_type (`str`, defaults to `"zero"`): - The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final - sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - solver_order: int = 2, - prediction_type: str = "flow_prediction", - shift: Optional[float] = 1.0, - use_dynamic_shifting=False, - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - predict_x0: bool = True, - solver_type: str = "bh2", - lower_order_final: bool = True, - disable_corrector: List[int] = [], - solver_p: SchedulerMixin = None, - timestep_spacing: str = "linspace", - steps_offset: int = 0, - final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" - ): - if solver_type not in ["bh1", "bh2"]: - if solver_type in ["midpoint", "heun", "logrho"]: - self.register_to_config(solver_type="bh2") - else: - raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") - - self.predict_x0 = predict_x0 - # setable values - self.num_inference_steps = None - alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() - sigmas = 1.0 - alphas - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) - - if not use_dynamic_shifting: - # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore - - self.sigmas = sigmas - self.timesteps = sigmas * num_train_timesteps - - self.model_outputs = [None] * solver_order - self.timestep_list = [None] * solver_order - self.lower_order_nums = 0 - self.disable_corrector = disable_corrector - self.solver_p = solver_p - self.last_sample = None - self._step_index = None - self._begin_index = None - - self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - self.sigma_min = self.sigmas[-1].item() - self.sigma_max = self.sigmas[0].item() - - @property - def step_index(self): - """ - The index counter for current timestep. It will increase 1 after each scheduler step. - """ - return self._step_index - - @property - def begin_index(self): - """ - The index for the first timestep. It should be set from pipeline with `set_begin_index` method. - """ - return self._begin_index - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): - """ - Sets the begin index for the scheduler. This function should be run from pipeline before the inference. - - Args: - begin_index (`int`): - The begin index for the scheduler. - """ - self._begin_index = begin_index - - # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps - def set_timesteps( - self, - num_inference_steps: Union[int, None] = None, - device: Union[str, torch.device] = None, - sigmas: Optional[List[float]] = None, - mu: Optional[Union[float, None]] = None, - shift: Optional[Union[float, None]] = None, - ): - """ - Sets the discrete timesteps used for the diffusion chain (to be run before inference). - Args: - num_inference_steps (`int`): - Total number of the spacing of the time steps. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - - if self.config.use_dynamic_shifting and mu is None: - raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") - - if sigmas is None: - sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore - - if self.config.use_dynamic_shifting: - sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore - else: - if shift is None: - shift = self.config.shift - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore - - if self.config.final_sigmas_type == "sigma_min": - sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 - elif self.config.final_sigmas_type == "zero": - sigma_last = 0 - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" - ) - - timesteps = sigmas * self.config.num_train_timesteps - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore - - self.sigmas = torch.from_numpy(sigmas) - self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) - - self.num_inference_steps = len(timesteps) - - self.model_outputs = [ - None, - ] * self.config.solver_order - self.lower_order_nums = 0 - self.last_sample = None - if self.solver_p: - self.solver_p.set_timesteps(self.num_inference_steps, device=device) - - # add an index counter for schedulers that allow duplicated timesteps - self._step_index = None - self._begin_index = None - self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample - def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: - """ - "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the - prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by - s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing - pixels from saturation at each step. We find that dynamic thresholding results in significantly better - photorealism as well as better image-text alignment, especially when using very large guidance weights." - - https://arxiv.org/abs/2205.11487 - """ - dtype = sample.dtype - batch_size, channels, *remaining_dims = sample.shape - - if dtype not in (torch.float32, torch.float64): - sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half - - # Flatten sample for doing quantile calculation along each image - sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) - - abs_sample = sample.abs() # "a certain percentile absolute pixel value" - - s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) - s = torch.clamp( - s, min=1, max=self.config.sample_max_value - ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] - s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 - sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" - - sample = sample.reshape(batch_size, channels, *remaining_dims) - sample = sample.to(dtype) - - return sample - - # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma): - return sigma * self.config.num_train_timesteps - - def _sigma_to_alpha_sigma_t(self, sigma): - return 1 - sigma, sigma - - # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps - def time_shift(self, mu: float, sigma: float, t: torch.Tensor): - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) - - def convert_model_output( - self, - model_output: torch.Tensor, - *args, - sample: torch.Tensor = None, - **kwargs, - ) -> torch.Tensor: - r""" - Convert the model output to the corresponding type the UniPC algorithm needs. - - Args: - model_output (`torch.Tensor`): - The direct output from the learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - - Returns: - `torch.Tensor`: - The converted model output. - """ - timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) - if sample is None: - if len(args) > 1: - sample = args[1] - else: - raise ValueError("missing `sample` as a required keyward argument") - if timestep is not None: - deprecate( - "timesteps", - "1.0.0", - "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - - # print("sigma_t ==>", self.step_index, sigma, sigma_t, alpha_t, sample.shape, model_output.shape) - if self.predict_x0: - if self.config.prediction_type == "flow_prediction": - sigma_t = self.sigmas[self.step_index] - x0_pred = sample - sigma_t * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," - " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." - ) - - if self.config.thresholding: - x0_pred = self._threshold_sample(x0_pred) - # print("self.config.thresholding", self.config.thresholding) - return x0_pred - else: - if self.config.prediction_type == "flow_prediction": - sigma_t = self.sigmas[self.step_index] - epsilon = sample - (1 - sigma_t) * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," - " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." - ) - - if self.config.thresholding: - sigma_t = self.sigmas[self.step_index] - x0_pred = sample - sigma_t * model_output - x0_pred = self._threshold_sample(x0_pred) - epsilon = model_output + x0_pred - - return epsilon - - def multistep_uni_p_bh_update( - self, - model_output: torch.Tensor, - *args, - sample: torch.Tensor = None, - order: int = None, # pyright: ignore - **kwargs, - ) -> torch.Tensor: - """ - One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. - - Args: - model_output (`torch.Tensor`): - The direct output from the learned diffusion model at the current timestep. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - order (`int`): - The order of UniP at this timestep (corresponds to the *p* in UniPC-p). - - Returns: - `torch.Tensor`: - The sample tensor at the previous timestep. - """ - prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) - if sample is None: - if len(args) > 1: - sample = args[1] - else: - raise ValueError(" missing `sample` as a required keyward argument") - if order is None: - if len(args) > 2: - order = args[2] - else: - raise ValueError(" missing `order` as a required keyward argument") - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - model_output_list = self.model_outputs - - s0 = self.timestep_list[-1] - m0 = model_output_list[-1] - x = sample - - if self.solver_p: - x_t = self.solver_p.step(model_output, s0, x).prev_sample - return x_t - - sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # pyright: ignore - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - - h = lambda_t - lambda_s0 - device = sample.device - - rks = [] - D1s = [] - for i in range(1, order): - si = self.step_index - i # pyright: ignore - mi = model_output_list[-(i + 1)] - alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) - lambda_si = torch.log(alpha_si) - torch.log(sigma_si) - rk = (lambda_si - lambda_s0) / h - rks.append(rk) - D1s.append((mi - m0) / rk) # pyright: ignore - - rks.append(1.0) - rks = torch.tensor(rks, device=device) - - R = [] - b = [] - - hh = -h if self.predict_x0 else h - h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 - h_phi_k = h_phi_1 / hh - 1 - - factorial_i = 1 - - if self.config.solver_type == "bh1": - B_h = hh - elif self.config.solver_type == "bh2": - B_h = torch.expm1(hh) - else: - raise NotImplementedError() - - for i in range(1, order + 1): - R.append(torch.pow(rks, i - 1)) - b.append(h_phi_k * factorial_i / B_h) - factorial_i *= i + 1 - h_phi_k = h_phi_k / hh - 1 / factorial_i - - R = torch.stack(R) - b = torch.tensor(b, device=device) - - if len(D1s) > 0: - D1s = torch.stack(D1s, dim=1) # (B, K) - # for order 2, we use a simplified version - if order == 2: - rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) - else: - rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) - else: - D1s = None - - if self.predict_x0: - x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 - if D1s is not None: - pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore - else: - pred_res = 0 - x_t = x_t_ - alpha_t * B_h * pred_res - else: - x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 - if D1s is not None: - pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore - else: - pred_res = 0 - x_t = x_t_ - sigma_t * B_h * pred_res - - x_t = x_t.to(x.dtype) - return x_t - - def multistep_uni_c_bh_update( - self, - this_model_output: torch.Tensor, - *args, - last_sample: torch.Tensor = None, - this_sample: torch.Tensor = None, - order: int = None, # pyright: ignore - **kwargs, - ) -> torch.Tensor: - """ - One step for the UniC (B(h) version). - - Args: - this_model_output (`torch.Tensor`): - The model outputs at `x_t`. - this_timestep (`int`): - The current timestep `t`. - last_sample (`torch.Tensor`): - The generated sample before the last predictor `x_{t-1}`. - this_sample (`torch.Tensor`): - The generated sample after the last predictor `x_{t}`. - order (`int`): - The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. - - Returns: - `torch.Tensor`: - The corrected sample tensor at the current timestep. - """ - this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) - if last_sample is None: - if len(args) > 1: - last_sample = args[1] - else: - raise ValueError(" missing`last_sample` as a required keyward argument") - if this_sample is None: - if len(args) > 2: - this_sample = args[2] - else: - raise ValueError(" missing`this_sample` as a required keyward argument") - if order is None: - if len(args) > 3: - order = args[3] - else: - raise ValueError(" missing`order` as a required keyward argument") - if this_timestep is not None: - deprecate( - "this_timestep", - "1.0.0", - "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - model_output_list = self.model_outputs - - m0 = model_output_list[-1] - x = last_sample - x_t = this_sample - model_t = this_model_output - - sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] # pyright: ignore - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - - h = lambda_t - lambda_s0 - device = this_sample.device - - rks = [] - D1s = [] - for i in range(1, order): - si = self.step_index - (i + 1) # pyright: ignore - mi = model_output_list[-(i + 1)] - alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) - lambda_si = torch.log(alpha_si) - torch.log(sigma_si) - rk = (lambda_si - lambda_s0) / h - rks.append(rk) - D1s.append((mi - m0) / rk) # pyright: ignore - - rks.append(1.0) - rks = torch.tensor(rks, device=device) - - R = [] - b = [] - - hh = -h if self.predict_x0 else h - h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 - h_phi_k = h_phi_1 / hh - 1 - - factorial_i = 1 - - if self.config.solver_type == "bh1": - B_h = hh - elif self.config.solver_type == "bh2": - B_h = torch.expm1(hh) - else: - raise NotImplementedError() - - for i in range(1, order + 1): - R.append(torch.pow(rks, i - 1)) - b.append(h_phi_k * factorial_i / B_h) - factorial_i *= i + 1 - h_phi_k = h_phi_k / hh - 1 / factorial_i - - R = torch.stack(R) - b = torch.tensor(b, device=device) - - if len(D1s) > 0: - D1s = torch.stack(D1s, dim=1) - else: - D1s = None - - # for order 1, we use a simplified version - if order == 1: - rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) - else: - rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) - - if self.predict_x0: - x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 - if D1s is not None: - corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) - else: - corr_res = 0 - D1_t = model_t - m0 - x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) - else: - x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 - if D1s is not None: - corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) - else: - corr_res = 0 - D1_t = model_t - m0 - x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) - x_t = x_t.to(x.dtype) - return x_t - - def index_for_timestep(self, timestep, schedule_timesteps=None): - if schedule_timesteps is None: - schedule_timesteps = self.timesteps - - indices = (schedule_timesteps == timestep).nonzero() - - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - pos = 1 if len(indices) > 1 else 0 - - return indices[pos].item() - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index - def _init_step_index(self, timestep): - """ - Initialize the step_index counter for the scheduler. - """ - - if self.begin_index is None: - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - self._step_index = self.index_for_timestep(timestep) - else: - self._step_index = self._begin_index - - def step( - self, - model_output: torch.Tensor, - timestep: Union[int, torch.Tensor], - sample: torch.Tensor, - return_dict: bool = True, - generator=None, - ) -> Union[SchedulerOutput, Tuple]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with - the multistep UniPC. - - Args: - model_output (`torch.Tensor`): - The direct output from learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. - - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if self.step_index is None: - self._init_step_index(timestep) - - # print("self.step_index ==> ", self.step_index) - - use_corrector = ( - self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None # pyright: ignore - ) - - model_output_convert = self.convert_model_output(model_output, sample=sample) - - if use_corrector: - sample = self.multistep_uni_c_bh_update( - this_model_output=model_output_convert, - last_sample=self.last_sample, - this_sample=sample, - order=self.this_order, - ) - - for i in range(self.config.solver_order - 1): - self.model_outputs[i] = self.model_outputs[i + 1] - self.timestep_list[i] = self.timestep_list[i + 1] - - self.model_outputs[-1] = model_output_convert - self.timestep_list[-1] = timestep # pyright: ignore - - if self.config.lower_order_final: - this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) # pyright: ignore - else: - this_order = self.config.solver_order - - self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep - assert self.this_order > 0 - - self.last_sample = sample - prev_sample = self.multistep_uni_p_bh_update( - model_output=model_output, # pass the original non-converted model output, in case solver-p is used - sample=sample, - order=self.this_order, - ) - - if self.lower_order_nums < self.config.solver_order: - self.lower_order_nums += 1 - - # upon completion increase step index by one - self._step_index += 1 # pyright: ignore - - if not return_dict: - return (prev_sample, model_output_convert) - - return SchedulerOutput(prev_sample=prev_sample) - - def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.Tensor`): - The input sample. - - Returns: - `torch.Tensor`: - A scaled input sample. - """ - return sample - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise - def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.IntTensor, - ) -> torch.Tensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index - if self.begin_index is None: - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] - elif self.step_index is not None: - # add_noise is called after first denoising step (for inpainting) - step_indices = [self.step_index] * timesteps.shape[0] - else: - # add noise is called before first denoising step to create initial latent(img2img) - step_indices = [self.begin_index] * timesteps.shape[0] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - noisy_samples = alpha_t * original_samples + sigma_t * noise - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/lyra_2/_src/models/lyra2_model.py b/lyra_2/_src/models/lyra2_model.py deleted file mode 100644 index d09570920152fb465f2566f04905081a84d477c6..0000000000000000000000000000000000000000 --- a/lyra_2/_src/models/lyra2_model.py +++ /dev/null @@ -1,2850 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Any, Optional, List, cast -import random -from statistics import NormalDist -import numpy as np -import torch -from einops import rearrange -import attrs -import gc -from lyra_2._ext.imaginaire.lazy_config import instantiate as lazy_instantiate -from lyra_2._ext.imaginaire.utils import log -from lyra_2._ext.imaginaire.utils import misc -from lyra_2._src.modules.conditioner import DataType, T2VCondition -from lyra_2._src.models.wan_t2v_model import WANDiffusionModel, T2VModelConfig -from megatron.core import parallel_state -from torch.distributed.tensor import DTensor -from torch.distributed._composable.fsdp import fully_shard -from lyra_2._src.utils.dtensor_helper import broadcast_dtensor_model_states -from lyra_2._src.utils.context_parallel import broadcast -from lyra_2._src.datasets.forward_warp_utils_pytorch import ( - unproject_points, - forward_warp_multiframes, -) -from lyra_2._src.datasets.plucker_embed_corrupter import ( - ray_condition, -) - - -WAN2PT1_I2V_COND_LATENT_KEY = "i2v_WAN2PT1_cond_latents" -LYRA2_BUFFER_SINCOS_MULTIRES = 2 -LYRA2_BUFFER_MLP_SQUEEZE_DIM = 256 -LYRA2_CORRESPONDENCE_CHANNELS_PER_SLOT = 4 * 8 * 8 - - -@attrs.define(slots=False) -class Lyra2T2VConfig(T2VModelConfig): - """Configuration for Lyra2 spatial model""" - init_framepack_weights: bool = True # disable this if resume training - - # Lyra2 AR configuration - framepack_type: str = "f16k4f2k2f1k1_g3" # fkfkfk_gN where N=new latent frames - num_frames_per_latent: int = 4 # frames per latent tokenization (video tokenizer) - max_segments: int = 10 - starting_frame_ratio: float = 0.1 - - apply_corruption_to_spatial_region: str = "none" - augment_sigma_sample_p_mean: float = 0.0 - augment_sigma_sample_p_std: float = 1.0 - augment_sigma_sample_multiplier: float = 1.0 - condition_video_augment_sigma_in_inference: float = 0.001 - - # Stage-A self-augmentation (training-time short denoise) configuration - self_aug_enabled: bool = False - self_aug_steps: int = 5 - # Number of discrete inference-like timesteps - self_aug_num_discrete_timesteps: int = 8 - self_aug_guidance: Optional[float] = None - self_aug_scheduler_shift: Optional[float] = None - self_aug_every_k: int = 1 - self_aug_prob: float = 1.0 - self_aug_max_T: int = 50 - self_aug_copy_chunk: bool = False - self_aug_encode_gt_with_clean_history: bool = False - self_aug_i2v_ratio: float = 0.3 # ensure i2v case is trained - - spatial_memory_stride: int = 8 - spatial_memory_skip_recent: int = 100 - spatial_memory_use_image: bool = False - spatial_memory_dropout_rate: float = 0.1 - - # Optional: comma-separated list of submodule names in Lyra2AttentionBlock to train. - # If provided, all other parameters are frozen. - framepack_trainable_modules: Optional[str] = None - - # This port is intentionally collapsed to the single target branch: - # accumulated correspondence + multibuffer + depth-augmented slots + K/Q-only injection. - # Drop spatial memory cache completely with this probability. - spatial_memory_drop_rate: float = 0.1 - # Max number of spatial buffers to keep (pad with zeros if fewer are available). - # If None, defaults to the number of spatial history slots from framepack_type. - multibuffer_max_spatial_frames: Optional[int] = None - warp_chunk_size: int = 2 - - -class Lyra2Model(WANDiffusionModel): - """Lyra2 spatial model""" - - def __init__(self, config: Lyra2T2VConfig): - super().__init__(config) - # Transient diagnostics/visualization cache controls and storage - self._collect_return_condition_state: bool = False - self._latest_condition_state_pixels = None - self._latest_plucker_rays_pixels = None - self._latest_gt_gen_pixels = None - # Parse Lyra2 AR metadata - self._init_lyra2_metadata() - self.framepack_weights_initialized = False - self._cached_spatial_coords: Optional[torch.Tensor] = None - self._cached_spatial_coords_meta: Optional[tuple[int, int, int, torch.device, torch.dtype]] = None - - self._spatial_history_positions = self._compute_spatial_history_positions() - log.info( - f"Lyra2Model spatial history positions: {self._spatial_history_positions}, " - f"spatial_memory_stride={self.config.spatial_memory_stride}, " - f"spatial_memory_skip_recent={self.config.spatial_memory_skip_recent}", - rank0_only=True, - ) - - def _compute_spatial_history_positions(self) -> tuple[int, ...]: - positions: list[int] = [] - offset = 0 - for count, kernel_type in zip(self.framepack_clean_latent_frame_splits, self.framepack_clean_latent_frame_kernel_types): - cnt = int(count) - if kernel_type == "s": - positions.extend(range(offset, offset + cnt)) - offset += cnt - return tuple(positions) - - def _apply_spatial_region_corruption(self, latents: torch.Tensor, cond_latent: torch.Tensor) -> None: - if len(self._spatial_history_positions) == 0: - return - spatial_idx = torch.tensor(self._spatial_history_positions, device=latents.device, dtype=torch.long) - if spatial_idx.numel() == 0: - return - latents_slice = latents[:, :, spatial_idx] - latents_corrupted, aug_sigma = self.augment_conditional_latent_frames( - latents_slice, - target_mode=self.config.apply_corruption_to_spatial_region, - ) - latents[:, :, spatial_idx] = latents_corrupted - cond_latent[:, :16, spatial_idx] = latents_corrupted - - def _prepare_video_window(self, video, start=None, cur_segment_id=None): - """Step 1: crop/pad video to the current segment window and return bookkeeping. - - Returns: - video_win: cropped/padded video on correct dtype/device - video_indices: absolute frame indices into the original video timeline for each frame in video_win - start: chosen start index - cur_segment_id: chosen segment id - chunk_len: number of frames in the window - to_repeat_front: number of frames repeated at the front (for history padding) - """ - video_length = int(video.shape[2]) - video_indices = torch.arange(video_length, device=video.device) - cfg = self.config - - # Choose start - if start is None: - start = int(np.random.randint(0, max(1, int(video_length * self.framepack_starting_frame_ratio)))) - - # Choose segment id - if cur_segment_id is None: - max_segments = (video_length - start - 1) // int(self.framepack_num_new_video_frames) - if cfg.self_aug_enabled: - max_segments = max_segments - 1 - max_segments = max(1, min(max_segments, int(self.framepack_max_segments))) - cur_segment_id = int(np.random.randint(0, max_segments)) - - # number of frames to keep - chunk_len = (cur_segment_id + 1) * int(self.framepack_num_new_video_frames) + 1 - - # Crop-right or pad-right by repeating the last available frame - if start + chunk_len > video_length: - video_win = video[:, :, start:] - video_indices = video_indices[start:] - to_repeat = start + chunk_len - video_length - video_win = torch.cat([video_win, video_win[:, :, -1:].repeat(1, 1, to_repeat, 1, 1)], dim=2) - video_indices = torch.cat([video_indices, video_indices[-1:].repeat(to_repeat)], dim=0) - else: - video_win = video[:, :, start:start + chunk_len] - video_indices = video_indices[start:start + chunk_len] - - # Lyra2 first iter, condition is i iiii iiii ... - to_repeat_front = (self.framepack_num_history_latent - 1) * int(self.framepack_num_frames_per_latent) - if to_repeat_front > 0: - video_win = torch.cat([video_win[:, :, :1].repeat(1, 1, to_repeat_front, 1, 1), video_win], dim=2) - video_indices = torch.cat([video_indices[:1].repeat(to_repeat_front), video_indices], dim=0) - - video_win = video_win.to(dtype=self.tensor_kwargs["dtype"], device=self.tensor_kwargs["device"]) - - if self._collect_return_condition_state: - self._latest_gt_gen_pixels = video_win[:, :, -int(self.framepack_num_new_video_frames):].contiguous() - - return video_win, video_indices, int(start), int(cur_segment_id), int(chunk_len) - - def build_net(self): - """Add clean patch embeddings before FSDP so they are sharded/initialized correctly.""" - config = self.config - if config.use_mp_policy_fsdp: - fsdp_kwargs = {"mp_policy": torch.distributed.fsdp.MixedPrecisionPolicy( - param_dtype=torch.bfloat16, - reduce_dtype=torch.float32, - cast_forward_inputs=False, - )} - else: - fsdp_kwargs = {} - init_device = "meta" - with misc.timer("Creating PyTorch model"): - # Initialize clean patch embeddings BEFORE FSDP sharding - log.info("Constructing clean patch embeddings before FSDP") - # Derive kernel splits from instance attributes if metadata is already initialized - # (e.g. called from a second build_net()), otherwise fall back to parsing the - # framepack_type config string directly. This is necessary because - # build_net() is called from WANDiffusionModel.__init__ -> set_up_model() - # before Lyra2Model.__init__ calls _init_lyra2_metadata(). - try: - splits = self.framepack_clean_latent_frame_splits - kernel_sizes = self.framepack_clean_latent_frame_kernel_sizes - kernel_types = getattr(self, "framepack_clean_latent_frame_kernel_types", ["k"] * len(kernel_sizes)) - except AttributeError: - fp_splits = config.framepack_type.split("_") - fk_substring = fp_splits[0] - segments = [seg for seg in fk_substring.split("f")[1:] if len(seg) > 0] - splits = [] - kernel_sizes = [] - kernel_types = [] - for seg in segments: - i = 0 - while i < len(seg) and seg[i].isdigit(): - i += 1 - assert i > 0 and i < len(seg), f"Invalid framepack segment: {seg}" - f_count = int(seg[:i]) - t = seg[i] - assert t in ("k", "s"), f"Unknown kernel type {t} in segment {seg}" - ksize = int(seg[i + 1:]) - splits.append(f_count) - kernel_sizes.append(ksize) - kernel_types.append(t) - - def _cfg_set(cfg, key, value): - if isinstance(cfg, dict): - cfg[key] = value - return - try: - cfg[key] = value - except Exception: - setattr(cfg, key, value) - - max_spatial = config.multibuffer_max_spatial_frames - if max_spatial is None: - max_spatial = sum(s for s, t in zip(splits, kernel_types) if t == "s") - max_spatial = int(max_spatial) - - buffer_in_dim = 0 - if max_spatial > 0: - buffer_in_dim = LYRA2_CORRESPONDENCE_CHANNELS_PER_SLOT * max_spatial - - _cfg_set(config.net, "buffer_pixelshuffle", True) - _cfg_set(config.net, "buffer_in_dim", int(buffer_in_dim)) - _cfg_set(config.net, "buffer_sincos_multires", LYRA2_BUFFER_SINCOS_MULTIRES) - _cfg_set(config.net, "use_correspondence", True) - _cfg_set(config.net, "use_plucker_condition", True) - _cfg_set(config.net, "inject_kq_only", True) - _cfg_set(config.net, "buffer_mlp_squeeze_dim", LYRA2_BUFFER_MLP_SQUEEZE_DIM) - - with torch.device(init_device): - net = lazy_instantiate(config.net) - - net.init_clean_patch_embeddings(kernel_sizes, kernel_types) - if hasattr(net, "buffer_pixelshuffle"): - net.buffer_pixelshuffle = True - - if self.fsdp_device_mesh: - net.fully_shard(mesh=self.fsdp_device_mesh, **fsdp_kwargs) - net = fully_shard(net, mesh=self.fsdp_device_mesh, reshard_after_forward=True, **fsdp_kwargs) - - with misc.timer("meta to cuda and broadcast model states"): - net.to_empty(device="cuda") - net.init_weights() - - if self.fsdp_device_mesh: - broadcast_dtensor_model_states(net, self.fsdp_device_mesh) - for name, param in net.named_parameters(): - assert isinstance(param, DTensor), f"param should be DTensor, {name} got {type(param)}" - - if config.framepack_trainable_modules: - whitelist = [p.strip() for p in config.framepack_trainable_modules.split(",") if p.strip()] - if whitelist: - log.info(f"Freezing model and unfreezing Lyra2AttentionBlock layers matching: {whitelist}") - - for param in net.parameters(): - param.requires_grad = False - - trainable_param_names = set() - - for name, module in net.named_modules(): - if type(module).__name__ == "Lyra2AttentionBlock": - for sub_name, param in module.named_parameters(): - if any(pattern in sub_name for pattern in whitelist): - param.requires_grad = True - full_name = f"{name}.{sub_name}" - trainable_param_names.add(full_name) - - if any("clean_patch_embeddings" in p for p in whitelist): - for name, param in net.named_parameters(): - if "clean_patch_embeddings" in name: - param.requires_grad = True - trainable_param_names.add(name) - - if "patch_embedding" in whitelist: - for name, param in net.named_parameters(): - if "patch_embedding" in name: - param.requires_grad = True - trainable_param_names.add(name) - if "patch_embedding_buffer" in whitelist: - for name, param in net.named_parameters(): - if "patch_embedding_buffer" in name: - param.requires_grad = True - trainable_param_names.add(name) - - log.info(f"Enabled gradients for {len(trainable_param_names)} parameters: {sorted(list(trainable_param_names))}") - - - return net - - @torch.no_grad() - def decode(self, latents): - """Decode video from latents""" - if self.tokenizer.model.video_std.shape[2] > 1: - self.tokenizer.model.video_std = self.tokenizer.model.video_std[:,:,:1] - self.tokenizer.model.video_mean = self.tokenizer.model.video_mean[:,:,:1] - return self.tokenizer.decode(latents) - - def augment_conditional_latent_frames( - self, - gt_latent, - condition_video_augment_sigma_in_inference=0.001, - seed_inference=1, - augment_sigma=None, - target_mode="none", - ): - if target_mode == "none": - return gt_latent - - elif target_mode == "noise_with_sigma": - # Training only, sample sigma for the condition region - if augment_sigma is None: - augment_sigma, _ = self.draw_augment_sigma_and_epsilon_gen3c( - gt_latent.shape, - self.config.augment_sigma_sample_p_mean, - self.config.augment_sigma_sample_p_std, - self.config.augment_sigma_sample_multiplier, - ) - noise = torch.randn(*gt_latent.shape, **self.tensor_kwargs) - - elif target_mode == "noise_with_sigma_fixed": - # Inference only, use fixed sigma for the condition region - log.debug(f"condition_video_augment_sigma_in_inference={condition_video_augment_sigma_in_inference}") - assert ( - condition_video_augment_sigma_in_inference is not None - ), "condition_video_augment_sigma_in_inference should be provided" - s = float(condition_video_augment_sigma_in_inference) - B, _, T, _, _ = gt_latent.shape - if augment_sigma is None: - augment_sigma = torch.full((B, T), s, device=self.tensor_kwargs["device"], dtype=torch.float32).to(**self.tensor_kwargs) - - # Inference, use fixed seed - noise = misc.arch_invariant_rand( - gt_latent.shape, - torch.float32, - self.tensor_kwargs["device"], - seed_inference, - ) - else: - raise ValueError(f"does not support {target_mode}") - - B, _, T, _, _ = gt_latent.shape - augment_latent = gt_latent + noise * augment_sigma.view(B, 1, T, 1, 1) - - return augment_latent, augment_sigma - - def draw_augment_sigma_and_epsilon_gen3c(self, size, p_mean, p_std, multiplier): - B, _, T, _, _ = size - epsilon = torch.randn(size, **self.tensor_kwargs) - - gaussian_dist = NormalDist(mu=p_mean, sigma=p_std) - cdf_vals = np.random.uniform(size=(B * T)) - samples_interval_gaussian = [gaussian_dist.inv_cdf(cdf_val) for cdf_val in cdf_vals] - log_sigma = torch.tensor(samples_interval_gaussian, device=self.tensor_kwargs["device"], dtype=torch.float32).view(B, T) - sigma_B = torch.exp(log_sigma).to(**self.tensor_kwargs) - return sigma_B, epsilon - - def _init_lyra2_metadata(self): - cfg = self.config - fp_splits = cfg.framepack_type.split("_") - if fp_splits[0].startswith("f") and fp_splits[1].startswith("g"): - self.framepack_num_new_latent_frames = int(fp_splits[1].split("g")[1]) - fk_substring = fp_splits[0] - else: - raise ValueError( - f"Unsupported framepack_type: {cfg.framepack_type}. Expected fk..._g..." - ) - - # Parse segments supporting both temporal ('k') and spatial ('s') kernels per segment - segments = [seg for seg in fk_substring.split("f")[1:] if len(seg) > 0] - splits: List[int] = [] - kernel_sizes: List[int] = [] - kernel_types: List[str] = [] - for seg in segments: - # seg like '2k2' or '2s2' - i = 0 - while i < len(seg) and seg[i].isdigit(): - i += 1 - assert i > 0 and i < len(seg), f"Invalid framepack segment: {seg}" - f_count = int(seg[:i]) - t = seg[i] - assert t in ("k", "s"), f"Unknown kernel type {t} in segment {seg}" - ksize = int(seg[i+1:]) - splits.append(f_count) - kernel_sizes.append(ksize) - kernel_types.append(t) - - self.framepack_clean_latent_frame_splits: List[int] = splits - self.framepack_clean_latent_frame_kernel_sizes: List[int] = kernel_sizes - self.framepack_clean_latent_frame_kernel_types: List[str] = kernel_types - # Cache commonly used counts (temporal vs spatial history slots) for downstream helpers. - # Temporal slots correspond to kernel_type == 'k'; spatial slots correspond to 's'. - self.framepack_num_temporal_hist = int( - sum(s for s, t in zip(self.framepack_clean_latent_frame_splits, self.framepack_clean_latent_frame_kernel_types) if t == "k") - ) - self.framepack_num_spatial_hist = int( - sum(s for s, t in zip(self.framepack_clean_latent_frame_splits, self.framepack_clean_latent_frame_kernel_types) if t == "s") - ) - log.info( - f"Lyra2: splits={self.framepack_clean_latent_frame_splits}, " - f"kernel_sizes={self.framepack_clean_latent_frame_kernel_sizes}, " - f"kernel_types={self.framepack_clean_latent_frame_kernel_types}, " - f"new_latent_frames={self.framepack_num_new_latent_frames}" - ) - - max_num_clean_latent_frames = sum(self.framepack_clean_latent_frame_splits) - self.framepack_total_max_num_latent_frames = ( - max_num_clean_latent_frames + self.framepack_num_new_latent_frames - ) - - # framepack_splits: e.g., [16, 2, 1, 9] - framepack_splits = self.framepack_clean_latent_frame_splits + [ - self.framepack_num_new_latent_frames - ] - # framepack_indices: 0, 1, ..., 18 (history), 19, 20, ..., 27 (new) - framepack_indices = torch.arange(self.framepack_total_max_num_latent_frames) - # framepack_kernel_ids: 0, 1, 2, -1 (new frames -- using the original patch embedding kernel) - framepack_kernel_ids = list(range(len(self.framepack_clean_latent_frame_kernel_sizes))) + [-1] - self.framepack_params = { - "framepack_indices": framepack_indices, - "framepack_splits": framepack_splits, - "framepack_kernel_ids": framepack_kernel_ids, - "framepack_kernel_types": self.framepack_clean_latent_frame_kernel_types, - } - self.framepack_max_segments = cfg.max_segments - self.framepack_starting_frame_ratio = cfg.starting_frame_ratio - self.framepack_num_frames_per_latent = cfg.num_frames_per_latent - self.framepack_num_new_video_frames = ( - self.framepack_num_frames_per_latent * self.framepack_num_new_latent_frames - ) - self.framepack_num_history_latent = max_num_clean_latent_frames - - def get_data_and_condition( - self, - data_batch, - dropout=False, - ): - """Prepare Lyra2 latents and I2V conditioning, always using tokenizer.""" - if self.is_image_batch(data_batch): - raise ValueError("Lyra2 expects video inputs.") - - # Normalize and tokenize video into fixed-length latents for Lyra2 - # Frank: skip assertion to save memory - _flag = data_batch.get("is_preprocessed", False) - if not _flag: - self._normalize_video_databatch_inplace(data_batch) - - raw_state = data_batch[self.input_data_key] - latent_state, last_hist_frame, cond_latent, cond_latent_mask = self._tokenizing_video_to_latents(raw_state, dropout=dropout, data_batch=data_batch) - - # Populate images key for CLIP embedding: use last frame in history segment - data_batch["last_hist_frame"] = last_hist_frame - data_batch["cond_latent_mask"] = cond_latent_mask - - data_batch[WAN2PT1_I2V_COND_LATENT_KEY] = cond_latent - - condition = self.conditioner(data_batch) - condition = condition.edit_data_type(DataType.VIDEO) - return raw_state, latent_state, condition - - def get_x0_fn_from_batch( - self, - data_batch, - guidance=1.5, - is_negative_prompt=False, - seed=None, - ): - """Prepare Lyra2-aware inference closure and initial latents. - - - Conditions are prepared the same way as training and broadcast for CP. - - Returns a closure that stitches the generated-region prediction back into a - full-length tensor with zeros over history frames, and an initial latent - state whose history is clean x0 and generated region is random noise. - """ - # Prepare normalized/tokenized latents and condition - # data_batch_uncond = data_batch.copy() - _, x0_latents, _ = self.get_data_and_condition(data_batch) - # _, x0_latents_uncond, _ = self.get_data_and_condition(data_batch_uncond, dropout=True) - - is_image_batch = self.is_image_batch(data_batch) - if is_negative_prompt: - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - # _, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch_uncond) - else: - condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) - # _, uncondition = self.conditioner.get_condition_uncondition(data_batch_uncond) - - condition = condition.edit_data_type(DataType.IMAGE if is_image_batch else DataType.VIDEO) - uncondition = uncondition.edit_data_type(DataType.IMAGE if is_image_batch else DataType.VIDEO) - - # Enable CP and broadcast conditions (no temporal split; net handles CP internally) - _, condition, _, _ = self.broadcast_split_for_model_parallelsim(None, condition, None, None) - _, uncondition, _, _ = self.broadcast_split_for_model_parallelsim(None, uncondition, None, None) - - if not parallel_state.is_initialized(): - assert not self.net.is_context_parallel_enabled, "parallel_state is not initialized, context parallel should be turned off." - - # Build initial latents: clean history + random noise on generated region - T_hist = self.framepack_total_max_num_latent_frames - self.framepack_num_new_latent_frames - init_latents = torch.zeros_like(x0_latents, dtype=torch.float32) - init_latents[:, :, :T_hist] = x0_latents[:, :, :T_hist].to(dtype=torch.float32) - - gen_shape = tuple(x0_latents[:, :, T_hist:].shape) - gen_noise = misc.arch_invariant_rand( - gen_shape, - torch.float32, - self.tensor_kwargs["device"], - seed if seed is not None else 0, - ) - init_latents[:, :, T_hist:] = gen_noise - - def x0_fn(noise_x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: - # noise_x_uncond = noise_x.clone() - cond_v_gen = self.denoise(noise_x, timestep, condition) - - # noise_x_uncond[:, :, :T_hist] = x0_latents_uncond[:, :, :T_hist].to(noise_x_uncond) # set history region of x to encoding of zeros - uncond_v_gen = self.denoise(noise_x, timestep, uncondition) - gen_v = uncond_v_gen + guidance * (cond_v_gen - uncond_v_gen) - - vt_full = torch.zeros_like(noise_x, dtype=gen_v.dtype) - vt_full[:, :, T_hist:] = gen_v # zero predicted noise for history frames to keep history unchanged. - return vt_full - - return x0_fn, init_latents - - @torch.no_grad() - def inference(self, history_latents, cond_latent, guidance, seed, num_steps, shift, t5_text_embeddings, neg_t5_text_embeddings, **kwargs): - # 1) Validate history latent length - T_hist_expected = self.framepack_total_max_num_latent_frames - self.framepack_num_new_latent_frames - assert ( - history_latents.shape[2] == T_hist_expected - ), f"history_latents has T={history_latents.shape[2]} but expected {T_hist_expected}" - - # 2) Build conditioner inputs from provided kwargs and history latents - # Required: last_hist_frame in pixel space [B,3,H,W] - assert ( - "last_hist_frame" in kwargs - ), "last_hist_frame (pixel) is required in kwargs for Lyra2 inference" - last_hist_frame = kwargs["last_hist_frame"] - - # Optional: fps, padding_mask passthrough if provided - data_batch = { - "t5_text_embeddings": t5_text_embeddings, - "neg_t5_text_embeddings": neg_t5_text_embeddings, - "last_hist_frame": last_hist_frame, - } - - # Build cond media latents = [history, zeros(gen_region)] and its mask (1 on history, 0 on gen) - B, C, T_hist, H, W = history_latents.shape - T_new = self.framepack_num_new_latent_frames - assert cond_latent.shape[2] == T_hist + T_new, "cond_latent must have T_hist + T_new frames" - - mask = kwargs.get("cond_latent_mask", None) - if mask is None: - mask = torch.ones((B, 4, T_hist + T_new, H, W), dtype=history_latents.dtype, device=history_latents.device) - - data_batch["cond_latent_mask"] = mask - data_batch[WAN2PT1_I2V_COND_LATENT_KEY] = cond_latent - data_batch["cond_latent_buffer"] = kwargs.get("cond_latent_buffer", None) - - # Add-through extras if present - if "fps" in kwargs and kwargs["fps"] is not None: - data_batch["fps"] = kwargs["fps"] - if "padding_mask" in kwargs and kwargs["padding_mask"] is not None: - data_batch["padding_mask"] = kwargs["padding_mask"] - - # 3) Build condition / uncondition with negative prompt - is_image_batch = False - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - condition = condition.edit_data_type(DataType.IMAGE if is_image_batch else DataType.VIDEO) - uncondition = uncondition.edit_data_type(DataType.IMAGE if is_image_batch else DataType.VIDEO) - - # 4) Init latents: keep history clean, random on generated region - init_latents = torch.zeros((B, C, T_hist + T_new, H, W), dtype=torch.float32, device=self.tensor_kwargs["device"]) # float32 for sampler math - init_latents[:, :, :T_hist] = history_latents.to(dtype=torch.float32, device=self.tensor_kwargs["device"]) # clean history - - gen_shape = (B, C, T_new, H, W) - gen_noise = misc.arch_invariant_rand( - gen_shape, - torch.float32, - self.tensor_kwargs["device"], - seed if seed is not None else 0, - ) - init_latents[:, :, T_hist:] = gen_noise - - # 5) CP broadcast (no temporal split for Lyra2; net handles internally) - cp_group = self.get_context_parallel_group() - if cp_group is not None: - init_latents = broadcast(init_latents.contiguous(), cp_group) - condition = condition.broadcast(cp_group) - uncondition = uncondition.broadcast(cp_group) - else: - # Some network variants may not expose the property until enabled; use getattr - assert not getattr(self.net, "is_context_parallel_enabled", False), ( - "context parallel should be disabled if parallel_state is not initialized" - ) - - # 6) Define x0_fn following Lyra2 get_x0_fn_from_batch semantics (zeros over history) - def x0_fn(noise_x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: - cond_v_gen = self.denoise(noise_x, timestep, condition) - uncond_v_gen = self.denoise(noise_x, timestep, uncondition) - gen_v = uncond_v_gen + guidance * (cond_v_gen - uncond_v_gen) - vt_full = torch.zeros_like(noise_x, dtype=gen_v.dtype) - vt_full[:, :, T_hist:] = gen_v - return vt_full - - # 7) Sampling loop - self.sample_scheduler.set_timesteps(num_steps, device=self.tensor_kwargs["device"], shift=shift) - timesteps = self.sample_scheduler.timesteps - - seed_g = torch.Generator(device=self.tensor_kwargs["device"]) - seed_g.manual_seed(seed if seed is not None else 0) - - latents = init_latents - for _, t in enumerate(timesteps): - latent_model_input = latents - timestep = torch.stack([t]) - velocity_pred = x0_fn(latent_model_input, timestep.unsqueeze(0)) - temp_x0 = self.sample_scheduler.step( - velocity_pred.unsqueeze(0), - t, - latents[0].unsqueeze(0), - return_dict=False, - generator=seed_g, - )[0] - latents = temp_x0.squeeze(0) - - # 8) Return only the newly generated latent chunk - return latents[:, :, T_hist:] - - def _convert_flow_pred_to_x0( - self, scheduler, flow_pred: torch.Tensor, xt: torch.Tensor, timestep: torch.Tensor - ) -> torch.Tensor: - """Convert flow-matching prediction (noise - x0) to x0 prediction. - - x_t = (1 - sigma_t) * x0 + sigma_t * noise, pred = noise - x0 - => x0 = x_t - sigma_t * pred - """ - original_dtype = flow_pred.dtype - flow_pred, xt, sigmas, timesteps = map( - lambda x: x.double().to(flow_pred.device), - [flow_pred, xt, scheduler.sigmas, scheduler.timesteps], - ) - timestep_id = torch.argmin( - (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 - ) - sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1) - x0_pred = xt - sigma_t * flow_pred - return x0_pred.to(original_dtype) - - def inference_dmd( - self, - history_latents, - cond_latent, - guidance, - seed, - num_steps, - shift, - t5_text_embeddings, - neg_t5_text_embeddings, - **kwargs, - ): - """DMD-distilled fast (4-step) inference. Mirrors :meth:`inference` but uses - a fixed 4-step flow-matching schedule and does not run CFG (the distilled - LoRA is already conditional-only). - """ - # 0) Create DMD flow scheduler (4-step schedule over 1000 train timesteps) - denoising_step_list: List[int] = [1000, 750, 500, 250] - num_train_timestep: int = 1000 - if self.dmd_scheduler is None: - from lyra_2._src.schedulers.self_forcing_scheduler import FlowMatchScheduler - - self.dmd_scheduler = FlowMatchScheduler( - shift=5.0, sigma_min=0.0, extra_one_step=True - ) - self.dmd_scheduler.set_timesteps(num_train_timestep, training=True) - self.dmd_scheduler.timesteps = self.dmd_scheduler.timesteps.to(history_latents.device) - self.denoising_step_list = torch.LongTensor(denoising_step_list) - timesteps = torch.cat( - ( - self.dmd_scheduler.timesteps.cpu(), - torch.tensor([0], dtype=torch.float32), - ) - ) - self.denoising_step_list = timesteps[num_train_timestep - self.denoising_step_list] - - # 1) Validate history latent length - T_hist_expected = self.framepack_total_max_num_latent_frames - self.framepack_num_new_latent_frames - assert ( - history_latents.shape[2] == T_hist_expected - ), f"history_latents has T={history_latents.shape[2]} but expected {T_hist_expected}" - - # 2) Build conditioner inputs - assert ( - "last_hist_frame" in kwargs - ), "last_hist_frame (pixel) is required in kwargs for Lyra2 inference" - last_hist_frame = kwargs["last_hist_frame"] - - data_batch = { - "t5_text_embeddings": t5_text_embeddings, - "neg_t5_text_embeddings": neg_t5_text_embeddings, - "last_hist_frame": last_hist_frame, - } - - B, C, T_hist, H, W = history_latents.shape - T_new = self.framepack_num_new_latent_frames - assert cond_latent.shape[2] == T_hist + T_new, "cond_latent must have T_hist + T_new frames" - - mask = kwargs.get("cond_latent_mask", None) - if mask is None: - mask = torch.ones((B, 4, T_hist + T_new, H, W), dtype=history_latents.dtype, device=history_latents.device) - - data_batch["cond_latent_mask"] = mask - data_batch[WAN2PT1_I2V_COND_LATENT_KEY] = cond_latent - data_batch["cond_latent_buffer"] = kwargs.get("cond_latent_buffer", None) - - if "fps" in kwargs and kwargs["fps"] is not None: - data_batch["fps"] = kwargs["fps"] - if "padding_mask" in kwargs and kwargs["padding_mask"] is not None: - data_batch["padding_mask"] = kwargs["padding_mask"] - - # 3) Build condition / uncondition (uncondition unused in DMD path but required by conditioner API) - is_image_batch = False - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - condition = condition.edit_data_type(DataType.IMAGE if is_image_batch else DataType.VIDEO) - uncondition = uncondition.edit_data_type(DataType.IMAGE if is_image_batch else DataType.VIDEO) - - # 4) Init latents: keep history clean, random noise on generated region - init_latents = torch.zeros( - (B, C, T_hist + T_new, H, W), dtype=torch.float32, device=self.tensor_kwargs["device"] - ) - init_latents[:, :, :T_hist] = history_latents.to( - dtype=torch.float32, device=self.tensor_kwargs["device"] - ) - gen_shape = (B, C, T_new, H, W) - gen_noise = misc.arch_invariant_rand( - gen_shape, - torch.float32, - self.tensor_kwargs["device"], - seed if seed is not None else 0, - ) - init_latents[:, :, T_hist:] = gen_noise - - # 5) CP broadcast - cp_group = self.get_context_parallel_group() - if cp_group is not None: - init_latents = broadcast(init_latents.contiguous(), cp_group) - condition = condition.broadcast(cp_group) - uncondition = uncondition.broadcast(cp_group) - else: - assert not getattr(self.net, "is_context_parallel_enabled", False), ( - "context parallel should be disabled if parallel_state is not initialized" - ) - - # 6) x0 prediction function (no CFG: distilled LoRA is conditional-only) - def x0_fn(noise_x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: - flow_pred = self.denoise(noise_x, timestep, condition) # B, C, T, H, W - flow_full = torch.zeros_like(noise_x, dtype=flow_pred.dtype) - flow_full[:, :, T_hist:] = flow_pred - flow_full = flow_full.permute(0, 2, 1, 3, 4) # B, T, C, H, W - noisy_image_or_video = noise_x.permute(0, 2, 1, 3, 4) - - pred_x0 = self._convert_flow_pred_to_x0( - scheduler=self.dmd_scheduler, - flow_pred=flow_full.flatten(0, 1), - xt=noisy_image_or_video.flatten(0, 1), - timestep=timestep.flatten(0, 1), - ).unflatten(0, flow_full.shape[:2]) - - return pred_x0.permute(0, 2, 1, 3, 4) # back to B, C, T, H, W - - # 7) Sampling loop (4-step) - seed_g = torch.Generator(device=self.tensor_kwargs["device"]) - seed_g.manual_seed(seed if seed is not None else 0) - - denoising_step_list = self.denoising_step_list - exit_flag = len(denoising_step_list) - 1 - - latents = init_latents - for index, current_timestep in enumerate(denoising_step_list): - latent_model_input = latents - timestep = torch.stack([current_timestep]).to(self.tensor_kwargs["device"]) - - if index < exit_flag: - with torch.no_grad(): - noise_pred = x0_fn(latent_model_input, timestep.unsqueeze(0)) # B, C, T, H, W - noise_pred = noise_pred.permute(0, 2, 1, 3, 4) # B, T, C, H, W - next_timestep = denoising_step_list[index + 1] - new_noise = torch.randn_like(noise_pred.flatten(0, 1)) - if cp_group is not None: - new_noise = broadcast(new_noise.contiguous(), cp_group) - temp_x0 = self.dmd_scheduler.add_noise( - noise_pred.flatten(0, 1), - new_noise, - next_timestep * torch.ones( - [noise_pred.shape[0] * noise_pred.shape[1]], - device=noise_pred.device, - dtype=torch.long, - ), - ).unflatten(0, noise_pred.shape[:2]) - latents = temp_x0.permute(0, 2, 1, 3, 4) - latents[:, :, :T_hist] = latent_model_input[:, :, :T_hist] - else: - noise_pred = x0_fn(latent_model_input, timestep.unsqueeze(0)) - latents = noise_pred - latents[:, :, :T_hist] = latent_model_input[:, :, :T_hist] - break - - # 8) Return only the newly generated latent chunk - return latents[:, :, T_hist:] - - @torch.no_grad() - def generate_samples_from_batch( - self, - data_batch: dict, - guidance: float = 1.5, - seed: int = 1, - state_shape: tuple | None = None, - n_sample: int | None = None, - is_negative_prompt: bool = False, - num_steps: int = 35, - shift: float = 5.0, - return_condition_state: bool = False, - **kwargs, - ) -> torch.Tensor: - """ - Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. - Args: - data_batch (dict): raw data batch draw from the training data loader. - iteration (int): Current iteration number. - guidance (float): guidance weights - seed (int): random seed - state_shape (tuple): shape of the state, default to data batch if not provided - n_sample (int): number of samples to generate - is_negative_prompt (bool): use negative prompt t5 in uncondition if true - num_steps (int): number of steps for the diffusion process - solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) - """ - - self._normalize_video_databatch_inplace(data_batch) - self._augment_image_dim_inplace(data_batch) - is_image_batch = self.is_image_batch(data_batch) - input_key = self.input_image_key if is_image_batch else self.input_data_key - - - seed_g = torch.Generator(device=self.tensor_kwargs["device"]) - seed_g.manual_seed(seed) - - self.sample_scheduler.set_timesteps( - num_steps, device=self.tensor_kwargs["device"], shift=shift) - - timesteps = self.sample_scheduler.timesteps - - # Indicate whether to collect visualization/condition payloads during tokenization - prev_collect_flag = getattr(self, "_collect_return_condition_state", False) - self._collect_return_condition_state = bool(return_condition_state) - x0_fn, init_latents = self.get_x0_fn_from_batch( - data_batch, guidance, is_negative_prompt=is_negative_prompt, seed=seed - ) - latents = init_latents - - # Broadcast initial latents across CP ranks; do not split (net handles CP internally) - cp_group = self.get_context_parallel_group() - if cp_group is not None: - latents = broadcast(latents.contiguous(), cp_group) - - for _, t in enumerate(timesteps): - latent_model_input = latents - timestep = [t] - - timestep = torch.stack(timestep) - - noise_pred = x0_fn(latent_model_input, timestep.unsqueeze(0)) - temp_x0 = self.sample_scheduler.step( - noise_pred.unsqueeze(0), - t, - latents[0].unsqueeze(0), - return_dict=False, - generator=seed_g)[0] - latents = temp_x0.squeeze(0) - - - if return_condition_state: - cond_parts = [] - gt_vis = getattr(self, "_latest_gt_gen_pixels", None) - cond_pixels = getattr(self, "_latest_condition_state_pixels", None) - if cond_pixels is not None: - cond_parts.append(cond_pixels) - rays = getattr(self, "_latest_plucker_rays_pixels", None) - cond_plucker = None - if rays is not None and isinstance(rays, dict): - ray_o = rays.get("ray_origin", None) - ray_d = rays.get("ray_direction", None) - if ray_o is not None and ray_d is not None: - # Convert to [B, 6, T, H, W] - cond_plucker = torch.cat([ - ray_o.permute(0, 4, 1, 2, 3).contiguous(), - ray_d.permute(0, 4, 1, 2, 3).contiguous(), - ], dim=1) - if cond_plucker is not None: - cond_parts.append(cond_plucker) - cond_state = None - if len(cond_parts) > 0: - if len(cond_parts) == 1: - cond_state = cond_parts[0] - else: - target_T = max(part.shape[2] for part in cond_parts) - target_HW = {(part.shape[3], part.shape[4]) for part in cond_parts} - assert len(target_HW) == 1, ( - f"Cannot concatenate condition parts with different spatial sizes: {target_HW}" - ) - aligned_parts = [] - for part in cond_parts: - if part.shape[2] < target_T: - pad_frames = target_T - part.shape[2] - pad = torch.zeros( - (part.shape[0], part.shape[1], pad_frames, part.shape[3], part.shape[4]), - dtype=part.dtype, - device=part.device, - ) - part = torch.cat([pad, part], dim=2) - elif part.shape[2] > target_T: - part = part[:, :, -target_T:] - aligned_parts.append(part) - cond_state = torch.cat(aligned_parts, dim=1) - # Clear cached references immediately after assembling return payload - self._latest_condition_state_pixels = None - self._latest_plucker_rays_pixels = None - self._latest_gt_gen_pixels = None - # Restore flag - self._collect_return_condition_state = prev_collect_flag - if gt_vis is not None and cond_state is not None: - return latents, cond_state, gt_vis - if cond_state is not None: - return latents, cond_state - if gt_vis is not None: - return latents, None, gt_vis - # Restore flag when not returning condition state - self._collect_return_condition_state = prev_collect_flag - return latents - - # ------------------------ Inference utilities ------------------------ - def denoise(self, xt_B_C_T_H_W, timestep, condition): - """Lyra2-aware denoise: return only generated-region prediction from net.""" - vt_pred_gen = self.net( - x_B_C_T_H_W=xt_B_C_T_H_W.to(**self.tensor_kwargs), - timesteps_B_T=timestep, - **condition.to_dict(), - **self.framepack_params, - ).float() - return vt_pred_gen - - # ------------------------ Training utilities ------------------------ - def _clone_vae_cache(self, cache_list): - """Clone a VAE encoder cache list, cloning tensors to avoid aliasing.""" - cloned = [] - for it in cache_list: - if torch.is_tensor(it): - cloned.append(it.clone()) - else: - cloned.append(it) - return cloned - - def _vae_encode_range_stream(self, x_vid, start_t, end_t, skip_first_frame: bool = False): - """Stream-encode frames in [start_t, end_t) using current encoder caches, return features.""" - vae_wrap = self.tokenizer.model # WanVAE wrapper - vae_core = vae_wrap.model # WanVAE_ core - temporal_window = vae_wrap.temporal_window - - feats = [] - # First frame handling only when starting from t=0 and not resuming from a prior cache - if start_t == 0 and not bool(skip_first_frame): - vae_core._enc_conv_idx = [0] - out0 = vae_core.encoder( - x_vid[:, :, :1, :, :], - feat_cache=vae_core._enc_feat_map, - feat_idx=vae_core._enc_conv_idx, - ) - feats.append(out0) - start_t = 1 - # Full temporal_window chunks - pos = start_t - while pos + temporal_window <= end_t: - vae_core._enc_conv_idx = [0] - out_w = vae_core.encoder( - x_vid[:, :, pos: pos + temporal_window, :, :], - feat_cache=vae_core._enc_feat_map, - feat_idx=vae_core._enc_conv_idx, - ) - feats.append(out_w) - pos += temporal_window - # Remainder - if pos < end_t: - vae_core._enc_conv_idx = [0] - out_r = vae_core.encoder( - x_vid[:, :, pos: end_t, :, :], - feat_cache=vae_core._enc_feat_map, - feat_idx=vae_core._enc_conv_idx, - ) - feats.append(out_r) - if len(feats) == 1: - return feats[0] - return torch.cat(feats, dim=2) - - @torch.no_grad() - def vae_encode_with_cache(self, enc_cache, video, start_t=None, end_t=None, return_cache=False): - """Resume encoder from `cache` and stream-encode `video[start_t:end_t)` into features. - - Args: - cache: Encoder feature cache captured from VAE encoder state. - Accepts either a list of per-layer cached tensors, or a legacy (enc_cache, enc_idx) tuple. - video: [B, C, T, H, W] pixels in the model's pixel range. - start_t: start index (inclusive). Defaults to 0. - end_t: end index (exclusive). Defaults to T. - Returns: - Tensor of encoder features over the requested range. - """ - vae_wrap = self.tokenizer.model # WanVAE wrapper - vae_core = vae_wrap.model # WanVAE_ core - with vae_wrap.context: - if not vae_wrap.is_amp: - x_vid = video.to(vae_wrap.dtype) - else: - x_vid = video - # Restore encoder caches - vae_core._enc_feat_map = self._clone_vae_cache(enc_cache) - # Always start layer-iteration index at 0 when resuming with an explicit cache - vae_core._enc_conv_idx = [0] - # Resolve range - T_total = int(x_vid.shape[2]) - s = 0 if start_t is None else int(start_t) - e = T_total if end_t is None or int(end_t) < 0 else int(end_t) - # When resuming from a provided cache, avoid the I-frame special case even if s==0 - feats = self._vae_encode_range_stream(x_vid, s, e, skip_first_frame=True) - if return_cache: - cache_current = self._clone_vae_cache(vae_core._enc_feat_map) - return feats, cache_current - else: - return feats - - @torch.no_grad() - def _encoder_feats_to_normalized_latents(self, encoder_feats: torch.Tensor) -> torch.Tensor: - """Project encoder features to [mu, logvar], take mu, and apply channel-wise and per-frame normalization. - - Matches the normalization used by Wan2pt1VAEInterface.encode in this codebase. - """ - vae_iface = self.tokenizer - vae_wrap = vae_iface.model # WanVAE wrapper - vae_core = vae_wrap.model # WanVAE_ core - with vae_wrap.context: - mu_logvar = vae_core.conv1(encoder_feats) - mu, _ = mu_logvar.chunk(2, dim=1) - # Channel-wise normalization - mean_c = vae_wrap.scale[0] - inv_std_c = vae_wrap.scale[1] - if torch.is_tensor(mean_c): - mu = (mu - mean_c.view(1, vae_core.z_dim, 1, 1, 1).type_as(mu)) * inv_std_c.view(1, vae_core.z_dim, 1, 1, 1).type_as(mu) - else: - mu = (mu - mean_c) * inv_std_c - # Per-frame normalization - if int(mu.shape[2]) == 1: - latents = (mu - vae_wrap.img_mean.type_as(mu)) / vae_wrap.img_std.type_as(mu) - else: - latents = (mu - vae_wrap.video_mean[:, :, :1].type_as(mu)) / vae_wrap.video_std[:, :, :1].type_as(mu) - return latents - - @torch.no_grad() - def _vae_encode_with_shared_prefix(self, video, gen_cond_pixels=None, return_cache=False): - """Efficiently encode full latents and zero-tailed conditional latents by reusing VAE encoder caches. - - Returns (latents, cond_latent), both normalized the same way as tokenizer.encode. - """ - # Build zero-tail video for I2V conditioning - _video = video.clone() - if gen_cond_pixels is None: - _video[:, :, -self.framepack_num_new_video_frames:] = 0 - else: - _video[:, :, -self.framepack_num_new_video_frames:] = gen_cond_pixels - - # Stream-encode prefix once, - # snapshot encoder caches, then continue with zero-tail and real-tail separately. - vae_iface = self.tokenizer # Wan2pt1VAEInterface - vae_wrap = vae_iface.model # WanVAE wrapper - vae_core = vae_wrap.model # WanVAE_ core - - T_total = int(video.shape[2]) - T_hist = T_total - int(self.framepack_num_new_video_frames) - - # Use VAE's autocast/dtype context for consistency - with vae_wrap.context: - if not vae_wrap.is_amp: - video_cast = video.to(vae_wrap.dtype) - video_zero_cast = _video.to(vae_wrap.dtype) - else: - video_cast = video - video_zero_cast = _video - - # Reset encoder caches and build prefix features once - vae_core.clear_cache() - vae_core._enc_conv_idx = [0] - prefix_feats = self._vae_encode_range_stream(video_cast, 0, T_hist) - - # Snapshot caches post-prefix (per-layer feature map cache only) - enc_cache_after_prefix = self._clone_vae_cache(vae_core._enc_feat_map) - - # Continue with zero-tail from cached state - zero_tail_feats = self.vae_encode_with_cache( - enc_cache_after_prefix, - video_zero_cast, - start_t=T_hist, - end_t=T_total, - ) - - # Continue with real-tail from the SAME cached state - real_tail_feats = self.vae_encode_with_cache( - enc_cache_after_prefix, - video_cast, - start_t=T_hist, - end_t=T_total, - ) - - # Stitch features and run final 1x1 causal conv to produce [mu, logvar] - feats_cond = torch.cat([prefix_feats, zero_tail_feats], dim=2) if T_hist < T_total else prefix_feats - feats_full = torch.cat([prefix_feats, real_tail_feats], dim=2) if T_hist < T_total else prefix_feats - cond_latent = self._encoder_feats_to_normalized_latents(feats_cond) - latents = self._encoder_feats_to_normalized_latents(feats_full) - - in_dtype = video.dtype - latents = latents.contiguous().to(in_dtype) - cond_latent = cond_latent.contiguous().to(in_dtype) - # Clear encoder feature caches to release memory - - if return_cache: - cache_after_prefix = enc_cache_after_prefix - cache_current = self._clone_vae_cache(vae_core._enc_feat_map) - return latents, cond_latent, cache_after_prefix, cache_current - else: - return latents, cond_latent - - # ------------------------ Training utilities ------------------------ - def prepare_latent_conditon( - self, - condition_state, - condition_state_mask, - dtype, - condition_video_augment_sigma_in_inference=0.001, - seed_inference=1, - is_testing=True, - ): - """Encode warped condition frames into VAE latents. - - Only support the simple case where condition_state has 5 dims: [B, C, T, H, W]. - """ - if condition_state.dim() == 5: - # Prepend the first frame to satisfy VAE's I-frame requirement, - # then drop the first latent token after encoding. - first = condition_state[:, :, :1] - condition_state = torch.cat([first, condition_state], dim=2) - latent_condition = self.encode(condition_state.to(dtype)).contiguous() - latent_condition = latent_condition[:, :, 1:] # drop the inserted I-frame token - else: - raise NotImplementedError("prepare_latent_conditon only supports 5D condition_state in Lyra2Model") - return latent_condition - - # ----------------------------- Training ----------------------------- - def training_step( - self, data_batch, iteration - ): - """Lyra2 training: loss only on generated latents; history latents are clean. - """ - if not self.framepack_weights_initialized and iteration == 0 and self.config.init_framepack_weights: - self.net.copy_weights_to_clean_patch_embeddings() - self.framepack_weights_initialized = True - self._update_train_stats(data_batch) - dropout = False - - prob = float(getattr(self.config, "self_aug_prob", 1.0)) - # Synchronized decision across devices via broadcast from rank 0 - rand_tensor = torch.zeros(1, device=self.tensor_kwargs["device"], dtype=torch.float32) - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - rand_tensor.uniform_(0.0, 1.0) - torch.distributed.broadcast(rand_tensor, src=0) - else: - rand_tensor.uniform_(0.0, 1.0) - rand_val = float(rand_tensor.item()) - # ========================= Stage A: Self-augmentation (optional) ========================= - if ( - getattr(self.config, "self_aug_enabled", False) - and (iteration % int(getattr(self.config, "self_aug_every_k", 1)) == 0) - and rand_val <= prob - ): - if float(rand_tensor.item()) > self.config.self_aug_i2v_ratio: - raw_state, x0_latents, condition = self.get_data_and_condition(data_batch, dropout=dropout) - with misc.timer("self_aug_training_step", debug=False), torch.no_grad(): - # Indices and shapes - T_hist = self.framepack_total_max_num_latent_frames - self.framepack_num_new_latent_frames - stride = int(self.framepack_num_frames_per_latent) - - # Build init latents: keep history clean; tail will be partially noised via RectifiedFlow - init_latents_aug = torch.zeros_like(x0_latents, dtype=torch.float32) - init_latents_aug[:, :, :T_hist] = x0_latents[:, :, :T_hist].to(torch.float32).clone() - - # Stage-A RectifiedFlow: uniform time with max_T scaling - stageA_flow = self.rectified_flow - - B = int(x0_latents.shape[0]) - tA_single = stageA_flow.sample_train_time(1).to(**self.flow_matching_kwargs) # [1] e.g. 0.193 - timestepA_single = stageA_flow.get_discrete_timestamp(tA_single, self.flow_matching_kwargs) # [1], e.g. 954 - max_T = int(getattr(self.config, "self_aug_max_T", 50)) - scale = max_T / 1000.0 - timestepA_single = (timestepA_single.float() * scale).floor().clamp(min=1).to(dtype=timestepA_single.dtype) - - # Find the nearest value in stageA_flow.noise_scheduler.timesteps to the single timestep - # stageA_flow.noise_scheduler.timesteps is a 1D tensor of available timesteps - available_timesteps = stageA_flow.noise_scheduler.timesteps.to(timestepA_single.device, dtype=timestepA_single.dtype) - - # timestepA_single: [1], available_timesteps: [N] - # Find the index of the closest available timestep - diff = torch.abs(available_timesteps - timestepA_single) # [N] - nearest_index = diff.argmin() # scalar - timestepA_single = available_timesteps[nearest_index].unsqueeze(0) # [1] - - timestepsA = timestepA_single.expand(B, 1) # [B, 1] - broadcast to batch - sigmasA = stageA_flow.get_sigmas(timestepsA, self.flow_matching_kwargs) # [B] - sigmasA_B1 = rearrange(sigmasA, "b -> b 1") - - log.info(f"self aug timestep={timestepA_single}", rank0_only=False) - - # Build epsilon only for the tail region and form xt via interpolation - eps_tail = torch.randn_like( - x0_latents[:, :, T_hist:].to(torch.float32), dtype=self.flow_matching_kwargs["dtype"] - ) #e.g. ([1, 16, 9, 56, 96]) - - xt_tail, _ = stageA_flow.get_interpolation( - eps_tail, x0_latents[:, :, T_hist:].to(torch.float32), sigmasA_B1 - ) - init_latents_aug[:, :, T_hist:] = xt_tail - - # Build Stage-A CFG conditions (video) and CP broadcast if needed - is_image_batch = self.is_image_batch(data_batch) - condition_A, uncondition_A = self.conditioner.get_condition_with_negative_prompt(data_batch) - condition_A = condition_A.edit_data_type(DataType.IMAGE if is_image_batch else DataType.VIDEO) - uncondition_A = uncondition_A.edit_data_type(DataType.IMAGE if is_image_batch else DataType.VIDEO) - - cp_group = self.get_context_parallel_group() - if cp_group is not None: - init_latents_aug = broadcast(init_latents_aug.contiguous(), cp_group) - condition_A = condition_A.broadcast(cp_group) - uncondition_A = uncondition_A.broadcast(cp_group) - sigmasA = broadcast(sigmasA.contiguous(), cp_group) - - # x0_fn: zeros history, predict only tail with CFG - guidance_A = float(getattr(self.config, "self_aug_guidance", 1.0) or 1.0) - - def x0_fn_A(noise_x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: - cond_v_gen = self.denoise(noise_x, timestep, condition_A) - if guidance_A == 1.0: - gen_v = cond_v_gen - else: - uncond_v_gen = self.denoise(noise_x, timestep, uncondition_A) - gen_v = uncond_v_gen + guidance_A * (cond_v_gen - uncond_v_gen) - vt_full = torch.zeros_like(noise_x, dtype=gen_v.dtype) - vt_full[:, :, T_hist:] = gen_v - return vt_full - - # Short sampling for Stage-A - steps_A = int(getattr(self.config, "self_aug_steps", 3)) - shift_A = ( - float(self.config.self_aug_scheduler_shift) - if getattr(self.config, "self_aug_scheduler_shift", None) is not None - else 1.0 - ) - - sigmas_steps =np.linspace(sigmasA.squeeze().item(), 0.0, steps_A) - self.sample_scheduler.set_timesteps( - steps_A, device=self.tensor_kwargs["device"], sigmas=sigmas_steps, shift=shift_A - ) - timesteps_iter = self.sample_scheduler.timesteps - - latents_A = init_latents_aug - for _, t in enumerate(timesteps_iter): - latent_model_input = latents_A - timestep = torch.stack([t]) - vt_pred = x0_fn_A(latent_model_input, timestep.unsqueeze(0)) - temp_x0 = self.sample_scheduler.step( - vt_pred.unsqueeze(0), - t, - latents_A[0].unsqueeze(0), - return_dict=False, - )[0] - latents_A = temp_x0.squeeze(0) - - # Decode Stage-A tail to pixels, stitch into the pixelized version of current latent window, re-encode - # Cached streaming decode/encode update - latents_new = latents_A[:, :, T_hist:] - assert latents_new.shape[2] == self.framepack_num_new_latent_frames - history_latents = data_batch["_stage_a_full_latents"][:, :, : -self.framepack_num_new_latent_frames] - latents_to_decode = torch.cat([history_latents, latents_new], dim=2) - full_frames = self.decode(latents_to_decode.to(self.tensor_kwargs["dtype"])) - # First latent frame is the start I-frame; remaining are V-frames. - chunk_frames = full_frames[:, :, -(self.framepack_total_max_num_latent_frames - 1)*self.framepack_num_frames_per_latent :] - del history_latents - del latents_to_decode, latents_new, full_frames - - # Get the frame indices and metadata from stage A - video_indices = data_batch.get("_stage_a_video_indices") - stage_a_start = data_batch.get("_stage_a_start") - stage_a_cur_segment_id = data_batch.get("_stage_a_cur_segment_id") - - # Create a modified data_batch with augmented frames inserted at correct positions - data_batch_aug = data_batch - original_video = data_batch[self.input_data_key] - # Insert chunk_frames into the original video at the correct positions - if video_indices is not None and len(video_indices) > 0: - # Create a copy of the original video - augmented_video = original_video.clone() - original_indices = video_indices.cpu().numpy() - # Only replace frames that actually exist in the original video and haven't been repeated/padded - valid_indices = original_indices[original_indices < original_video.shape[2]] - - if len(valid_indices) > 0: - # Map the valid indices to the corresponding frames in video_px_aug2 - # align the tails - if len(valid_indices) > chunk_frames.shape[2]: - n_extra = len(valid_indices) - chunk_frames.shape[2] - valid_indices = valid_indices[n_extra:] - elif len(valid_indices) < chunk_frames.shape[2]: - n_extra = chunk_frames.shape[2] - len(valid_indices) - chunk_frames = chunk_frames[-n_extra:] - - if not self.config.self_aug_copy_chunk: - valid_indices = valid_indices[-self.framepack_num_new_video_frames:] - chunk_frames = chunk_frames[:, :, -self.framepack_num_new_video_frames:] - max_aug_frames = min(len(valid_indices), chunk_frames.shape[2]) - for i in range(max_aug_frames): - orig_idx = int(valid_indices[i]) - augmented_video[:, :, orig_idx] = chunk_frames[:, :, i].clamp(-1, 1) - - data_batch_aug[self.input_data_key] = augmented_video - else: - raise ValueError("Fallback not implemented") - - # Preserve the original start to maintain consistency - if stage_a_start is not None: - data_batch_aug["start"] = stage_a_start - # increment cur_segment_id by 1 - if stage_a_cur_segment_id is not None: - data_batch_aug["cur_segment_id"] = stage_a_cur_segment_id + 1 - - # Regenerate x0_latents and condition with the preserved parameters - raw_state, x0_latents, condition = self.get_data_and_condition(data_batch_aug, dropout=dropout) - else: - data_batch["cur_segment_id"] = 0 # train i2v - data_batch["is_i2v"] = True - raw_state, x0_latents, condition = self.get_data_and_condition(data_batch, dropout=dropout) - else: - raw_state, x0_latents, condition = self.get_data_and_condition(data_batch, dropout=dropout) - - # Remove a series of cache and input keys from data_batch if present - if hasattr(self.tokenizer.model.model, "clear_cache"): - try: - self.tokenizer.model.model.clear_cache() - del original_video - del augmented_video - except Exception: - pass - - del raw_state - for k in [ - "_stage_a_vae_cache_T-2", - "_stage_a_vae_cache_T-1", - "stage_a_full_latents", - "t5_chunk_embeddings", - "t5_chunk_mask", - "depth", - "_stage_a_full_latents", - "camera_w2c", - "intrinsics", - "t5_chunk_keys", - "sample_frame_indices", - "control_input_world_scenario", - WAN2PT1_I2V_COND_LATENT_KEY, - "video", - "last_hist_frame", - ]: - if k in data_batch: - del data_batch[k] - data_batch["video"] = torch.zeros([1]) - - gc.collect() - torch.cuda.empty_cache() - - # Sample times - batch_size = x0_latents.size(0) - t_B = self.rectified_flow.sample_train_time(batch_size).to(**self.flow_matching_kwargs) - t_B = rearrange(t_B, "b -> b 1") - - # Build epsilon BEFORE CP split: zeros on history, noise on generated region. - T_hist = self.framepack_total_max_num_latent_frames - self.framepack_num_new_latent_frames - epsilon_full = x0_latents.clone() - epsilon_full[:, :, T_hist:] = torch.randn_like( - x0_latents[:, :, T_hist:], dtype=self.flow_matching_kwargs["dtype"] - ) - - # Enable CP (without splitting T) and broadcast inputs so all CP ranks see identical data; - # CP splitting happens after Lyra2 patchify inside net - cp_group = self.get_context_parallel_group() - if cp_group is not None: - x0_latents = broadcast(x0_latents.contiguous(), cp_group) - epsilon_full = broadcast(epsilon_full.contiguous(), cp_group) - t_B = broadcast(t_B.contiguous(), cp_group) - condition = condition.broadcast(cp_group) - self.net.enable_context_parallel(cp_group) - else: - self.net.disable_context_parallel() - # CP after Lyra2 patchify - timesteps = self.rectified_flow.get_discrete_timestamp(t_B, self.flow_matching_kwargs) - sigmas = self.rectified_flow.get_sigmas(timesteps, self.flow_matching_kwargs) - timesteps = rearrange(timesteps, "b -> b 1") - sigmas = rearrange(sigmas, "b -> b 1") - - # Compute interpolation directly on already-split tensors - xt, vt = self.rectified_flow.get_interpolation(epsilon_full, x0_latents, sigmas) - - # Net forward: it returns only generated region prediction when Lyra2 params are provided - vt_pred_gen = self.net( - x_B_C_T_H_W=xt.to(**self.tensor_kwargs), - timesteps_B_T=timesteps, - **condition.to_dict(), - **self.framepack_params, - ) - - # Loss only over generated region; align with base model weighting - time_weights_B = self.rectified_flow.train_time_weight(timesteps, self.flow_matching_kwargs) - vt_gen_target = vt[:, :, T_hist:].to(vt_pred_gen.dtype) - per_instance_loss = torch.mean((vt_pred_gen - vt_gen_target) ** 2, dim=list(range(1, vt_pred_gen.dim()))) - loss = torch.mean(time_weights_B * per_instance_loss) - output_batch = {"edm_loss": loss} - - return output_batch, loss - - - - def _select_temporal_history_indices(self, T_hist_total, num_temporal_hist): - """Select temporal history latent indices: always include important start (0) and most recent rest.""" - temporal_rest_needed = max(0, num_temporal_hist - 1) - recent_start = max(0, T_hist_total - temporal_rest_needed) - temporal_rest = list(range(recent_start, T_hist_total)) - return [0] + temporal_rest - - def _compose_selected_indices( - self, - splits, - types, - T_hist_total, - temporal_selected, - spatial_selected, - ): - device = self.tensor_kwargs["device"] - - temporal_pool = [idx for idx in temporal_selected if idx != 0] - spatial_pool = list(spatial_selected) - ordered_past: list[int] = [] - for seg_idx, (cnt, tp) in enumerate(zip(splits, types)): - if tp == "k": - for j in range(cnt): - if seg_idx == 0 and j == 0: - ordered_past.append(0) - else: - ordered_past.append(temporal_pool.pop(0)) - else: - for _ in range(cnt): - if len(spatial_pool) == 0: - ordered_past.append(T_hist_total - 1) - else: - ordered_past.append(spatial_pool.pop(0)) - - selected_idx = torch.tensor(ordered_past, device=device, dtype=torch.long) - return selected_idx - - @staticmethod - def _build_canonical_spatial_coords( - H: int, - W: int, - num_spatial_hist: int, - device: torch.device, - dtype: torch.dtype, - ) -> Optional[torch.Tensor]: - if num_spatial_hist <= 0: - return None - xs = torch.linspace(-1.0, 1.0, W, device=device, dtype=dtype) - ys = torch.linspace(-1.0, 1.0, H, device=device, dtype=dtype) - yy, xx = torch.meshgrid(ys, xs, indexing="ij") - base_xy = torch.stack([xx, yy], dim=0) # [2,H,W] - base_xy = base_xy.unsqueeze(0).repeat(num_spatial_hist, 1, 1, 1) # [N,2,H,W] - if num_spatial_hist == 1: - zs = torch.zeros(1, device=device, dtype=dtype) - else: - zs = torch.linspace(-1.0, 1.0, num_spatial_hist, device=device, dtype=dtype) - z = zs.view(num_spatial_hist, 1, 1, 1).expand(num_spatial_hist, 1, H, W) - coords = torch.cat([base_xy, z], dim=1) # [N,3,H,W] - return coords - - def _get_cached_spatial_coords( - self, - H: int, - W: int, - num_spatial_hist: int, - device: torch.device, - dtype: torch.dtype, - ) -> Optional[torch.Tensor]: - meta = (H, W, num_spatial_hist, device, dtype) - if self._cached_spatial_coords is None or self._cached_spatial_coords_meta != meta: - self._cached_spatial_coords = self._build_canonical_spatial_coords( - H=H, - W=W, - num_spatial_hist=num_spatial_hist, - device=device, - dtype=dtype, - ) - self._cached_spatial_coords_meta = meta - return self._cached_spatial_coords - - @staticmethod - def _pixelshuffle_hw_to_latent( - x: torch.Tensor, - *, - h8: int = 8, - w8: int = 8, - ) -> torch.Tensor: - return rearrange( - x, - "b c t (h h8) (w w8) -> b (c h8 w8) t h w", - h8=h8, - w8=w8, - ) - - def _coord_pixels_to_latents( - self, - coord_pixels: torch.Tensor, - *, - dtype: torch.dtype, - target_t: Optional[int] = None, - ) -> torch.Tensor: - """Convert warped coordinate pixels to latent grid via subsample+pixelshuffle.""" - frames_per_lat = int(self.framepack_num_frames_per_latent) - F = int(coord_pixels.shape[2]) - start = max(frames_per_lat - 1, 0) - time_idx = list(range(start, F, frames_per_lat)) - time_idx_t = torch.tensor(time_idx, device=coord_pixels.device, dtype=torch.long) - coord_sel = coord_pixels[:, :, time_idx_t] - coord_lat = self._pixelshuffle_hw_to_latent(coord_sel) - if target_t is not None and int(coord_lat.shape[2]) != int(target_t): - raise ValueError(f"Unexpected coord_lat.shape[2]={coord_lat.shape[2]} != target_t={target_t}") - return coord_lat.to(dtype=dtype) - - def _apply_camera_controls( - self, - cond_latent, - selected_idx, - video_indices, - spatial_selected_frame_ids: Optional[torch.Tensor] = None, - spatial_selected_coords: Optional[torch.Tensor] = None, - *, - video: Optional[torch.Tensor] = None, - camera_w2c: Optional[torch.Tensor] = None, - intrinsics: Optional[torch.Tensor] = None, - buffer_depth_B_1_H_W: Optional[torch.Tensor] = None, - spatial_cache: Optional[Sparse3DCache] = None, - is_training: bool = True, - ): - cfg = self.config - if video is None or camera_w2c is None or intrinsics is None: - return cond_latent, None - buffer_cond_latents: Optional[torch.Tensor] = None - T_new_lat = int(self.framepack_num_new_latent_frames) - # Depth warp / HD map condition overwriting the tail of cond_latent - with misc.timer("camera_pose_condition - build"): - device = cond_latent.device - B = int(cond_latent.shape[0]) - # Find the buffer frame absolute index - rel_buffer_idx = video_indices.shape[0] - self.framepack_num_new_video_frames - 1 - abs_buffer_idx = int(video_indices[rel_buffer_idx].item()) - abs_gen_indices = video_indices[-self.framepack_num_new_video_frames:] - - F = int(abs_gen_indices.numel()) - target_w2cs = camera_w2c[:, abs_gen_indices] - target_intrinsics = intrinsics[:, abs_gen_indices] - - spatial_condition_pixels_list: list[torch.Tensor] = [] - - with misc.timer("camera_pose_condition - corruptor"): - # _warp_multisrc: shared helper for accumulated PCD / correspondence warping. - def _warp_multisrc( - src_rgb: torch.Tensor, - src_depth: torch.Tensor, - src_w2c: torch.Tensor, - src_K: torch.Tensor, - *, - return_depth: bool = False, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - # Repeat sources across target frames, flatten (B,F) into batch, then chunk-wise warp. - src_rgb_bf = rearrange( - src_rgb.unsqueeze(1).repeat(1, F, 1, 1, 1, 1), - "b f n c h w -> (b f) n c h w", - ) - src_depth_bf = rearrange( - src_depth.unsqueeze(1).repeat(1, F, 1, 1, 1, 1), - "b f n c h w -> (b f) n c h w", - ) - src_w2c_bf = rearrange( - src_w2c.unsqueeze(1).repeat(1, F, 1, 1, 1), - "b f n c d -> (b f) n c d", - ) - src_K_bf = rearrange( - src_K.unsqueeze(1).repeat(1, F, 1, 1, 1), - "b f n c d -> (b f) n c d", - ) - tgt_w2c_bf = rearrange(target_w2cs.to(dtype=torch.float32), "b f c d -> (b f) c d") - tgt_K_bf = rearrange(target_intrinsics.to(dtype=torch.float32), "b f c d -> (b f) c d") - - warp_chunk_size = self.config.warp_chunk_size - warped_imgs_list: list[torch.Tensor] = [] - warped_masks_list: list[torch.Tensor] = [] - warped_depths_list: list[torch.Tensor] = [] - for i in range(0, int(src_rgb_bf.shape[0]), int(warp_chunk_size)): - w_img, w_mask, w_depth, _ = forward_warp_multiframes( - src_rgb_bf[i : i + warp_chunk_size], - mask1=None, - depth1=src_depth_bf[i : i + warp_chunk_size], - transformation1=src_w2c_bf[i : i + warp_chunk_size], - transformation2=tgt_w2c_bf[i : i + warp_chunk_size], - intrinsic1=src_K_bf[i : i + warp_chunk_size], - intrinsic2=tgt_K_bf[i : i + warp_chunk_size], - is_image=True, - render_depth=return_depth, - world_points1=None, - clean_points=True, - clean_points_continuity=True, - ) - warped_imgs_list.append(w_img) - warped_masks_list.append(w_mask) - if return_depth: - assert w_depth is not None - warped_depths_list.append(w_depth) - warped_imgs_bf = torch.cat(warped_imgs_list, dim=0) - # warped_masks_bf = torch.cat(warped_masks_list, dim=0) # currently unused downstream - warped_imgs_bf = warped_imgs_bf.contiguous() - warped_imgs_bf_reshaped = rearrange(warped_imgs_bf, "(b f) c h w -> b f c h w", b=B, f=F) - warped_imgs = warped_imgs_bf_reshaped.permute(0, 2, 1, 3, 4).contiguous() # [B,C,F,H,W] - if not return_depth: - return warped_imgs, None - warped_depths_bf = torch.cat(warped_depths_list, dim=0).contiguous() - if warped_depths_bf.ndim == 3: - warped_depths_bf = warped_depths_bf.unsqueeze(1) - warped_depths_bf_reshaped = rearrange( - warped_depths_bf, "(b f) c h w -> b f c h w", b=B, f=F - ) - warped_depths = warped_depths_bf_reshaped.permute(0, 2, 1, 3, 4).contiguous() # [B,1,F,H,W] - return warped_imgs, warped_depths - - # -- Multiview-safe accessors ---------------------------------------- - # For multiview inputs stored with negative frame_ids, retrieve - # camera/depth/rgb from spatial_cache instead of tensor indexing - # (which would wrap around with negative indices). - def _mv_w2c(f_id: int) -> torch.Tensor: - if int(f_id) < 0: - _, w, _ = spatial_cache.get_rgbd_by_frame_id(int(f_id)) - return w.to(device=device, dtype=torch.float32) - return camera_w2c[:, int(f_id)].to(dtype=torch.float32) - - def _mv_K(f_id: int) -> torch.Tensor: - if int(f_id) < 0: - _, _, k = spatial_cache.get_rgbd_by_frame_id(int(f_id)) - return k.to(device=device, dtype=torch.float32) - return intrinsics[:, int(f_id)].to(dtype=torch.float32) - - # -- end multiview-safe accessors ------------------------------------ - - # Accumulated correspondence warp: - # warp each spatial frame's canonical coordinates separately and - # append normalized depth as the fourth channel for each slot. - spatial_frame_ids: list[int] = ( - [int(x) for x in spatial_selected_frame_ids.tolist()] - if spatial_selected_frame_ids is not None - else [] - ) - spatial_coords: Optional[torch.Tensor] = spatial_selected_coords - - spatial_unique: list[int] = [] - seen_sp: set[int] = set() - keep_unique_idx: list[int] = [] - for j, idx in enumerate(spatial_frame_ids): - if idx not in seen_sp: - spatial_unique.append(int(idx)) - seen_sp.add(int(idx)) - keep_unique_idx.append(j) - spatial_frame_ids = spatial_unique - if spatial_coords is not None and spatial_coords.numel() > 0: - spatial_coords = spatial_coords[:, keep_unique_idx] if keep_unique_idx else spatial_coords[:, :0] - - max_spatial = cfg.multibuffer_max_spatial_frames - if max_spatial is None: - max_spatial = int(self.framepack_num_spatial_hist) - max_spatial = int(max_spatial) - - assert buffer_depth_B_1_H_W is not None, "buffer_depth_B_1_H_W is required for accumulated correspondence warping" - assert spatial_cache is not None and spatial_cache._store_values, ( - "spatial_cache(store_values=True) is required to provide depth for spatial frames" - ) - - # Buffer warp for main conditioning tail - buf_rgb = video[:, :, abs_buffer_idx].to(dtype=torch.float32).unsqueeze(1) # [B,1,C,H,W] - buf_depth = buffer_depth_B_1_H_W.to(device=device, dtype=torch.float32).unsqueeze(1) # [B,1,1,H,W] - buf_w2c = camera_w2c[:, abs_buffer_idx].to(dtype=torch.float32).unsqueeze(1) # [B,1,4,4] - buf_K = intrinsics[:, abs_buffer_idx].to(dtype=torch.float32).unsqueeze(1) # [B,1,3,3] - condition_state_pixels, _ = _warp_multisrc(buf_rgb, buf_depth, buf_w2c, buf_K) - - if max_spatial > 0: - spatial_latents: list[torch.Tensor] = [] - spatial_warped_coords: list[torch.Tensor] = [] - spatial_warped_depths: list[torch.Tensor] = [] - for j_rev in range(len(spatial_frame_ids) - 1, -1, -1): # due to left padding. - f_id = spatial_frame_ids[j_rev] - assert spatial_coords is not None, ( - "spatial_selected_coords is required for correspondence multibuffer warping" - ) - src_rgb = spatial_coords[:, j_rev : j_rev + 1].to( - device=device, dtype=torch.float32, non_blocking=True - ) # [B,1,3,H,W] - d, _w, _k = spatial_cache.get_rgbd_by_frame_id(int(f_id)) - src_depth = d.to(device=device, dtype=torch.float32, non_blocking=True).unsqueeze(1) - src_w2c = _mv_w2c(f_id).unsqueeze(1) - src_K = _mv_K(f_id).unsqueeze(1) - warped_coords, warped_depth = _warp_multisrc( - src_rgb, src_depth, src_w2c, src_K, return_depth=True - ) - assert warped_depth is not None - spatial_warped_depths.append(warped_depth) - spatial_condition_pixels_list.append(warped_coords) - spatial_warped_coords.append(warped_coords) - - depth_norm_per_spatial: list[torch.Tensor] = [] - if len(spatial_warped_depths) > 0: - depth_stack = torch.cat(spatial_warped_depths, dim=1) # [B,N,F,H,W] - dmin = depth_stack.amin(dim=(1, 2, 3, 4), keepdim=True) - dmax = depth_stack.amax(dim=(1, 2, 3, 4), keepdim=True) - depth_stack = 2.0 * (depth_stack - dmin) / torch.clamp(dmax - dmin, min=1e-6) - 1.0 - depth_norm_per_spatial = [depth_stack[:, i : i + 1] for i in range(int(depth_stack.shape[1]))] - - for i, warped_coords in enumerate(spatial_warped_coords): - warped_for_latent = torch.cat([warped_coords, depth_norm_per_spatial[i]], dim=1) - coord_lat = self._coord_pixels_to_latents( - warped_for_latent, dtype=cond_latent.dtype, target_t=T_new_lat, - ) - spatial_latents.append(coord_lat) - - if len(spatial_latents) < max_spatial: - H_lat = int(cond_latent.shape[3]) - W_lat = int(cond_latent.shape[4]) - pad_lat = torch.full( - (B, LYRA2_CORRESPONDENCE_CHANNELS_PER_SLOT, T_new_lat, H_lat, W_lat), - -1.0, - device=cond_latent.device, - dtype=cond_latent.dtype, - ) - pad_lat[:, 3 * 8 * 8 :, :, :, :] = 1.0 - spatial_latents.extend([pad_lat] * (max_spatial - len(spatial_latents))) - buffer_cond_latents = torch.cat(spatial_latents, dim=1) - if self._collect_return_condition_state: - if len(spatial_condition_pixels_list) > 0: - # Show buffer warp + all per-frame spatial warps (like multibuffer vis). - vis_list = [condition_state_pixels] + spatial_condition_pixels_list - condition_vis = torch.cat(vis_list, dim=1) - self._latest_condition_state_pixels = condition_vis - else: - self._latest_condition_state_pixels = condition_state_pixels - - with misc.timer("camera_pose_condition - encode camera tail"): - camera_latent = self.prepare_latent_conditon(condition_state_pixels, None, cond_latent.dtype) - assert int(camera_latent.shape[2]) == T_new_lat, ( - f"Unexpected camera latent T={camera_latent.shape[2]} != {T_new_lat}" - ) - cond_latent[:, :, -T_new_lat:] = camera_latent.type_as(cond_latent) - - # Plucker ray condition concatenated along channels - with misc.timer("plucker_condition - build"): - device = cond_latent.device - # Gather intrinsics and poses per selected latent (absolute indices) - K_sel = intrinsics[:, video_indices].to(device=device, dtype=torch.float32) - w2c_sel = camera_w2c[:, video_indices].to(device=device, dtype=torch.float32) - c2w_sel = torch.inverse(w2c_sel) - # c2w_ref for absolute pose mode: last history frame (buffer) - c2w_ref_inv = w2c_sel[:, -self.framepack_num_new_video_frames - 1:-self.framepack_num_new_video_frames] # [B,1,4,4] - - # Build intrinsics vector - fx = K_sel[..., 0, 0] - fy = K_sel[..., 1, 1] - cx = K_sel[..., 0, 2] - cy = K_sel[..., 1, 2] - K_vec = torch.stack([fx, fy, cx, cy], dim=-1) # [B, V, 4] - - # Downsample along time to align with VAE latents BEFORE computing rays - frames_per_lat = int(self.framepack_num_frames_per_latent) - v_len = int(K_sel.shape[1]) - # subsample long T to 0, 4, 8, 12, ... to match VAE latent indexing used elsewhere - time_idx = [0] + [i for i in range(4, v_len, frames_per_lat)] - time_idx_t = torch.tensor(time_idx, device=device, dtype=torch.long) - - # Select downsampled intrinsics and poses - K_vec_ds = K_vec[:, time_idx_t] # [B, T_ds, 4] - c2w_sel_ds = c2w_sel[:, time_idx_t] # [B, T_ds, 4, 4] - - # Compute relative camera-to-world transforms on the downsampled timeline - # absolute to the buffer (last history) pose - c2w_rel_ds = torch.matmul(c2w_ref_inv, c2w_sel_ds) # [B, T_ds, 4, 4] - - H_pix = int(video.shape[-2]) - W_pix = int(video.shape[-1]) - # Compute Plücker rays on the downsampled timeline - plucker_sel_B_T_H_W_6 = ray_condition( - K_vec_ds, - c2w_rel_ds, - H_pix, - W_pix, - device=device, - flip_flag=None, - use_ray_o=True, - ) # [B, T_ds, H, W, 6] - - # [B, T, H, W, 6] -> [B, 6, T, H, W] - plucker_5d = plucker_sel_B_T_H_W_6.permute(0, 4, 1, 2, 3).contiguous() - - # Spatial rearrange to latent grid: 6 -> 6*8*8 channels, then reorder by selected_idx - assert H_pix % 8 == 0 and W_pix % 8 == 0 - plucker_down384 = rearrange( - plucker_5d, - "b c t (h h8) (w w8) -> b (c h8 w8) t h w", - h8=8, - w8=8, - ) # [B,384,T,H/8,W/8] - plucker_down384 = plucker_down384[:, :, selected_idx] - - cond_latent = torch.cat([cond_latent, plucker_down384.type_as(cond_latent)], dim=1) - - return cond_latent, buffer_cond_latents - - @torch.no_grad() - def _prepare_lyra2_inputs( - self, - history_full: torch.Tensor, - gen_cond: torch.Tensor, - spatial_cache: Optional[Sparse3DCache], - video: torch.Tensor, - buffer_depth_B_1_H_W: Optional[torch.Tensor], - camera_w2c: torch.Tensor, - intrinsics: torch.Tensor, - video_indices: torch.Tensor, - *, - is_training: bool = True, - spatial_cache_skip_last_n: int = 0, - num_retrieval_views: int = 1, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """Prepare (latents, cond_latent, mask) for Lyra2Model. - - This consolidates the history selection + camera-control tail + optional image-based - spatial memory insertion. - - Args: - history_full: [B, C_lat, T_hist, H, W] - gen_cond: [B, C_lat, T_new, H, W] (conditioning tail latents, e.g. zero-tail) - camera_w2c: [B, T_video, 4, 4] - intrinsics: [B, T_video, 3, 3] - video_indices: [T_window] absolute indices into the original video timeline - spatial_cache: pre-built Sparse3DCache instance (built outside; inference uses AR-updated cache) - video: full video tensor for RGB lookup [B, C, T, H, W] - buffer_depth_B_1_H_W: last history-frame depth for warping (NOT stored in cache) - is_training: enables stochastic retrieval only during training - num_retrieval_views: Number of evenly-spaced target views within the generation - window used for multi-view coverage retrieval. 1 = single last-frame (legacy). - """ - cfg = self.config - device = history_full.device - - num_temporal_hist = int(self.framepack_num_temporal_hist) - num_spatial_hist = int(self.framepack_num_spatial_hist) - - T_hist_total = int(history_full.shape[2]) - T_new_lat = self.framepack_num_new_latent_frames - # We don't pass gen_gt into this helper: build a dummy tail for shape bookkeeping. - gen_lat_dummy = torch.zeros_like(gen_cond) - - temporal_selected = self._select_temporal_history_indices(T_hist_total, num_temporal_hist) - - spatial_selected_frame_ids_t: Optional[torch.Tensor] = None - spatial_selected: list[int] = [] - spatial_coords_all: Optional[torch.Tensor] = None - spatial_selected_coords: Optional[torch.Tensor] = None - - if num_spatial_hist > 0: - H = int(video.shape[-2]) - W = int(video.shape[-1]) - spatial_coords_all = self._get_cached_spatial_coords( - H=H, - W=W, - num_spatial_hist=num_spatial_hist, - device=device, - dtype=torch.float32, - ) # [num_spatial_hist, 3, H, W] - - # Pre-compute retrieval targets (multi-view or single last-frame). - if num_spatial_hist > 0 and num_retrieval_views > 1 and not is_training: - T_new_pix = int(self.framepack_num_new_video_frames) - gen_start = int(video_indices.shape[0]) - T_new_pix - pts = torch.linspace(0, T_new_pix - 1, num_retrieval_views + 1).long().tolist() - target_offsets = pts[1:] - target_abs_list = [int(video_indices[gen_start + off].item()) for off in target_offsets] - retrieval_w2c = torch.stack( - [camera_w2c[:, idx].to(device=device, dtype=torch.float32) for idx in target_abs_list], dim=1 - ) # [B, V, 4, 4] - retrieval_K = torch.stack( - [intrinsics[:, idx].to(device=device, dtype=torch.float32) for idx in target_abs_list], dim=1 - ) # [B, V, 3, 3] - log.info( - f"Multi-view retrieval: {num_retrieval_views} views at abs indices {target_abs_list}", - rank0_only=True, - ) - elif num_spatial_hist > 0: - last_abs = int(video_indices[-1].item()) - retrieval_w2c = camera_w2c[:, last_abs].to(device=device, dtype=torch.float32) # [B, 4, 4] - retrieval_K = intrinsics[:, last_abs].to(device=device, dtype=torch.float32) # [B, 3, 3] - - if cfg.spatial_memory_use_image and num_spatial_hist > 0: - assert spatial_cache is not None, "spatial_cache is required when spatial_memory_use_image=True" - H = int(video.shape[-2]) - W = int(video.shape[-1]) - - retrieved = spatial_cache.retrieve( - retrieval_w2c, - retrieval_K, - (H, W), - num_latents=num_spatial_hist, - skip_last_n=int(spatial_cache_skip_last_n), - random=bool(is_training), - max_coverage=not bool(is_training), - ) - spatial_selected_frame_ids_t = torch.tensor([int(fi) for (_li, fi) in retrieved], device=device, dtype=torch.long) - if spatial_coords_all is not None: - n_retrieved = int(spatial_selected_frame_ids_t.numel()) - offset = max(0, int(num_spatial_hist - n_retrieved)) - coords_sel = spatial_coords_all[offset : offset + n_retrieved] - if coords_sel.numel() > 0: - B0 = int(history_full.shape[0]) - spatial_selected_coords = coords_sel.unsqueeze(0).repeat(B0, 1, 1, 1, 1) - - # Post-retrieval spatial memory dropout: randomly drop 1..N of selected spatial frames. - if is_training and float(cfg.spatial_memory_drop_rate) > 0: - N_sel = int(spatial_selected_frame_ids_t.numel()) if spatial_selected_frame_ids_t is not None else 0 - if N_sel > 0 and torch.rand(1, device=device).item() < float(cfg.spatial_memory_drop_rate): - num_drop = int(torch.randint(1, N_sel + 1, (1,), device=device).item()) - perm = torch.randperm(N_sel, device=device) - keep_idx = perm[num_drop:].sort().values - spatial_selected_frame_ids_t = spatial_selected_frame_ids_t[keep_idx] - # Re-align canonical coords with the new count so that the right-aligned - # offset matches the left-padding of image tokens at s-positions. - if spatial_coords_all is not None: - n_post = int(spatial_selected_frame_ids_t.numel()) - offset_post = max(0, int(num_spatial_hist - n_post)) - coords_post = spatial_coords_all[offset_post : offset_post + n_post] - if coords_post.numel() > 0: - B0 = int(history_full.shape[0]) - spatial_selected_coords = coords_post.unsqueeze(0).repeat(B0, 1, 1, 1, 1) - else: - spatial_selected_coords = None - log.info( - f"Spatial memory dropout: dropped {num_drop}/{N_sel}, kept {N_sel - num_drop}", - rank0_only=True, - ) - - # For the base history selection, use temporal-only indices (spatial slots will be inserted later). - splits_temp = [s for s, t in zip(self.framepack_clean_latent_frame_splits, self.framepack_clean_latent_frame_kernel_types) if t == "k"] - types_temp = ["k"] * len(splits_temp) - selected_idx_hist = self._compose_selected_indices( - splits=splits_temp, - types=types_temp, - T_hist_total=T_hist_total, - temporal_selected=temporal_selected, - spatial_selected=[], - ) - elif num_spatial_hist > 0: - raise NotImplementedError( - "Lyra2Model is collapsed to the image-token target branch and requires spatial_memory_use_image=True" - ) - else: - selected_idx_hist = self._compose_selected_indices( - splits=self.framepack_clean_latent_frame_splits, - types=self.framepack_clean_latent_frame_kernel_types, - T_hist_total=T_hist_total, - temporal_selected=temporal_selected, - spatial_selected=[], - ) - # Mask: history always clean; generation tail masked out only when no pose conditioning. - B, C_lat, _, H_lat, W_lat = history_full.shape - mask_hist = torch.ones(B, 4, int(selected_idx_hist.shape[0]), H_lat, W_lat, dtype=history_full.dtype, device=device) - mask_gen = torch.ones(B, 4, T_new_lat, H_lat, W_lat, dtype=history_full.dtype, device=device) - - # Reorder history and concatenate tails. - latents_hist = history_full[:, :, selected_idx_hist] - cond_hist = latents_hist - latents = torch.cat([latents_hist, gen_lat_dummy], dim=2) - cond_latent = torch.cat([cond_hist, gen_cond], dim=2) - mask = torch.cat([mask_hist, mask_gen], dim=2) - - # Build full selected_idx for plucker alignment (history indices + gen indices). - gen_idx = torch.arange(T_hist_total, T_hist_total + T_new_lat, device=device, dtype=torch.long) - selected_idx_full = torch.cat([selected_idx_hist, gen_idx], dim=0) - - # Camera controls (depth warp + plucker) applied after reordering. - cond_latent, buffer_cond_latents = self._apply_camera_controls( - cond_latent, - selected_idx_full, - video_indices, - spatial_selected_frame_ids=spatial_selected_frame_ids_t, - spatial_selected_coords=spatial_selected_coords, - video=video, - camera_w2c=camera_w2c, - intrinsics=intrinsics, - buffer_depth_B_1_H_W=buffer_depth_B_1_H_W, - spatial_cache=spatial_cache, - is_training=bool(is_training), - ) - - # Optional image-based spatial memory insertion (encode retrieved frames and interleave). - if cfg.spatial_memory_use_image: - spatial_image_ids: list[int] = spatial_selected_frame_ids_t.tolist() if spatial_selected_frame_ids_t is not None and spatial_selected_frame_ids_t.numel() > 0 else [] - if self.framepack_num_spatial_hist <= 0: - spatial_image_ids = [] - if len(spatial_image_ids) < self.framepack_num_spatial_hist: - spatial_image_ids = [video_indices[0].item()] * (self.framepack_num_spatial_hist - len(spatial_image_ids)) + spatial_image_ids - - spatial_latents_list = [] - for t in spatial_image_ids: - if int(t) < 0: - mv_rgb = spatial_cache.get_rgb_by_frame_id(int(t)).to(device=device, dtype=video.dtype) - if mv_rgb.dim() == 4: - mv_rgb = mv_rgb.unsqueeze(2) - spatial_latents_list.append(self.encode(mv_rgb)) - else: - spatial_latents_list.append(self.encode(video[:, :, t : t + 1])) - spatial_latents = torch.cat(spatial_latents_list, dim=2) - - spatial_plucker = None - if spatial_image_ids: - # Multiview-safe: build K_sel/w2c_sel per-element for negative frame IDs. - if any(int(t) < 0 for t in spatial_image_ids): - K_list, w2c_list = [], [] - for t in spatial_image_ids: - if int(t) < 0: - _, w, k = spatial_cache.get_rgbd_by_frame_id(int(t)) - w2c_list.append(w.to(device=device, dtype=torch.float32)) - K_list.append(k.to(device=device, dtype=torch.float32)) - else: - w2c_list.append(camera_w2c[:, int(t)].to(dtype=torch.float32)) - K_list.append(intrinsics[:, int(t)].to(dtype=torch.float32)) - K_sel = torch.stack(K_list, dim=1) # [B, T_sp, 3, 3] - w2c_sel = torch.stack(w2c_list, dim=1) - else: - spatial_image_ids_t = torch.tensor(spatial_image_ids, device=device, dtype=torch.long) - K_sel = intrinsics[:, spatial_image_ids_t] # [B, T_sp, 3, 3] - w2c_sel = camera_w2c[:, spatial_image_ids_t] - c2w_sel = torch.inverse(w2c_sel) - ref_idx = video_indices[-int(self.framepack_num_new_video_frames) - 1] - c2w_ref_inv = camera_w2c[:, ref_idx : ref_idx + 1] - c2w_rel = torch.matmul(c2w_ref_inv, c2w_sel) - fx = K_sel[..., 0, 0] - fy = K_sel[..., 1, 1] - cx = K_sel[..., 0, 2] - cy = K_sel[..., 1, 2] - K_vec_sp = torch.stack([fx, fy, cx, cy], dim=-1) # [B, T_sp, 4] - plucker = ray_condition( - K_vec_sp, - c2w_rel, - int(video.shape[-2]), - int(video.shape[-1]), - device=device, - flip_flag=None, - use_ray_o=True, - ) - plucker_5d = plucker.permute(0, 4, 1, 2, 3).contiguous() - spatial_plucker = rearrange( - plucker_5d, - "b c t (h h8) (w w8) -> b (c h8 w8) t h w", - h8=8, - w8=8, - ) - - # Insert spatial latents into history according to (splits, types), keep gen tail at end. - final_latents = [] - final_cond = [] - final_mask = [] - gen_ptr = latents.shape[2] - int(self.framepack_num_new_latent_frames) - current_temp_lat = latents[:, :, :gen_ptr] - current_gen_lat = latents[:, :, gen_ptr:] - current_temp_cond = cond_latent[:, :, :gen_ptr] - current_gen_cond = cond_latent[:, :, gen_ptr:] - current_temp_mask = mask[:, :, :gen_ptr] - current_gen_mask = mask[:, :, gen_ptr:] - t_cursor = 0 - s_cursor = 0 - B0 = latents.shape[0] - for s_cnt, tp in zip(self.framepack_clean_latent_frame_splits, self.framepack_clean_latent_frame_kernel_types): - if tp == "k": - final_latents.append(current_temp_lat[:, :, t_cursor : t_cursor + s_cnt]) - final_cond.append(current_temp_cond[:, :, t_cursor : t_cursor + s_cnt]) - final_mask.append(current_temp_mask[:, :, t_cursor : t_cursor + s_cnt]) - t_cursor += int(s_cnt) - elif tp == "s": - chunk_lat = spatial_latents[:, :, s_cursor : s_cursor + s_cnt] - final_latents.append(chunk_lat) - if spatial_plucker is not None: - chunk_plucker = spatial_plucker[:, :, s_cursor : s_cursor + s_cnt] - chunk_cond = torch.cat([chunk_lat, chunk_plucker.type_as(chunk_lat)], dim=1) - else: - chunk_cond = chunk_lat - final_cond.append(chunk_cond) - chunk_mask = torch.ones( - B0, - 4, - int(s_cnt), - latents.shape[3], - latents.shape[4], - device=device, - dtype=latents.dtype, - ) - final_mask.append(chunk_mask) - s_cursor += int(s_cnt) - - final_latents.append(current_gen_lat) - final_cond.append(current_gen_cond) - final_mask.append(current_gen_mask) - latents = torch.cat(final_latents, dim=2) - cond_latent = torch.cat(final_cond, dim=2) - mask = torch.cat(final_mask, dim=2) - - # Pad buffer_cond_latents to full length and inject spatial 3D coordinates. - B0 = int(cond_latent.shape[0]) - H_lat = int(cond_latent.shape[3]) - W_lat = int(cond_latent.shape[4]) - T_hist = int(cond_latent.shape[2]) - int(self.framepack_num_new_latent_frames) - - _max_spatial = cfg.multibuffer_max_spatial_frames - if _max_spatial is None: - _max_spatial = int(self.framepack_num_spatial_hist) - _max_spatial = int(_max_spatial) - - C_coord_lat = LYRA2_CORRESPONDENCE_CHANNELS_PER_SLOT * _max_spatial - - def _coords_to_slotted_latent(coords_b: torch.Tensor) -> torch.Tensor: - """[B, 3, T, H, W] raw pixel coords → [B, _max_spatial*4*8*8, T, H_lat, W_lat].""" - d_minus = torch.full( - (B0, 1, coords_b.shape[2], coords_b.shape[3], coords_b.shape[4]), - -1.0, device=coords_b.device, dtype=coords_b.dtype, - ) - d_plus = torch.ones_like(d_minus) - lat_minus = self._pixelshuffle_hw_to_latent( - torch.cat([coords_b, d_minus], dim=1) - ).to(dtype=cond_latent.dtype) - lat_plus = self._pixelshuffle_hw_to_latent( - torch.cat([coords_b, d_plus], dim=1) - ).to(dtype=cond_latent.dtype) - if _max_spatial <= 0: - return lat_minus[:, :0] - if _max_spatial == 1: - return lat_minus - return torch.cat([lat_minus] + [lat_plus] * (_max_spatial - 1), dim=1) - - buffer_hist_chunks: list[torch.Tensor] = [] - C_buf = int(buffer_cond_latents.shape[1]) if buffer_cond_latents is not None else C_coord_lat - - # Convert spatial coords to slotted latent space. - spatial_coords_latent: Optional[torch.Tensor] = None - if spatial_coords_all is not None and num_spatial_hist > 0: - spatial_coords_all_b = spatial_coords_all.unsqueeze(0).permute(0, 2, 1, 3, 4).repeat(B0, 1, 1, 1, 1) - spatial_coords_latent = _coords_to_slotted_latent(spatial_coords_all_b) - - # Assemble per-chunk history buffer with debug logging. - s_cursor = 0 - _debug_labels: list[str] = [] - for _pos, (s_cnt, tp) in enumerate(zip(self.framepack_clean_latent_frame_splits, self.framepack_clean_latent_frame_kernel_types)): - s_cnt = int(s_cnt) - if tp == "s" and spatial_coords_latent is not None: - chunk = spatial_coords_latent[:, :, s_cursor : s_cursor + s_cnt] - buffer_hist_chunks.append(chunk) - s_cursor += int(s_cnt) - _debug_labels.append(f"pos{_pos}:f{s_cnt}{tp}=spatial_coord") - else: - buffer_hist_chunks.append( - torch.zeros((B0, C_buf, s_cnt, H_lat, W_lat), device=cond_latent.device, dtype=cond_latent.dtype) - ) - _debug_labels.append(f"pos{_pos}:f{s_cnt}{tp}=zeros") - log.info( - f"Buffer hist padding: {_debug_labels}, C_buf={C_buf}, _max_spatial={_max_spatial}", - rank0_only=True, - ) - - buffer_hist = torch.cat(buffer_hist_chunks, dim=2) if buffer_hist_chunks else torch.zeros( - (B0, C_buf, T_hist, H_lat, W_lat), - device=cond_latent.device, - dtype=cond_latent.dtype, - ) - if buffer_cond_latents is None: - buffer_tail = torch.zeros( - (B0, C_buf, T_new_lat, H_lat, W_lat), - device=cond_latent.device, - dtype=cond_latent.dtype, - ) - else: - buffer_tail = buffer_cond_latents.to(dtype=cond_latent.dtype) - buffer_cond_latents = torch.cat([buffer_hist, buffer_tail], dim=2) - - return latents, cond_latent, mask, buffer_cond_latents - - @torch.no_grad() - def _tokenizing_video_to_latents(self, video, dropout=False, data_batch=None): - cfg = self.config - assert data_batch is not None, "Lyra2Model._tokenizing_video_to_latents requires data_batch" - with misc.timer("_tokenizing_video_to_latents(spatial) - total"): - # Step 1: windowing - video, video_indices, start, cur_segment_id, chunk_len = self._prepare_video_window( - video, - data_batch.get("start") if data_batch is not None else None, - data_batch.get("cur_segment_id") if data_batch is not None else None, - ) - - # Encode latents and cond_latent with shared prefix - with misc.timer("vae_encoding - shared prefix"): - if cfg.self_aug_enabled: - if "_stage_a_vae_cache_T-2" not in data_batch: # self aug step. Save vae cache - out = self._vae_encode_with_shared_prefix(video, None, return_cache=True) - latents, cond_latent, cache_after_prefix, cache_current = cast(tuple[torch.Tensor, torch.Tensor, Any, Any], out) - data_batch["_stage_a_full_latents"] = latents.clone() - data_batch["_stage_a_vae_cache_T-2"] = cache_after_prefix - data_batch["_stage_a_vae_cache_T-1"] = cache_current - else: - # Temporal slices along T dimension - prev_gen_chunk_aug = video[:, :, -2 * self.framepack_num_new_video_frames : -1 * self.framepack_num_new_video_frames] # self-augmented previous chunk - curr_gen_chunk = video[:, :, -1 * self.framepack_num_new_video_frames :] # clean current chunk - # 1) Encode self-augmented previous-chunk - feat1_enc, cache_after_prev = self.vae_encode_with_cache( - data_batch["_stage_a_vae_cache_T-2"], - prev_gen_chunk_aug, - start_t=0, - end_t=prev_gen_chunk_aug.shape[2], - return_cache=True, - ) - # 2) Encode zero-tail for the next chunk - zeros_last = torch.zeros_like(prev_gen_chunk_aug) - feat2_enc = self.vae_encode_with_cache( - cache_after_prev, - zeros_last, - start_t=0, - end_t=zeros_last.shape[2], - return_cache=False, - ) - # 3) Encode GT with clean cache - feat3_enc = self.vae_encode_with_cache( - data_batch["_stage_a_vae_cache_T-1"], - curr_gen_chunk, - start_t=0, - end_t=curr_gen_chunk.shape[2], - return_cache=False, - ) - # Convert encoder feats to normalized latents using shared helper - lat1 = self._encoder_feats_to_normalized_latents(feat1_enc) - lat2 = self._encoder_feats_to_normalized_latents(feat2_enc) - lat3 = self._encoder_feats_to_normalized_latents(feat3_enc) - # Cast to input dtype and stitch - in_dtype = video.dtype - lat1 = lat1.contiguous().to(in_dtype) - lat2 = lat2.contiguous().to(in_dtype) - lat3 = lat3.contiguous().to(in_dtype) - # replace previous chunk with self-augmented latents, and concatenate with clean gt / zero latents - latents = torch.cat([data_batch["_stage_a_full_latents"][:, :, :-self.framepack_num_new_latent_frames], lat1, lat3], dim=2) - cond_latent = torch.cat([data_batch["_stage_a_full_latents"][:, :, :-self.framepack_num_new_latent_frames], lat1, lat2], dim=2) - del data_batch["_stage_a_full_latents"] - else: - out2 = self._vae_encode_with_shared_prefix(video, None, return_cache=False) - latents, cond_latent = cast(tuple[torch.Tensor, torch.Tensor], out2) - - history_full = latents[:, :, : -self.framepack_num_new_latent_frames] - gen_gt = latents[:, :, -self.framepack_num_new_latent_frames :] - gen_cond = cond_latent[:, :, -self.framepack_num_new_latent_frames :] - - # Build Sparse3DCache OUTSIDE _prepare_lyra2_inputs. - # In inference, this cache is maintained incrementally during AR, so the build logic differs. - splits = self.framepack_clean_latent_frame_splits - types = self.framepack_clean_latent_frame_kernel_types - num_spatial_hist = int(sum(s for s, t in zip(splits, types) if t == "s")) - num_temporal_hist = int(sum(s for s, t in zip(splits, types) if t == "k")) - _ = num_temporal_hist # explicit for readability - - spatial_cache: Optional[Sparse3DCache] = None - buffer_depth_B_1_H_W: Optional[torch.Tensor] = None - # Last history frame (buffer) absolute index; depth passed explicitly for warping and NOT cached. - rel_buffer_idx = video_indices.shape[0] - self.framepack_num_new_video_frames - 1 - abs_buffer_idx = int(video_indices[rel_buffer_idx].item()) - buffer_depth_B_1_H_W = data_batch["depth"][:, abs_buffer_idx].to(device=latents.device, dtype=torch.float32) - if buffer_depth_B_1_H_W.dim() == 3: - buffer_depth_B_1_H_W = buffer_depth_B_1_H_W.unsqueeze(1) - - if num_spatial_hist > 0: - spatial_cache = Sparse3DCache( - downsample=4, - store_device=str(latents.device), - store_values=True, - ) - - if cfg.spatial_memory_use_image: - # Image-based spatial memory: cache over global timeline, excluding frames near the current window. - is_i2v = bool(data_batch.get("is_i2v", False)) - use_only_first = is_i2v - if use_only_first: - log.info( - f"Spatial memory: only use first frame (is_i2v={is_i2v}).", - rank0_only=True, - ) - if int(data_batch["video"].shape[2]) > 0: - spatial_cache.add( - data_batch["depth"][:, video_indices[0].item()].to(device=latents.device, dtype=torch.float32), - data_batch["camera_w2c"][:, video_indices[0].item()].to(device=latents.device, dtype=torch.float32), - data_batch["intrinsics"][:, video_indices[0].item()].to(device=latents.device, dtype=torch.float32), - latent_index=0, - frame_id=video_indices[0].item(), - ) - else: - skip_recent = int(cfg.spatial_memory_skip_recent) - stride = max(int(cfg.spatial_memory_stride), 1) - t0 = int(video_indices[-int(self.framepack_num_new_video_frames)].item()) - t1 = int(video_indices[-1].item()) - abs_buffer_idx = int(video_indices[video_indices.shape[0] - self.framepack_num_new_video_frames - 1].item()) - for t in range(int(data_batch["video"].shape[2])): - if t == abs_buffer_idx: - continue - if t < (t0 - skip_recent) or t > (t1 + skip_recent): - if (t % stride == 0) and t != 0: - spatial_cache.add( - data_batch["depth"][:, t].to(device=latents.device, dtype=torch.float32), - data_batch["camera_w2c"][:, t].to(device=latents.device, dtype=torch.float32), - data_batch["intrinsics"][:, t].to(device=latents.device, dtype=torch.float32), - latent_index=int(t), - frame_id=int(t), - ) - - # Prepare final (latents, cond_latent, mask) using the unified helper. - with misc.timer("post - prepare_lyra2_inputs"): - latents, cond_latent, mask, buffer_cond_latents = self._prepare_lyra2_inputs( - history_full=history_full, - gen_cond=gen_cond, - spatial_cache=spatial_cache, - video=data_batch["video"].to(device=latents.device, dtype=latents.dtype), - buffer_depth_B_1_H_W=buffer_depth_B_1_H_W, - camera_w2c=data_batch["camera_w2c"], - intrinsics=data_batch["intrinsics"], - video_indices=video_indices, - is_training=True, - ) - data_batch["cond_latent_buffer"] = buffer_cond_latents - # Replace dummy tail latents with the actual ground-truth tail for training. - latents = torch.cat([latents[:, :, : -self.framepack_num_new_latent_frames], gen_gt], dim=2) - - # Recompute visualization rays so they align with final latent order. - if self._collect_return_condition_state: - C_lat = latents.shape[1] - plucker_grid = cond_latent[:, C_lat:, :, :, :] # [B, 6*8*8, T_lat, H_lat, W_lat] - if plucker_grid.numel() > 0: - plucker_lat = rearrange( - plucker_grid, - "b (c h8 w8) t h w -> b t (h h8) (w w8) c", - h8=8, - w8=8, - ) # [B, T_lat, H_pix, W_pix, 6] - B_vis, T_lat, H_pix, W_pix, _ = plucker_lat.shape - F = int(self.framepack_num_frames_per_latent) - if T_lat <= 1 or F <= 1: - plucker_vis = plucker_lat - else: - head = plucker_lat[:, :1] # first latent -> 1 frame - tail = plucker_lat[:, 1:] # remaining latents - tail_rep = tail.unsqueeze(2).repeat(1, 1, F, 1, 1, 1) - tail_rep = tail_rep.reshape(B_vis, (T_lat - 1) * F, H_pix, W_pix, 6) - plucker_vis = torch.cat([head, tail_rep], dim=1) # [B, 1 + (T_lat-1)*F, H, W, 6] - M_all = plucker_vis[..., 0:3] - d_all = plucker_vis[..., 3:6] - self._latest_plucker_rays_pixels = { - "ray_origin": M_all.detach(), - "ray_direction": d_all.detach(), - } - - # Optional corruption/augmentation (must always run, regardless of visualization flags). - if self.config.apply_corruption_to_spatial_region != "none": - self._apply_spatial_region_corruption(latents, cond_latent) - - # last history frame in pixel space (from pre-encoded video timeline) - last_hist_frame = video[:, :, -self.framepack_num_new_video_frames - 1].clone() - - data_batch["_stage_a_video_indices"] = video_indices - data_batch["_stage_a_start"] = start - data_batch["_stage_a_cur_segment_id"] = cur_segment_id - data_batch["_stage_a_chunk_len"] = chunk_len - data_batch["_stage_a_video_shape"] = video.shape - - if "t5_chunk_keys" in data_batch: - # absolute index of the last history frame (first frame before generation) - rel_gen_first_idx = int(video_indices[-int(self.framepack_num_new_video_frames)].item()) - sample_frame_indices = data_batch["sample_frame_indices"] # [B, F] - t5_chunk_keys = data_batch["t5_chunk_keys"] # [B, K] - t5_chunk_embeddings = data_batch["t5_chunk_embeddings"] # [B, K, 512, 4096] - t5_chunk_mask = data_batch["t5_chunk_mask"] # [B, K, 512] - assert torch.is_tensor(sample_frame_indices) and torch.is_tensor(t5_chunk_keys) - assert torch.is_tensor(t5_chunk_embeddings) and torch.is_tensor(t5_chunk_mask) - B = int(t5_chunk_keys.shape[0]) - # Per-sample absolute index into original sequence - first_abs_idx_B = sample_frame_indices[:, rel_gen_first_idx].to(dtype=torch.long) # [B] - selected_emb_list = [] - selected_mask_list = [] - for b in range(B): - keys_b = t5_chunk_keys[b] # [K], ascending - K = int(keys_b.numel()) - val = int(first_abs_idx_B[b].item()) - # strictly smaller w.r.t first index where key > val, then minus 1 - pos = torch.searchsorted(keys_b, torch.tensor([val], device=keys_b.device, dtype=keys_b.dtype), right=True).item() - sel_idx = max(0, min(int(pos) - 1, K - 1)) - emb_b = t5_chunk_embeddings[b, sel_idx] # [512, 4096] - msk_b = t5_chunk_mask[b, sel_idx] # [512] - - sel_key = int(keys_b[sel_idx].item()) if K > 0 else -1 - - selected_emb_list.append(emb_b) - selected_mask_list.append(msk_b) - data_batch["t5_text_embeddings"] = torch.stack(selected_emb_list, dim=0) # [B, 512, 4096] - data_batch["t5_text_mask"] = torch.stack(selected_mask_list, dim=0) # [B, 512] - return latents, last_hist_frame, cond_latent, mask - - -class Sparse3DCache: - def __init__( - self, - downsample: int = 4, - store_device: str = "cuda", - store_values: bool = False, - ) -> None: - self.downsample = int(downsample) - self._store_device = str(store_device) - self._store_values = bool(store_values) - self._world_points: list[torch.Tensor] = [] # each: [B, H', W', 3] - self._latent_indices: list[int] = [] # latent index per entry - self._frame_ids: list[int] = [] # original video frame id per entry - # Optional raw RGBD camera storage for value lookup (used in inference warping). - self._depths: list[torch.Tensor] = [] - self._w2cs: list[torch.Tensor] = [] - self._Ks: list[torch.Tensor] = [] - # Multiview RGB storage keyed by frame_id (only populated for multiview inputs). - self._rgbs: dict[int, torch.Tensor] = {} - - @staticmethod - def _scale_intrinsics(intrinsic: torch.Tensor, scale: float) -> torch.Tensor: - """Scale pinhole intrinsics for a downsampled grid by factor `scale` (e.g., 1/4).""" - assert intrinsic.dim() == 3 and intrinsic.shape[-2:] == (3, 3) - K = intrinsic.clone() - K[:, 0, 0] = K[:, 0, 0] * scale - K[:, 1, 1] = K[:, 1, 1] * scale - K[:, 0, 2] = K[:, 0, 2] * scale - K[:, 1, 2] = K[:, 1, 2] * scale - return K - - def add( - self, - depth_B_1_H_W: torch.Tensor, - w2c_B_4_4: torch.Tensor, - K_B_3_3: torch.Tensor, - latent_index: int, - frame_id: Optional[int] = None, - ) -> None: - ds = self.downsample - # Subsample depth and scale intrinsics accordingly - depth_ds = depth_B_1_H_W[:, :, ::ds, ::ds] - scale = 1.0 / float(ds) - K_scaled = self._scale_intrinsics(K_B_3_3, scale) - # Valid mask where depth > 0 - mask_valid = (depth_ds > 0) - world_pts: torch.Tensor = unproject_points( - depth=depth_ds, - w2c=w2c_B_4_4, - intrinsic=K_scaled, - is_depth=True, - is_ftheta=False, - mask=mask_valid, - return_sparse=False, - ) # [B, H', W', 3] - if self._store_device == "cpu": - world_pts = world_pts.detach().to("cpu", non_blocking=True) - self._world_points.append(world_pts) - self._latent_indices.append(int(latent_index)) - self._frame_ids.append(int(latent_index) if frame_id is None else int(frame_id)) - if self._store_values: - # Store full-res values (no downsample) for later retrieval by frame_id. - d = depth_B_1_H_W.detach() - w = w2c_B_4_4.detach() - k = K_B_3_3.detach() - if self._store_device == "cpu": - d = d.to("cpu", non_blocking=True) - w = w.to("cpu", non_blocking=True) - k = k.to("cpu", non_blocking=True) - self._depths.append(d) - self._w2cs.append(w) - self._Ks.append(k) - - def store_rgb(self, frame_id: int, rgb: torch.Tensor) -> None: - """Store RGB pixels for a frame (used for multiview inputs with negative frame_id).""" - t = rgb.detach() - if self._store_device == "cpu": - t = t.to("cpu", non_blocking=True) - self._rgbs[int(frame_id)] = t - - def get_rgb_by_frame_id(self, frame_id: int) -> torch.Tensor: - """Return stored RGB for a frame_id. Raises KeyError if not found.""" - fid = int(frame_id) - if fid not in self._rgbs: - raise KeyError(f"frame_id={fid} not found in Sparse3DCache RGB storage") - return self._rgbs[fid] - - def get_rgbd_by_frame_id(self, frame_id: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Return (depth, w2c, K) for a stored frame_id. - - Requires store_values=True at construction time. - """ - if not self._store_values: - raise RuntimeError("Sparse3DCache.get_rgbd_by_frame_id requires store_values=True") - # Prefer most recent match if duplicated. - for i in range(len(self._frame_ids) - 1, -1, -1): - if int(self._frame_ids[i]) == int(frame_id): - return self._depths[i], self._w2cs[i], self._Ks[i] - raise KeyError(f"frame_id={int(frame_id)} not found in Sparse3DCache") - - def update_by_frame_id( - self, - frame_id: int, - depth_B_1_H_W: torch.Tensor, - w2c_B_4_4: torch.Tensor, - K_B_3_3: torch.Tensor, - ) -> bool: - """Replace depth/w2c/K and recompute world points for an existing frame_id. - - Returns True if the frame was found and updated, False otherwise. - """ - fid = int(frame_id) - idx = None - for i in range(len(self._frame_ids)): - if int(self._frame_ids[i]) == fid: - idx = i - break - if idx is None: - return False - - # Ensure all tensors are on the same device for unproject_points. - compute_device = depth_B_1_H_W.device - _depth = depth_B_1_H_W.to(compute_device) - _w2c = w2c_B_4_4.to(compute_device) - _K = K_B_3_3.to(compute_device) - - ds = self.downsample - depth_ds = _depth[:, :, ::ds, ::ds] - scale = 1.0 / float(ds) - K_scaled = self._scale_intrinsics(_K, scale) - mask_valid = (depth_ds > 0) - world_pts: torch.Tensor = unproject_points( - depth=depth_ds, - w2c=_w2c, - intrinsic=K_scaled, - is_depth=True, - is_ftheta=False, - mask=mask_valid, - return_sparse=False, - ) - if self._store_device == "cpu": - world_pts = world_pts.detach().to("cpu", non_blocking=True) - self._world_points[idx] = world_pts - if self._store_values: - d = depth_B_1_H_W.detach() - w = w2c_B_4_4.detach() - k = K_B_3_3.detach() - if self._store_device == "cpu": - d = d.to("cpu", non_blocking=True) - w = w.to("cpu", non_blocking=True) - k = k.to("cpu", non_blocking=True) - self._depths[idx] = d - self._w2cs[idx] = w - self._Ks[idx] = k - return True - - - @torch.no_grad() - def retrieve( - self, - target_w2c_B_4_4: torch.Tensor, - target_K_B_3_3: torch.Tensor, - target_hw: tuple[int, int], - num_latents: int, - skip_last_n: int = 0, - random: bool = False, - max_coverage: bool = False, - depth_threshold: float = 0.1, - ) -> list[tuple[int, int]]: - """Retrieve the best candidate frames from the cache. - - Args: - target_w2c_B_4_4: Target world-to-camera matrices. - Single view [B, 4, 4] or multi-view [B, V, 4, 4]. - target_K_B_3_3: Target intrinsics. - Single view [B, 3, 3] or multi-view [B, V, 3, 3]. - target_hw: (H, W) of the target image in pixels. - num_latents: Maximum number of candidates to return. - skip_last_n: Skip the most recent N entries in the cache. - random: Stochastic sampling (training only). - max_coverage: Greedy set-cover maximizing pixel coverage. - When multi-view targets are given, coverage is maximized - across the union of all views' pixels. - depth_threshold: Tolerance for depth-based occlusion filtering. - """ - Ht, Wt = target_hw - num_total = len(self._world_points) - if num_total == 0 or num_latents <= 0: - return [] - device = target_w2c_B_4_4.device - ds = self.downsample - scale = 1.0 / float(ds) - Ht_ds = int((Ht + ds - 1) // ds) - Wt_ds = int((Wt + ds - 1) // ds) - - # Handle multi-view [B, V, 4, 4] vs single-view [B, 4, 4] - if target_w2c_B_4_4.dim() == 4: - num_views = int(target_w2c_B_4_4.shape[1]) - w2c_views = [target_w2c_B_4_4[:, v] for v in range(num_views)] - K_views = [target_K_B_3_3[:, v] for v in range(num_views)] - else: - num_views = 1 - w2c_views = [target_w2c_B_4_4] - K_views = [target_K_B_3_3] - - s = int(skip_last_n) if skip_last_n is not None else 0 - avail = max(0, num_total - max(0, s)) - if avail <= 0: - return [] - - num_cands = avail - pts_list = self._world_points[:avail] - - # Vectorized projection of all (view, candidate) pairs at once. - # pts_stacked: [C, B, H', W', 3] - pts_stacked = torch.stack([p.to(device=device) for p in pts_list], dim=0) - C, Bp, Hp, Wp, _ = pts_stacked.shape - - # Homogeneous coordinates: [C, B, H', W', 4, 1] - ones_hw = torch.ones(C, Bp, Hp, Wp, 1, device=device, dtype=pts_stacked.dtype) - pts_homo = torch.cat([pts_stacked, ones_hw], dim=-1).unsqueeze(-1) - - # Build per-view downsampled intrinsics: [V, B, 3, 3] and [V, B, 4, 4] - K_ds_views = [self._scale_intrinsics(K_v, scale) for K_v in K_views] - w2c_stack = torch.stack(w2c_views, dim=0) # [V, B, 4, 4] - K_ds_stack = torch.stack(K_ds_views, dim=0) # [V, B, 3, 3] - - # Broadcast matmul: w2c [V,1,B,1,1,4,4] x pts_homo [1,C,B,H',W',4,1] - cam_homo = torch.matmul( - w2c_stack[:, None, :, None, None], # [V, 1, B, 1, 1, 4, 4] - pts_homo[None], # [1, C, B, H', W', 4, 1] - ) # [V, C, B, H', W', 4, 1] - cam_pts = cam_homo[..., :3, :] # [V, C, B, H', W', 3, 1] - - # Broadcast matmul: K [V,1,B,1,1,3,3] x cam_pts [V,C,B,H',W',3,1] - proj = torch.matmul( - K_ds_stack[:, None, :, None, None], # [V, 1, B, 1, 1, 3, 3] - cam_pts, # [V, C, B, H', W', 3, 1] - ) # [V, C, B, H', W', 3, 1] - - z_all = proj[..., 2, 0] # [V, C, B, H', W'] - u_all = proj[..., 0, 0] / (z_all + 1e-7) - v_all = proj[..., 1, 0] / (z_all + 1e-7) - x_all = u_all.round().long() - y_all = v_all.round().long() - valid = (z_all > 0) & (x_all >= 0) & (x_all < Wt_ds) & (y_all >= 0) & (y_all < Ht_ds) - - if not valid.any(): - log.info( - f"Sparse3DCache.retrieve: no valid projections for any of {num_cands} candidates " - f"(frame_ids={self._frame_ids[:avail]})", - rank0_only=True, - ) - return [] - - # valid dims: [V, C, B, H', W'] → nonzero gives (view_ids, cand_ids, b_idx, _, _) - view_ids, cand_ids, b_idx, _, _ = valid.nonzero(as_tuple=True) - y_idx = y_all[valid] - x_idx = x_all[valid] - z_vals = z_all[valid].to(torch.float32) - - Btot = Bp - pixels_per_view = Btot * Ht_ds * Wt_ds - # Each (view, batch, y, x) gets a unique key so depth fusion is per-view. - lin_keys = view_ids * pixels_per_view + b_idx * (Ht_ds * Wt_ds) + y_idx * Wt_ds + x_idx - n_keys = num_views * pixels_per_view - - inf_val = torch.tensor(float('inf'), device=device, dtype=z_vals.dtype) - min_depth = torch.full((n_keys,), inf_val, device=device, dtype=z_vals.dtype) - min_depth.scatter_reduce_(0, lin_keys, z_vals, reduce='amin', include_self=True) - - min_d_for_pts = min_depth[lin_keys] - if max_coverage: - keep = z_vals <= (min_d_for_pts + float(depth_threshold)) - if not keep.any(): - return [] - - lin_keys_keep = lin_keys[keep] - cand_keep = cand_ids[keep].to(torch.long) - - flat_idx = cand_keep * n_keys + lin_keys_keep - mask_flat = torch.zeros((num_cands * n_keys,), device=device, dtype=torch.bool) - mask_flat.scatter_(0, flat_idx, torch.ones_like(flat_idx, dtype=torch.bool)) - mask = mask_flat.view(num_cands, n_keys) - - k = min(int(num_latents), num_cands) - if k <= 0: - return [] - - # Pre-cover pixels from the temporally closest frame (largest frame_id) - # because its warping is already included in the network condition. - # Only applies when the most recent frame_id > 0 (skip the seed frame_id=0 - # and multiview inputs with negative IDs). - avail_frame_ids = self._frame_ids[:avail] - max_frame_id = max(avail_frame_ids) - excluded: set[int] = set() - if max_frame_id > 0: - last_cand_idx = int(max(range(avail), key=lambda i: avail_frame_ids[i])) - covered = mask[last_cand_idx].clone() - excluded.add(last_cand_idx) - log.info( - f"Sparse3DCache.retrieve(max_coverage): pre-covering pixels from temporally closest " - f"frame_id={avail_frame_ids[last_cand_idx]} (cand_idx={last_cand_idx}, " - f"pixels={int(covered.sum().item())})", - rank0_only=True, - ) - else: - covered = torch.zeros((n_keys,), device=device, dtype=torch.bool) - - selected: list[int] = [] - for _ in range(k): - additional = (mask & (~covered)).sum(dim=1) - exclude_indices = list(selected) + list(excluded) - if len(exclude_indices) > 0: - additional[torch.tensor(exclude_indices, device=device)] = -1 - best = int(torch.argmax(additional).item()) - if additional[best].item() <= 0: - break - selected.append(best) - covered |= mask[best] - - if len(selected) == 0: - return [] - top_ids = selected - else: - is_min = z_vals <= (min_d_for_pts + 1e-6) - big_int = torch.iinfo(torch.long).max - cid_masked = torch.where(is_min, cand_ids.to(torch.long), torch.full_like(cand_ids, big_int, dtype=torch.long)) - - owner_lin = torch.full((n_keys,), -1, device=device, dtype=torch.long) - owner_lin_tmp = torch.full((n_keys,), big_int, device=device, dtype=torch.long) - owner_lin_tmp.scatter_reduce_(0, lin_keys, cid_masked, reduce='amin', include_self=True) - owner_lin = torch.where(owner_lin_tmp == big_int, owner_lin, owner_lin_tmp) - - valid_owner = owner_lin[owner_lin >= 0] - counts = torch.bincount(valid_owner, minlength=num_cands) - - scores_t = counts.float() - scores = scores_t.tolist() - - score_map = { - int(self._latent_indices[i]): {"score": float(scores[i]), "frame_id": int(self._frame_ids[i])} - for i in range(num_cands) - } - log.info(f"Sparse3DCache.retrieve scores (latent_index -> score): {score_map}", rank0_only=True) - - if random and num_latents > 0: - max_score = scores_t.max() if scores_t.numel() > 0 else scores_t.new_tensor(1.0) - weights = torch.clamp(scores_t, min=0.0) + max_score * 0.02 - - k = min(int(num_latents), scores_t.shape[0]) - if k <= 0: - return [] - sampled_ids = torch.multinomial(weights, num_samples=k, replacement=False) - top_ids = [int(i) for i in sampled_ids.tolist()] - - else: - top_ids = sorted(range(num_cands), key=lambda i: scores[i], reverse=True)[:num_latents] - - top_ids_reversed = top_ids[::-1] - return [(self._latent_indices[i], self._frame_ids[i]) for i in top_ids_reversed] - - diff --git a/lyra_2/_src/models/utils.py b/lyra_2/_src/models/utils.py deleted file mode 100644 index 907764d2f939eaab687e64e5821b944ea220bb37..0000000000000000000000000000000000000000 --- a/lyra_2/_src/models/utils.py +++ /dev/null @@ -1,457 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import torch -from safetensors.torch import load as safetensors_torch_load - -from lyra_2._ext.imaginaire.utils import log -from lyra_2._ext.imaginaire.utils.easy_io import easy_io - - -def load_state_dict_from_safetensors(file_path, torch_dtype=None): - byte_stream = easy_io.load(file_path, file_format="byte") - state_dict = safetensors_torch_load(byte_stream) - return state_dict - - -def load_state_dict_from_folder(file_path, torch_dtype=None): - state_dict = {} - for file_name in os.listdir(file_path): - if "." in file_name and file_name.split(".")[-1] in ["safetensors", "bin", "ckpt", "pth", "pt"]: - state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype)) - return state_dict - - -def load_state_dict(file_path, torch_dtype=None): - if file_path.endswith(".safetensors"): - return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) - else: - return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) - - -def load_state_dict_from_bin(file_path, torch_dtype=None): - state_dict = easy_io.load(file_path, file_format="pt", map_location="cpu", weights_only=False) - if torch_dtype is not None: - for i in state_dict: - if isinstance(state_dict[i], torch.Tensor): - state_dict[i] = state_dict[i].to(torch_dtype) - return state_dict - - -# based on https://github.com/huggingface/diffusers/blob/b793debd9d09225582943a1e9cb4ccdab30f1b37/src/diffusers/loaders/lora_conversion_utils.py#L1817 -# since our model is the same as non-diffusers Wan, we only need to change the lora keys: -# 1. add adapter_name to the key, 2. change lora keys -def _convert_non_diffusers_wan_lora_to_diffusers(state_dict, adapter_name="default"): - converted_state_dict = {} - if any("diffusion_model." in k for k in state_dict.keys()): - original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()} - else: - original_state_dict = state_dict - - block_numbers = {int(k.split(".")[1]) for k in original_state_dict if k.startswith("blocks.")} - min_block = min(block_numbers) - max_block = max(block_numbers) - - is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict) - lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down" - lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up" - - def get_alpha_scales(down_weight, key): - rank = down_weight.shape[0] - alpha = original_state_dict.pop(key + ".alpha").item() - scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here - scale_down = scale - scale_up = 1.0 - while scale_down * 2 < scale_up: - scale_down *= 2 - scale_up /= 2 - return scale_down, scale_up - - diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))] - if diff_keys: - for diff_k in diff_keys: - param = original_state_dict[diff_k] - # The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3, - # and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond - # to norm layers. Ignoring them is the best option at the moment until a better solution is found. It - # is okay to ignore because they do not affect the model output in a significant manner. - threshold = 1.6e-2 - absdiff = param.abs().max() - param.abs().min() - all_zero = torch.all(param == 0).item() - all_absdiff_lower_than_threshold = absdiff < threshold - if all_zero or all_absdiff_lower_than_threshold: - log.debug( - f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold." - ) - original_state_dict.pop(diff_k) - - # For the `diff_b` keys, we treat them as lora_bias. - # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias - - for i in range(min_block, max_block + 1): - # Self-attention - for o, c in zip(["q", "k", "v", "o"], ["q", "k", "v", "o"]): - original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight" - converted_down_key = f"blocks.{i}.self_attn.{c}.lora_A.{adapter_name}.weight" - if original_key in original_state_dict: - converted_state_dict[converted_down_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight" - converted_up_key = f"blocks.{i}.self_attn.{c}.lora_B.{adapter_name}.weight" - if original_key in original_state_dict: - converted_state_dict[converted_up_key] = original_state_dict.pop(original_key) - - alpha_key = f"blocks.{i}.self_attn.{o}.alpha" - if alpha_key in original_state_dict: - down_weight = converted_state_dict[converted_down_key] - up_weight = converted_state_dict[converted_up_key] - scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.self_attn.{o}") - converted_state_dict[converted_down_key] = down_weight * scale_down - converted_state_dict[converted_up_key] = up_weight * scale_up - - original_key = f"blocks.{i}.self_attn.{o}.diff_b" - converted_key = f"blocks.{i}.self_attn.{c}.lora_B.{adapter_name}.bias" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - # Cross-attention - for o, c in zip(["q", "k", "v", "o"], ["q", "k", "v", "o"]): - original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" - converted_down_key = f"blocks.{i}.cross_attn.{c}.lora_A.{adapter_name}.weight" - if original_key in original_state_dict: - converted_state_dict[converted_down_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" - converted_up_key = f"blocks.{i}.cross_attn.{c}.lora_B.{adapter_name}.weight" - if original_key in original_state_dict: - converted_state_dict[converted_up_key] = original_state_dict.pop(original_key) - - alpha_key = f"blocks.{i}.cross_attn.{o}.alpha" - if alpha_key in original_state_dict: - down_weight = converted_state_dict[converted_down_key] - up_weight = converted_state_dict[converted_up_key] - scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.cross_attn.{o}") - converted_state_dict[converted_down_key] = down_weight * scale_down - converted_state_dict[converted_up_key] = up_weight * scale_up - - original_key = f"blocks.{i}.cross_attn.{o}.diff_b" - converted_key = f"blocks.{i}.cross_attn.{c}.lora_B.{adapter_name}.bias" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - if is_i2v_lora: - for o, c in zip(["k_img", "v_img"], ["k_img", "v_img"]): - original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" - converted_down_key = f"blocks.{i}.cross_attn.{c}.lora_A.{adapter_name}.weight" - if original_key in original_state_dict: - converted_state_dict[converted_down_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" - converted_up_key = f"blocks.{i}.cross_attn.{c}.lora_B.{adapter_name}.weight" - if original_key in original_state_dict: - converted_state_dict[converted_up_key] = original_state_dict.pop(original_key) - - alpha_key = f"blocks.{i}.cross_attn.{o}.alpha" - if alpha_key in original_state_dict: - down_weight = converted_state_dict[converted_down_key] - up_weight = converted_state_dict[converted_up_key] - scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.cross_attn.{o}") - converted_state_dict[converted_down_key] = down_weight * scale_down - converted_state_dict[converted_up_key] = up_weight * scale_up - - original_key = f"blocks.{i}.cross_attn.{o}.diff_b" - converted_key = f"blocks.{i}.cross_attn.{c}.lora_B.{adapter_name}.bias" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - # FFN - for o, c in zip(["ffn.0", "ffn.2"], ["ffn.0", "ffn.2"]): - original_key = f"blocks.{i}.{o}.{lora_down_key}.weight" - converted_down_key = f"blocks.{i}.{c}.lora_A.{adapter_name}.weight" - if original_key in original_state_dict: - converted_state_dict[converted_down_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.{o}.{lora_up_key}.weight" - converted_up_key = f"blocks.{i}.{c}.lora_B.{adapter_name}.weight" - if original_key in original_state_dict: - converted_state_dict[converted_up_key] = original_state_dict.pop(original_key) - - alpha_key = f"blocks.{i}.{o}.alpha" - if alpha_key in original_state_dict: - down_weight = converted_state_dict[converted_down_key] - up_weight = converted_state_dict[converted_up_key] - scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.{o}") - converted_state_dict[converted_down_key] = down_weight * scale_down - converted_state_dict[converted_up_key] = up_weight * scale_up - - original_key = f"blocks.{i}.{o}.diff_b" - converted_key = f"blocks.{i}.{c}.lora_B.{adapter_name}.bias" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - # Lyra2-specific: cam_encoder, buffer_encoder.{0,1} - for o in ["cam_encoder", "buffer_encoder.0", "buffer_encoder.1"]: - original_key = f"blocks.{i}.{o}.{lora_down_key}.weight" - converted_down_key = f"blocks.{i}.{o}.lora_A.{adapter_name}.weight" - if original_key in original_state_dict: - converted_state_dict[converted_down_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.{o}.{lora_up_key}.weight" - converted_up_key = f"blocks.{i}.{o}.lora_B.{adapter_name}.weight" - if original_key in original_state_dict: - converted_state_dict[converted_up_key] = original_state_dict.pop(original_key) - - alpha_key = f"blocks.{i}.{o}.alpha" - if alpha_key in original_state_dict: - down_weight = converted_state_dict[converted_down_key] - up_weight = converted_state_dict[converted_up_key] - scale_down, scale_up = get_alpha_scales(down_weight, f"blocks.{i}.{o}") - converted_state_dict[converted_down_key] = down_weight * scale_down - converted_state_dict[converted_up_key] = up_weight * scale_up - - original_key = f"blocks.{i}.{o}.diff_b" - converted_key = f"blocks.{i}.{o}.lora_B.{adapter_name}.bias" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - # Remaining. - if original_state_dict: - if any("time_projection" in k for k in original_state_dict): - original_key = f"time_projection.1.{lora_down_key}.weight" - converted_down_key = f"time_projection.1.lora_A.{adapter_name}.weight" - if original_key in original_state_dict: - converted_state_dict[converted_down_key] = original_state_dict.pop(original_key) - - original_key = f"time_projection.1.{lora_up_key}.weight" - converted_up_key = f"time_projection.1.lora_B.{adapter_name}.weight" - if original_key in original_state_dict: - converted_state_dict[converted_up_key] = original_state_dict.pop(original_key) - - alpha_key = f"time_projection.1.alpha" - if alpha_key in original_state_dict: - down_weight = converted_state_dict[converted_down_key] - up_weight = converted_state_dict[converted_up_key] - scale_down, scale_up = get_alpha_scales(down_weight, f"time_projection.1") - converted_state_dict[converted_down_key] = down_weight * scale_down - converted_state_dict[converted_up_key] = up_weight * scale_up - - if "time_projection.1.diff_b" in original_state_dict: - converted_state_dict[f"time_projection.1.lora_B.{adapter_name}.bias"] = original_state_dict.pop( - "time_projection.1.diff_b" - ) - - if any("head.head" in k for k in state_dict): - converted_state_dict[f"head.head.lora_A.{adapter_name}.weight"] = original_state_dict.pop( - f"head.head.{lora_down_key}.weight" - ) - converted_state_dict[f"head.head.lora_B.{adapter_name}.weight"] = original_state_dict.pop( - f"head.head.{lora_up_key}.weight" - ) - if "head.head.diff_b" in original_state_dict: - converted_state_dict[f"head.head.lora_B.{adapter_name}.bias"] = original_state_dict.pop( - "head.head.diff_b" - ) - - alpha_key = f"head.head.alpha" - if alpha_key in original_state_dict: - down_weight = converted_state_dict[f"head.head.lora_A.{adapter_name}.weight"] - up_weight = converted_state_dict[f"head.head.lora_B.{adapter_name}.weight"] - scale_down, scale_up = get_alpha_scales(down_weight, f"head.head") - converted_state_dict[f"head.head.lora_A.{adapter_name}.weight"] = down_weight * scale_down - converted_state_dict[f"head.head.lora_B.{adapter_name}.weight"] = up_weight * scale_up - - for text_time in ["text_embedding", "time_embedding"]: - if any(text_time in k for k in original_state_dict): - for b_n in [0, 2]: - diffusers_b_n = b_n - diffusers_name = text_time - if any(f"{text_time}.{b_n}" in k for k in original_state_dict): - converted_state_dict[f"{diffusers_name}.{diffusers_b_n}.lora_A.{adapter_name}.weight"] = ( - original_state_dict.pop(f"{text_time}.{b_n}.{lora_down_key}.weight") - ) - converted_state_dict[f"{diffusers_name}.{diffusers_b_n}.lora_B.{adapter_name}.weight"] = ( - original_state_dict.pop(f"{text_time}.{b_n}.{lora_up_key}.weight") - ) - alpha_key = f"{text_time}.{b_n}.alpha" - if alpha_key in original_state_dict: - down_weight = converted_state_dict[ - f"{diffusers_name}.{diffusers_b_n}.lora_A.{adapter_name}.weight" - ] - up_weight = converted_state_dict[ - f"{diffusers_name}.{diffusers_b_n}.lora_B.{adapter_name}.weight" - ] - scale_down, scale_up = get_alpha_scales(down_weight, f"{text_time}.{b_n}") - converted_state_dict[f"{diffusers_name}.{diffusers_b_n}.lora_A.{adapter_name}.weight"] = ( - down_weight * scale_down - ) - converted_state_dict[f"{diffusers_name}.{diffusers_b_n}.lora_B.{adapter_name}.weight"] = ( - up_weight * scale_up - ) - - if f"{text_time}.{b_n}.diff_b" in original_state_dict: - converted_state_dict[f"{diffusers_name}.{diffusers_b_n}.lora_B.{adapter_name}.bias"] = ( - original_state_dict.pop(f"{text_time}.{b_n}.diff_b") - ) - - # Lyra2-specific: top-level patch_embedding - if any("patch_embedding" in k and "clean_patch_embeddings" not in k for k in original_state_dict): - lora_name = "patch_embedding" - diffusers_name = lora_name - original_key = f"{lora_name}.{lora_down_key}.weight" - if original_key in original_state_dict: - converted_state_dict[f"{diffusers_name}.lora_A.{adapter_name}.weight"] = original_state_dict.pop( - original_key - ) - original_key = f"{lora_name}.{lora_up_key}.weight" - if original_key in original_state_dict: - converted_state_dict[f"{diffusers_name}.lora_B.{adapter_name}.weight"] = original_state_dict.pop( - original_key - ) - alpha_key = f"{lora_name}.alpha" - if alpha_key in original_state_dict: - down_weight = converted_state_dict[f"{diffusers_name}.lora_A.{adapter_name}.weight"] - up_weight = converted_state_dict[f"{diffusers_name}.lora_B.{adapter_name}.weight"] - scale_down, scale_up = get_alpha_scales(down_weight, f"{lora_name}") - converted_state_dict[f"{diffusers_name}.lora_A.{adapter_name}.weight"] = down_weight * scale_down - converted_state_dict[f"{diffusers_name}.lora_B.{adapter_name}.weight"] = up_weight * scale_up - if f"{lora_name}.diff_b" in original_state_dict: - converted_state_dict[f"{diffusers_name}.lora_B.{adapter_name}.bias"] = original_state_dict.pop( - f"{lora_name}.diff_b" - ) - - # Lyra2-specific: clean_patch_embeddings.{0..5} - for cp_idx in range(6): - lora_name = f"clean_patch_embeddings.{cp_idx}" - if any(lora_name in k for k in original_state_dict): - diffusers_name = lora_name - original_key = f"{lora_name}.{lora_down_key}.weight" - converted_down_key = f"{diffusers_name}.lora_A.{adapter_name}.weight" - if original_key in original_state_dict: - converted_state_dict[converted_down_key] = original_state_dict.pop(original_key) - - original_key = f"{lora_name}.{lora_up_key}.weight" - converted_up_key = f"{diffusers_name}.lora_B.{adapter_name}.weight" - if original_key in original_state_dict: - converted_state_dict[converted_up_key] = original_state_dict.pop(original_key) - - alpha_key = f"{lora_name}.alpha" - if alpha_key in original_state_dict: - down_weight = converted_state_dict[converted_down_key] - up_weight = converted_state_dict[converted_up_key] - scale_down, scale_up = get_alpha_scales(down_weight, f"{lora_name}") - converted_state_dict[converted_down_key] = down_weight * scale_down - converted_state_dict[converted_up_key] = up_weight * scale_up - - if f"{lora_name}.diff_b" in original_state_dict: - converted_state_dict[f"{diffusers_name}.lora_B.{adapter_name}.bias"] = original_state_dict.pop( - f"{lora_name}.diff_b" - ) - - for img_ours, img_theirs in [ - ("img_emb.proj.1", "img_emb.proj.1"), - ("img_emb.proj.3", "img_emb.proj.3"), - ]: - original_key = f"{img_theirs}.{lora_down_key}.weight" - converted_key = f"{img_ours}.lora_A.{adapter_name}.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"{img_theirs}.{lora_up_key}.weight" - converted_key = f"{img_ours}.lora_B.{adapter_name}.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - alpha_key = f"{img_ours}.alpha" - if alpha_key in original_state_dict: - down_weight = converted_state_dict[f"{img_ours}.lora_A.{adapter_name}.weight"] - up_weight = converted_state_dict[f"{img_ours}.lora_B.{adapter_name}.weight"] - scale_down, scale_up = get_alpha_scales(down_weight, f"{img_ours}") - converted_state_dict[f"{img_ours}.lora_A.{adapter_name}.weight"] = down_weight * scale_down - converted_state_dict[f"{img_ours}.lora_B.{adapter_name}.weight"] = up_weight * scale_up - - if len(original_state_dict) > 0: - diff = all(".diff" in k for k in original_state_dict) - if diff: - diff_keys = {k for k in original_state_dict if k.endswith(".diff")} - if not all("lora" not in k for k in diff_keys): - raise ValueError - log.info( - "The remaining `state_dict` contains `diff` keys which we do not handle yet. If you see performance issues, please file an issue: " - "https://github.com/huggingface/diffusers//issues/new" - ) - else: - raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}") - - return converted_state_dict - - -def _convert_musubi_wan_lora_to_non_diffusers_wan(state_dict): - # https://github.com/kohya-ss/musubi-tuner - converted_state_dict = {} - original_state_dict = {k[len("lora_unet_") :]: v for k, v in state_dict.items()} - - num_blocks = len({k.split("blocks_")[1].split("_")[0] for k in original_state_dict}) - is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict) - - def get_alpha_scales(down_weight, key): - rank = down_weight.shape[0] - alpha = original_state_dict.pop(key + ".alpha").item() - scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here - scale_down = scale - scale_up = 1.0 - while scale_down * 2 < scale_up: - scale_down *= 2 - scale_up /= 2 - return scale_down, scale_up - - for i in range(num_blocks): - # Self-attention - for o, c in zip(["q", "k", "v", "o"], ["q", "k", "v", "o"]): - down_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_down.weight") - up_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_up.weight") - scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_self_attn_{o}") - converted_state_dict[f"blocks.{i}.self_attn.{c}.lora_down.weight"] = down_weight * scale_down - converted_state_dict[f"blocks.{i}.self_attn.{c}.lora_up.weight"] = up_weight * scale_up - - # Cross-attention - for o, c in zip(["q", "k", "v", "o"], ["q", "k", "v", "o"]): - down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight") - up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight") - scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}") - converted_state_dict[f"blocks.{i}.cross_attn.{c}.lora_down.weight"] = down_weight * scale_down - converted_state_dict[f"blocks.{i}.cross_attn.{c}.lora_up.weight"] = up_weight * scale_up - - if is_i2v_lora: - for o, c in zip(["k_img", "v_img"], ["k_img", "v_img"]): - down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight") - up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight") - scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}") - converted_state_dict[f"blocks.{i}.cross_attn.{c}.lora_down.weight"] = down_weight * scale_down - converted_state_dict[f"blocks.{i}.cross_attn.{c}.lora_up.weight"] = up_weight * scale_up - - # FFN - for o, c in zip(["ffn_0", "ffn_2"], ["ffn.0", "ffn.2"]): - down_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_down.weight") - up_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_up.weight") - scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_{o}") - converted_state_dict[f"blocks.{i}.{c}.lora_down.weight"] = down_weight * scale_down - converted_state_dict[f"blocks.{i}.{c}.lora_up.weight"] = up_weight * scale_up - - if len(original_state_dict) > 0: - raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}") - - return converted_state_dict diff --git a/lyra_2/_src/models/wan_t2v_model.py b/lyra_2/_src/models/wan_t2v_model.py deleted file mode 100644 index b97fa28437f5c6fd24d29669fc4306dce8dcb3ec..0000000000000000000000000000000000000000 --- a/lyra_2/_src/models/wan_t2v_model.py +++ /dev/null @@ -1,1116 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from __future__ import annotations - -import collections -import os -from contextlib import contextmanager -from functools import partial -from typing import Any, Callable, Dict, Mapping, Optional, Tuple - -import attrs -import numpy as np -import torch -from einops import rearrange -from megatron.core import parallel_state -from torch import Tensor -from torch.distributed._composable.fsdp import FSDPModule, fully_shard -from torch.distributed._tensor.api import DTensor -from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict -from torch.distributed.tensor import distribute_tensor -from torch.nn.modules.module import _IncompatibleKeys - -try: - from peft import LoraConfig, inject_adapter_in_model - from peft.tuners.tuners_utils import BaseTunerLayer, _find_minimal_target_modules -except ImportError: - print("peft is not installed, Lora is not supported") - LoraConfig = None - inject_adapter_in_model = None - _find_minimal_target_modules = None - BaseTunerLayer = None - -from lyra_2._ext.imaginaire.lazy_config import LazyDict -from lyra_2._ext.imaginaire.lazy_config import instantiate as lazy_instantiate -from lyra_2._ext.imaginaire.model import ImaginaireModel -from lyra_2._ext.imaginaire.types.denoise_prediction import DenoisePrediction -from lyra_2._ext.imaginaire.utils import log, misc -from lyra_2._ext.imaginaire.utils.checkpointer import non_strict_load_model -from lyra_2._ext.imaginaire.utils.count_params import count_params -from lyra_2._ext.imaginaire.utils.ema import FastEmaModelUpdater -from lyra_2._ext.imaginaire.utils.fsdp_helper import hsdp_device_mesh -from lyra_2._ext.imaginaire.utils.optim_instantiate import get_base_scheduler -from lyra_2._src.callbacks.model_weights_stats import WeightTrainingStat -from lyra_2._src.datasets.utils import VIDEO_RES_SIZE_INFO -from lyra_2._src.models.fm_solvers_unipc import FlowUniPCMultistepScheduler -from lyra_2._src.models.utils import ( - _convert_musubi_wan_lora_to_non_diffusers_wan, - _convert_non_diffusers_wan_lora_to_diffusers, - load_state_dict, -) -from lyra_2._src.modules.conditioner import DataType, T2VCondition -from lyra_2._src.schedulers.rectified_flow import RectifiedFlow -from lyra_2._src.tokenizers.base_vae import BaseVAE -from lyra_2._src.utils.context_parallel import ( - broadcast, - broadcast_split_tensor, - cat_outputs_cp, -) -from lyra_2._src.utils.dtensor_helper import DTensorFastEmaModelUpdater, broadcast_dtensor_model_states -from lyra_2._src.utils.misc import sync_timer -from lyra_2._src.utils.torch_future import clip_grad_norm_ - -IS_PREPROCESSED_KEY = "is_preprocessed" -NUM_EMBEDDING_PADDING_TOKENS = 512 - - -@attrs.define(slots=False) -class EMAConfig: - """ - Config for the EMA. - """ - - enabled: bool = True - rate: float = 0.1 - iteration_shift: int = 0 - - -@attrs.define(slots=False) -class I4LoraConfig: - enabled: bool = False - pretrained_lora_path: str = "" - lora_rank: int = -1 - adapter_name: str = "default" - lora_target_modules: list[str] = [] - init_lora_weights: str = "kaiming" - - -@attrs.define(slots=False) -class T2VModelConfig: - tokenizer: LazyDict = None - conditioner: LazyDict = None - net: LazyDict = None - ema: EMAConfig = EMAConfig() - - fsdp_shard_size: int = 1 - precision: str = "bfloat16" - input_data_key: str = "video" # key to fetch input data from data_batch - input_image_key: str = "image" # key to fetch input image from data_batch - input_caption_key: str = "ai_caption" # Key used to fetch input captions - use_torch_compile: bool = False - lora_config: I4LoraConfig = I4LoraConfig() - use_mp_policy_fsdp: bool = False - keep_original_net_dtype: bool = False - - state_ch: int = 16 # for latent model, ref to the latent channel number - state_t: int = 8 # for latent model, ref to the latent number of frames - resolution: str = "512" - - shift: int = 5 - use_dynamic_shift: bool = False - train_time_weight: str = "uniform" - train_time_distribution: str = "logitnormal" - max_timestep_boundary: float = 1.0 - min_timestep_boundary: float = 0.0 - - -class WANDiffusionModel(ImaginaireModel): - """ - Diffusion model. - """ - - def __init__(self, config: T2VModelConfig): - super().__init__() - - self.config = config - - self.precision = { - "float32": torch.float32, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - }[config.precision] - self.tensor_kwargs = {"device": "cuda", "dtype": self.precision} - self.flow_matching_kwargs = {"device": "cuda", "dtype": torch.float32} - - log.warning(f"DiffusionModel: precision {self.precision}") - log.warning(f"Flow Matching: precision {self.flow_matching_kwargs['dtype']}") - - # 1. set data keys and data information - # self.sigma_data = config.sigma_data - self.setup_data_key() - - # 2. setup up diffusion processing and scaling~(pre-condition), sampler - self.sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=1000, shift=1, use_dynamic_shifting=False - ) - # Lazily-initialized flow-matching scheduler for DMD distillation (4-step) inference. - self.dmd_scheduler = None - - # 3. tokenizer - with misc.timer("DiffusionModel: set_up_tokenizer"): - self.tokenizer: BaseVAE = lazy_instantiate(config.tokenizer) - assert self.tokenizer.latent_ch == self.config.state_ch, ( - f"latent_ch {self.tokenizer.latent_ch} != state_shape {self.config.state_ch}" - ) - - # 5. create fsdp mesh if needed - if config.fsdp_shard_size > 1: - self.fsdp_device_mesh = hsdp_device_mesh( - sharding_group_size=config.fsdp_shard_size, - ) - else: - self.fsdp_device_mesh = None - - # 6. diffusion neural networks part - if "lora_config" in config: - if config.lora_config.enabled: - self.config.net.postpone_checkpoint = True - self.set_up_model() - - # 7. training states - if parallel_state.is_initialized(): - self.data_parallel_size = parallel_state.get_data_parallel_world_size() - else: - self.data_parallel_size = 1 - - # 8. rectified flow - self.rectified_flow = RectifiedFlow( - velocity_field=self.net, - train_time_distribution=config.train_time_distribution, - max_timestep_boundary=config.max_timestep_boundary, - min_timestep_boundary=config.min_timestep_boundary, - use_dynamic_shift=config.use_dynamic_shift, - shift=config.shift, - train_time_weight_method=config.train_time_weight, - device=torch.device("cuda"), - dtype=self.flow_matching_kwargs["dtype"], - ) - - if not config.lora_config.enabled: - self.net.requires_grad_(True) - if config.lora_config.enabled: - self.net.enable_selective_checkpoint(self.net.sac_config, self.net.blocks) - - def maybe_inject_lora_to_net(self, net, lora_config=None, skip_inject=False, skip_load=False): - if lora_config is None: - lora_config = self.config.lora_config - if lora_config.enabled: - if lora_config.pretrained_lora_path: - self.load_lora_weights( - lora_path=lora_config.pretrained_lora_path, - adapter_name=lora_config.adapter_name, - training_mode=True, - skip_inject=skip_inject, - skip_load=skip_load, - model=net, - ) - elif not skip_inject: - self.add_lora_to_model( - adapter_name=lora_config.adapter_name, - lora_rank=lora_config.lora_rank, - lora_target_modules=lora_config.lora_target_modules, - init_lora_weights=lora_config.init_lora_weights, - training_mode=True, - model=net, - ) - - def setup_data_key(self) -> None: - self.input_data_key = self.config.input_data_key # by default it is video key for Video diffusion model - self.input_image_key = self.config.input_image_key - - def build_net(self): - config = self.config - # NOTE: (ruiyuang) always use meta device, no need to use cpu - init_device = "meta" - with misc.timer("Creating PyTorch model"): - with sync_timer("net instantiate"): - with torch.device(init_device): - net = lazy_instantiate(config.net) - - if "lora_config" in config: - self.maybe_inject_lora_to_net(net, skip_load=True) - self._param_count = count_params(net, verbose=False) - - if self.fsdp_device_mesh: - net.fully_shard(mesh=self.fsdp_device_mesh) - net = fully_shard(net, mesh=self.fsdp_device_mesh, reshard_after_forward=True) - - with misc.timer("meta to cuda and broadcast model states"): - net.to_empty(device="cuda") - net.init_weights() - - if self.fsdp_device_mesh: - broadcast_dtensor_model_states(net, self.fsdp_device_mesh) - for name, param in net.named_parameters(): - assert isinstance(param, DTensor), f"param should be DTensor, {name} got {type(param)}" - if "lora_config" in config: - self.maybe_inject_lora_to_net(net, skip_inject=True) - return net - - @misc.timer("DiffusionModel: set_up_model") - def set_up_model(self): - config = self.config - with misc.timer("Creating PyTorch model and ema if enabled"): - self.conditioner = lazy_instantiate(config.conditioner) - assert sum(p.numel() for p in self.conditioner.parameters() if p.requires_grad) == 0, ( - "conditioner should not have learnable parameters" - ) - self.net = self.build_net() - self._param_count = count_params(self.net, verbose=False) - - if config.ema.enabled: - self.net_ema = self.build_net() - self.net_ema.requires_grad_(False) - - if self.fsdp_device_mesh: - self.net_ema_worker = DTensorFastEmaModelUpdater() - else: - self.net_ema_worker = FastEmaModelUpdater() - - s = config.ema.rate - self.ema_exp_coefficient = np.roots([1, 7, 16 - s**-2, 12 - s**-2]).real.max() - - self.net_ema_worker.copy_to(src_model=self.net, tgt_model=self.net_ema) - torch.cuda.empty_cache() - - def init_optimizer_scheduler( - self, optimizer_config: LazyDict, scheduler_config: LazyDict - ) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: - """Creates the optimizer and scheduler for the model. - - Args: - config_model (ModelConfig): The config object for the model. - - Returns: - optimizer (torch.optim.Optimizer): The model optimizer. - scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. - """ - optimizer = lazy_instantiate(optimizer_config, model=self.net) - scheduler = get_base_scheduler(optimizer, self, scheduler_config) - return optimizer, scheduler - - # ------------------------ training hooks ------------------------ - def on_before_zero_grad( - self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int - ) -> None: - """ - update the net_ema - """ - del scheduler, optimizer - - if self.config.ema.enabled: - # calculate beta for EMA update - ema_beta = self.ema_beta(iteration) - self.net_ema_worker.update_average(self.net, self.net_ema, beta=ema_beta) - - def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: - if self.config.ema.enabled: - self.net_ema.to(dtype=torch.float32) - if hasattr(self.tokenizer, "reset_dtype"): - self.tokenizer.reset_dtype() - self.net = self.net.to(memory_format=memory_format, **self.tensor_kwargs) - - if hasattr(self.config, "use_torch_compile") and self.config.use_torch_compile: # compatible with old config - if torch.__version__ < "2.3": - log.warning( - "torch.compile in Pytorch version older than 2.3 doesn't work well with activation checkpointing.\n" - "It's very likely there will be no significant speedup from torch.compile.\n" - "Please use at least 24.04 Pytorch container, or imaginaire4:v7 container." - ) - # Increasing cache size. It's required because of the model size and dynamic input shapes resulting in - # multiple different triton kernels. For 28 TransformerBlocks, the cache limit of 256 should be enough for - # up to 9 different input shapes, as 28*9 < 256. If you have more Blocks or input shapes, and you observe - # graph breaks at each Block (detectable with torch._dynamo.explain) or warnings about - # exceeding cache limit, you may want to increase this size. - # Starting with 24.05 Pytorch container, the default value is 256 anyway. - # You can read more about it in the comments in Pytorch source code under path torch/_dynamo/cache_size.py. - torch._dynamo.config.accumulated_cache_size_limit = 256 - # dynamic=False means that a separate kernel is created for each shape. It incurs higher compilation costs - # at initial iterations, but can result in more specialized and efficient kernels. - # dynamic=True currently throws errors in pytorch 2.3. - self.net = torch.compile(self.net, dynamic=False, disable=not self.config.use_torch_compile) - - def is_lora_model(self, model): - for _, module in model.named_modules(): - if isinstance(module, BaseTunerLayer): - return True - return False - - def load_lora_weights( - self, - lora_path, - adapter_name=None, - lora_alpha=None, - training_mode=False, - skip_inject=False, - skip_load=False, - model=None, - ): - if adapter_name is None: - adapter_name = os.path.basename(lora_path).replace(".", "--") - if model is None: - model = self.net - - for _, module in model.named_modules(): - if isinstance(module, BaseTunerLayer) and not skip_inject and not skip_load: - # on `skip_inject` or `skip_load`, we allow override the existing LoRA with same name. - if adapter_name in module.active_adapter: - log.info(f"LoRA {adapter_name} already loaded, skip loading") - return adapter_name - else: - log.warning(f"LoRA {module.active_adapter} loaded, loading {adapter_name} will override it!") - break - - state_dict = load_state_dict(lora_path, torch_dtype=self.precision) - if any(k.startswith("lora_unet_") for k in state_dict): - state_dict = _convert_musubi_wan_lora_to_non_diffusers_wan(state_dict) - # remove the prefix of the state dict - state_dict = {k.replace("diffusion_model.", ""): v for k, v in state_dict.items()} - # strip leading "net." prefix: load_lora_weights injects into `self.net`, - # so the LoRA keys must be relative to `self.net` (not the full model). - state_dict = {k[len("net.") :] if k.startswith("net.") else k: v for k, v in state_dict.items()} - - if not skip_inject: - target_modules = list({name.split(".lora")[0] for name in state_dict.keys()}) - for key, val in state_dict.items(): - if ("lora_down" in key or "lora_A" in key) and val.ndim > 1: - rank = val.shape[0] - break - else: - raise ValueError("Rank is not found in the state dict") - if lora_alpha is None: - lora_alpha = rank - lora_bias = any("diff_b" in k for k in state_dict) - - named_modules = model.named_modules() - key_list = [key for key, _ in named_modules] - names_no_target = [ - name - for name in key_list - if not any((name == suffix) or name.endswith("." + suffix) for suffix in target_modules) - ] - new_target_modules = _find_minimal_target_modules(target_modules, names_no_target) - - log.info( - f"Injecting LoRA as from {lora_path}, rank: {rank}, lora_alpha: {lora_alpha}, lora_bias: {lora_bias}, target_modules: {new_target_modules}" - ) - self.add_lora_to_model( - adapter_name=adapter_name, - lora_rank=rank, - lora_alpha=lora_alpha, - lora_bias=lora_bias, - lora_target_modules=new_target_modules, - training_mode=training_mode, - model=model, - ) - log.info(f"Injected LoRA weights as from {lora_path}") - else: - log.info(f"LoRA {adapter_name} skip injecting on this call.") - if not skip_load: - self.load_weights_to_lora( - pretrained_lora_state_dict=state_dict, - state_dict_converter=partial(_convert_non_diffusers_wan_lora_to_diffusers, adapter_name=adapter_name), - model=model, - ) - log.info(f"Loaded LoRA weights from {lora_path}") - else: - log.info(f"LoRA {adapter_name} skip loading on this call.") - return adapter_name - - def add_lora_to_model( - self, - adapter_name="default", - lora_rank=4, - lora_alpha=None, - lora_bias=False, - lora_target_modules="q,k,v,o,ffn.0,ffn.2", - training_mode=True, - init_lora_weights="kaiming", - model=None, - ): - if model is None: - model = self.net - - # Add LoRA to UNet - if init_lora_weights == "kaiming": - init_lora_weights = True - if lora_alpha is None: - lora_alpha = lora_rank - - if isinstance(lora_target_modules, str): - lora_target_modules = lora_target_modules.split(",") - - lora_config = LoraConfig( - r=lora_rank, - lora_alpha=lora_alpha, - init_lora_weights=init_lora_weights, - target_modules=lora_target_modules, - lora_bias=lora_bias, - ) - - # this op is inplace - inject_adapter_in_model(lora_config, model, adapter_name=adapter_name) - - # count trainable and total parameters - count_trainable_params = 0 - count_total_params = 0 - if training_mode: - for param in model.parameters(): - if param.requires_grad: - # # Upcast LoRA parameters into fp32 - # param.data = param.to(torch.float32) - param.data = param.data.to(self.precision) - count_trainable_params += param.numel() - count_total_params += param.numel() - log.info( - f"Trainable parameters after adding LoRA: {count_trainable_params:,} / Total parameters: {count_total_params:,}" - ) - - def load_weights_to_lora( - self, - pretrained_lora_path=None, - pretrained_lora_state_dict=None, - state_dict_converter=None, - model=None, - ): - assert pretrained_lora_path is None or pretrained_lora_state_dict is None, ( - "Only one of pretrained_lora_path or pretrained_lora_state_dict should be provided" - ) - - if model is None: - model = self.net - # Lora pretrained lora weights - if pretrained_lora_path is not None or pretrained_lora_state_dict is not None: - if pretrained_lora_path is not None and pretrained_lora_state_dict is None: - state_dict = load_state_dict(pretrained_lora_path, torch_dtype=self.precision) - pretrained_lora_state_dict = state_dict - if state_dict_converter is not None: - pretrained_lora_state_dict = state_dict_converter(pretrained_lora_state_dict) - if self.fsdp_device_mesh: - _state_dict = get_model_state_dict(model) - missing_keys = [] - unexpected_keys = [] - for k in _state_dict.keys(): - if "_extra_state" in k: - pass - if k in pretrained_lora_state_dict: - # set local tensor to DTensor - _state_dict[k] = distribute_tensor( - pretrained_lora_state_dict.pop(k), - _state_dict[k].device_mesh, - _state_dict[k].placements, - ) - else: - missing_keys.append(k) - unexpected_keys = list(pretrained_lora_state_dict.keys()) - log.info(set_model_state_dict(model, _state_dict, options=StateDictOptions(strict=True))) - else: - missing_keys, unexpected_keys = model.load_state_dict(pretrained_lora_state_dict, strict=False) - all_keys = [i for i, _ in model.named_parameters()] - if any("k_img" in k for k in unexpected_keys): - total_unexpected_keys = len(unexpected_keys) - unexpected_keys = [k for k in unexpected_keys if "k_img" not in k] - unexpected_keys = [k for k in unexpected_keys if "v_img" not in k] - unexpected_keys = [k for k in unexpected_keys if "img_emb.proj" not in k] - ignore_keys = total_unexpected_keys - len(unexpected_keys) - log.critical(f"You may loading a I2V LoRA into T2V model. Ignore {ignore_keys} unexpected keys.") - num_updated_keys = len(all_keys) - len(missing_keys) - num_unexpected_keys = len(unexpected_keys) - log.info( - f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected." - ) - if num_unexpected_keys > 0: - log.critical(f"Unexpected keys: {unexpected_keys}") - - def set_weights_and_activate_adapters(self, adapter_names, weights=None): - if isinstance(adapter_names, str): - adapter_names = [adapter_names] - if weights is None: - weights = [1.0] * len(adapter_names) - assert len(adapter_names) == len(weights), "adapter_names and weights should have the same length" - - def get_module_weight(weight_for_adapter, module_name): - if not isinstance(weight_for_adapter, dict): - # If weight_for_adapter is a single number, always return it. - return weight_for_adapter - - for layer_name, weight_ in weight_for_adapter.items(): - if layer_name in module_name: - return weight_ - - raise ValueError( - "weight_for_adapter should be a single number or a dict containing the layer " - f"name, got {weight_for_adapter} for {module_name}" - ) - - for module_name, module in self.net.named_modules(): - if isinstance(module, BaseTunerLayer): - # For backward compatibility with previous PEFT versions, set multiple active adapters - if hasattr(module, "set_adapter"): - module.set_adapter(adapter_names) - else: - module.active_adapter = adapter_names - - # Set the scaling weight for each adapter for this module - for adapter_name, weight in zip(adapter_names, weights): - module.set_scale(adapter_name, get_module_weight(weight, module_name)) - log.info(f"Set weights: {weights}, and activate adapters: {adapter_names}") - - def training_step( - self, data_batch: dict[str, torch.Tensor], iteration: int - ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: - """ - Performs a single training step for the diffusion model. - - This method is responsible for executing one iteration of the model's training. It involves: - 1. Adding noise to the input data using the SDE process. - 2. Passing the noisy data through the network to generate predictions. - 3. Computing the loss based on the difference between the predictions and the original data, \ - considering any configured loss weighting. - - Args: - data_batch (dict): raw data batch draw from the training data loader. - iteration (int): Current iteration number. - - Returns: - tuple: A tuple containing two elements: - - dict: additional data that used to debug / logging / callbacks - - Tensor: The computed loss for the training step as a PyTorch Tensor. - - Raises: - AssertionError: If the class is conditional, \ - but no number of classes is specified in the network configuration. - - Notes: - - The method handles different types of conditioning - - The method also supports Kendall's loss - """ - self._update_train_stats(data_batch) - # Get the input data to noise and denoise~(image, video) and the corresponding conditioner. - _, x0_B_C_T_H_W, condition = self.get_data_and_condition(data_batch) - - # Sample pertubation noise levels and N(0, 1) noises - epsilon_B_C_T_H_W = torch.randn(x0_B_C_T_H_W.size(), **self.flow_matching_kwargs) - batch_size = x0_B_C_T_H_W.size()[0] - t_B = self.rectified_flow.sample_train_time(batch_size).to(**self.flow_matching_kwargs) - t_B = rearrange(t_B, "b -> b 1") # add a dimension for T, all frames share the same sigma - x0_B_C_T_H_W, condition, epsilon_B_C_T_H_W, t_B = self.broadcast_split_for_model_parallelsim( - x0_B_C_T_H_W, condition, epsilon_B_C_T_H_W, t_B - ) - timesteps = self.rectified_flow.get_discrete_timestamp(t_B, self.flow_matching_kwargs) - sigmas = self.rectified_flow.get_sigmas( - timesteps, - self.flow_matching_kwargs, - ) - timesteps = rearrange(timesteps, "b -> b 1") - sigmas = rearrange(sigmas, "b -> b 1") - xt_B_C_T_H_W, vt_B_C_T_H_W = self.rectified_flow.get_interpolation(epsilon_B_C_T_H_W, x0_B_C_T_H_W, sigmas) - - vt_pred_B_C_T_H_W = self.net( - x_B_C_T_H_W=xt_B_C_T_H_W.to(**self.tensor_kwargs), - timesteps_B_T=timesteps.to(**self.tensor_kwargs), - **condition.to_dict(), - ) - - time_weights_B = self.rectified_flow.train_time_weight(timesteps, self.flow_matching_kwargs) - per_instance_loss = torch.mean( - (vt_pred_B_C_T_H_W - vt_B_C_T_H_W) ** 2, dim=list(range(1, vt_pred_B_C_T_H_W.dim())) - ) - loss = torch.mean(time_weights_B * per_instance_loss) - output_batch = {"edm_loss": loss} - - return output_batch, loss - - @staticmethod - def get_context_parallel_group(): - if parallel_state.is_initialized(): - return parallel_state.get_context_parallel_group() - return None - - def broadcast_split_for_model_parallelsim(self, x0_B_C_T_H_W, condition, epsilon_B_C_T_H_W, sigma_B_T): - """ - Broadcast and split the input data and condition for model parallelism. - Currently, we only support context parallelism. - """ - cp_group = self.get_context_parallel_group() - cp_size = 1 if cp_group is None else cp_group.size() - if condition.is_video and cp_size > 1: - if x0_B_C_T_H_W is not None: - x0_B_C_T_H_W = broadcast_split_tensor(x0_B_C_T_H_W, seq_dim=2, process_group=cp_group) - epsilon_B_C_T_H_W = broadcast_split_tensor(epsilon_B_C_T_H_W, seq_dim=2, process_group=cp_group) - if sigma_B_T is not None: - assert sigma_B_T.ndim == 2, "sigma_B_T should be 2D tensor" - if sigma_B_T.shape[-1] == 1: # single sigma is shared across all frames - sigma_B_T = broadcast(sigma_B_T, cp_group) - else: # different sigma for each frame - sigma_B_T = broadcast_split_tensor(sigma_B_T, seq_dim=1, process_group=cp_group) - if condition is not None: - condition = condition.broadcast(cp_group) - self.net.enable_context_parallel(cp_group) - else: - self.net.disable_context_parallel() - - return x0_B_C_T_H_W, condition, epsilon_B_C_T_H_W, sigma_B_T - - def _update_train_stats(self, data_batch: dict[str, torch.Tensor]) -> None: - is_image = self.is_image_batch(data_batch) - input_key = self.input_image_key if is_image else self.input_data_key - if isinstance(self.net, WeightTrainingStat): - if is_image: - self.net.accum_image_sample_counter += data_batch[input_key].shape[0] * self.data_parallel_size - else: - self.net.accum_video_sample_counter += data_batch[input_key].shape[0] * self.data_parallel_size - - # ------------------------ Sampling ------------------------ - - def get_x0_fn_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - is_negative_prompt: bool = False, - ) -> Callable: - """ - Generates a callable function `x0_fn` based on the provided data batch and guidance factor. - - This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states. - - Args: - - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner` - - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5. - - is_negative_prompt (bool): use negative prompt t5 in uncondition if true - - Returns: - - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin - - The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence. - """ - _, x0, _ = self.get_data_and_condition(data_batch) # we need always process the data batch first. - is_image_batch = self.is_image_batch(data_batch) - - if is_negative_prompt: - condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch) - else: - condition, uncondition = self.conditioner.get_condition_uncondition(data_batch) - - condition = condition.edit_data_type(DataType.IMAGE if is_image_batch else DataType.VIDEO) - uncondition = uncondition.edit_data_type(DataType.IMAGE if is_image_batch else DataType.VIDEO) - _, condition, _, _ = self.broadcast_split_for_model_parallelsim(x0, condition, None, None) - _, uncondition, _, _ = self.broadcast_split_for_model_parallelsim(x0, uncondition, None, None) - - # For inference, check if parallel_state is initialized - if parallel_state.is_initialized(): - pass - else: - assert not self.net.is_context_parallel_enabled, ( - "parallel_state is not initialized, context parallel should be turned off." - ) - - def x0_fn(noise_x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: - if guidance == 1.0: - cond_v = self.denoise(noise_x, timestep, condition) - noise_pred = cond_v - elif guidance == 0.0: - uncond_v = self.denoise(noise_x, timestep, uncondition) - noise_pred = uncond_v - else: - cond_v = self.denoise(noise_x, timestep, condition) - uncond_v = self.denoise(noise_x, timestep, uncondition) - noise_pred = uncond_v + guidance * (cond_v - uncond_v) - return noise_pred - - return x0_fn - - @sync_timer("WANDiffusionModel: generate_samples_from_batch") - @torch.no_grad() - def generate_samples_from_batch( - self, - data_batch: Dict, - guidance: float = 1.5, - seed: int = 1, - state_shape: Tuple | None = None, - n_sample: int | None = None, - is_negative_prompt: bool = False, - num_steps: int = 35, - shift: float = 5.0, - **kwargs, - ) -> torch.Tensor: - """ - Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples. - Args: - data_batch (dict): raw data batch draw from the training data loader. - iteration (int): Current iteration number. - guidance (float): guidance weights - seed (int): random seed - state_shape (tuple): shape of the state, default to data batch if not provided - n_sample (int): number of samples to generate - is_negative_prompt (bool): use negative prompt t5 in uncondition if true - num_steps (int): number of steps for the diffusion process - solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver) - """ - - self._normalize_video_databatch_inplace(data_batch) - self._augment_image_dim_inplace(data_batch) - is_image_batch = self.is_image_batch(data_batch) - input_key = self.input_image_key if is_image_batch else self.input_data_key - if n_sample is None: - n_sample = data_batch[input_key].shape[0] - if state_shape is None: - _T, _H, _W = data_batch[input_key].shape[-3:] - state_shape = [ - self.config.state_ch, - self.tokenizer.get_latent_num_frames(_T), - _H // self.tokenizer.spatial_compression_factor, - _W // self.tokenizer.spatial_compression_factor, - ] - - noise = misc.arch_invariant_rand( - (n_sample,) + tuple(state_shape), - torch.float32, - self.tensor_kwargs["device"], - seed, - ) - - seed_g = torch.Generator(device=self.tensor_kwargs["device"]) - seed_g.manual_seed(seed) - - self.sample_scheduler.set_timesteps(num_steps, device=self.tensor_kwargs["device"], shift=shift) - - timesteps = self.sample_scheduler.timesteps - - x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt) - latents = noise - - if self.net.is_context_parallel_enabled: - latents = broadcast_split_tensor(latents, seq_dim=2, process_group=self.get_context_parallel_group()) - - with sync_timer(f"WANDiffusionModel: generate_samples_from_batch: {num_steps} diffusion_steps"): - for _, t in enumerate(timesteps): - latent_model_input = latents - timestep = [t] - - timestep = torch.stack(timestep) - - velocity_field_pred = x0_fn(latent_model_input, timestep.unsqueeze(0)) # velocity field - temp_x0 = self.sample_scheduler.step( - velocity_field_pred.unsqueeze(0), t, latents, return_dict=False, generator=seed_g - )[0] - latents = temp_x0.squeeze(0) - - if self.net.is_context_parallel_enabled: - latents = cat_outputs_cp(latents, seq_dim=2, cp_group=self.get_context_parallel_group()) - - return latents - - @torch.no_grad() - def validation_step( - self, data: dict[str, torch.Tensor], iteration: int - ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: - pass - - @torch.no_grad() - def forward(self, xt, t, condition: T2VCondition): - pass - - def get_data_and_condition( - self, data_batch: dict[str, torch.Tensor], return_latent_state: bool = True - ) -> Tuple[Tensor, Tensor, T2VCondition]: - self._normalize_video_databatch_inplace(data_batch) - self._augment_image_dim_inplace(data_batch) - is_image_batch = self.is_image_batch(data_batch) - - # Latent state - raw_state = data_batch[self.input_image_key if is_image_batch else self.input_data_key] - if return_latent_state: - latent_state = self.encode(raw_state).contiguous().float() - else: - latent_state = None - - # Condition - condition = self.conditioner(data_batch) - condition = condition.edit_data_type(DataType.IMAGE if is_image_batch else DataType.VIDEO) - return raw_state, latent_state, condition - - def _normalize_video_databatch_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: - """ - Normalizes video data in-place on a CUDA device to reduce data loading overhead. - - This function modifies the video data tensor within the provided data_batch dictionary - in-place, scaling the uint8 data from the range [0, 255] to the normalized range [-1, 1]. - - Warning: - A warning is issued if the data has not been previously normalized. - - Args: - data_batch (dict[str, Tensor]): A dictionary containing the video data under a specific key. - This tensor is expected to be on a CUDA device and have dtype of torch.uint8. - - Side Effects: - Modifies the 'input_data_key' tensor within the 'data_batch' dictionary in-place. - - Note: - This operation is performed directly on the CUDA device to avoid the overhead associated - with moving data to/from the GPU. Ensure that the tensor is already on the appropriate device - and has the correct dtype (torch.uint8) to avoid unexpected behaviors. - """ - input_key = self.input_data_key if input_key is None else input_key - # only handle video batch - if input_key in data_batch: - # Check if the data has already been normalized and avoid re-normalizing - _flag = data_batch.get(IS_PREPROCESSED_KEY, False) - if isinstance(_flag, torch.Tensor): - try: - _flag = bool(_flag.bool().all().item()) - except Exception: - _flag = False - else: - _flag = bool(_flag) - - if _flag: - assert torch.is_floating_point(data_batch[input_key]), "Video data is not in float format." - assert torch.all((data_batch[input_key] >= -1.0001) & (data_batch[input_key] <= 1.0001)), ( - f"Video data is not in the range [-1, 1]. get data range [{data_batch[input_key].min()}, {data_batch[input_key].max()}]" - ) - else: - assert data_batch[input_key].dtype == torch.uint8, "Video data is not in uint8 format." - data_batch[input_key] = data_batch[input_key].to(**self.tensor_kwargs) / 127.5 - 1.0 - data_batch[IS_PREPROCESSED_KEY] = True - - def _augment_image_dim_inplace(self, data_batch: dict[str, Tensor], input_key: str = None) -> None: - input_key = self.input_image_key if input_key is None else input_key - if input_key in data_batch: - # Check if the data has already been augmented and avoid re-augmenting - _flag = data_batch.get(IS_PREPROCESSED_KEY, False) - if isinstance(_flag, torch.Tensor): - try: - _flag = bool(_flag.bool().all().item()) - except Exception: - _flag = False - else: - _flag = bool(_flag) - - if _flag: - assert data_batch[input_key].shape[2] == 1, ( - f"Image data is claimed be augmented while its shape is {data_batch[input_key].shape}" - ) - return - else: - data_batch[input_key] = rearrange(data_batch[input_key], "b c h w -> b c 1 h w").contiguous() - data_batch[IS_PREPROCESSED_KEY] = True - - # ------------------ Checkpointing ------------------ - - def state_dict(self) -> Dict[str, Any]: - net_state_dict = self.net.state_dict(prefix="net.") - if self.config.ema.enabled: - ema_state_dict = self.net_ema.state_dict(prefix="net_ema.") - net_state_dict.update(ema_state_dict) - return net_state_dict - - def load_state_dict( - self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False, pretrain_copy: bool = False - ) -> None: - """Only called when using .pth checkpoint - Loads a state dictionary into the model and optionally its EMA counterpart. - Different from torch strict=False mode, the method will not raise error for unmatched state shape while raise warning. - - Parameters:e - state_dict (Mapping[str, Any]): A dictionary containing separate state dictionaries for the model and - potentially for an EMA version of the model under the keys 'model' and 'ema', respectively. - strict (bool, optional): If True, the method will enforce that the keys in the state dict match exactly - those in the model and EMA model (if applicable). Defaults to True. - assign (bool, optional): If True and in strict mode, will assign the state dictionary directly rather than - matching keys one-by-one. This is typically used when loading parts of state dicts - or using customized loading procedures. Defaults to False. - """ - - if pretrain_copy: - if strict: - reg_results: _IncompatibleKeys = self.net.load_state_dict(state_dict, strict=strict, assign=assign) - - if self.config.ema.enabled: - ema_results: _IncompatibleKeys = self.net_ema.load_state_dict( - state_dict, strict=strict, assign=assign - ) - - return _IncompatibleKeys( - missing_keys=reg_results.missing_keys - + (ema_results.missing_keys if self.config.ema.enabled else []), - unexpected_keys=reg_results.unexpected_keys - + (ema_results.unexpected_keys if self.config.ema.enabled else []), - ) - else: - log.critical("load model in non-strict mode") - log.critical(non_strict_load_model(self.net, state_dict), rank0_only=False) - if self.config.ema.enabled: - log.critical("load ema model in non-strict mode") - log.critical(non_strict_load_model(self.net_ema, state_dict), rank0_only=False) - else: - _reg_state_dict = collections.OrderedDict() - _ema_state_dict = collections.OrderedDict() - for k, v in state_dict.items(): - if k.startswith("net."): - _reg_state_dict[k.replace("net.", "")] = v - elif k.startswith("net_ema."): - _ema_state_dict[k.replace("net_ema.", "")] = v - else: - _reg_state_dict[k] = v - - state_dict = _reg_state_dict - if strict: - reg_results: _IncompatibleKeys = self.net.load_state_dict(_reg_state_dict, strict=strict, assign=assign) - - if self.config.ema.enabled: - ema_results: _IncompatibleKeys = self.net_ema.load_state_dict( - _ema_state_dict, strict=strict, assign=assign - ) - - return _IncompatibleKeys( - missing_keys=reg_results.missing_keys - + (ema_results.missing_keys if self.config.ema.enabled else []), - unexpected_keys=reg_results.unexpected_keys - + (ema_results.unexpected_keys if self.config.ema.enabled else []), - ) - else: - log.warning("load model in non-strict mode") - log.warning(non_strict_load_model(self.net, _reg_state_dict), rank0_only=False) - if self.config.ema.enabled: - log.warning("load ema model in non-strict mode") - log.warning(non_strict_load_model(self.net_ema, _ema_state_dict), rank0_only=False) - - # ------------------ public methods ------------------ - def ema_beta(self, iteration: int) -> float: - """ - Calculate the beta value for EMA update. - weights = weights * beta + (1 - beta) * new_weights - - Args: - iteration (int): Current iteration number. - - Returns: - float: The calculated beta value. - """ - iteration = iteration + self.config.ema.iteration_shift - # Prevent iteration from being 0 or negative to avoid beta=0.0 or division issues - if iteration <= 0: - return 0.0 - # Safe division: iteration + 1 is at least 1 - return (1 - 1 / (iteration + 1)) ** (self.ema_exp_coefficient + 1) - - def model_param_stats(self) -> Dict[str, int]: - return {"total_learnable_param_num": self._param_count} - - def is_image_batch(self, data_batch: dict[str, Tensor]) -> bool: - """We hanlde two types of data_batch. One comes from a joint_dataloader where "dataset_name" can be used to differenciate image_batch and video_batch. - Another comes from a dataloader which we by default assumes as video_data for video model training. - """ - is_image = self.input_image_key in data_batch - is_video = self.input_data_key in data_batch - assert is_image != is_video, ( - "Only one of the input_image_key or input_data_key should be present in the data_batch." - ) - return is_image - - def return_data_type(self, data_batch: dict[str, Tensor]) -> DataType: - if self.is_image_batch(data_batch): - return "image" - else: - return "video" - - def denoise(self, xt_B_C_T_H_W: torch.Tensor, timestep: torch.Tensor, condition: T2VCondition) -> DenoisePrediction: - """ - Performs denoising on the input noise data, noise level, and condition - - Args: - xt (torch.Tensor): The input noise data. - timestep (torch.Tensor): The timestep level. - condition (T2VCondition): conditional information, generated from self.conditioner - - Returns: - DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \ - noise prediction (eps_pred). - """ - # forward pass through the network - net_output_B_C_T_H_W = self.net( - x_B_C_T_H_W=(xt_B_C_T_H_W).to(**self.tensor_kwargs), - timesteps_B_T=timestep, - **condition.to_dict(), - ).float() - - return net_output_B_C_T_H_W - - @torch.no_grad() - def encode(self, state: torch.Tensor) -> torch.Tensor: - return self.tokenizer.encode(state) - - @torch.no_grad() - def decode(self, latent: torch.Tensor, T_latent_seq_lens: Optional[torch.Tensor] = None) -> torch.Tensor: - if T_latent_seq_lens is not None: - latent_list = torch.split(latent, T_latent_seq_lens.tolist(), dim=2) - decoded_list = [self.tokenizer.decode(latent) for latent in latent_list] - return torch.cat(decoded_list, dim=2) - else: - return self.tokenizer.decode(latent) - - def get_video_height_width(self) -> Tuple[int, int]: - return VIDEO_RES_SIZE_INFO[self.config.resolution]["9,16"] - - def get_video_latent_height_width(self) -> Tuple[int, int]: - height, width = VIDEO_RES_SIZE_INFO[self.config.resolution]["9,16"] - return height // self.tokenizer.spatial_compression_factor, width // self.tokenizer.spatial_compression_factor - - def get_num_video_latent_frames(self) -> int: - return self.config.state_t - - @contextmanager - def ema_scope(self, context=None, is_cpu=False): - if self.config.ema.enabled: - # https://github.com/pytorch/pytorch/issues/144289 - for module in self.net.modules(): - if isinstance(module, FSDPModule): - module.reshard() - self.net_ema_worker.cache(self.net.parameters(), is_cpu=is_cpu) - self.net_ema_worker.copy_to(src_model=self.net_ema, tgt_model=self.net) - if context is not None: - log.info(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.config.ema.enabled: - for module in self.net.modules(): - if isinstance(module, FSDPModule): - module.reshard() - self.net_ema_worker.restore(self.net.parameters()) - if context is not None: - log.info(f"{context}: Restored training weights") - - def clip_grad_norm_( - self, - max_norm: float, - norm_type: float = 2.0, - error_if_nonfinite: bool = False, - foreach: Optional[bool] = None, - ): - return clip_grad_norm_( - self.net.parameters(), - max_norm, - norm_type=norm_type, - error_if_nonfinite=error_if_nonfinite, - foreach=foreach, - ) - - -NUM_CONDITIONAL_FRAMES_KEY: str = "num_conditional_frames" diff --git a/lyra_2/_src/modules/__init__.py b/lyra_2/_src/modules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/lyra_2/_src/modules/attention.py b/lyra_2/_src/modules/attention.py deleted file mode 100644 index 6cd759cd5148fec58460111880000206b25583f3..0000000000000000000000000000000000000000 --- a/lyra_2/_src/modules/attention.py +++ /dev/null @@ -1,181 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# From Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. - -# Description: -# Single point of entry for all generic attention ops (self and cross attention), that tries to -# deliver the best performance possible given any use case (GPU and environment). -# -# On Hopper GPUs (i.e. H100, H20, H200), Flash Attention 3 is the best-performing choice, but it -# needs to be installed. When it is not available, the second best choice is cuDNN attention, which -# we get using PyTorch's SDPA API. -# -# For all other use cases, we will just use PyTorch's SDPA, but we need to specify backends and -# priorities. -# Flash Attention 2, which is one of the backends, is the best choice for Ampere GPUs (both RTX and -# datacenter-class). -# -# For anything pre-Ampere, the only choice is "memory-efficient" (xformers) FMHA. -# -# For Ada and Blackwell RTX, it is unclear at the moment, so we defer to Flash Attention 2, and -# fallbacks are cuDNN and xformers. -# -# For Blackwell datacenter-class (B200, GB200), cuDNN is the best choice. -# -# -# Dispatching to the desired backends/paths are done by checking the compute capability (really SM -# number, which is just compute capability * 10) of the GPU device the input tensors are on. -# -# Here's a breakdown of relevant compute capabilities: -# -# | GPU / category | Arch | -# |================|=======| -# | A100 | SM80 | -# | A40 | SM80 | -# | Ampere RTX | SM86 | -# |----------------|-------| -# | Ada Lovelace | SM89 | -# |----------------|-------| -# | H20 | SM90 | -# | H100 | SM90 | -# | H200 | SM90 | -# |----------------|-------| -# | B200 | SM100 | -# | Blackwell RTX | SM103 | -# |----------------|-------| -# - -from functools import partial - -import torch -from torch.nn.attention import SDPBackend, sdpa_kernel - -try: - from flash_attn_3.flash_attn_interface import flash_attn_func - - FLASH_ATTN_3_AVAILABLE = True -except ModuleNotFoundError: - FLASH_ATTN_3_AVAILABLE = False - - -def get_device_cc(device) -> int: - """ - Returns the compute capability of a given torch device if it's a CUDA device, otherwise returns 0. - - Args: - device: torch device. - - Returns: - device_cc (int): compute capability in the SmXXX format (i.e. 90 for Hopper). - """ - if torch.cuda.is_available() and torch.version.cuda and device.type == "cuda": - major, minor = torch.cuda.get_device_capability(device) - return major * 10 + minor - return 0 - - -def attention( - q, - k, - v, - q_lens=None, - k_lens=None, - dropout_p=0.0, - softmax_scale=None, - q_scale=None, - causal=False, - deterministic=False, - dtype=torch.bfloat16, -): - supported_dtypes = [torch.bfloat16, torch.float16, torch.float32] - is_half = dtype in [torch.bfloat16, torch.float16] - compute_cap = get_device_cc(q.device) - - if dtype not in supported_dtypes: - raise NotImplementedError(f"{dtype=} is not supported.") - - q = q.to(dtype) - k = k.to(dtype) - v = v.to(dtype) - - if q_scale is not None: - q = q * q_scale - - # If Flash Attention 3 is installed, and the user's running on a Hopper GPU (compute capability - # 9.0, or SM90), use Flash Attention 3. - if compute_cap == 90 and FLASH_ATTN_3_AVAILABLE and is_half: - return flash_attn_func( - q=q, - k=k, - v=v, - softmax_scale=softmax_scale, - causal=causal, - deterministic=deterministic, - )[0] - else: - # If Blackwell or Hopper (SM100 or SM90), cuDNN has native FMHA kernels. The Hopper one is - # not always as fast as Flash Attention 3, but when Flash Attention is unavailable, it's - # still a far better choice than Flash Attention 2 (Ampere). - if compute_cap in [90, 100] and is_half: - SDPA_BACKENDS = [ - SDPBackend.CUDNN_ATTENTION, - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - ] - BEST_SDPA_BACKEND = SDPBackend.CUDNN_ATTENTION - elif is_half: - SDPA_BACKENDS = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.CUDNN_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - ] - BEST_SDPA_BACKEND = SDPBackend.FLASH_ATTENTION if compute_cap >= 80 else SDPBackend.EFFICIENT_ATTENTION - else: - assert dtype == torch.float32, f"Unrecognized {dtype=}." - SDPA_BACKENDS = [SDPBackend.EFFICIENT_ATTENTION] - BEST_SDPA_BACKEND = SDPBackend.EFFICIENT_ATTENTION - - if deterministic: - raise NotImplementedError( - "Deterministic mode in attention is only supported when Flash Attention 3 is available." - ) - - # Torch 2.6 and later allows priorities for backends, but for older versions - # we can only run with a specific backend. As long as we pick ones we're certain - # will work on that device, it should be fine. - try: - sdpa_kernel(backends=SDPA_BACKENDS, set_priority_order=True) - sdpa_kernel_ = partial(sdpa_kernel, set_priority_order=True) - except TypeError: - sdpa_kernel_ = sdpa_kernel - SDPA_BACKENDS = [BEST_SDPA_BACKEND] - - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - with sdpa_kernel_(backends=SDPA_BACKENDS): - out = torch.nn.functional.scaled_dot_product_attention( - q, - k, - v, - is_causal=causal, - dropout_p=dropout_p, - scale=softmax_scale, - ) - - out = out.transpose(1, 2).contiguous() - return out diff --git a/lyra_2/_src/modules/clip.py b/lyra_2/_src/modules/clip.py deleted file mode 100644 index 8c81f6d6e512ea049abae883bccececf9e3ad484..0000000000000000000000000000000000000000 --- a/lyra_2/_src/modules/clip.py +++ /dev/null @@ -1,529 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip'' -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. - -import math -from typing import Dict, List, Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.transforms as T - -from lyra_2._ext.imaginaire.utils import distributed, log -from lyra_2._ext.imaginaire.utils.easy_io import easy_io -from lyra_2._src.modules.attention import attention -from lyra_2._src.modules.conditioner import AbstractEmbModel -from lyra_2._src.modules.umt5 import HuggingfaceTokenizer -from lyra_2._src.modules.xlm_roberta import XLMRoberta - -__all__ = [ - "CLIPModel", -] - - -class QuickGELU(nn.Module): - def forward(self, x): - return x * torch.sigmoid(1.702 * x) - - -class LayerNorm(nn.LayerNorm): - def forward(self, x): - return super().forward(x.float()).type_as(x) - - -class SelfAttention(nn.Module): - def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0): - assert dim % num_heads == 0 - super().__init__() - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.causal = causal - self.attn_dropout = attn_dropout - self.proj_dropout = proj_dropout - - # layers - self.to_qkv = nn.Linear(dim, dim * 3) - self.proj = nn.Linear(dim, dim) - - def forward(self, x): - """ - x: [B, L, C]. - """ - b, s, c, n, d = *x.size(), self.num_heads, self.head_dim - - # compute query, key, value - q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) - - # compute attention - p = self.attn_dropout if self.training else 0.0 - x = attention(q, k, v, dropout_p=p, causal=self.causal) - x = x.reshape(b, s, c) - - # output - x = self.proj(x) - x = F.dropout(x, self.proj_dropout, self.training) - return x - - -class SwiGLU(nn.Module): - def __init__(self, dim, mid_dim): - super().__init__() - self.dim = dim - self.mid_dim = mid_dim - - # layers - self.fc1 = nn.Linear(dim, mid_dim) - self.fc2 = nn.Linear(dim, mid_dim) - self.fc3 = nn.Linear(mid_dim, dim) - - def forward(self, x): - x = F.silu(self.fc1(x)) * self.fc2(x) - x = self.fc3(x) - return x - - -class AttentionBlock(nn.Module): - def __init__( - self, - dim, - mlp_ratio, - num_heads, - post_norm=False, - causal=False, - activation="quick_gelu", - attn_dropout=0.0, - proj_dropout=0.0, - norm_eps=1e-5, - ): - assert activation in ["quick_gelu", "gelu", "swi_glu"] - super().__init__() - self.dim = dim - self.mlp_ratio = mlp_ratio - self.num_heads = num_heads - self.post_norm = post_norm - self.causal = causal - self.norm_eps = norm_eps - - # layers - self.norm1 = LayerNorm(dim, eps=norm_eps) - self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout) - self.norm2 = LayerNorm(dim, eps=norm_eps) - if activation == "swi_glu": - self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) - else: - self.mlp = nn.Sequential( - nn.Linear(dim, int(dim * mlp_ratio)), - QuickGELU() if activation == "quick_gelu" else nn.GELU(), - nn.Linear(int(dim * mlp_ratio), dim), - nn.Dropout(proj_dropout), - ) - - def forward(self, x): - if self.post_norm: - x = x + self.norm1(self.attn(x)) - x = x + self.norm2(self.mlp(x)) - else: - x = x + self.attn(self.norm1(x)) - x = x + self.mlp(self.norm2(x)) - return x - - -class AttentionPool(nn.Module): - def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5): - assert dim % num_heads == 0 - super().__init__() - self.dim = dim - self.mlp_ratio = mlp_ratio - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.proj_dropout = proj_dropout - self.norm_eps = norm_eps - - # layers - gain = 1.0 / math.sqrt(dim) - self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) - self.to_q = nn.Linear(dim, dim) - self.to_kv = nn.Linear(dim, dim * 2) - self.proj = nn.Linear(dim, dim) - self.norm = LayerNorm(dim, eps=norm_eps) - self.mlp = nn.Sequential( - nn.Linear(dim, int(dim * mlp_ratio)), - QuickGELU() if activation == "quick_gelu" else nn.GELU(), - nn.Linear(int(dim * mlp_ratio), dim), - nn.Dropout(proj_dropout), - ) - - def forward(self, x): - """ - x: [B, L, C]. - """ - b, s, c, n, d = *x.size(), self.num_heads, self.head_dim - - # compute query, key, value - q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1) - k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) - - # compute attention - x = attention(q, k, v) - x = x.reshape(b, 1, c) - - # output - x = self.proj(x) - x = F.dropout(x, self.proj_dropout, self.training) - - # mlp - x = x + self.mlp(self.norm(x)) - return x[:, 0] - - -class VisionTransformer(nn.Module): - def __init__( - self, - image_size=224, - patch_size=16, - dim=768, - mlp_ratio=4, - out_dim=512, - num_heads=12, - num_layers=12, - pool_type="token", - pre_norm=True, - post_norm=False, - activation="quick_gelu", - attn_dropout=0.0, - proj_dropout=0.0, - embedding_dropout=0.0, - norm_eps=1e-5, - ): - if image_size % patch_size != 0: - print("[WARNING] image_size is not divisible by patch_size", flush=True) - assert pool_type in ("token", "token_fc", "attn_pool") - out_dim = out_dim or dim - super().__init__() - self.image_size = image_size - self.patch_size = patch_size - self.num_patches = (image_size // patch_size) ** 2 - self.dim = dim - self.mlp_ratio = mlp_ratio - self.out_dim = out_dim - self.num_heads = num_heads - self.num_layers = num_layers - self.pool_type = pool_type - self.post_norm = post_norm - self.norm_eps = norm_eps - - # embeddings - gain = 1.0 / math.sqrt(dim) - self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm) - if pool_type in ("token", "token_fc"): - self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) - self.pos_embedding = nn.Parameter( - gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim) - ) - self.dropout = nn.Dropout(embedding_dropout) - - # transformer - self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None - self.transformer = nn.Sequential( - *[ - AttentionBlock( - dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps - ) - for _ in range(num_layers) - ] - ) - self.post_norm = LayerNorm(dim, eps=norm_eps) - - # head - if pool_type == "token": - self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) - elif pool_type == "token_fc": - self.head = nn.Linear(dim, out_dim) - elif pool_type == "attn_pool": - self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps) - - def forward(self, x, interpolation=False, use_31_block=False): - b = x.size(0) - - # embeddings - x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) - if self.pool_type in ("token", "token_fc"): - x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1) - if interpolation: - e = pos_interpolate(self.pos_embedding, x.size(1)) # noqa: F821 - else: - e = self.pos_embedding - x = self.dropout(x + e) - if self.pre_norm is not None: - x = self.pre_norm(x) - - # transformer - if use_31_block: - x = self.transformer[:-1](x) - return x - else: - x = self.transformer(x) - return x - - -class XLMRobertaWithHead(XLMRoberta): - def __init__(self, **kwargs): - self.out_dim = kwargs.pop("out_dim") - super().__init__(**kwargs) - - # head - mid_dim = (self.dim + self.out_dim) // 2 - self.head = nn.Sequential( - nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False) - ) - - def forward(self, ids): - # xlm-roberta - x = super().forward(ids) - - # average pooling - mask = ids.ne(self.pad_id).unsqueeze(-1).to(x) - x = (x * mask).sum(dim=1) / mask.sum(dim=1) - - # head - x = self.head(x) - return x - - -class XLMRobertaCLIP(nn.Module): - def __init__( - self, - embed_dim=1024, - image_size=224, - patch_size=14, - vision_dim=1280, - vision_mlp_ratio=4, - vision_heads=16, - vision_layers=32, - vision_pool="token", - vision_pre_norm=True, - vision_post_norm=False, - activation="gelu", - vocab_size=250002, - max_text_len=514, - type_size=1, - pad_id=1, - text_dim=1024, - text_heads=16, - text_layers=24, - text_post_norm=True, - text_dropout=0.1, - attn_dropout=0.0, - proj_dropout=0.0, - embedding_dropout=0.0, - norm_eps=1e-5, - ): - super().__init__() - self.embed_dim = embed_dim - self.image_size = image_size - self.patch_size = patch_size - self.vision_dim = vision_dim - self.vision_mlp_ratio = vision_mlp_ratio - self.vision_heads = vision_heads - self.vision_layers = vision_layers - self.vision_pre_norm = vision_pre_norm - self.vision_post_norm = vision_post_norm - self.activation = activation - self.vocab_size = vocab_size - self.max_text_len = max_text_len - self.type_size = type_size - self.pad_id = pad_id - self.text_dim = text_dim - self.text_heads = text_heads - self.text_layers = text_layers - self.text_post_norm = text_post_norm - self.norm_eps = norm_eps - - # models - self.visual = VisionTransformer( - image_size=image_size, - patch_size=patch_size, - dim=vision_dim, - mlp_ratio=vision_mlp_ratio, - out_dim=embed_dim, - num_heads=vision_heads, - num_layers=vision_layers, - pool_type=vision_pool, - pre_norm=vision_pre_norm, - post_norm=vision_post_norm, - activation=activation, - attn_dropout=attn_dropout, - proj_dropout=proj_dropout, - embedding_dropout=embedding_dropout, - norm_eps=norm_eps, - ) - self.textual = XLMRobertaWithHead( - vocab_size=vocab_size, - max_seq_len=max_text_len, - type_size=type_size, - pad_id=pad_id, - dim=text_dim, - out_dim=embed_dim, - num_heads=text_heads, - num_layers=text_layers, - post_norm=text_post_norm, - dropout=text_dropout, - ) - self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) - - def forward(self, imgs, txt_ids): - """ - imgs: [B, 3, H, W] of torch.float32. - - mean: [0.48145466, 0.4578275, 0.40821073] - - std: [0.26862954, 0.26130258, 0.27577711] - txt_ids: [B, L] of torch.long. - Encoded by data.CLIPTokenizer. - """ - xi = self.visual(imgs) - xt = self.textual(txt_ids) - return xi, xt - - def param_groups(self): - groups = [ - { - "params": [p for n, p in self.named_parameters() if "norm" in n or n.endswith("bias")], - "weight_decay": 0.0, - }, - {"params": [p for n, p in self.named_parameters() if not ("norm" in n or n.endswith("bias"))]}, - ] - return groups - - -def _clip( - pretrained=False, - pretrained_name=None, - model_cls=XLMRobertaCLIP, - return_transforms=False, - return_tokenizer=False, - tokenizer_padding="eos", - dtype=torch.float32, - device="cpu", - **kwargs, -): - # init a model on device - with torch.device(device): - model = model_cls(**kwargs) - - # set device - model = model.to(dtype=dtype, device=device) - output = (model,) - - # init transforms - if return_transforms: - # mean and std - if "siglip" in pretrained_name.lower(): - mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] - else: - mean = [0.48145466, 0.4578275, 0.40821073] - std = [0.26862954, 0.26130258, 0.27577711] - - # transforms - transforms = T.Compose( - [ - T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC), - T.ToTensor(), - T.Normalize(mean=mean, std=std), - ] - ) - output += (transforms,) - return output[0] if len(output) == 1 else output - - -def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs): - cfg = dict( - embed_dim=1024, - image_size=224, - patch_size=14, - vision_dim=1280, - vision_mlp_ratio=4, - vision_heads=16, - vision_layers=32, - vision_pool="token", - activation="gelu", - vocab_size=250002, - max_text_len=514, - type_size=1, - pad_id=1, - text_dim=1024, - text_heads=16, - text_layers=24, - text_post_norm=True, - text_dropout=0.1, - attn_dropout=0.0, - proj_dropout=0.0, - embedding_dropout=0.0, - ) - cfg.update(**kwargs) - return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) - - -def load_model_torch(model, ckpt_path): - log.info(f"loading weights from {ckpt_path}") - if distributed.is_rank0(): - ckpt = easy_io.load( - ckpt_path, map_location="cpu", fast_backend=False, weights_only=True - ) - model.load_state_dict(ckpt) - - distributed.sync_model_states(model, src=0) - return model - - -class CLIPModel: - def __init__( - self, - dtype=torch.float16, - device="cuda", - checkpoint_path="./checkpoints/image_encoder/model.pth", - tokenizer_path="xlm-roberta-large", - ): - self.dtype = dtype - self.device = device - self.checkpoint_path = checkpoint_path - self.tokenizer_path = tokenizer_path - - # init model - self.model, self.transforms = clip_xlm_roberta_vit_h_14( - pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device - ) - self.model = self.model.cuda().eval().requires_grad_(False) - self.model = load_model_torch(self.model, checkpoint_path) - - # init tokenizer - self.tokenizer = HuggingfaceTokenizer( - name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace" - ) - - def visual(self, videos_B_C_H_W_n1_p1): - # preprocess - size = (self.model.image_size,) * 2 - videos = F.interpolate(videos_B_C_H_W_n1_p1, size=size, mode="bicubic", align_corners=False) - videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) - - # forward - with torch.amp.autocast("cuda", dtype=self.dtype): - out = self.model.visual(videos, use_31_block=True) - return out - - diff --git a/lyra_2/_src/modules/conditioner.py b/lyra_2/_src/modules/conditioner.py deleted file mode 100644 index 50718c46088dd3cc29e712e888d95f84856380df..0000000000000000000000000000000000000000 --- a/lyra_2/_src/modules/conditioner.py +++ /dev/null @@ -1,476 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from __future__ import annotations - -import copy -from abc import ABC, abstractmethod -from collections import defaultdict -from contextlib import nullcontext -from dataclasses import dataclass, fields -from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union - -import omegaconf -import torch -import torch.nn as nn -from torch.distributed import ProcessGroup - -from lyra_2._ext.imaginaire.functional.batch_ops import batch_mul -from lyra_2._ext.imaginaire.lazy_config import instantiate -from lyra_2._ext.imaginaire.utils import log -from lyra_2._ext.imaginaire.utils.count_params import count_params, disabled_train -from lyra_2._ext.imaginaire.utils.easy_io import easy_io -try: - from lyra_2._src.datasets.data_sources.item_datasets_for_validation import get_itemdataset_option_local -except ImportError: - get_itemdataset_option_local = None -from lyra_2._src.utils.context_parallel import broadcast - -T = TypeVar("T", bound="BaseCondition") - - -class DataType(str, Enum): - IMAGE = "image" - VIDEO = "video" - MIX = "mix" - COUPLED = "coupled" # for coupled dataloader, video and image data are in one batch and concat together - - def __str__(self) -> str: - return self.value - - -def broadcast_condition(condition: BaseCondition, process_group: Optional[ProcessGroup] = None) -> BaseCondition: - """ - Broadcast the condition from the minimum rank in the specified group(s). - """ - if condition.is_broadcasted: - return condition - - kwargs = condition.to_dict(skip_underscore=False) - for key, value in kwargs.items(): - if value is not None: - kwargs[key] = broadcast(value, process_group) - kwargs["_is_broadcasted"] = True - return type(condition)(**kwargs) - - -@dataclass(frozen=True) -class BaseCondition(ABC): - """ - Attributes: - _is_broadcasted: Flag indicating if parallel broadcast splitting - has been performed. This is an internal implementation detail. - """ - - _is_broadcasted: bool = False - - def to_dict(self, skip_underscore: bool = True) -> Dict[str, Any]: - """Converts the condition to a dictionary. - - Returns: - Dictionary containing the condition's fields and values. - """ - return {f.name: getattr(self, f.name) for f in fields(self) if not (f.name.startswith("_") and skip_underscore)} - - @property - def is_broadcasted(self) -> bool: - return self._is_broadcasted - - def broadcast(self, process_group: torch.distributed.ProcessGroup) -> BaseCondition: - """Broadcasts and splits the condition across the checkpoint parallelism group. - For most condition, such asT2VCondition, we do not need split. - - Args: - process_group: The process group for broadcast and split - - Returns: - A new BaseCondition instance with the broadcasted and split condition. - """ - if self.is_broadcasted: - return self - return broadcast_condition(self, process_group) - - -@dataclass(frozen=True) -class T2VCondition(BaseCondition): - crossattn_emb: Optional[torch.Tensor] = None - data_type: DataType = DataType.VIDEO - padding_mask: Optional[torch.Tensor] = None - fps: Optional[torch.Tensor] = None - - def edit_data_type(self, data_type: DataType) -> T2VCondition: - """Edit the data type of the condition. - - Args: - data_type: The new data type. - - Returns: - A new T2VCondition instance with the new data type. - """ - kwargs = self.to_dict(skip_underscore=False) - kwargs["data_type"] = data_type - return type(self)(**kwargs) - - @property - def is_video(self) -> bool: - return self.data_type == DataType.VIDEO - - -class AbstractEmbModel(nn.Module): - def __init__(self): - super().__init__() - - self._is_trainable = None - self._dropout_rate = None - self._input_key = None - self._return_dict = False - - @property - def is_trainable(self) -> bool: - return self._is_trainable - - @property - def dropout_rate(self) -> Union[float, torch.Tensor]: - return self._dropout_rate - - @property - def input_key(self) -> str: - return self._input_key - - @property - def is_return_dict(self) -> bool: - return self._return_dict - - @is_trainable.setter - def is_trainable(self, value: bool): - self._is_trainable = value - - @dropout_rate.setter - def dropout_rate(self, value: Union[float, torch.Tensor]): - self._dropout_rate = value - - @input_key.setter - def input_key(self, value: str): - self._input_key = value - - @is_return_dict.setter - def is_return_dict(self, value: bool): - self._return_dict = value - - @is_trainable.deleter - def is_trainable(self): - del self._is_trainable - - @dropout_rate.deleter - def dropout_rate(self): - del self._dropout_rate - - @input_key.deleter - def input_key(self): - del self._input_key - - @is_return_dict.deleter - def is_return_dict(self): - del self._return_dict - - def random_dropout_input( - self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None - ) -> torch.Tensor: - del key - dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate - return batch_mul( - torch.bernoulli((1.0 - dropout_rate) * torch.ones(in_tensor.shape[0])).type_as(in_tensor), - in_tensor, - ) - - def details(self) -> str: - return "" - - def summary(self) -> str: - input_key = self.input_key if self.input_key is not None else getattr(self, "input_keys", None) - return ( - f"{self.__class__.__name__} \n\tinput key: {input_key}" - f"\n\tParam count: {count_params(self, False)} \n\tTrainable: {self.is_trainable}" - f"\n\tDropout rate: {self.dropout_rate}" - f"\n\t{self.details()}" - ) - - -class TextAttr(AbstractEmbModel): - def __init__(self, input_key: List[str], dropout_rate: Optional[float] = 0.0): - super().__init__() - self._input_key = input_key - self._dropout_rate = dropout_rate - - def forward(self, token: torch.Tensor): - return {"crossattn_emb": token} - - def random_dropout_input( - self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None - ) -> torch.Tensor: - if key is not None and "mask" in key: - return in_tensor - return super().random_dropout_input(in_tensor, dropout_rate, key) - - def details(self) -> str: - return "Output key: [crossattn_emb]" - - -class TextAttrEmptyStringDrop(AbstractEmbModel): - def __init__(self, input_key: List[str], dropout_rate: Optional[float] = 0.0): - super().__init__() - self._input_key = input_key - self._dropout_rate = dropout_rate - self.empty_prompt_data = None - - def forward(self, token: torch.Tensor): - return {"crossattn_emb": token} - - def random_dropout_input( - self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None - ) -> torch.Tensor: - if key is not None and "mask" in key: - return in_tensor - del key - dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate - if dropout_rate == 0.0: - return in_tensor - if self.empty_prompt_data is None: - if get_itemdataset_option_local is None: - raise ImportError( - "Text dropout requires lyra_2._src.datasets.data_sources.item_datasets_for_validation " - "which is not available in this installation." - ) - self.empty_prompt_data = easy_io.load(get_itemdataset_option_local("empty_string_umt5").path) - - B = in_tensor.shape[0] # batch size - # Create dropout mask: 1 -> keep in_tensor, 0 -> use empty_prompt_data - keep_mask = torch.bernoulli((1.0 - dropout_rate) * torch.ones(B, device=in_tensor.device)).type_as(in_tensor) - keep_mask = keep_mask.view(B, *[1] * (in_tensor.dim() - 1)) # broadcastable shape - # Prepare empty_prompt_data with correct shape, dtype, and device - empty_prompt = self.empty_prompt_data.to(dtype=in_tensor.dtype, device=in_tensor.device) - # Repeat empty_prompt along batch dimension if needed - if empty_prompt.shape[0] != B: - if empty_prompt.shape[0] == 1: - empty_prompt = empty_prompt.expand(B, *empty_prompt.shape[1:]) - else: - raise ValueError( - f"empty_prompt_data batch size {empty_prompt.shape[0]} does not match in_tensor batch size {B}" - ) - - # Mix using the dropout mask - return keep_mask * in_tensor + (1.0 - keep_mask) * empty_prompt - - def details(self) -> str: - return "Output key: [crossattn_emb]" - - -class ReMapkey(AbstractEmbModel): - def __init__( - self, - input_key: str, - output_key: Optional[str] = None, - dropout_rate: Optional[float] = 0.0, - dtype: Optional[str] = None, - ): - super().__init__() - self.output_key = output_key - self.dtype = { - None: None, - "float": torch.float32, - "bfloat16": torch.bfloat16, - "half": torch.float16, - "float16": torch.float16, - "int": torch.int32, - "long": torch.int64, - }[dtype] - self._input_key = input_key - self._output_key = output_key - self._dropout_rate = dropout_rate - - def forward(self, element: torch.Tensor) -> Dict[str, torch.Tensor]: - key = self.output_key if self.output_key else self.input_key - if isinstance(element, torch.Tensor): - element = element.to(dtype=self.dtype) - return {key: element} - - def details(self) -> str: - key = self.output_key if self.output_key else self.input_key - return f"Output key: {key} \n\tDtype: {self.dtype}" - - -class GeneralConditioner(nn.Module, ABC): - """ - An abstract module designed to handle various embedding models with conditional and unconditional configurations. - This abstract base class initializes and manages a collection of embedders that can dynamically adjust - their dropout rates based on conditioning. - - Attributes: - KEY2DIM (dict): A mapping from output keys to dimensions used for concatenation. - embedders (nn.ModuleDict): A dictionary containing all embedded models initialized and configured - based on the provided configurations. - - Parameters: - emb_models (Union[List, Any]): A dictionary where keys are embedder names and values are configurations - for initializing the embedders. - - Example: - See Edify4ConditionerConfig - """ - - KEY2DIM = {"crossattn_emb": 1} - - def __init__(self, **emb_models: Union[List, Any]): - super().__init__() - self.embedders = nn.ModuleDict() - for n, (emb_name, emb_config) in enumerate(emb_models.items()): - embedder = instantiate(emb_config) - # assert isinstance( - # embedder, AbstractEmbModel - # ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" - embedder.is_trainable = getattr(emb_config, "is_trainable", True) - embedder.dropout_rate = getattr(emb_config, "dropout_rate", 0.0) - if not embedder.is_trainable: - embedder.train = disabled_train - for param in embedder.parameters(): - param.requires_grad = False - embedder.eval() - - log.info(f"Initialized embedder #{n}-{emb_name}: \n {embedder.summary()}") - self.embedders[emb_name] = embedder - - @abstractmethod - def forward( - self, - batch: Dict, - override_dropout_rate: Optional[Dict[str, float]] = None, - ) -> Any: - """Should be implemented in subclasses to handle conditon datatype""" - raise NotImplementedError - - def _forward( - self, - batch: Dict, - override_dropout_rate: Optional[Dict[str, float]] = None, - ) -> Dict: - """ - Processes the input batch through all configured embedders, applying conditional dropout rates if specified. - Output tensors for each key are concatenated along the dimensions specified in KEY2DIM. - - Parameters: - batch (Dict): The input data batch to process. - override_dropout_rate (Optional[Dict[str, float]]): Optional dictionary to override default dropout rates - per embedder key. - - Returns: - Dict: A dictionary of output tensors concatenated by specified dimensions. - - Note: - In case the network code is sensitive to the order of concatenation, you can either control the order via \ - config file or make sure the embedders return a unique key for each output. - """ - output = defaultdict(list) - if override_dropout_rate is None: - override_dropout_rate = {} - - # make sure emb_name in override_dropout_rate is valid - for emb_name in override_dropout_rate.keys(): - assert emb_name in self.embedders, f"invalid name found {emb_name}" - - for emb_name, embedder in self.embedders.items(): - embedding_context = nullcontext if embedder.is_trainable else torch.no_grad - with embedding_context(): - if isinstance(embedder.input_key, str): - emb_out = embedder( - embedder.random_dropout_input( - batch[embedder.input_key], override_dropout_rate.get(emb_name, None) - ) - ) - elif isinstance(embedder.input_key, (list, omegaconf.listconfig.ListConfig)): - emb_out = embedder( - *[ - embedder.random_dropout_input(batch.get(k), override_dropout_rate.get(emb_name, None), k) - for k in embedder.input_key - ] - ) - else: - raise KeyError( - f"Embedder '{embedder.__class__.__name__}' requires an 'input_key' attribute to be defined as either a string or list of strings" - ) - for k, v in emb_out.items(): - output[k].append(v) - # Concatenate the outputs - return {k: torch.cat(v, dim=self.KEY2DIM.get(k, -1)) for k, v in output.items()} - - def get_condition_uncondition( - self, - data_batch: Dict, - ) -> Tuple[Any, Any]: - """ - Processes the provided data batch to generate two sets of outputs: conditioned and unconditioned. This method - manipulates the dropout rates of embedders to simulate two scenarios — one where all conditions are applied - (conditioned), and one where they are removed or reduced to the minimum (unconditioned). - - This method first sets the dropout rates to zero for the conditioned scenario to fully apply the embedders' effects. - For the unconditioned scenario, it sets the dropout rates to 1 (or to 0 if the initial unconditional dropout rate - is insignificant) to minimize the embedders' influences, simulating an unconditioned generation. - - Parameters: - data_batch (Dict): The input data batch that contains all necessary information for embedding processing. The - data is expected to match the required format and keys expected by the embedders. - - Returns: - Tuple[Any, Any]: A tuple containing two condition: - - The first one contains the outputs with all embedders fully applied (conditioned outputs). - - The second one contains the outputs with embedders minimized or not applied (unconditioned outputs). - """ - cond_dropout_rates, dropout_rates = {}, {} - for emb_name, embedder in self.embedders.items(): - cond_dropout_rates[emb_name] = 0.0 - dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 - - condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) - un_condition: Any = self(data_batch, override_dropout_rate=dropout_rates) - return condition, un_condition - - def get_condition_with_negative_prompt( - self, - data_batch: Dict, - ) -> Tuple[Any, Any]: - """ - Similar functionality as get_condition_uncondition - But use negative prompts for unconditon - """ - cond_dropout_rates, uncond_dropout_rates = {}, {} - for emb_name, embedder in self.embedders.items(): - cond_dropout_rates[emb_name] = 0.0 - if isinstance(embedder, TextAttr) or isinstance(embedder, TextAttrEmptyStringDrop): - uncond_dropout_rates[emb_name] = 0.0 - else: - uncond_dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 - - data_batch_neg_prompt = copy.deepcopy(data_batch) - if "neg_t5_text_embeddings" in data_batch_neg_prompt: - if isinstance(data_batch_neg_prompt["neg_t5_text_embeddings"], torch.Tensor): - data_batch_neg_prompt["t5_text_embeddings"] = data_batch_neg_prompt["neg_t5_text_embeddings"] - - condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) - un_condition: Any = self(data_batch_neg_prompt, override_dropout_rate=uncond_dropout_rates) - - return condition, un_condition - - diff --git a/lyra_2/_src/modules/selective_activation_checkpoint.py b/lyra_2/_src/modules/selective_activation_checkpoint.py deleted file mode 100644 index 236b2d1a2b993c250eb3573efba2a1fd17b0c30b..0000000000000000000000000000000000000000 --- a/lyra_2/_src/modules/selective_activation_checkpoint.py +++ /dev/null @@ -1,73 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from enum import Enum - -import torch - -try: - from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts, noop_context_fn -except ImportError: - CheckpointPolicy = None - -mm_only_save_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops.aten.addmm.default, -} - - -class CheckpointMode(str, Enum): - """ - Enum for the different checkpoint modes. - """ - - NONE = "none" - MM_ONLY = "mm_only" - BLOCK_WISE = "block_wise" - - def __str__(self) -> str: - # Optional: makes print() show just the value - return self.value - - -def mm_only_policy(ctx, func, *args, **kwargs): - """ - In newer flash-attn and TE versions, FA2 shows up in the list of ops with the name of 'flash_attn._flash_attn_forward'. - However, FA2 is much slower (2-3x) than FA3 or cuDNN kernel. Registering cuDNN kernel would require heavy changes in TE code. - That's why the best option is to use FA3 with small modifications to flash_attn_interface.py to register FA3 as PyTorch op. - """ - to_save = func in mm_only_save_list or "flash_attn" in str(func) - return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE - - -def mm_only_context_fn(): - return create_selective_checkpoint_contexts(mm_only_policy) - - -@dataclass -class SACConfig: - mode: str = "mm_only" - every_n_blocks: int = 1 - - def get_context_fn(self): - if self.mode == CheckpointMode.MM_ONLY: - return mm_only_context_fn - elif self.mode == CheckpointMode.BLOCK_WISE: - return noop_context_fn - else: - raise ValueError(f"Invalid mode: {self.mode}") diff --git a/lyra_2/_src/modules/umt5.py b/lyra_2/_src/modules/umt5.py deleted file mode 100644 index 648d49afcfefc4d464cedd623f8449d5adb1b9fc..0000000000000000000000000000000000000000 --- a/lyra_2/_src/modules/umt5.py +++ /dev/null @@ -1,93 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import html -import string - -import ftfy -import regex as re -from transformers import AutoTokenizer - - -def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text - - -def canonicalize(text, keep_punctuation_exact_string=None): - text = text.replace("_", " ") - if keep_punctuation_exact_string: - text = keep_punctuation_exact_string.join( - part.translate(str.maketrans("", "", string.punctuation)) - for part in text.split(keep_punctuation_exact_string) - ) - else: - text = text.translate(str.maketrans("", "", string.punctuation)) - text = text.lower() - text = re.sub(r"\s+", " ", text) - return text.strip() - - -class HuggingfaceTokenizer: - def __init__(self, name, seq_len=None, clean=None, **kwargs): - assert clean in (None, "whitespace", "lower", "canonicalize") - self.name = name - self.seq_len = seq_len - self.clean = clean - - # init tokenizer - self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) - self.vocab_size = self.tokenizer.vocab_size - - def __call__(self, sequence, **kwargs): - return_mask = kwargs.pop("return_mask", False) - - # arguments - _kwargs = {"return_tensors": "pt"} - if self.seq_len is not None: - _kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len}) - _kwargs.update(**kwargs) - - # tokenization - if isinstance(sequence, str): - sequence = [sequence] - if self.clean: - sequence = [self._clean(u) for u in sequence] - ids = self.tokenizer(sequence, **_kwargs) - - # output - if return_mask: - return ids.input_ids, ids.attention_mask - else: - return ids.input_ids - - def _clean(self, text): - if self.clean == "whitespace": - text = whitespace_clean(basic_clean(text)) - elif self.clean == "lower": - text = whitespace_clean(basic_clean(text)).lower() - elif self.clean == "canonicalize": - text = canonicalize(basic_clean(text)) - return text - - diff --git a/lyra_2/_src/modules/xlm_roberta.py b/lyra_2/_src/modules/xlm_roberta.py deleted file mode 100644 index 1591ee3dfceb62dd58e127daca589448a0780bdf..0000000000000000000000000000000000000000 --- a/lyra_2/_src/modules/xlm_roberta.py +++ /dev/null @@ -1,157 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import torch -import torch.nn as nn -import torch.nn.functional as F - -__all__ = ["XLMRoberta"] - - -class SelfAttention(nn.Module): - def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): - assert dim % num_heads == 0 - super().__init__() - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.eps = eps - - # layers - self.q = nn.Linear(dim, dim) - self.k = nn.Linear(dim, dim) - self.v = nn.Linear(dim, dim) - self.o = nn.Linear(dim, dim) - self.dropout = nn.Dropout(dropout) - - def forward(self, x, mask): - """ - x: [B, L, C]. - """ - b, s, c, n, d = *x.size(), self.num_heads, self.head_dim - - # compute query, key, value - q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) - k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) - v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) - - # compute attention - p = self.dropout.p if self.training else 0.0 - x = F.scaled_dot_product_attention(q, k, v, mask, p) - x = x.permute(0, 2, 1, 3).reshape(b, s, c) - - # output - x = self.o(x) - x = self.dropout(x) - return x - - -class AttentionBlock(nn.Module): - def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): - super().__init__() - self.dim = dim - self.num_heads = num_heads - self.post_norm = post_norm - self.eps = eps - - # layers - self.attn = SelfAttention(dim, num_heads, dropout, eps) - self.norm1 = nn.LayerNorm(dim, eps=eps) - self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout)) - self.norm2 = nn.LayerNorm(dim, eps=eps) - - def forward(self, x, mask): - if self.post_norm: - x = self.norm1(x + self.attn(x, mask)) - x = self.norm2(x + self.ffn(x)) - else: - x = x + self.attn(self.norm1(x), mask) - x = x + self.ffn(self.norm2(x)) - return x - - -class XLMRoberta(nn.Module): - """ - XLMRobertaModel with no pooler and no LM head. - """ - - def __init__( - self, - vocab_size=250002, - max_seq_len=514, - type_size=1, - pad_id=1, - dim=1024, - num_heads=16, - num_layers=24, - post_norm=True, - dropout=0.1, - eps=1e-5, - ): - super().__init__() - self.vocab_size = vocab_size - self.max_seq_len = max_seq_len - self.type_size = type_size - self.pad_id = pad_id - self.dim = dim - self.num_heads = num_heads - self.num_layers = num_layers - self.post_norm = post_norm - self.eps = eps - - # embeddings - self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) - self.type_embedding = nn.Embedding(type_size, dim) - self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) - self.dropout = nn.Dropout(dropout) - - # blocks - self.blocks = nn.ModuleList( - [AttentionBlock(dim, num_heads, post_norm, dropout, eps) for _ in range(num_layers)] - ) - - # norm layer - self.norm = nn.LayerNorm(dim, eps=eps) - - def forward(self, ids): - """ - ids: [B, L] of torch.LongTensor. - """ - b, s = ids.shape - mask = ids.ne(self.pad_id).long() - - # embeddings - x = ( - self.token_embedding(ids) - + self.type_embedding(torch.zeros_like(ids)) - + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) - ) - if self.post_norm: - x = self.norm(x) - x = self.dropout(x) - - # blocks - mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min) - for block in self.blocks: - x = block(x, mask) - - # output - if not self.post_norm: - x = self.norm(x) - return x - - diff --git a/lyra_2/_src/networks/__init__.py b/lyra_2/_src/networks/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/lyra_2/_src/networks/clip_lyra2.py b/lyra_2/_src/networks/clip_lyra2.py deleted file mode 100644 index 22ff0fb18b02095e2076a853bc35900e67632931..0000000000000000000000000000000000000000 --- a/lyra_2/_src/networks/clip_lyra2.py +++ /dev/null @@ -1,55 +0,0 @@ -from typing import Dict, List, Optional - -import torch -import torch.nn as nn - -from lyra_2._src.modules.clip import CLIPModel -from lyra_2._src.modules.conditioner import AbstractEmbModel - - -class Wan2pt1CLIPEmbLyra2(AbstractEmbModel): - """Lyra2-aware CLIP embedder.""" - - def __init__( - self, - input_key: List[str], - dropout_rate: float = 0.0, - num_token: int = 257, - dtype: str = "bfloat16", - ): - super().__init__() - self.num_token = num_token - self.model_dim = 1280 - self.clip_model = CLIPModel() - - self._input_key = input_key - self._output_key = None - self._dropout_rate = dropout_rate - self.dtype = { - "bfloat16": torch.bfloat16, - "float16": torch.float16, - "float32": torch.float32, - }[dtype] - - def random_dropout_input(self, in_tensor=None, dropout_rate=None, key=None): - return in_tensor - - def forward( - self, - image_tensor: Optional[torch.Tensor] = None, - video_tensor: Optional[torch.Tensor] = None, - media_latents: Optional[torch.Tensor] = None, - mask: Optional[torch.Tensor] = None, - buffer_latents: Optional[torch.Tensor] = None, - ) -> Dict[str, torch.Tensor]: - assert media_latents is not None, "media_latents is required" - assert mask is not None, "mask is required" - with torch.no_grad(): - assert image_tensor is not None, "image_tensor is required" - context_B_L_D = self.clip_model.visual(image_tensor).to(self.dtype) - - y = torch.concat([mask, media_latents.to(self.dtype)], dim=1) - out = {"frame_cond_crossattn_emb_B_L_D": context_B_L_D, "y_B_C_T_H_W": y} - if buffer_latents is not None: - out["y_buffer_B_C_T_H_W"] = buffer_latents.to(self.dtype) - return out diff --git a/lyra_2/_src/networks/wan2pt1.py b/lyra_2/_src/networks/wan2pt1.py deleted file mode 100644 index b9cddd884a789374c872dfb9a37adea49637bd40..0000000000000000000000000000000000000000 --- a/lyra_2/_src/networks/wan2pt1.py +++ /dev/null @@ -1,970 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# from Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. - -import math -from typing import Optional - -import torch -import torch.amp as amp -import torch.nn as nn -from einops import rearrange, repeat - -try: - from flash_attn.layers.rotary import apply_rotary_emb as flash_apply_rotary_emb -except ImportError: - flash_apply_rotary_emb = None - raise ImportError("flash_attn is not installed.") - -from torch.distributed import ProcessGroup, get_process_group_ranks -from torch.distributed._composable.fsdp import fully_shard -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper -from torchvision import transforms -from transformer_engine.pytorch.attention import DotProductAttention - -from lyra_2._ext.imaginaire.utils import log -from lyra_2._src.callbacks.model_weights_stats import WeightTrainingStat -from lyra_2._src.modules.selective_activation_checkpoint import ( - CheckpointMode, -) -from lyra_2._src.modules.selective_activation_checkpoint import SACConfig as SACConfig -from lyra_2._src.utils.context_parallel import split_inputs_cp - -T5_CONTEXT_TOKEN_NUMBER = 512 -FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2 -from collections import namedtuple - -VideoSize = namedtuple("VideoSize", ["T", "H", "W"]) - - -class VideoPositionEmb(nn.Module): - def __init__(self): - super().__init__() - self._cp_group = None - - def enable_context_parallel(self, process_group: ProcessGroup): - self._cp_group = process_group - - def disable_context_parallel(self): - self._cp_group = None - - @property - def seq_dim(self): - return 1 - - def forward(self, x_B_T_H_W_C: torch.Tensor) -> torch.Tensor: - """ - With CP, the function assume that the input tensor is already split. - It delegates the embedding generation to generate_embeddings function. - """ - B_T_H_W_C = x_B_T_H_W_C.shape - if self._cp_group is not None: - cp_ranks = get_process_group_ranks(self._cp_group) - cp_size = len(cp_ranks) - B, T, H, W, C = B_T_H_W_C - B_T_H_W_C = (B, T * cp_size, H, W, C) - embeddings = self.generate_embeddings(B_T_H_W_C) - - return self._split_for_context_parallel(embeddings) - - def generate_embeddings(self, B_T_H_W_C: torch.Size): - raise NotImplementedError - - def _split_for_context_parallel(self, embeddings): - if self._cp_group is not None: - embeddings = split_inputs_cp(x=embeddings, seq_dim=self.seq_dim, cp_group=self._cp_group) - return embeddings - - -class VideoRopePosition3DEmb(VideoPositionEmb): - def __init__( - self, - head_dim: int, - len_h: int, - len_w: int, - len_t: int, - h_extrapolation_ratio: float = 1.0, - w_extrapolation_ratio: float = 1.0, - t_extrapolation_ratio: float = 1.0, - ): - super().__init__() - self.max_h = len_h - self.max_w = len_w - self.max_t = len_t - dim = head_dim - dim_h = dim // 6 * 2 - dim_w = dim_h - dim_t = dim - 2 * dim_h - assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" - self._dim_h = dim_h - self._dim_t = dim_t - - self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) - self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) - self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) - - self._is_initialized = False - - def cache_parameters(self) -> None: - if self._is_initialized: - return - - dim_h = self._dim_h - dim_t = self._dim_t - - self.seq = torch.arange(max(self.max_h, self.max_w, self.max_t)).float().cuda() - self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h - self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t - self._is_initialized = True - - def generate_embeddings( - self, - B_T_H_W_C: torch.Size, - h_ntk_factor: Optional[float] = None, - w_ntk_factor: Optional[float] = None, - t_ntk_factor: Optional[float] = None, - ): - """ - Generate embeddings for the given input size. - - Args: - B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). - fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. - h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. - w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. - t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. - - Returns: - Not specified in the original code snippet. - """ - self.cache_parameters() - - h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor - w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor - t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor - - h_theta = 10000.0 * h_ntk_factor - w_theta = 10000.0 * w_ntk_factor - t_theta = 10000.0 * t_ntk_factor - - h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) - w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) - temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) - - B, T, H, W, _ = B_T_H_W_C - assert H <= self.max_h and W <= self.max_w, ( - f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})" - ) - freqs_h = torch.outer(self.seq[:H], h_spatial_freqs) - freqs_w = torch.outer(self.seq[:W], w_spatial_freqs) - freqs_t = torch.outer(self.seq[:T], temporal_freqs) - freqs_T_H_W_D = torch.cat( - [ - repeat(freqs_t, "t d -> t h w d", h=H, w=W), - repeat(freqs_h, "h d -> t h w d", t=T, w=W), - repeat(freqs_w, "w d -> t h w d", t=T, h=H), - ], - dim=-1, - ) - - return rearrange(freqs_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() - - @property - def seq_dim(self): - return 0 - - -def sinusoidal_embedding_1d(dim, position): - # preprocess - assert dim % 2 == 0 - half = dim // 2 - position = position.type(torch.float64) - - # calculation - sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) - x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) - return x - - -def rope_apply(x, video_size: VideoSize, freqs): - """ - Optimized version of rope_apply using flash_attention's rotary embedding implementation. - This version processes the entire batch at once for efficiency. - - Args: - x (Tensor): Input tensor with shape [batch_size, seq_len, n_heads, head_dim] - video_size (VideoSize): Video dimensions with shape [T, H, W] - freqs (Tensor): Complex frequencies with shape [max_seq_len, head_dim // 2] - - Returns: - Tensor: Rotary-embedded tensor with same shape as input - """ - batch_size, seq_len, n_heads, head_dim = x.shape - - # Since all items in the batch share the same grid dimensions, we can use the first item - T, H, W = video_size - curr_seq_len = T * H * W - - # Make sure the sequence length matches the grid size - assert seq_len == curr_seq_len, "Sequence length must be equal to T*H*W" - - freqs = freqs.view(seq_len, head_dim // 2) - cos = torch.cos(freqs).to(torch.float32) - sin = torch.sin(freqs).to(torch.float32) - - # Apply the rotation - rotated = flash_apply_rotary_emb(x.to(torch.float32), cos, sin, interleaved=True, inplace=False) - - return rotated.to(x.dtype) - - -class WanRMSNorm(nn.Module): - def __init__(self, dim, eps=1e-5): - super().__init__() - self.dim = dim - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def reset_parameters(self): - self.weight.data.fill_(1.0) - - def forward(self, x): - r""" - Args: - x(Tensor): Shape [B, L, C] - """ - return self._norm(x.float()).type_as(x) * self.weight - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) - - -class WanLayerNorm(nn.LayerNorm): - def __init__(self, dim, eps=1e-6, elementwise_affine=False): - super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) - - def forward(self, x): - r""" - Args: - x(Tensor): Shape [B, L, C] - """ - # return super().forward(x.float()).type_as(x) - return super().forward(x) - - -class SelfAttnOp(DotProductAttention): - def forward( - self, - q_B_L_H_D, - k_B_L_H_D, - v_B_L_H_D, - video_size: Optional[VideoSize] = None, - ): - return super().forward(q_B_L_H_D, k_B_L_H_D, v_B_L_H_D) - - -class WanSelfAttention(nn.Module): - def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6, cp_comm_type="p2p"): - assert dim % num_heads == 0 - super().__init__() - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.window_size = window_size - self.qk_norm = qk_norm - self.eps = eps - self.qk_norm = qk_norm - self.cp_comm_type = cp_comm_type - - # layers - self.q = nn.Linear(dim, dim) - self.k = nn.Linear(dim, dim) - self.v = nn.Linear(dim, dim) - self.o = nn.Linear(dim, dim) - self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() - self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() - self.attn_op = SelfAttnOp( - self.num_heads, - self.head_dim, - num_gqa_groups=self.num_heads, - attention_dropout=0, - qkv_format="bshd", - attn_mask_type="no_mask", - ) - - def init_weights(self): - std = 1.0 / math.sqrt(self.dim) - torch.nn.init.trunc_normal_(self.q.weight, std=std) - torch.nn.init.trunc_normal_(self.k.weight, std=std) - torch.nn.init.trunc_normal_(self.v.weight, std=std) - torch.nn.init.trunc_normal_(self.o.weight, std=std) - # zero out bias - self.q.bias.data.zero_() - self.k.bias.data.zero_() - self.v.bias.data.zero_() - self.o.bias.data.zero_() - # reset norm weights - if self.qk_norm: - self.norm_q.reset_parameters() - self.norm_k.reset_parameters() - - def forward(self, x, seq_lens, video_size: VideoSize, freqs, kq_bias=None): - r""" - Args: - x(Tensor): Shape [B, L, num_heads, C / num_heads] - seq_lens(Tensor): Shape [B] - video_size(VideoSize): Shape [T, H, W] - freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] - kq_bias(Tensor|None): Optional bias added to input before Q and K projections only (not V). - """ - b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim - - # query, key, value function - x_kq = x if kq_bias is None else x + kq_bias - - def qkv_fn(x, x_kq): - q = self.norm_q(self.q(x_kq)).view(b, s, n, d) - k = self.norm_k(self.k(x_kq)).view(b, s, n, d) - v = self.v(x).view(b, s, n, d) - return q, k, v - - q, k, v = qkv_fn(x, x_kq) - - x = self.attn_op(rope_apply(q, video_size, freqs), rope_apply(k, video_size, freqs), v, video_size) - - # output - x = x.flatten(2) - x = self.o(x) - return x - - def set_context_parallel_group(self, process_group, ranks, stream): - self.attn_op.set_context_parallel_group(process_group, ranks, stream, cp_comm_type=self.cp_comm_type) - - -class WanT2VCrossAttention(WanSelfAttention): - def forward(self, x, context, context_lens): - r""" - Args: - x(Tensor): Shape [B, L1, C] - context(Tensor): Shape [B, L2, C] - context_lens(Tensor): Shape [B] - """ - b, n, d = x.size(0), self.num_heads, self.head_dim - - # compute query, key, value - q = self.norm_q(self.q(x)).view(b, -1, n, d) - k = self.norm_k(self.k(context)).view(b, -1, n, d) - v = self.v(context).view(b, -1, n, d) - - # compute attention - x = self.attn_op(q, k, v, None) - # output - x = x.flatten(2) - x = self.o(x) - return x - - -class WanI2VCrossAttention(WanSelfAttention): - def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6, cp_comm_type="p2p"): - super().__init__(dim, num_heads, window_size, qk_norm, eps, cp_comm_type) - - self.k_img = nn.Linear(dim, dim) - self.v_img = nn.Linear(dim, dim) - # self.alpha = nn.Parameter(torch.zeros((1, ))) - self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() - self.attn_op_image = DotProductAttention( - self.num_heads, - self.head_dim, - num_gqa_groups=self.num_heads, - attention_dropout=0, - qkv_format="bshd", - attn_mask_type="no_mask", - ) - - def init_weights(self): - super().init_weights() - std = 1.0 / math.sqrt(self.dim) - torch.nn.init.trunc_normal_(self.k_img.weight, std=std) - torch.nn.init.trunc_normal_(self.v_img.weight, std=std) - # zero out bias - self.k_img.bias.data.zero_() - self.v_img.bias.data.zero_() - # reset norm weights - if self.qk_norm: - self.norm_k_img.reset_parameters() - - def forward(self, x, context, context_lens): - r""" - Args: - x(Tensor): Shape [B, L1, C] - context(Tensor): Shape [B, L2, C] - context_lens(Tensor): Shape [B] - """ - image_context_length = context.shape[1] - T5_CONTEXT_TOKEN_NUMBER - context_img = context[:, :image_context_length] - context = context[:, image_context_length:] - b, n, d = x.size(0), self.num_heads, self.head_dim - - # compute query, key, value - q = self.norm_q(self.q(x)).view(b, -1, n, d) - k = self.norm_k(self.k(context)).view(b, -1, n, d) - v = self.v(context).view(b, -1, n, d) - k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) - v_img = self.v_img(context_img).view(b, -1, n, d) - img_x = self.attn_op_image(q, k_img, v_img) - # compute attention - x = self.attn_op(q, k, v) - - # output - x = x.flatten(2) - img_x = img_x.flatten(2) - x = x + img_x - x = self.o(x) - return x - - -WAN_CROSSATTENTION_CLASSES = { - "t2v_cross_attn": WanT2VCrossAttention, - "i2v_cross_attn": WanI2VCrossAttention, -} - - -class WanAttentionBlock(nn.Module): - def __init__( - self, - cross_attn_type, - dim, - ffn_dim, - num_heads, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=False, - eps=1e-6, - cp_comm_type="p2p", - ): - super().__init__() - self.dim = dim - self.ffn_dim = ffn_dim - self.num_heads = num_heads - self.window_size = window_size - self.qk_norm = qk_norm - self.cross_attn_norm = cross_attn_norm - self.eps = eps - - # layers - self.norm1 = WanLayerNorm(dim, eps) - self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps, cp_comm_type) - self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() - self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type]( - dim, num_heads, (-1, -1), qk_norm, eps, cp_comm_type - ) - self.norm2 = WanLayerNorm(dim, eps) - self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)) - - # modulation - self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) - - def init_weights(self): - self.self_attn.init_weights() - self.cross_attn.init_weights() - - self.norm1.reset_parameters() - self.norm2.reset_parameters() - self.norm3.reset_parameters() - - std = 1.0 / math.sqrt(self.dim) - torch.nn.init.trunc_normal_(self.modulation, std=std) - - def forward( - self, - x, - e, - seq_lens, - video_size: VideoSize, - freqs, - context, - context_lens, - ): - r""" - Args: - x(Tensor): Shape [B, L, C] - e(Tensor): Shape [B, 6, C] - seq_lens(Tensor): Shape [B], length of each sequence in batch - video_size(VideoSize): Shape [T, H, W] - freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] - """ - assert e.dtype == torch.float32 - with amp.autocast("cuda", dtype=torch.float32): - e = (self.modulation + e).chunk(6, dim=1) - assert e[0].dtype == torch.float32 - - # self-attention - y = self.self_attn((self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x), seq_lens, video_size, freqs) - with amp.autocast("cuda", dtype=torch.float32): - x = x + y * e[2].type_as(x) - - # cross-attention & ffn function - def cross_attn_ffn(x, context, context_lens, e): - x = x + self.cross_attn(self.norm3(x), context, context_lens) - y = self.ffn((self.norm2(x).float() * (1 + e[4]) + e[3]).type_as(x)) - with amp.autocast("cuda", dtype=torch.float32): - x = x + y * e[5].type_as(x) - return x - - x = cross_attn_ffn(x, context, context_lens, e) - return x - - -class Head(nn.Module): - def __init__(self, dim, out_dim, patch_size, eps=1e-6): - super().__init__() - self.dim = dim - self.out_dim = out_dim - self.patch_size = patch_size - self.eps = eps - - # layers - out_dim = math.prod(patch_size) * out_dim - self.norm = WanLayerNorm(dim, eps) - self.head = nn.Linear(dim, out_dim) - - # modulation - self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) - - def init_weights(self): - self.norm.reset_parameters() - - std = 1.0 / math.sqrt(self.dim) - torch.nn.init.trunc_normal_(self.modulation, std=std) - torch.nn.init.trunc_normal_(self.head.weight, std=std) - self.head.bias.data.zero_() - - def forward(self, x, e): - r""" - Args: - x(Tensor): Shape [B, L1, C] - e(Tensor): Shape [B, C] - """ - assert e.dtype == torch.float32 - with amp.autocast("cuda", dtype=torch.float32): - e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) - x = self.head(self.norm(x) * (1 + e[1]) + e[0]) - return x - - -class MLPProj(torch.nn.Module): - def __init__(self, in_dim, out_dim, flf_pos_emb=False): - super().__init__() - - self.proj = torch.nn.Sequential( - torch.nn.LayerNorm(in_dim), - torch.nn.Linear(in_dim, in_dim), - torch.nn.GELU(), - torch.nn.Linear(in_dim, out_dim), - torch.nn.LayerNorm(out_dim), - ) - if flf_pos_emb: # NOTE: we only use this for `flf2v` - self.emb_pos = nn.Parameter(torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280)) - - def init_weights(self): - self.proj[0].reset_parameters() - self.proj[1].reset_parameters() - self.proj[3].reset_parameters() - self.proj[4].reset_parameters() - - if hasattr(self, "emb_pos"): - self.emb_pos.data.zero_() - - def forward(self, image_embeds): - if hasattr(self, "emb_pos"): - bs, n, d = image_embeds.shape - image_embeds = image_embeds.view(-1, 2 * n, d) - image_embeds = image_embeds + self.emb_pos - clip_extra_context_tokens = self.proj(image_embeds) - return clip_extra_context_tokens - - -class WanModel(WeightTrainingStat): - r""" - Wan diffusion backbone supporting both text-to-video and image-to-video. - """ - - def __init__( - self, - model_type="t2v", - patch_size=(1, 2, 2), - text_len=512, - in_dim=16, - dim=2048, - ffn_dim=8192, - freq_dim=256, - text_dim=4096, - out_dim=16, - num_heads=16, - num_layers=32, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=True, - eps=1e-6, - concat_padding_mask: bool = False, - sac_config: SACConfig = SACConfig(), - cp_comm_type: str = "p2p", - postpone_checkpoint: bool = False, - conv_patchify: bool = False, - ): - r""" - Initialize the diffusion model backbone. - - Args: - model_type (`str`, *optional*, defaults to 't2v'): - Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) - patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): - 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) - text_len (`int`, *optional*, defaults to 512): - Fixed length for text embeddings - in_dim (`int`, *optional*, defaults to 16): - Input video channels (C_in) - dim (`int`, *optional*, defaults to 2048): - Hidden dimension of the transformer - ffn_dim (`int`, *optional*, defaults to 8192): - Intermediate dimension in feed-forward network - freq_dim (`int`, *optional*, defaults to 256): - Dimension for sinusoidal time embeddings - text_dim (`int`, *optional*, defaults to 4096): - Input dimension for text embeddings - out_dim (`int`, *optional*, defaults to 16): - Output video channels (C_out) - num_heads (`int`, *optional*, defaults to 16): - Number of attention heads - num_layers (`int`, *optional*, defaults to 32): - Number of transformer blocks - window_size (`tuple`, *optional*, defaults to (-1, -1)): - Window size for local attention (-1 indicates global attention) - qk_norm (`bool`, *optional*, defaults to True): - Enable query/key normalization - cross_attn_norm (`bool`, *optional*, defaults to False): - Enable cross-attention normalization - eps (`float`, *optional*, defaults to 1e-6): - Epsilon value for normalization layers - concat_padding_mask (`bool`, *optional*, defaults to False): - Enable concat padding mask - cp_comm_type (str, *optional*, defaults to 'p2p'): - CP communication type passed to TE. - """ - - super().__init__() - - assert model_type in ["t2v", "i2v", "flf2v"] - self.model_type = model_type - - self.patch_size = patch_size - self.text_len = text_len - self.in_dim = in_dim - self.dim = dim - self.ffn_dim = ffn_dim - self.freq_dim = freq_dim - self.text_dim = text_dim - self.out_dim = out_dim - self.num_heads = num_heads - self.num_layers = num_layers - self.window_size = window_size - self.qk_norm = qk_norm - self.cross_attn_norm = cross_attn_norm - self.eps = eps - self.concat_padding_mask = concat_padding_mask - self.cp_comm_type = cp_comm_type - self.conv_patchify = conv_patchify - # embeddings - in_dim = in_dim + 1 if self.concat_padding_mask else in_dim - if self.conv_patchify: - self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) - else: - self.patch_embedding = nn.Linear(in_dim * patch_size[0] * patch_size[1] * patch_size[2], dim) - - self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)) - - self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) - self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) - - # blocks - cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn" - self.blocks = nn.ModuleList( - [ - WanAttentionBlock( - cross_attn_type, - dim, - ffn_dim, - num_heads, - window_size, - qk_norm, - cross_attn_norm, - eps, - self.cp_comm_type, - ) - for _ in range(num_layers) - ] - ) - - # head - self.head = Head(dim, out_dim, patch_size, eps) - - # buffers (don't use register_buffer otherwise dtype will be changed in to()) - assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 - - d = dim // num_heads - - self.rope_position_embedding = VideoRopePosition3DEmb( - head_dim=d, - len_h=128, - len_w=128, - len_t=32, - ) - - if model_type == "i2v" or model_type == "flf2v": - self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == "flf2v") - - # initialize weights - self.sac_config = sac_config - if not postpone_checkpoint: - self.enable_selective_checkpoint(sac_config, self.blocks) - - def forward( - self, - x_B_C_T_H_W, - timesteps_B_T, - crossattn_emb, - seq_len=None, - frame_cond_crossattn_emb_B_L_D=None, - y_B_C_T_H_W=None, - padding_mask: Optional[torch.Tensor] = None, - is_uncond=False, - slg_layers=None, - **kwargs, - ): - r""" - Forward pass through the diffusion model - - Args: - x_B_C_T_H_W (Tensor): - Input video tensor with shape [B, C_in, T, H, W] - t (Tensor): - Diffusion timesteps tensor of shape [B] - context (List[Tensor]): - List of text embeddings each with shape [L, C] - seq_len (`int`): - Maximum sequence length for positional encoding - frame_cond_crossattn_emb_B_L_D (Tensor, *optional*): - CLIP image features for image-to-video mode or first-last-frame-to-video mode - y_B_C_T_H_W (Tensor, *optional*): - Conditional video inputs for image-to-video mode, shape [B, C_in, T, H, W] - - Returns: - Tensor: - Denoised video tensor with shape [B, C_out, T, H / 8, W / 8] - """ - assert timesteps_B_T.shape[1] == 1 - t_B = timesteps_B_T[:, 0] - del kwargs - if self.model_type == "i2v" or self.model_type == "flf2v": - assert frame_cond_crossattn_emb_B_L_D is not None and y_B_C_T_H_W is not None - - if y_B_C_T_H_W is not None: - x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, y_B_C_T_H_W], dim=1) - - if self.concat_padding_mask: - padding_mask = transforms.functional.resize( - padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST - ) - x_B_C_T_H_W = torch.cat( - [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 - ) - - if self.conv_patchify: - x_B_D_T_H_W = self.patch_embedding(x_B_C_T_H_W) - x_B_T_H_W_D = rearrange(x_B_D_T_H_W, "b d t h w -> b t h w d") - else: - # embeddings - x_B_T_H_W_D = rearrange( - x_B_C_T_H_W, - "b c (t kt) (h kh) (w kw) -> b t h w (c kt kh kw)", - kt=self.patch_size[0], - kh=self.patch_size[1], - kw=self.patch_size[2], - ) - x_B_T_H_W_D = self.patch_embedding(x_B_T_H_W_D) - - video_size = VideoSize(T=x_B_T_H_W_D.shape[1], H=x_B_T_H_W_D.shape[2], W=x_B_T_H_W_D.shape[3]) - x_B_L_D = rearrange(x_B_T_H_W_D, "b t h w d -> b (t h w) d") - seq_lens = torch.tensor([u.size(0) for u in x_B_L_D], dtype=torch.long) - seq_len = seq_lens.max().item() - assert seq_lens.max() == seq_len - - # time embeddings - with amp.autocast("cuda", dtype=torch.float32): - e_B_D = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t_B).float()) - e0_B_6_D = self.time_projection(e_B_D).unflatten(1, (6, self.dim)) - assert e_B_D.dtype == torch.float32 and e0_B_6_D.dtype == torch.float32 - - # context - context_lens = None - context_B_L_D = self.text_embedding(crossattn_emb) - - if frame_cond_crossattn_emb_B_L_D is not None: - context_clip = self.img_emb(frame_cond_crossattn_emb_B_L_D) # bs x 257 (x2) x dim - context_B_L_D = torch.concat([context_clip, context_B_L_D], dim=1) - - # arguments - kwargs = dict( - e=e0_B_6_D, - seq_lens=seq_lens, - video_size=video_size, - freqs=self.rope_position_embedding(x_B_T_H_W_D), - context=context_B_L_D, - context_lens=context_lens, - ) - - for block_idx, block in enumerate(self.blocks): - if slg_layers is not None and block_idx in slg_layers and is_uncond: - continue - x_B_L_D = block(x_B_L_D, **kwargs) - - # head - x_B_L_D = self.head(x_B_L_D, e_B_D) - - # unpatchify - t, h, w = video_size - x_B_C_T_H_W = rearrange( - x_B_L_D, - "b (t h w) (nt nh nw d) -> b d (t nt) (h nh) (w nw)", - nt=self.patch_size[0], - nh=self.patch_size[1], - nw=self.patch_size[2], - t=t, - h=h, - w=w, - d=self.out_dim, - ) - - return x_B_C_T_H_W - - def init_weights(self): - r""" - Initialize model parameters using Xavier initialization. - """ - - # basic init - for m in self.modules(): - if isinstance(m, nn.Linear): - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - nn.init.zeros_(m.bias) - - for block in self.blocks: - block.init_weights() - self.head.init_weights() - - # init embeddings - nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) - nn.init.zeros_(self.patch_embedding.bias) - - for m in self.text_embedding.modules(): - if isinstance(m, nn.Linear): - nn.init.normal_(m.weight, std=0.02) - if m.bias is not None: - nn.init.zeros_(m.bias) - - for m in self.time_embedding.modules(): - if isinstance(m, nn.Linear): - nn.init.normal_(m.weight, std=0.02) - if m.bias is not None: - nn.init.zeros_(m.bias) - - for m in self.time_projection.modules(): - if isinstance(m, nn.Linear): - nn.init.normal_(m.weight, std=0.02) - if m.bias is not None: - nn.init.zeros_(m.bias) - - # init output layer - nn.init.zeros_(self.head.head.weight) - if self.head.head.bias is not None: - nn.init.zeros_(self.head.head.bias) - - def fully_shard(self, mesh): - for i, block in enumerate(self.blocks): - fully_shard(block, mesh=mesh, reshard_after_forward=True) - fully_shard(self.head, mesh=mesh, reshard_after_forward=False) - fully_shard(self.text_embedding, mesh=mesh, reshard_after_forward=True) - fully_shard(self.time_embedding, mesh=mesh, reshard_after_forward=True) - fully_shard(self.patch_embedding, mesh=mesh, reshard_after_forward=True) - fully_shard(self.time_projection, mesh=mesh, reshard_after_forward=True) - - def disable_context_parallel(self): - # pos_embedder - self.rope_position_embedding.disable_context_parallel() - # attention - for block in self.blocks: - block.self_attn.set_context_parallel_group( - process_group=None, - ranks=None, - stream=torch.cuda.Stream(), - ) - - self._is_context_parallel_enabled = False - - def enable_context_parallel(self, process_group: Optional[ProcessGroup] = None): - # pos_embedder - self.rope_position_embedding.enable_context_parallel(process_group=process_group) - cp_ranks = get_process_group_ranks(process_group) - for block in self.blocks: - block.self_attn.set_context_parallel_group( - process_group=process_group, - ranks=cp_ranks, - stream=torch.cuda.Stream(), - ) - - self._is_context_parallel_enabled = True - - @property - def is_context_parallel_enabled(self): - return self._is_context_parallel_enabled - - def enable_selective_checkpoint(self, sac_config: SACConfig, blocks: nn.ModuleList): - if sac_config.mode == CheckpointMode.NONE: - pass - - log.info( - f"Enable selective checkpoint with {sac_config.mode}, for every {sac_config.every_n_blocks} blocks. Total blocks: {len(blocks)}" - ) - _context_fn = sac_config.get_context_fn() - for block_id, block in blocks.named_children(): - if int(block_id) % sac_config.every_n_blocks == 0: - log.info(f"Enable selective checkpoint for block {block_id}") - block = ptd_checkpoint_wrapper( - block, - context_fn=_context_fn, - preserve_rng_state=False, - ) - blocks.register_module(block_id, block) - self.register_module( - "head", - ptd_checkpoint_wrapper( - self.head, - context_fn=_context_fn, - preserve_rng_state=False, - ), - ) diff --git a/lyra_2/_src/networks/wan2pt1_lyra2.py b/lyra_2/_src/networks/wan2pt1_lyra2.py deleted file mode 100644 index fa8ce042c36640ebd3e93401c4f6ffec5059770e..0000000000000000000000000000000000000000 --- a/lyra_2/_src/networks/wan2pt1_lyra2.py +++ /dev/null @@ -1,1121 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import math -import torch.nn as nn -from typing import Optional, List, Tuple -import torch.amp as amp -from einops import rearrange -from torchvision import transforms -from torch.distributed._composable.fsdp import fully_shard -from lyra_2._src.networks.wan2pt1 import ( - WanLayerNorm, - WanSelfAttention, - WAN_CROSSATTENTION_CLASSES, - VideoSize, - sinusoidal_embedding_1d, - Head, - MLPProj, - VideoRopePosition3DEmb, -) -from lyra_2._src.callbacks.model_weights_stats import WeightTrainingStat -from lyra_2._src.modules.selective_activation_checkpoint import ( - SACConfig, - CheckpointMode, - mm_only_context_fn, -) -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper -from lyra_2._ext.imaginaire.utils import log -from einops import repeat -from lyra_2._src.utils.context_parallel import ( - split_inputs_cp, - cat_outputs_cp_with_grad, -) -from torch.distributed import ProcessGroup, get_process_group_ranks - - -class Lyra2AttentionBlock(nn.Module): - """Attention block copied from WanAttentionBlock with optional camera/buffer embedding. - - If cam_dim > 0, constructs a camera encoder and injects the camera embedding - into the self-attention input (pre-attention add). If buffer_dim > 0, injects - buffer embeddings similarly. - """ - def __init__( - self, - cross_attn_type, - dim, - ffn_dim, - num_heads, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=False, - eps=1e-6, - cp_comm_type="p2p", - cam_dim: int = 0, - buffer_dim: int = 0, - buffer_sincos_multires: int = 0, - inject_kq_only: bool = False, - buffer_mlp_squeeze_dim: int = 0, - ): - super().__init__() - self.dim = dim - self.ffn_dim = ffn_dim - self.num_heads = num_heads - self.window_size = window_size - self.qk_norm = qk_norm - self.cross_attn_norm = cross_attn_norm - self.eps = eps - self.inject_kq_only = bool(inject_kq_only) - self.buffer_mlp_squeeze_dim = int(buffer_mlp_squeeze_dim) - - # layers - self.norm1 = WanLayerNorm(dim, eps) - self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps, cp_comm_type) - self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() - self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type]( - dim, num_heads, (-1, -1), qk_norm, eps, cp_comm_type - ) - self.norm2 = WanLayerNorm(dim, eps) - self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)) - - # modulation - self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) - - # optional camera/buffer encoders - self.cam_dim = cam_dim - self.cam_encoder = nn.Linear(self.cam_dim, self.dim, bias=False) if self.cam_dim > 0 else None - self.buffer_dim = buffer_dim - self.buffer_sincos_multires = int(buffer_sincos_multires) - buffer_embed_dim = self.buffer_dim - if self.buffer_sincos_multires > 0 and self.buffer_dim > 0: - buffer_embed_dim = self.buffer_dim * 2 * self.buffer_sincos_multires - if self.buffer_dim > 0: - if self.buffer_mlp_squeeze_dim > 0: - self.buffer_encoder = nn.Sequential( - nn.Linear(buffer_embed_dim, self.buffer_mlp_squeeze_dim, bias=False), - nn.Linear(self.buffer_mlp_squeeze_dim, self.dim, bias=False), - ) - else: - self.buffer_encoder = nn.Linear(buffer_embed_dim, self.dim, bias=False) - else: - self.buffer_encoder = None - - def init_weights(self): - self.self_attn.init_weights() - self.cross_attn.init_weights() - - self.norm1.reset_parameters() - self.norm2.reset_parameters() - self.norm3.reset_parameters() - - std = 1.0 / (self.dim ** 0.5) - torch.nn.init.trunc_normal_(self.modulation, std=std) - if self.cam_encoder is not None: - torch.nn.init.trunc_normal_(self.cam_encoder.weight, std=std, a=-3 * std, b=3 * std) - if self.buffer_encoder is not None: - if isinstance(self.buffer_encoder, nn.Sequential): - for layer in self.buffer_encoder: - if isinstance(layer, nn.Linear): - torch.nn.init.trunc_normal_(layer.weight, std=std, a=-3 * std, b=3 * std) - else: - torch.nn.init.trunc_normal_(self.buffer_encoder.weight, std=std, a=-3 * std, b=3 * std) - - @staticmethod - def _sincos_embed(x: torch.Tensor, multires: int) -> torch.Tensor: - if multires <= 0: - return x - x_float = x.float() - embeds = [] - for i in range(int(multires)): - freq = (2.0 ** i) * math.pi - embeds.append(torch.sin(x_float * freq)) - embeds.append(torch.cos(x_float * freq)) - out = torch.cat(embeds, dim=-1) - return out.type_as(x) - - def forward( - self, - x, - e, - seq_lens, - video_size: VideoSize, - freqs, - context, - context_lens, - camera: Optional[torch.Tensor] = None, - buffer: Optional[torch.Tensor] = None, - ): - assert e.dtype == torch.float32 - with amp.autocast("cuda", dtype=torch.float32): - e = (self.modulation + e).chunk(6, dim=1) - assert e[0].dtype == torch.float32 - - # Self-attention with optional camera injection - if camera is not None: - assert self.cam_encoder is not None - cam_emb = self.cam_encoder(camera) - else: - if self.cam_encoder is not None: - raise ValueError("cam_encoder is enabled but camera tokens are None") - cam_emb = 0 - if buffer is not None: - if self.inject_kq_only: - validity = buffer[..., -1:] # [B, L, 1] - buffer = buffer[..., :-1] - assert self.buffer_encoder is not None - if self.buffer_sincos_multires > 0: - buffer = self._sincos_embed(buffer, self.buffer_sincos_multires) - buf_emb = self.buffer_encoder(buffer) - if self.inject_kq_only: - buf_emb = buf_emb * validity - else: - if self.buffer_sincos_multires > 0 and self.buffer_encoder is not None: - raise ValueError("buffer_sincos_multires>0 requires buffer tokens, but buffer is None") - buf_emb = 0 - - y = (self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x) - - if self.inject_kq_only: - kq_bias = cam_emb + buf_emb - if isinstance(kq_bias, (int, float)) and kq_bias == 0: - kq_bias = None - y = self.self_attn(y, seq_lens, video_size, freqs, kq_bias=kq_bias) - else: - y = self.self_attn(y + cam_emb + buf_emb, seq_lens, video_size, freqs) - - with amp.autocast("cuda", dtype=torch.float32): - x = x + y * e[2].type_as(x) - - # cross-attn + ffn (same as base) - def cross_attn_ffn(x, context, context_lens, e): - x = x + self.cross_attn(self.norm3(x), context, context_lens) - y = self.ffn((self.norm2(x).float() * (1 + e[4]) + e[3]).type_as(x)) - with amp.autocast("cuda", dtype=torch.float32): - x = x + y * e[5].type_as(x) - return x - - x = cross_attn_ffn(x, context, context_lens, e) - return x - - -class Lyra2WanModel(WeightTrainingStat): - """WAN backbone with Lyra2 modifications inlined. - - - Copies core logic from WanModel (no subclassing). - - Adds clean patch embeddings and Lyra2-aware forward. - - Optionally injects camera conditioning (Plücker) via attention blocks. - """ - - def __init__( - self, - model_type: str = "t2v", - patch_size: Tuple[int, int, int] = (1, 2, 2), - text_len: int = 512, - in_dim: int = 16, - dim: int = 2048, - ffn_dim: int = 8192, - freq_dim: int = 256, - text_dim: int = 4096, - out_dim: int = 16, - num_heads: int = 16, - num_layers: int = 32, - window_size: Tuple[int, int] = (-1, -1), - qk_norm: bool = True, - cross_attn_norm: bool = True, - eps: float = 1e-6, - concat_padding_mask: bool = False, - sac_config: SACConfig = SACConfig(), - cp_comm_type: str = "p2p", - postpone_checkpoint: bool = False, - conv_patchify: bool = False, - use_plucker_condition: bool = False, - buffer_in_dim: int = 0, - buffer_pixelshuffle: bool = False, - buffer_sincos_multires: int = 0, - use_correspondence: bool = False, - inject_kq_only: bool = False, - buffer_mlp_squeeze_dim: int = 0, - ): - super().__init__() - - assert model_type in ["t2v", "i2v", "flf2v"] - self.model_type = model_type - self.patch_size = patch_size - self.text_len = text_len - self.in_dim = in_dim - self.dim = dim - self.ffn_dim = ffn_dim - self.freq_dim = freq_dim - self.text_dim = text_dim - self.out_dim = out_dim - self.num_heads = num_heads - self.num_layers = num_layers - self.window_size = window_size - self.qk_norm = qk_norm - self.cross_attn_norm = cross_attn_norm - self.eps = eps - self.concat_padding_mask = concat_padding_mask - self.cp_comm_type = cp_comm_type - self.conv_patchify = conv_patchify - self.use_plucker_condition = bool(use_plucker_condition) - self.buffer_in_dim = int(buffer_in_dim) - self.buffer_pixelshuffle = bool(buffer_pixelshuffle) - self.buffer_sincos_multires = int(buffer_sincos_multires) - self.use_correspondence = bool(use_correspondence) - self.inject_kq_only = bool(inject_kq_only) - self.buffer_mlp_squeeze_dim = int(buffer_mlp_squeeze_dim) - - # Clean-embedding holders (lazy init) - self.clean_patch_embeddings: nn.ModuleList | None = None - self.clean_kernel_sizes: list[int] | None = None - self.clean_kernel_types: list[str] | None = None - self.patch_embedding_buffer: nn.Linear | None = None - - # CP state - self.cp_group: Optional[ProcessGroup] = None - self._is_context_parallel_enabled: bool = False - - # embeddings - in_dim_eff = in_dim + 1 if self.concat_padding_mask else in_dim - if self.conv_patchify: - self.patch_embedding = nn.Conv3d(in_dim_eff, dim, kernel_size=patch_size, stride=patch_size) - else: - self.patch_embedding = nn.Linear(in_dim_eff * patch_size[0] * patch_size[1] * patch_size[2], dim) - - self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)) - self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) - self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) - - # blocks - cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn" - cam_dim = 1536 if self.use_plucker_condition else 0 - buffer_dim = 0 - if self.use_correspondence: - if self.buffer_pixelshuffle: - pt, ph, pw = self.patch_size - buffer_dim = int(self.buffer_in_dim) * pt * ph * pw - else: - buffer_dim = self.dim - self.blocks = nn.ModuleList( - [ - Lyra2AttentionBlock( - cross_attn_type, - dim, - ffn_dim, - num_heads, - window_size, - qk_norm, - cross_attn_norm, - eps, - self.cp_comm_type, - cam_dim=cam_dim, - buffer_dim=buffer_dim, - buffer_sincos_multires=self.buffer_sincos_multires, - inject_kq_only=self.inject_kq_only, - buffer_mlp_squeeze_dim=self.buffer_mlp_squeeze_dim, - ) - for _ in range(num_layers) - ] - ) - - # head - self.head = Head(dim, out_dim, patch_size, eps) - - # rope position embedding - assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 - d = dim // num_heads - self.rope_position_embedding = VideoRopePosition3DEmb( - head_dim=d, - len_h=128, - len_w=128, - len_t=32, - ) - - if model_type == "i2v" or model_type == "flf2v": - self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == "flf2v") - - # initialize weights - self.init_weights() - # SAC (same behavior as base) - self.sac_config = sac_config - if not postpone_checkpoint: - self.enable_selective_checkpoint(sac_config, self.blocks) - - # ------------------------- Clean Embeddings ------------------------- - def init_clean_patch_embeddings(self, clean_latent_frame_kernel_sizes: List[int], clean_latent_frame_kernel_types: List[str] | None = None) -> None: - """Construct clean patch embedding layers without copying weights. - - This only creates `nn.Linear` layers with the correct input feature size - for each enlarged patch. Weight copying from the base `patch_embedding` - is deferred to `copy_weights_to_clean_patch_embeddings` and should be - called at training start. - """ - if self.clean_patch_embeddings is not None: - return - - pt, ph, pw = self.patch_size - in_dim = self.in_dim + (1 if self.concat_padding_mask else 0) - base_linear: nn.Linear = self.patch_embedding - assert isinstance(base_linear, nn.Linear) - - # Create holders - self.clean_patch_embeddings = nn.ModuleList() - # Ensure newly created layers match dtype/device of the base embedding - base_dtype = base_linear.weight.dtype - base_device = base_linear.weight.device - base_bias_is_not_none = base_linear.bias is not None - - # Default to temporal kernels if types are not provided (backward compatible) - if clean_latent_frame_kernel_types is None: - clean_latent_frame_kernel_types = ["k"] * len(clean_latent_frame_kernel_sizes) - - assert len(clean_latent_frame_kernel_types) == len(clean_latent_frame_kernel_sizes), "kernel sizes/types length mismatch" - - for k, t in zip(clean_latent_frame_kernel_sizes, clean_latent_frame_kernel_types): - if t == "s": # spatial-only packing (T unchanged) - new_pt, new_ph, new_pw = pt, ph * k, pw * k - else: # default temporal-style packing (THW) - new_pt, new_ph, new_pw = pt * k, ph * k, pw * k - new_in_features = in_dim * new_pt * new_ph * new_pw - clean_lin = nn.Linear(new_in_features, self.dim, bias=base_bias_is_not_none) - clean_lin = clean_lin.to(dtype=base_dtype, device=base_device) - self.clean_patch_embeddings.append(clean_lin) - - # Mark structure ready but weights not yet copied - self.clean_kernel_sizes = list(clean_latent_frame_kernel_sizes) - self.clean_kernel_types = list(clean_latent_frame_kernel_types) - log.info( - f"Constructed {len(clean_latent_frame_kernel_sizes)} clean patch embedding layers (weights not yet copied)." - ) - - def init_patch_embedding_buffer(self, buffer_in_dim: int) -> None: - """Construct an extra patch embedding for buffer inputs.""" - if buffer_in_dim <= 0 or self.patch_embedding_buffer is not None: - return - pt, ph, pw = self.patch_size - base_linear: nn.Linear = self.patch_embedding - assert isinstance(base_linear, nn.Linear) - - # Ensure newly created layer matches dtype/device of the base embedding - base_dtype = base_linear.weight.dtype - base_device = base_linear.weight.device - base_bias_is_not_none = base_linear.bias is not None - - in_dim_eff = self.in_dim + (1 if self.concat_padding_mask else 0) - total_in_features = (in_dim_eff + int(buffer_in_dim)) * pt * ph * pw - buf_lin = nn.Linear(total_in_features, self.dim, bias=base_bias_is_not_none) - buf_lin = buf_lin.to(dtype=base_dtype, device=base_device) - self.patch_embedding_buffer = buf_lin - self.buffer_in_dim = int(buffer_in_dim) - log.info( - f"Constructed patch_embedding_buffer with extra_in_dim={self.buffer_in_dim} (weights not yet copied)." - ) - - def copy_weights_to_clean_patch_embeddings(self) -> None: - """Copy/base-initialize clean patch embeddings from `self.patch_embedding`. - - Tiling/averaging follows Conv3d style weight expansion: - - Reshape base weight [dim, c*pt*ph*pw] to [dim, c, pt, ph, pw] - - For temporal kernels ('k'): tile k along (pt, ph, pw), divide by k^3 - - For spatial kernels ('s'): tile k along (ph, pw), divide by k^2 (pt unchanged) - Bias is copied directly if present. - """ - assert self.clean_patch_embeddings is not None, "Call init_clean_patch_embeddings first." - assert self.clean_kernel_sizes is not None, "clean_kernel_sizes must be set in init_clean_patch_embeddings." - assert self.clean_kernel_types is not None, "clean_kernel_types must be set in init_clean_patch_embeddings." - - pt, ph, pw = self.patch_size - in_dim = self.in_dim + (1 if self.concat_padding_mask else 0) - base_linear: nn.Linear = self.patch_embedding - assert isinstance(base_linear, nn.Linear) - - with torch.no_grad(): - base_weight = base_linear.weight.detach() # [dim, in_dim*pt*ph*pw] - base_bias = base_linear.bias.detach() if base_linear.bias is not None else None - - for clean_lin, k, t in zip(self.clean_patch_embeddings, self.clean_kernel_sizes, self.clean_kernel_types): - # Prepare tiled weights - tiled = rearrange( - base_weight, - "o (c pt ph pw) -> o c pt ph pw", - c=in_dim, - pt=pt, - ph=ph, - pw=pw, - ) - if t == "s": - # Only H, W are expanded; T unchanged - tiled = repeat( - tiled, - "o c pt ph pw -> o c pt (ph hk) (pw wk)", - hk=k, - wk=k, - ) - divisor = (k ** 2) - else: - # Temporal-style: expand along T, H, W - tiled = repeat( - tiled, - "o c pt ph pw -> o c (pt tk) (ph hk) (pw wk)", - tk=k, - hk=k, - wk=k, - ) - divisor = (k ** 3) - tiled = rearrange(tiled, "o c pt ph pw -> o (c pt ph pw)") - tiled = tiled / divisor - clean_lin.weight.copy_(tiled) - clean_lin.bias.copy_(base_bias) - - if self.patch_embedding_buffer is not None and not self.buffer_pixelshuffle: - buf_lin = self.patch_embedding_buffer - buf_lin.weight.zero_() - buf_lin.weight[:, : base_weight.shape[1]].copy_(base_weight) - if buf_lin.bias is not None and base_bias is not None: - buf_lin.bias.copy_(base_bias) - log.info( - f"patch_embedding_buffer weight shape={tuple(buf_lin.weight.shape)}, " - f"base patch_embedding weight shape={tuple(base_weight.shape)}" - ) - - log.info("Copied base patch_embedding weights into clean_patch_embeddings and marked initialized.") - - # --------------------------- Utilities ---------------------------- - @staticmethod - def _pad_for_linear_patch(x: torch.Tensor, kernel: Tuple[int, int, int]) -> torch.Tensor: - """Pad a BCHWT tensor so T/H/W are divisible by kernel. - - Args: - x: Tensor [B, C, T, H, W] - kernel: (kt, kh, kw) - """ - _, _, t, h, w = x.shape - kt, kh, kw = kernel - pad_t = (kt - (t % kt)) % kt - pad_h = (kh - (h % kh)) % kh - pad_w = (kw - (w % kw)) % kw - if pad_t or pad_h or pad_w: - x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate") - return x - - def _patchify_linear(self, x: torch.Tensor, patch: Tuple[int, int, int], lin: nn.Linear) -> Tuple[torch.Tensor, Tuple[int, int, int]]: - """Patchify via einops+Linear, returning tokens and (f,h,w) grid size. - - Args: - x: [B, C, T, H, W] - patch: (pt, ph, pw) - lin: Linear mapping from flattened patch to dim - Returns: - x_tokens: [B, (f*h*w), dim] - grid_size: (f, h, w) - """ - pt, ph, pw = patch - x = self._pad_for_linear_patch(x, patch) - b, c, t, h, w = x.shape - f, hh, ww = t // pt, h // ph, w // pw - # b c (f pt) (hh ph) (ww pw) -> b f hh ww (c pt ph pw) - x = rearrange(x, "b c (f pt) (hh ph) (ww pw) -> b f hh ww (c pt ph pw)", f=f, pt=pt, hh=hh, ph=ph, ww=ww, pw=pw) - x = lin(x) - x = rearrange(x, "b f h w d -> b (f h w) d") - return x, (f, hh, ww) - - def _pixelshuffle_tokens(self, x: torch.Tensor, patch: Tuple[int, int, int]) -> Tuple[torch.Tensor, Tuple[int, int, int]]: - """Rearrange patches into channels without a linear projection.""" - pt, ph, pw = patch - x = self._pad_for_linear_patch(x, patch) - b, c, t, h, w = x.shape - f, hh, ww = t // pt, h // ph, w // pw - x = rearrange( - x, - "b c (f pt) (hh ph) (ww pw) -> b (f hh ww) (c pt ph pw)", - f=f, - pt=pt, - hh=hh, - ph=ph, - ww=ww, - pw=pw, - ) - return x, (f, hh, ww) - - # ------------------------- Lyra2 Path ------------------------- - def _patchify_lyra2( - self, - x: torch.Tensor, - framepack_indices: torch.Tensor, - framepack_splits: List[int], - framepack_kernel_ids: List[int], - framepack_kernel_types: List[str] | None = None, - camera: Optional[torch.Tensor] = None, - buffer_B_C_T_H_W: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Tuple[int, int], Tuple[int, int, int]]: - """Lyra2-aware patchify. - - Splits time dimension according to `framepack_splits`. For each chunk, - - if kernel_id == -1: use base patch size and base patch_embedding - - else: use clean_patch_embeddings[kernel_id] with enlarged patch sizes - If corresponding kernel type is 's', enlarge only H and W (T unchanged) - - Returns: - x_tokens: [B, L, dim] - freqs_tokens: [L, 1, 1, head_dim] as produced by rope embedder per token - camera_tokens: [B, L, cam_dim] if camera is provided, else None - buffer_tokens: [B, L, dim] if buffer is injected, else None - gen_range: (gen_start, gen_end) token range for generation part - gen_grid: (f, h, w) grid for generation part - """ - assert self.clean_patch_embeddings is not None, ( - "clean_patch_embeddings must be initialized before using Lyra2" - ) - xs = x[:,:,framepack_indices].split(framepack_splits, dim=2) # split along T - inds = framepack_indices.split(framepack_splits) - - # Determine base T,H,W for token grid and precompute base RoPE freqs once - _, _, T_total, H, W = x.shape - pt, ph, pw = self.patch_size - f_base = T_total // pt - h_base = H // ph - w_base = W // pw - freqs_base = self.rope_position_embedding.generate_embeddings( - B_T_H_W_C=torch.Size([x.shape[0], f_base, h_base, w_base, self.dim // self.num_heads]) - ) - # (f*h*w,1,1,d) -> (1,d,f,h,w) - freqs_base_5d = rearrange(freqs_base, "(f h w) 1 1 d -> 1 d f h w", f=f_base, h=h_base, w=w_base) - - token_chunks: List[torch.Tensor] = [] - freq_chunks: List[torch.Tensor] = [] - cam_chunks: List[torch.Tensor] = [] if camera is not None else [] - use_buffer_tokens = bool(self.use_correspondence and buffer_B_C_T_H_W is not None) - buf_chunks: List[torch.Tensor] = [] if use_buffer_tokens else [] - buf_validity_chunks: List[torch.Tensor] = [] if use_buffer_tokens else [] - buf_splits = None - buffer_full_match = False - if buffer_B_C_T_H_W is not None: - buffer_full_match = (int(buffer_B_C_T_H_W.shape[2]) == int(x.shape[2])) - if buffer_full_match: - buf_splits = buffer_B_C_T_H_W[:, :, framepack_indices].split(framepack_splits, dim=2) - gen_start = None - gen_end = None - total_tokens = 0 - gen_grid = (0, 0, 0) - buffer_token_dim = None - if use_buffer_tokens: - if self.buffer_pixelshuffle: - buffer_token_dim = int(self.buffer_in_dim) * pt * ph * pw - else: - buffer_token_dim = int(self.dim) - - for i, x_chunk in enumerate(xs): - kid = framepack_kernel_ids[i] - ktype = None - if framepack_kernel_types is not None and i < len(framepack_kernel_types): - ktype = framepack_kernel_types[i] - if kid == -1: - # Generated/new segment uses base embedding - if buffer_B_C_T_H_W is not None and self.use_correspondence: - x_tokens, (f, h, w) = self._patchify_linear(x_chunk, self.patch_size, self.patch_embedding) - buf = buf_splits[i] if buffer_full_match else buffer_B_C_T_H_W - buf = buf.to(dtype=x_chunk.dtype, device=x_chunk.device) - assert buf.shape[2] == x_chunk.shape[2], ( - f"Buffer T={buf.shape[2]} must match latent T={x_chunk.shape[2]}" - ) - assert buf.shape[-2:] == x_chunk.shape[-2:], "Buffer spatial size must match latent size." - if not self.buffer_pixelshuffle: - raise ValueError("use_correspondence requires buffer_pixelshuffle=True") - buf_tokens, _ = self._pixelshuffle_tokens(buf, self.patch_size) - buf_chunks.append(buf_tokens) - buf_validity_chunks.append( - torch.ones(buf_tokens.shape[0], buf_tokens.shape[1], 1, device=buf_tokens.device, dtype=buf_tokens.dtype) - ) - else: - x_tokens, (f, h, w) = self._patchify_linear(x_chunk, self.patch_size, self.patch_embedding) - token_chunks.append(x_tokens) - if gen_start is None: - gen_start = total_tokens - total_tokens += x_tokens.shape[1] - gen_end = total_tokens - gen_grid = (f, h, w) - - # Slice base precomputed freqs along T using provided indices (with padding if needed) - t_idx = inds[i].to(device=freqs_base_5d.device, dtype=torch.long) - if f > t_idx.numel(): - pad_t = f - t_idx.numel() - t_idx = torch.cat([t_idx, t_idx[-1:].repeat(pad_t)], dim=0) - freqs_sel = freqs_base_5d[:, :, t_idx, :, :] - freqs_tokens = rearrange(freqs_sel[0], "d f h w -> (f h w) 1 1 d") - freq_chunks.append(freqs_tokens) - - # Camera tokens for generated segment (no pooling; slice T to match f) - if camera is not None: - cam_base_5d = camera # [B, D_cam, f_base, h_base, w_base] - cam_t_idx = t_idx.to(device=cam_base_5d.device, dtype=torch.long) - cam_sel = cam_base_5d[:, :, cam_t_idx, :, :] - cam_tokens = rearrange(cam_sel, "b d f h w -> b (f h w) d").type_as(x_tokens) - cam_chunks.append(cam_tokens) - else: - # History/clean segment uses enlarged clean embedding - assert self.clean_kernel_sizes is not None - kernel_factor = int(self.clean_kernel_sizes[kid]) - clean_lin = self.clean_patch_embeddings[kid] - if (self.clean_kernel_types is not None and self.clean_kernel_types[kid] == "s") or ktype == "s": - # Spatial-only packing - enlarged_patch = ( - self.patch_size[0], - self.patch_size[1] * kernel_factor, - self.patch_size[2] * kernel_factor, - ) - pool_k = ( - 1, - enlarged_patch[1] // self.patch_size[1], - enlarged_patch[2] // self.patch_size[2], - ) - else: - # Temporal packing - enlarged_patch = ( - self.patch_size[0] * kernel_factor, - self.patch_size[1] * kernel_factor, - self.patch_size[2] * kernel_factor, - ) - pool_k = ( - enlarged_patch[0] // self.patch_size[0], - enlarged_patch[1] // self.patch_size[1], - enlarged_patch[2] // self.patch_size[2], - ) - x_tokens, (f, h, w) = self._patchify_linear(x_chunk, enlarged_patch, clean_lin) - token_chunks.append(x_tokens) - total_tokens += x_tokens.shape[1] - if use_buffer_tokens: - if buffer_full_match: - buf = buf_splits[i].to(dtype=x_chunk.dtype, device=x_chunk.device) - if ktype == "s": - if kernel_factor <= 1: - buf_tokens, _ = self._pixelshuffle_tokens(buf, self.patch_size) - else: - pool_k_buf = (1, pool_k[1], pool_k[2]) - buf_pooled = torch.nn.functional.avg_pool3d( - buf.float(), - kernel_size=pool_k_buf, - stride=pool_k_buf, - ) - buf_tokens, _ = self._pixelshuffle_tokens(buf_pooled.type_as(buf), self.patch_size) - buf_is_real = True - else: - buf_tokens = torch.full( - (x_tokens.shape[0], x_tokens.shape[1], int(buffer_token_dim)), - -1.0, - device=x_tokens.device, - dtype=x_tokens.dtype, - ) - buf_is_real = False - else: - # Non-full buffer (gen-only): history buffer tokens should be empty. - fill_val = -1.0 if ktype != "s" else 0.0 - buf_tokens = torch.full( - (x_tokens.shape[0], x_tokens.shape[1], int(buffer_token_dim)), - fill_val, - device=x_tokens.device, - dtype=x_tokens.dtype, - ) - buf_is_real = False - buf_chunks.append(buf_tokens) - validity_val = 1.0 if buf_is_real else 0.0 - buf_validity_chunks.append(torch.full( - (buf_tokens.shape[0], buf_tokens.shape[1], 1), - validity_val, - device=buf_tokens.device, - dtype=buf_tokens.dtype, - )) - - # Pool from base freqs with kernel equal to ratio enlarged/base - # Slice along T using provided indices and pad so dims divisible by pool_k (except when pool_k[0]==1) - t_idx = inds[i].to(device=freqs_base_5d.device, dtype=torch.long) - if pool_k[0] > 1: - pad_t = (-t_idx.numel()) % pool_k[0] - if pad_t: - t_idx = torch.cat([t_idx, t_idx[-1:].repeat(pad_t)], dim=0) - freqs_sel = freqs_base_5d[:, :, t_idx, :, :] - pad_h = (-h_base) % pool_k[1] - pad_w = (-w_base) % pool_k[2] - if pad_h or pad_w: - freqs_sel = torch.nn.functional.pad( - freqs_sel, - (0, pad_w, 0, pad_h, 0, 0), # pad W, H (T handled by index pad) - mode="replicate", - ) - freqs_pooled = torch.nn.functional.avg_pool3d( - freqs_sel.float(), - kernel_size=(pool_k[0], pool_k[1], pool_k[2]), - stride=(pool_k[0], pool_k[1], pool_k[2]), - ) - freqs_tokens = rearrange(freqs_pooled[0], "d f h w -> (f h w) 1 1 d") - freq_chunks.append(freqs_tokens) - - # Camera pooling mirrors freqs pooling - if camera is not None: - cam_base_5d = camera # [B, D_cam, f_base, h_base, w_base] - cam_t_idx = inds[i].to(device=cam_base_5d.device, dtype=torch.long) - if pool_k[0] > 1: - pad_t_cam = (-cam_t_idx.numel()) % pool_k[0] - if pad_t_cam: - cam_t_idx = torch.cat([cam_t_idx, cam_t_idx[-1:].repeat(pad_t_cam)], dim=0) - cam_sel = cam_base_5d[:, :, cam_t_idx, :, :] - cam_pad_h = (-h_base) % pool_k[1] - cam_pad_w = (-w_base) % pool_k[2] - if cam_pad_h or cam_pad_w: - cam_sel = torch.nn.functional.pad( - cam_sel, - (0, cam_pad_w, 0, cam_pad_h, 0, 0), - mode="replicate", - ) - cam_pooled = torch.nn.functional.avg_pool3d( - cam_sel.float(), - kernel_size=(pool_k[0], pool_k[1], pool_k[2]), - stride=(pool_k[0], pool_k[1], pool_k[2]), - ) - cam_tokens = rearrange(cam_pooled, "b d f h w -> b (f h w) d").type_as(x_tokens) - cam_chunks.append(cam_tokens) - x_tokens = torch.cat(token_chunks, dim=1) - freqs_tokens = torch.cat(freq_chunks, dim=0) - camera_tokens = torch.cat(cam_chunks, dim=1) if camera is not None else None - buffer_tokens = torch.cat(buf_chunks, dim=1) if use_buffer_tokens else None - - # When inject_kq_only is enabled, append a per-token validity indicator channel - # (1.0 = real data, 0.0 = dummy) so the attention block can mask out - # dummy positions after buffer encoding. - if self.inject_kq_only and buffer_tokens is not None: - validity_tokens = torch.cat(buf_validity_chunks, dim=1) - buffer_tokens = torch.cat([buffer_tokens, validity_tokens], dim=-1) - - assert gen_start is not None and gen_end is not None - return x_tokens, freqs_tokens, camera_tokens, buffer_tokens, (gen_start, gen_end), gen_grid - - # ---------------------------- Forward ----------------------------- - def forward( - self, - x_B_C_T_H_W: torch.Tensor, - timesteps_B_T: torch.Tensor, - crossattn_emb: torch.Tensor, - seq_len: int | None = None, - frame_cond_crossattn_emb_B_L_D: torch.Tensor | None = None, - y_B_C_T_H_W: torch.Tensor | None = None, - y_buffer_B_C_T_H_W: torch.Tensor | None = None, - padding_mask: Optional[torch.Tensor] = None, - is_uncond: bool = False, - slg_layers=None, - **kwargs, - ): - # Choose path: base-like or Lyra2 - framepack_keys = {"framepack_indices", "framepack_splits", "framepack_kernel_ids"} - use_framepack = framepack_keys.issubset(set(kwargs.keys())) - - assert timesteps_B_T.shape[1] == 1 - t_B = timesteps_B_T[:, 0] - - if y_B_C_T_H_W is not None: - x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, y_B_C_T_H_W], dim=1) - - if not use_framepack: - # Base WanModel forward (with minor compatibility tweaks) - if self.concat_padding_mask and padding_mask is not None: - padding_mask = transforms.functional.resize( - padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST - ) - x_B_C_T_H_W = torch.cat( - [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 - ) - - if self.conv_patchify: - x_B_D_T_H_W = self.patch_embedding(x_B_C_T_H_W) - x_B_T_H_W_D = rearrange(x_B_D_T_H_W, "b d t h w -> b t h w d") - else: - x_B_T_H_W_D = rearrange( - x_B_C_T_H_W, - "b c (t kt) (h kh) (w kw) -> b t h w (c kt kh kw)", - kt=self.patch_size[0], - kh=self.patch_size[1], - kw=self.patch_size[2], - ) - x_B_T_H_W_D = self.patch_embedding(x_B_T_H_W_D) - - video_size = VideoSize(T=x_B_T_H_W_D.shape[1], H=x_B_T_H_W_D.shape[2], W=x_B_T_H_W_D.shape[3]) - x_B_L_D = rearrange(x_B_T_H_W_D, "b t h w d -> b (t h w) d") - seq_lens = torch.tensor([u.size(0) for u in x_B_L_D], dtype=torch.long) - seq_len = seq_lens.max().item() - assert seq_lens.max() == seq_len - - with amp.autocast("cuda", dtype=torch.float32): - e_B_D = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t_B).float()) - e0_B_6_D = self.time_projection(e_B_D).unflatten(1, (6, self.dim)) - - if crossattn_emb.dim() == 4: - crossattn_emb = crossattn_emb.squeeze(1) - context_lens = None - context_B_L_D = self.text_embedding(crossattn_emb) - if frame_cond_crossattn_emb_B_L_D is not None: - context_clip = self.img_emb(frame_cond_crossattn_emb_B_L_D) - context_B_L_D = torch.concat([context_clip, context_B_L_D], dim=1) - - kwargs_blocks = dict( - e=e0_B_6_D, - seq_lens=seq_lens, - video_size=video_size, - freqs=self.rope_position_embedding(x_B_T_H_W_D), - context=context_B_L_D, - context_lens=context_lens, - ) - - for block_idx, block in enumerate(self.blocks): - if slg_layers is not None and block_idx in slg_layers and is_uncond: - continue - x_B_L_D = block(x_B_L_D, **kwargs_blocks) - - x_B_L_D = self.head(x_B_L_D, e_B_D) - t, h, w = video_size - x_B_C_T_H_W = rearrange( - x_B_L_D, - "b (t h w) (nt nh nw d) -> b d (t nt) (h nh) (w nw)", - nt=self.patch_size[0], - nh=self.patch_size[1], - nw=self.patch_size[2], - t=t, - h=h, - w=w, - d=self.out_dim, - ) - return x_B_C_T_H_W - - # Lyra2 path - # Optional camera extraction (Plücker condition path). Expect last 384 channels appended to x. - camera_5d = None - if self.use_plucker_condition: - camera_ch = 384 - assert x_B_C_T_H_W.size(1) >= camera_ch, "Input missing appended camera channels (384)." - camera = x_B_C_T_H_W[:, -camera_ch:] - x_B_C_T_H_W = x_B_C_T_H_W[:, :-camera_ch] - camera_5d = rearrange( - camera, - "b c t (h h2) (w w2) -> b (c h2 w2) t h w", - h2=2, w2=2, - ) - - if self.concat_padding_mask and padding_mask is not None: - padding_mask = transforms.functional.resize( - padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST - ) - x_B_C_T_H_W = torch.cat( - [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 - ) - - framepack_indices = kwargs["framepack_indices"] - framepack_splits = kwargs["framepack_splits"] - framepack_kernel_ids = kwargs["framepack_kernel_ids"] - framepack_kernel_types = kwargs.get("framepack_kernel_types", None) - - assert self.clean_patch_embeddings is not None - - x_tokens, freqs_tokens, camera_tokens, buffer_tokens, (gen_start, gen_end), (f_gen, h_gen, w_gen) = self._patchify_lyra2( - x_B_C_T_H_W, - framepack_indices, - framepack_splits, - framepack_kernel_ids, - framepack_kernel_types, - camera=camera_5d, - buffer_B_C_T_H_W=y_buffer_B_C_T_H_W, - ) - - # Context Parallel after Lyra2 patchify: split tokens along L if enabled - cp_enabled = getattr(self, "is_context_parallel_enabled", False) - cp_group = getattr(self, "cp_group", None) - if cp_enabled and cp_group is not None: - L = x_tokens.size(1) - cp_size = cp_group.size() - assert L % cp_size == 0, f"Token length {L} must be divisible by cp_size {cp_size}" - assert freqs_tokens.shape[0] % cp_size == 0, ( - f"Freq tokens length {freqs_tokens.shape[0]} must be divisible by cp_size {cp_size}" - ) - x_tokens = split_inputs_cp(x_tokens, seq_dim=1, cp_group=cp_group) - freqs_tokens = split_inputs_cp(freqs_tokens, seq_dim=0, cp_group=cp_group) - if camera_tokens is not None: - assert camera_tokens.size(1) % cp_size == 0, ( - f"Camera tokens length {camera_tokens.size(1)} must be divisible by cp_size {cp_size}" - ) - camera_tokens = split_inputs_cp(camera_tokens, seq_dim=1, cp_group=cp_group) - if buffer_tokens is not None: - assert buffer_tokens.size(1) % cp_size == 0, ( - f"Buffer tokens length {buffer_tokens.size(1)} must be divisible by cp_size {cp_size}" - ) - buffer_tokens = split_inputs_cp(buffer_tokens, seq_dim=1, cp_group=cp_group) - - with amp.autocast("cuda", dtype=torch.float32): - e_B_D = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t_B).float()) - e0_B_6_D = self.time_projection(e_B_D).unflatten(1, (6, self.dim)) - - if crossattn_emb.dim() == 4: - crossattn_emb = crossattn_emb.squeeze(1) - context_B_L_D = self.text_embedding(crossattn_emb) - if frame_cond_crossattn_emb_B_L_D is not None: - context_clip = self.img_emb(frame_cond_crossattn_emb_B_L_D) - context_B_L_D = torch.concat([context_clip, context_B_L_D], dim=1) - - assert x_tokens.shape[0] == 1 - seq_lens = torch.tensor([x_tokens.size(1)] * x_tokens.size(0), dtype=torch.long, device=x_tokens.device) - if self.use_plucker_condition: - kwargs_blocks = dict( - e=e0_B_6_D, - seq_lens=seq_lens, - video_size=VideoSize(T=1, H=1, W=x_tokens.size(1)), - freqs=freqs_tokens, - context=context_B_L_D, - context_lens=None, - camera=camera_tokens, - buffer=buffer_tokens, - ) - else: - kwargs_blocks = dict( - e=e0_B_6_D, - seq_lens=seq_lens, - video_size=VideoSize(T=1, H=1, W=x_tokens.size(1)), - freqs=freqs_tokens, - context=context_B_L_D, - context_lens=None, - buffer=buffer_tokens, - ) - - x_B_L_D = x_tokens - for block in self.blocks: - x_B_L_D = block(x_B_L_D, **kwargs_blocks) - - x_B_L_D = self.head(x_B_L_D, e_B_D) - - if cp_enabled and cp_group is not None: - x_B_L_D = cat_outputs_cp_with_grad(x_B_L_D, seq_dim=1, cp_group=cp_group) - - x_gen = x_B_L_D[:, gen_start:gen_end] - x_B_C_T_H_W = rearrange( - x_gen, - "b (f h w) (pt ph pw c) -> b c (f pt) (h ph) (w pw)", - f=f_gen, - h=h_gen, - w=w_gen, - pt=self.patch_size[0], - ph=self.patch_size[1], - pw=self.patch_size[2], - c=self.out_dim, - ) - return x_B_C_T_H_W - - def init_weights(self): - # Match base WanModel.init_weights - for m in self.modules(): - if isinstance(m, nn.Linear): - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - nn.init.zeros_(m.bias) - - for block in self.blocks: - block.init_weights() - self.head.init_weights() - - if isinstance(self.patch_embedding, nn.Linear): - nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) - if self.patch_embedding.bias is not None: - nn.init.zeros_(self.patch_embedding.bias) - if self.patch_embedding_buffer is not None: - nn.init.xavier_uniform_(self.patch_embedding_buffer.weight.flatten(1)) - if self.patch_embedding_buffer.bias is not None: - nn.init.zeros_(self.patch_embedding_buffer.bias) - - for m in self.text_embedding.modules(): - if isinstance(m, nn.Linear): - nn.init.normal_(m.weight, std=0.02) - if m.bias is not None: - nn.init.zeros_(m.bias) - - for m in self.time_embedding.modules(): - if isinstance(m, nn.Linear): - nn.init.normal_(m.weight, std=0.02) - if m.bias is not None: - nn.init.zeros_(m.bias) - - for m in self.time_projection.modules(): - if isinstance(m, nn.Linear): - nn.init.normal_(m.weight, std=0.02) - if m.bias is not None: - nn.init.zeros_(m.bias) - - nn.init.zeros_(self.head.head.weight) - if self.head.head.bias is not None: - nn.init.zeros_(self.head.head.bias) - - def fully_shard(self, mesh, **fsdp_kwargs): - for i, block in enumerate(self.blocks): - fully_shard(block, mesh=mesh, reshard_after_forward=True, **fsdp_kwargs) - fully_shard(self.head, mesh=mesh, reshard_after_forward=False, **fsdp_kwargs) - fully_shard(self.text_embedding, mesh=mesh, reshard_after_forward=True, **fsdp_kwargs) - fully_shard(self.time_embedding, mesh=mesh, reshard_after_forward=True, **fsdp_kwargs) - fully_shard(self.patch_embedding, mesh=mesh, reshard_after_forward=True, **fsdp_kwargs) - fully_shard(self.time_projection, mesh=mesh, reshard_after_forward=True, **fsdp_kwargs) - if self.clean_patch_embeddings is not None: - for lin in self.clean_patch_embeddings: - fully_shard(lin, mesh=mesh, reshard_after_forward=True, **fsdp_kwargs) - - def enable_context_parallel(self, process_group: Optional[ProcessGroup] = None): - # For Lyra2, we split after patchify; disable CP inside rope embedder - self.rope_position_embedding.disable_context_parallel() - cp_ranks = get_process_group_ranks(process_group) - for block in self.blocks: - block.self_attn.set_context_parallel_group( - process_group=process_group, - ranks=cp_ranks, - stream=torch.cuda.Stream(), - ) - - self._is_context_parallel_enabled = True - self.cp_group = process_group - - def disable_context_parallel(self): - # Lyra2 CP is applied post-patchify via token-splitting; simply drop the group flag - self.cp_group = None - self._is_context_parallel_enabled = False - - @property - def is_context_parallel_enabled(self) -> bool: - return self._is_context_parallel_enabled - - def enable_selective_checkpoint(self, sac_config: SACConfig, blocks: nn.ModuleList): - if sac_config.mode == CheckpointMode.NONE: - return - log.info( - f"Enable selective checkpoint with {sac_config.mode}, for every {sac_config.every_n_blocks} blocks. Total blocks: {len(blocks)}" - ) - _context_fn = sac_config.get_context_fn() - for block_id, block in blocks.named_children(): - if int(block_id) % sac_config.every_n_blocks == 0: - log.info(f"Enable selective checkpoint for block {block_id}") - block = ptd_checkpoint_wrapper( - block, - context_fn=_context_fn, - preserve_rng_state=False, - ) - blocks.register_module(block_id, block) - self.register_module( - "head", - ptd_checkpoint_wrapper( - self.head, - context_fn=_context_fn, - preserve_rng_state=False, - ), - ) diff --git a/lyra_2/_src/schedulers/__init__.py b/lyra_2/_src/schedulers/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/lyra_2/_src/schedulers/rectified_flow.py b/lyra_2/_src/schedulers/rectified_flow.py deleted file mode 100644 index c348cc9b8fb801f9c92e05eb7cfeaa2bd4cdcc69..0000000000000000000000000000000000000000 --- a/lyra_2/_src/schedulers/rectified_flow.py +++ /dev/null @@ -1,261 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable -from diffusers import FlowMatchEulerDiscreteScheduler -import torch -from contextlib import contextmanager - - -class TrainTimeWeight: - def __init__( - self, - noise_scheduler, - weight: str = "uniform", - ): - self.weight = weight - self.noise_scheduler = noise_scheduler - if self.weight == "reweighting": - x = self.noise_scheduler.timesteps.cuda() - y = torch.exp(-2 * (( x - self.noise_scheduler.config.num_train_timesteps / 2) / self.noise_scheduler.config.num_train_timesteps) ** 2) - y_shifted = y - y.min() - bsmntw_weighing = y_shifted * (self.noise_scheduler.config.num_train_timesteps / y_shifted.sum()) - self.linear_timesteps_weights = bsmntw_weighing.cpu() - - def __call__( - self, - t, - tensor_kwargs - ) -> torch.Tensor: - if self.weight == "uniform": - wts = torch.ones_like(t) - elif self.weight == "reweighting": - timestep_id = torch.argmin((self.noise_scheduler.timesteps.to(**tensor_kwargs) - t).abs()) - wts = self.linear_timesteps_weights[timestep_id] - else: - raise NotImplementedError(f"Time weight '{self.weight}' is not implemented.") - - return wts - - -class TrainTimeSampler: - def __init__( - self, - distribution: str = "uniform", - max_timestep_boundary: float = 1.0, - min_timestep_boundary: float = 0.0, - - ): - self.distribution = distribution - self.max_timestep_boundary = max_timestep_boundary - self.min_timestep_boundary = min_timestep_boundary - - @torch.no_grad() - def __call__( - self, - batch_size: int, - device: torch.device = torch.device("cpu"), - dtype: torch.dtype = torch.float32, - ) -> torch.Tensor: - """ - Sample time tensor for training - - Returns: - torch.Tensor: Time tensor, shape (batch_size,) - """ - if self.distribution == "uniform": - t = torch.rand((batch_size,)) * (self.max_timestep_boundary - self.min_timestep_boundary) + self.min_timestep_boundary - elif self.distribution == "logitnormal": - t = torch.sigmoid(torch.randn((batch_size,), device=device, dtype=dtype)) # .to(device=device, dtype=dtype) - else: - raise NotImplementedError(f"Time distribution '{self.dist}' is not implemented.") - - return t - - -class RectifiedFlow: - def __init__( - self, - velocity_field: Callable, - train_time_distribution: TrainTimeSampler | str = "uniform", - max_timestep_boundary: float = 1.0, - min_timestep_boundary: float = 0.0, - train_time_weight_method: str = "uniform", - use_dynamic_shift: bool = False, - shift: int= 3, - device: torch.device = torch.device("cpu"), - dtype: torch.dtype = torch.float32, - ): - r"""Initialize the RectifiedFlow class. - - Args: - velocity_field (`Callable`): - A function that predicts the velocity given the current state and time. - train_time_distribution (`TrainTimeSampler` or `str`, *optional*, defaults to `"uniform"`): - Distribution for sampling training times. - Can be an instance of `TrainTimeSampler` or a string specifying the distribution type. - train_time_weight (`TrainTimeWeight` or `str`, *optional*, defaults to `"uniform"`): - Weight applied to training times. - Can be an instance of `TrainTimeWeight` or a string specifying the weight type. - """ - self.velocity_field = velocity_field - self.train_time_sampler: TrainTimeSampler = ( - train_time_distribution - if isinstance(train_time_distribution, TrainTimeSampler) - else TrainTimeSampler(train_time_distribution,max_timestep_boundary,min_timestep_boundary) - ) - - - if use_dynamic_shift: - self.noise_scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=use_dynamic_shift) - else: - self.noise_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift) - self.train_time_weight = TrainTimeWeight(self.noise_scheduler, train_time_weight_method) - self.use_t_in_reverse_order = True - - self.device = torch.device(device) if isinstance(device, str) else device - self.dtype = torch.dtype(dtype) if isinstance(dtype, str) else dtype - - @contextmanager - def temporary_use_t_in_reverse_order(self, use_t_in_reverse_order: bool): - """ - Context manager to temporarily set use_t_in_reverse_order value. - - Args: - use_t_in_reverse_order (bool): The temporary value to set for use_t_in_reverse_order. - - Example: - with rectified_flow.temporary_use_t_in_reverse_order(False): - # use_t_in_reverse_order is temporarily set to False - # ... do some operations ... - # use_t_in_reverse_order is restored to its original value - """ - original_value = self.use_t_in_reverse_order - try: - self.use_t_in_reverse_order = use_t_in_reverse_order - yield - finally: - self.use_t_in_reverse_order = original_value - - def sample_train_time(self, batch_size: int): - r"""This method calls the `TrainTimeSampler` to sample training times. - - Returns: - t (`torch.Tensor`): - A tensor of sampled training times with shape `(batch_size,)`, - matching the class specified `device` and `dtype`. - """ - time = self.train_time_sampler(batch_size, device=self.device, dtype=self.dtype) - return time - - def get_discrete_timestamp(self, u, tensor_kwargs): - r"""This method map time from 0,1 to discrete steps - """ - assert 0 <= u.min() and u.max() <= 1, "Time must be in [0, 1]" - indices = (u.squeeze() * self.noise_scheduler.config.num_train_timesteps).long() - if not self.use_t_in_reverse_order: - indices = self.noise_scheduler.config.num_train_timesteps - indices - 1 - timesteps = self.noise_scheduler.timesteps.to(**tensor_kwargs)[indices] - return timesteps.unsqueeze(0) - - def get_sigmas(self, timesteps, tensor_kwargs): - - sigmas = self.noise_scheduler.sigmas.to(**tensor_kwargs) - schedule_timesteps = self.noise_scheduler.timesteps.to(**tensor_kwargs) - - step_indices = [(schedule_timesteps == t).nonzero() for t in timesteps] - - sigma = sigmas[step_indices].flatten() - return sigma - - - def get_interpolation( - self, - x_0: torch.Tensor, - x_1: torch.Tensor, - sigmas: torch.Tensor | None, - t: torch.Tensor | None = None, - ): - r""" - This method computes interpolation `X_t` and their time derivatives `dotX_t` at the specified time points `t`. - Note that `x_0` is the noise, and `x_1` is the clean data. This is aligned with the notation in the recified flow community, - but different from the notation in the diffusion community. - - Args: - x_0 (`torch.Tensor`): - noise, shape `(B, D1, D2, ..., Dn)`, where `B` is the batch size, and `D1, D2, ..., Dn` are the data dimensions. - x_1 (`torch.Tensor`): - clean data, with the same shape as `x_0` - sigmas (`torch.Tensor`): - A tensor of sigmas, with shape `(B,)`, where each value is in `[0, 1]`. - t (`torch.Tensor`): - A tensor of time, with shape `(B,)`, where each value is in `[0, 1]`. - - Returns: - (x_t, dot_x_t) (`Tuple[torch.Tensor, torch.Tensor]`): - - x_t (`torch.Tensor`): The interpolated state, with shape `(B, D1, D2, ..., Dn)`. - - dot_x_t (torch.Tensor): The time derivative of the interpolated state, with the same shape as `x_t`. - """ - if sigmas is None: - assert t is not None, "t must be provided when sigmas is None." - timesteps = self.get_discrete_timestamp(t, {"device": self.device, "dtype": self.dtype}) - sigmas = self.get_sigmas(timesteps, {"device": self.device, "dtype": self.dtype}) - else: - assert t is None, "t must be None when sigmas is provided." - sigmas = sigmas.to(device=self.device, dtype=self.dtype) - - assert x_0.shape == x_1.shape, "x_0 and x_1 must have the same shape." - assert x_0.shape[0] == x_1.shape[0], "Batch size of x_0 and x_1 must match." - assert sigmas.shape[0] == x_1.shape[0], "Batch size of sigmas must match x_1." - # Reshape t to match dimensions of x_1 - sigmas = sigmas.view(sigmas.shape[0], *([1] * (len(x_1.shape) - 1))) - x_t = x_0 * sigmas + x_1 * (1 - sigmas) - dot_x_t = x_0 - x_1 - return x_t, dot_x_t - - def get_x0_from_flow_prediction( - self, - x_t: torch.Tensor, - dot_x_t: torch.Tensor, - t: torch.Tensor | None = None, - sigmas: torch.Tensor | None = None, - ): - r""" - Convert flow matching's prediction to x0 prediction. - x_t: the input noisy data with shape [B, D1, D2, ..., Dn] - dot_x_t: the prediction with shape [B, D1, D2, ..., Dn] - t: the timestep with shape [B,], where each value is in [0, 1] - sigmas: the sigmas with shape [B,], where each value is in [0, 1] - - pred = noise(x_0) - x_1 - x_t = (1-sigma_t) * x_1 + sigma_t * noise(x_0) - we have x_1 = x_t - sigma_t * pred - see derivations https://chatgpt.com/share/67bf8589-3d04-8008-bc6e-4cf1a24e2d0e - """ - assert t is not None or sigmas is not None, "Either t or sigmas must be provided." - if sigmas is not None: - assert t is None, "t and sigmas cannot be provided at the same time." - sigmas = sigmas.to(device=self.device, dtype=self.dtype) - else: - timesteps = self.get_discrete_timestamp(t, {"device": self.device, "dtype": self.dtype}) - sigmas = self.get_sigmas(timesteps, {"device": self.device, "dtype": self.dtype}) - - # Reshape t to match dimensions of x_1 - sigmas = sigmas.view(sigmas.shape[0], *([1] * (len(x_t.shape) - 1))) - original_dtype = x_t.dtype - x_t, dot_x_t, sigmas = map(lambda x: x.to(dtype=torch.float64), (x_t, dot_x_t, sigmas)) - x_1_pred = x_t - sigmas * dot_x_t - x_1_pred = x_1_pred.to(dtype=original_dtype) - return x_1_pred diff --git a/lyra_2/_src/schedulers/self_forcing_scheduler.py b/lyra_2/_src/schedulers/self_forcing_scheduler.py deleted file mode 100644 index 6a7a096dcd4c1e77aa29f31b0c5ab9b2ed9900fe..0000000000000000000000000000000000000000 --- a/lyra_2/_src/schedulers/self_forcing_scheduler.py +++ /dev/null @@ -1,93 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -"""Flow-matching scheduler used by DMD distillation (4-step) inference. - -Ported from imaginaire4's `self_forcing_scheduler.py` for Lyra-2. -""" - -import torch - - -class FlowMatchScheduler: - def __init__( - self, - num_inference_steps=100, - num_train_timesteps=1000, - shift=3.0, - sigma_max=1.0, - sigma_min=0.003 / 1.002, - inverse_timesteps=False, - extra_one_step=False, - reverse_sigmas=False, - ): - self.num_train_timesteps = num_train_timesteps - self.shift = shift - self.sigma_max = sigma_max - self.sigma_min = sigma_min - self.inverse_timesteps = inverse_timesteps - self.extra_one_step = extra_one_step - self.reverse_sigmas = reverse_sigmas - self.set_timesteps(num_inference_steps) - - def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False): - sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength - if self.extra_one_step: - self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1] - else: - self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps) - if self.inverse_timesteps: - self.sigmas = torch.flip(self.sigmas, dims=[0]) - self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas) - if self.reverse_sigmas: - self.sigmas = 1 - self.sigmas - self.timesteps = self.sigmas * self.num_train_timesteps - if training: - x = self.timesteps - y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) - y_shifted = y - y.min() - bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) - self.linear_timesteps_weights = bsmntw_weighing - - def step(self, model_output, timestep, sample, to_final=False): - if timestep.ndim == 2: - timestep = timestep.flatten(0, 1) - self.sigmas = self.sigmas.to(model_output.device) - self.timesteps = self.timesteps.to(model_output.device) - timestep_id = torch.argmin( - (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 - ) - sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1) - if to_final or (timestep_id + 1 >= len(self.timesteps)).any(): - sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0 - else: - sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1) - prev_sample = sample + model_output * (sigma_ - sigma) - return prev_sample - - def add_noise(self, original_samples, noise, timestep): - """Diffusion forward corruption process. - - Inputs are batched over [B*T, C, H, W]; timestep has shape [B*T]. - """ - if timestep.ndim == 2: - timestep = timestep.flatten(0, 1) - self.sigmas = self.sigmas.to(noise.device) - self.timesteps = self.timesteps.to(noise.device) - timestep_id = torch.argmin( - (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 - ) - sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1) - sample = (1 - sigma) * original_samples + sigma * noise - return sample.type_as(noise) - - def training_target(self, sample, noise, timestep): - return noise - sample - - def training_weight(self, timestep): - if timestep.ndim == 2: - timestep = timestep.flatten(0, 1) - self.linear_timesteps_weights = self.linear_timesteps_weights.to(timestep.device) - timestep_id = torch.argmin( - (self.timesteps.unsqueeze(1) - timestep.unsqueeze(0)).abs(), dim=0 - ) - return self.linear_timesteps_weights[timestep_id] diff --git a/lyra_2/_src/tokenizers/__init__.py b/lyra_2/_src/tokenizers/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/lyra_2/_src/tokenizers/base_vae.py b/lyra_2/_src/tokenizers/base_vae.py deleted file mode 100644 index aad74f773574437e631c0bf3c4e95d18c395781d..0000000000000000000000000000000000000000 --- a/lyra_2/_src/tokenizers/base_vae.py +++ /dev/null @@ -1,445 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from abc import ABC, abstractmethod - -import torch -import torch.nn.functional as F - -from lyra_2._ext.imaginaire.utils.distributed import rank0_first -from lyra_2._ext.imaginaire.utils.s3_utils import load_from_s3_with_cache - - -class BaseVAE(torch.nn.Module, ABC): - """ - Abstract base class for a Variational Autoencoder (VAE). - - All subclasses should implement the methods to define the behavior for encoding - and decoding, along with specifying the latent channel size. - """ - - def __init__(self, channel: int = 3, name: str = "vae"): - super().__init__() - self.channel = channel - self.name = name - - @property - def latent_ch(self) -> int: - """ - Returns the number of latent channels in the VAE. - """ - return self.channel - - @abstractmethod - def encode(self, state: torch.Tensor) -> torch.Tensor: - """ - Encodes the input tensor into a latent representation. - - Args: - - state (torch.Tensor): The input tensor to encode. - - Returns: - - torch.Tensor: The encoded latent tensor. - """ - pass - - @abstractmethod - def decode(self, latent: torch.Tensor) -> torch.Tensor: - """ - Decodes the latent representation back to the original space. - - Args: - - latent (torch.Tensor): The latent tensor to decode. - - Returns: - - torch.Tensor: The decoded tensor. - """ - pass - - @property - def spatial_compression_factor(self) -> int: - """ - Returns the spatial reduction factor for the VAE. - """ - raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.") - - -class BasePretrainedImageVAE(BaseVAE): - """ - A base class for pretrained Variational Autoencoder (VAE) that loads mean and standard deviation values - from a remote store, handles data type conversions, and normalization - using provided mean and standard deviation values for latent space representation. - Derived classes should load pre-trained encoder and decoder components from a remote store - - Attributes: - latent_mean (Tensor): The mean used for normalizing the latent representation. - latent_std (Tensor): The standard deviation used for normalizing the latent representation. - dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. - - Args: - mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. - latent_ch (int, optional): Number of latent channels (default is 16). - is_image (bool, optional): Flag to indicate whether the output is an image (default is True). - is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). - """ - - def __init__( - self, - name: str, - mean_std_fp: str, - latent_ch: int = 16, - is_image: bool = True, - is_bf16: bool = True, - load_mean_std: bool = True, - ) -> None: - super().__init__(latent_ch, name) - dtype = torch.bfloat16 if is_bf16 else torch.float32 - self.dtype = dtype - self.is_image = is_image - self.mean_std_fp = mean_std_fp - self.name = name - self.load_mean_std = load_mean_std - - self.backend_args = None - - self.register_mean_std(mean_std_fp) - - def register_mean_std(self, mean_std_fp: str) -> None: - target_shape = [1, self.latent_ch, 1, 1] if self.is_image else [1, self.latent_ch, 1, 1, 1] - - if self.load_mean_std: - extention = mean_std_fp.split(".")[-1] - latent_mean, latent_std = load_from_s3_with_cache( - mean_std_fp, - f"vae/{self.name}_mean_std.{extention}", - easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())}, - backend_args=self.backend_args, - ) - self.register_buffer( - "latent_mean", - latent_mean.to(self.dtype).reshape(*target_shape), - persistent=False, - ) - self.register_buffer( - "latent_std", - latent_std.to(self.dtype).reshape(*target_shape), - persistent=False, - ) - else: - # Use zeros for mean and ones for std when load_mean_std=False - device = torch.device(torch.cuda.current_device()) if torch.cuda.is_available() else torch.device("cpu") - self.register_buffer( - "latent_mean", - torch.zeros(*target_shape, dtype=self.dtype, device=device), - persistent=False, - ) - self.register_buffer( - "latent_std", - torch.ones(*target_shape, dtype=self.dtype, device=device), - persistent=False, - ) - - @torch.no_grad() - def encode(self, state: torch.Tensor) -> torch.Tensor: - """ - Encode the input state to latent space; also handle the dtype conversion, mean and std scaling - """ - in_dtype = state.dtype - latent_mean = self.latent_mean.to(in_dtype) - latent_std = self.latent_std.to(in_dtype) - encoded_state = self.encoder(state.to(self.dtype)) - if isinstance(encoded_state, torch.Tensor): - pass - elif isinstance(encoded_state, tuple): - assert isinstance(encoded_state[0], torch.Tensor) - encoded_state = encoded_state[0] - else: - raise ValueError("Invalid type of encoded state") - return (encoded_state.to(in_dtype) - latent_mean) / latent_std - - @torch.no_grad() - def decode(self, latent: torch.Tensor) -> torch.Tensor: - """ - Decode the input latent to state; also handle the dtype conversion, mean and std scaling - """ - in_dtype = latent.dtype - latent = latent * self.latent_std.to(in_dtype) + self.latent_mean.to(in_dtype) - return self.decoder(latent.to(self.dtype)).to(in_dtype) - - def reset_dtype(self, *args, **kwargs): - """ - Resets the data type of the encoder and decoder to the model's default data type. - - Args: - *args, **kwargs: Unused, present to allow flexibility in method calls. - """ - del args, kwargs - self.decoder.to(self.dtype) - self.encoder.to(self.dtype) - - -class JITVAE(BasePretrainedImageVAE): - """ - A JIT compiled Variational Autoencoder (VAE) that loads pre-trained encoder - and decoder components from a remote store, handles data type conversions, and normalization - using provided mean and standard deviation values for latent space representation. - - Attributes: - encoder (Module): The JIT compiled encoder loaded from storage. - decoder (Module): The JIT compiled decoder loaded from storage. - latent_mean (Tensor): The mean used for normalizing the latent representation. - latent_std (Tensor): The standard deviation used for normalizing the latent representation. - dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. - - Args: - enc_fp (str): File path to the encoder's JIT file on the remote store. - dec_fp (str): File path to the decoder's JIT file on the remote store. - name (str): Name of the model, used for differentiating cache file paths. - mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. - latent_ch (int, optional): Number of latent channels (default is 16). - is_image (bool, optional): Flag to indicate whether the output is an image (default is True). - is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). - """ - - def __init__( - self, - enc_fp: str, - dec_fp: str, - name: str, - mean_std_fp: str, - latent_ch: int = 16, - is_image: bool = True, - is_bf16: bool = True, - load_mean_std: bool = True, - ): - super().__init__( - name, - mean_std_fp, - latent_ch, - is_image, - is_bf16, - load_mean_std=load_mean_std, - ) - self.load_encoder(enc_fp) - self.load_decoder(dec_fp) - - def load_encoder(self, enc_fp: str) -> None: - """ - Load the encoder from the remote store. - - Args: - - enc_fp (str): File path to the encoder's JIT file on the remote store. - """ - self.encoder = load_from_s3_with_cache( - enc_fp, - f"vae/{self.name}_enc.jit", - easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())}, - backend_args=self.backend_args, - ) - self.encoder.eval() - for param in self.encoder.parameters(): - param.requires_grad = False - self.encoder.to(self.dtype) - - def load_decoder(self, dec_fp: str) -> None: - """ - Load the decoder from the remote store. - - Args: - - dec_fp (str): File path to the decoder's JIT file on the remote store. - """ - self.decoder = load_from_s3_with_cache( - dec_fp, - f"vae/{self.name}_dec.jit", - easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())}, - backend_args=self.backend_args, - ) - self.decoder.eval() - for param in self.decoder.parameters(): - param.requires_grad = False - self.decoder.to(self.dtype) - - -class StateDictVAE(BasePretrainedImageVAE): - """ - A Variational Autoencoder (VAE) that loads pre-trained weights into - provided encoder and decoder components from a remote store, handles data type conversions, - and normalization using provided mean and standard deviation values for latent space representation. - - Attributes: - encoder (Module): The encoder with weights loaded from storage. - decoder (Module): The decoder with weights loaded from storage. - latent_mean (Tensor): The mean used for normalizing the latent representation. - latent_std (Tensor): The standard deviation used for normalizing the latent representation. - dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled. - - Args: - enc_fp (str): File path to the encoder's JIT file on the remote store. - dec_fp (str): File path to the decoder's JIT file on the remote store. - vae (Module): Instance of VAE with not loaded weights - name (str): Name of the model, used for differentiating cache file paths. - mean_std_fp (str): File path to the pickle file containing mean and std of the latent space. - latent_ch (int, optional): Number of latent channels (default is 16). - is_image (bool, optional): Flag to indicate whether the output is an image (default is True). - is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True). - """ - - def __init__( - self, - enc_fp: str, - dec_fp: str, - vae: torch.nn.Module, - name: str, - mean_std_fp: str, - latent_ch: int = 16, - is_image: bool = True, - is_bf16: bool = True, - ): - super().__init__(name, mean_std_fp, latent_ch, is_image, is_bf16) - - self.load_encoder_and_decoder(enc_fp, dec_fp, vae) - - def load_encoder_and_decoder(self, enc_fp: str, dec_fp: str, vae: torch.nn.Module) -> None: - """ - Load the encoder from the remote store. - - Args: - - vae_fp (str): File path to the vae's state dict file on the remote store. - - vae (str): VAE module into which weights will be loaded. - """ - state_dict_enc = load_from_s3_with_cache( - enc_fp, - f"vae/{self.name}_enc.jit", - easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())}, - backend_args=self.backend_args, - ) - - state_dict_dec = load_from_s3_with_cache( - dec_fp, - f"vae/{self.name}_dec.jit", - easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())}, - backend_args=self.backend_args, - ) - - jit_weights_state_dict = state_dict_enc.state_dict() | state_dict_dec.state_dict() - jit_weights_state_dict = { - k: v - for k, v in jit_weights_state_dict.items() - # Global variables captured by JIT - if k - not in ( - "encoder.patcher.wavelets", - "encoder.patcher._arange", - "decoder.unpatcher.wavelets", - "decoder.unpatcher._arange", - ) - } - - vae.load_state_dict(jit_weights_state_dict) - vae.eval() - for param in vae.parameters(): - param.requires_grad = False - vae.to(self.dtype) - - self.vae = vae - self.encoder = self.vae.encode - self.decoder = self.vae.decode - - def reset_dtype(self, *args, **kwargs): - """ - Resets the data type of the encoder and decoder to the model's default data type. - - Args: - *args, **kwargs: Unused, present to allow flexibility in method calls. - """ - del args, kwargs - self.vae.to(self.dtype) - - -class SDVAE(BaseVAE): - def __init__(self, batch_size=16, count_std: bool = False, is_downsample: bool = True) -> None: - super().__init__(channel=4, name="sd_vae") - self.dtype = torch.bfloat16 - self.register_buffer( - "scale", - torch.tensor([4.17, 4.62, 3.71, 3.28], dtype=self.dtype).reciprocal().reshape(1, -1, 1, 1), - persistent=False, - ) - self.register_buffer( - "bias", - -1.0 * torch.tensor([5.81, 3.25, 0.12, -2.15], dtype=self.dtype).reshape(1, -1, 1, 1) * self.scale, - persistent=False, - ) - self.batch_size = batch_size - self.count_std = count_std - self.is_downsample = is_downsample - self.load_vae() - self.reset_dtype() - - def reset_dtype(self, *args, **kwargs): - del args, kwargs - self.vae.to(self.dtype) - - @rank0_first - def load_vae(self) -> None: - os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" - os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" - import diffusers - - vae_name = "stabilityai/sd-vae-ft-mse" - try: - vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, local_files_only=True) - except: # noqa: E722 - # Could not load the model from cache; try without local_files_only. - vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name) - self.vae = vae.eval().requires_grad_(False) - - @torch.no_grad() - def encode(self, state: torch.Tensor) -> torch.Tensor: - """ - state : pixel range [-1, 1] - """ - if self.is_downsample: - _h, _w = state.shape[-2:] - state = F.interpolate(state, size=(_h // 2, _w // 2), mode="bilinear", align_corners=False) - in_dtype = state.dtype - state = state.to(self.dtype) - state = (state + 1.0) / 2.0 - latent_dist = self.vae.encode(state)["latent_dist"] - mean, std = latent_dist.mean, latent_dist.std - if self.count_std: - latent = mean + torch.randn_like(mean) * std - else: - latent = mean - latent = latent * self.scale - latent = latent + self.bias - return latent.to(in_dtype) - - @torch.no_grad() - def decode(self, latent: torch.Tensor) -> torch.Tensor: - in_dtype = latent.dtype - latent = latent.to(self.dtype) - latent = latent - self.bias - latent = latent / self.scale - latent = torch.cat([self.vae.decode(batch)["sample"] for batch in latent.split(self.batch_size)]) - if self.is_downsample: - _h, _w = latent.shape[-2:] - latent = F.interpolate(latent, size=(_h * 2, _w * 2), mode="bilinear", align_corners=False) - return latent.to(in_dtype) * 2 - 1.0 - - @property - def spatial_compression_factor(self) -> int: - return 8 diff --git a/lyra_2/_src/tokenizers/interface.py b/lyra_2/_src/tokenizers/interface.py deleted file mode 100644 index ec31d5f6446f67be7b5262fbd527eeb51bafda22..0000000000000000000000000000000000000000 --- a/lyra_2/_src/tokenizers/interface.py +++ /dev/null @@ -1,84 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from abc import ABC, abstractmethod - -import torch - -class VideoTokenizerInterface(ABC): - def __init__(self): - self.backend_args = None - - @abstractmethod - def reset_dtype(self): - """ - Reset the dtype of the model to the dtype its weights were trained with or quantized to. - """ - pass - - @abstractmethod - def encode(self, state: torch.Tensor) -> torch.Tensor: - pass - - @abstractmethod - def decode(self, latent: torch.Tensor) -> torch.Tensor: - pass - - @abstractmethod - def get_latent_num_frames(self, num_pixel_frames: int) -> int: - pass - - @abstractmethod - def get_pixel_num_frames(self, num_latent_frames: int) -> int: - pass - - @property - @abstractmethod - def spatial_compression_factor(self): - pass - - @property - @abstractmethod - def temporal_compression_factor(self): - pass - - @property - @abstractmethod - def spatial_resolution(self): - pass - - @property - @abstractmethod - def pixel_chunk_duration(self): - pass - - @property - @abstractmethod - def latent_chunk_duration(self): - pass - - @property - @abstractmethod - def latent_ch(self) -> int: - pass - - @property - def is_chunk_overlap(self): - return False - - @property - def is_causal(self): - return True diff --git a/lyra_2/_src/tokenizers/wan2pt1.py b/lyra_2/_src/tokenizers/wan2pt1.py deleted file mode 100644 index 1900b966a7d22614d0e06b546db1b58bf435e2a8..0000000000000000000000000000000000000000 --- a/lyra_2/_src/tokenizers/wan2pt1.py +++ /dev/null @@ -1,907 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. - -from contextlib import nullcontext - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange - -from lyra_2._ext.imaginaire.lazy_config import LazyCall as L -from lyra_2._ext.imaginaire.lazy_config import LazyDict -from lyra_2._ext.imaginaire.utils import log -from lyra_2._ext.imaginaire.utils.distributed import broadcast, get_rank, sync_model_states -from lyra_2._ext.imaginaire.utils.easy_io import easy_io -from lyra_2._src.tokenizers.interface import VideoTokenizerInterface - -__all__ = [ - "WanVAE", -] - -CACHE_T = 2 - - -class CausalConv3d(nn.Conv3d): - """ - Causal 3d convolusion. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) - self.padding = (0, 0, 0) - - def forward(self, x, cache_x=None): - padding = list(self._padding) - if cache_x is not None and self._padding[4] > 0: - cache_x = cache_x.to(x.device) - x = torch.cat([cache_x, x], dim=2) - padding[4] -= cache_x.shape[2] - x = F.pad(x, padding) - - return super().forward(x) - - -class RMS_norm(nn.Module): - def __init__(self, dim, channel_first=True, images=True, bias=False): - super().__init__() - broadcastable_dims = (1, 1, 1) if not images else (1, 1) - shape = (dim, *broadcastable_dims) if channel_first else (dim,) - - self.channel_first = channel_first - self.scale = dim**0.5 - self.gamma = nn.Parameter(torch.ones(shape)) - self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 - - def forward(self, x): - return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias - - -class Upsample(nn.Upsample): - def forward(self, x): - """ - Fix bfloat16 support for nearest neighbor interpolation. - """ - return super().forward(x.float()).type_as(x) - - -class Resample(nn.Module): - def __init__(self, dim, mode): - assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d") - super().__init__() - self.dim = dim - self.mode = mode - - # layers - if mode == "upsample2d": - self.resample = nn.Sequential( - Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) - ) - elif mode == "upsample3d": - self.resample = nn.Sequential( - Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) - ) - self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) - - elif mode == "downsample2d": - self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) - elif mode == "downsample3d": - self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) - self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) - - else: - self.resample = nn.Identity() - - def forward(self, x, feat_cache=None, feat_idx=[0]): - b, c, t, h, w = x.size() - if self.mode == "upsample3d": - if feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = "Rep" - feat_idx[0] += 1 - else: - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": - # cache last frame of last two chunk - cache_x = torch.cat( - [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 - ) - if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": - cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) - if feat_cache[idx] == "Rep": - x = self.time_conv(x) - else: - x = self.time_conv(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - - x = x.reshape(b, 2, c, t, h, w) - x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) - x = x.reshape(b, c, t * 2, h, w) - t = x.shape[2] - x = rearrange(x, "b c t h w -> (b t) c h w") - x = self.resample(x) - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - - if self.mode == "downsample3d": - if feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = x.clone() - feat_idx[0] += 1 - else: - cache_x = x[:, :, -1:, :, :].clone() - x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - return x - - def init_weight(self, conv): - conv_weight = conv.weight - nn.init.zeros_(conv_weight) - c1, c2, t, h, w = conv_weight.size() - one_matrix = torch.eye(c1, c2) - init_matrix = one_matrix - nn.init.zeros_(conv_weight) - # conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 - conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5 - conv.weight.data.copy_(conv_weight) - nn.init.zeros_(conv.bias.data) - - def init_weight2(self, conv): - conv_weight = conv.weight.data - nn.init.zeros_(conv_weight) - c1, c2, t, h, w = conv_weight.size() - init_matrix = torch.eye(c1 // 2, c2) - # init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) - conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix - conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix - conv.weight.data.copy_(conv_weight) - nn.init.zeros_(conv.bias.data) - - -class ResidualBlock(nn.Module): - def __init__(self, in_dim, out_dim, dropout=0.0): - super().__init__() - self.in_dim = in_dim - self.out_dim = out_dim - - # layers - self.residual = nn.Sequential( - RMS_norm(in_dim, images=False), - nn.SiLU(), - CausalConv3d(in_dim, out_dim, 3, padding=1), - RMS_norm(out_dim, images=False), - nn.SiLU(), - nn.Dropout(dropout), - CausalConv3d(out_dim, out_dim, 3, padding=1), - ) - self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() - - def forward(self, x, feat_cache=None, feat_idx=[0]): - h = self.shortcut(x) - for layer in self.residual: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat( - [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 - ) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x + h - - -class AttentionBlock(nn.Module): - """ - Causal self-attention with a single head. - """ - - def __init__(self, dim): - super().__init__() - self.dim = dim - - # layers - self.norm = RMS_norm(dim) - self.to_qkv = nn.Conv2d(dim, dim * 3, 1) - self.proj = nn.Conv2d(dim, dim, 1) - - # zero out the last layer params - nn.init.zeros_(self.proj.weight) - - def forward(self, x): - identity = x - b, c, t, h, w = x.size() - x = rearrange(x, "b c t h w -> (b t) c h w") - x = self.norm(x) - # compute query, key, value - q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1) - - # apply attention - x = F.scaled_dot_product_attention( - q, - k, - v, - ) - x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) - - # output - x = self.proj(x) - x = rearrange(x, "(b t) c h w-> b c t h w", t=t) - return x + identity - - -class Encoder3d(nn.Module): - def __init__( - self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], - dropout=0.0, - ): - super().__init__() - self.dim = dim - self.z_dim = z_dim - self.dim_mult = dim_mult - self.num_res_blocks = num_res_blocks - self.attn_scales = attn_scales - self.temperal_downsample = temperal_downsample - - # dimensions - dims = [dim * u for u in [1] + dim_mult] - scale = 1.0 - - # init block - self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) - - # downsample blocks - downsamples = [] - for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): - # residual (+attention) blocks - for _ in range(num_res_blocks): - downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) - if scale in attn_scales: - downsamples.append(AttentionBlock(out_dim)) - in_dim = out_dim - - # downsample block - if i != len(dim_mult) - 1: - mode = "downsample3d" if temperal_downsample[i] else "downsample2d" - downsamples.append(Resample(out_dim, mode=mode)) - scale /= 2.0 - self.downsamples = nn.Sequential(*downsamples) - - # middle blocks - self.middle = nn.Sequential( - ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout) - ) - - # output blocks - self.head = nn.Sequential( - RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1) - ) - - def forward(self, x, feat_cache=None, feat_idx=[0]): - if feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) - x = self.conv1(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = self.conv1(x) - - # downsamples - for layer in self.downsamples: - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - # middle - for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - # head - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat( - [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 - ) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x - - -class Decoder3d(nn.Module): - def __init__( - self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_upsample=[False, True, True], - dropout=0.0, - ): - super().__init__() - self.dim = dim - self.z_dim = z_dim - self.dim_mult = dim_mult - self.num_res_blocks = num_res_blocks - self.attn_scales = attn_scales - self.temperal_upsample = temperal_upsample - - # dimensions - dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] - scale = 1.0 / 2 ** (len(dim_mult) - 2) - - # init block - self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) - - # middle blocks - self.middle = nn.Sequential( - ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout) - ) - - # upsample blocks - upsamples = [] - for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): - # residual (+attention) blocks - if i == 1 or i == 2 or i == 3: - in_dim = in_dim // 2 - for _ in range(num_res_blocks + 1): - upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) - if scale in attn_scales: - upsamples.append(AttentionBlock(out_dim)) - in_dim = out_dim - - # upsample block - if i != len(dim_mult) - 1: - mode = "upsample3d" if temperal_upsample[i] else "upsample2d" - upsamples.append(Resample(out_dim, mode=mode)) - scale *= 2.0 - self.upsamples = nn.Sequential(*upsamples) - - # output blocks - self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1)) - - def forward(self, x, feat_cache=None, feat_idx=[0]): - # conv1 - if feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) - x = self.conv1(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = self.conv1(x) - - # middle - for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - # upsamples - for layer in self.upsamples: - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - # head - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat( - [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 - ) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x - - -def count_conv3d(model): - count = 0 - for m in model.modules(): - if isinstance(m, CausalConv3d): - count += 1 - return count - - -class WanVAE_(nn.Module): - def __init__( - self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], - dropout=0.0, - temporal_window=4, - ): - super().__init__() - self.dim = dim - self.z_dim = z_dim - self.dim_mult = dim_mult - self.num_res_blocks = num_res_blocks - self.attn_scales = attn_scales - self.temperal_downsample = temperal_downsample - self.temperal_upsample = temperal_downsample[::-1] - self.temporal_window = temporal_window - # modules - self.encoder = Encoder3d( - dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout - ) - self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) - self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) - - def forward(self, x): - mu, log_var = self.encode(x) - z = self.reparameterize(mu, log_var) - x_recon = self.decode(z) - return x_recon, mu, log_var - - def encode(self, x, scale): - batch_size = x.shape[0] - - if batch_size >= 8: - chunk_size = 4 - chunks = [] - for start_idx in range(0, batch_size, chunk_size): - end_idx = min(start_idx + chunk_size, batch_size) - chunks.append(self._encode_single_batch(x[start_idx:end_idx], scale)) - return torch.cat(chunks, dim=0) - else: - return self._encode_single_batch(x, scale) - - def _encode_single_batch(self, x, scale): - """Encode a single batch.""" - self.clear_cache() - # cache - t = x.shape[2] - iter_ = 1 + (t - 1) // self.temporal_window - # Split x along T into chunks: [1, temporal_window, temporal_window, ...] - for i in range(iter_): - self._enc_conv_idx = [0] - if i == 0: - out = self._i0_encode(x) - else: - out_ = self.encoder( - x[:, :, 1 + self.temporal_window * (i - 1) : 1 + self.temporal_window * i, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx, - ) - out = torch.cat([out, out_], 2) - if (t - 1) % self.temporal_window: - self._enc_conv_idx = [0] - out_ = self.encoder( - x[:, :, 1 + self.temporal_window * (iter_ - 1) :, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx, - ) - out = torch.cat([out, out_], 2) - mu, log_var = self.conv1(out).chunk(2, dim=1) - if isinstance(scale[0], torch.Tensor): - mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) - else: - mu = (mu - scale[0]) * scale[1] - self.clear_cache() - return mu - - @torch.compiler.disable - def _i0_encode(self, x): - """ - If enabled torch.compile uses significantly more memory for this step, so we disable it - """ - out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) - return out - - def decode(self, z, scale): - batch_size = z.shape[0] - - if batch_size >= 8: - chunk_size = 4 - log.info(f"Decoding with chunking, batch size: {batch_size}, chunk size: {chunk_size}") - chunks = [] - for start_idx in range(0, batch_size, chunk_size): - end_idx = min(start_idx + chunk_size, batch_size) - chunks.append(self._decode_single_batch(z[start_idx:end_idx], scale)) - return torch.cat(chunks, dim=0) - else: - return self._decode_single_batch(z, scale) - - def _decode_single_batch(self, z, scale): - """Decode a single batch.""" - self.clear_cache() - # z: [b,c,t,h,w] - if isinstance(scale[0], torch.Tensor): - z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) - else: - z = z / scale[1] + scale[0] - iter_ = z.shape[2] - x = self.conv2(z) - for i in range(iter_): - self._conv_idx = [0] - if i == 0: - out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) - else: - out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) - out = torch.cat([out, out_], 2) - self.clear_cache() - return out - - def reparameterize(self, mu, log_var): - std = torch.exp(0.5 * log_var) - eps = torch.randn_like(std) - return eps * std + mu - - def sample(self, imgs, deterministic=False): - mu, log_var = self.encode(imgs) - if deterministic: - return mu - std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) - return mu + std * torch.randn_like(std) - - def clear_cache(self): - self._conv_num = count_conv3d(self.decoder) - self._conv_idx = [0] - self._feat_map = [None] * self._conv_num - # cache encode - self._enc_conv_num = count_conv3d(self.encoder) - self._enc_conv_idx = [0] - self._enc_feat_map = [None] * self._enc_conv_num - - -def _video_vae( - pretrained_path=None, - z_dim=None, - device="cpu", - load_mean_std=False, - mean_std_path=None, - image_mean_std_path: str = "./checkpoints/vae/images_mean_std.pt", - video_mean_std_path: str = "./checkpoints/vae/video_mean_std.pt", - **kwargs, -): - """ - Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. - """ - # params - cfg = dict( - dim=96, - z_dim=z_dim, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[False, True, True], - dropout=0.0, - ) - cfg.update(**kwargs) - - if mean_std_path is not None: - image_mean_std_path = mean_std_path.replace("mean_std.pt", "images_mean_std.pt") - video_mean_std_path = mean_std_path.replace("mean_std.pt", "video_mean_std.pt") - - # init model - with torch.device("meta"): - model = WanVAE_(**cfg) - - if pretrained_path is None: - model.to_empty(device=device) - if load_mean_std: - img_mean, img_std = torch.randn(1, 16, 1, 1, 1, device=device), torch.randn(1, 16, 1, 1, 1, device=device) - video_mean, video_std = ( - torch.randn(1, 16, 32, 1, 1, device=device), - torch.randn(1, 16, 32, 1, 1, device=device), - ) - else: - if get_rank() == 0: - ckpt = easy_io.load( - pretrained_path, - map_location=device, - ) - if load_mean_std: - img_mean, img_std = easy_io.load(image_mean_std_path, map_location=device) - video_mean, video_std = easy_io.load(video_mean_std_path, map_location=device) - img_mean = img_mean.reshape(1, 16, 1, 1, 1) - img_std = img_std.reshape(1, 16, 1, 1, 1) - video_mean = video_mean.reshape(1, 16, 32, 1, 1) - video_std = video_std.reshape(1, 16, 32, 1, 1) - - # load checkpoint - log.info(f"loading {pretrained_path}") - model.load_state_dict(ckpt, assign=True) - else: - model.to_empty(device=device) - if load_mean_std: - img_mean, img_std = ( - torch.randn(1, 16, 1, 1, 1, device=device), - torch.randn(1, 16, 1, 1, 1, device=device), - ) - video_mean, video_std = ( - torch.randn(1, 16, 32, 1, 1, device=device), - torch.randn(1, 16, 32, 1, 1, device=device), - ) - sync_model_states(model) - - if load_mean_std: - log.info("broadcast mean and std for wan2pt1") - broadcast(img_mean, 0) - broadcast(img_std, 0) - broadcast(video_mean, 0) - broadcast(video_std, 0) - return model, img_mean, img_std, video_mean, video_std - - return ( - model, - torch.zeros(1, 1, 1, 1, 1, device=device), - torch.ones(1, 1, 1, 1, 1, device=device), - torch.zeros(1, 1, 50, 1, 1, device=device), - torch.ones(1, 1, 50, 1, 1, device=device), - ) - - -class WanVAE: - def __init__( - self, - z_dim=16, - vae_pth="./checkpoints/vae/vae.pth", - load_mean_std=False, - mean_std_path=None, - image_mean_std_path: str = "./checkpoints/vae/images_mean_std.pt", - video_mean_std_path: str = "./checkpoints/vae/video_mean_std.pt", - dtype=torch.float, - device="cuda", - is_amp=True, - benchmark: bool = False, - temporal_window: int = 4, - ): - self.dtype = dtype - self.device = device - self.temporal_window = temporal_window - - mean = [ - -0.7571, - -0.7089, - -0.9113, - 0.1075, - -0.1745, - 0.9653, - -0.1517, - 1.5508, - 0.4134, - -0.0715, - 0.5517, - -0.3632, - -0.1922, - -0.9497, - 0.2503, - -0.2921, - ] - std = [ - 2.8184, - 1.4541, - 2.3275, - 2.6558, - 1.2196, - 1.7708, - 2.6052, - 2.0743, - 3.2687, - 2.1526, - 2.8652, - 1.5579, - 1.6382, - 1.1253, - 2.8251, - 1.9160, - ] - self.mean = torch.tensor(mean, dtype=dtype, device=device) - self.std = torch.tensor(std, dtype=dtype, device=device) - self.scale = [self.mean, 1.0 / self.std] - - # init model - self.model, self.img_mean, self.img_std, self.video_mean, self.video_std = _video_vae( - pretrained_path=vae_pth, - z_dim=z_dim, - load_mean_std=load_mean_std, - mean_std_path=mean_std_path, - image_mean_std_path=image_mean_std_path, - video_mean_std_path=video_mean_std_path, - device=device, - temporal_window=temporal_window, - ) - self.model = self.model.eval().requires_grad_(False) - self.is_amp = is_amp - if not is_amp: - self.model = self.model.to(dtype=dtype) - self.context = nullcontext() - else: - self.context = torch.amp.autocast("cuda", dtype=dtype) - - def count_param(self): - return sum(p.numel() for p in self.model.parameters()) - - @torch.no_grad() - def encode(self, videos): - """ - videos: A list of videos each with shape [C, T, H, W]. - """ - - in_dtype = videos.dtype - with self.context: - if not self.is_amp: - videos = videos.to(self.dtype) - latent = self.model.encode(videos, self.scale) - latent = latent.to(in_dtype) - return latent - - @torch.no_grad() - def decode(self, zs): - in_dtype = zs.dtype - with self.context: - if not self.is_amp: - zs = zs.to(self.dtype) - video_recon = self.model.decode(zs, self.scale) - video_recon = video_recon.to(in_dtype) - return video_recon - - -class Wan2pt1VAEInterface(VideoTokenizerInterface): - def __init__(self, chunk_duration: int = 81, load_mean_std=False, **kwargs): - self.model = WanVAE( - dtype=torch.bfloat16, - is_amp=False, - load_mean_std=load_mean_std, - vae_pth=kwargs.get( - "vae_pth", - "./checkpoints/vae/vae.pth", - ), - mean_std_path=kwargs.get("mean_std_path"), - image_mean_std_path=kwargs.get( - "image_mean_std_path", - "./checkpoints/vae/images_mean_std.pt", - ), - video_mean_std_path=kwargs.get( - "video_mean_std_path", - "./checkpoints/vae/video_mean_std.pt", - ), - temporal_window=kwargs.get("temporal_window", 4), - ) - if kwargs.get("compile_encode", False) and hasattr(torch, "compile"): - torch_compile_available = True - try: - # PyTorch >= 2.7 - torch._dynamo.config.recompile_limit = 32 - except AttributeError: - try: - torch._dynamo.config.cache_size_limit = 32 - except AttributeError: - log.warning( - "`compile_encode=True` requested, but Torch Dynamo is unavailable – skipping compilation." - ) - torch_compile_available = False - if torch_compile_available: - log.warning( - "The 'model.config.tokenizer.compile_encode' config option is deprecated. Please switch to using CompileTokenizer callback." - ) - self.encode = torch.compile(self.encode, dynamic=False) - del kwargs - self.chunk_duration = chunk_duration - - @property - def dtype(self): - return self.model.dtype - - def reset_dtype(self): - pass - - def encode(self, state: torch.Tensor) -> torch.Tensor: - latents = self.model.encode(state) - num_frames = latents.shape[2] - if num_frames == 1: - return (latents - self.model.img_mean.type_as(latents)) / self.model.img_std.type_as(latents) - else: - return (latents - self.model.video_mean[:, :, :num_frames].type_as(latents)) / self.model.video_std[ - :, :, :num_frames - ].type_as(latents) - - def decode(self, latent: torch.Tensor) -> torch.Tensor: - num_frames = latent.shape[2] - if num_frames == 1: - return self.model.decode( - (latent * self.model.img_std.type_as(latent)) + self.model.img_mean.type_as(latent) - ) - else: - return self.model.decode( - (latent * self.model.video_std[:, :, :num_frames].type_as(latent)) - + self.model.video_mean[:, :, :num_frames].type_as(latent) - ) - - def get_latent_num_frames(self, num_pixel_frames: int) -> int: - return 1 + (num_pixel_frames - 1) // 4 - - def get_pixel_num_frames(self, num_latent_frames: int) -> int: - return (num_latent_frames - 1) * 4 + 1 - - @property - def spatial_compression_factor(self): - return 8 - - @property - def temporal_compression_factor(self): - return 4 - - @property - def pixel_chunk_duration(self): - return self.chunk_duration - - @property - def latent_chunk_duration(self): - return self.get_latent_num_frames(self.chunk_duration) - - @property - def latent_ch(self): - return 16 - - @property - def spatial_resolution(self): - return 512 - - @property - def name(self): - return "wan2pt1_tokenizer" - - -Wan2pt1VAEConfig: LazyDict = L(Wan2pt1VAEInterface)(name="wan2pt1_tokenizer", compile_encode=False) diff --git a/lyra_2/_src/utils/__init__.py b/lyra_2/_src/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/lyra_2/_src/utils/context_parallel.py b/lyra_2/_src/utils/context_parallel.py deleted file mode 100644 index 60781f8799f87c6a88054f0721d8b04ae5b5b8f3..0000000000000000000000000000000000000000 --- a/lyra_2/_src/utils/context_parallel.py +++ /dev/null @@ -1,192 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Optional - -import torch -from torch import Tensor -from torch.distributed import ProcessGroup, all_gather, broadcast_object_list, get_process_group_ranks, get_world_size -from torch.distributed.utils import _verify_param_shape_across_processes - -from lyra_2._ext.imaginaire.utils import distributed - - -def split_inputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: - """ - Split input tensor along the sequence dimension for context parallelism. - - This function divides the input tensor into equal parts along the specified - sequence dimension, based on the number of ranks in the context parallelism group. - It then selects the part corresponding to the current rank. - - Args: - x: Input tensor to be split. - seq_dim: The dimension along which to split the input (sequence dimension). - cp_group: The process group for context parallelism. - - Returns: - A slice of the input tensor corresponding to the current rank. - - Raises: - AssertionError: If the sequence dimension is not divisible by the number of ranks. - """ - cp_ranks = get_process_group_ranks(cp_group) - cp_size = len(cp_ranks) - - assert x.shape[seq_dim] % cp_size == 0, f"{x.shape[seq_dim]} cannot divide cp_size {cp_size}" - x = x.view(*x.shape[:seq_dim], cp_size, x.shape[seq_dim] // cp_size, *x.shape[(seq_dim + 1) :]) - seq_idx = torch.tensor([cp_group.rank()], device=x.device) - x = x.index_select(seq_dim, seq_idx) - # Note that the new sequence length is the original sequence length / cp_size - x = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) - return x - - -def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: - """ - Concatenate outputs from different ranks in the checkpoint parallelism group. - - This function gathers tensors from all ranks in the checkpoint parallelism group - and concatenates them along the specified sequence dimension. - - Args: - x: Input tensor to be concatenated. - seq_dim: The dimension along which to concatenate the tensors (sequence dimension). - cp_group: The process group for checkpoint parallelism. - - Returns: - A tensor that is the concatenation of tensors from all ranks in the cp_group. - - Raises: - RuntimeError: If the gather operation fails. - """ - # Get the world size (number of processes in the group) - world_size = get_world_size(cp_group) - - # Create a list to store tensors from all ranks - gathered_tensors = [torch.zeros_like(x) for _ in range(world_size)] - - # Gather tensors from all ranks - try: - all_gather(gathered_tensors, x, group=cp_group) - except RuntimeError as e: - raise RuntimeError(f"Failed to gather tensors: {e}") - - # Concatenate the gathered tensors along the specified dimension - return torch.cat(gathered_tensors, dim=seq_dim) - - -def cat_outputs_cp_with_grad(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor: - """ - Concatenate outputs from different ranks in the context parallelism group. - - This function gathers tensors from all ranks in the checkpoint parallelism group - and concatenates them along the specified sequence dimension. - - It retains computational graph locally for each rank by replacing the portion of the tensor with original output. - - Args: - x: Input tensor to be concatenated. - seq_dim: The dimension along which to concatenate the tensors (sequence dimension). - cp_group: The process group for checkpoint parallelism. - - Returns: - A tensor that is the concatenation of tensors from all ranks in the cp_group. - - Raises: - RuntimeError: If the gather operation fails. - """ - # Get the world size (number of processes in the group) - cp_size = cp_group.size() - assert cp_size > 0, "cp_size should be greater than 0" - - # Create a list to store tensors from all ranks - gathered_tensors = [torch.zeros_like(x) for _ in range(cp_size)] - - # Gather tensors from all ranks - try: - all_gather(gathered_tensors, x, group=cp_group) - except RuntimeError as e: - raise RuntimeError(f"Failed to gather tensors: {e}") - - rank = cp_group.rank() - gathered_tensors[rank] = x - # Concatenate the gathered tensors along the specified dimension - return torch.cat(gathered_tensors, dim=seq_dim) - - -def robust_broadcast(tensor: torch.Tensor, src: int, pg: ProcessGroup, is_check_shape: bool = False) -> torch.Tensor: - """ - Perform a robust broadcast operation that works regardless of tensor shapes on different ranks. - - Args: - tensor (torch.Tensor): The tensor to broadcast (on src rank) or receive (on other ranks). - src (int): The source rank for the broadcast. Defaults to 0. - - Returns: - torch.Tensor: The broadcasted tensor on all ranks. - """ - # First, broadcast the shape of the tensor - if distributed.get_rank() == src: - shape = torch.tensor(tensor.shape, dtype=torch.long).cuda() - else: - shape = torch.empty(tensor.dim(), dtype=torch.long).cuda() - if is_check_shape: - _verify_param_shape_across_processes(pg, [shape]) - torch.distributed.broadcast(shape, src, group=pg) - - # Resize the tensor on non-src ranks if necessary - if distributed.get_rank() != src: - tensor = tensor.new_empty(shape.tolist()).type_as(tensor) - - # Now broadcast the tensor data - torch.distributed.broadcast(tensor, src, group=pg) - - return tensor - - -def broadcast( - item: torch.Tensor | str | None, process_group: Optional[ProcessGroup] = None -) -> torch.Tensor | str | None: - """ - Broadcast the item from the minimum rank in the specified group(s). - """ - if process_group is None: - return item - - min_rank = min(get_process_group_ranks(process_group)) - if isinstance(item, torch.Tensor): # assume the device is cuda - item = robust_broadcast(item, min_rank, process_group) - elif item is not None: - broadcastable_list = [item] - broadcast_object_list(broadcastable_list, min_rank, group=process_group) - item = broadcastable_list[0] - return item - - -def broadcast_split_tensor( - tensor: torch.Tensor, - seq_dim: int, - process_group: Optional[ProcessGroup] = None, -) -> torch.Tensor: - """ - Broadcast the tensor from the minimum rank in the specified group(s). - """ - if tensor is None: - return tensor - min_rank = min(get_process_group_ranks(process_group)) - tensor = robust_broadcast(tensor, min_rank, process_group) - return split_inputs_cp(tensor, seq_dim, process_group) diff --git a/lyra_2/_src/utils/dtensor_helper.py b/lyra_2/_src/utils/dtensor_helper.py deleted file mode 100644 index e56cc85b78a4778f4edf7525e26d5b6b6c18f4d0..0000000000000000000000000000000000000000 --- a/lyra_2/_src/utils/dtensor_helper.py +++ /dev/null @@ -1,90 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from __future__ import annotations - -import itertools -from typing import Any - -import torch -import torch.distributed as dist -from torch.distributed import DeviceMesh - -from lyra_2._ext.imaginaire.utils.misc import get_local_tensor_if_DTensor - - -class DTensorFastEmaModelUpdater: - """ - Similar as FastEmaModelUpdater - """ - - def __init__(self): - # Flag to indicate whether the cache is taken or not. Useful to avoid cache overwrite - self.is_cached = False - - def copy_to(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module) -> None: - with torch.no_grad(): - for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): - tgt_params.to_local().data.copy_(src_params.to_local().data) - - @torch.no_grad() - def update_average(self, src_model: torch.nn.Module, tgt_model: torch.nn.Module, beta: float = 0.9999) -> None: - target_list = [] - source_list = [] - for tgt_params, src_params in zip(tgt_model.parameters(), src_model.parameters()): - assert tgt_params.dtype == torch.float32, ( - f"EMA model only works in FP32 dtype, got {tgt_params.dtype} instead." - ) - target_list.append(tgt_params.to_local()) - source_list.append(src_params.to_local().data) - torch._foreach_mul_(target_list, beta) - torch._foreach_add_(target_list, source_list, alpha=1.0 - beta) - - @torch.no_grad() - def cache(self, parameters: Any, is_cpu: bool = False) -> None: - assert self.is_cached is False, "EMA cache is already taken. Did you forget to restore it?" - device = "cpu" if is_cpu else "cuda" - self.collected_params = [param.to_local().clone().to(device) for param in parameters] - self.is_cached = True - - @torch.no_grad() - def restore(self, parameters: Any) -> None: - assert self.is_cached, "EMA cache is not taken yet." - for c_param, param in zip(self.collected_params, parameters, strict=False): - local_param = param.to_local() - local_param.copy_(c_param.data.type_as(local_param)) - self.collected_params = [] - # Release the cache after we call restore - self.is_cached = False - - -def broadcast_dtensor_model_states(model: torch.nn.Module, mesh: DeviceMesh): - """Broadcast model states from replicate mesh's rank 0.""" - replicate_group = mesh.get_group("replicate") - all_ranks = dist.get_process_group_ranks(replicate_group) - if len(all_ranks) == 1: - return - - for _, tensor in itertools.chain(model.named_parameters(), model.named_buffers()): - # Get src rank which is the first rank in each replication group - src_rank = all_ranks[0] - # Broadcast the local tensor - local_tensor = get_local_tensor_if_DTensor(tensor) - dist.broadcast( - local_tensor, - src=src_rank, - group=replicate_group, - ) diff --git a/lyra_2/_src/utils/misc.py b/lyra_2/_src/utils/misc.py deleted file mode 100644 index 2efbf99e5fc5bb3c2bcc598eddb50b5166a99a76..0000000000000000000000000000000000000000 --- a/lyra_2/_src/utils/misc.py +++ /dev/null @@ -1,143 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math -import os -from functools import wraps - -import torch -import torchvision -from PIL import Image - -from lyra_2._ext.imaginaire.utils import log - -_IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", "webp"] - - -def resize_input(video: torch.Tensor, resolution: list[int]): - r""" - Resizes and crops the input video tensor while preserving aspect ratio. - - The video is first resized so that the smaller dimension matches the target resolution, - preserving the aspect ratio. Then, it's center-cropped to the target resolution. - - Args: - video (torch.Tensor): Input video tensor of shape (T, C, H, W). - resolution (list[int]): Target resolution [H, W]. - - Returns: - torch.Tensor: Resized and cropped video tensor of shape (T, C, target_H, target_W). - """ - - orig_h, orig_w = video.shape[2], video.shape[3] - target_h, target_w = resolution - - scaling_ratio = max((target_w / orig_w), (target_h / orig_h)) - resizing_shape = (int(math.ceil(scaling_ratio * orig_h)), int(math.ceil(scaling_ratio * orig_w))) - video_resized = torchvision.transforms.functional.resize(video, resizing_shape) - video_cropped = torchvision.transforms.functional.center_crop(video_resized, resolution) - return video_cropped - - -def read_and_process_image(img_path: str, resolution: list[int], num_video_frames: int, resize: bool = True): - """ - Reads an image, converts it to a video tensor, and processes it for model input. - - The image is loaded, converted to a tensor, and replicated to match the - `num_video_frames`. It's then optionally resized and permuted to the - standard video format (B, C, T, H, W). - - Args: - img_path (str): Path to the input image file. - resolution (list[int]): Target resolution [H, W] for resizing. - num_video_frames (int): The number of frames the output video tensor should have. - resize (bool, optional): Whether to resize the image to the target resolution. Defaults to True. - - Returns: - torch.Tensor: Processed video tensor of shape (1, C, T, H, W). - - Raises: - ValueError: If the image extension is not one of the supported types. - """ - ext = os.path.splitext(img_path)[1] - if ext.lower() not in _IMAGE_EXTENSIONS: - raise ValueError(f"Invalid image extension: {ext}") - - # Read the image - img = Image.open(img_path) - - # Convert to tensor - img = torchvision.transforms.functional.to_tensor(img) - # Create a video tensor by repeating the first frame - vid_input = img.unsqueeze(0) # Add temporal dimension T=1 - - # Repeat the first frame to match the desired number of video frames - # Note: The actual content for frames > 0 will be generated by the model. - vid_input = torch.cat([vid_input, torch.zeros_like(vid_input).repeat(num_video_frames - 1, 1, 1, 1)], dim=0) - vid_input = (vid_input * 255.0).to(torch.uint8) # Convert to uint8 range if needed (might depend on model) - if resize: - # Resize and crop to the target resolution - vid_input = resize_input(vid_input, resolution) - - # Convert to {B, C, T, H, W} format expected by the model - vid_input = vid_input.unsqueeze(0).permute(0, 2, 1, 3, 4) # Add batch dim B=1 and permute - return vid_input - - -class sync_timer: - """ - Synchronized timer to count the inference time of `nn.Module.forward` or else. - set env var SYNC_TIMER=1 to enable logging! - - Example as context manager: - ```python - with timer('name'): - run() - ``` - - Example as decorator: - ```python - @timer('name') - def run(): - pass - ``` - """ - - def __init__(self, name=None, flag_env="SYNC_TIMER"): - self.name = name - self.flag_env = flag_env - - def __enter__(self): - if os.environ.get(self.flag_env, "0") == "1": - self.start = torch.cuda.Event(enable_timing=True) - self.end = torch.cuda.Event(enable_timing=True) - self.start.record() - return lambda: self.time - - def __exit__(self, exc_type, exc_value, exc_tb): - if os.environ.get(self.flag_env, "0") == "1": - self.end.record() - torch.cuda.synchronize() - self.time = self.start.elapsed_time(self.end) - if self.name is not None: - log.info(f"{self.name} takes {self.time / 1000:.4f}s", rank0_only=False) - - def __call__(self, func): - @wraps(func) - def wrapper(*args, **kwargs): - with self: - result = func(*args, **kwargs) - return result - - return wrapper diff --git a/lyra_2/_src/utils/model_loader.py b/lyra_2/_src/utils/model_loader.py deleted file mode 100644 index a4eb98d3f9793adc408bea4a63092f6ece37499d..0000000000000000000000000000000000000000 --- a/lyra_2/_src/utils/model_loader.py +++ /dev/null @@ -1,96 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import importlib -import os - -import torch -import torch.distributed.checkpoint as dcp - -from lyra_2._ext.imaginaire.checkpointer.dcp import DefaultLoadPlanner, DistributedCheckpointer, ModelWrapper -from lyra_2._ext.imaginaire.lazy_config import instantiate -from lyra_2._ext.imaginaire.utils import log, misc -from lyra_2._ext.imaginaire.utils.config_helper import get_config_module, override -from lyra_2._ext.imaginaire.utils.easy_io import easy_io - - -def load_model_from_checkpoint( - experiment_name, - checkpoint_path, - config_file="lyra_2/_src/configs/t2v_wan/config.py", - enable_fsdp=False, - instantiate_ema=True, - load_ema_to_reg=False, - seed=0, - experiment_opts: list[str] = [], - strict=True, -): - """ - experiment_name: experiment name - checkpoint_dir: path to iteration_model - config_file: config file path - enable_fsdp: enable fsdp - seed: random seed - """ - config_module = get_config_module(config_file) - config = importlib.import_module(config_module).make_config() - config = override(config, ["--", f"experiment={experiment_name}"] + experiment_opts) - - if instantiate_ema is False and config.model.config.ema.enabled: - config.model.config.ema.enabled = False - - # Check that the config is valid - config.validate() - # Freeze the config so developers don't change it during training. - config.freeze() # type: ignore - misc.set_random_seed(seed=seed, by_rank=True) - # Initialize cuDNN. - torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic - torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark - # Floating-point precision settings. - torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True - - if not enable_fsdp: - # disable fsdp - config.model.config.fsdp_shard_size = 1 - with misc.timer("instantiate model"): - model = instantiate(config.model).cuda() - # Convert the model parameters to bf16 - model.on_train_start() - - if checkpoint_path.endswith(".pth"): - log.info(f"Loading model from consolidated checkpoint {checkpoint_path}") - - model.load_state_dict(easy_io.load(checkpoint_path), strict=strict) - else: - log.info(f"Loading model from dcp checkpoint {checkpoint_path}") - - checkpointer = DistributedCheckpointer(config.checkpoint, config.job, callbacks=None, disable_async=True) - cur_key_ckpt_full_path = os.path.join(checkpoint_path, "model") - storage_reader = checkpointer.get_storage_reader(cur_key_ckpt_full_path) - - _model_wrapper = ModelWrapper(model, load_ema_to_reg=load_ema_to_reg) - _state_dict = _model_wrapper.state_dict() - dcp.load( - _state_dict, - storage_reader=storage_reader, - planner=DefaultLoadPlanner(allow_partial_load=True), - ) - _model_wrapper.load_state_dict(_state_dict) - - torch.cuda.empty_cache() - - return model, config diff --git a/lyra_2/_src/utils/optim_instantiate.py b/lyra_2/_src/utils/optim_instantiate.py deleted file mode 100644 index f3dd9c4034ddcd344acd91688998df1d5571302a..0000000000000000000000000000000000000000 --- a/lyra_2/_src/utils/optim_instantiate.py +++ /dev/null @@ -1,211 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re - -import torch -from omegaconf import ListConfig -from torch import nn - -from lyra_2._ext.imaginaire.utils import log - - -def get_regular_param_group(net: nn.Module): - """ - seperate the parameters of the network into two groups: decay and no_decay. - based on nano_gpt codebase. - """ - param_dict = {pn: p for pn, p in net.named_parameters()} - param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} - - decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] - nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] - return decay_params, nodecay_params - - -def get_base_optimizer( - model: nn.Module, - lr: float, - weight_decay: float, - optim_type: str = "adamw", - **kwargs, -) -> torch.optim.Optimizer: - net_decay_param, net_nodecay_param = get_regular_param_group(model) - - num_decay_params = sum(p.numel() for p in net_decay_param) - num_nodecay_params = sum(p.numel() for p in net_nodecay_param) - net_param_total = num_decay_params + num_nodecay_params - log.info(f"total num parameters : {net_param_total:,}") - - param_group = [ - { - "params": net_decay_param + net_nodecay_param, - "lr": lr, - "weight_decay": weight_decay, - }, - ] - - if optim_type == "adamw": - opt_cls = torch.optim.AdamW - else: - raise ValueError(f"Unknown optimizer type: {optim_type}") - - for k, v in kwargs.items(): - if isinstance(v, ListConfig): - kwargs[k] = list(v) - - return opt_cls(param_group, **kwargs) - - -def get_multiple_optimizer( - model: nn.Module, - lr: float, - weight_decay: float, - optim_type: str = "adamw", - lr_overrides: list[dict] = None, - **kwargs, -) -> torch.optim.Optimizer: - """ - Get an optimizer with multiple learning rates for different parts of the model, - allowing pattern matching for parameter names. - - The logic is: - 1. All parameters are initially considered for the default learning rate. - 2. We iterate through lr_overrides. If a parameter's name matches a pattern, - it's moved to a group with the specified learning rate. A parameter is only - assigned to the *first* pattern it matches. - - Args: - model (nn.Module): The model to optimize. - lr (float): The default learning rate. - weight_decay (float): The default weight decay. - optim_type (str): The type of optimizer to use ('adamw' or 'fusedadam'). - lr_overrides (list[dict], optional): A list of dicts with keys: - - 'pattern' (str): The pattern to match (required) - - 'lr' (float): The learning rate for matching params (required) - - 'match_type' (str): 'regex', 'contains', 'startswith', 'endswith' (default: 'contains') - - Example: - [ - {'pattern': 'cross_view_attn', 'lr': 2e-4, 'match_type': 'contains'}, - {'pattern': 'text_encoder', 'lr': 1e-5, 'match_type': 'contains'}, - ] - - This is Hydra-friendly and can be overridden from command line like: - optimizer.lr_overrides.0.lr=1e-4 - optimizer.lr_overrides.1.pattern=vision_encoder - **kwargs: Additional arguments for the optimizer. - - Returns: - torch.optim.Optimizer: The configured optimizer. - """ - param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} - - # Initialize groups for parameters with overridden LR - override_groups = {} # (lr, has_decay) -> [params] - override_groups_name = {} # (lr, has_decay) -> [name] - - # Initialize lists for parameters with default LR - default_decay_params = [] - default_nodecay_params = [] - - # Temporarily hold all params to check against overrides - unassigned_params = list(param_dict.items()) - - # First, assign params that match an override pattern - if lr_overrides: - override_list = lr_overrides - - for name, p in list(unassigned_params): - assigned = False - for override_item in override_list: - pattern = override_item["pattern"] - special_lr = override_item["lr"] - match_type = override_item.get("match_type", "contains") - - # Determine if the parameter name matches - matched = False - if match_type == "regex": - matched = re.match(pattern, name) is not None - elif match_type == "contains": - matched = pattern in name - elif match_type == "startswith": - matched = name.startswith(pattern) - elif match_type == "endswith": - matched = name.endswith(pattern) - else: - raise ValueError( - f"Unknown match_type: {match_type}. Must be one of: regex, contains, startswith, endswith" - ) - - if matched: - has_decay = p.dim() >= 2 - group_key = (special_lr, has_decay) - if group_key not in override_groups: - override_groups[group_key] = [] - if group_key not in override_groups_name: - override_groups_name[group_key] = [] - override_groups[group_key].append(p) - override_groups_name[group_key].append(name) - assigned = True - break # Assign to first matching pattern - if assigned: - # Remove from unassigned list; this is a bit inefficient but clear - unassigned_params = [(n, param) for n, param in unassigned_params if n != name] - - # Assign all remaining params to default groups - for name, p in unassigned_params: - if p.dim() >= 2: - default_decay_params.append(p) - else: - default_nodecay_params.append(p) - - # Build final param_groups list for the optimizer - final_param_groups = [] - if default_decay_params: - final_param_groups.append({"params": default_decay_params, "lr": lr, "weight_decay": weight_decay}) - if default_nodecay_params: - final_param_groups.append({"params": default_nodecay_params, "lr": lr, "weight_decay": 0.0}) - - for (special_lr, has_decay), params in override_groups.items(): - final_param_groups.append( - {"params": params, "lr": special_lr, "weight_decay": weight_decay if has_decay else 0.0} - ) - - # print the parameter names in each group - for (special_lr, has_decay), params in override_groups.items(): - log.critical(f"special_lr {special_lr}: {override_groups_name[(special_lr, has_decay)]}") - - # Log parameter group information - total_params = 0 - log.critical("Optimizer parameter groups:") - for i, group in enumerate(final_param_groups): - group_params = sum(p.numel() for p in group["params"]) - total_params += group_params - log.critical( - f" Group {i}: num_params={group_params:,}, lr={group['lr']:.1e}, weight_decay={group['weight_decay']}" - ) - log.critical(f"Total trainable parameters: {total_params:,}") - - if optim_type == "adamw": - opt_cls = torch.optim.AdamW - else: - raise ValueError(f"Unknown optimizer type: {optim_type}") - - for k, v in kwargs.items(): - if isinstance(v, ListConfig): - kwargs[k] = list(v) - - return opt_cls(final_param_groups, **kwargs) diff --git a/lyra_2/_src/utils/torch_future.py b/lyra_2/_src/utils/torch_future.py deleted file mode 100644 index c6133eac70c33c39b194c780ac6838f56f0d6f3f..0000000000000000000000000000000000000000 --- a/lyra_2/_src/utils/torch_future.py +++ /dev/null @@ -1,233 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import math - -# Backported gradient-clipping utilities from newer PyTorch / torchtitan. -from typing import Dict, Iterable, List, Optional, Tuple, Union - -import torch -import torch.distributed as dist -from torch import Tensor -from torch.distributed import DeviceMesh -from torch.distributed._tensor.api import DTensor -from torch.nn.utils.clip_grad import ( - _device_has_foreach_support, - _group_tensors_by_device_and_dtype, - _has_foreach_support, - _no_grad, - _tensor_or_tensors, -) - - -# https://github.com/pytorch/torchtitan/blob/d4c86e3758a84cf23e2e879ab3c995cba9d5e410/torchtitan/utils.py#L354 -@torch.no_grad() -def clip_grad_norm_( - parameters: Union[torch.Tensor, Iterable[torch.Tensor]], - max_norm: float, - norm_type: float = 2.0, - error_if_nonfinite: bool = False, - foreach: Optional[bool] = None, - pp_mesh: Optional[DeviceMesh] = None, -) -> torch.Tensor: - """ - Clip the gradient norm of an iterable of parameters. - - Gradient norm clipping requires computing the gradient norm over the entire model. - `torch.nn.utils.clip_grad_norm_` only computes gradient norm along DP/FSDP/TP dimensions. - We need to manually reduce the gradient norm across PP stages. - See https://github.com/pytorch/torchtitan/issues/596 for details. - - Args: - parameters: an iterable of Tensors or a single Tensor that will have gradients normalized - max_norm (float): max norm of the gradients - norm_type (float): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - error_if_nonfinite (bool): if True, an error is thrown if the total - norm of the gradients from :attr:`parameters` is ``nan``, - ``inf``, or ``-inf``. Default: False (will switch to True in the future) - foreach (bool): use the faster foreach-based implementation. - If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently - fall back to the slow implementation for other device types. - Default: ``None`` - pp_mesh: pipeline parallel device mesh. If not None, will reduce gradient norm across PP stages. - - Returns: - Total norm of the parameter gradients (viewed as a single vector). - - """ - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - else: - parameters = list(parameters) # prevent generators from being exhausted - - grads = [p.grad for p in parameters if p.grad is not None] - total_norm = get_total_norm(grads, norm_type, error_if_nonfinite, foreach) - - # If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`. - # We can simply reduce the DTensor to get the total norm in this tensor's process group - # and then convert it to a local tensor. - # NOTE: It has two purposes: - # 1. to make sure the total norm is computed correctly when PP is used (see below) - # 2. to return a reduced total_norm tensor whose .item() would return the correct value - if isinstance(total_norm, DTensor): - # Will reach here if any non-PP parallelism is used. - # If only using PP, total_norm will be a local tensor. - total_norm = total_norm.full_tensor() - - if pp_mesh is not None: - if math.isinf(norm_type): - dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) - else: - total_norm **= norm_type - dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) - total_norm **= 1.0 / norm_type - - clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) - return total_norm - - -# from: pytorch/pytorch/torch/nn/utils/clip_grad.py -@_no_grad -def _get_total_norm( - tensors: _tensor_or_tensors, - norm_type: float = 2.0, - error_if_nonfinite: bool = False, - foreach: Optional[bool] = None, -) -> torch.Tensor: - r"""Compute the norm of an iterable of tensors. - - The norm is computed over the norms of the individual tensors, as if the norms of - the individual tensors were concatenated into a single vector. - - Args: - tensors (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will be normalized - norm_type (float): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - error_if_nonfinite (bool): if True, an error is thrown if the total - norm of :attr:`tensors` is ``nan``, ``inf``, or ``-inf``. - Default: ``False`` - foreach (bool): use the faster foreach-based implementation. - If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently - fall back to the slow implementation for other device types. - Default: ``None`` - - Returns: - Total norm of the tensors (viewed as a single vector). - """ - if isinstance(tensors, torch.Tensor): - tensors = [tensors] - else: - tensors = list(tensors) - norm_type = float(norm_type) - if len(tensors) == 0: - return torch.tensor(0.0) - first_device = tensors[0].device - grouped_tensors: Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]] = ( - _group_tensors_by_device_and_dtype( - [tensors] # type: ignore[list-item] - ) - ) # type: ignore[assignment] - - norms: List[Tensor] = [] - for (device, _), ([device_tensors], _) in grouped_tensors.items(): - if (foreach is None and _has_foreach_support(device_tensors, device)) or ( - foreach and _device_has_foreach_support(device) - ): - norms.extend(torch._foreach_norm(device_tensors, norm_type)) - elif foreach: - raise RuntimeError(f"foreach=True was passed, but can't use the foreach API on {device.type} tensors") - else: - norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_tensors]) - - total_norm = torch.linalg.vector_norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) - - if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): - raise RuntimeError( - f"The total norm of order {norm_type} for gradients from " - "`parameters` is non-finite, so it cannot be clipped. To disable " - "this error and scale the gradients by the non-finite norm anyway, " - "set `error_if_nonfinite=False`" - ) - return total_norm - - -get_total_norm = _get_total_norm - - -# from: pytorch/pytorch/torch/nn/utils/clip_grad.py -@_no_grad -def _clip_grads_with_norm_( - parameters: _tensor_or_tensors, - max_norm: float, - total_norm: torch.Tensor, - foreach: Optional[bool] = None, -) -> None: - r"""Scale the gradients of an iterable of parameters given a pre-calculated total norm and desired max norm. - - The gradients will be scaled by the following calculation - - .. math:: - grad = grad * \frac{max\_norm}{total\_norm + 1e-6} - - Gradients are modified in-place. - - This function is equivalent to :func:`torch.nn.utils.clip_grad_norm_` with a pre-calculated - total norm. - - Args: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized - max_norm (float): max norm of the gradients - total_norm (Tensor): total norm of the gradients to use for clipping - foreach (bool): use the faster foreach-based implementation. - If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently - fall back to the slow implementation for other device types. - Default: ``None`` - - Returns: - None - """ - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - grads = [p.grad for p in parameters if p.grad is not None] - max_norm = float(max_norm) - if len(grads) == 0: - return - grouped_grads: Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]] = ( - _group_tensors_by_device_and_dtype([grads]) - ) # type: ignore[assignment] - - clip_coef = max_norm / (total_norm + 1e-6) - # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so - # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization - # when the gradients do not reside in CPU memory. - clip_coef_clamped = torch.clamp(clip_coef, max=1.0) - for (device, _), ([device_grads], _) in grouped_grads.items(): - if (foreach is None and _has_foreach_support(device_grads, device)) or ( - foreach and _device_has_foreach_support(device) - ): - torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) - elif foreach: - raise RuntimeError(f"foreach=True was passed, but can't use the foreach API on {device.type} tensors") - else: - clip_coef_clamped_device = clip_coef_clamped.to(device) - for g in device_grads: - g.mul_(clip_coef_clamped_device) - - -clip_grads_with_norm_ = _clip_grads_with_norm_