diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 8670756b30428404d909ca183be3a74590cee29b..0000000000000000000000000000000000000000 Binary files a/.DS_Store and /dev/null differ diff --git a/build_apexx.sh b/build_apexx.sh deleted file mode 100644 index b3b2a138f5bf10f6e72bf205844e8c063df95173..0000000000000000000000000000000000000000 --- a/build_apexx.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash -set -e - -# ๐Ÿงพ Config -# APEX_COMMIT=f3a960f80244cf9e80558ab30f7f7e8cbf03c0a0 # Known stable commit - -echo "๐Ÿงน Cleaning any previous apex build..." -rm -rf apex -rm -rf *.egg-info build dist - -echo "๐Ÿ“ฅ Cloning NVIDIA/apex..." -git clone https://github.com/NVIDIA/apex.git -cd apex -# git checkout $APEX_COMMIT - -echo "๐Ÿ›  Installing build dependencies..." -sudo apt-get update -sudo apt-get install -y \ - build-essential \ - ninja-build \ - python3-dev \ - libffi-dev \ - libncurses5-dev \ - libncursesw5-dev \ - libreadline-dev \ - libssl-dev \ - libsqlite3-dev \ - zlib1g-dev \ - libbz2-dev \ - liblzma-dev \ - git - -echo "๐Ÿ Upgrading pip and wheel..." -pip install -U pip setuptools wheel - -echo "๐Ÿงช Verifying PyTorch + CUDA availability..." -python -c "import torch; print('PyTorch:', torch.__version__, '| CUDA:', torch.version.cuda)" - -echo "๐Ÿ”ง Building Apex with CUDA and C++ extensions..." -python setup.py bdist_wheel --cuda_ext --cpp_ext - -echo "โœ… Done! Built wheel:" -ls -lh dist/*.whl - -cd .. diff --git a/commonx/__init__.py b/commonx/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/commonx/cache.py b/commonx/cache.py deleted file mode 100644 index 89592fe8747a0b68b8553729abe908c6f06a5aa5..0000000000000000000000000000000000000000 --- a/commonx/cache.py +++ /dev/null @@ -1,47 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Callable - - -class Cache: - """Caching reusable args for faster inference""" - - def __init__(self, disable=False, prefix="", cache=None): - self.cache = cache if cache is not None else {} - self.disable = disable - self.prefix = prefix - - def __call__(self, key: str, fn: Callable): - if self.disable: - return fn() - - key = self.prefix + key - try: - result = self.cache[key] - except KeyError: - result = fn() - self.cache[key] = result - return result - - def namespace(self, namespace: str): - return Cache( - disable=self.disable, - prefix=self.prefix + namespace + ".", - cache=self.cache, - ) - - def get(self, key: str): - key = self.prefix + key - return self.cache[key] diff --git a/commonx/config.py b/commonx/config.py deleted file mode 100644 index f963e8229b8352ef514422609bcbaf9b8c761b15..0000000000000000000000000000000000000000 --- a/commonx/config.py +++ /dev/null @@ -1,110 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -""" -Configuration utility functions -""" - -import importlib -from typing import Any, Callable, List, Union -from omegaconf import DictConfig, ListConfig, OmegaConf - -OmegaConf.register_new_resolver("eval", eval) - - -def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]: - """ - Load a configuration. Will resolve inheritance. - """ - config = OmegaConf.load(path) - if argv is not None: - config_argv = OmegaConf.from_dotlist(argv) - config = OmegaConf.merge(config, config_argv) - config = resolve_recursive(config, resolve_inheritance) - return config - - -def resolve_recursive( - config: Any, - resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]], -) -> Any: - config = resolver(config) - if isinstance(config, DictConfig): - for k in config.keys(): - v = config.get(k) - if isinstance(v, (DictConfig, ListConfig)): - config[k] = resolve_recursive(v, resolver) - if isinstance(config, ListConfig): - for i in range(len(config)): - v = config.get(i) - if isinstance(v, (DictConfig, ListConfig)): - config[i] = resolve_recursive(v, resolver) - return config - - -def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any: - """ - Recursively resolve inheritance if the config contains: - __inherit__: path/to/parent.yaml or a ListConfig of such paths. - """ - if isinstance(config, DictConfig): - inherit = config.pop("__inherit__", None) - - if inherit: - inherit_list = inherit if isinstance(inherit, ListConfig) else [inherit] - - parent_config = None - for parent_path in inherit_list: - assert isinstance(parent_path, str) - parent_config = ( - load_config(parent_path) - if parent_config is None - else OmegaConf.merge(parent_config, load_config(parent_path)) - ) - - if len(config.keys()) > 0: - config = OmegaConf.merge(parent_config, config) - else: - config = parent_config - return config - - -def import_item(path: str, name: str) -> Any: - """ - Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass - """ - return getattr(importlib.import_module(path), name) - - -def create_object(config: DictConfig) -> Any: - """ - Create an object from config. - The config is expected to contains the following: - __object__: - path: path.to.module - name: MyClass - args: as_config | as_params (default to as_config) - """ - item = import_item( - path=config.__object__.path, - name=config.__object__.name, - ) - args = config.__object__.get("args", "as_config") - if args == "as_config": - return item(config) - if args == "as_params": - config = OmegaConf.to_object(config) - config.pop("__object__") - return item(**config) - raise NotImplementedError(f"Unknown args type: {args}") \ No newline at end of file diff --git a/commonx/decorators.py b/commonx/decorators.py deleted file mode 100644 index 332a32d7b838cf7f8be902b9ae4895bad5edcd2e..0000000000000000000000000000000000000000 --- a/commonx/decorators.py +++ /dev/null @@ -1,147 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -""" -Decorators. -""" - -import functools -import threading -import time -from typing import Callable -import torch - -from common.distributed import barrier_if_distributed, get_global_rank, get_local_rank -from common.logger import get_logger - -logger = get_logger(__name__) - - -def log_on_entry(func: Callable) -> Callable: - """ - Functions with this decorator will log the function name at entry. - When using multiple decorators, this must be applied innermost to properly capture the name. - """ - - def log_on_entry_wrapper(*args, **kwargs): - logger.info(f"Entering {func.__name__}") - return func(*args, **kwargs) - - return log_on_entry_wrapper - - -def barrier_on_entry(func: Callable) -> Callable: - """ - Functions with this decorator will start executing when all ranks are ready to enter. - """ - - def barrier_on_entry_wrapper(*args, **kwargs): - barrier_if_distributed() - return func(*args, **kwargs) - - return barrier_on_entry_wrapper - - -def _conditional_execute_wrapper_factory(execute: bool, func: Callable) -> Callable: - """ - Helper function for local_rank_zero_only and global_rank_zero_only. - """ - - def conditional_execute_wrapper(*args, **kwargs): - # Only execute if needed. - result = func(*args, **kwargs) if execute else None - # All GPUs must wait. - barrier_if_distributed() - # Return results. - return result - - return conditional_execute_wrapper - - -def _asserted_wrapper_factory(condition: bool, func: Callable, err_msg: str = "") -> Callable: - """ - Helper function for some functions with special constraints, - especially functions called by other global_rank_zero_only / local_rank_zero_only ones, - in case they are wrongly invoked in other scenarios. - """ - - def asserted_execute_wrapper(*args, **kwargs): - assert condition, err_msg - result = func(*args, **kwargs) - return result - - return asserted_execute_wrapper - - -def local_rank_zero_only(func: Callable) -> Callable: - """ - Functions with this decorator will only execute on local rank zero. - """ - return _conditional_execute_wrapper_factory(get_local_rank() == 0, func) - - -def global_rank_zero_only(func: Callable) -> Callable: - """ - Functions with this decorator will only execute on global rank zero. - """ - return _conditional_execute_wrapper_factory(get_global_rank() == 0, func) - - -def assert_only_global_rank_zero(func: Callable) -> Callable: - """ - Functions with this decorator are only accessible to processes with global rank zero. - """ - return _asserted_wrapper_factory( - get_global_rank() == 0, func, err_msg="Not accessible to processes with global_rank != 0" - ) - - -def assert_only_local_rank_zero(func: Callable) -> Callable: - """ - Functions with this decorator are only accessible to processes with local rank zero. - """ - return _asserted_wrapper_factory( - get_local_rank() == 0, func, err_msg="Not accessible to processes with local_rank != 0" - ) - - -def new_thread(func: Callable) -> Callable: - """ - Functions with this decorator will run in a new thread. - The function will return the thread, which can be joined to wait for completion. - """ - - def new_thread_wrapper(*args, **kwargs): - thread = threading.Thread(target=func, args=args, kwargs=kwargs) - thread.start() - return thread - - return new_thread_wrapper - - -def log_runtime(func: Callable) -> Callable: - """ - Functions with this decorator will logging the runtime. - """ - - @functools.wraps(func) - def wrapped(*args, **kwargs): - torch.distributed.barrier() - start = time.perf_counter() - result = func(*args, **kwargs) - torch.distributed.barrier() - logger.info(f"Completed {func.__name__} in {time.perf_counter() - start:.3f} seconds.") - return result - - return wrapped diff --git a/commonx/diffusion/__init__.py b/commonx/diffusion/__init__.py deleted file mode 100644 index 034e36ef7f9eb0b3ae94280165e622a362e9fc1e..0000000000000000000000000000000000000000 --- a/commonx/diffusion/__init__.py +++ /dev/null @@ -1,56 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -""" -Diffusion package. -""" - -from .config import ( - create_sampler_from_config, - create_sampling_timesteps_from_config, - create_schedule_from_config, -) -from .samplers.base import Sampler -from .samplers.euler import EulerSampler -from .schedules.base import Schedule -from .schedules.lerp import LinearInterpolationSchedule -from .timesteps.base import SamplingTimesteps, Timesteps -from .timesteps.sampling.trailing import UniformTrailingSamplingTimesteps -from .types import PredictionType, SamplingDirection -from .utils import classifier_free_guidance, classifier_free_guidance_dispatcher, expand_dims - -__all__ = [ - # Configs - "create_sampler_from_config", - "create_sampling_timesteps_from_config", - "create_schedule_from_config", - # Schedules - "Schedule", - "DiscreteVariancePreservingSchedule", - "LinearInterpolationSchedule", - # Samplers - "Sampler", - "EulerSampler", - # Timesteps - "Timesteps", - "SamplingTimesteps", - # Types - "PredictionType", - "SamplingDirection", - "UniformTrailingSamplingTimesteps", - # Utils - "classifier_free_guidance", - "classifier_free_guidance_dispatcher", - "expand_dims", -] diff --git a/commonx/diffusion/config.py b/commonx/diffusion/config.py deleted file mode 100644 index f1d0468d88b5dd5f0d787c75ed3df06742d0a483..0000000000000000000000000000000000000000 --- a/commonx/diffusion/config.py +++ /dev/null @@ -1,74 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -""" -Utility functions for creating schedules and samplers from config. -""" - -import torch -from omegaconf import DictConfig - -from .samplers.base import Sampler -from .samplers.euler import EulerSampler -from .schedules.base import Schedule -from .schedules.lerp import LinearInterpolationSchedule -from .timesteps.base import SamplingTimesteps -from .timesteps.sampling.trailing import UniformTrailingSamplingTimesteps - - -def create_schedule_from_config( - config: DictConfig, - device: torch.device, - dtype: torch.dtype = torch.float32, -) -> Schedule: - """ - Create a schedule from configuration. - """ - if config.type == "lerp": - return LinearInterpolationSchedule(T=config.get("T", 1.0)) - - raise NotImplementedError - - -def create_sampler_from_config( - config: DictConfig, - schedule: Schedule, - timesteps: SamplingTimesteps, -) -> Sampler: - """ - Create a sampler from configuration. - """ - if config.type == "euler": - return EulerSampler( - schedule=schedule, - timesteps=timesteps, - prediction_type=config.prediction_type, - ) - raise NotImplementedError - - -def create_sampling_timesteps_from_config( - config: DictConfig, - schedule: Schedule, - device: torch.device, - dtype: torch.dtype = torch.float32, -) -> SamplingTimesteps: - if config.type == "uniform_trailing": - return UniformTrailingSamplingTimesteps( - T=schedule.T, - steps=config.steps, - shift=config.get("shift", 1.0), - device=device, - ) - raise NotImplementedError \ No newline at end of file diff --git a/commonx/diffusion/samplers/base.py b/commonx/diffusion/samplers/base.py deleted file mode 100644 index 8e65f19896b6d5844e769762e76d699b96abc733..0000000000000000000000000000000000000000 --- a/commonx/diffusion/samplers/base.py +++ /dev/null @@ -1,108 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -""" -Sampler base class. -""" - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Callable -import torch -from tqdm import tqdm - -from ..schedules.base import Schedule -from ..timesteps.base import SamplingTimesteps -from ..types import PredictionType, SamplingDirection -from ..utils import assert_schedule_timesteps_compatible - - -@dataclass -class SamplerModelArgs: - x_t: torch.Tensor - t: torch.Tensor - i: int - - -class Sampler(ABC): - """ - Samplers are ODE/SDE solvers. - """ - - def __init__( - self, - schedule: Schedule, - timesteps: SamplingTimesteps, - prediction_type: PredictionType, - return_endpoint: bool = True, - ): - assert_schedule_timesteps_compatible( - schedule=schedule, - timesteps=timesteps, - ) - self.schedule = schedule - self.timesteps = timesteps - self.prediction_type = prediction_type - self.return_endpoint = return_endpoint - - @abstractmethod - def sample( - self, - x: torch.Tensor, - f: Callable[[SamplerModelArgs], torch.Tensor], - ) -> torch.Tensor: - """ - Generate a new sample given the the intial sample x and score function f. - """ - - def get_next_timestep( - self, - t: torch.Tensor, - ) -> torch.Tensor: - """ - Get the next sample timestep. - Support multiple different timesteps t in a batch. - If no more steps, return out of bound value -1 or T+1. - """ - T = self.timesteps.T - steps = len(self.timesteps) - curr_idx = self.timesteps.index(t) - next_idx = curr_idx + 1 - bound = -1 if self.timesteps.direction == SamplingDirection.backward else T + 1 - - s = self.timesteps[next_idx.clamp_max(steps - 1)] - s = s.where(next_idx < steps, bound) - return s - - def get_endpoint( - self, - pred: torch.Tensor, - x_t: torch.Tensor, - t: torch.Tensor, - ) -> torch.Tensor: - """ - Get to the endpoint of the probability flow. - """ - x_0, x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t) - return x_0 if self.timesteps.direction == SamplingDirection.backward else x_T - - def get_progress_bar(self): - """ - Get progress bar for sampling. - """ - return tqdm( - iterable=range(len(self.timesteps) - (0 if self.return_endpoint else 1)), - dynamic_ncols=True, - desc=self.__class__.__name__, - ) diff --git a/commonx/diffusion/samplers/euler.py b/commonx/diffusion/samplers/euler.py deleted file mode 100644 index 5994979a43658b7ebb75316cefea737d1c54681b..0000000000000000000000000000000000000000 --- a/commonx/diffusion/samplers/euler.py +++ /dev/null @@ -1,89 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - - -""" -Euler ODE solver. -""" - -from typing import Callable -import torch -from einops import rearrange -from torch.nn import functional as F - -from models.dit_v2 import na - -from ..types import PredictionType -from ..utils import expand_dims -from .base import Sampler, SamplerModelArgs - - -class EulerSampler(Sampler): - """ - The Euler method is the simplest ODE solver. - - """ - - def sample( - self, - x: torch.Tensor, - f: Callable[[SamplerModelArgs], torch.Tensor], - ) -> torch.Tensor: - timesteps = self.timesteps.timesteps - progress = self.get_progress_bar() - i = 0 - for t, s in zip(timesteps[:-1], timesteps[1:]): - pred = f(SamplerModelArgs(x, t, i)) - x = self.step_to(pred, x, t, s) - i += 1 - progress.update() - - if self.return_endpoint: - t = timesteps[-1] - pred = f(SamplerModelArgs(x, t, i)) - x = self.get_endpoint(pred, x, t) - progress.update() - return x - - def step( - self, - pred: torch.Tensor, - x_t: torch.Tensor, - t: torch.Tensor, - ) -> torch.Tensor: - """ - Step to the next timestep. - """ - return self.step_to(pred, x_t, t, self.get_next_timestep(t)) - - def step_to( - self, - pred: torch.Tensor, - x_t: torch.Tensor, - t: torch.Tensor, - s: torch.Tensor, - ) -> torch.Tensor: - """ - Steps from x_t at timestep t to x_s at timestep s. Returns x_s. - """ - t = expand_dims(t, x_t.ndim) - s = expand_dims(s, x_t.ndim) - T = self.schedule.T - # Step from x_t to x_s. - pred_x_0, pred_x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t) - pred_x_s = self.schedule.forward(pred_x_0, pred_x_T, s.clamp(0, T)) - # Clamp x_s to x_0 and x_T if s is out of bound. - pred_x_s = pred_x_s.where(s >= 0, pred_x_0) - pred_x_s = pred_x_s.where(s <= T, pred_x_T) - return pred_x_s diff --git a/commonx/diffusion/schedules/base.py b/commonx/diffusion/schedules/base.py deleted file mode 100644 index bcf6c6b6460977c6e2687e225c5c913a928bf812..0000000000000000000000000000000000000000 --- a/commonx/diffusion/schedules/base.py +++ /dev/null @@ -1,131 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -""" -Schedule base class. -""" - -from abc import ABC, abstractmethod, abstractproperty -from typing import Tuple, Union -import torch - -from ..types import PredictionType -from ..utils import expand_dims - - -class Schedule(ABC): - """ - Diffusion schedules are uniquely defined by T, A, B: - - x_t = A(t) * x_0 + B(t) * x_T, where t in [0, T] - - Schedules can be continuous or discrete. - """ - - @abstractproperty - def T(self) -> Union[int, float]: - """ - Maximum timestep inclusive. - Schedule is continuous if float, discrete if int. - """ - - @abstractmethod - def A(self, t: torch.Tensor) -> torch.Tensor: - """ - Interpolation coefficient A. - Returns tensor with the same shape as t. - """ - - @abstractmethod - def B(self, t: torch.Tensor) -> torch.Tensor: - """ - Interpolation coefficient B. - Returns tensor with the same shape as t. - """ - - # ---------------------------------------------------- - - def snr(self, t: torch.Tensor) -> torch.Tensor: - """ - Signal to noise ratio. - Returns tensor with the same shape as t. - """ - return (self.A(t) ** 2) / (self.B(t) ** 2) - - def isnr(self, snr: torch.Tensor) -> torch.Tensor: - """ - Inverse signal to noise ratio. - Returns tensor with the same shape as snr. - Subclass may implement. - """ - raise NotImplementedError - - # ---------------------------------------------------- - - def is_continuous(self) -> bool: - """ - Whether the schedule is continuous. - """ - return isinstance(self.T, float) - - def forward(self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - """ - Diffusion forward function. - """ - t = expand_dims(t, x_0.ndim) - return self.A(t) * x_0 + self.B(t) * x_T - - def convert_from_pred( - self, pred: torch.Tensor, pred_type: PredictionType, x_t: torch.Tensor, t: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Convert from prediction. Return predicted x_0 and x_T. - """ - t = expand_dims(t, x_t.ndim) - A_t = self.A(t) - B_t = self.B(t) - - if pred_type == PredictionType.x_T: - pred_x_T = pred - pred_x_0 = (x_t - B_t * pred_x_T) / A_t - elif pred_type == PredictionType.x_0: - pred_x_0 = pred - pred_x_T = (x_t - A_t * pred_x_0) / B_t - elif pred_type == PredictionType.v_cos: - pred_x_0 = A_t * x_t - B_t * pred - pred_x_T = A_t * pred + B_t * x_t - elif pred_type == PredictionType.v_lerp: - pred_x_0 = (x_t - B_t * pred) / (A_t + B_t) - pred_x_T = (x_t + A_t * pred) / (A_t + B_t) - else: - raise NotImplementedError - - return pred_x_0, pred_x_T - - def convert_to_pred( - self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor, pred_type: PredictionType - ) -> torch.FloatTensor: - """ - Convert to prediction target given x_0 and x_T. - """ - if pred_type == PredictionType.x_T: - return x_T - if pred_type == PredictionType.x_0: - return x_0 - if pred_type == PredictionType.v_cos: - t = expand_dims(t, x_0.ndim) - return self.A(t) * x_T - self.B(t) * x_0 - if pred_type == PredictionType.v_lerp: - return x_T - x_0 - raise NotImplementedError diff --git a/commonx/diffusion/schedules/lerp.py b/commonx/diffusion/schedules/lerp.py deleted file mode 100644 index 56b42bc17538b3217b2209234fc723ac3f58a746..0000000000000000000000000000000000000000 --- a/commonx/diffusion/schedules/lerp.py +++ /dev/null @@ -1,55 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -""" -Linear interpolation schedule (lerp). -""" - -from typing import Union -import torch - -from .base import Schedule - - -class LinearInterpolationSchedule(Schedule): - """ - Linear interpolation schedule (lerp) is proposed by flow matching and rectified flow. - It leads to straighter probability flow theoretically. It is also used by Stable Diffusion 3. - - - - x_t = (1 - t) * x_0 + t * x_T - - Can be either continuous or discrete. - """ - - def __init__(self, T: Union[int, float] = 1.0): - self._T = T - - @property - def T(self) -> Union[int, float]: - return self._T - - def A(self, t: torch.Tensor) -> torch.Tensor: - return 1 - (t / self.T) - - def B(self, t: torch.Tensor) -> torch.Tensor: - return t / self.T - - # ---------------------------------------------------- - - def isnr(self, snr: torch.Tensor) -> torch.Tensor: - t = self.T / (1 + snr**0.5) - t = t if self.is_continuous() else t.round().int() - return t diff --git a/commonx/diffusion/timesteps/base.py b/commonx/diffusion/timesteps/base.py deleted file mode 100644 index d1a598103547694d5ef4dc5db0be1e5be2deb60c..0000000000000000000000000000000000000000 --- a/commonx/diffusion/timesteps/base.py +++ /dev/null @@ -1,72 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Sequence, Union -import torch - -from ..types import SamplingDirection - - -class Timesteps(ABC): - """ - Timesteps base class. - """ - - def __init__(self, T: Union[int, float]): - assert T > 0 - self._T = T - - @property - def T(self) -> Union[int, float]: - """ - Maximum timestep inclusive. - int if discrete, float if continuous. - """ - return self._T - - def is_continuous(self) -> bool: - """ - Whether the schedule is continuous. - """ - return isinstance(self.T, float) - - -class SamplingTimesteps(Timesteps): - """ - Sampling timesteps. - It defines the discretization of sampling steps. - """ - - def __init__( - self, - T: Union[int, float], - timesteps: torch.Tensor, - direction: SamplingDirection, - ): - assert timesteps.ndim == 1 - super().__init__(T) - self.timesteps = timesteps - self.direction = direction - - def __len__(self) -> int: - """ - Number of sampling steps. - """ - return len(self.timesteps) - - def __getitem__(self, idx: Union[int, torch.IntTensor]) -> torch.Tensor: - """ - The timestep at the sampling step. - Returns a scalar tensor if idx is int, - or tensor of the same size if idx is a tensor. - """ - return self.timesteps[idx] - - def index(self, t: torch.Tensor) -> torch.Tensor: - """ - Find index by t. - Return index of the same shape as t. - Index is -1 if t not found in timesteps. - """ - i, j = t.reshape(-1, 1).eq(self.timesteps).nonzero(as_tuple=True) - idx = torch.full_like(t, fill_value=-1, dtype=torch.int) - idx.view(-1)[i] = j.int() - return idx diff --git a/commonx/diffusion/timesteps/sampling/trailing.py b/commonx/diffusion/timesteps/sampling/trailing.py deleted file mode 100644 index 248d986aedaaff8f417c32a42e9d9e3a61012f58..0000000000000000000000000000000000000000 --- a/commonx/diffusion/timesteps/sampling/trailing.py +++ /dev/null @@ -1,49 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -import torch - -from ...types import SamplingDirection -from ..base import SamplingTimesteps - - -class UniformTrailingSamplingTimesteps(SamplingTimesteps): - """ - Uniform trailing sampling timesteps. - Defined in (https://arxiv.org/abs/2305.08891) - - Shift is proposed in SD3 for RF schedule. - Defined in (https://arxiv.org/pdf/2403.03206) eq.23 - """ - - def __init__( - self, - T: int, - steps: int, - shift: float = 1.0, - device: torch.device = "cpu", - ): - # Create trailing timesteps. - timesteps = torch.arange(1.0, 0.0, -1.0 / steps, device=device) - - # Shift timesteps. - timesteps = shift * timesteps / (1 + (shift - 1) * timesteps) - - # Scale to T range. - if isinstance(T, float): - timesteps = timesteps * T - else: - timesteps = timesteps.mul(T + 1).sub(1).round().int() - - super().__init__(T=T, timesteps=timesteps, direction=SamplingDirection.backward) diff --git a/commonx/diffusion/types.py b/commonx/diffusion/types.py deleted file mode 100644 index 076295f2be24dadc79da20a5f335b391eb9543bb..0000000000000000000000000000000000000000 --- a/commonx/diffusion/types.py +++ /dev/null @@ -1,59 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -""" -Type definitions. -""" - -from enum import Enum - - -class PredictionType(str, Enum): - """ - x_0: - Predict data sample. - x_T: - Predict noise sample. - Proposed by DDPM (https://arxiv.org/abs/2006.11239) - Proved problematic by zsnr paper (https://arxiv.org/abs/2305.08891) - v_cos: - Predict velocity dx/dt based on the cosine schedule (A_t * x_T - B_t * x_0). - Proposed by progressive distillation (https://arxiv.org/abs/2202.00512) - v_lerp: - Predict velocity dx/dt based on the lerp schedule (x_T - x_0). - Proposed by rectified flow (https://arxiv.org/abs/2209.03003) - """ - - x_0 = "x_0" - x_T = "x_T" - v_cos = "v_cos" - v_lerp = "v_lerp" - - -class SamplingDirection(str, Enum): - """ - backward: Sample from x_T to x_0 for data generation. - forward: Sample from x_0 to x_T for noise inversion. - """ - - backward = "backward" - forward = "forward" - - @staticmethod - def reverse(direction): - if direction == SamplingDirection.backward: - return SamplingDirection.forward - if direction == SamplingDirection.forward: - return SamplingDirection.backward - raise NotImplementedError diff --git a/commonx/diffusion/utils.py b/commonx/diffusion/utils.py deleted file mode 100644 index 69d4aec34f59b293e2354744a4329008063a30e3..0000000000000000000000000000000000000000 --- a/commonx/diffusion/utils.py +++ /dev/null @@ -1,84 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -""" -Utility functions. -""" - -from typing import Callable -import torch - - -def expand_dims(tensor: torch.Tensor, ndim: int): - """ - Expand tensor to target ndim. New dims are added to the right. - For example, if the tensor shape was (8,), target ndim is 4, return (8, 1, 1, 1). - """ - shape = tensor.shape + (1,) * (ndim - tensor.ndim) - return tensor.reshape(shape) - - -def assert_schedule_timesteps_compatible(schedule, timesteps): - """ - Check if schedule and timesteps are compatible. - """ - if schedule.T != timesteps.T: - raise ValueError("Schedule and timesteps must have the same T.") - if schedule.is_continuous() != timesteps.is_continuous(): - raise ValueError("Schedule and timesteps must have the same continuity.") - - -def classifier_free_guidance( - pos: torch.Tensor, - neg: torch.Tensor, - scale: float, - rescale: float = 0.0, -): - """ - Apply classifier-free guidance. - """ - # Classifier-free guidance (https://arxiv.org/abs/2207.12598) - cfg = neg + scale * (pos - neg) - - # Classifier-free guidance rescale (https://arxiv.org/pdf/2305.08891.pdf) - if rescale != 0.0: - pos_std = pos.std(dim=list(range(1, pos.ndim)), keepdim=True) - cfg_std = cfg.std(dim=list(range(1, cfg.ndim)), keepdim=True) - factor = pos_std / cfg_std - factor = rescale * factor + (1 - rescale) - cfg *= factor - - return cfg - - -def classifier_free_guidance_dispatcher( - pos: Callable, - neg: Callable, - scale: float, - rescale: float = 0.0, -): - """ - Optionally execute models depending on classifer-free guidance scale. - """ - # If scale is 1, no need to execute neg model. - if scale == 1.0: - return pos() - - # Otherwise, execute both pos nad neg models and apply cfg. - return classifier_free_guidance( - pos=pos(), - neg=neg(), - scale=scale, - rescale=rescale, - ) diff --git a/commonx/distributed/__init__.py b/commonx/distributed/__init__.py deleted file mode 100644 index a5b4f873ae3e5524c88942bb27ec98ac98c3b5b5..0000000000000000000000000000000000000000 --- a/commonx/distributed/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -""" -Distributed package. -""" - -from .basic import ( - barrier_if_distributed, - convert_to_ddp, - get_device, - get_global_rank, - get_local_rank, - get_world_size, - init_torch, -) - -__all__ = [ - "barrier_if_distributed", - "convert_to_ddp", - "get_device", - "get_global_rank", - "get_local_rank", - "get_world_size", - "init_torch", -] diff --git a/commonx/distributed/advanced.py b/commonx/distributed/advanced.py deleted file mode 100644 index f55fe20ab45494c96124b072d628273d49def1fa..0000000000000000000000000000000000000000 --- a/commonx/distributed/advanced.py +++ /dev/null @@ -1,208 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -""" -Advanced distributed functions for sequence parallel. -""" - -from typing import Optional, List -import torch -import torch.distributed as dist -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -from torch.distributed.fsdp import ShardingStrategy - -from .basic import get_global_rank, get_world_size - - -_DATA_PARALLEL_GROUP = None -_SEQUENCE_PARALLEL_GROUP = None -_SEQUENCE_PARALLEL_CPU_GROUP = None -_MODEL_SHARD_CPU_INTER_GROUP = None -_MODEL_SHARD_CPU_INTRA_GROUP = None -_MODEL_SHARD_INTER_GROUP = None -_MODEL_SHARD_INTRA_GROUP = None -_SEQUENCE_PARALLEL_GLOBAL_RANKS = None - - -def get_data_parallel_group() -> Optional[dist.ProcessGroup]: - """ - Get data parallel process group. - """ - return _DATA_PARALLEL_GROUP - - -def get_sequence_parallel_group() -> Optional[dist.ProcessGroup]: - """ - Get sequence parallel process group. - """ - return _SEQUENCE_PARALLEL_GROUP - - -def get_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]: - """ - Get sequence parallel CPU process group. - """ - return _SEQUENCE_PARALLEL_CPU_GROUP - - -def get_data_parallel_rank() -> int: - """ - Get data parallel rank. - """ - group = get_data_parallel_group() - return dist.get_rank(group) if group else get_global_rank() - - -def get_data_parallel_world_size() -> int: - """ - Get data parallel world size. - """ - group = get_data_parallel_group() - return dist.get_world_size(group) if group else get_world_size() - - -def get_sequence_parallel_rank() -> int: - """ - Get sequence parallel rank. - """ - group = get_sequence_parallel_group() - return dist.get_rank(group) if group else 0 - - -def get_sequence_parallel_world_size() -> int: - """ - Get sequence parallel world size. - """ - group = get_sequence_parallel_group() - return dist.get_world_size(group) if group else 1 - - -def get_model_shard_cpu_intra_group() -> Optional[dist.ProcessGroup]: - """ - Get the CPU intra process group of model sharding. - """ - return _MODEL_SHARD_CPU_INTRA_GROUP - - -def get_model_shard_cpu_inter_group() -> Optional[dist.ProcessGroup]: - """ - Get the CPU inter process group of model sharding. - """ - return _MODEL_SHARD_CPU_INTER_GROUP - - -def get_model_shard_intra_group() -> Optional[dist.ProcessGroup]: - """ - Get the GPU intra process group of model sharding. - """ - return _MODEL_SHARD_INTRA_GROUP - - -def get_model_shard_inter_group() -> Optional[dist.ProcessGroup]: - """ - Get the GPU inter process group of model sharding. - """ - return _MODEL_SHARD_INTER_GROUP - - -def init_sequence_parallel(sequence_parallel_size: int): - """ - Initialize sequence parallel. - """ - global _DATA_PARALLEL_GROUP - global _SEQUENCE_PARALLEL_GROUP - global _SEQUENCE_PARALLEL_CPU_GROUP - global _SEQUENCE_PARALLEL_GLOBAL_RANKS - assert dist.is_initialized() - world_size = dist.get_world_size() - rank = dist.get_rank() - data_parallel_size = world_size // sequence_parallel_size - for i in range(data_parallel_size): - start_rank = i * sequence_parallel_size - end_rank = (i + 1) * sequence_parallel_size - ranks = range(start_rank, end_rank) - group = dist.new_group(ranks) - cpu_group = dist.new_group(ranks, backend="gloo") - if rank in ranks: - _SEQUENCE_PARALLEL_GROUP = group - _SEQUENCE_PARALLEL_CPU_GROUP = cpu_group - _SEQUENCE_PARALLEL_GLOBAL_RANKS = list(ranks) - - -def init_model_shard_group( - *, - sharding_strategy: ShardingStrategy, - device_mesh: Optional[DeviceMesh] = None, -): - """ - Initialize process group of model sharding. - """ - global _MODEL_SHARD_INTER_GROUP - global _MODEL_SHARD_INTRA_GROUP - global _MODEL_SHARD_CPU_INTER_GROUP - global _MODEL_SHARD_CPU_INTRA_GROUP - assert dist.is_initialized() - world_size = dist.get_world_size() - if device_mesh is not None: - num_shards_per_group = device_mesh.shape[1] - elif sharding_strategy == ShardingStrategy.NO_SHARD: - num_shards_per_group = 1 - elif sharding_strategy in [ - ShardingStrategy.HYBRID_SHARD, - ShardingStrategy._HYBRID_SHARD_ZERO2, - ]: - num_shards_per_group = torch.cuda.device_count() - else: - num_shards_per_group = world_size - num_groups = world_size // num_shards_per_group - device_mesh = (num_groups, num_shards_per_group) - - gpu_mesh_2d = init_device_mesh("cuda", device_mesh, mesh_dim_names=("inter", "intra")) - cpu_mesh_2d = init_device_mesh("cpu", device_mesh, mesh_dim_names=("inter", "intra")) - - _MODEL_SHARD_INTER_GROUP = gpu_mesh_2d.get_group("inter") - _MODEL_SHARD_INTRA_GROUP = gpu_mesh_2d.get_group("intra") - _MODEL_SHARD_CPU_INTER_GROUP = cpu_mesh_2d.get_group("inter") - _MODEL_SHARD_CPU_INTRA_GROUP = cpu_mesh_2d.get_group("intra") - -def get_sequence_parallel_global_ranks() -> List[int]: - """ - Get all global ranks of the sequence parallel process group - that the caller rank belongs to. - """ - if _SEQUENCE_PARALLEL_GLOBAL_RANKS is None: - return [dist.get_rank()] - return _SEQUENCE_PARALLEL_GLOBAL_RANKS - - -def get_next_sequence_parallel_rank() -> int: - """ - Get the next global rank of the sequence parallel process group - that the caller rank belongs to. - """ - sp_global_ranks = get_sequence_parallel_global_ranks() - sp_rank = get_sequence_parallel_rank() - sp_size = get_sequence_parallel_world_size() - return sp_global_ranks[(sp_rank + 1) % sp_size] - - -def get_prev_sequence_parallel_rank() -> int: - """ - Get the previous global rank of the sequence parallel process group - that the caller rank belongs to. - """ - sp_global_ranks = get_sequence_parallel_global_ranks() - sp_rank = get_sequence_parallel_rank() - sp_size = get_sequence_parallel_world_size() - return sp_global_ranks[(sp_rank + sp_size - 1) % sp_size] \ No newline at end of file diff --git a/commonx/distributed/basic.py b/commonx/distributed/basic.py deleted file mode 100644 index f829aec01eba2cc44d7274b6a0430155c6d42af6..0000000000000000000000000000000000000000 --- a/commonx/distributed/basic.py +++ /dev/null @@ -1,84 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -""" -Distributed basic functions. -""" - -import os -from datetime import timedelta -import torch -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel - - -def get_global_rank() -> int: - """ - Get the global rank, the global index of the GPU. - """ - return int(os.environ.get("RANK", "0")) - - -def get_local_rank() -> int: - """ - Get the local rank, the local index of the GPU. - """ - return int(os.environ.get("LOCAL_RANK", "0")) - - -def get_world_size() -> int: - """ - Get the world size, the total amount of GPUs. - """ - return int(os.environ.get("WORLD_SIZE", "1")) - - -def get_device() -> torch.device: - """ - Get current rank device. - """ - return torch.device("cuda", get_local_rank()) - - -def barrier_if_distributed(*args, **kwargs): - """ - Synchronizes all processes if under distributed context. - """ - if dist.is_initialized(): - return dist.barrier(*args, **kwargs) - - -def init_torch(cudnn_benchmark=True, timeout=timedelta(seconds=600)): - """ - Common PyTorch initialization configuration. - """ - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - torch.backends.cudnn.benchmark = cudnn_benchmark - torch.cuda.set_device(get_local_rank()) - dist.init_process_group( - backend="nccl", - rank=get_global_rank(), - world_size=get_world_size(), - timeout=timeout, - ) - - -def convert_to_ddp(module: torch.nn.Module, **kwargs) -> DistributedDataParallel: - return DistributedDataParallel( - module=module, - device_ids=[get_local_rank()], - output_device=get_local_rank(), - **kwargs, - ) diff --git a/commonx/distributed/meta_init_utils.py b/commonx/distributed/meta_init_utils.py deleted file mode 100644 index 794cd0b8162de596064e494c0b8140a04b9c36a0..0000000000000000000000000000000000000000 --- a/commonx/distributed/meta_init_utils.py +++ /dev/null @@ -1,41 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -import torch -from rotary_embedding_torch import RotaryEmbedding -from torch import nn -from torch.distributed.fsdp._common_utils import _is_fsdp_flattened - -__all__ = ["meta_non_persistent_buffer_init_fn"] - - -def meta_non_persistent_buffer_init_fn(module: nn.Module) -> nn.Module: - """ - Used for materializing `non-persistent tensor buffers` while model resuming. - - Since non-persistent tensor buffers are not saved in state_dict, - when initializing model with meta device, user should materialize those buffers manually. - - Currently, only `rope.dummy` is this special case. - """ - with torch.no_grad(): - for submodule in module.modules(): - if not isinstance(submodule, RotaryEmbedding): - continue - for buffer_name, buffer in submodule.named_buffers(recurse=False): - if buffer.is_meta and "dummy" in buffer_name: - materialized_buffer = torch.zeros_like(buffer, device="cpu") - setattr(submodule, buffer_name, materialized_buffer) - assert not any(b.is_meta for n, b in module.named_buffers()) - return module diff --git a/commonx/distributed/ops.py b/commonx/distributed/ops.py deleted file mode 100644 index 9b2ae02a6f77de3a8a31d217e0e1f2a6b359c3be..0000000000000000000000000000000000000000 --- a/commonx/distributed/ops.py +++ /dev/null @@ -1,494 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -""" -Distributed ops for supporting sequence parallel. -""" - -from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import torch -import torch.distributed as dist -from torch import Tensor - -from common.cache import Cache -from common.distributed.advanced import ( - get_sequence_parallel_group, - get_sequence_parallel_rank, - get_sequence_parallel_world_size, -) - -from .basic import get_device - -_SEQ_DATA_BUF = defaultdict(lambda: [None, None, None]) -_SEQ_DATA_META_SHAPES = defaultdict() -_SEQ_DATA_META_DTYPES = defaultdict() -_SEQ_DATA_ASYNC_COMMS = defaultdict(list) -_SYNC_BUFFER = defaultdict(dict) - - -def single_all_to_all( - local_input: Tensor, - scatter_dim: int, - gather_dim: int, - group: dist.ProcessGroup, - async_op: bool = False, -): - """ - A function to do all-to-all on a tensor - """ - seq_world_size = dist.get_world_size(group) - prev_scatter_dim = scatter_dim - if scatter_dim != 0: - local_input = local_input.transpose(0, scatter_dim) - if gather_dim == 0: - gather_dim = scatter_dim - scatter_dim = 0 - - inp_shape = list(local_input.shape) - inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size - input_t = local_input.reshape( - [seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :] - ).contiguous() - output = torch.empty_like(input_t) - comm = dist.all_to_all_single(output, input_t, group=group, async_op=async_op) - if async_op: - # let user's code transpose & reshape - return output, comm, prev_scatter_dim - - # first dim is seq_world_size, so we can split it directly - output = torch.cat(output.split(1), dim=gather_dim + 1).squeeze(0) - if prev_scatter_dim: - output = output.transpose(0, prev_scatter_dim).contiguous() - return output - - -def _all_to_all( - local_input: Tensor, - scatter_dim: int, - gather_dim: int, - group: dist.ProcessGroup, -): - seq_world_size = dist.get_world_size(group) - input_list = [ - t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim) - ] - output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] - dist.all_to_all(output_list, input_list, group=group) - return torch.cat(output_list, dim=gather_dim).contiguous() - - -class SeqAllToAll(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - group: dist.ProcessGroup, - local_input: Tensor, - scatter_dim: int, - gather_dim: int, - async_op: bool, - ) -> Tensor: - ctx.group = group - ctx.scatter_dim = scatter_dim - ctx.gather_dim = gather_dim - ctx.async_op = async_op - if async_op: - output, comm, prev_scatter_dim = single_all_to_all( - local_input, scatter_dim, gather_dim, group, async_op=async_op - ) - ctx.prev_scatter_dim = prev_scatter_dim - return output, comm - - return _all_to_all(local_input, scatter_dim, gather_dim, group) - - @staticmethod - def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: - if ctx.async_op: - input_t = torch.cat(grad_output[0].split(1), dim=ctx.gather_dim + 1).squeeze(0) - if ctx.prev_scatter_dim: - input_t = input_t.transpose(0, ctx.prev_scatter_dim) - else: - input_t = grad_output[0] - return ( - None, - _all_to_all(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group), - None, - None, - None, - ) - - -class Slice(torch.autograd.Function): - @staticmethod - def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int) -> Tensor: - ctx.group = group - ctx.rank = dist.get_rank(group) - seq_world_size = dist.get_world_size(group) - ctx.seq_world_size = seq_world_size - ctx.dim = dim - dim_size = local_input.shape[dim] - return local_input.split(dim_size // seq_world_size, dim=dim)[ctx.rank].contiguous() - - @staticmethod - def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor, None]: - dim_size = list(grad_output.size()) - split_size = dim_size[0] - dim_size[0] = dim_size[0] * ctx.seq_world_size - output = torch.empty(dim_size, dtype=grad_output.dtype, device=torch.cuda.current_device()) - dist._all_gather_base(output, grad_output, group=ctx.group) - return (None, torch.cat(output.split(split_size), dim=ctx.dim), None) - - -class Gather(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - group: dist.ProcessGroup, - local_input: Tensor, - dim: int, - grad_scale: Optional[bool] = False, - ) -> Tensor: - ctx.group = group - ctx.rank = dist.get_rank(group) - ctx.dim = dim - ctx.grad_scale = grad_scale - seq_world_size = dist.get_world_size(group) - ctx.seq_world_size = seq_world_size - dim_size = list(local_input.size()) - split_size = dim_size[0] - ctx.part_size = dim_size[dim] - dim_size[0] = dim_size[0] * seq_world_size - output = torch.empty(dim_size, dtype=local_input.dtype, device=torch.cuda.current_device()) - dist._all_gather_base(output, local_input.contiguous(), group=ctx.group) - return torch.cat(output.split(split_size), dim=dim) - - @staticmethod - def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]: - if ctx.grad_scale: - grad_output = grad_output * ctx.seq_world_size - return ( - None, - grad_output.split(ctx.part_size, dim=ctx.dim)[ctx.rank].contiguous(), - None, - None, - ) - - -def gather_seq_scatter_heads_qkv( - qkv_tensor: Tensor, - *, - seq_dim: int, - qkv_shape: Optional[Tensor] = None, - cache: Cache = Cache(disable=True), - restore_shape: bool = True, -): - """ - A func to sync splited qkv tensor - qkv_tensor: the tensor we want to do alltoall with. The last dim must - be the projection_idx, which we will split into 3 part. After - spliting, the gather idx will be projecttion_idx + 1 - seq_dim: gather_dim for all2all comm - restore_shape: if True, output will has the same shape length as input - """ - group = get_sequence_parallel_group() - if not group: - return qkv_tensor - world = get_sequence_parallel_world_size() - orig_shape = qkv_tensor.shape - scatter_dim = qkv_tensor.dim() - bef_all2all_shape = list(orig_shape) - qkv_proj_dim = bef_all2all_shape[-1] - bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3] - qkv_tensor = qkv_tensor.view(bef_all2all_shape) - qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, False) - if restore_shape: - out_shape = list(orig_shape) - out_shape[seq_dim] *= world - out_shape[-1] = qkv_proj_dim // world - qkv_tensor = qkv_tensor.view(out_shape) - - # remove padding - if qkv_shape is not None: - unpad_dim_size = cache( - "unpad_dim_size", lambda: torch.sum(torch.prod(qkv_shape, dim=-1)).item() - ) - if unpad_dim_size % world != 0: - padding_size = qkv_tensor.size(seq_dim) - unpad_dim_size - qkv_tensor = _unpad_tensor(qkv_tensor, seq_dim, padding_size) - return qkv_tensor - - -def slice_inputs(x: Tensor, dim: int, padding: bool = True): - """ - A func to slice the input sequence in sequence parallel - """ - group = get_sequence_parallel_group() - if group is None: - return x - sp_rank = get_sequence_parallel_rank() - sp_world = get_sequence_parallel_world_size() - dim_size = x.shape[dim] - unit = (dim_size + sp_world - 1) // sp_world - if padding and dim_size % sp_world: - padding_size = sp_world - (dim_size % sp_world) - x = _pad_tensor(x, dim, padding_size) - slc = [slice(None)] * len(x.shape) - slc[dim] = slice(unit * sp_rank, unit * (sp_rank + 1)) - return x[slc] - - -def remove_seqeunce_parallel_padding(x: Tensor, dim: int, unpad_dim_size: int): - """ - A func to remove the padding part of the tensor based on its original shape - """ - group = get_sequence_parallel_group() - if group is None: - return x - sp_world = get_sequence_parallel_world_size() - if unpad_dim_size % sp_world == 0: - return x - padding_size = sp_world - (unpad_dim_size % sp_world) - assert (padding_size + unpad_dim_size) % sp_world == 0 - return _unpad_tensor(x, dim=dim, padding_size=padding_size) - - -def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int) -> Tensor: - """ - A func to sync attention result with alltoall in sequence parallel - """ - group = get_sequence_parallel_group() - if not group: - return x - dim_size = x.size(seq_dim) - sp_world = get_sequence_parallel_world_size() - if dim_size % sp_world != 0: - padding_size = sp_world - (dim_size % sp_world) - x = _pad_tensor(x, seq_dim, padding_size) - return SeqAllToAll.apply(group, x, seq_dim, head_dim, False) - - -def gather_seq_scatter_heads(x: Tensor, seq_dim: int, head_dim: int) -> Tensor: - """ - A func to sync embedding input with alltoall in sequence parallel - """ - group = get_sequence_parallel_group() - if not group: - return x - return SeqAllToAll.apply(group, x, head_dim, seq_dim, False) - - -def scatter_heads(x: Tensor, dim: int) -> Tensor: - """ - A func to split heads before attention in sequence parallel - """ - group = get_sequence_parallel_group() - if not group: - return x - return Slice.apply(group, x, dim) - - -def gather_heads(x: Tensor, dim: int, grad_scale: Optional[bool] = False) -> Tensor: - """ - A func to gather heads for the attention result in sequence parallel - """ - group = get_sequence_parallel_group() - if not group: - return x - return Gather.apply(group, x, dim, grad_scale) - - -def gather_outputs( - x: Tensor, - *, - gather_dim: int, - padding_dim: Optional[int] = None, - unpad_shape: Optional[Tensor] = None, - cache: Cache = Cache(disable=True), - scale_grad=True, -): - """ - A func to gather the outputs for the model result in sequence parallel - """ - group = get_sequence_parallel_group() - if not group: - return x - x = Gather.apply(group, x, gather_dim, scale_grad) - if padding_dim is not None: - unpad_dim_size = cache( - "unpad_dim_size", lambda: torch.sum(torch.prod(unpad_shape, dim=1)).item() - ) - x = remove_seqeunce_parallel_padding(x, padding_dim, unpad_dim_size) - return x - - -def _pad_tensor(x: Tensor, dim: int, padding_size: int): - shape = list(x.shape) - shape[dim] = padding_size - pad = torch.zeros(shape, dtype=x.dtype, device=x.device) - return torch.cat([x, pad], dim=dim) - - -def _unpad_tensor(x: Tensor, dim: int, padding_size): - slc = [slice(None)] * len(x.shape) - slc[dim] = slice(0, -padding_size) - return x[slc] - - -def _broadcast_data(data, shape, dtype, src, group, async_op): - comms = [] - if isinstance(data, (list, tuple)): - for i, sub_shape in enumerate(shape): - comms += _broadcast_data(data[i], sub_shape, dtype[i], src, group, async_op) - elif isinstance(data, dict): - for key, sub_data in data.items(): - comms += _broadcast_data(sub_data, shape[key], dtype[key], src, group, async_op) - elif isinstance(data, Tensor): - comms.append(dist.broadcast(data, src=src, group=group, async_op=async_op)) - return comms - - -def _traverse(data: Any, op: Callable) -> Union[None, List, Dict, Any]: - if isinstance(data, (list, tuple)): - return [_traverse(sub_data, op) for sub_data in data] - elif isinstance(data, dict): - return {key: _traverse(sub_data, op) for key, sub_data in data.items()} - elif isinstance(data, Tensor): - return op(data) - else: - return None - - -def _get_shapes(data): - return _traverse(data, op=lambda x: x.shape) - - -def _get_dtypes(data): - return _traverse(data, op=lambda x: x.dtype) - - -def _construct_broadcast_buffer(shapes, dtypes, device): - if isinstance(shapes, torch.Size): - return torch.empty(shapes, dtype=dtypes, device=device) - - if isinstance(shapes, (list, tuple)): - buffer = [] - for i, sub_shape in enumerate(shapes): - buffer.append(_construct_broadcast_buffer(sub_shape, dtypes[i], device)) - elif isinstance(shapes, dict): - buffer = {} - for key, sub_shape in shapes.items(): - buffer[key] = _construct_broadcast_buffer(sub_shape, dtypes[key], device) - else: - return None - return buffer - - -class SPDistForward: - """A forward tool to sync different result across sp group - - Args: - module: a function or module to process users input - sp_step: current training step to judge which rank to broadcast its result to all - name: a distinct str to save meta and async comm - comm_shape: if different ranks have different shape, mark this arg to True - device: the device for current rank, can be empty - """ - - def __init__( - self, - name: str, - comm_shape: bool, - device: torch.device = None, - ): - self.name = name - self.comm_shape = comm_shape - if device: - self.device = device - else: - self.device = get_device() - - def __call__(self, inputs) -> Any: - group = get_sequence_parallel_group() - if not group: - yield inputs - else: - device = self.device - sp_world = get_sequence_parallel_world_size() - sp_rank = get_sequence_parallel_rank() - for local_step in range(sp_world): - src_rank = dist.get_global_rank(group, local_step) - is_src = sp_rank == local_step - local_shapes = [] - local_dtypes = [] - if local_step == 0: - local_result = inputs - _SEQ_DATA_BUF[self.name][-1] = local_result - local_shapes = _get_shapes(local_result) - local_dtypes = _get_dtypes(local_result) - if self.comm_shape: - group_shapes_lists = [None] * sp_world - dist.all_gather_object(group_shapes_lists, local_shapes, group=group) - _SEQ_DATA_META_SHAPES[self.name] = group_shapes_lists - else: - _SEQ_DATA_META_SHAPES[self.name] = [local_shapes] * sp_world - _SEQ_DATA_META_DTYPES[self.name] = local_dtypes - shapes = _SEQ_DATA_META_SHAPES[self.name][local_step] - dtypes = _SEQ_DATA_META_DTYPES[self.name] - buf_id = local_step % 2 - if local_step == 0: - sync_data = ( - local_result - if is_src - else _construct_broadcast_buffer(shapes, dtypes, device) - ) - _broadcast_data(sync_data, shapes, dtypes, src_rank, group, False) - _SEQ_DATA_BUF[self.name][buf_id] = sync_data - - # wait for async comm ops - if _SEQ_DATA_ASYNC_COMMS[self.name]: - for comm in _SEQ_DATA_ASYNC_COMMS[self.name]: - comm.wait() - # before return the sync result, do async broadcast for next batch - if local_step < sp_world - 1: - next_buf_id = 1 - buf_id - shapes = _SEQ_DATA_META_SHAPES[self.name][local_step + 1] - src_rank = dist.get_global_rank(group, local_step + 1) - is_src = sp_rank == local_step + 1 - next_sync_data = ( - _SEQ_DATA_BUF[self.name][-1] - if is_src - else _construct_broadcast_buffer(shapes, dtypes, device) - ) - _SEQ_DATA_ASYNC_COMMS[self.name] = _broadcast_data( - next_sync_data, shapes, dtypes, src_rank, group, True - ) - _SEQ_DATA_BUF[self.name][next_buf_id] = next_sync_data - yield _SEQ_DATA_BUF[self.name][buf_id] - - -sync_inputs = SPDistForward(name="bef_fwd", comm_shape=True) - - -def sync_data(data, sp_idx, name="tmp"): - group = get_sequence_parallel_group() - if group is None: - return data - # if sp_idx in _SYNC_BUFFER[name]: - # return _SYNC_BUFFER[name][sp_idx] - sp_rank = get_sequence_parallel_rank() - src_rank = dist.get_global_rank(group, sp_idx) - objects = [data] if sp_rank == sp_idx else [None] - dist.broadcast_object_list(objects, src=src_rank, group=group) - # _SYNC_BUFFER[name] = {sp_idx: objects[0]} - return objects[0] diff --git a/commonx/logger.py b/commonx/logger.py deleted file mode 100644 index faf795f0aecb2b16471c99802f2240880d701830..0000000000000000000000000000000000000000 --- a/commonx/logger.py +++ /dev/null @@ -1,44 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -""" -Logging utility functions. -""" - -import logging -import sys -from typing import Optional - -from common.distributed import get_global_rank, get_local_rank, get_world_size - -_default_handler = logging.StreamHandler(sys.stdout) -_default_handler.setFormatter( - logging.Formatter( - "%(asctime)s " - + (f"[Rank:{get_global_rank()}]" if get_world_size() > 1 else "") - + (f"[LocalRank:{get_local_rank()}]" if get_world_size() > 1 else "") - + "[%(threadName).12s][%(name)s][%(levelname).5s] " - + "%(message)s" - ) -) - - -def get_logger(name: Optional[str] = None) -> logging.Logger: - """ - Get a logger. - """ - logger = logging.getLogger(name) - logger.addHandler(_default_handler) - logger.setLevel(logging.INFO) - return logger diff --git a/commonx/partition.py b/commonx/partition.py deleted file mode 100644 index 648c87fe2a61294c09704b9af3e47f5a8570c215..0000000000000000000000000000000000000000 --- a/commonx/partition.py +++ /dev/null @@ -1,59 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -""" -Partition utility functions. -""" - -from typing import Any, List - - -def partition_by_size(data: List[Any], size: int) -> List[List[Any]]: - """ - Partition a list by size. - When indivisible, the last group contains fewer items than the target size. - - Examples: - - data: [1,2,3,4,5] - - size: 2 - - return: [[1,2], [3,4], [5]] - """ - assert size > 0 - return [data[i : (i + size)] for i in range(0, len(data), size)] - - -def partition_by_groups(data: List[Any], groups: int) -> List[List[Any]]: - """ - Partition a list by groups. - When indivisible, some groups may have more items than others. - - Examples: - - data: [1,2,3,4,5] - - groups: 2 - - return: [[1,3,5], [2,4]] - """ - assert groups > 0 - return [data[i::groups] for i in range(groups)] - - -def shift_list(data: List[Any], n: int) -> List[Any]: - """ - Rotate a list by n elements. - - Examples: - - data: [1,2,3,4,5] - - n: 3 - - return: [4,5,1,2,3] - """ - return data[(n % len(data)) :] + data[: (n % len(data))] diff --git a/commonx/seed.py b/commonx/seed.py deleted file mode 100644 index 52866de72fcf98f4a2ceff51a55986780a8b701a..0000000000000000000000000000000000000000 --- a/commonx/seed.py +++ /dev/null @@ -1,30 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -import random -from typing import Optional -import numpy as np -import torch - -from common.distributed import get_global_rank - - -def set_seed(seed: Optional[int], same_across_ranks: bool = False): - """Function that sets the seed for pseudo-random number generators.""" - if seed is not None: - seed += get_global_rank() if not same_across_ranks else 0 - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - diff --git a/configs_3bx/main.yaml b/configs_3bx/main.yaml deleted file mode 100644 index 78579065f27852990354bd565b5375a679a76035..0000000000000000000000000000000000000000 --- a/configs_3bx/main.yaml +++ /dev/null @@ -1,88 +0,0 @@ -__object__: - path: projects.video_diffusion_sr.train - name: VideoDiffusionTrainer - -dit: - model: - __object__: - path: models.dit_v2.nadit - name: NaDiT - args: as_params - vid_in_channels: 33 - vid_out_channels: 16 - vid_dim: 2560 - vid_out_norm: fusedrms - txt_in_dim: 5120 - txt_in_norm: fusedln - txt_dim: ${.vid_dim} - emb_dim: ${eval:'6 * ${.vid_dim}'} - heads: 20 - head_dim: 128 # llm-like - expand_ratio: 4 - norm: fusedrms - norm_eps: 1.0e-05 - ada: single - qk_bias: False - qk_norm: fusedrms - patch_size: [ 1,2,2 ] - num_layers: 32 # llm-like - mm_layers: 10 - mlp_type: swiglu - msa_type: None - block_type: ${eval:'${.num_layers} * ["mmdit_sr"]'} # space-full - window: ${eval:'${.num_layers} * [(4,3,3)]'} # space-full - window_method: ${eval:'${.num_layers} // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]'} # space-full - rope_type: mmrope3d - rope_dim: 128 - compile: False - gradient_checkpoint: True - fsdp: - sharding_strategy: _HYBRID_SHARD_ZERO2 - -ema: - decay: 0.9998 - -vae: - model: - __inherit__: models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml - freeze_encoder: False - # gradient_checkpoint: True - slicing: - split_size: 4 - memory_device: same - memory_limit: - conv_max_mem: 0.5 - norm_max_mem: 0.5 - checkpoint: ./ckpts/ema_vae.pth - scaling_factor: 0.9152 - compile: False - grouping: False - dtype: bfloat16 - -diffusion: - schedule: - type: lerp - T: 1000.0 - sampler: - type: euler - prediction_type: v_lerp - timesteps: - training: - type: logitnormal - loc: 0.0 - scale: 1.0 - sampling: - type: uniform_trailing - steps: 50 - transform: True - loss: - type: v_lerp - cfg: - scale: 7.5 - rescale: 0 - -condition: - i2v: 0.0 - v2v: 0.0 - sr: 1.0 - noise_scale: 0.25 diff --git a/configs_7bx/main.yaml b/configs_7bx/main.yaml deleted file mode 100644 index 51c5eaf880788ff941bcce84b2548e3f21646339..0000000000000000000000000000000000000000 --- a/configs_7bx/main.yaml +++ /dev/null @@ -1,85 +0,0 @@ -__object__: - path: projects.video_diffusion_sr.train - name: VideoDiffusionTrainer - -dit: - model: - __object__: - path: models.dit.nadit - name: NaDiT - args: as_params - vid_in_channels: 33 - vid_out_channels: 16 - vid_dim: 3072 - txt_in_dim: 5120 - txt_dim: ${.vid_dim} - emb_dim: ${eval:'6 * ${.vid_dim}'} - heads: 24 - head_dim: 128 # llm-like - expand_ratio: 4 - norm: fusedrms - norm_eps: 1e-5 - ada: single - qk_bias: False - qk_rope: True - qk_norm: fusedrms - patch_size: [ 1,2,2 ] - num_layers: 36 # llm-like - shared_mlp: False - shared_qkv: False - mlp_type: normal - block_type: ${eval:'${.num_layers} * ["mmdit_sr"]'} # space-full - window: ${eval:'${.num_layers} * [(4,3,3)]'} # space-full - window_method: ${eval:'${.num_layers} // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]'} # space-full - compile: False - gradient_checkpoint: True - fsdp: - sharding_strategy: _HYBRID_SHARD_ZERO2 - -ema: - decay: 0.9998 - -vae: - model: - __inherit__: models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml - freeze_encoder: False - # gradient_checkpoint: True - slicing: - split_size: 4 - memory_device: same - memory_limit: - conv_max_mem: 0.5 - norm_max_mem: 0.5 - checkpoint: ./ckpts/ema_vae.pth - scaling_factor: 0.9152 - compile: False - grouping: False - dtype: bfloat16 - -diffusion: - schedule: - type: lerp - T: 1000.0 - sampler: - type: euler - prediction_type: v_lerp - timesteps: - training: - type: logitnormal - loc: 0.0 - scale: 1.0 - sampling: - type: uniform_trailing - steps: 50 - transform: True - loss: - type: v_lerp - cfg: - scale: 7.5 - rescale: 0 - -condition: - i2v: 0.0 - v2v: 0.0 - sr: 1.0 - noise_scale: 0.25 diff --git a/datax/image/transforms/area_resize.py b/datax/image/transforms/area_resize.py deleted file mode 100644 index 9f621dae1b0af40f58e090405db1ac7338110980..0000000000000000000000000000000000000000 --- a/datax/image/transforms/area_resize.py +++ /dev/null @@ -1,135 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -import math -import random -from typing import Union -import torch -from PIL import Image -from torchvision.transforms import functional as TVF -from torchvision.transforms.functional import InterpolationMode - - -class AreaResize: - def __init__( - self, - max_area: float, - downsample_only: bool = False, - interpolation: InterpolationMode = InterpolationMode.BICUBIC, - ): - self.max_area = max_area - self.downsample_only = downsample_only - self.interpolation = interpolation - - def __call__(self, image: Union[torch.Tensor, Image.Image]): - - if isinstance(image, torch.Tensor): - height, width = image.shape[-2:] - elif isinstance(image, Image.Image): - width, height = image.size - else: - raise NotImplementedError - - scale = math.sqrt(self.max_area / (height * width)) - - # keep original height and width for small pictures. - scale = 1 if scale >= 1 and self.downsample_only else scale - - resized_height, resized_width = round(height * scale), round(width * scale) - - return TVF.resize( - image, - size=(resized_height, resized_width), - interpolation=self.interpolation, - ) - - -class AreaRandomCrop: - def __init__( - self, - max_area: float, - ): - self.max_area = max_area - - def get_params(self, input_size, output_size): - """Get parameters for ``crop`` for a random crop. - - Args: - img (PIL Image): Image to be cropped. - output_size (tuple): Expected output size of the crop. - - Returns: - tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. - """ - # w, h = _get_image_size(img) - h, w = input_size - th, tw = output_size - if w <= tw and h <= th: - return 0, 0, h, w - - i = random.randint(0, h - th) - j = random.randint(0, w - tw) - return i, j, th, tw - - def __call__(self, image: Union[torch.Tensor, Image.Image]): - if isinstance(image, torch.Tensor): - height, width = image.shape[-2:] - elif isinstance(image, Image.Image): - width, height = image.size - else: - raise NotImplementedError - - resized_height = math.sqrt(self.max_area / (width / height)) - resized_width = (width / height) * resized_height - - # print('>>>>>>>>>>>>>>>>>>>>>') - # print((height, width)) - # print( (resized_height, resized_width)) - - resized_height, resized_width = round(resized_height), round(resized_width) - i, j, h, w = self.get_params((height, width), (resized_height, resized_width)) - image = TVF.crop(image, i, j, h, w) - return image - -class ScaleResize: - def __init__( - self, - scale: float, - ): - self.scale = scale - - def __call__(self, image: Union[torch.Tensor, Image.Image]): - if isinstance(image, torch.Tensor): - height, width = image.shape[-2:] - interpolation_mode = InterpolationMode.BILINEAR - antialias = True if image.ndim == 4 else "warn" - elif isinstance(image, Image.Image): - width, height = image.size - interpolation_mode = InterpolationMode.LANCZOS - antialias = "warn" - else: - raise NotImplementedError - - scale = self.scale - - # keep original height and width for small pictures - - resized_height, resized_width = round(height * scale), round(width * scale) - image = TVF.resize( - image, - size=(resized_height, resized_width), - interpolation=interpolation_mode, - antialias=antialias, - ) - return image diff --git a/datax/image/transforms/divisible_crop.py b/datax/image/transforms/divisible_crop.py deleted file mode 100644 index d1815b03ee1ce99486143aca24b9023ab0b3973c..0000000000000000000000000000000000000000 --- a/datax/image/transforms/divisible_crop.py +++ /dev/null @@ -1,40 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Union -import torch -from PIL import Image -from torchvision.transforms import functional as TVF - - -class DivisibleCrop: - def __init__(self, factor): - if not isinstance(factor, tuple): - factor = (factor, factor) - - self.height_factor, self.width_factor = factor[0], factor[1] - - def __call__(self, image: Union[torch.Tensor, Image.Image]): - if isinstance(image, torch.Tensor): - height, width = image.shape[-2:] - elif isinstance(image, Image.Image): - width, height = image.size - else: - raise NotImplementedError - - cropped_height = height - (height % self.height_factor) - cropped_width = width - (width % self.width_factor) - - image = TVF.center_crop(img=image, output_size=(cropped_height, cropped_width)) - return image diff --git a/datax/image/transforms/na_resize.py b/datax/image/transforms/na_resize.py deleted file mode 100644 index d230e25e3ca1710ad6261d8e14541a97732b9a30..0000000000000000000000000000000000000000 --- a/datax/image/transforms/na_resize.py +++ /dev/null @@ -1,50 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Literal -from torchvision.transforms import CenterCrop, Compose, InterpolationMode, Resize - -from .area_resize import AreaResize -from .side_resize import SideResize - - -def NaResize( - resolution: int, - mode: Literal["area", "side"], - downsample_only: bool, - interpolation: InterpolationMode = InterpolationMode.BICUBIC, -): - if mode == "area": - return AreaResize( - max_area=resolution**2, - downsample_only=downsample_only, - interpolation=interpolation, - ) - if mode == "side": - return SideResize( - size=resolution, - downsample_only=downsample_only, - interpolation=interpolation, - ) - if mode == "square": - return Compose( - [ - Resize( - size=resolution, - interpolation=interpolation, - ), - CenterCrop(resolution), - ] - ) - raise ValueError(f"Unknown resize mode: {mode}") diff --git a/datax/image/transforms/side_resize.py b/datax/image/transforms/side_resize.py deleted file mode 100644 index 6e07402b2187a048b99d995d68ead12f790f5724..0000000000000000000000000000000000000000 --- a/datax/image/transforms/side_resize.py +++ /dev/null @@ -1,54 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Union -import torch -from PIL import Image -from torchvision.transforms import InterpolationMode -from torchvision.transforms import functional as TVF - - -class SideResize: - def __init__( - self, - size: int, - downsample_only: bool = False, - interpolation: InterpolationMode = InterpolationMode.BICUBIC, - ): - self.size = size - self.downsample_only = downsample_only - self.interpolation = interpolation - - def __call__(self, image: Union[torch.Tensor, Image.Image]): - """ - Args: - image (PIL Image or Tensor): Image to be scaled. - - Returns: - PIL Image or Tensor: Rescaled image. - """ - if isinstance(image, torch.Tensor): - height, width = image.shape[-2:] - elif isinstance(image, Image.Image): - width, height = image.size - else: - raise NotImplementedError - - if self.downsample_only and min(width, height) < self.size: - # keep original height and width for small pictures. - size = min(width, height) - else: - size = self.size - - return TVF.resize(image, size, self.interpolation) diff --git a/datax/video/transforms/rearrange.py b/datax/video/transforms/rearrange.py deleted file mode 100644 index 895347991d71043742777f103d32b62c80284660..0000000000000000000000000000000000000000 --- a/datax/video/transforms/rearrange.py +++ /dev/null @@ -1,24 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from einops import rearrange - - -class Rearrange: - def __init__(self, pattern: str, **kwargs): - self.pattern = pattern - self.kwargs = kwargs - - def __call__(self, x): - return rearrange(x, self.pattern, **self.kwargs) diff --git a/environmentx.yml b/environmentx.yml deleted file mode 100644 index b34b69ab1aae4d93201173a75ab27ff32354333d..0000000000000000000000000000000000000000 --- a/environmentx.yml +++ /dev/null @@ -1,238 +0,0 @@ -name: seedvr -channels: - - defaults -dependencies: - - _libgcc_mutex=0.1=main - - _openmp_mutex=5.1=1_gnu - - ld_impl_linux-64=2.38=h1181459_1 - - libffi=3.4.4=h6a678d5_1 - - libgcc-ng=11.2.0=h1234567_1 - - libgomp=11.2.0=h1234567_1 - - libstdcxx-ng=11.2.0=h1234567_1 - - ncurses=6.4=h6a678d5_0 - - openssl=3.0.14=h5eee18b_0 - - pip=24.0=py39h06a4308_0 - - python=3.9.19=h955ad1f_1 - - readline=8.2=h5eee18b_0 - - setuptools=69.5.1=py39h06a4308_0 - - sqlite=3.45.3=h5eee18b_0 - - tk=8.6.14=h39e8969_0 - - tzdata=2024a=h04d1e81_0 - - wheel=0.43.0=py39h06a4308_0 - - xz=5.4.6=h5eee18b_1 - - zlib=1.2.13=h5eee18b_1 - - pip: - - absl-py==2.1.0 - - accelerate==0.33.0 - - addict==2.4.0 - - antlr4-python3-runtime==4.9.3 - - anykeystore==0.2 - - apex==0.1 - - asttokens==2.4.1 - - astunparse==1.6.3 - - attrs==23.2.0 - - av==12.0.0 - - basicsr==1.4.2 - - beartype==0.18.5 - - beautifulsoup4==4.12.3 - - bitsandbytes==0.44.1 - - black==24.4.2 - - bs4==0.0.2 - - bson==0.5.10 - - certifi==2024.6.2 - - cffi==1.16.0 - - cfgv==3.4.0 - - charset-normalizer==3.3.2 - - click==8.1.7 - - colorama==0.4.6 - - contourpy==1.2.1 - - cryptacular==1.6.2 - - cryptography==39.0.2 - - cycler==0.12.1 - - decorator==4.4.2 - - deepdiff==7.0.1 - - deprecated==1.2.14 - - diffusers==0.33.1 - - distlib==0.3.8 - - dnspython==2.6.1 - - docker-pycreds==0.4.0 - - docstring-parser==0.16 - - einops==0.7.0 - - exceptiongroup==1.2.1 - - executing==2.0.1 - - expecttest==0.2.1 - - facexlib==0.3.0 - - ffmpeg-python==0.2.0 - - filelock==3.15.4 - - filterpy==1.4.5 - - flake8==7.1.0 - - flash-attn==2.5.9.post1 - - flatbuffers==24.3.25 - - fonttools==4.53.0 - - fsspec==2023.6.0 - - ftfy==6.2.0 - - future==1.0.0 - - gast==0.5.4 - - gdown==5.2.0 - - gitdb==4.0.11 - - gitpython==3.1.43 - - google-pasta==0.2.0 - - greenlet==3.0.3 - - grpcio==1.64.1 - - h5py==3.11.0 - - hf-xet==1.1.2 - - huggingface-hub==0.32.2 - - hupper==1.12.1 - - hypothesis==6.100.1 - - icecream==2.1.3 - - identify==2.5.36 - - idna==3.7 - - imageio==2.34.0 - - imageio-ffmpeg==0.5.1 - - importlib-metadata==7.2.1 - - importlib-resources==6.4.0 - - iniconfig==2.0.0 - - ipaddress==1.0.23 - - ipython==8.18.1 - - isort==5.13.2 - - jedi==0.19.1 - - jinja2==3.1.4 - - jsonargparse==4.14.1 - - keras==3.3.3 - - kiwisolver==1.4.5 - - lazy-loader==0.4 - - libclang==18.1.1 - - lightning-utilities==0.11.2 - - llvmlite==0.43.0 - - lmdb==1.5.1 - - lpips==0.1.4 - - markdown==3.6 - - markdown-it-py==3.0.0 - - markupsafe==2.1.5 - - matplotlib==3.9.0 - - matplotlib-inline==0.1.7 - - mccabe==0.7.0 - - mdurl==0.1.2 - - mediapy==1.2.0 - - ml-dtypes==0.3.2 - - moviepy==1.0.3 - - mpmath==1.3.0 - - msgpack==1.0.8 - - mypy-extensions==1.0.0 - - namex==0.0.8 - - networkx==3.2.1 - - nodeenv==1.9.1 - - numba==0.60.0 - - numpy==1.24.4 - - oauthlib==3.2.2 - - omegaconf==2.3.0 - - openai-clip==1.0.1 - - opencv-python==4.9.0.80 - - opencv-python-headless==4.10.0.84 - - opt-einsum==3.3.0 - - optree==0.11.0 - - ordered-set==4.1.0 - - packaging==22.0 - - pandas==1.5.3 - - parameterized==0.9.0 - - parso==0.8.4 - - pastedeploy==3.1.0 - - pathspec==0.12.1 - - pathtools==0.1.2 - - pbkdf2==1.3 - - pexpect==4.9.0 - - pillow==10.3.0 - - plaster==1.1.2 - - plaster-pastedeploy==1.0.1 - - platformdirs==4.2.2 - - pluggy==1.5.0 - - proglog==0.1.10 - - promise==2.3 - - prompt-toolkit==3.0.47 - - protobuf==3.20.3 - - psutil==6.0.0 - - ptyprocess==0.7.0 - - pure-eval==0.2.2 - - pyarrow==11.0.0 - - pycocotools==2.0.7 - - pycodestyle==2.12.0 - - pycparser==2.22 - - pydantic==1.10.17 - - pyflakes==3.2.0 - - pygments==2.18.0 - - pyiqa==0.1.13 - - pyjwt==2.8.0 - - pyopenssl==23.2.0 - - pyparsing==3.1.2 - - pyramid==2.0.2 - - pyramid-mailer==0.15.1 - - pysocks==1.7.1 - - pytest==8.3.3 - - python-dateutil==2.9.0.post0 - - python-etcd==0.4.5 - - python3-openid==3.2.0 - - pytz==2024.1 - - pyyaml==6.0.1 - - regex==2024.5.15 - - repoze-sendmail==4.4.1 - - requests==2.32.3 - - requests-oauthlib==2.0.0 - - rich==13.7.1 - - rotary-embedding-torch==0.5.3 - - safetensors==0.4.3 - - scenedetect==0.6.4 - - schedule==1.2.2 - - scikit-image==0.24.0 - - scipy==1.13.1 - - sentencepiece==0.2.0 - - sentry-sdk==2.6.0 - - setproctitle==1.3.3 - - shortuuid==1.0.13 - - six==1.16.0 - - smmap==5.0.1 - - sortedcontainers==2.4.0 - - soupsieve==2.5 - - sqlalchemy==2.0.31 - - stack-data==0.6.3 - - sympy==1.12.1 - - tabulate==0.9.0 - - tb-nightly==2.20.0a20250528 - - tenacity==8.4.1 - - tensorboard==2.16.2 - - tensorboard-data-server==0.7.2 - - tensorflow==2.16.1 - - tensorflow-io-gcs-filesystem==0.37.0 - - termcolor==2.4.0 - - tifffile==2024.8.30 - - tiktoken==0.9.0 - - timm==1.0.11 - - tokenizers==0.20.3 - - tomli==2.0.1 - - torch==2.4.0+cu121 - - torch-fidelity==0.3.0 - - torchaudio==2.4.0+cu121 - - torchmetrics==1.3.2 - - torchvision==0.19.0+cu121 - - tqdm==4.66.4 - - traitlets==5.14.3 - - transaction==4.0 - - transformers==4.46.2 - - transformers-stream-generator==0.0.5 - - translationstring==1.4 - - triton==3.0.0 - - typing-extensions==4.12.2 - - urllib3==1.26.19 - - velruse==1.1.1 - - venusian==3.1.0 - - virtualenv==20.26.3 - - wcwidth==0.2.13 - - webob==1.8.7 - - werkzeug==3.0.3 - - wrapt==1.16.0 - - wtforms==3.1.2 - - wtforms-recaptcha==0.3.2 - - yapf==0.43.0 - - zipp==3.19.2 - - zope-deprecation==5.0 - - zope-interface==6.4.post2 - - zope-sqlalchemy==3.1 \ No newline at end of file diff --git a/modelsx/dit/attention.py b/modelsx/dit/attention.py deleted file mode 100644 index ac0cadbcd62e7d40700108d2857cb587f794fcee..0000000000000000000000000000000000000000 --- a/modelsx/dit/attention.py +++ /dev/null @@ -1,46 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -import torch -import torch.nn.functional as F - -from flash_attn import flash_attn_varlen_func - -from torch import nn - -class TorchAttention(nn.Module): - def tflops(self, args, kwargs, output) -> float: - assert len(args) == 0 or len(args) > 2, "query, key should both provided by args / kwargs" - q = kwargs.get("query") or args[0] - k = kwargs.get("key") or args[1] - b, h, sq, d = q.shape - b, h, sk, d = k.shape - return b * h * (4 * d * (sq / 1e6) * (sk / 1e6)) - - def forward(self, *args, **kwargs): - return F.scaled_dot_product_attention(*args, **kwargs) - - -class FlashAttentionVarlen(nn.Module): - def tflops(self, args, kwargs, output) -> float: - cu_seqlens_q = kwargs["cu_seqlens_q"] - cu_seqlens_k = kwargs["cu_seqlens_k"] - _, h, d = output.shape - seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]) / 1e6 - seqlens_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]) / 1e6 - return h * (4 * d * (seqlens_q * seqlens_k).sum()) - - def forward(self, *args, **kwargs): - kwargs["deterministic"] = torch.are_deterministic_algorithms_enabled() - return flash_attn_varlen_func(*args, **kwargs) \ No newline at end of file diff --git a/modelsx/dit/blocks/__init__.py b/modelsx/dit/blocks/__init__.py deleted file mode 100644 index 3195b400a407b871a6c19b67cf25239c5c3f196d..0000000000000000000000000000000000000000 --- a/modelsx/dit/blocks/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from .mmdit_window_block import MMWindowTransformerBlock - -dit_blocks = { - "mmdit_window": MMWindowTransformerBlock, -} - - -def get_block(block_type: str): - if block_type in dit_blocks: - return dit_blocks[block_type] - raise NotImplementedError(f"{block_type} is not supported") diff --git a/modelsx/dit/blocks/mmdit_window_block.py b/modelsx/dit/blocks/mmdit_window_block.py deleted file mode 100644 index eacaa093658f62fb483086215cfb6ac72a2dc9fd..0000000000000000000000000000000000000000 --- a/modelsx/dit/blocks/mmdit_window_block.py +++ /dev/null @@ -1,233 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Tuple, Union -import torch -from einops import rearrange -from torch import nn -from torch.nn import functional as F -from torch.nn.modules.utils import _triple - -from common.distributed.ops import ( - gather_heads, - gather_heads_scatter_seq, - gather_seq_scatter_heads_qkv, - scatter_heads, -) - -from ..attention import TorchAttention -from ..mlp import get_mlp -from ..mm import MMArg, MMModule -from ..modulation import ada_layer_type -from ..normalization import norm_layer_type -from ..rope import RotaryEmbedding3d - - -class MMWindowAttention(nn.Module): - def __init__( - self, - vid_dim: int, - txt_dim: int, - heads: int, - head_dim: int, - qk_bias: bool, - qk_rope: bool, - qk_norm: norm_layer_type, - qk_norm_eps: float, - window: Union[int, Tuple[int, int, int]], - window_method: str, - shared_qkv: bool, - ): - super().__init__() - dim = MMArg(vid_dim, txt_dim) - inner_dim = heads * head_dim - qkv_dim = inner_dim * 3 - - self.window = _triple(window) - self.window_method = window_method - assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) - - self.head_dim = head_dim - self.proj_qkv = MMModule(nn.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_qkv) - self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_qkv) - self.norm_q = MMModule(qk_norm, dim=head_dim, eps=qk_norm_eps, elementwise_affine=True) - self.norm_k = MMModule(qk_norm, dim=head_dim, eps=qk_norm_eps, elementwise_affine=True) - self.rope = RotaryEmbedding3d(dim=head_dim // 2) if qk_rope else None - self.attn = TorchAttention() - - def forward( - self, - vid: torch.FloatTensor, # b T H W c - txt: torch.FloatTensor, # b L c - txt_mask: torch.BoolTensor, # b L - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - # Project q, k, v. - vid_qkv, txt_qkv = self.proj_qkv(vid, txt) - vid_qkv = gather_seq_scatter_heads_qkv(vid_qkv, seq_dim=2) - _, T, H, W, _ = vid_qkv.shape - _, L, _ = txt.shape - - if self.window_method == "win": - nt, nh, nw = self.window - tt, hh, ww = T // nt, H // nh, W // nw - elif self.window_method == "win_by_size": - tt, hh, ww = self.window - tt, hh, ww = ( - tt if tt > 0 else T, - hh if hh > 0 else H, - ww if ww > 0 else W, - ) - nt, nh, nw = T // tt, H // hh, W // ww - else: - raise NotImplementedError - - vid_qkv = rearrange(vid_qkv, "b T H W (o h d) -> o b h (T H W) d", o=3, d=self.head_dim) - txt_qkv = rearrange(txt_qkv, "b L (o h d) -> o b h L d", o=3, d=self.head_dim) - txt_qkv = scatter_heads(txt_qkv, dim=2) - - vid_q, vid_k, vid_v = vid_qkv.unbind() - txt_q, txt_k, txt_v = txt_qkv.unbind() - - vid_q, txt_q = self.norm_q(vid_q, txt_q) - vid_k, txt_k = self.norm_k(vid_k, txt_k) - - if self.rope: - vid_q, vid_k = self.rope(vid_q, vid_k, (T, H, W)) - - def vid_window(v): - return rearrange( - v, - "b h (nt tt nh hh nw ww) d -> b h (nt nh nw) (tt hh ww) d", - hh=hh, - ww=ww, - tt=tt, - nh=nh, - nw=nw, - nt=nt, - ) - - def txt_window(t): - return rearrange(t, "b h L d -> b h 1 L d").expand(-1, -1, nt * nh * nw, -1, -1) - - # Process video attention. - vid_msk = F.pad(txt_mask, (tt * hh * ww, 0), value=True) - vid_msk = rearrange(vid_msk, "b l -> b 1 1 1 l").expand(-1, 1, 1, tt * hh * ww, -1) - vid_out = self.attn( - vid_window(vid_q), - torch.cat([vid_window(vid_k), txt_window(txt_k)], dim=-2), - torch.cat([vid_window(vid_v), txt_window(txt_v)], dim=-2), - vid_msk, - ) - vid_out = rearrange( - vid_out, - "b h (nt nh nw) (tt hh ww) d -> b (nt tt) (nh hh) (nw ww) (h d)", - hh=hh, - ww=ww, - tt=tt, - nh=nh, - nw=nw, - ) - vid_out = gather_heads_scatter_seq(vid_out, head_dim=4, seq_dim=2) - - # Process text attention. - txt_msk = F.pad(txt_mask, (T * H * W, 0), value=True) - txt_msk = rearrange(txt_msk, "b l -> b 1 1 l").expand(-1, 1, L, -1) - txt_out = self.attn( - txt_q, - torch.cat([vid_k, txt_k], dim=-2), - torch.cat([vid_v, txt_v], dim=-2), - txt_msk, - ) - txt_out = rearrange(txt_out, "b h L d -> b L (h d)") - txt_out = gather_heads(txt_out, dim=2) - - # Project output. - vid_out, txt_out = self.proj_out(vid_out, txt_out) - return vid_out, txt_out - - -class MMWindowTransformerBlock(nn.Module): - def __init__( - self, - *, - vid_dim: int, - txt_dim: int, - emb_dim: int, - heads: int, - head_dim: int, - expand_ratio: int, - norm: norm_layer_type, - norm_eps: float, - ada: ada_layer_type, - qk_bias: bool, - qk_rope: bool, - qk_norm: norm_layer_type, - window: Union[int, Tuple[int, int, int]], - window_method: str, - shared_qkv: bool, - shared_mlp: bool, - mlp_type: str, - **kwargs, - ): - super().__init__() - dim = MMArg(vid_dim, txt_dim) - self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False) - self.attn = MMWindowAttention( - vid_dim=vid_dim, - txt_dim=txt_dim, - heads=heads, - head_dim=head_dim, - qk_bias=qk_bias, - qk_rope=qk_rope, - qk_norm=qk_norm, - qk_norm_eps=norm_eps, - window=window, - window_method=window_method, - shared_qkv=shared_qkv, - ) - self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False) - self.mlp = MMModule( - get_mlp(mlp_type), - dim=dim, - expand_ratio=expand_ratio, - shared_weights=shared_mlp, - ) - self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"]) - - def forward( - self, - vid: torch.FloatTensor, - txt: torch.FloatTensor, - txt_mask: torch.BoolTensor, - emb: torch.FloatTensor, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - vid_attn, txt_attn = self.attn_norm(vid, txt) - vid_attn, txt_attn = self.ada(vid_attn, txt_attn, emb=emb, layer="attn", mode="in") - vid_attn, txt_attn = self.attn(vid_attn, txt_attn, txt_mask=txt_mask) - vid_attn, txt_attn = self.ada(vid_attn, txt_attn, emb=emb, layer="attn", mode="out") - vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) - - vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) - vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, emb=emb, layer="mlp", mode="in") - vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) - vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, emb=emb, layer="mlp", mode="out") - vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) - - return vid_mlp, txt_mlp diff --git a/modelsx/dit/embedding.py b/modelsx/dit/embedding.py deleted file mode 100644 index e972244f5767c9f34e5e77bb180ae720ce88b89c..0000000000000000000000000000000000000000 --- a/modelsx/dit/embedding.py +++ /dev/null @@ -1,62 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Optional, Union -import torch -from diffusers.models.embeddings import get_timestep_embedding -from torch import nn - - -def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]): - return emb1 if emb2 is None else emb1 + emb2 - - -class TimeEmbedding(nn.Module): - def __init__( - self, - sinusoidal_dim: int, - hidden_dim: int, - output_dim: int, - ): - super().__init__() - self.sinusoidal_dim = sinusoidal_dim - self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim) - self.proj_hid = nn.Linear(hidden_dim, hidden_dim) - self.proj_out = nn.Linear(hidden_dim, output_dim) - self.act = nn.SiLU() - - def forward( - self, - timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], - device: torch.device, - dtype: torch.dtype, - ) -> torch.FloatTensor: - if not torch.is_tensor(timestep): - timestep = torch.tensor([timestep], device=device, dtype=dtype) - if timestep.ndim == 0: - timestep = timestep[None] - - emb = get_timestep_embedding( - timesteps=timestep, - embedding_dim=self.sinusoidal_dim, - flip_sin_to_cos=False, - downscale_freq_shift=0, - ) - emb = emb.to(dtype) - emb = self.proj_in(emb) - emb = self.act(emb) - emb = self.proj_hid(emb) - emb = self.act(emb) - emb = self.proj_out(emb) - return emb diff --git a/modelsx/dit/mlp.py b/modelsx/dit/mlp.py deleted file mode 100644 index 2d05cb021f3e3c6ac05c0e7ae1aa8a6d29475b87..0000000000000000000000000000000000000000 --- a/modelsx/dit/mlp.py +++ /dev/null @@ -1,62 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Optional -import torch -import torch.nn.functional as F -from torch import nn - - -def get_mlp(mlp_type: Optional[str] = "normal"): - if mlp_type == "normal": - return MLP - elif mlp_type == "swiglu": - return SwiGLUMLP - - -class MLP(nn.Module): - def __init__( - self, - dim: int, - expand_ratio: int, - ): - super().__init__() - self.proj_in = nn.Linear(dim, dim * expand_ratio) - self.act = nn.GELU("tanh") - self.proj_out = nn.Linear(dim * expand_ratio, dim) - - def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: - x = self.proj_in(x) - x = self.act(x) - x = self.proj_out(x) - return x - - -class SwiGLUMLP(nn.Module): - def __init__( - self, - dim: int, - expand_ratio: int, - multiple_of: int = 256, - ): - super().__init__() - hidden_dim = int(2 * dim * expand_ratio / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False) - self.proj_out = nn.Linear(hidden_dim, dim, bias=False) - self.proj_in = nn.Linear(dim, hidden_dim, bias=False) - - def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: - x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) - return x diff --git a/modelsx/dit/mm.py b/modelsx/dit/mm.py deleted file mode 100644 index 49be1f5915a61d8ea27f3e3718f35e5c9af662e7..0000000000000000000000000000000000000000 --- a/modelsx/dit/mm.py +++ /dev/null @@ -1,67 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Tuple -import torch -from torch import nn - - -@dataclass -class MMArg: - vid: Any - txt: Any - - -def get_args(key: str, args: List[Any]) -> List[Any]: - return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] - - -def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: - return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} - - -class MMModule(nn.Module): - def __init__( - self, - module: Callable[..., nn.Module], - *args, - shared_weights: bool = False, - **kwargs, - ): - super().__init__() - self.shared_weights = shared_weights - if self.shared_weights: - assert get_args("vid", args) == get_args("txt", args) - assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) - self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) - else: - self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) - self.txt = module(*get_args("txt", args), **get_kwargs("txt", kwargs)) - - def forward( - self, - vid: torch.FloatTensor, - txt: torch.FloatTensor, - *args, - **kwargs, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - vid_module = self.vid if not self.shared_weights else self.all - txt_module = self.txt if not self.shared_weights else self.all - vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) - txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) - return vid, txt diff --git a/modelsx/dit/modulation.py b/modelsx/dit/modulation.py deleted file mode 100644 index cd3b41f6c457396ac65403d88edc3d5ad3382262..0000000000000000000000000000000000000000 --- a/modelsx/dit/modulation.py +++ /dev/null @@ -1,97 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Callable, List, Optional -import torch -from einops import rearrange -from torch import nn - -from common.cache import Cache -from common.distributed.ops import slice_inputs - -# (dim: int, emb_dim: int) -ada_layer_type = Callable[[int, int], nn.Module] - - -def get_ada_layer(ada_layer: str) -> ada_layer_type: - if ada_layer == "single": - return AdaSingle - raise NotImplementedError(f"{ada_layer} is not supported") - - -def expand_dims(x: torch.Tensor, dim: int, ndim: int): - """ - Expand tensor "x" to "ndim" by adding empty dims at "dim". - Example: x is (b d), target ndim is 5, add dim at 1, return (b 1 1 1 d). - """ - shape = x.shape - shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] - return x.reshape(shape) - - -class AdaSingle(nn.Module): - def __init__( - self, - dim: int, - emb_dim: int, - layers: List[str], - ): - assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" - super().__init__() - self.dim = dim - self.emb_dim = emb_dim - self.layers = layers - for l in layers: - self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5)) - self.register_parameter(f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1)) - self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5)) - - def forward( - self, - hid: torch.FloatTensor, # b ... c - emb: torch.FloatTensor, # b d - layer: str, - mode: str, - cache: Cache = Cache(disable=True), - branch_tag: str = "", - hid_len: Optional[torch.LongTensor] = None, # b - ) -> torch.FloatTensor: - idx = self.layers.index(layer) - emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] - emb = expand_dims(emb, 1, hid.ndim + 1) - - if hid_len is not None: - emb = cache( - f"emb_repeat_{idx}_{branch_tag}", - lambda: slice_inputs( - torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]), - dim=0, - ), - ) - - shiftA, scaleA, gateA = emb.unbind(-1) - shiftB, scaleB, gateB = ( - getattr(self, f"{layer}_shift"), - getattr(self, f"{layer}_scale"), - getattr(self, f"{layer}_gate"), - ) - - if mode == "in": - return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) - if mode == "out": - return hid.mul_(gateA + gateB) - raise NotImplementedError - - def extra_repr(self) -> str: - return f"dim={self.dim}, emb_dim={self.emb_dim}, layers={self.layers}" \ No newline at end of file diff --git a/modelsx/dit/na.py b/modelsx/dit/na.py deleted file mode 100644 index 0dbd546c4705b3b9c7c19a9823f9d113a0447616..0000000000000000000000000000000000000000 --- a/modelsx/dit/na.py +++ /dev/null @@ -1,241 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from itertools import chain -from typing import Callable, Dict, List, Tuple -import einops -import torch - - -def flatten( - hid: List[torch.FloatTensor], # List of (*** c) -) -> Tuple[ - torch.FloatTensor, # (L c) - torch.LongTensor, # (b n) -]: - assert len(hid) > 0 - shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) - hid = torch.cat([x.flatten(0, -2) for x in hid]) - return hid, shape - - -def unflatten( - hid: torch.FloatTensor, # (L c) or (L ... c) - hid_shape: torch.LongTensor, # (b n) -) -> List[torch.Tensor]: # List of (*** c) or (*** ... c) - hid_len = hid_shape.prod(-1) - hid = hid.split(hid_len.tolist()) - hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] - return hid - - -def concat( - vid: torch.FloatTensor, # (VL ... c) - txt: torch.FloatTensor, # (TL ... c) - vid_len: torch.LongTensor, # (b) - txt_len: torch.LongTensor, # (b) -) -> torch.FloatTensor: # (L ... c) - vid = torch.split(vid, vid_len.tolist()) - txt = torch.split(txt, txt_len.tolist()) - return torch.cat(list(chain(*zip(vid, txt)))) - - -def concat_idx( - vid_len: torch.LongTensor, # (b) - txt_len: torch.LongTensor, # (b) -) -> Tuple[ - Callable, - Callable, -]: - device = vid_len.device - vid_idx = torch.arange(vid_len.sum(), device=device) - txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) - tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len) - src_idx = torch.argsort(tgt_idx) - return ( - lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx), - lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]), - ) - - -def unconcat( - all: torch.FloatTensor, # (L ... c) - vid_len: torch.LongTensor, # (b) - txt_len: torch.LongTensor, # (b) -) -> Tuple[ - torch.FloatTensor, # (VL ... c) - torch.FloatTensor, # (TL ... c) -]: - interleave_len = list(chain(*zip(vid_len.tolist(), txt_len.tolist()))) - all = all.split(interleave_len) - vid = torch.cat(all[0::2]) - txt = torch.cat(all[1::2]) - return vid, txt - - -def repeat_concat( - vid: torch.FloatTensor, # (VL ... c) - txt: torch.FloatTensor, # (TL ... c) - vid_len: torch.LongTensor, # (n*b) - txt_len: torch.LongTensor, # (b) - txt_repeat: List, # (n) -) -> torch.FloatTensor: # (L ... c) - vid = torch.split(vid, vid_len.tolist()) - txt = torch.split(txt, txt_len.tolist()) - txt = [[x] * n for x, n in zip(txt, txt_repeat)] - txt = list(chain(*txt)) - return torch.cat(list(chain(*zip(vid, txt)))) - - -def repeat_concat_idx( - vid_len: torch.LongTensor, # (n*b) - txt_len: torch.LongTensor, # (b) - txt_repeat: torch.LongTensor, # (n) -) -> Tuple[ - Callable, - Callable, -]: - device = vid_len.device - vid_idx = torch.arange(vid_len.sum(), device=device) - txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) - txt_repeat_list = txt_repeat.tolist() - tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) - src_idx = torch.argsort(tgt_idx) - txt_idx_len = len(tgt_idx) - len(vid_idx) - repeat_txt_len = (txt_len * txt_repeat).tolist() - - def unconcat_coalesce(all): - """ - Un-concat vid & txt, and coalesce the repeated txt. - e.g. vid [0 1 2 3 4 5 6 7 8] -> 3 splits -> [0 1 2] [3 4 5] [6 7 8] - txt [9 10] - repeat_concat ==> [0 1 2 9 10 3 4 5 9 10 6 7 8 9 10] - 1. argsort re-index ==> [0 1 2 3 4 5 6 7 8 9 9 9 10 10 10] - split ==> vid_out [0 1 2 3 4 5 6 7 8] txt_out [9 9 9 10 10 10] - 2. reshape & mean for each sample to coalesce the repeated txt. - """ - vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) - txt_out_coalesced = [] - for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): - txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) - txt_out_coalesced.append(txt) - return vid_out, torch.cat(txt_out_coalesced) - - # Note: Backward of torch.index_select is non-deterministic when existing repeated index, - # the difference may cumulative like torch.repeat_interleave, so we use vanilla index here. - return ( - lambda vid, txt: torch.cat([vid, txt])[tgt_idx], - lambda all: unconcat_coalesce(all), - ) - - -def rearrange( - hid: torch.FloatTensor, # (L c) - hid_shape: torch.LongTensor, # (b n) - pattern: str, - **kwargs: Dict[str, int], -) -> Tuple[ - torch.FloatTensor, - torch.LongTensor, -]: - return flatten([einops.rearrange(h, pattern, **kwargs) for h in unflatten(hid, hid_shape)]) - - -def rearrange_idx( - hid_shape: torch.LongTensor, # (b n) - pattern: str, - **kwargs: Dict[str, int], -) -> Tuple[Callable, Callable, torch.LongTensor]: - hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) - tgt_idx, tgt_shape = rearrange(hid_idx, hid_shape, pattern, **kwargs) - tgt_idx = tgt_idx.squeeze(-1) - src_idx = torch.argsort(tgt_idx) - return ( - lambda hid: torch.index_select(hid, 0, tgt_idx), - lambda hid: torch.index_select(hid, 0, src_idx), - tgt_shape, - ) - - -def repeat( - hid: torch.FloatTensor, # (L c) - hid_shape: torch.LongTensor, # (b n) - pattern: str, - **kwargs: Dict[str, torch.LongTensor], # (b) -) -> Tuple[ - torch.FloatTensor, - torch.LongTensor, -]: - hid = unflatten(hid, hid_shape) - kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] - return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) - - -def pack( - samples: List[torch.Tensor], # List of (h w c). -) -> Tuple[ - List[torch.Tensor], # groups [(b1 h1 w1 c1), (b2 h2 w2 c2)] - List[List[int]], # reversal indices. -]: - batches = {} - indices = {} - for i, sample in enumerate(samples): - shape = sample.shape - batches[shape] = batches.get(shape, []) - indices[shape] = indices.get(shape, []) - batches[shape].append(sample) - indices[shape].append(i) - - batches = list(map(torch.stack, batches.values())) - indices = list(indices.values()) - return batches, indices - - -def unpack( - batches: List[torch.Tensor], - indices: List[List[int]], -) -> List[torch.Tensor]: - samples = [None] * (max(chain(*indices)) + 1) - for batch, index in zip(batches, indices): - for sample, i in zip(batch.unbind(), index): - samples[i] = sample - return samples - - -def window( - hid: torch.FloatTensor, # (L c) - hid_shape: torch.LongTensor, # (b n) - window_fn: Callable[[torch.Tensor], List[torch.Tensor]], -): - hid = unflatten(hid, hid_shape) - hid = list(map(window_fn, hid)) - hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) - hid, hid_shape = flatten(list(chain(*hid))) - return hid, hid_shape, hid_windows - - -def window_idx( - hid_shape: torch.LongTensor, # (b n) - window_fn: Callable[[torch.Tensor], List[torch.Tensor]], -): - hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) - tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) - tgt_idx = tgt_idx.squeeze(-1) - src_idx = torch.argsort(tgt_idx) - return ( - lambda hid: torch.index_select(hid, 0, tgt_idx), - lambda hid: torch.index_select(hid, 0, src_idx), - tgt_shape, - tgt_windows, - ) diff --git a/modelsx/dit/nablocks/__init__.py b/modelsx/dit/nablocks/__init__.py deleted file mode 100644 index afa206db157786d9e4cf830bec09bd3a390bd9a8..0000000000000000000000000000000000000000 --- a/modelsx/dit/nablocks/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from .mmsr_block import NaMMSRTransformerBlock - -nadit_blocks = { - "mmdit_sr": NaMMSRTransformerBlock, -} - - -def get_nablock(block_type: str): - if block_type in nadit_blocks: - return nadit_blocks[block_type] - raise NotImplementedError(f"{block_type} is not supported") diff --git a/modelsx/dit/nablocks/mmsr_block.py b/modelsx/dit/nablocks/mmsr_block.py deleted file mode 100644 index b75652efc070188268bb84b35352b543e1a3746b..0000000000000000000000000000000000000000 --- a/modelsx/dit/nablocks/mmsr_block.py +++ /dev/null @@ -1,248 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Tuple, Union -import torch -from einops import rearrange -from torch.nn import functional as F - -# from ..cache import Cache -from common.cache import Cache -from common.distributed.ops import gather_heads_scatter_seq, gather_seq_scatter_heads_qkv - -from .. import na -from ..attention import FlashAttentionVarlen -from ..blocks.mmdit_window_block import MMWindowAttention, MMWindowTransformerBlock -from ..mm import MMArg -from ..modulation import ada_layer_type -from ..normalization import norm_layer_type -from ..rope import NaRotaryEmbedding3d -from ..window import get_window_op - - -class NaSwinAttention(MMWindowAttention): - def __init__( - self, - vid_dim: int, - txt_dim: int, - heads: int, - head_dim: int, - qk_bias: bool, - qk_rope: bool, - qk_norm: norm_layer_type, - qk_norm_eps: float, - window: Union[int, Tuple[int, int, int]], - window_method: str, - shared_qkv: bool, - **kwargs, - ): - super().__init__( - vid_dim=vid_dim, - txt_dim=txt_dim, - heads=heads, - head_dim=head_dim, - qk_bias=qk_bias, - qk_rope=qk_rope, - qk_norm=qk_norm, - qk_norm_eps=qk_norm_eps, - window=window, - window_method=window_method, - shared_qkv=shared_qkv, - ) - self.rope = NaRotaryEmbedding3d(dim=head_dim // 2) if qk_rope else None - self.attn = FlashAttentionVarlen() - self.window_op = get_window_op(window_method) - - def forward( - self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - - vid_qkv, txt_qkv = self.proj_qkv(vid, txt) - vid_qkv = gather_seq_scatter_heads_qkv( - vid_qkv, - seq_dim=0, - qkv_shape=vid_shape, - cache=cache.namespace("vid"), - ) - txt_qkv = gather_seq_scatter_heads_qkv( - txt_qkv, - seq_dim=0, - qkv_shape=txt_shape, - cache=cache.namespace("txt"), - ) - - # re-org the input seq for window attn - cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") - - def make_window(x: torch.Tensor): - t, h, w, _ = x.shape - window_slices = self.window_op((t, h, w), self.window) - return [x[st, sh, sw] for (st, sh, sw) in window_slices] - - window_partition, window_reverse, window_shape, window_count = cache_win( - "win_transform", - lambda: na.window_idx(vid_shape, make_window), - ) - vid_qkv_win = window_partition(vid_qkv) - - vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim) - txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) - - vid_q, vid_k, vid_v = vid_qkv_win.unbind(1) - txt_q, txt_k, txt_v = txt_qkv.unbind(1) - - vid_q, txt_q = self.norm_q(vid_q, txt_q) - vid_k, txt_k = self.norm_k(vid_k, txt_k) - - txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) - - vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) - txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) - all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) - concat_win, unconcat_win = cache_win( - "mm_pnp", lambda: na.repeat_concat_idx(vid_len_win, txt_len, window_count) - ) - - # window rope - if self.rope: - vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) - - out = self.attn( - q=concat_win(vid_q, txt_q).bfloat16(), - k=concat_win(vid_k, txt_k).bfloat16(), - v=concat_win(vid_v, txt_v).bfloat16(), - cu_seqlens_q=cache_win( - "vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() - ), - cu_seqlens_k=cache_win( - "vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() - ), - max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()), - max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()), - ).type_as(vid_q) - - # text pooling - vid_out, txt_out = unconcat_win(out) - - vid_out = rearrange(vid_out, "l h d -> l (h d)") - txt_out = rearrange(txt_out, "l h d -> l (h d)") - vid_out = window_reverse(vid_out) - - vid_out = gather_heads_scatter_seq(vid_out, head_dim=1, seq_dim=0) - txt_out = gather_heads_scatter_seq(txt_out, head_dim=1, seq_dim=0) - - vid_out, txt_out = self.proj_out(vid_out, txt_out) - - return vid_out, txt_out - - -class NaMMSRTransformerBlock(MMWindowTransformerBlock): - def __init__( - self, - *, - vid_dim: int, - txt_dim: int, - emb_dim: int, - heads: int, - head_dim: int, - expand_ratio: int, - norm: norm_layer_type, - norm_eps: float, - ada: ada_layer_type, - qk_bias: bool, - qk_rope: bool, - qk_norm: norm_layer_type, - shared_qkv: bool, - shared_mlp: bool, - mlp_type: str, - **kwargs, - ): - super().__init__( - vid_dim=vid_dim, - txt_dim=txt_dim, - emb_dim=emb_dim, - heads=heads, - head_dim=head_dim, - expand_ratio=expand_ratio, - norm=norm, - norm_eps=norm_eps, - ada=ada, - qk_bias=qk_bias, - qk_rope=qk_rope, - qk_norm=qk_norm, - shared_qkv=shared_qkv, - shared_mlp=shared_mlp, - mlp_type=mlp_type, - **kwargs, - ) - - self.attn = NaSwinAttention( - vid_dim=vid_dim, - txt_dim=txt_dim, - heads=heads, - head_dim=head_dim, - qk_bias=qk_bias, - qk_rope=qk_rope, - qk_norm=qk_norm, - qk_norm_eps=norm_eps, - shared_qkv=shared_qkv, - **kwargs, - ) - - def forward( - self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - emb: torch.FloatTensor, - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - torch.LongTensor, - torch.LongTensor, - ]: - hid_len = MMArg( - cache("vid_len", lambda: vid_shape.prod(-1)), - cache("txt_len", lambda: txt_shape.prod(-1)), - ) - ada_kwargs = { - "emb": emb, - "hid_len": hid_len, - "cache": cache, - "branch_tag": MMArg("vid", "txt"), - } - - vid_attn, txt_attn = self.attn_norm(vid, txt) - vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) - vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) - vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) - vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) - - vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) - vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) - vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) - vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) - vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) - - return vid_mlp, txt_mlp, vid_shape, txt_shape diff --git a/modelsx/dit/nadit.py b/modelsx/dit/nadit.py deleted file mode 100644 index 7e778236db6a70f49a364db6e84bf7539c0b58ac..0000000000000000000000000000000000000000 --- a/modelsx/dit/nadit.py +++ /dev/null @@ -1,350 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from dataclasses import dataclass -from typing import Optional, Tuple, Union, Callable -import torch -from torch import nn - -from common.cache import Cache -from common.distributed.ops import slice_inputs - -from . import na -from .embedding import TimeEmbedding -from .modulation import get_ada_layer -from .nablocks import get_nablock -from .normalization import get_norm_layer -from .patch import NaPatchIn, NaPatchOut - -# Fake func, no checkpointing is required for inference -def gradient_checkpointing(module: Union[Callable, nn.Module], *args, enabled: bool, **kwargs): - return module(*args, **kwargs) - -@dataclass -class NaDiTOutput: - vid_sample: torch.Tensor - - -class NaDiT(nn.Module): - """ - Native Resolution Diffusion Transformer (NaDiT) - """ - - gradient_checkpointing = False - - def __init__( - self, - vid_in_channels: int, - vid_out_channels: int, - vid_dim: int, - txt_in_dim: Optional[int], - txt_dim: Optional[int], - emb_dim: int, - heads: int, - head_dim: int, - expand_ratio: int, - norm: Optional[str], - norm_eps: float, - ada: str, - qk_bias: bool, - qk_rope: bool, - qk_norm: Optional[str], - patch_size: Union[int, Tuple[int, int, int]], - num_layers: int, - block_type: Union[str, Tuple[str]], - shared_qkv: bool = False, - shared_mlp: bool = False, - mlp_type: str = "normal", - window: Optional[Tuple] = None, - window_method: Optional[Tuple[str]] = None, - temporal_window_size: int = None, - temporal_shifted: bool = False, - **kwargs, - ): - ada = get_ada_layer(ada) - norm = get_norm_layer(norm) - qk_norm = get_norm_layer(qk_norm) - if isinstance(block_type, str): - block_type = [block_type] * num_layers - elif len(block_type) != num_layers: - raise ValueError("The ``block_type`` list should equal to ``num_layers``.") - super().__init__() - self.vid_in = NaPatchIn( - in_channels=vid_in_channels, - patch_size=patch_size, - dim=vid_dim, - ) - self.txt_in = ( - nn.Linear(txt_in_dim, txt_dim) - if txt_in_dim and txt_in_dim != txt_dim - else nn.Identity() - ) - self.emb_in = TimeEmbedding( - sinusoidal_dim=256, - hidden_dim=max(vid_dim, txt_dim), - output_dim=emb_dim, - ) - - if window is None or isinstance(window[0], int): - window = [window] * num_layers - if window_method is None or isinstance(window_method, str): - window_method = [window_method] * num_layers - if temporal_window_size is None or isinstance(temporal_window_size, int): - temporal_window_size = [temporal_window_size] * num_layers - if temporal_shifted is None or isinstance(temporal_shifted, bool): - temporal_shifted = [temporal_shifted] * num_layers - - self.blocks = nn.ModuleList( - [ - get_nablock(block_type[i])( - vid_dim=vid_dim, - txt_dim=txt_dim, - emb_dim=emb_dim, - heads=heads, - head_dim=head_dim, - expand_ratio=expand_ratio, - norm=norm, - norm_eps=norm_eps, - ada=ada, - qk_bias=qk_bias, - qk_rope=qk_rope, - qk_norm=qk_norm, - shared_qkv=shared_qkv, - shared_mlp=shared_mlp, - mlp_type=mlp_type, - window=window[i], - window_method=window_method[i], - temporal_window_size=temporal_window_size[i], - temporal_shifted=temporal_shifted[i], - **kwargs, - ) - for i in range(num_layers) - ] - ) - self.vid_out = NaPatchOut( - out_channels=vid_out_channels, - patch_size=patch_size, - dim=vid_dim, - ) - - self.need_txt_repeat = block_type[0] in [ - "mmdit_stwin", - "mmdit_stwin_spatial", - "mmdit_stwin_3d_spatial", - ] - - def set_gradient_checkpointing(self, enable: bool): - self.gradient_checkpointing = enable - - def forward( - self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b - disable_cache: bool = True, # for test - ): - # Text input. - if txt_shape.size(-1) == 1 and self.need_txt_repeat: - txt, txt_shape = na.repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) - # slice vid after patching in when using sequence parallelism - txt = slice_inputs(txt, dim=0) - txt = self.txt_in(txt) - - # Video input. - # Sequence parallel slicing is done inside patching class. - vid, vid_shape = self.vid_in(vid, vid_shape) - - # Embedding input. - emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) - - # Body - cache = Cache(disable=disable_cache) - for i, block in enumerate(self.blocks): - vid, txt, vid_shape, txt_shape = gradient_checkpointing( - enabled=(self.gradient_checkpointing and self.training), - module=block, - vid=vid, - txt=txt, - vid_shape=vid_shape, - txt_shape=txt_shape, - emb=emb, - cache=cache, - ) - - vid, vid_shape = self.vid_out(vid, vid_shape, cache) - return NaDiTOutput(vid_sample=vid) - - -class NaDiTUpscaler(nn.Module): - """ - Native Resolution Diffusion Transformer (NaDiT) - """ - - gradient_checkpointing = False - - def __init__( - self, - vid_in_channels: int, - vid_out_channels: int, - vid_dim: int, - txt_in_dim: Optional[int], - txt_dim: Optional[int], - emb_dim: int, - heads: int, - head_dim: int, - expand_ratio: int, - norm: Optional[str], - norm_eps: float, - ada: str, - qk_bias: bool, - qk_rope: bool, - qk_norm: Optional[str], - patch_size: Union[int, Tuple[int, int, int]], - num_layers: int, - block_type: Union[str, Tuple[str]], - shared_qkv: bool = False, - shared_mlp: bool = False, - mlp_type: str = "normal", - window: Optional[Tuple] = None, - window_method: Optional[Tuple[str]] = None, - temporal_window_size: int = None, - temporal_shifted: bool = False, - **kwargs, - ): - ada = get_ada_layer(ada) - norm = get_norm_layer(norm) - qk_norm = get_norm_layer(qk_norm) - if isinstance(block_type, str): - block_type = [block_type] * num_layers - elif len(block_type) != num_layers: - raise ValueError("The ``block_type`` list should equal to ``num_layers``.") - super().__init__() - self.vid_in = NaPatchIn( - in_channels=vid_in_channels, - patch_size=patch_size, - dim=vid_dim, - ) - self.txt_in = ( - nn.Linear(txt_in_dim, txt_dim) - if txt_in_dim and txt_in_dim != txt_dim - else nn.Identity() - ) - self.emb_in = TimeEmbedding( - sinusoidal_dim=256, - hidden_dim=max(vid_dim, txt_dim), - output_dim=emb_dim, - ) - - self.emb_scale = TimeEmbedding( - sinusoidal_dim=256, - hidden_dim=max(vid_dim, txt_dim), - output_dim=emb_dim, - ) - - if window is None or isinstance(window[0], int): - window = [window] * num_layers - if window_method is None or isinstance(window_method, str): - window_method = [window_method] * num_layers - if temporal_window_size is None or isinstance(temporal_window_size, int): - temporal_window_size = [temporal_window_size] * num_layers - if temporal_shifted is None or isinstance(temporal_shifted, bool): - temporal_shifted = [temporal_shifted] * num_layers - - self.blocks = nn.ModuleList( - [ - get_nablock(block_type[i])( - vid_dim=vid_dim, - txt_dim=txt_dim, - emb_dim=emb_dim, - heads=heads, - head_dim=head_dim, - expand_ratio=expand_ratio, - norm=norm, - norm_eps=norm_eps, - ada=ada, - qk_bias=qk_bias, - qk_rope=qk_rope, - qk_norm=qk_norm, - shared_qkv=shared_qkv, - shared_mlp=shared_mlp, - mlp_type=mlp_type, - window=window[i], - window_method=window_method[i], - temporal_window_size=temporal_window_size[i], - temporal_shifted=temporal_shifted[i], - **kwargs, - ) - for i in range(num_layers) - ] - ) - self.vid_out = NaPatchOut( - out_channels=vid_out_channels, - patch_size=patch_size, - dim=vid_dim, - ) - - self.need_txt_repeat = block_type[0] in [ - "mmdit_stwin", - "mmdit_stwin_spatial", - "mmdit_stwin_3d_spatial", - ] - - def set_gradient_checkpointing(self, enable: bool): - self.gradient_checkpointing = enable - - def forward( - self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b - downscale: Union[int, float, torch.IntTensor, torch.FloatTensor], # b - disable_cache: bool = False, # for test - ): - - # Text input. - if txt_shape.size(-1) == 1 and self.need_txt_repeat: - txt, txt_shape = na.repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) - # slice vid after patching in when using sequence parallelism - txt = slice_inputs(txt, dim=0) - txt = self.txt_in(txt) - - # Video input. - # Sequence parallel slicing is done inside patching class. - vid, vid_shape = self.vid_in(vid, vid_shape) - - # Embedding input. - emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) - emb_scale = self.emb_scale(downscale, device=vid.device, dtype=vid.dtype) - emb = emb + emb_scale - - # Body - cache = Cache(disable=disable_cache) - for i, block in enumerate(self.blocks): - vid, txt, vid_shape, txt_shape = gradient_checkpointing( - enabled=(self.gradient_checkpointing and self.training), - module=block, - vid=vid, - txt=txt, - vid_shape=vid_shape, - txt_shape=txt_shape, - emb=emb, - cache=cache, - ) - - vid, vid_shape = self.vid_out(vid, vid_shape, cache) - return NaDiTOutput(vid_sample=vid) diff --git a/modelsx/dit/normalization.py b/modelsx/dit/normalization.py deleted file mode 100644 index 98827a9c71f9fd6e461937774d022b68844aee34..0000000000000000000000000000000000000000 --- a/modelsx/dit/normalization.py +++ /dev/null @@ -1,63 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Callable, Optional -from diffusers.models.normalization import RMSNorm -from torch import nn - -# (dim: int, eps: float, elementwise_affine: bool) -norm_layer_type = Callable[[int, float, bool], nn.Module] - - -def get_norm_layer(norm_type: Optional[str]) -> norm_layer_type: - - def _norm_layer(dim: int, eps: float, elementwise_affine: bool): - if norm_type is None: - return nn.Identity() - - if norm_type == "layer": - return nn.LayerNorm( - normalized_shape=dim, - eps=eps, - elementwise_affine=elementwise_affine, - ) - - if norm_type == "rms": - return RMSNorm( - dim=dim, - eps=eps, - elementwise_affine=elementwise_affine, - ) - - if norm_type == "fusedln": - from apex.normalization import FusedLayerNorm - - return FusedLayerNorm( - normalized_shape=dim, - elementwise_affine=elementwise_affine, - eps=eps, - ) - - if norm_type == "fusedrms": - from apex.normalization import FusedRMSNorm - - return FusedRMSNorm( - normalized_shape=dim, - elementwise_affine=elementwise_affine, - eps=eps, - ) - - raise NotImplementedError(f"{norm_type} is not supported") - - return _norm_layer diff --git a/modelsx/dit/patch.py b/modelsx/dit/patch.py deleted file mode 100644 index d98158e34a94e0447ed82b92fbfa289bf1a2be1d..0000000000000000000000000000000000000000 --- a/modelsx/dit/patch.py +++ /dev/null @@ -1,112 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Tuple, Union -import torch -from einops import rearrange -from torch import nn -from torch.nn.modules.utils import _triple - -from common.cache import Cache -from common.distributed.ops import gather_outputs, slice_inputs - -from . import na - - -class PatchIn(nn.Module): - def __init__( - self, - in_channels: int, - patch_size: Union[int, Tuple[int, int, int]], - dim: int, - ): - super().__init__() - t, h, w = _triple(patch_size) - self.patch_size = t, h, w - self.proj = nn.Linear(in_channels * t * h * w, dim) - - def forward( - self, - vid: torch.Tensor, - ) -> torch.Tensor: - t, h, w = self.patch_size - vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) - vid = self.proj(vid) - return vid - - -class PatchOut(nn.Module): - def __init__( - self, - out_channels: int, - patch_size: Union[int, Tuple[int, int, int]], - dim: int, - ): - super().__init__() - t, h, w = _triple(patch_size) - self.patch_size = t, h, w - self.proj = nn.Linear(dim, out_channels * t * h * w) - - def forward( - self, - vid: torch.Tensor, - ) -> torch.Tensor: - t, h, w = self.patch_size - vid = self.proj(vid) - vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) - return vid - - -class NaPatchIn(PatchIn): - def forward( - self, - vid: torch.Tensor, # l c - vid_shape: torch.LongTensor, - ) -> torch.Tensor: - t, h, w = self.patch_size - if not (t == h == w == 1): - vid, vid_shape = na.rearrange( - vid, vid_shape, "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w - ) - # slice vid after patching in when using sequence parallelism - vid = slice_inputs(vid, dim=0) - vid = self.proj(vid) - return vid, vid_shape - - -class NaPatchOut(PatchOut): - def forward( - self, - vid: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), - ) -> Tuple[ - torch.FloatTensor, - torch.LongTensor, - ]: - t, h, w = self.patch_size - vid = self.proj(vid) - # gather vid before patching out when enabling sequence parallelism - vid = gather_outputs( - vid, - gather_dim=0, - padding_dim=0, - unpad_shape=vid_shape, - cache=cache.namespace("vid"), - ) - if not (t == h == w == 1): - vid, vid_shape = na.rearrange( - vid, vid_shape, "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w - ) - return vid, vid_shape diff --git a/modelsx/dit/rope.py b/modelsx/dit/rope.py deleted file mode 100644 index 32a4815a1b349001cb86ea6d752fb4f91f6e655e..0000000000000000000000000000000000000000 --- a/modelsx/dit/rope.py +++ /dev/null @@ -1,101 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from functools import lru_cache -from typing import Tuple -import torch -from einops import rearrange -from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb -from torch import nn - -from common.cache import Cache - - -class RotaryEmbeddingBase(nn.Module): - def __init__(self, dim: int, rope_dim: int): - super().__init__() - self.rope = RotaryEmbedding( - dim=dim // rope_dim, - freqs_for="pixel", - max_freq=256, - ) - # 1. Set model.requires_grad_(True) after model creation will make - # the `requires_grad=False` for rope freqs no longer hold. - # 2. Even if we don't set requires_grad_(True) explicitly, - # FSDP is not memory efficient when handling fsdp_wrap - # with mixed requires_grad=True/False. - # With above consideration, it is easier just remove the freqs - # out of nn.Parameters when `learned_freq=False` - freqs = self.rope.freqs - del self.rope.freqs - self.rope.register_buffer("freqs", freqs.data) - - @lru_cache(maxsize=128) - def get_axial_freqs(self, *dims): - return self.rope.get_axial_freqs(*dims) - - -class RotaryEmbedding3d(RotaryEmbeddingBase): - def __init__(self, dim: int): - super().__init__(dim, rope_dim=3) - - def forward( - self, - q: torch.FloatTensor, # b h l d - k: torch.FloatTensor, # b h l d - size: Tuple[int, int, int], - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - T, H, W = size - freqs = self.get_axial_freqs(T, H, W) - q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) - k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) - q = apply_rotary_emb(freqs, q) - k = apply_rotary_emb(freqs, k) - q = rearrange(q, "b h T H W d -> b h (T H W) d") - k = rearrange(k, "b h T H W d -> b h (T H W) d") - return q, k - - -class NaRotaryEmbedding3d(RotaryEmbedding3d): - def forward( - self, - q: torch.FloatTensor, # L h d - k: torch.FloatTensor, # L h d - shape: torch.LongTensor, - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape)) - q = rearrange(q, "L h d -> h L d") - k = rearrange(k, "L h d -> h L d") - q = apply_rotary_emb(freqs, q.float()).to(q.dtype) - k = apply_rotary_emb(freqs, k.float()).to(k.dtype) - q = rearrange(q, "h L d -> L h d") - k = rearrange(k, "h L d -> L h d") - return q, k - - def get_freqs( - self, - shape: torch.LongTensor, - ) -> torch.Tensor: - freq_list = [] - for f, h, w in shape.tolist(): - freqs = self.get_axial_freqs(f, h, w) - freq_list.append(freqs.view(-1, freqs.size(-1))) - return torch.cat(freq_list, dim=0) diff --git a/modelsx/dit/window.py b/modelsx/dit/window.py deleted file mode 100644 index b7475921ae283cf76d82bff7521233c133f54bfd..0000000000000000000000000000000000000000 --- a/modelsx/dit/window.py +++ /dev/null @@ -1,83 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from math import ceil -from typing import Tuple -import math - -def get_window_op(name: str): - if name == "720pwin_by_size_bysize": - return make_720Pwindows_bysize - if name == "720pswin_by_size_bysize": - return make_shifted_720Pwindows_bysize - raise ValueError(f"Unknown windowing method: {name}") - - -# -------------------------------- Windowing -------------------------------- # -def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): - t, h, w = size - resized_nt, resized_nh, resized_nw = num_windows - #cal windows under 720p - scale = math.sqrt((45 * 80) / (h * w)) - resized_h, resized_w = round(h * scale), round(w * scale) - wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. - wt = ceil(min(t, 30) / resized_nt) # window size. - nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. - return [ - ( - slice(it * wt, min((it + 1) * wt, t)), - slice(ih * wh, min((ih + 1) * wh, h)), - slice(iw * ww, min((iw + 1) * ww, w)), - ) - for iw in range(nw) - if min((iw + 1) * ww, w) > iw * ww - for ih in range(nh) - if min((ih + 1) * wh, h) > ih * wh - for it in range(nt) - if min((it + 1) * wt, t) > it * wt - ] - -def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): - t, h, w = size - resized_nt, resized_nh, resized_nw = num_windows - #cal windows under 720p - scale = math.sqrt((45 * 80) / (h * w)) - resized_h, resized_w = round(h * scale), round(w * scale) - wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. - wt = ceil(min(t, 30) / resized_nt) # window size. - - st, sh, sw = ( # shift size. - 0.5 if wt < t else 0, - 0.5 if wh < h else 0, - 0.5 if ww < w else 0, - ) - nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. - nt, nh, nw = ( # number of window. - nt + 1 if st > 0 else 1, - nh + 1 if sh > 0 else 1, - nw + 1 if sw > 0 else 1, - ) - return [ - ( - slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)), - slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)), - slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)), - ) - for iw in range(nw) - if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0) - for ih in range(nh) - if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0) - for it in range(nt) - if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0) - ] diff --git a/modelsx/dit_v2/attention.py b/modelsx/dit_v2/attention.py deleted file mode 100644 index 9201fe095778db21ebd3384d163b0ccac4b35664..0000000000000000000000000000000000000000 --- a/modelsx/dit_v2/attention.py +++ /dev/null @@ -1,46 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -import torch -import torch.nn.functional as F - -from flash_attn import flash_attn_varlen_func - -from torch import nn - -class TorchAttention(nn.Module): - def tflops(self, args, kwargs, output) -> float: - assert len(args) == 0 or len(args) > 2, "query, key should both provided by args / kwargs" - q = kwargs.get("query") or args[0] - k = kwargs.get("key") or args[1] - b, h, sq, d = q.shape - b, h, sk, d = k.shape - return b * h * (4 * d * (sq / 1e6) * (sk / 1e6)) - - def forward(self, *args, **kwargs): - return F.scaled_dot_product_attention(*args, **kwargs) - - -class FlashAttentionVarlen(nn.Module): - def tflops(self, args, kwargs, output) -> float: - cu_seqlens_q = kwargs["cu_seqlens_q"] - cu_seqlens_k = kwargs["cu_seqlens_k"] - _, h, d = output.shape - seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]) / 1e6 - seqlens_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]) / 1e6 - return h * (4 * d * (seqlens_q * seqlens_k).sum()) - - def forward(self, *args, **kwargs): - kwargs["deterministic"] = torch.are_deterministic_algorithms_enabled() - return flash_attn_varlen_func(*args, **kwargs) diff --git a/modelsx/dit_v2/embedding.py b/modelsx/dit_v2/embedding.py deleted file mode 100644 index e972244f5767c9f34e5e77bb180ae720ce88b89c..0000000000000000000000000000000000000000 --- a/modelsx/dit_v2/embedding.py +++ /dev/null @@ -1,62 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Optional, Union -import torch -from diffusers.models.embeddings import get_timestep_embedding -from torch import nn - - -def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]): - return emb1 if emb2 is None else emb1 + emb2 - - -class TimeEmbedding(nn.Module): - def __init__( - self, - sinusoidal_dim: int, - hidden_dim: int, - output_dim: int, - ): - super().__init__() - self.sinusoidal_dim = sinusoidal_dim - self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim) - self.proj_hid = nn.Linear(hidden_dim, hidden_dim) - self.proj_out = nn.Linear(hidden_dim, output_dim) - self.act = nn.SiLU() - - def forward( - self, - timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], - device: torch.device, - dtype: torch.dtype, - ) -> torch.FloatTensor: - if not torch.is_tensor(timestep): - timestep = torch.tensor([timestep], device=device, dtype=dtype) - if timestep.ndim == 0: - timestep = timestep[None] - - emb = get_timestep_embedding( - timesteps=timestep, - embedding_dim=self.sinusoidal_dim, - flip_sin_to_cos=False, - downscale_freq_shift=0, - ) - emb = emb.to(dtype) - emb = self.proj_in(emb) - emb = self.act(emb) - emb = self.proj_hid(emb) - emb = self.act(emb) - emb = self.proj_out(emb) - return emb diff --git a/modelsx/dit_v2/mlp.py b/modelsx/dit_v2/mlp.py deleted file mode 100644 index 2d05cb021f3e3c6ac05c0e7ae1aa8a6d29475b87..0000000000000000000000000000000000000000 --- a/modelsx/dit_v2/mlp.py +++ /dev/null @@ -1,62 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Optional -import torch -import torch.nn.functional as F -from torch import nn - - -def get_mlp(mlp_type: Optional[str] = "normal"): - if mlp_type == "normal": - return MLP - elif mlp_type == "swiglu": - return SwiGLUMLP - - -class MLP(nn.Module): - def __init__( - self, - dim: int, - expand_ratio: int, - ): - super().__init__() - self.proj_in = nn.Linear(dim, dim * expand_ratio) - self.act = nn.GELU("tanh") - self.proj_out = nn.Linear(dim * expand_ratio, dim) - - def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: - x = self.proj_in(x) - x = self.act(x) - x = self.proj_out(x) - return x - - -class SwiGLUMLP(nn.Module): - def __init__( - self, - dim: int, - expand_ratio: int, - multiple_of: int = 256, - ): - super().__init__() - hidden_dim = int(2 * dim * expand_ratio / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False) - self.proj_out = nn.Linear(hidden_dim, dim, bias=False) - self.proj_in = nn.Linear(dim, hidden_dim, bias=False) - - def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: - x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) - return x diff --git a/modelsx/dit_v2/mm.py b/modelsx/dit_v2/mm.py deleted file mode 100644 index 344f89a8fa22b9a5473b8d25f208085a630f0c85..0000000000000000000000000000000000000000 --- a/modelsx/dit_v2/mm.py +++ /dev/null @@ -1,74 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Tuple -import torch -from torch import nn - - -@dataclass -class MMArg: - vid: Any - txt: Any - - -def get_args(key: str, args: List[Any]) -> List[Any]: - return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] - - -def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: - return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} - - -class MMModule(nn.Module): - def __init__( - self, - module: Callable[..., nn.Module], - *args, - shared_weights: bool = False, - vid_only: bool = False, - **kwargs, - ): - super().__init__() - self.shared_weights = shared_weights - self.vid_only = vid_only - if self.shared_weights: - assert get_args("vid", args) == get_args("txt", args) - assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) - self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) - else: - self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) - self.txt = ( - module(*get_args("txt", args), **get_kwargs("txt", kwargs)) - if not vid_only - else None - ) - - def forward( - self, - vid: torch.FloatTensor, - txt: torch.FloatTensor, - *args, - **kwargs, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - vid_module = self.vid if not self.shared_weights else self.all - vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) - if not self.vid_only: - txt_module = self.txt if not self.shared_weights else self.all - txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) - return vid, txt diff --git a/modelsx/dit_v2/modulation.py b/modelsx/dit_v2/modulation.py deleted file mode 100644 index 9e14bb005ef2d0a2c7205f593c483e8862a42858..0000000000000000000000000000000000000000 --- a/modelsx/dit_v2/modulation.py +++ /dev/null @@ -1,102 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Callable, List, Optional -import torch -from einops import rearrange -from torch import nn - -from common.cache import Cache -from common.distributed.ops import slice_inputs - -# (dim: int, emb_dim: int) -ada_layer_type = Callable[[int, int], nn.Module] - - -def get_ada_layer(ada_layer: str) -> ada_layer_type: - if ada_layer == "single": - return AdaSingle - raise NotImplementedError(f"{ada_layer} is not supported") - - -def expand_dims(x: torch.Tensor, dim: int, ndim: int): - """ - Expand tensor "x" to "ndim" by adding empty dims at "dim". - Example: x is (b d), target ndim is 5, add dim at 1, return (b 1 1 1 d). - """ - shape = x.shape - shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] - return x.reshape(shape) - - -class AdaSingle(nn.Module): - def __init__( - self, - dim: int, - emb_dim: int, - layers: List[str], - modes: List[str] = ["in", "out"], - ): - assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" - super().__init__() - self.dim = dim - self.emb_dim = emb_dim - self.layers = layers - for l in layers: - if "in" in modes: - self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5)) - self.register_parameter( - f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1) - ) - if "out" in modes: - self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5)) - - def forward( - self, - hid: torch.FloatTensor, # b ... c - emb: torch.FloatTensor, # b d - layer: str, - mode: str, - cache: Cache = Cache(disable=True), - branch_tag: str = "", - hid_len: Optional[torch.LongTensor] = None, # b - ) -> torch.FloatTensor: - idx = self.layers.index(layer) - emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] - emb = expand_dims(emb, 1, hid.ndim + 1) - - if hid_len is not None: - emb = cache( - f"emb_repeat_{idx}_{branch_tag}", - lambda: slice_inputs( - torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]), - dim=0, - ), - ) - - shiftA, scaleA, gateA = emb.unbind(-1) - shiftB, scaleB, gateB = ( - getattr(self, f"{layer}_shift", None), - getattr(self, f"{layer}_scale", None), - getattr(self, f"{layer}_gate", None), - ) - - if mode == "in": - return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) - if mode == "out": - return hid.mul_(gateA + gateB) - raise NotImplementedError - - def extra_repr(self) -> str: - return f"dim={self.dim}, emb_dim={self.emb_dim}, layers={self.layers}" \ No newline at end of file diff --git a/modelsx/dit_v2/na.py b/modelsx/dit_v2/na.py deleted file mode 100644 index 0dbd546c4705b3b9c7c19a9823f9d113a0447616..0000000000000000000000000000000000000000 --- a/modelsx/dit_v2/na.py +++ /dev/null @@ -1,241 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from itertools import chain -from typing import Callable, Dict, List, Tuple -import einops -import torch - - -def flatten( - hid: List[torch.FloatTensor], # List of (*** c) -) -> Tuple[ - torch.FloatTensor, # (L c) - torch.LongTensor, # (b n) -]: - assert len(hid) > 0 - shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) - hid = torch.cat([x.flatten(0, -2) for x in hid]) - return hid, shape - - -def unflatten( - hid: torch.FloatTensor, # (L c) or (L ... c) - hid_shape: torch.LongTensor, # (b n) -) -> List[torch.Tensor]: # List of (*** c) or (*** ... c) - hid_len = hid_shape.prod(-1) - hid = hid.split(hid_len.tolist()) - hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] - return hid - - -def concat( - vid: torch.FloatTensor, # (VL ... c) - txt: torch.FloatTensor, # (TL ... c) - vid_len: torch.LongTensor, # (b) - txt_len: torch.LongTensor, # (b) -) -> torch.FloatTensor: # (L ... c) - vid = torch.split(vid, vid_len.tolist()) - txt = torch.split(txt, txt_len.tolist()) - return torch.cat(list(chain(*zip(vid, txt)))) - - -def concat_idx( - vid_len: torch.LongTensor, # (b) - txt_len: torch.LongTensor, # (b) -) -> Tuple[ - Callable, - Callable, -]: - device = vid_len.device - vid_idx = torch.arange(vid_len.sum(), device=device) - txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) - tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len) - src_idx = torch.argsort(tgt_idx) - return ( - lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx), - lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]), - ) - - -def unconcat( - all: torch.FloatTensor, # (L ... c) - vid_len: torch.LongTensor, # (b) - txt_len: torch.LongTensor, # (b) -) -> Tuple[ - torch.FloatTensor, # (VL ... c) - torch.FloatTensor, # (TL ... c) -]: - interleave_len = list(chain(*zip(vid_len.tolist(), txt_len.tolist()))) - all = all.split(interleave_len) - vid = torch.cat(all[0::2]) - txt = torch.cat(all[1::2]) - return vid, txt - - -def repeat_concat( - vid: torch.FloatTensor, # (VL ... c) - txt: torch.FloatTensor, # (TL ... c) - vid_len: torch.LongTensor, # (n*b) - txt_len: torch.LongTensor, # (b) - txt_repeat: List, # (n) -) -> torch.FloatTensor: # (L ... c) - vid = torch.split(vid, vid_len.tolist()) - txt = torch.split(txt, txt_len.tolist()) - txt = [[x] * n for x, n in zip(txt, txt_repeat)] - txt = list(chain(*txt)) - return torch.cat(list(chain(*zip(vid, txt)))) - - -def repeat_concat_idx( - vid_len: torch.LongTensor, # (n*b) - txt_len: torch.LongTensor, # (b) - txt_repeat: torch.LongTensor, # (n) -) -> Tuple[ - Callable, - Callable, -]: - device = vid_len.device - vid_idx = torch.arange(vid_len.sum(), device=device) - txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) - txt_repeat_list = txt_repeat.tolist() - tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) - src_idx = torch.argsort(tgt_idx) - txt_idx_len = len(tgt_idx) - len(vid_idx) - repeat_txt_len = (txt_len * txt_repeat).tolist() - - def unconcat_coalesce(all): - """ - Un-concat vid & txt, and coalesce the repeated txt. - e.g. vid [0 1 2 3 4 5 6 7 8] -> 3 splits -> [0 1 2] [3 4 5] [6 7 8] - txt [9 10] - repeat_concat ==> [0 1 2 9 10 3 4 5 9 10 6 7 8 9 10] - 1. argsort re-index ==> [0 1 2 3 4 5 6 7 8 9 9 9 10 10 10] - split ==> vid_out [0 1 2 3 4 5 6 7 8] txt_out [9 9 9 10 10 10] - 2. reshape & mean for each sample to coalesce the repeated txt. - """ - vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) - txt_out_coalesced = [] - for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): - txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) - txt_out_coalesced.append(txt) - return vid_out, torch.cat(txt_out_coalesced) - - # Note: Backward of torch.index_select is non-deterministic when existing repeated index, - # the difference may cumulative like torch.repeat_interleave, so we use vanilla index here. - return ( - lambda vid, txt: torch.cat([vid, txt])[tgt_idx], - lambda all: unconcat_coalesce(all), - ) - - -def rearrange( - hid: torch.FloatTensor, # (L c) - hid_shape: torch.LongTensor, # (b n) - pattern: str, - **kwargs: Dict[str, int], -) -> Tuple[ - torch.FloatTensor, - torch.LongTensor, -]: - return flatten([einops.rearrange(h, pattern, **kwargs) for h in unflatten(hid, hid_shape)]) - - -def rearrange_idx( - hid_shape: torch.LongTensor, # (b n) - pattern: str, - **kwargs: Dict[str, int], -) -> Tuple[Callable, Callable, torch.LongTensor]: - hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) - tgt_idx, tgt_shape = rearrange(hid_idx, hid_shape, pattern, **kwargs) - tgt_idx = tgt_idx.squeeze(-1) - src_idx = torch.argsort(tgt_idx) - return ( - lambda hid: torch.index_select(hid, 0, tgt_idx), - lambda hid: torch.index_select(hid, 0, src_idx), - tgt_shape, - ) - - -def repeat( - hid: torch.FloatTensor, # (L c) - hid_shape: torch.LongTensor, # (b n) - pattern: str, - **kwargs: Dict[str, torch.LongTensor], # (b) -) -> Tuple[ - torch.FloatTensor, - torch.LongTensor, -]: - hid = unflatten(hid, hid_shape) - kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] - return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) - - -def pack( - samples: List[torch.Tensor], # List of (h w c). -) -> Tuple[ - List[torch.Tensor], # groups [(b1 h1 w1 c1), (b2 h2 w2 c2)] - List[List[int]], # reversal indices. -]: - batches = {} - indices = {} - for i, sample in enumerate(samples): - shape = sample.shape - batches[shape] = batches.get(shape, []) - indices[shape] = indices.get(shape, []) - batches[shape].append(sample) - indices[shape].append(i) - - batches = list(map(torch.stack, batches.values())) - indices = list(indices.values()) - return batches, indices - - -def unpack( - batches: List[torch.Tensor], - indices: List[List[int]], -) -> List[torch.Tensor]: - samples = [None] * (max(chain(*indices)) + 1) - for batch, index in zip(batches, indices): - for sample, i in zip(batch.unbind(), index): - samples[i] = sample - return samples - - -def window( - hid: torch.FloatTensor, # (L c) - hid_shape: torch.LongTensor, # (b n) - window_fn: Callable[[torch.Tensor], List[torch.Tensor]], -): - hid = unflatten(hid, hid_shape) - hid = list(map(window_fn, hid)) - hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) - hid, hid_shape = flatten(list(chain(*hid))) - return hid, hid_shape, hid_windows - - -def window_idx( - hid_shape: torch.LongTensor, # (b n) - window_fn: Callable[[torch.Tensor], List[torch.Tensor]], -): - hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) - tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) - tgt_idx = tgt_idx.squeeze(-1) - src_idx = torch.argsort(tgt_idx) - return ( - lambda hid: torch.index_select(hid, 0, tgt_idx), - lambda hid: torch.index_select(hid, 0, src_idx), - tgt_shape, - tgt_windows, - ) diff --git a/modelsx/dit_v2/nablocks/__init__.py b/modelsx/dit_v2/nablocks/__init__.py deleted file mode 100644 index c1a9da26ef760575192042ea32b01bd9cd1a267d..0000000000000000000000000000000000000000 --- a/modelsx/dit_v2/nablocks/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from .mmsr_block import NaMMSRTransformerBlock - - -nadit_blocks = { - "mmdit_sr": NaMMSRTransformerBlock, -} - - -def get_nablock(block_type: str): - if block_type in nadit_blocks: - return nadit_blocks[block_type] - raise NotImplementedError(f"{block_type} is not supported") diff --git a/modelsx/dit_v2/nablocks/attention/__init__.py b/modelsx/dit_v2/nablocks/attention/__init__.py deleted file mode 100644 index a7561025245d888d26ade38f25668efb216cd907..0000000000000000000000000000000000000000 --- a/modelsx/dit_v2/nablocks/attention/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from .mmattn import NaMMAttention - -attns = { - "mm_full": NaMMAttention, -} - - -def get_attn(attn_type: str): - if attn_type in attns: - return attns[attn_type] - raise NotImplementedError(f"{attn_type} is not supported") diff --git a/modelsx/dit_v2/nablocks/attention/mmattn.py b/modelsx/dit_v2/nablocks/attention/mmattn.py deleted file mode 100644 index 4fea9cb9c6fa2f82dd1aba46d658a04a19a11305..0000000000000000000000000000000000000000 --- a/modelsx/dit_v2/nablocks/attention/mmattn.py +++ /dev/null @@ -1,266 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Optional, Tuple, Union -import torch -from einops import rearrange -from torch import nn -from torch.nn import functional as F -from torch.nn.modules.utils import _triple - -from common.cache import Cache -from common.distributed.ops import gather_heads_scatter_seq, gather_seq_scatter_heads_qkv - -from ... import na -from ...attention import FlashAttentionVarlen -from ...mm import MMArg, MMModule -from ...normalization import norm_layer_type -from ...rope import get_na_rope -from ...window import get_window_op -from itertools import chain - - -class NaMMAttention(nn.Module): - def __init__( - self, - vid_dim: int, - txt_dim: int, - heads: int, - head_dim: int, - qk_bias: bool, - qk_norm: norm_layer_type, - qk_norm_eps: float, - rope_type: Optional[str], - rope_dim: int, - shared_weights: bool, - **kwargs, - ): - super().__init__() - dim = MMArg(vid_dim, txt_dim) - inner_dim = heads * head_dim - qkv_dim = inner_dim * 3 - self.head_dim = head_dim - self.proj_qkv = MMModule( - nn.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights - ) - self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_weights) - self.norm_q = MMModule( - qk_norm, - dim=head_dim, - eps=qk_norm_eps, - elementwise_affine=True, - shared_weights=shared_weights, - ) - self.norm_k = MMModule( - qk_norm, - dim=head_dim, - eps=qk_norm_eps, - elementwise_affine=True, - shared_weights=shared_weights, - ) - - self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) - self.attn = FlashAttentionVarlen() - - def forward( - self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - vid_qkv, txt_qkv = self.proj_qkv(vid, txt) - vid_qkv = gather_seq_scatter_heads_qkv( - vid_qkv, - seq_dim=0, - qkv_shape=vid_shape, - cache=cache.namespace("vid"), - ) - txt_qkv = gather_seq_scatter_heads_qkv( - txt_qkv, - seq_dim=0, - qkv_shape=txt_shape, - cache=cache.namespace("txt"), - ) - vid_qkv = rearrange(vid_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) - txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) - - vid_q, vid_k, vid_v = vid_qkv.unbind(1) - txt_q, txt_k, txt_v = txt_qkv.unbind(1) - - vid_q, txt_q = self.norm_q(vid_q, txt_q) - vid_k, txt_k = self.norm_k(vid_k, txt_k) - - if self.rope: - if self.rope.mm: - vid_q, vid_k, txt_q, txt_k = self.rope( - vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache - ) - else: - vid_q, vid_k = self.rope(vid_q, vid_k, vid_shape, cache) - - vid_len = cache("vid_len", lambda: vid_shape.prod(-1)) - txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) - all_len = cache("all_len", lambda: vid_len + txt_len) - - concat, unconcat = cache("mm_pnp", lambda: na.concat_idx(vid_len, txt_len)) - - attn = self.attn( - q=concat(vid_q, txt_q).bfloat16(), - k=concat(vid_k, txt_k).bfloat16(), - v=concat(vid_v, txt_v).bfloat16(), - cu_seqlens_q=cache("mm_seqlens", lambda: F.pad(all_len.cumsum(0), (1, 0)).int()), - cu_seqlens_k=cache("mm_seqlens", lambda: F.pad(all_len.cumsum(0), (1, 0)).int()), - max_seqlen_q=cache("mm_maxlen", lambda: all_len.max().item()), - max_seqlen_k=cache("mm_maxlen", lambda: all_len.max().item()), - ).type_as(vid_q) - - attn = rearrange(attn, "l h d -> l (h d)") - vid_out, txt_out = unconcat(attn) - vid_out = gather_heads_scatter_seq(vid_out, head_dim=1, seq_dim=0) - txt_out = gather_heads_scatter_seq(txt_out, head_dim=1, seq_dim=0) - - vid_out, txt_out = self.proj_out(vid_out, txt_out) - return vid_out, txt_out - - -class NaSwinAttention(NaMMAttention): - def __init__( - self, - *args, - window: Union[int, Tuple[int, int, int]], - window_method: str, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.window = _triple(window) - self.window_method = window_method - assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) - - self.window_op = get_window_op(window_method) - - def forward( - self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - - vid_qkv, txt_qkv = self.proj_qkv(vid, txt) - vid_qkv = gather_seq_scatter_heads_qkv( - vid_qkv, - seq_dim=0, - qkv_shape=vid_shape, - cache=cache.namespace("vid"), - ) - txt_qkv = gather_seq_scatter_heads_qkv( - txt_qkv, - seq_dim=0, - qkv_shape=txt_shape, - cache=cache.namespace("txt"), - ) - - # re-org the input seq for window attn - cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") - - def make_window(x: torch.Tensor): - t, h, w, _ = x.shape - window_slices = self.window_op((t, h, w), self.window) - return [x[st, sh, sw] for (st, sh, sw) in window_slices] - - window_partition, window_reverse, window_shape, window_count = cache_win( - "win_transform", - lambda: na.window_idx(vid_shape, make_window), - ) - vid_qkv_win = window_partition(vid_qkv) - - vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim) - txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) - - vid_q, vid_k, vid_v = vid_qkv_win.unbind(1) - txt_q, txt_k, txt_v = txt_qkv.unbind(1) - - vid_q, txt_q = self.norm_q(vid_q, txt_q) - vid_k, txt_k = self.norm_k(vid_k, txt_k) - - txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) - - vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) - txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) - all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) - concat_win, unconcat_win = cache_win( - "mm_pnp", lambda: na.repeat_concat_idx(vid_len_win, txt_len, window_count) - ) - - # window rope - if self.rope: - if self.rope.mm: - # repeat text q and k for window mmrope - _, num_h, _ = txt_q.shape - txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)") - txt_q_repeat = na.unflatten(txt_q_repeat, txt_shape) - txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] - txt_q_repeat = list(chain(*txt_q_repeat)) - txt_q_repeat, txt_shape_repeat = na.flatten(txt_q_repeat) - txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h) - - txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)") - txt_k_repeat = na.unflatten(txt_k_repeat, txt_shape) - txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] - txt_k_repeat = list(chain(*txt_k_repeat)) - txt_k_repeat, _ = na.flatten(txt_k_repeat) - txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h) - - vid_q, vid_k, txt_q, txt_k = self.rope( - vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win - ) - else: - vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) - - out = self.attn( - q=concat_win(vid_q, txt_q).bfloat16(), - k=concat_win(vid_k, txt_k).bfloat16(), - v=concat_win(vid_v, txt_v).bfloat16(), - cu_seqlens_q=cache_win( - "vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() - ), - cu_seqlens_k=cache_win( - "vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() - ), - max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()), - max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()), - ).type_as(vid_q) - - # text pooling - vid_out, txt_out = unconcat_win(out) - - vid_out = rearrange(vid_out, "l h d -> l (h d)") - txt_out = rearrange(txt_out, "l h d -> l (h d)") - vid_out = window_reverse(vid_out) - - vid_out = gather_heads_scatter_seq(vid_out, head_dim=1, seq_dim=0) - txt_out = gather_heads_scatter_seq(txt_out, head_dim=1, seq_dim=0) - - vid_out, txt_out = self.proj_out(vid_out, txt_out) - - return vid_out, txt_out \ No newline at end of file diff --git a/modelsx/dit_v2/nablocks/mmsr_block.py b/modelsx/dit_v2/nablocks/mmsr_block.py deleted file mode 100644 index 407c5b3eac3d0e572a148283ac322cf50a77d8a4..0000000000000000000000000000000000000000 --- a/modelsx/dit_v2/nablocks/mmsr_block.py +++ /dev/null @@ -1,119 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Tuple -import torch -import torch.nn as nn - -# from ..cache import Cache -from common.cache import Cache - -from .attention.mmattn import NaSwinAttention -from ..mm import MMArg -from ..modulation import ada_layer_type -from ..normalization import norm_layer_type -from ..mm import MMArg, MMModule -from ..mlp import get_mlp - - -class NaMMSRTransformerBlock(nn.Module): - def __init__( - self, - *, - vid_dim: int, - txt_dim: int, - emb_dim: int, - heads: int, - head_dim: int, - expand_ratio: int, - norm: norm_layer_type, - norm_eps: float, - ada: ada_layer_type, - qk_bias: bool, - qk_norm: norm_layer_type, - mlp_type: str, - shared_weights: bool, - rope_type: str, - rope_dim: int, - is_last_layer: bool, - **kwargs, - ): - super().__init__() - dim = MMArg(vid_dim, txt_dim) - self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) - - self.attn = NaSwinAttention( - vid_dim=vid_dim, - txt_dim=txt_dim, - heads=heads, - head_dim=head_dim, - qk_bias=qk_bias, - qk_norm=qk_norm, - qk_norm_eps=norm_eps, - rope_type=rope_type, - rope_dim=rope_dim, - shared_weights=shared_weights, - window=kwargs.pop("window", None), - window_method=kwargs.pop("window_method", None), - ) - - self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) - self.mlp = MMModule( - get_mlp(mlp_type), - dim=dim, - expand_ratio=expand_ratio, - shared_weights=shared_weights, - vid_only=is_last_layer - ) - self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer) - self.is_last_layer = is_last_layer - - def forward( - self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - emb: torch.FloatTensor, - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - torch.LongTensor, - torch.LongTensor, - ]: - hid_len = MMArg( - cache("vid_len", lambda: vid_shape.prod(-1)), - cache("txt_len", lambda: txt_shape.prod(-1)), - ) - ada_kwargs = { - "emb": emb, - "hid_len": hid_len, - "cache": cache, - "branch_tag": MMArg("vid", "txt"), - } - - vid_attn, txt_attn = self.attn_norm(vid, txt) - vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) - vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) - vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) - vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) - - vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) - vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) - vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) - vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) - vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) - - return vid_mlp, txt_mlp, vid_shape, txt_shape diff --git a/modelsx/dit_v2/nadit.py b/modelsx/dit_v2/nadit.py deleted file mode 100644 index fe9d7f85fa38e330069d1888cdd996468c719144..0000000000000000000000000000000000000000 --- a/modelsx/dit_v2/nadit.py +++ /dev/null @@ -1,246 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union, Callable -import torch -from torch import nn - -from common.cache import Cache -from common.distributed.ops import slice_inputs - -from . import na -from .embedding import TimeEmbedding -from .modulation import get_ada_layer -from .nablocks import get_nablock -from .normalization import get_norm_layer -from .patch import get_na_patch_layers - -# Fake func, no checkpointing is required for inference -def gradient_checkpointing(module: Union[Callable, nn.Module], *args, enabled: bool, **kwargs): - return module(*args, **kwargs) - -@dataclass -class NaDiTOutput: - vid_sample: torch.Tensor - - -class NaDiT(nn.Module): - """ - Native Resolution Diffusion Transformer (NaDiT) - """ - - gradient_checkpointing = False - - def __init__( - self, - vid_in_channels: int, - vid_out_channels: int, - vid_dim: int, - txt_in_dim: Union[int, List[int]], - txt_dim: Optional[int], - emb_dim: int, - heads: int, - head_dim: int, - expand_ratio: int, - norm: Optional[str], - norm_eps: float, - ada: str, - qk_bias: bool, - qk_norm: Optional[str], - patch_size: Union[int, Tuple[int, int, int]], - num_layers: int, - block_type: Union[str, Tuple[str]], - mm_layers: Union[int, Tuple[bool]], - mlp_type: str = "normal", - patch_type: str = "v1", - rope_type: Optional[str] = "rope3d", - rope_dim: Optional[int] = None, - window: Optional[Tuple] = None, - window_method: Optional[Tuple[str]] = None, - msa_type: Optional[Tuple[str]] = None, - mca_type: Optional[Tuple[str]] = None, - txt_in_norm: Optional[str] = None, - txt_in_norm_scale_factor: int = 0.01, - txt_proj_type: Optional[str] = "linear", - vid_out_norm: Optional[str] = None, - **kwargs, - ): - ada = get_ada_layer(ada) - norm = get_norm_layer(norm) - qk_norm = get_norm_layer(qk_norm) - rope_dim = rope_dim if rope_dim is not None else head_dim // 2 - if isinstance(block_type, str): - block_type = [block_type] * num_layers - elif len(block_type) != num_layers: - raise ValueError("The ``block_type`` list should equal to ``num_layers``.") - super().__init__() - NaPatchIn, NaPatchOut = get_na_patch_layers(patch_type) - self.vid_in = NaPatchIn( - in_channels=vid_in_channels, - patch_size=patch_size, - dim=vid_dim, - ) - if not isinstance(txt_in_dim, int): - self.txt_in = nn.ModuleList([]) - for in_dim in txt_in_dim: - txt_norm_layer = get_norm_layer(txt_in_norm)(txt_dim, norm_eps, True) - if txt_proj_type == "linear": - txt_proj_layer = nn.Linear(in_dim, txt_dim) - else: - txt_proj_layer = nn.Sequential( - nn.Linear(in_dim, in_dim), nn.GELU("tanh"), nn.Linear(in_dim, txt_dim) - ) - torch.nn.init.constant_(txt_norm_layer.weight, txt_in_norm_scale_factor) - self.txt_in.append( - nn.Sequential( - txt_proj_layer, - txt_norm_layer, - ) - ) - else: - self.txt_in = ( - nn.Linear(txt_in_dim, txt_dim) - if txt_in_dim and txt_in_dim != txt_dim - else nn.Identity() - ) - self.emb_in = TimeEmbedding( - sinusoidal_dim=256, - hidden_dim=max(vid_dim, txt_dim), - output_dim=emb_dim, - ) - - if window is None or isinstance(window[0], int): - window = [window] * num_layers - if window_method is None or isinstance(window_method, str): - window_method = [window_method] * num_layers - - if msa_type is None or isinstance(msa_type, str): - msa_type = [msa_type] * num_layers - if mca_type is None or isinstance(mca_type, str): - mca_type = [mca_type] * num_layers - - self.blocks = nn.ModuleList( - [ - get_nablock(block_type[i])( - vid_dim=vid_dim, - txt_dim=txt_dim, - emb_dim=emb_dim, - heads=heads, - head_dim=head_dim, - expand_ratio=expand_ratio, - norm=norm, - norm_eps=norm_eps, - ada=ada, - qk_bias=qk_bias, - qk_norm=qk_norm, - shared_weights=not ( - (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] - ), - mlp_type=mlp_type, - window=window[i], - window_method=window_method[i], - msa_type=msa_type[i], - mca_type=mca_type[i], - rope_type=rope_type, - rope_dim=rope_dim, - is_last_layer=(i == num_layers - 1), - **kwargs, - ) - for i in range(num_layers) - ] - ) - - self.vid_out_norm = None - if vid_out_norm is not None: - self.vid_out_norm = get_norm_layer(vid_out_norm)( - dim=vid_dim, - eps=norm_eps, - elementwise_affine=True, - ) - self.vid_out_ada = ada( - dim=vid_dim, - emb_dim=emb_dim, - layers=["out"], - modes=["in"], - ) - - self.vid_out = NaPatchOut( - out_channels=vid_out_channels, - patch_size=patch_size, - dim=vid_dim, - ) - - def set_gradient_checkpointing(self, enable: bool): - self.gradient_checkpointing = enable - - def forward( - self, - vid: torch.FloatTensor, # l c - txt: Union[torch.FloatTensor, List[torch.FloatTensor]], # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: Union[torch.LongTensor, List[torch.LongTensor]], # b 1 - timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b - disable_cache: bool = False, # for test - ): - cache = Cache(disable=disable_cache) - - # slice vid after patching in when using sequence parallelism - if isinstance(txt, list): - assert isinstance(self.txt_in, nn.ModuleList) - txt = [ - na.unflatten(fc(i), s) for fc, i, s in zip(self.txt_in, txt, txt_shape) - ] # B L D - txt, txt_shape = na.flatten([torch.cat(t, dim=0) for t in zip(*txt)]) - txt = slice_inputs(txt, dim=0) - else: - txt = slice_inputs(txt, dim=0) - txt = self.txt_in(txt) - - # Video input. - # Sequence parallel slicing is done inside patching class. - vid, vid_shape = self.vid_in(vid, vid_shape, cache) - - # Embedding input. - emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) - - # Body - for i, block in enumerate(self.blocks): - vid, txt, vid_shape, txt_shape = gradient_checkpointing( - enabled=(self.gradient_checkpointing and self.training), - module=block, - vid=vid, - txt=txt, - vid_shape=vid_shape, - txt_shape=txt_shape, - emb=emb, - cache=cache, - ) - - # Video output norm. - if self.vid_out_norm: - vid = self.vid_out_norm(vid) - vid = self.vid_out_ada( - vid, - emb=emb, - layer="out", - mode="in", - hid_len=cache("vid_len", lambda: vid_shape.prod(-1)), - cache=cache, - branch_tag="vid", - ) - - # Video output. - vid, vid_shape = self.vid_out(vid, vid_shape, cache) - return NaDiTOutput(vid_sample=vid) diff --git a/modelsx/dit_v2/normalization.py b/modelsx/dit_v2/normalization.py deleted file mode 100644 index 98827a9c71f9fd6e461937774d022b68844aee34..0000000000000000000000000000000000000000 --- a/modelsx/dit_v2/normalization.py +++ /dev/null @@ -1,63 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Callable, Optional -from diffusers.models.normalization import RMSNorm -from torch import nn - -# (dim: int, eps: float, elementwise_affine: bool) -norm_layer_type = Callable[[int, float, bool], nn.Module] - - -def get_norm_layer(norm_type: Optional[str]) -> norm_layer_type: - - def _norm_layer(dim: int, eps: float, elementwise_affine: bool): - if norm_type is None: - return nn.Identity() - - if norm_type == "layer": - return nn.LayerNorm( - normalized_shape=dim, - eps=eps, - elementwise_affine=elementwise_affine, - ) - - if norm_type == "rms": - return RMSNorm( - dim=dim, - eps=eps, - elementwise_affine=elementwise_affine, - ) - - if norm_type == "fusedln": - from apex.normalization import FusedLayerNorm - - return FusedLayerNorm( - normalized_shape=dim, - elementwise_affine=elementwise_affine, - eps=eps, - ) - - if norm_type == "fusedrms": - from apex.normalization import FusedRMSNorm - - return FusedRMSNorm( - normalized_shape=dim, - elementwise_affine=elementwise_affine, - eps=eps, - ) - - raise NotImplementedError(f"{norm_type} is not supported") - - return _norm_layer diff --git a/modelsx/dit_v2/patch/__init__.py b/modelsx/dit_v2/patch/__init__.py deleted file mode 100644 index 4e3c9783163f1e671f2d946dfad39ca33b12843d..0000000000000000000000000000000000000000 --- a/modelsx/dit_v2/patch/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -def get_na_patch_layers(patch_type="v1"): - assert patch_type in ["v1"] - if patch_type == "v1": - from .patch_v1 import NaPatchIn, NaPatchOut - return NaPatchIn, NaPatchOut diff --git a/modelsx/dit_v2/patch/patch_v1.py b/modelsx/dit_v2/patch/patch_v1.py deleted file mode 100644 index 0231bc0905e70e1fc702fe088fb2d0dac30fcc71..0000000000000000000000000000000000000000 --- a/modelsx/dit_v2/patch/patch_v1.py +++ /dev/null @@ -1,127 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Tuple, Union -import torch -from einops import rearrange -from torch import nn -from torch.nn.modules.utils import _triple - -from common.cache import Cache -from common.distributed.ops import gather_outputs, slice_inputs - -from .. import na - - -class PatchIn(nn.Module): - def __init__( - self, - in_channels: int, - patch_size: Union[int, Tuple[int, int, int]], - dim: int, - ): - super().__init__() - t, h, w = _triple(patch_size) - self.patch_size = t, h, w - self.proj = nn.Linear(in_channels * t * h * w, dim) - - def forward( - self, - vid: torch.Tensor, - ) -> torch.Tensor: - t, h, w = self.patch_size - if t > 1: - assert vid.size(2) % t == 1 - vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) - vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) - vid = self.proj(vid) - return vid - - -class PatchOut(nn.Module): - def __init__( - self, - out_channels: int, - patch_size: Union[int, Tuple[int, int, int]], - dim: int, - ): - super().__init__() - t, h, w = _triple(patch_size) - self.patch_size = t, h, w - self.proj = nn.Linear(dim, out_channels * t * h * w) - - def forward( - self, - vid: torch.Tensor, - ) -> torch.Tensor: - t, h, w = self.patch_size - vid = self.proj(vid) - vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) - if t > 1: - vid = vid[:, :, (t - 1) :] - return vid - - -class NaPatchIn(PatchIn): - def forward( - self, - vid: torch.Tensor, # l c - vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), # for test - ) -> torch.Tensor: - cache = cache.namespace("patch") - vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) - t, h, w = self.patch_size - if not (t == h == w == 1): - vid = na.unflatten(vid, vid_shape) - for i in range(len(vid)): - if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: - vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0) - vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w) - vid, vid_shape = na.flatten(vid) - - # slice vid after patching in when using sequence parallelism - vid = slice_inputs(vid, dim=0) - vid = self.proj(vid) - return vid, vid_shape - - -class NaPatchOut(PatchOut): - def forward( - self, - vid: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), # for test - ) -> Tuple[ - torch.FloatTensor, - torch.LongTensor, - ]: - cache = cache.namespace("patch") - vid_shape_before_patchify = cache.get("vid_shape_before_patchify") - - t, h, w = self.patch_size - vid = self.proj(vid) - # gather vid before patching out when enabling sequence parallelism - vid = gather_outputs( - vid, gather_dim=0, padding_dim=0, unpad_shape=vid_shape, cache=cache.namespace("vid") - ) - if not (t == h == w == 1): - vid = na.unflatten(vid, vid_shape) - for i in range(len(vid)): - vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w) - if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: - vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :] - vid, vid_shape = na.flatten(vid) - - return vid, vid_shape diff --git a/modelsx/dit_v2/rope.py b/modelsx/dit_v2/rope.py deleted file mode 100644 index ceb5458ba2829417a93124b9e06a86b74a523765..0000000000000000000000000000000000000000 --- a/modelsx/dit_v2/rope.py +++ /dev/null @@ -1,150 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from functools import lru_cache -from typing import Optional, Tuple -import torch -from einops import rearrange -from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb -from torch import nn - -from common.cache import Cache - - -class RotaryEmbeddingBase(nn.Module): - def __init__(self, dim: int, rope_dim: int): - super().__init__() - self.rope = RotaryEmbedding( - dim=dim // rope_dim, - freqs_for="pixel", - max_freq=256, - ) - # 1. Set model.requires_grad_(True) after model creation will make - # the `requires_grad=False` for rope freqs no longer hold. - # 2. Even if we don't set requires_grad_(True) explicitly, - # FSDP is not memory efficient when handling fsdp_wrap - # with mixed requires_grad=True/False. - # With above consideration, it is easier just remove the freqs - # out of nn.Parameters when `learned_freq=False` - freqs = self.rope.freqs - del self.rope.freqs - self.rope.register_buffer("freqs", freqs.data) - - @lru_cache(maxsize=128) - def get_axial_freqs(self, *dims): - return self.rope.get_axial_freqs(*dims) - - -class RotaryEmbedding3d(RotaryEmbeddingBase): - def __init__(self, dim: int): - super().__init__(dim, rope_dim=3) - self.mm = False - - def forward( - self, - q: torch.FloatTensor, # b h l d - k: torch.FloatTensor, # b h l d - size: Tuple[int, int, int], - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - T, H, W = size - freqs = self.get_axial_freqs(T, H, W) - q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) - k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) - q = apply_rotary_emb(freqs, q.float()).to(q.dtype) - k = apply_rotary_emb(freqs, k.float()).to(k.dtype) - q = rearrange(q, "b h T H W d -> b h (T H W) d") - k = rearrange(k, "b h T H W d -> b h (T H W) d") - return q, k - - -class MMRotaryEmbeddingBase(RotaryEmbeddingBase): - def __init__(self, dim: int, rope_dim: int): - super().__init__(dim, rope_dim) - self.rope = RotaryEmbedding( - dim=dim // rope_dim, - freqs_for="lang", - theta=10000, - ) - freqs = self.rope.freqs - del self.rope.freqs - self.rope.register_buffer("freqs", freqs.data) - self.mm = True - - -class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): - def __init__(self, dim: int): - super().__init__(dim, rope_dim=3) - - def forward( - self, - vid_q: torch.FloatTensor, # L h d - vid_k: torch.FloatTensor, # L h d - vid_shape: torch.LongTensor, # B 3 - txt_q: torch.FloatTensor, # L h d - txt_k: torch.FloatTensor, # L h d - txt_shape: torch.LongTensor, # B 1 - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - torch.FloatTensor, - torch.FloatTensor, - ]: - vid_freqs, txt_freqs = cache( - "mmrope_freqs_3d", - lambda: self.get_freqs(vid_shape, txt_shape), - ) - vid_q = rearrange(vid_q, "L h d -> h L d") - vid_k = rearrange(vid_k, "L h d -> h L d") - vid_q = apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) - vid_k = apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype) - vid_q = rearrange(vid_q, "h L d -> L h d") - vid_k = rearrange(vid_k, "h L d -> L h d") - - txt_q = rearrange(txt_q, "L h d -> h L d") - txt_k = rearrange(txt_k, "L h d -> h L d") - txt_q = apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype) - txt_k = apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype) - txt_q = rearrange(txt_q, "h L d -> L h d") - txt_k = rearrange(txt_k, "h L d -> L h d") - return vid_q, vid_k, txt_q, txt_k - - def get_freqs( - self, - vid_shape: torch.LongTensor, - txt_shape: torch.LongTensor, - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - ]: - vid_freqs = self.get_axial_freqs(1024, 128, 128) - txt_freqs = self.get_axial_freqs(1024) - vid_freq_list, txt_freq_list = [], [] - for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): - vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) - txt_freq = txt_freqs[:l].repeat(1, 3).reshape(-1, vid_freqs.size(-1)) - vid_freq_list.append(vid_freq) - txt_freq_list.append(txt_freq) - return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0) - - -def get_na_rope(rope_type: Optional[str], dim: int): - if rope_type is None: - return None - if rope_type == "mmrope3d": - return NaMMRotaryEmbedding3d(dim=dim) - raise NotImplementedError(f"{rope_type} is not supported.") diff --git a/modelsx/dit_v2/window.py b/modelsx/dit_v2/window.py deleted file mode 100644 index b7475921ae283cf76d82bff7521233c133f54bfd..0000000000000000000000000000000000000000 --- a/modelsx/dit_v2/window.py +++ /dev/null @@ -1,83 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from math import ceil -from typing import Tuple -import math - -def get_window_op(name: str): - if name == "720pwin_by_size_bysize": - return make_720Pwindows_bysize - if name == "720pswin_by_size_bysize": - return make_shifted_720Pwindows_bysize - raise ValueError(f"Unknown windowing method: {name}") - - -# -------------------------------- Windowing -------------------------------- # -def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): - t, h, w = size - resized_nt, resized_nh, resized_nw = num_windows - #cal windows under 720p - scale = math.sqrt((45 * 80) / (h * w)) - resized_h, resized_w = round(h * scale), round(w * scale) - wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. - wt = ceil(min(t, 30) / resized_nt) # window size. - nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. - return [ - ( - slice(it * wt, min((it + 1) * wt, t)), - slice(ih * wh, min((ih + 1) * wh, h)), - slice(iw * ww, min((iw + 1) * ww, w)), - ) - for iw in range(nw) - if min((iw + 1) * ww, w) > iw * ww - for ih in range(nh) - if min((ih + 1) * wh, h) > ih * wh - for it in range(nt) - if min((it + 1) * wt, t) > it * wt - ] - -def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): - t, h, w = size - resized_nt, resized_nh, resized_nw = num_windows - #cal windows under 720p - scale = math.sqrt((45 * 80) / (h * w)) - resized_h, resized_w = round(h * scale), round(w * scale) - wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. - wt = ceil(min(t, 30) / resized_nt) # window size. - - st, sh, sw = ( # shift size. - 0.5 if wt < t else 0, - 0.5 if wh < h else 0, - 0.5 if ww < w else 0, - ) - nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. - nt, nh, nw = ( # number of window. - nt + 1 if st > 0 else 1, - nh + 1 if sh > 0 else 1, - nw + 1 if sw > 0 else 1, - ) - return [ - ( - slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)), - slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)), - slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)), - ) - for iw in range(nw) - if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0) - for ih in range(nh) - if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0) - for it in range(nt) - if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0) - ] diff --git a/modelsx/video_vae_v3/modules/attn_video_vae.py b/modelsx/video_vae_v3/modules/attn_video_vae.py deleted file mode 100644 index edaf817452af1df8c85746f07d017e8802d989b0..0000000000000000000000000000000000000000 --- a/modelsx/video_vae_v3/modules/attn_video_vae.py +++ /dev/null @@ -1,1345 +0,0 @@ -# Copyright (c) 2023 HuggingFace Team -# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates. -# SPDX-License-Identifier: Apache License, Version 2.0 (the "License") -# -# This file has been modified by ByteDance Ltd. and/or its affiliates. on 1st June 2025 -# -# Original file was released under Apache License, Version 2.0 (the "License"), with the full license text -# available at http://www.apache.org/licenses/LICENSE-2.0. -# -# This modified file is released under the same license. - - -from contextlib import nullcontext -from typing import Literal, Optional, Tuple, Union -import diffusers -import torch -import torch.nn as nn -import torch.nn.functional as F -from diffusers.models.attention_processor import Attention, SpatialNorm -from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution -from diffusers.models.downsampling import Downsample2D -from diffusers.models.lora import LoRACompatibleConv -from diffusers.models.modeling_outputs import AutoencoderKLOutput -from diffusers.models.resnet import ResnetBlock2D -from diffusers.models.unets.unet_2d_blocks import DownEncoderBlock2D, UpDecoderBlock2D -from diffusers.models.upsampling import Upsample2D -from diffusers.utils import is_torch_version -from diffusers.utils.accelerate_utils import apply_forward_hook -from einops import rearrange - -from common.distributed.advanced import get_sequence_parallel_world_size -from common.logger import get_logger -from models.video_vae_v3.modules.causal_inflation_lib import ( - InflatedCausalConv3d, - causal_norm_wrapper, - init_causal_conv3d, - remove_head, -) -from models.video_vae_v3.modules.context_parallel_lib import ( - causal_conv_gather_outputs, - causal_conv_slice_inputs, -) -from models.video_vae_v3.modules.global_config import set_norm_limit -from models.video_vae_v3.modules.types import ( - CausalAutoencoderOutput, - CausalDecoderOutput, - CausalEncoderOutput, - MemoryState, - _inflation_mode_t, - _memory_device_t, - _receptive_field_t, -) - -logger = get_logger(__name__) # pylint: disable=invalid-name - - -class Upsample3D(Upsample2D): - """A 3D upsampling layer with an optional convolution.""" - - def __init__( - self, - *args, - inflation_mode: _inflation_mode_t = "tail", - temporal_up: bool = False, - spatial_up: bool = True, - slicing: bool = False, - **kwargs, - ): - super().__init__(*args, **kwargs) - conv = self.conv if self.name == "conv" else self.Conv2d_0 - - assert type(conv) is not nn.ConvTranspose2d - # Note: lora_layer is not passed into constructor in the original implementation. - # So we make a simplification. - conv = init_causal_conv3d( - self.channels, - self.out_channels, - 3, - padding=1, - inflation_mode=inflation_mode, - ) - - self.temporal_up = temporal_up - self.spatial_up = spatial_up - self.temporal_ratio = 2 if temporal_up else 1 - self.spatial_ratio = 2 if spatial_up else 1 - self.slicing = slicing - - assert not self.interpolate - # [Override] MAGViT v2 implementation - if not self.interpolate: - upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio - self.upscale_conv = nn.Conv3d( - self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 - ) - identity = ( - torch.eye(self.channels) - .repeat(upscale_ratio, 1) - .reshape_as(self.upscale_conv.weight) - ) - self.upscale_conv.weight.data.copy_(identity) - nn.init.zeros_(self.upscale_conv.bias) - - if self.name == "conv": - self.conv = conv - else: - self.Conv2d_0 = conv - - def forward( - self, - hidden_states: torch.FloatTensor, - output_size: Optional[int] = None, - memory_state: MemoryState = MemoryState.DISABLED, - **kwargs, - ) -> torch.FloatTensor: - assert hidden_states.shape[1] == self.channels - - if hasattr(self, "norm") and self.norm is not None: - # [Overridden] change to causal norm. - hidden_states = causal_norm_wrapper(self.norm, hidden_states) - - if self.use_conv_transpose: - return self.conv(hidden_states) - - if self.slicing: - split_size = hidden_states.size(2) // 2 - hidden_states = list( - hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) - ) - else: - hidden_states = [hidden_states] - - for i in range(len(hidden_states)): - hidden_states[i] = self.upscale_conv(hidden_states[i]) - hidden_states[i] = rearrange( - hidden_states[i], - "b (x y z c) f h w -> b c (f z) (h x) (w y)", - x=self.spatial_ratio, - y=self.spatial_ratio, - z=self.temporal_ratio, - ) - - # [Overridden] For causal temporal conv - if self.temporal_up and memory_state != MemoryState.ACTIVE: - hidden_states[0] = remove_head(hidden_states[0]) - - if not self.slicing: - hidden_states = hidden_states[0] - - if self.use_conv: - if self.name == "conv": - hidden_states = self.conv(hidden_states, memory_state=memory_state) - else: - hidden_states = self.Conv2d_0(hidden_states, memory_state=memory_state) - - if not self.slicing: - return hidden_states - else: - return torch.cat(hidden_states, dim=2) - - -class Downsample3D(Downsample2D): - """A 3D downsampling layer with an optional convolution.""" - - def __init__( - self, - *args, - inflation_mode: _inflation_mode_t = "tail", - spatial_down: bool = False, - temporal_down: bool = False, - **kwargs, - ): - super().__init__(*args, **kwargs) - conv = self.conv - self.temporal_down = temporal_down - self.spatial_down = spatial_down - - self.temporal_ratio = 2 if temporal_down else 1 - self.spatial_ratio = 2 if spatial_down else 1 - - self.temporal_kernel = 3 if temporal_down else 1 - self.spatial_kernel = 3 if spatial_down else 1 - - if type(conv) in [nn.Conv2d, LoRACompatibleConv]: - # Note: lora_layer is not passed into constructor in the original implementation. - # So we make a simplification. - conv = init_causal_conv3d( - self.channels, - self.out_channels, - kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), - stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), - padding=( - 1 if self.temporal_down else 0, - self.padding if self.spatial_down else 0, - self.padding if self.spatial_down else 0, - ), - inflation_mode=inflation_mode, - ) - elif type(conv) is nn.AvgPool2d: - assert self.channels == self.out_channels - conv = nn.AvgPool3d( - kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), - stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), - ) - else: - raise NotImplementedError - - if self.name == "conv": - self.Conv2d_0 = conv - self.conv = conv - else: - self.conv = conv - - def forward( - self, - hidden_states: torch.FloatTensor, - memory_state: MemoryState = MemoryState.DISABLED, - **kwargs, - ) -> torch.FloatTensor: - - assert hidden_states.shape[1] == self.channels - - if hasattr(self, "norm") and self.norm is not None: - # [Overridden] change to causal norm. - hidden_states = causal_norm_wrapper(self.norm, hidden_states) - - if self.use_conv and self.padding == 0 and self.spatial_down: - pad = (0, 1, 0, 1) - hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) - - assert hidden_states.shape[1] == self.channels - - hidden_states = self.conv(hidden_states, memory_state=memory_state) - - return hidden_states - - -class ResnetBlock3D(ResnetBlock2D): - def __init__( - self, - *args, - inflation_mode: _inflation_mode_t = "tail", - time_receptive_field: _receptive_field_t = "half", - slicing: bool = False, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.conv1 = init_causal_conv3d( - self.in_channels, - self.out_channels, - kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), - stride=1, - padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), - inflation_mode=inflation_mode, - ) - - self.conv2 = init_causal_conv3d( - self.out_channels, - self.conv2.out_channels, - kernel_size=3, - stride=1, - padding=1, - inflation_mode=inflation_mode, - ) - - if self.up: - assert type(self.upsample) is Upsample2D - self.upsample = Upsample3D( - self.in_channels, - use_conv=False, - inflation_mode=inflation_mode, - slicing=slicing, - ) - elif self.down: - assert type(self.downsample) is Downsample2D - self.downsample = Downsample3D( - self.in_channels, - use_conv=False, - padding=1, - name="op", - inflation_mode=inflation_mode, - ) - - if self.use_in_shortcut: - self.conv_shortcut = init_causal_conv3d( - self.in_channels, - self.conv_shortcut.out_channels, - kernel_size=1, - stride=1, - padding=0, - bias=(self.conv_shortcut.bias is not None), - inflation_mode=inflation_mode, - ) - - def forward( - self, input_tensor, temb, memory_state: MemoryState = MemoryState.DISABLED, **kwargs - ): - hidden_states = input_tensor - - hidden_states = causal_norm_wrapper(self.norm1, hidden_states) - - hidden_states = self.nonlinearity(hidden_states) - - if self.upsample is not None: - # upsample_nearest_nhwc fails with large batch sizes. - # see https://github.com/huggingface/diffusers/issues/984 - if hidden_states.shape[0] >= 64: - input_tensor = input_tensor.contiguous() - hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor, memory_state=memory_state) - hidden_states = self.upsample(hidden_states, memory_state=memory_state) - elif self.downsample is not None: - input_tensor = self.downsample(input_tensor, memory_state=memory_state) - hidden_states = self.downsample(hidden_states, memory_state=memory_state) - - hidden_states = self.conv1(hidden_states, memory_state=memory_state) - - if self.time_emb_proj is not None: - if not self.skip_time_act: - temb = self.nonlinearity(temb) - temb = self.time_emb_proj(temb)[:, :, None, None] - - if temb is not None and self.time_embedding_norm == "default": - hidden_states = hidden_states + temb - - hidden_states = causal_norm_wrapper(self.norm2, hidden_states) - - if temb is not None and self.time_embedding_norm == "scale_shift": - scale, shift = torch.chunk(temb, 2, dim=1) - hidden_states = hidden_states * (1 + scale) + shift - - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, memory_state=memory_state) - - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) - - output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - - return output_tensor - - -class DownEncoderBlock3D(DownEncoderBlock2D): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_downsample: bool = True, - downsample_padding: int = 1, - inflation_mode: _inflation_mode_t = "tail", - time_receptive_field: _receptive_field_t = "half", - temporal_down: bool = True, - spatial_down: bool = True, - ): - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - dropout=dropout, - num_layers=num_layers, - resnet_eps=resnet_eps, - resnet_time_scale_shift=resnet_time_scale_shift, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_pre_norm=resnet_pre_norm, - output_scale_factor=output_scale_factor, - add_downsample=add_downsample, - downsample_padding=downsample_padding, - ) - resnets = [] - temporal_modules = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - # [Override] Replace module. - ResnetBlock3D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - ) - temporal_modules.append(nn.Identity()) - - self.resnets = nn.ModuleList(resnets) - self.temporal_modules = nn.ModuleList(temporal_modules) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - # [Override] Replace module. - Downsample3D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - temporal_down=temporal_down, - spatial_down=spatial_down, - inflation_mode=inflation_mode, - ) - ] - ) - else: - self.downsamplers = None - - def forward( - self, - hidden_states: torch.FloatTensor, - memory_state: MemoryState = MemoryState.DISABLED, - **kwargs, - ) -> torch.FloatTensor: - for resnet, temporal in zip(self.resnets, self.temporal_modules): - hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) - hidden_states = temporal(hidden_states) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, memory_state=memory_state) - - return hidden_states - - -class UpDecoderBlock3D(UpDecoderBlock2D): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - temb_channels: Optional[int] = None, - inflation_mode: _inflation_mode_t = "tail", - time_receptive_field: _receptive_field_t = "half", - temporal_up: bool = True, - spatial_up: bool = True, - slicing: bool = False, - ): - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - dropout=dropout, - num_layers=num_layers, - resnet_eps=resnet_eps, - resnet_time_scale_shift=resnet_time_scale_shift, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_pre_norm=resnet_pre_norm, - output_scale_factor=output_scale_factor, - add_upsample=add_upsample, - temb_channels=temb_channels, - ) - resnets = [] - temporal_modules = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - resnets.append( - # [Override] Replace module. - ResnetBlock3D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - slicing=slicing, - ) - ) - - temporal_modules.append(nn.Identity()) - - self.resnets = nn.ModuleList(resnets) - self.temporal_modules = nn.ModuleList(temporal_modules) - - if add_upsample: - # [Override] Replace module & use learnable upsample - self.upsamplers = nn.ModuleList( - [ - Upsample3D( - out_channels, - use_conv=True, - out_channels=out_channels, - temporal_up=temporal_up, - spatial_up=spatial_up, - interpolate=False, - inflation_mode=inflation_mode, - slicing=slicing, - ) - ] - ) - else: - self.upsamplers = None - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: Optional[torch.FloatTensor] = None, - memory_state: MemoryState = MemoryState.DISABLED, - ) -> torch.FloatTensor: - for resnet, temporal in zip(self.resnets, self.temporal_modules): - hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) - hidden_states = temporal(hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, memory_state=memory_state) - - return hidden_states - - -class UNetMidBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - add_attention: bool = True, - attention_head_dim: int = 1, - output_scale_factor: float = 1.0, - inflation_mode: _inflation_mode_t = "tail", - time_receptive_field: _receptive_field_t = "half", - ): - super().__init__() - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - self.add_attention = add_attention - - # there is always at least one resnet - resnets = [ - # [Override] Replace module. - ResnetBlock3D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - ] - attentions = [] - - if attention_head_dim is None: - logger.warn( - f"It is not recommend to pass `attention_head_dim=None`. " - f"Defaulting `attention_head_dim` to `in_channels`: {in_channels}." - ) - attention_head_dim = in_channels - - for _ in range(num_layers): - if self.add_attention: - attentions.append( - Attention( - in_channels, - heads=in_channels // attention_head_dim, - dim_head=attention_head_dim, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=( - resnet_groups if resnet_time_scale_shift == "default" else None - ), - spatial_norm_dim=( - temb_channels if resnet_time_scale_shift == "spatial" else None - ), - residual_connection=True, - bias=True, - upcast_softmax=True, - _from_deprecated_attn_block=True, - ) - ) - else: - attentions.append(None) - - resnets.append( - ResnetBlock3D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None, memory_state: MemoryState = MemoryState.DISABLED): - video_length, frame_height, frame_width = hidden_states.size()[-3:] - hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") - hidden_states = attn(hidden_states, temb=temb) - hidden_states = rearrange( - hidden_states, "(b f) c h w -> b c f h w", f=video_length - ) - hidden_states = resnet(hidden_states, temb, memory_state=memory_state) - - return hidden_states - - -class Encoder3D(nn.Module): - r""" - [Override] override most logics to support extra condition input and causal conv - - The `Encoder` layer of a variational autoencoder that encodes - its input into a latent representation. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): - The types of down blocks to use. - See `~diffusers.models.unet_2d_blocks.get_down_block` - for available options. - block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): - The number of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - act_fn (`str`, *optional*, defaults to `"silu"`): - The activation function to use. - See `~diffusers.models.activations.get_activation` for available options. - double_z (`bool`, *optional*, defaults to `True`): - Whether to double the number of output channels for the last block. - """ - - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - double_z: bool = True, - mid_block_add_attention=True, - # [Override] add extra_cond_dim, temporal down num - temporal_down_num: int = 2, - extra_cond_dim: int = None, - gradient_checkpoint: bool = False, - inflation_mode: _inflation_mode_t = "tail", - time_receptive_field: _receptive_field_t = "half", - ): - super().__init__() - self.layers_per_block = layers_per_block - self.temporal_down_num = temporal_down_num - - self.conv_in = init_causal_conv3d( - in_channels, - block_out_channels[0], - kernel_size=3, - stride=1, - padding=1, - inflation_mode=inflation_mode, - ) - - self.mid_block = None - self.down_blocks = nn.ModuleList([]) - self.extra_cond_dim = extra_cond_dim - - self.conv_extra_cond = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - # [Override] to support temporal down block design - is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 - # Note: take the last ones - - assert down_block_type == "DownEncoderBlock3D" - - down_block = DownEncoderBlock3D( - num_layers=self.layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - add_downsample=not is_final_block, - resnet_eps=1e-6, - downsample_padding=0, - # Note: Don't know why set it as 0 - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - temporal_down=is_temporal_down_block, - spatial_down=True, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - self.down_blocks.append(down_block) - - def zero_module(module): - # Zero out the parameters of a module and return it. - for p in module.parameters(): - p.detach().zero_() - return module - - self.conv_extra_cond.append( - zero_module( - nn.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) - ) - if self.extra_cond_dim is not None and self.extra_cond_dim > 0 - else None - ) - - # mid - self.mid_block = UNetMidBlock3D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default", - attention_head_dim=block_out_channels[-1], - resnet_groups=norm_num_groups, - temb_channels=None, - add_attention=mid_block_add_attention, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - # out - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 - ) - self.conv_act = nn.SiLU() - - conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = init_causal_conv3d( - block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode - ) - - self.gradient_checkpointing = gradient_checkpoint - - def forward( - self, - sample: torch.FloatTensor, - extra_cond=None, - memory_state: MemoryState = MemoryState.DISABLED, - ) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" - sample = self.conv_in(sample, memory_state=memory_state) - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - # down - # [Override] add extra block and extra cond - for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), sample, memory_state, use_reentrant=False - ) - if extra_block is not None: - sample = sample + F.interpolate(extra_block(extra_cond), size=sample.shape[2:]) - - # middle - sample = self.mid_block(sample, memory_state=memory_state) - - # sample = torch.utils.checkpoint.checkpoint( - # create_custom_forward(self.mid_block), sample, use_reentrant=False - # ) - - else: - # down - # [Override] add extra block and extra cond - for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): - sample = down_block(sample, memory_state=memory_state) - if extra_block is not None: - sample = sample + F.interpolate(extra_block(extra_cond), size=sample.shape[2:]) - - # middle - sample = self.mid_block(sample, memory_state=memory_state) - - # post-process - sample = causal_norm_wrapper(self.conv_norm_out, sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample, memory_state=memory_state) - - return sample - - -class Decoder3D(nn.Module): - r""" - The `Decoder` layer of a variational autoencoder that - decodes its latent representation into an output sample. - - Args: - in_channels (`int`, *optional*, defaults to 3): - The number of input channels. - out_channels (`int`, *optional*, defaults to 3): - The number of output channels. - up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): - The types of up blocks to use. - See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. - block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): - The number of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): - The number of layers per block. - norm_num_groups (`int`, *optional*, defaults to 32): - The number of groups for normalization. - act_fn (`str`, *optional*, defaults to `"silu"`): - The activation function to use. - See `~diffusers.models.activations.get_activation` for available options. - norm_type (`str`, *optional*, defaults to `"group"`): - The normalization type to use. Can be either `"group"` or `"spatial"`. - """ - - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - up_block_types: Tuple[str, ...] = ("UpDecoderBlock3D",), - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - norm_type: str = "group", # group, spatial - mid_block_add_attention=True, - # [Override] add temporal up block - inflation_mode: _inflation_mode_t = "tail", - time_receptive_field: _receptive_field_t = "half", - temporal_up_num: int = 2, - slicing_up_num: int = 0, - gradient_checkpoint: bool = False, - ): - super().__init__() - self.layers_per_block = layers_per_block - self.temporal_up_num = temporal_up_num - - self.conv_in = init_causal_conv3d( - in_channels, - block_out_channels[-1], - kernel_size=3, - stride=1, - padding=1, - inflation_mode=inflation_mode, - ) - - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - temb_channels = in_channels if norm_type == "spatial" else None - - # mid - self.mid_block = UNetMidBlock3D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default" if norm_type == "group" else norm_type, - attention_head_dim=block_out_channels[-1], - resnet_groups=norm_num_groups, - temb_channels=temb_channels, - add_attention=mid_block_add_attention, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - print(f"slicing_up_num: {slicing_up_num}") - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - is_temporal_up_block = i < self.temporal_up_num - is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num - # Note: Keep symmetric - - assert up_block_type == "UpDecoderBlock3D" - up_block = UpDecoderBlock3D( - num_layers=self.layers_per_block + 1, - in_channels=prev_output_channel, - out_channels=output_channel, - add_upsample=not is_final_block, - resnet_eps=1e-6, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - resnet_time_scale_shift=norm_type, - temb_channels=temb_channels, - temporal_up=is_temporal_up_block, - slicing=is_slicing_up_block, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - if norm_type == "spatial": - self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) - else: - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 - ) - self.conv_act = nn.SiLU() - self.conv_out = init_causal_conv3d( - block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode - ) - - self.gradient_checkpointing = gradient_checkpoint - - # Note: Just copy from Decoder. - def forward( - self, - sample: torch.FloatTensor, - latent_embeds: Optional[torch.FloatTensor] = None, - memory_state: MemoryState = MemoryState.DISABLED, - ) -> torch.FloatTensor: - r"""The forward method of the `Decoder` class.""" - - sample = self.conv_in(sample, memory_state=memory_state) - - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - if is_torch_version(">=", "1.11.0"): - sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) - sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), - sample, - latent_embeds, - memory_state, - use_reentrant=False, - ) - else: - # middle - sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) - sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(up_block), sample, latent_embeds, memory_state - ) - else: - # middle - sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) - sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = up_block(sample, latent_embeds, memory_state=memory_state) - - # post-process - sample = causal_norm_wrapper(self.conv_norm_out, sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample, memory_state=memory_state) - - return sample - - -class AutoencoderKL(diffusers.AutoencoderKL): - """ - We simply inherit the model code from diffusers - """ - - def __init__(self, attention: bool = True, *args, **kwargs): - super().__init__(*args, **kwargs) - - # A hacky way to remove attention. - if not attention: - self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) - self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) - - def load_state_dict(self, state_dict, strict=True): - # Newer version of diffusers changed the model keys, - # causing incompatibility with old checkpoints. - # They provided a method for conversion. We call conversion before loading state_dict. - convert_deprecated_attention_blocks = getattr( - self, "_convert_deprecated_attention_blocks", None - ) - if callable(convert_deprecated_attention_blocks): - convert_deprecated_attention_blocks(state_dict) - return super().load_state_dict(state_dict, strict) - - -class VideoAutoencoderKL(diffusers.AutoencoderKL): - """ - We simply inherit the model code from diffusers - """ - - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str] = ("DownEncoderBlock3D",), - up_block_types: Tuple[str] = ("UpDecoderBlock3D",), - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, - act_fn: str = "silu", - latent_channels: int = 4, - norm_num_groups: int = 32, - sample_size: int = 32, - scaling_factor: float = 0.18215, - force_upcast: float = True, - attention: bool = True, - temporal_scale_num: int = 2, - slicing_up_num: int = 0, - gradient_checkpoint: bool = False, - inflation_mode: _inflation_mode_t = "tail", - time_receptive_field: _receptive_field_t = "full", - slicing_sample_min_size: int = 32, - use_quant_conv: bool = True, - use_post_quant_conv: bool = True, - *args, - **kwargs, - ): - extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None - self.slicing_sample_min_size = slicing_sample_min_size - self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) - - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - # [Override] make sure it can be normally initialized - down_block_types=tuple( - [down_block_type.replace("3D", "2D") for down_block_type in down_block_types] - ), - up_block_types=tuple( - [up_block_type.replace("3D", "2D") for up_block_type in up_block_types] - ), - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - latent_channels=latent_channels, - norm_num_groups=norm_num_groups, - sample_size=sample_size, - scaling_factor=scaling_factor, - force_upcast=force_upcast, - *args, - **kwargs, - ) - - # pass init params to Encoder - self.encoder = Encoder3D( - in_channels=in_channels, - out_channels=latent_channels, - down_block_types=down_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=True, - extra_cond_dim=extra_cond_dim, - # [Override] add temporal_down_num parameter - temporal_down_num=temporal_scale_num, - gradient_checkpoint=gradient_checkpoint, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - # pass init params to Decoder - self.decoder = Decoder3D( - in_channels=latent_channels, - out_channels=out_channels, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - norm_num_groups=norm_num_groups, - act_fn=act_fn, - # [Override] add temporal_up_num parameter - temporal_up_num=temporal_scale_num, - slicing_up_num=slicing_up_num, - gradient_checkpoint=gradient_checkpoint, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - self.quant_conv = ( - init_causal_conv3d( - in_channels=2 * latent_channels, - out_channels=2 * latent_channels, - kernel_size=1, - inflation_mode=inflation_mode, - ) - if use_quant_conv - else None - ) - self.post_quant_conv = ( - init_causal_conv3d( - in_channels=latent_channels, - out_channels=latent_channels, - kernel_size=1, - inflation_mode=inflation_mode, - ) - if use_post_quant_conv - else None - ) - - # A hacky way to remove attention. - if not attention: - self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) - self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) - - @apply_forward_hook - def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: - h = self.slicing_encode(x) - posterior = DiagonalGaussianDistribution(h) - - if not return_dict: - return (posterior,) - - return AutoencoderKLOutput(latent_dist=posterior) - - @apply_forward_hook - def decode( - self, z: torch.Tensor, return_dict: bool = True - ) -> Union[DecoderOutput, torch.Tensor]: - decoded = self.slicing_decode(z) - - if not return_dict: - return (decoded,) - - return DecoderOutput(sample=decoded) - - def _encode( - self, x: torch.Tensor, memory_state: MemoryState = MemoryState.DISABLED - ) -> torch.Tensor: - _x = x.to(self.device) - _x = causal_conv_slice_inputs(_x, self.slicing_sample_min_size, memory_state=memory_state) - h = self.encoder(_x, memory_state=memory_state) - if self.quant_conv is not None: - output = self.quant_conv(h, memory_state=memory_state) - else: - output = h - output = causal_conv_gather_outputs(output) - return output.to(x.device) - - def _decode( - self, z: torch.Tensor, memory_state: MemoryState = MemoryState.DISABLED - ) -> torch.Tensor: - _z = z.to(self.device) - _z = causal_conv_slice_inputs(_z, self.slicing_latent_min_size, memory_state=memory_state) - if self.post_quant_conv is not None: - _z = self.post_quant_conv(_z, memory_state=memory_state) - output = self.decoder(_z, memory_state=memory_state) - output = causal_conv_gather_outputs(output) - return output.to(z.device) - - def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: - sp_size = get_sequence_parallel_world_size() - if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: - x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2) - encoded_slices = [ - self._encode( - torch.cat((x[:, :, :1], x_slices[0]), dim=2), - memory_state=MemoryState.INITIALIZING, - ) - ] - for x_idx in range(1, len(x_slices)): - encoded_slices.append( - self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) - ) - return torch.cat(encoded_slices, dim=2) - else: - return self._encode(x) - - def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: - sp_size = get_sequence_parallel_world_size() - if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: - z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) - decoded_slices = [ - self._decode( - torch.cat((z[:, :, :1], z_slices[0]), dim=2), - memory_state=MemoryState.INITIALIZING, - ) - ] - for z_idx in range(1, len(z_slices)): - decoded_slices.append( - self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) - ) - return torch.cat(decoded_slices, dim=2) - else: - return self._decode(z) - - def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: - raise NotImplementedError - - def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: - raise NotImplementedError - - def forward( - self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs - ): - # x: [b c t h w] - if mode == "encode": - h = self.encode(x) - return h.latent_dist - elif mode == "decode": - h = self.decode(x) - return h.sample - else: - h = self.encode(x) - h = self.decode(h.latent_dist.mode()) - return h.sample - - def load_state_dict(self, state_dict, strict=False): - # Newer version of diffusers changed the model keys, - # causing incompatibility with old checkpoints. - # They provided a method for conversion. - # We call conversion before loading state_dict. - convert_deprecated_attention_blocks = getattr( - self, "_convert_deprecated_attention_blocks", None - ) - if callable(convert_deprecated_attention_blocks): - convert_deprecated_attention_blocks(state_dict) - return super().load_state_dict(state_dict, strict) - - -class VideoAutoencoderKLWrapper(VideoAutoencoderKL): - def __init__( - self, - *args, - spatial_downsample_factor: int, - temporal_downsample_factor: int, - freeze_encoder: bool, - **kwargs, - ): - self.spatial_downsample_factor = spatial_downsample_factor - self.temporal_downsample_factor = temporal_downsample_factor - self.freeze_encoder = freeze_encoder - super().__init__(*args, **kwargs) - - def forward(self, x: torch.FloatTensor) -> CausalAutoencoderOutput: - with torch.no_grad() if self.freeze_encoder else nullcontext(): - z, p = self.encode(x) - x = self.decode(z).sample - return CausalAutoencoderOutput(x, z, p) - - def encode(self, x: torch.FloatTensor) -> CausalEncoderOutput: - if x.ndim == 4: - x = x.unsqueeze(2) - p = super().encode(x).latent_dist - z = p.sample().squeeze(2) - return CausalEncoderOutput(z, p) - - def decode(self, z: torch.FloatTensor) -> CausalDecoderOutput: - if z.ndim == 4: - z = z.unsqueeze(2) - x = super().decode(z).sample.squeeze(2) - return CausalDecoderOutput(x) - - def preprocess(self, x: torch.Tensor): - # x should in [B, C, T, H, W], [B, C, H, W] - assert x.ndim == 4 or x.size(2) % 4 == 1 - return x - - def postprocess(self, x: torch.Tensor): - # x should in [B, C, T, H, W], [B, C, H, W] - return x - - def set_causal_slicing( - self, - *, - split_size: Optional[int], - memory_device: _memory_device_t, - ): - assert ( - split_size is None or memory_device is not None - ), "if split_size is set, memory_device must not be None." - if split_size is not None: - self.enable_slicing() - self.slicing_sample_min_size = split_size - self.slicing_latent_min_size = split_size // self.temporal_downsample_factor - else: - self.disable_slicing() - for module in self.modules(): - if isinstance(module, InflatedCausalConv3d): - module.set_memory_device(memory_device) - - def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): - set_norm_limit(norm_max_mem) - for m in self.modules(): - if isinstance(m, InflatedCausalConv3d): - m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) diff --git a/modelsx/video_vae_v3/modules/causal_inflation_lib.py b/modelsx/video_vae_v3/modules/causal_inflation_lib.py deleted file mode 100644 index fdd3cbe2512b119d76729c4103325ac22e0b12fe..0000000000000000000000000000000000000000 --- a/modelsx/video_vae_v3/modules/causal_inflation_lib.py +++ /dev/null @@ -1,460 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -import math -from contextlib import contextmanager -from typing import List, Optional, Union -import torch -import torch.distributed as dist -import torch.nn.functional as F -from diffusers.models.normalization import RMSNorm -from einops import rearrange -from torch import Tensor, nn -from torch.nn import Conv3d - -from common.distributed.advanced import ( - get_next_sequence_parallel_rank, - get_prev_sequence_parallel_rank, - get_sequence_parallel_group, - get_sequence_parallel_rank, - get_sequence_parallel_world_size, -) -from common.logger import get_logger -from models.video_vae_v3.modules.context_parallel_lib import cache_send_recv, get_cache_size -from models.video_vae_v3.modules.global_config import get_norm_limit -from models.video_vae_v3.modules.types import MemoryState, _inflation_mode_t, _memory_device_t - -logger = get_logger(__name__) - - -@contextmanager -def ignore_padding(model): - orig_padding = model.padding - model.padding = (0, 0, 0) - try: - yield - finally: - model.padding = orig_padding - - -class InflatedCausalConv3d(Conv3d): - def __init__( - self, - *args, - inflation_mode: _inflation_mode_t, - memory_device: _memory_device_t = "same", - **kwargs, - ): - self.inflation_mode = inflation_mode - self.memory = None - super().__init__(*args, **kwargs) - self.temporal_padding = self.padding[0] - self.memory_device = memory_device - self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal. - self.memory_limit = float("inf") - - def set_memory_limit(self, value: float): - self.memory_limit = value - - def set_memory_device(self, memory_device: _memory_device_t): - self.memory_device = memory_device - - def memory_limit_conv( - self, - x, - *, - split_dim=3, - padding=(0, 0, 0, 0, 0, 0), - prev_cache=None, - ): - # Compatible with no limit. - if math.isinf(self.memory_limit): - if prev_cache is not None: - x = torch.cat([prev_cache, x], dim=split_dim - 1) - return super().forward(x) - - # Compute tensor shape after concat & padding. - shape = torch.tensor(x.size()) - if prev_cache is not None: - shape[split_dim - 1] += prev_cache.size(split_dim - 1) - shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0) - memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB - logger.debug( - f"x:{(shape, x.dtype)} {memory_occupy:.3f}GiB " - f"prev_cache:{prev_cache.shape if prev_cache is not None else None}" - ) - if memory_occupy < self.memory_limit or split_dim == x.ndim: - if prev_cache is not None: - x = torch.cat([prev_cache, x], dim=split_dim - 1) - x = F.pad(x, padding, value=0.0) - with ignore_padding(self): - return super().forward(x) - - logger.debug( - f"Exceed memory limit {memory_occupy} > {self.memory_limit}, split dim {split_dim}" - ) - - # Split input (& prev_cache). - num_splits = math.ceil(memory_occupy / self.memory_limit) - size_per_split = x.size(split_dim) // num_splits - split_sizes = [size_per_split] * (num_splits - 1) - split_sizes += [x.size(split_dim) - sum(split_sizes)] - - x = list(x.split(split_sizes, dim=split_dim)) - logger.debug(f"Conv inputs: {[inp.size() for inp in x]} {x[0].dtype}") - if prev_cache is not None: - prev_cache = list(prev_cache.split(split_sizes, dim=split_dim)) - - # Loop Fwd. - cache = None - for idx in range(len(x)): - # Concat prev cache from last dim - if prev_cache is not None: - x[idx] = torch.cat([prev_cache[idx], x[idx]], dim=split_dim - 1) - - # Get padding pattern. - lpad_dim = (x[idx].ndim - split_dim - 1) * 2 - rpad_dim = lpad_dim + 1 - padding = list(padding) - padding[lpad_dim] = self.padding[split_dim - 2] if idx == 0 else 0 - padding[rpad_dim] = self.padding[split_dim - 2] if idx == len(x) - 1 else 0 - pad_len = padding[lpad_dim] + padding[rpad_dim] - padding = tuple(padding) - - # Prepare cache for next slice (this dim). - next_cache = None - cache_len = cache.size(split_dim) if cache is not None else 0 - next_catch_size = get_cache_size( - conv_module=self, - input_len=x[idx].size(split_dim) + cache_len, - pad_len=pad_len, - dim=split_dim - 2, - ) - if next_catch_size != 0: - assert next_catch_size <= x[idx].size(split_dim) - next_cache = ( - x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim) - ) - - # Recursive. - x[idx] = self.memory_limit_conv( - x[idx], - split_dim=split_dim + 1, - padding=padding, - prev_cache=cache, - ) - - # Update cache. - cache = next_cache - - logger.debug(f"Conv outputs, concat(dim={split_dim}): {[d.size() for d in x]}") - return torch.cat(x, split_dim) - - def forward( - self, - input: Union[Tensor, List[Tensor]], - memory_state: MemoryState = MemoryState.UNSET, - ) -> Tensor: - assert memory_state != MemoryState.UNSET - if memory_state != MemoryState.ACTIVE: - self.memory = None - if ( - math.isinf(self.memory_limit) - and torch.is_tensor(input) - and get_sequence_parallel_group() is None - ): - return self.basic_forward(input, memory_state) - return self.slicing_forward(input, memory_state) - - def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET): - mem_size = self.stride[0] - self.kernel_size[0] - if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): - input = extend_head(input, memory=self.memory, times=-1) - else: - input = extend_head(input, times=self.temporal_padding * 2) - memory = ( - input[:, :, mem_size:].detach() - if (mem_size != 0 and memory_state != MemoryState.DISABLED) - else None - ) - if ( - memory_state != MemoryState.DISABLED - and not self.training - and (self.memory_device is not None) - ): - self.memory = memory - if self.memory_device == "cpu" and self.memory is not None: - self.memory = self.memory.to("cpu") - return super().forward(input) - - def slicing_forward( - self, - input: Union[Tensor, List[Tensor]], - memory_state: MemoryState = MemoryState.UNSET, - ) -> Tensor: - squeeze_out = False - if torch.is_tensor(input): - input = [input] - squeeze_out = True - - cache_size = self.kernel_size[0] - self.stride[0] - cache = cache_send_recv( - input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2 - ) - - # For slice=4 and sp=2, and 17 frames in total - # sp0 sp1 - # slice 0: [`0 0` 0 1 2 {3 4}] [{3 4} 5 6 (7 8)] extend=`0 0` cache={3 4} memory=(7 8) - # slice 1: [(7 8) 9 10 {11 12}] [{11 12} 13 14 15 16] - sp_rank = get_sequence_parallel_rank() - sp_size = get_sequence_parallel_world_size() - sp_group = get_sequence_parallel_group() - send_dst = get_next_sequence_parallel_rank() - recv_src = get_prev_sequence_parallel_rank() - if ( - memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing - and not self.training - and (self.memory_device is not None) - and sp_rank in [0, sp_size - 1] - and cache_size != 0 - ): - if cache_size > input[-1].size(2) and cache is not None and len(input) == 1: - input[0] = torch.cat([cache, input[0]], dim=2) - cache = None - assert cache_size <= input[-1].size(2) - if sp_size == 1: - self.memory = input[-1][:, :, -cache_size:].detach().contiguous() - else: - if sp_rank == sp_size - 1: - dist.send( - input[-1][:, :, -cache_size:].detach().contiguous(), - send_dst, - group=sp_group, - ) - if sp_rank == 0: - shape = list(input[0].size()) - shape[2] = cache_size - self.memory = torch.empty( - *shape, device=input[0].device, dtype=input[0].dtype - ).contiguous() - dist.recv(self.memory, recv_src, group=sp_group) - if self.memory_device == "cpu" and self.memory is not None: - self.memory = self.memory.to("cpu") - - padding = tuple(x for x in reversed(self.padding) for _ in range(2)) - for i in range(len(input)): - # Prepare cache for next input slice. - next_cache = None - cache_size = 0 - if i < len(input) - 1: - cache_len = cache.size(2) if cache is not None else 0 - cache_size = get_cache_size(self, input[i].size(2) + cache_len, pad_len=0) - if cache_size != 0: - if cache_size > input[i].size(2) and cache is not None: - input[i] = torch.cat([cache, input[i]], dim=2) - cache = None - assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}" - next_cache = input[i][:, :, -cache_size:] - - # Conv forward for this input slice. - input[i] = self.memory_limit_conv( - input[i], - padding=padding, - prev_cache=cache, - ) - - # Update cache. - cache = next_cache - - return input[0] if squeeze_out else input - - def tflops(self, args, kwargs, output) -> float: - if torch.is_tensor(output): - output_numel = output.numel() - elif isinstance(output, list): - output_numel = sum(o.numel() for o in output) - else: - raise NotImplementedError - return (2 * math.prod(self.kernel_size) * self.in_channels * (output_numel / 1e6)) / 1e6 - - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - if self.inflation_mode != "none": - state_dict = modify_state_dict( - self, - state_dict, - prefix, - inflate_weight_fn=inflate_weight, - inflate_bias_fn=inflate_bias, - ) - super()._load_from_state_dict( - state_dict, - prefix, - local_metadata, - (strict and self.inflation_mode == "none"), - missing_keys, - unexpected_keys, - error_msgs, - ) - - -def init_causal_conv3d( - *args, - inflation_mode: _inflation_mode_t, - **kwargs, -): - """ - Initialize a Causal-3D convolution layer. - Parameters: - inflation_mode: Listed as below. It's compatible with all the 3D-VAE checkpoints we have. - - none: No inflation will be conducted. - The loading logic of state dict will fall back to default. - - tail / replicate: Refer to the definition of `InflatedCausalConv3d`. - """ - return InflatedCausalConv3d(*args, inflation_mode=inflation_mode, **kwargs) - - -def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: - input_dtype = x.dtype - if isinstance(norm_layer, (nn.LayerNorm, RMSNorm)): - if x.ndim == 4: - x = rearrange(x, "b c h w -> b h w c") - x = norm_layer(x) - x = rearrange(x, "b h w c -> b c h w") - return x.to(input_dtype) - if x.ndim == 5: - x = rearrange(x, "b c t h w -> b t h w c") - x = norm_layer(x) - x = rearrange(x, "b t h w c -> b c t h w") - return x.to(input_dtype) - if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): - if x.ndim <= 4: - return norm_layer(x).to(input_dtype) - if x.ndim == 5: - t = x.size(2) - x = rearrange(x, "b c t h w -> (b t) c h w") - memory_occupy = x.numel() * x.element_size() / 1024**3 - if isinstance(norm_layer, nn.GroupNorm) and memory_occupy > get_norm_limit(): - num_chunks = min(4 if x.element_size() == 2 else 2, norm_layer.num_groups) - logger.debug(f"large tensor {x.shape}, norm in {num_chunks} chunks") - assert norm_layer.num_groups % num_chunks == 0 - num_groups_per_chunk = norm_layer.num_groups // num_chunks - - x = list(x.chunk(num_chunks, dim=1)) - weights = norm_layer.weight.chunk(num_chunks, dim=0) - biases = norm_layer.bias.chunk(num_chunks, dim=0) - for i, (w, b) in enumerate(zip(weights, biases)): - x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps) - x[i] = x[i].to(input_dtype) - x = torch.cat(x, dim=1) - else: - x = norm_layer(x) - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - return x.to(input_dtype) - raise NotImplementedError - - -def remove_head(tensor: Tensor, times: int = 1) -> Tensor: - """ - Remove duplicated first frame features in the up-sampling process. - """ - sp_rank = get_sequence_parallel_rank() - if times == 0 or sp_rank > 0: - return tensor - return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2) - - -def extend_head(tensor: Tensor, times: int = 2, memory: Optional[Tensor] = None) -> Tensor: - """ - When memory is None: - - Duplicate first frame features in the down-sampling process. - When memory is not None: - - Concatenate memory features with the input features to keep temporal consistency. - """ - if memory is not None: - return torch.cat((memory.to(tensor), tensor), dim=2) - assert times >= 0, "Invalid input for function 'extend_head'!" - if times == 0: - return tensor - else: - tile_repeat = [1] * tensor.ndim - tile_repeat[2] = times - return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2) - - -def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str): - """ - Inflate a 2D convolution weight matrix to a 3D one. - Parameters: - weight_2d: The weight matrix of 2D conv to be inflated. - weight_3d: The weight matrix of 3D conv to be initialized. - inflation_mode: the mode of inflation - """ - assert inflation_mode in ["tail", "replicate"] - assert weight_3d.shape[:2] == weight_2d.shape[:2] - with torch.no_grad(): - if inflation_mode == "replicate": - depth = weight_3d.size(2) - weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) - else: - weight_3d.fill_(0.0) - weight_3d[:, :, -1].copy_(weight_2d) - return weight_3d - - -def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str): - """ - Inflate a 2D convolution bias tensor to a 3D one - Parameters: - bias_2d: The bias tensor of 2D conv to be inflated. - bias_3d: The bias tensor of 3D conv to be initialized. - inflation_mode: Placeholder to align `inflate_weight`. - """ - assert bias_3d.shape == bias_2d.shape - with torch.no_grad(): - bias_3d.copy_(bias_2d) - return bias_3d - - -def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): - """ - the main function to inflated 2D parameters to 3D. - """ - weight_name = prefix + "weight" - bias_name = prefix + "bias" - if weight_name in state_dict: - weight_2d = state_dict[weight_name] - if weight_2d.dim() == 4: - # Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w) - weight_3d = inflate_weight_fn( - weight_2d=weight_2d, - weight_3d=layer.weight, - inflation_mode=layer.inflation_mode, - ) - state_dict[weight_name] = weight_3d - else: - return state_dict - # It's a 3d state dict, should not do inflation on both bias and weight. - if bias_name in state_dict: - bias_2d = state_dict[bias_name] - if bias_2d.dim() == 1: - # Assuming the 2D biases are 1D tensors (out_channels,) - bias_3d = inflate_bias_fn( - bias_2d=bias_2d, - bias_3d=layer.bias, - inflation_mode=layer.inflation_mode, - ) - state_dict[bias_name] = bias_3d - return state_dict diff --git a/modelsx/video_vae_v3/modules/context_parallel_lib.py b/modelsx/video_vae_v3/modules/context_parallel_lib.py deleted file mode 100644 index 55cfe481ee2ade7434166bfda0b83589b423137c..0000000000000000000000000000000000000000 --- a/modelsx/video_vae_v3/modules/context_parallel_lib.py +++ /dev/null @@ -1,164 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import List -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor - -from common.distributed import get_device -from common.distributed.advanced import ( - get_next_sequence_parallel_rank, - get_prev_sequence_parallel_rank, - get_sequence_parallel_group, - get_sequence_parallel_rank, - get_sequence_parallel_world_size, -) -from common.distributed.ops import Gather -from common.logger import get_logger -from models.video_vae_v3.modules.types import MemoryState - -logger = get_logger(__name__) - - -def causal_conv_slice_inputs(x, split_size, memory_state): - sp_size = get_sequence_parallel_world_size() - sp_group = get_sequence_parallel_group() - sp_rank = get_sequence_parallel_rank() - if sp_group is None: - return x - - assert memory_state != MemoryState.UNSET - leave_out = 1 if memory_state != MemoryState.ACTIVE else 0 - - # Should have at least sp_size slices. - num_slices = (x.size(2) - leave_out) // split_size - assert num_slices >= sp_size, f"{num_slices} < {sp_size}" - - split_sizes = [split_size + leave_out] + [split_size] * (num_slices - 1) - split_sizes += [x.size(2) - sum(split_sizes)] - assert sum(split_sizes) == x.size(2) - - split_sizes = torch.tensor(split_sizes) - slices_per_rank = len(split_sizes) // sp_size - split_sizes = split_sizes.split( - [slices_per_rank] * (sp_size - 1) + [len(split_sizes) - slices_per_rank * (sp_size - 1)] - ) - split_sizes = list(map(lambda s: s.sum().item(), split_sizes)) - logger.debug(f"split_sizes: {split_sizes}") - return x.split(split_sizes, dim=2)[sp_rank] - - -def causal_conv_gather_outputs(x): - sp_group = get_sequence_parallel_group() - sp_size = get_sequence_parallel_world_size() - if sp_group is None: - return x - - # Communicate shapes. - unpad_lens = torch.empty((sp_size,), device=get_device(), dtype=torch.long) - local_unpad_len = torch.tensor([x.size(2)], device=get_device(), dtype=torch.long) - torch.distributed.all_gather_into_tensor(unpad_lens, local_unpad_len, group=sp_group) - - # Padding to max_len for gather. - max_len = unpad_lens.max() - x_pad = F.pad(x, (0, 0, 0, 0, 0, max_len - x.size(2))).contiguous() - - # Gather outputs. - x_pad = Gather.apply(sp_group, x_pad, 2, True) - - # Remove padding. - x_pad_lists = list(x_pad.chunk(sp_size, dim=2)) - for i, (x_pad, unpad_len) in enumerate(zip(x_pad_lists, unpad_lens)): - x_pad_lists[i] = x_pad[:, :, :unpad_len] - - return torch.cat(x_pad_lists, dim=2) - - -def get_output_len(conv_module, input_len, pad_len, dim=0): - dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 - output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 - return output_len - - -def get_cache_size(conv_module, input_len, pad_len, dim=0): - dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 - output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 - remain_len = ( - input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size) - ) - overlap_len = dilated_kernerl_size - conv_module.stride[dim] - cache_len = overlap_len + remain_len # >= 0 - logger.debug( - f"I:{input_len}, " - f"P:{pad_len}, " - f"K:{conv_module.kernel_size[dim]}, " - f"S:{conv_module.stride[dim]}, " - f"O:{output_len}, " - f"Cache:{cache_len}" - ) - assert output_len > 0 - return cache_len - - -def cache_send_recv(tensor: List[Tensor], cache_size, times, memory=None): - sp_group = get_sequence_parallel_group() - sp_rank = get_sequence_parallel_rank() - sp_size = get_sequence_parallel_world_size() - send_dst = get_next_sequence_parallel_rank() - recv_src = get_prev_sequence_parallel_rank() - recv_buffer = None - recv_req = None - - logger.debug( - f"[sp{sp_rank}] cur_tensors:{[(t.size(), t.dtype) for t in tensor]}, times: {times}" - ) - if sp_rank == 0 or sp_group is None: - if memory is not None: - recv_buffer = memory.to(tensor[0]) - elif times > 0: - tile_repeat = [1] * tensor[0].ndim - tile_repeat[2] = times - recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat) - - if cache_size != 0 and sp_group is not None: - if sp_rank > 0: - shape = list(tensor[0].size()) - shape[2] = cache_size - recv_buffer = torch.empty( - *shape, device=tensor[0].device, dtype=tensor[0].dtype - ).contiguous() - recv_req = dist.irecv(recv_buffer, recv_src, group=sp_group) - if sp_rank < sp_size - 1: - if cache_size > tensor[-1].size(2) and len(tensor) == 1: - logger.debug(f"[sp{sp_rank}] force concat before send {tensor[-1].size()}") - if recv_req is not None: - recv_req.wait() - tensor[0] = torch.cat([recv_buffer, tensor[0]], dim=2) - recv_buffer = None - assert cache_size <= tensor[-1].size( - 2 - ), f"Not enough value to cache, got {tensor[-1].size()}, cache_size={cache_size}" - dist.isend( - tensor[-1][:, :, -cache_size:].detach().contiguous(), send_dst, group=sp_group - ) - if recv_req is not None: - recv_req.wait() - - logger.debug( - f"[sp{sp_rank}] recv_src:{recv_src}, " - f"recv_buffer:{recv_buffer.size() if recv_buffer is not None else None}" - ) - return recv_buffer diff --git a/modelsx/video_vae_v3/modules/global_config.py b/modelsx/video_vae_v3/modules/global_config.py deleted file mode 100644 index 863117570a8aadde38b8eae8f1aa16480cd9f7ca..0000000000000000000000000000000000000000 --- a/modelsx/video_vae_v3/modules/global_config.py +++ /dev/null @@ -1,28 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import Optional - -_NORM_LIMIT = float("inf") - - -def get_norm_limit(): - return _NORM_LIMIT - - -def set_norm_limit(value: Optional[float] = None): - global _NORM_LIMIT - if value is None: - value = float("inf") - _NORM_LIMIT = value diff --git a/modelsx/video_vae_v3/modules/inflated_layers.py b/modelsx/video_vae_v3/modules/inflated_layers.py deleted file mode 100644 index 8dfa4841a4ba3e4f758831396497b246614c7bb5..0000000000000000000000000000000000000000 --- a/modelsx/video_vae_v3/modules/inflated_layers.py +++ /dev/null @@ -1,106 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from functools import partial -from typing import Literal, Optional -from torch import Tensor -from torch.nn import Conv3d - -from models.video_vae_v3.modules.inflated_lib import ( - MemoryState, - extend_head, - inflate_bias, - inflate_weight, - modify_state_dict, -) - -_inflation_mode_t = Literal["none", "tail", "replicate"] -_memory_device_t = Optional[Literal["cpu", "same"]] - - -class InflatedCausalConv3d(Conv3d): - def __init__( - self, - *args, - inflation_mode: _inflation_mode_t, - memory_device: _memory_device_t = "same", - **kwargs, - ): - self.inflation_mode = inflation_mode - self.memory = None - super().__init__(*args, **kwargs) - self.temporal_padding = self.padding[0] - self.memory_device = memory_device - self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal. - - def set_memory_device(self, memory_device: _memory_device_t): - self.memory_device = memory_device - - def forward(self, input: Tensor, memory_state: MemoryState = MemoryState.DISABLED) -> Tensor: - mem_size = self.stride[0] - self.kernel_size[0] - if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): - input = extend_head(input, memory=self.memory) - else: - input = extend_head(input, times=self.temporal_padding * 2) - memory = ( - input[:, :, mem_size:].detach() - if (mem_size != 0 and memory_state != MemoryState.DISABLED) - else None - ) - if ( - memory_state != MemoryState.DISABLED - and not self.training - and (self.memory_device is not None) - ): - self.memory = memory - if self.memory_device == "cpu" and self.memory is not None: - self.memory = self.memory.to("cpu") - return super().forward(input) - - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - if self.inflation_mode != "none": - state_dict = modify_state_dict( - self, - state_dict, - prefix, - inflate_weight_fn=partial(inflate_weight, position="tail"), - inflate_bias_fn=partial(inflate_bias, position="tail"), - ) - super()._load_from_state_dict( - state_dict, - prefix, - local_metadata, - (strict and self.inflation_mode == "none"), - missing_keys, - unexpected_keys, - error_msgs, - ) - - -def init_causal_conv3d( - *args, - inflation_mode: _inflation_mode_t, - **kwargs, -): - """ - Initialize a Causal-3D convolution layer. - Parameters: - inflation_mode: Listed as below. It's compatible with all the 3D-VAE checkpoints we have. - - none: No inflation will be conducted. - The loading logic of state dict will fall back to default. - - tail / replicate: Refer to the definition of `InflatedCausalConv3d`. - """ - return InflatedCausalConv3d(*args, inflation_mode=inflation_mode, **kwargs) diff --git a/modelsx/video_vae_v3/modules/inflated_lib.py b/modelsx/video_vae_v3/modules/inflated_lib.py deleted file mode 100644 index cbdaf3138bb5994c4702185426f854a4660cc6a4..0000000000000000000000000000000000000000 --- a/modelsx/video_vae_v3/modules/inflated_lib.py +++ /dev/null @@ -1,156 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from enum import Enum -from typing import Optional -import numpy as np -import torch -from diffusers.models.normalization import RMSNorm -from einops import rearrange -from torch import Tensor, nn - -from common.logger import get_logger - -logger = get_logger(__name__) - - -class MemoryState(Enum): - """ - State[Disabled]: No memory bank will be enabled. - State[Initializing]: The model is handling the first clip, - need to reset / initialize the memory bank. - State[Active]: There has been some data in the memory bank. - """ - - DISABLED = 0 - INITIALIZING = 1 - ACTIVE = 2 - - -def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: - if isinstance(norm_layer, (nn.LayerNorm, RMSNorm)): - if x.ndim == 4: - x = rearrange(x, "b c h w -> b h w c") - x = norm_layer(x) - x = rearrange(x, "b h w c -> b c h w") - return x - if x.ndim == 5: - x = rearrange(x, "b c t h w -> b t h w c") - x = norm_layer(x) - x = rearrange(x, "b t h w c -> b c t h w") - return x - if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): - if x.ndim <= 4: - return norm_layer(x) - if x.ndim == 5: - t = x.size(2) - x = rearrange(x, "b c t h w -> (b t) c h w") - x = norm_layer(x) - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - return x - raise NotImplementedError - - -def remove_head(tensor: Tensor, times: int = 1) -> Tensor: - """ - Remove duplicated first frame features in the up-sampling process. - """ - if times == 0: - return tensor - return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2) - - -def extend_head( - tensor: Tensor, times: Optional[int] = 2, memory: Optional[Tensor] = None -) -> Tensor: - """ - When memory is None: - - Duplicate first frame features in the down-sampling process. - When memory is not None: - - Concatenate memory features with the input features to keep temporal consistency. - """ - if times == 0: - return tensor - if memory is not None: - return torch.cat((memory.to(tensor), tensor), dim=2) - else: - tile_repeat = np.ones(tensor.ndim).astype(int) - tile_repeat[2] = times - return torch.cat(tensors=(torch.tile(tensor[:, :, :1], list(tile_repeat)), tensor), dim=2) - - -def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str): - """ - Inflate a 2D convolution weight matrix to a 3D one. - Parameters: - weight_2d: The weight matrix of 2D conv to be inflated. - weight_3d: The weight matrix of 3D conv to be initialized. - inflation_mode: the mode of inflation - """ - assert inflation_mode in ["constant", "replicate"] - assert weight_3d.shape[:2] == weight_2d.shape[:2] - with torch.no_grad(): - if inflation_mode == "replicate": - depth = weight_3d.size(2) - weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) - else: - weight_3d.fill_(0.0) - weight_3d[:, :, -1].copy_(weight_2d) - return weight_3d - - -def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str): - """ - Inflate a 2D convolution bias tensor to a 3D one - Parameters: - bias_2d: The bias tensor of 2D conv to be inflated. - bias_3d: The bias tensor of 3D conv to be initialized. - inflation_mode: Placeholder to align `inflate_weight`. - """ - assert bias_3d.shape == bias_2d.shape - with torch.no_grad(): - bias_3d.copy_(bias_2d) - return bias_3d - - -def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): - """ - the main function to inflated 2D parameters to 3D. - """ - weight_name = prefix + "weight" - bias_name = prefix + "bias" - if weight_name in state_dict: - weight_2d = state_dict[weight_name] - if weight_2d.dim() == 4: - # Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w) - weight_3d = inflate_weight_fn( - weight_2d=weight_2d, - weight_3d=layer.weight, - inflation_mode=layer.inflation_mode, - ) - state_dict[weight_name] = weight_3d - else: - return state_dict - # It's a 3d state dict, should not do inflation on both bias and weight. - if bias_name in state_dict: - bias_2d = state_dict[bias_name] - if bias_2d.dim() == 1: - # Assuming the 2D biases are 1D tensors (out_channels,) - bias_3d = inflate_bias_fn( - bias_2d=bias_2d, - bias_3d=layer.bias, - inflation_mode=layer.inflation_mode, - ) - state_dict[bias_name] = bias_3d - return state_dict diff --git a/modelsx/video_vae_v3/modules/types.py b/modelsx/video_vae_v3/modules/types.py deleted file mode 100644 index 5a030d2d284f9535f2a84c1f9befcd3f82d8d9ff..0000000000000000000000000000000000000000 --- a/modelsx/video_vae_v3/modules/types.py +++ /dev/null @@ -1,76 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from enum import Enum -from typing import Dict, Literal, NamedTuple, Optional -import torch - -_receptive_field_t = Literal["half", "full"] -_inflation_mode_t = Literal["none", "tail", "replicate"] -_memory_device_t = Optional[Literal["cpu", "same"]] -_gradient_checkpointing_t = Optional[Literal["half", "full"]] -_selective_checkpointing_t = Optional[Literal["coarse", "fine"]] - -class DiagonalGaussianDistribution: - def __init__(self, mean: torch.Tensor, logvar: torch.Tensor): - self.mean = mean - self.logvar = torch.clamp(logvar, -30.0, 20.0) - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - - def mode(self) -> torch.Tensor: - return self.mean - - def sample(self) -> torch.FloatTensor: - return self.mean + self.std * torch.randn_like(self.mean) - - def kl(self) -> torch.Tensor: - return 0.5 * torch.sum( - self.mean**2 + self.var - 1.0 - self.logvar, - dim=list(range(1, self.mean.ndim)), - ) - -class MemoryState(Enum): - """ - State[Disabled]: No memory bank will be enabled. - State[Initializing]: The model is handling the first clip, need to reset the memory bank. - State[Active]: There has been some data in the memory bank. - State[Unset]: Error state, indicating users didn't pass correct memory state in. - """ - - DISABLED = 0 - INITIALIZING = 1 - ACTIVE = 2 - UNSET = 3 - - -class QuantizerOutput(NamedTuple): - latent: torch.Tensor - extra_loss: torch.Tensor - statistics: Dict[str, torch.Tensor] - - -class CausalAutoencoderOutput(NamedTuple): - sample: torch.Tensor - latent: torch.Tensor - posterior: Optional[DiagonalGaussianDistribution] - - -class CausalEncoderOutput(NamedTuple): - latent: torch.Tensor - posterior: Optional[DiagonalGaussianDistribution] - - -class CausalDecoderOutput(NamedTuple): - sample: torch.Tensor diff --git a/modelsx/video_vae_v3/modules/video_vae.py b/modelsx/video_vae_v3/modules/video_vae.py deleted file mode 100644 index 1b169431c637ba273de7e6a2340c64206746ef28..0000000000000000000000000000000000000000 --- a/modelsx/video_vae_v3/modules/video_vae.py +++ /dev/null @@ -1,955 +0,0 @@ -# Copyright (c) 2023 HuggingFace Team -# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates. -# SPDX-License-Identifier: Apache License, Version 2.0 (the "License") -# -# This file has been modified by ByteDance Ltd. and/or its affiliates. on 1st June 2025 -# -# Original file was released under Apache License, Version 2.0 (the "License"), with the full license text -# available at http://www.apache.org/licenses/LICENSE-2.0. -# -# This modified file is released under the same license. - -from contextlib import nullcontext -from typing import Optional, Tuple, Literal, Callable, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution -from einops import rearrange - -from common.distributed.advanced import get_sequence_parallel_world_size -from common.logger import get_logger -from models.video_vae_v3.modules.causal_inflation_lib import ( - InflatedCausalConv3d, - causal_norm_wrapper, - init_causal_conv3d, - remove_head, -) -from models.video_vae_v3.modules.context_parallel_lib import ( - causal_conv_gather_outputs, - causal_conv_slice_inputs, -) -from models.video_vae_v3.modules.global_config import set_norm_limit -from models.video_vae_v3.modules.types import ( - CausalAutoencoderOutput, - CausalDecoderOutput, - CausalEncoderOutput, - MemoryState, - _inflation_mode_t, - _memory_device_t, - _receptive_field_t, - _selective_checkpointing_t, -) - -logger = get_logger(__name__) # pylint: disable=invalid-name - -# Fake func, no checkpointing is required for inference -def gradient_checkpointing(module: Union[Callable, nn.Module], *args, enabled: bool, **kwargs): - return module(*args, **kwargs) - -class ResnetBlock2D(nn.Module): - r""" - A Resnet block. - - Parameters: - in_channels (`int`): The number of channels in the input. - out_channels (`int`, *optional*, default to be `None`): - The number of output channels for the first conv2d layer. - If None, same as `in_channels`. - dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. - """ - - def __init__( - self, *, in_channels: int, out_channels: Optional[int] = None, dropout: float = 0.0 - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - - self.nonlinearity = nn.SiLU() - - self.norm1 = torch.nn.GroupNorm( - num_groups=32, num_channels=in_channels, eps=1e-6, affine=True - ) - - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - - self.norm2 = torch.nn.GroupNorm( - num_groups=32, num_channels=out_channels, eps=1e-6, affine=True - ) - - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - - self.use_in_shortcut = self.in_channels != out_channels - - self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = nn.Conv2d( - in_channels, out_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: - hidden = input_tensor - - hidden = self.norm1(hidden) - hidden = self.nonlinearity(hidden) - hidden = self.conv1(hidden) - - hidden = self.norm2(hidden) - hidden = self.nonlinearity(hidden) - hidden = self.dropout(hidden) - hidden = self.conv2(hidden) - - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) - - output_tensor = input_tensor + hidden - - return output_tensor - -class Upsample3D(nn.Module): - """A 3D upsampling layer.""" - - def __init__( - self, - channels: int, - inflation_mode: _inflation_mode_t = "tail", - temporal_up: bool = False, - spatial_up: bool = True, - slicing: bool = False, - ): - super().__init__() - self.channels = channels - self.conv = init_causal_conv3d( - self.channels, self.channels, kernel_size=3, padding=1, inflation_mode=inflation_mode - ) - - self.temporal_up = temporal_up - self.spatial_up = spatial_up - self.temporal_ratio = 2 if temporal_up else 1 - self.spatial_ratio = 2 if spatial_up else 1 - self.slicing = slicing - - upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio - self.upscale_conv = nn.Conv3d( - self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 - ) - identity = ( - torch.eye(self.channels).repeat(upscale_ratio, 1).reshape_as(self.upscale_conv.weight) - ) - - self.upscale_conv.weight.data.copy_(identity) - nn.init.zeros_(self.upscale_conv.bias) - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.FloatTensor, - memory_state: MemoryState, - ) -> torch.FloatTensor: - return gradient_checkpointing( - self.custom_forward, - hidden_states, - memory_state, - enabled=self.training and self.gradient_checkpointing, - ) - - def custom_forward( - self, - hidden_states: torch.FloatTensor, - memory_state: MemoryState, - ) -> torch.FloatTensor: - assert hidden_states.shape[1] == self.channels - - if self.slicing: - split_size = hidden_states.size(2) // 2 - hidden_states = list( - hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) - ) - else: - hidden_states = [hidden_states] - - for i in range(len(hidden_states)): - hidden_states[i] = self.upscale_conv(hidden_states[i]) - hidden_states[i] = rearrange( - hidden_states[i], - "b (x y z c) f h w -> b c (f z) (h x) (w y)", - x=self.spatial_ratio, - y=self.spatial_ratio, - z=self.temporal_ratio, - ) - - # [Overridden] For causal temporal conv - if self.temporal_up and memory_state != MemoryState.ACTIVE: - hidden_states[0] = remove_head(hidden_states[0]) - - if self.slicing: - hidden_states = self.conv(hidden_states, memory_state=memory_state) - return torch.cat(hidden_states, dim=2) - else: - return self.conv(hidden_states[0], memory_state=memory_state) - - -class Downsample3D(nn.Module): - """A 3D downsampling layer.""" - - def __init__( - self, - channels: int, - inflation_mode: _inflation_mode_t = "tail", - temporal_down: bool = False, - spatial_down: bool = True, - ): - super().__init__() - self.channels = channels - self.temporal_down = temporal_down - self.spatial_down = spatial_down - - self.temporal_ratio = 2 if temporal_down else 1 - self.spatial_ratio = 2 if spatial_down else 1 - - self.temporal_kernel = 3 if temporal_down else 1 - self.spatial_kernel = 3 if spatial_down else 1 - - self.conv = init_causal_conv3d( - self.channels, - self.channels, - kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), - stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), - padding=((1 if self.temporal_down else 0), 0, 0), - inflation_mode=inflation_mode, - ) - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.FloatTensor, - memory_state: MemoryState, - ) -> torch.FloatTensor: - return gradient_checkpointing( - self.custom_forward, - hidden_states, - memory_state, - enabled=self.training and self.gradient_checkpointing, - ) - - def custom_forward( - self, - hidden_states: torch.FloatTensor, - memory_state: MemoryState, - ) -> torch.FloatTensor: - - assert hidden_states.shape[1] == self.channels - - if self.spatial_down: - hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) - - hidden_states = self.conv(hidden_states, memory_state=memory_state) - return hidden_states - - -class ResnetBlock3D(ResnetBlock2D): - def __init__( - self, - *args, - inflation_mode: _inflation_mode_t = "tail", - time_receptive_field: _receptive_field_t = "half", - **kwargs, - ): - super().__init__(*args, **kwargs) - self.conv1 = init_causal_conv3d( - self.in_channels, - self.out_channels, - kernel_size=3, - stride=1, - padding=1, - inflation_mode=inflation_mode, - ) - - self.conv2 = init_causal_conv3d( - self.out_channels, - self.out_channels, - kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), - stride=1, - padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), - inflation_mode=inflation_mode, - ) - - if self.use_in_shortcut: - self.conv_shortcut = init_causal_conv3d( - self.in_channels, - self.out_channels, - kernel_size=1, - stride=1, - padding=0, - bias=(self.conv_shortcut.bias is not None), - inflation_mode=inflation_mode, - ) - self.gradient_checkpointing = False - - def forward(self, input_tensor: torch.Tensor, memory_state: MemoryState = MemoryState.UNSET): - return gradient_checkpointing( - self.custom_forward, - input_tensor, - memory_state, - enabled=self.training and self.gradient_checkpointing, - ) - - def custom_forward( - self, input_tensor: torch.Tensor, memory_state: MemoryState = MemoryState.UNSET - ): - assert memory_state != MemoryState.UNSET - hidden_states = input_tensor - - hidden_states = causal_norm_wrapper(self.norm1, hidden_states) - hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.conv1(hidden_states, memory_state=memory_state) - - hidden_states = causal_norm_wrapper(self.norm2, hidden_states) - hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, memory_state=memory_state) - - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) - - output_tensor = input_tensor + hidden_states - - return output_tensor - - -class DownEncoderBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - add_downsample: bool = True, - inflation_mode: _inflation_mode_t = "tail", - time_receptive_field: _receptive_field_t = "half", - temporal_down: bool = True, - spatial_down: bool = True, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock3D( - in_channels=in_channels, - out_channels=out_channels, - dropout=dropout, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - self.downsamplers = None - if add_downsample: - # Todo: Refactor this line before V5 Image VAE Training. - self.downsamplers = nn.ModuleList( - [ - Downsample3D( - channels=out_channels, - inflation_mode=inflation_mode, - temporal_down=temporal_down, - spatial_down=spatial_down, - ) - ] - ) - - def forward( - self, hidden_states: torch.FloatTensor, memory_state: MemoryState - ) -> torch.FloatTensor: - for resnet in self.resnets: - hidden_states = resnet(hidden_states, memory_state=memory_state) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, memory_state=memory_state) - - return hidden_states - - -class UpDecoderBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - add_upsample: bool = True, - inflation_mode: _inflation_mode_t = "tail", - time_receptive_field: _receptive_field_t = "half", - temporal_up: bool = True, - spatial_up: bool = True, - slicing: bool = False, - ): - super().__init__() - resnets = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - resnets.append( - ResnetBlock3D( - in_channels=input_channels, - out_channels=out_channels, - dropout=dropout, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - ) - - self.resnets = nn.ModuleList(resnets) - - self.upsamplers = None - # Todo: Refactor this line before V5 Image VAE Training. - if add_upsample: - self.upsamplers = nn.ModuleList( - [ - Upsample3D( - channels=out_channels, - inflation_mode=inflation_mode, - temporal_up=temporal_up, - spatial_up=spatial_up, - slicing=slicing, - ) - ] - ) - - def forward( - self, hidden_states: torch.FloatTensor, memory_state: MemoryState - ) -> torch.FloatTensor: - for resnet in self.resnets: - hidden_states = resnet(hidden_states, memory_state=memory_state) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, memory_state=memory_state) - - return hidden_states - - -class UNetMidBlock3D(nn.Module): - def __init__( - self, - channels: int, - dropout: float = 0.0, - inflation_mode: _inflation_mode_t = "tail", - time_receptive_field: _receptive_field_t = "half", - ): - super().__init__() - self.resnets = nn.ModuleList( - [ - ResnetBlock3D( - in_channels=channels, - out_channels=channels, - dropout=dropout, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ), - ResnetBlock3D( - in_channels=channels, - out_channels=channels, - dropout=dropout, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ), - ] - ) - - def forward(self, hidden_states: torch.Tensor, memory_state: MemoryState): - for resnet in self.resnets: - hidden_states = resnet(hidden_states, memory_state) - return hidden_states - - -class Encoder3D(nn.Module): - r""" - The `Encoder` layer of a variational autoencoder that encodes - its input into a latent representation. - """ - - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - double_z: bool = True, - temporal_down_num: int = 2, - inflation_mode: _inflation_mode_t = "tail", - time_receptive_field: _receptive_field_t = "half", - selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), - ): - super().__init__() - self.layers_per_block = layers_per_block - - self.temporal_down_num = temporal_down_num - - self.conv_in = init_causal_conv3d( - in_channels, - block_out_channels[0], - kernel_size=3, - stride=1, - padding=1, - inflation_mode=inflation_mode, - ) - - self.down_blocks = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i in range(len(block_out_channels)): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 - # Note: take the last one - - down_block = DownEncoderBlock3D( - num_layers=self.layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - add_downsample=not is_final_block, - temporal_down=is_temporal_down_block, - spatial_down=True, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = UNetMidBlock3D( - channels=block_out_channels[-1], - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - # out - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[-1], num_groups=32, eps=1e-6 - ) - self.conv_act = nn.SiLU() - - conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = init_causal_conv3d( - block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode - ) - - assert len(selective_checkpointing) == len(self.down_blocks) - self.set_gradient_checkpointing(selective_checkpointing) - - def set_gradient_checkpointing(self, checkpointing_types): - gradient_checkpointing = [] - for down_block, sac_type in zip(self.down_blocks, checkpointing_types): - if sac_type == "coarse": - gradient_checkpointing.append(True) - elif sac_type == "fine": - for n, m in down_block.named_modules(): - if hasattr(m, "gradient_checkpointing"): - m.gradient_checkpointing = True - logger.debug(f"set gradient_checkpointing: {n}") - gradient_checkpointing.append(False) - else: - gradient_checkpointing.append(False) - self.gradient_checkpointing = gradient_checkpointing - logger.info(f"[Encoder3D] gradient_checkpointing: {checkpointing_types}") - - def forward(self, sample: torch.FloatTensor, memory_state: MemoryState) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" - sample = self.conv_in(sample, memory_state=memory_state) - # down - for down_block, sac in zip(self.down_blocks, self.gradient_checkpointing): - sample = gradient_checkpointing( - down_block, - sample, - memory_state=memory_state, - enabled=self.training and sac, - ) - - # middle - sample = self.mid_block(sample, memory_state=memory_state) - - # post-process - sample = causal_norm_wrapper(self.conv_norm_out, sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample, memory_state=memory_state) - - return sample - - -class Decoder3D(nn.Module): - r""" - The `Decoder` layer of a variational autoencoder that - decodes its latent representation into an output sample. - """ - - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - inflation_mode: _inflation_mode_t = "tail", - time_receptive_field: _receptive_field_t = "half", - temporal_up_num: int = 2, - slicing_up_num: int = 0, - selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), - ): - super().__init__() - self.layers_per_block = layers_per_block - self.temporal_up_num = temporal_up_num - - self.conv_in = init_causal_conv3d( - in_channels, - block_out_channels[-1], - kernel_size=3, - stride=1, - padding=1, - inflation_mode=inflation_mode, - ) - - self.up_blocks = nn.ModuleList([]) - - # mid - self.mid_block = UNetMidBlock3D( - channels=block_out_channels[-1], - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i in range(len(reversed_block_out_channels)): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - is_temporal_up_block = i < self.temporal_up_num - is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num - # Note: Keep symmetric - - up_block = UpDecoderBlock3D( - num_layers=self.layers_per_block + 1, - in_channels=prev_output_channel, - out_channels=output_channel, - add_upsample=not is_final_block, - temporal_up=is_temporal_up_block, - slicing=is_slicing_up_block, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - self.up_blocks.append(up_block) - - # out - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=32, eps=1e-6 - ) - self.conv_act = nn.SiLU() - self.conv_out = init_causal_conv3d( - block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode - ) - - assert len(selective_checkpointing) == len(self.up_blocks) - self.set_gradient_checkpointing(selective_checkpointing) - - def set_gradient_checkpointing(self, checkpointing_types): - gradient_checkpointing = [] - for up_block, sac_type in zip(self.up_blocks, checkpointing_types): - if sac_type == "coarse": - gradient_checkpointing.append(True) - elif sac_type == "fine": - for n, m in up_block.named_modules(): - if hasattr(m, "gradient_checkpointing"): - m.gradient_checkpointing = True - logger.debug(f"set gradient_checkpointing: {n}") - gradient_checkpointing.append(False) - else: - gradient_checkpointing.append(False) - self.gradient_checkpointing = gradient_checkpointing - logger.info(f"[Decoder3D] gradient_checkpointing: {checkpointing_types}") - - def forward(self, sample: torch.FloatTensor, memory_state: MemoryState) -> torch.FloatTensor: - r"""The forward method of the `Decoder` class.""" - - sample = self.conv_in(sample, memory_state=memory_state) - - # middle - sample = self.mid_block(sample, memory_state=memory_state) - - # up - for up_block, sac in zip(self.up_blocks, self.gradient_checkpointing): - sample = gradient_checkpointing( - up_block, - sample, - memory_state=memory_state, - enabled=self.training and sac, - ) - - # post-process - sample = causal_norm_wrapper(self.conv_norm_out, sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample, memory_state=memory_state) - - return sample - - -class VideoAutoencoderKL(nn.Module): - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - block_out_channels: Tuple[int] = (64,), - layers_per_block: int = 1, - latent_channels: int = 4, - use_quant_conv: bool = True, - use_post_quant_conv: bool = True, - enc_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), - dec_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), - temporal_scale_num: int = 3, - slicing_up_num: int = 0, - inflation_mode: _inflation_mode_t = "tail", - time_receptive_field: _receptive_field_t = "half", - slicing_sample_min_size: int = None, - spatial_downsample_factor: int = 16, - temporal_downsample_factor: int = 8, - freeze_encoder: bool = False, - ): - super().__init__() - self.spatial_downsample_factor = spatial_downsample_factor - self.temporal_downsample_factor = temporal_downsample_factor - self.freeze_encoder = freeze_encoder - if slicing_sample_min_size is None: - slicing_sample_min_size = temporal_downsample_factor - self.slicing_sample_min_size = slicing_sample_min_size - self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) - - # pass init params to Encoder - self.encoder = Encoder3D( - in_channels=in_channels, - out_channels=latent_channels, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - double_z=True, - temporal_down_num=temporal_scale_num, - selective_checkpointing=enc_selective_checkpointing, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - # pass init params to Decoder - self.decoder = Decoder3D( - in_channels=latent_channels, - out_channels=out_channels, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - # [Override] add temporal_up_num parameter - temporal_up_num=temporal_scale_num, - slicing_up_num=slicing_up_num, - selective_checkpointing=dec_selective_checkpointing, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - self.quant_conv = ( - init_causal_conv3d( - in_channels=2 * latent_channels, - out_channels=2 * latent_channels, - kernel_size=1, - inflation_mode=inflation_mode, - ) - if use_quant_conv - else None - ) - self.post_quant_conv = ( - init_causal_conv3d( - in_channels=latent_channels, - out_channels=latent_channels, - kernel_size=1, - inflation_mode=inflation_mode, - ) - if use_post_quant_conv - else None - ) - - self.use_slicing = False - - def enable_slicing(self): - self.use_slicing = True - - def disable_slicing(self): - self.use_slicing = False - - def encode(self, x: torch.FloatTensor) -> CausalEncoderOutput: - if x.ndim == 4: - x = x.unsqueeze(2) - h = self.slicing_encode(x) - p = DiagonalGaussianDistribution(h) - z = p.sample() - return CausalEncoderOutput(z, p) - - def decode(self, z: torch.FloatTensor) -> CausalDecoderOutput: - if z.ndim == 4: - z = z.unsqueeze(2) - x = self.slicing_decode(z) - return CausalDecoderOutput(x) - - def _encode(self, x: torch.Tensor, memory_state: MemoryState) -> torch.Tensor: - x = causal_conv_slice_inputs(x, self.slicing_sample_min_size, memory_state=memory_state) - h = self.encoder(x, memory_state=memory_state) - h = self.quant_conv(h, memory_state=memory_state) if self.quant_conv is not None else h - h = causal_conv_gather_outputs(h) - return h - - def _decode(self, z: torch.Tensor, memory_state: MemoryState) -> torch.Tensor: - z = causal_conv_slice_inputs(z, self.slicing_latent_min_size, memory_state=memory_state) - z = ( - self.post_quant_conv(z, memory_state=memory_state) - if self.post_quant_conv is not None - else z - ) - x = self.decoder(z, memory_state=memory_state) - x = causal_conv_gather_outputs(x) - return x - - def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: - sp_size = get_sequence_parallel_world_size() - if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: - x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2) - encoded_slices = [ - self._encode( - torch.cat((x[:, :, :1], x_slices[0]), dim=2), - memory_state=MemoryState.INITIALIZING, - ) - ] - for x_idx in range(1, len(x_slices)): - encoded_slices.append( - self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) - ) - return torch.cat(encoded_slices, dim=2) - else: - return self._encode(x, memory_state=MemoryState.DISABLED) - - def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: - sp_size = get_sequence_parallel_world_size() - if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: - z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) - decoded_slices = [ - self._decode( - torch.cat((z[:, :, :1], z_slices[0]), dim=2), - memory_state=MemoryState.INITIALIZING, - ) - ] - for z_idx in range(1, len(z_slices)): - decoded_slices.append( - self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) - ) - return torch.cat(decoded_slices, dim=2) - else: - return self._decode(z, memory_state=MemoryState.DISABLED) - - def forward(self, x: torch.FloatTensor) -> CausalAutoencoderOutput: - with torch.no_grad() if self.freeze_encoder else nullcontext(): - z, p = self.encode(x) - x = self.decode(z).sample - return CausalAutoencoderOutput(x, z, p) - - def preprocess(self, x: torch.Tensor): - # x should in [B, C, T, H, W], [B, C, H, W] - assert x.ndim == 4 or x.size(2) % self.temporal_downsample_factor == 1 - return x - - def postprocess(self, x: torch.Tensor): - # x should in [B, C, T, H, W], [B, C, H, W] - return x - - def set_causal_slicing( - self, - *, - split_size: Optional[int], - memory_device: _memory_device_t, - ): - assert ( - split_size is None or memory_device is not None - ), "if split_size is set, memory_device must not be None." - if split_size is not None: - self.enable_slicing() - self.slicing_sample_min_size = split_size - self.slicing_latent_min_size = split_size // self.temporal_downsample_factor - else: - self.disable_slicing() - for module in self.modules(): - if isinstance(module, InflatedCausalConv3d): - module.set_memory_device(memory_device) - - def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): - set_norm_limit(norm_max_mem) - for m in self.modules(): - if isinstance(m, InflatedCausalConv3d): - m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) - - -class VideoAutoencoderKLWrapper(VideoAutoencoderKL): - def __init__( - self, *args, spatial_downsample_factor: int, temporal_downsample_factor: int, **kwargs - ): - self.spatial_downsample_factor = spatial_downsample_factor - self.temporal_downsample_factor = temporal_downsample_factor - super().__init__(*args, **kwargs) - - def forward(self, x) -> CausalAutoencoderOutput: - z, _, p = self.encode(x) - x, _ = self.decode(z) - return CausalAutoencoderOutput(x, z, None, p) - - def encode(self, x) -> CausalEncoderOutput: - if x.ndim == 4: - x = x.unsqueeze(2) - p = super().encode(x).latent_dist - z = p.sample().squeeze(2) - return CausalEncoderOutput(z, None, p) - - def decode(self, z) -> CausalDecoderOutput: - if z.ndim == 4: - z = z.unsqueeze(2) - x = super().decode(z).sample.squeeze(2) - return CausalDecoderOutput(x, None) - - def preprocess(self, x): - # x should in [B, C, T, H, W], [B, C, H, W] - assert x.ndim == 4 or x.size(2) % 4 == 1 - return x - - def postprocess(self, x): - # x should in [B, C, T, H, W], [B, C, H, W] - return x - - def set_causal_slicing( - self, - *, - split_size: Optional[int], - memory_device: Optional[Literal["cpu", "same"]], - ): - assert ( - split_size is None or memory_device is not None - ), "if split_size is set, memory_device must not be None." - if split_size is not None: - self.enable_slicing() - else: - self.disable_slicing() - self.slicing_sample_min_size = split_size - if split_size is not None: - self.slicing_latent_min_size = split_size // self.temporal_downsample_factor - for module in self.modules(): - if isinstance(module, InflatedCausalConv3d): - module.set_memory_device(memory_device) \ No newline at end of file diff --git a/modelsx/video_vae_v3/s8_c16_t4_inflation_sd3.yaml b/modelsx/video_vae_v3/s8_c16_t4_inflation_sd3.yaml deleted file mode 100644 index 58309522b791171f9d39f78ea1eaf57bab2a28fe..0000000000000000000000000000000000000000 --- a/modelsx/video_vae_v3/s8_c16_t4_inflation_sd3.yaml +++ /dev/null @@ -1,33 +0,0 @@ -__object__: - path: models.video_vae_v3.modules.attn_video_vae - name: VideoAutoencoderKLWrapper - args: as_params - -act_fn: silu -block_out_channels: - - 128 - - 256 - - 512 - - 512 -down_block_types: - - DownEncoderBlock3D - - DownEncoderBlock3D - - DownEncoderBlock3D - - DownEncoderBlock3D -in_channels: 3 -latent_channels: 16 -layers_per_block: 2 -norm_num_groups: 32 -out_channels: 3 -slicing_sample_min_size: 4 -temporal_scale_num: 2 -inflation_mode: pad -up_block_types: - - UpDecoderBlock3D - - UpDecoderBlock3D - - UpDecoderBlock3D - - UpDecoderBlock3D -spatial_downsample_factor: 8 -temporal_downsample_factor: 4 -use_quant_conv: False -use_post_quant_conv: False diff --git a/projectsx/inference_seedvr2_3b.py b/projectsx/inference_seedvr2_3b.py deleted file mode 100644 index 298cb217451bcf939b5bb0134aca63348c2a5639..0000000000000000000000000000000000000000 --- a/projectsx/inference_seedvr2_3b.py +++ /dev/null @@ -1,322 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -import os -import torch -import mediapy -from einops import rearrange -from omegaconf import OmegaConf -print(os.getcwd()) -import datetime -from tqdm import tqdm -import gc - - -from data.image.transforms.divisible_crop import DivisibleCrop -from data.image.transforms.na_resize import NaResize -from data.video.transforms.rearrange import Rearrange -if os.path.exists("./projects/video_diffusion_sr/color_fix.py"): - from projects.video_diffusion_sr.color_fix import wavelet_reconstruction - use_colorfix=True -else: - use_colorfix = False - print('Note!!!!!! Color fix is not avaliable!') -from torchvision.transforms import Compose, Lambda, Normalize -from torchvision.io.video import read_video -import argparse - - -from common.distributed import ( - get_device, - init_torch, -) - -from common.distributed.advanced import ( - get_data_parallel_rank, - get_data_parallel_world_size, - get_sequence_parallel_rank, - get_sequence_parallel_world_size, - init_sequence_parallel, -) - -from projects.video_diffusion_sr.infer import VideoDiffusionInfer -from common.config import load_config -from common.distributed.ops import sync_data -from common.seed import set_seed -from common.partition import partition_by_groups, partition_by_size - - -def configure_sequence_parallel(sp_size): - if sp_size > 1: - init_sequence_parallel(sp_size) - -def configure_runner(sp_size): - config_path = os.path.join('./configs_3b', 'main.yaml') - config = load_config(config_path) - runner = VideoDiffusionInfer(config) - OmegaConf.set_readonly(runner.config, False) - - init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600)) - configure_sequence_parallel(sp_size) - runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth') - runner.configure_vae_model() - # Set memory limit. - if hasattr(runner.vae, "set_memory_limit"): - runner.vae.set_memory_limit(**runner.config.vae.memory_limit) - return runner - -def generation_step(runner, text_embeds_dict, cond_latents): - def _move_to_cuda(x): - return [i.to(get_device()) for i in x] - - noises = [torch.randn_like(latent) for latent in cond_latents] - aug_noises = [torch.randn_like(latent) for latent in cond_latents] - print(f"Generating with noise shape: {noises[0].size()}.") - noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0) - noises, aug_noises, cond_latents = list( - map(lambda x: _move_to_cuda(x), (noises, aug_noises, cond_latents)) - ) - cond_noise_scale = 0.0 - - def _add_noise(x, aug_noise): - t = ( - torch.tensor([1000.0], device=get_device()) - * cond_noise_scale - ) - shape = torch.tensor(x.shape[1:], device=get_device())[None] - t = runner.timestep_transform(t, shape) - print( - f"Timestep shifting from" - f" {1000.0 * cond_noise_scale} to {t}." - ) - x = runner.schedule.forward(x, aug_noise, t) - return x - - conditions = [ - runner.get_condition( - noise, - task="sr", - latent_blur=_add_noise(latent_blur, aug_noise), - ) - for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents) - ] - - with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True): - video_tensors = runner.inference( - noises=noises, - conditions=conditions, - dit_offload=True, - **text_embeds_dict, - ) - - samples = [ - ( - rearrange(video[:, None], "c t h w -> t c h w") - if video.ndim == 3 - else rearrange(video, "c t h w -> t c h w") - ) - for video in video_tensors - ] - del video_tensors - - return samples - -def generation_loop(runner, video_path='./test_videos', output_dir='./results', batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, seed=666, res_h=1280, res_w=720, sp_size=1): - - def _build_pos_and_neg_prompt(): - # read positive prompt - positive_text = "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \ - hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \ - skin pore detailing, hyper sharpness, perfect without deformations." - # read negative prompt - negative_text = "painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \ - CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \ - signature, jpeg artifacts, deformed, lowres, over-smooth" - return positive_text, negative_text - - def _build_test_prompts(video_path): - positive_text, negative_text = _build_pos_and_neg_prompt() - original_videos = [] - prompts = {} - video_list = os.listdir(video_path) - for f in video_list: - if f.endswith(".mp4"): - original_videos.append(f) - prompts[f] = positive_text - print(f"Total prompts to be generated: {len(original_videos)}") - return original_videos, prompts, negative_text - - def _extract_text_embeds(): - # Text encoder forward. - positive_prompts_embeds = [] - for texts_pos in tqdm(original_videos_local): - text_pos_embeds = torch.load('pos_emb.pt') - text_neg_embeds = torch.load('neg_emb.pt') - - positive_prompts_embeds.append( - {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]} - ) - gc.collect() - torch.cuda.empty_cache() - return positive_prompts_embeds - - def cut_videos(videos, sp_size): - t = videos.size(1) - if t <= 4 * sp_size: - print(f"Cut input video size: {videos.size()}") - padding = [videos[:, -1].unsqueeze(1)] * (4 * sp_size - t + 1) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - return videos - if (t - 1) % (4 * sp_size) == 0: - return videos - else: - padding = [videos[:, -1].unsqueeze(1)] * ( - 4 * sp_size - ((t - 1) % (4 * sp_size)) - ) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - assert (videos.size(1) - 1) % (4 * sp_size) == 0 - return videos - - # classifier-free guidance - runner.config.diffusion.cfg.scale = cfg_scale - runner.config.diffusion.cfg.rescale = cfg_rescale - # sampling steps - runner.config.diffusion.timesteps.sampling.steps = sample_steps - runner.configure_diffusion() - - # set random seed - set_seed(seed, same_across_ranks=True) - os.makedirs(output_dir, exist_ok=True) - tgt_path = output_dir - - # get test prompts - original_videos, _, _ = _build_test_prompts(video_path) - - # divide the prompts into different groups - original_videos_group = partition_by_groups( - original_videos, - get_data_parallel_world_size() // get_sequence_parallel_world_size(), - ) - # store prompt mapping - original_videos_local = original_videos_group[ - get_data_parallel_rank() // get_sequence_parallel_world_size() - ] - original_videos_local = partition_by_size(original_videos_local, batch_size) - - # pre-extract the text embeddings - positive_prompts_embeds = _extract_text_embeds() - - video_transform = Compose( - [ - NaResize( - resolution=( - res_h * res_w - ) - ** 0.5, - mode="area", - # Upsample image, model only trained for high res. - downsample_only=False, - ), - Lambda(lambda x: torch.clamp(x, 0.0, 1.0)), - DivisibleCrop((16, 16)), - Normalize(0.5, 0.5), - Rearrange("t c h w -> c t h w"), - ] - ) - - # generation loop - for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)): - # read condition latents - cond_latents = [] - for video in videos: - video = ( - read_video( - os.path.join(video_path, video), output_format="TCHW" - )[0] - / 255.0 - ) - print(f"Read video size: {video.size()}") - cond_latents.append(video_transform(video.to(get_device()))) - - ori_lengths = [video.size(1) for video in cond_latents] - input_videos = cond_latents - cond_latents = [cut_videos(video, sp_size) for video in cond_latents] - - runner.dit.to("cpu") - print(f"Encoding videos: {list(map(lambda x: x.size(), cond_latents))}") - runner.vae.to(get_device()) - cond_latents = runner.vae_encode(cond_latents) - runner.vae.to("cpu") - runner.dit.to(get_device()) - - for i, emb in enumerate(text_embeds["texts_pos"]): - text_embeds["texts_pos"][i] = emb.to(get_device()) - for i, emb in enumerate(text_embeds["texts_neg"]): - text_embeds["texts_neg"][i] = emb.to(get_device()) - - samples = generation_step(runner, text_embeds, cond_latents=cond_latents) - runner.dit.to("cpu") - del cond_latents - - # dump samples to the output directory - if get_sequence_parallel_rank() == 0: - for path, input, sample, ori_length in zip( - videos, input_videos, samples, ori_lengths - ): - if ori_length < sample.shape[0]: - sample = sample[:ori_length] - filename = os.path.join(tgt_path, os.path.basename(path)) - # color fix - input = ( - rearrange(input[:, None], "c t h w -> t c h w") - if input.ndim == 3 - else rearrange(input, "c t h w -> t c h w") - ) - if use_colorfix: - sample = wavelet_reconstruction( - sample.to("cpu"), input[: sample.size(0)].to("cpu") - ) - else: - sample = sample.to("cpu") - sample = ( - rearrange(sample[:, None], "t c h w -> t h w c") - if sample.ndim == 3 - else rearrange(sample, "t c h w -> t h w c") - ) - sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round() - sample = sample.to(torch.uint8).numpy() - - if sample.shape[0] == 1: - mediapy.write_image(filename, sample.squeeze(0)) - else: - mediapy.write_video( - filename, sample, fps=24 - ) - gc.collect() - torch.cuda.empty_cache() - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--video_path", type=str, default="./test_videos") - parser.add_argument("--output_dir", type=str, default="./results") - parser.add_argument("--seed", type=int, default=666) - parser.add_argument("--res_h", type=int, default=720) - parser.add_argument("--res_w", type=int, default=1280) - parser.add_argument("--sp_size", type=int, default=1) - args = parser.parse_args() - - runner = configure_runner(args.sp_size) - generation_loop(runner, **vars(args)) diff --git a/projectsx/inference_seedvr2_7b.py b/projectsx/inference_seedvr2_7b.py deleted file mode 100644 index c4b73c25ce91bc0691a34e87d157edde488272cd..0000000000000000000000000000000000000000 --- a/projectsx/inference_seedvr2_7b.py +++ /dev/null @@ -1,321 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -import os -import torch -import mediapy -from einops import rearrange -from omegaconf import OmegaConf -print(os.getcwd()) -import datetime -from tqdm import tqdm -from models.dit import na -import gc - -from data.image.transforms.divisible_crop import DivisibleCrop -from data.image.transforms.na_resize import NaResize -from data.video.transforms.rearrange import Rearrange -if os.path.exists("./projects/video_diffusion_sr/color_fix.py"): - from projects.video_diffusion_sr.color_fix import wavelet_reconstruction - use_colorfix=True -else: - use_colorfix = False - print('Note!!!!!! Color fix is not avaliable!') -from torchvision.transforms import Compose, Lambda, Normalize -from torchvision.io.video import read_video - - -from common.distributed import ( - get_device, - init_torch, -) - -from common.distributed.advanced import ( - get_data_parallel_rank, - get_data_parallel_world_size, - get_sequence_parallel_rank, - get_sequence_parallel_world_size, - init_sequence_parallel, -) - -from projects.video_diffusion_sr.infer import VideoDiffusionInfer -from common.config import load_config -from common.distributed.ops import sync_data -from common.seed import set_seed -from common.partition import partition_by_groups, partition_by_size -import argparse - -def configure_sequence_parallel(sp_size): - if sp_size > 1: - init_sequence_parallel(sp_size) - -def configure_runner(sp_size): - config_path = os.path.join('./configs_7b', 'main.yaml') - config = load_config(config_path) - runner = VideoDiffusionInfer(config) - OmegaConf.set_readonly(runner.config, False) - - init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600)) - configure_sequence_parallel(sp_size) - runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_7b.pth') - runner.configure_vae_model() - # Set memory limit. - if hasattr(runner.vae, "set_memory_limit"): - runner.vae.set_memory_limit(**runner.config.vae.memory_limit) - return runner - -def generation_step(runner, text_embeds_dict, cond_latents): - def _move_to_cuda(x): - return [i.to(get_device()) for i in x] - - noises = [torch.randn_like(latent) for latent in cond_latents] - aug_noises = [torch.randn_like(latent) for latent in cond_latents] - print(f"Generating with noise shape: {noises[0].size()}.") - noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0) - noises, aug_noises, cond_latents = list( - map(lambda x: _move_to_cuda(x), (noises, aug_noises, cond_latents)) - ) - cond_noise_scale = 0.0 - - def _add_noise(x, aug_noise): - t = ( - torch.tensor([1000.0], device=get_device()) - * cond_noise_scale - ) - shape = torch.tensor(x.shape[1:], device=get_device())[None] - t = runner.timestep_transform(t, shape) - print( - f"Timestep shifting from" - f" {1000.0 * cond_noise_scale} to {t}." - ) - x = runner.schedule.forward(x, aug_noise, t) - return x - - conditions = [ - runner.get_condition( - noise, - task="sr", - latent_blur=_add_noise(latent_blur, aug_noise), - ) - for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents) - ] - - with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True): - video_tensors = runner.inference( - noises=noises, - conditions=conditions, - dit_offload=True, - **text_embeds_dict, - ) - - samples = [ - ( - rearrange(video[:, None], "c t h w -> t c h w") - if video.ndim == 3 - else rearrange(video, "c t h w -> t c h w") - ) - for video in video_tensors - ] - del video_tensors - - return samples - -def generation_loop(runner, video_path='./test_videos', output_dir='./results', batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, seed=666, res_h=1280, res_w=720, sp_size=1): - - def _build_pos_and_neg_prompt(): - # read positive prompt - positive_text = "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \ - hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \ - skin pore detailing, hyper sharpness, perfect without deformations." - # read negative prompt - negative_text = "painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \ - CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \ - signature, jpeg artifacts, deformed, lowres, over-smooth" - return positive_text, negative_text - - def _build_test_prompts(video_path): - positive_text, negative_text = _build_pos_and_neg_prompt() - original_videos = [] - prompts = {} - video_list = os.listdir(video_path) - for f in video_list: - if f.endswith(".mp4"): - original_videos.append(f) - prompts[f] = positive_text - print(f"Total prompts to be generated: {len(original_videos)}") - return original_videos, prompts, negative_text - - def _extract_text_embeds(): - # Text encoder forward. - positive_prompts_embeds = [] - for texts_pos in tqdm(original_videos_local): - text_pos_embeds = torch.load('pos_emb.pt') - text_neg_embeds = torch.load('neg_emb.pt') - - positive_prompts_embeds.append( - {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]} - ) - gc.collect() - torch.cuda.empty_cache() - return positive_prompts_embeds - - def cut_videos(videos, sp_size): - t = videos.size(1) - if t <= 4 * sp_size: - print(f"Cut input video size: {videos.size()}") - padding = [videos[:, -1].unsqueeze(1)] * (4 * sp_size - t + 1) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - return videos - if (t - 1) % (4 * sp_size) == 0: - return videos - else: - padding = [videos[:, -1].unsqueeze(1)] * ( - 4 * sp_size - ((t - 1) % (4 * sp_size)) - ) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - assert (videos.size(1) - 1) % (4 * sp_size) == 0 - return videos - - # classifier-free guidance - runner.config.diffusion.cfg.scale = cfg_scale - runner.config.diffusion.cfg.rescale = cfg_rescale - # sampling steps - runner.config.diffusion.timesteps.sampling.steps = sample_steps - runner.configure_diffusion() - - # set random seed - set_seed(seed, same_across_ranks=True) - os.makedirs(output_dir, exist_ok=True) - tgt_path = output_dir - - # get test prompts - original_videos, _, _ = _build_test_prompts(video_path) - - # divide the prompts into different groups - original_videos_group = partition_by_groups( - original_videos, - get_data_parallel_world_size() // get_sequence_parallel_world_size(), - ) - # store prompt mapping - original_videos_local = original_videos_group[ - get_data_parallel_rank() // get_sequence_parallel_world_size() - ] - original_videos_local = partition_by_size(original_videos_local, batch_size) - - # pre-extract the text embeddings - positive_prompts_embeds = _extract_text_embeds() - - video_transform = Compose( - [ - NaResize( - resolution=( - res_h * res_w - ) - ** 0.5, - mode="area", - # Upsample image, model only trained for high res. - downsample_only=False, - ), - Lambda(lambda x: torch.clamp(x, 0.0, 1.0)), - DivisibleCrop((16, 16)), - Normalize(0.5, 0.5), - Rearrange("t c h w -> c t h w"), - ] - ) - - # generation loop - for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)): - # read condition latents - cond_latents = [] - for video in videos: - video = ( - read_video( - os.path.join(video_path, video), output_format="TCHW" - )[0] - / 255.0 - ) - print(f"Read video size: {video.size()}") - cond_latents.append(video_transform(video.to(get_device()))) - - ori_lengths = [video.size(1) for video in cond_latents] - input_videos = cond_latents - cond_latents = [cut_videos(video, sp_size) for video in cond_latents] - - runner.dit.to("cpu") - print(f"Encoding videos: {list(map(lambda x: x.size(), cond_latents))}") - runner.vae.to(get_device()) - cond_latents = runner.vae_encode(cond_latents) - runner.vae.to("cpu") - runner.dit.to(get_device()) - - for i, emb in enumerate(text_embeds["texts_pos"]): - text_embeds["texts_pos"][i] = emb.to(get_device()) - for i, emb in enumerate(text_embeds["texts_neg"]): - text_embeds["texts_neg"][i] = emb.to(get_device()) - - samples = generation_step(runner, text_embeds, cond_latents=cond_latents) - runner.dit.to("cpu") - del cond_latents - - # dump samples to the output directory - if get_sequence_parallel_rank() == 0: - for path, input, sample, ori_length in zip( - videos, input_videos, samples, ori_lengths - ): - if ori_length < sample.shape[0]: - sample = sample[:ori_length] - filename = os.path.join(tgt_path, os.path.basename(path)) - # color fix - input = ( - rearrange(input[:, None], "c t h w -> t c h w") - if input.ndim == 3 - else rearrange(input, "c t h w -> t c h w") - ) - if use_colorfix: - sample = wavelet_reconstruction( - sample.to("cpu"), input[: sample.size(0)].to("cpu") - ) - else: - sample = sample.to("cpu") - sample = ( - rearrange(sample[:, None], "t c h w -> t h w c") - if sample.ndim == 3 - else rearrange(sample, "t c h w -> t h w c") - ) - sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round() - sample = sample.to(torch.uint8).numpy() - - if sample.shape[0] == 1: - mediapy.write_image(filename, sample.squeeze(0)) - else: - mediapy.write_video( - filename, sample, fps=24 - ) - gc.collect() - torch.cuda.empty_cache() - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--video_path", type=str, default="./test_videos") - parser.add_argument("--output_dir", type=str, default="./results") - parser.add_argument("--seed", type=int, default=666) - parser.add_argument("--res_h", type=int, default=720) - parser.add_argument("--res_w", type=int, default=1280) - parser.add_argument("--sp_size", type=int, default=1) - args = parser.parse_args() - - runner = configure_runner(args.sp_size) - generation_loop(runner, **vars(args)) diff --git a/projectsx/inference_seedvr_3b.py b/projectsx/inference_seedvr_3b.py deleted file mode 100644 index 469a97d8dac0769d208be21943b5f7215b249380..0000000000000000000000000000000000000000 --- a/projectsx/inference_seedvr_3b.py +++ /dev/null @@ -1,323 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -import os -import torch -import mediapy -from einops import rearrange -from omegaconf import OmegaConf -print(os.getcwd()) -import datetime -from tqdm import tqdm -import gc - -from data.image.transforms.divisible_crop import DivisibleCrop -from data.image.transforms.na_resize import NaResize -from data.video.transforms.rearrange import Rearrange -if os.path.exists("./projects/video_diffusion_sr/color_fix.py"): - from projects.video_diffusion_sr.color_fix import wavelet_reconstruction - use_colorfix=True -else: - use_colorfix = False - print('Note!!!!!! Color fix is not avaliable!') -from torchvision.transforms import Compose, Lambda, Normalize -from torchvision.io.video import read_video -import argparse - -from common.distributed import ( - get_device, - init_torch, -) - -from common.distributed.advanced import ( - get_data_parallel_rank, - get_data_parallel_world_size, - get_sequence_parallel_rank, - get_sequence_parallel_world_size, - init_sequence_parallel, -) - -from projects.video_diffusion_sr.infer import VideoDiffusionInfer -from common.config import load_config -from common.distributed.ops import sync_data -from common.seed import set_seed -from common.partition import partition_by_groups, partition_by_size - - -def configure_sequence_parallel(sp_size): - if sp_size > 1: - init_sequence_parallel(sp_size) - -def configure_runner(sp_size): - config_path = os.path.join('./configs_3b', 'main.yaml') - config = load_config(config_path) - runner = VideoDiffusionInfer(config) - OmegaConf.set_readonly(runner.config, False) - - init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600)) - configure_sequence_parallel(sp_size) - runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr_ema_3b.pth') - runner.configure_vae_model() - # Set memory limit. - if hasattr(runner.vae, "set_memory_limit"): - runner.vae.set_memory_limit(**runner.config.vae.memory_limit) - return runner - -def generation_step(runner, text_embeds_dict, cond_latents): - def _move_to_cuda(x): - return [i.to(get_device()) for i in x] - - noises = [torch.randn_like(latent) for latent in cond_latents] - aug_noises = [torch.randn_like(latent) for latent in cond_latents] - print(f"Generating with noise shape: {noises[0].size()}.") - noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0) - noises, aug_noises, cond_latents = list( - map(lambda x: _move_to_cuda(x), (noises, aug_noises, cond_latents)) - ) - cond_noise_scale = 0.1 - - def _add_noise(x, aug_noise): - t = ( - torch.tensor([1000.0], device=get_device()) - * cond_noise_scale - ) - shape = torch.tensor(x.shape[1:], device=get_device())[None] - t = runner.timestep_transform(t, shape) - print( - f"Timestep shifting from" - f" {1000.0 * cond_noise_scale} to {t}." - ) - x = runner.schedule.forward(x, aug_noise, t) - return x - - conditions = [ - runner.get_condition( - noise, - task="sr", - latent_blur=_add_noise(latent_blur, aug_noise), - ) - for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents) - ] - - with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True): - video_tensors = runner.inference( - noises=noises, - conditions=conditions, - dit_offload=True, - **text_embeds_dict, - ) - - samples = [ - ( - rearrange(video[:, None], "c t h w -> t c h w") - if video.ndim == 3 - else rearrange(video, "c t h w -> t c h w") - ) - for video in video_tensors - ] - del video_tensors - - return samples - -def generation_loop(runner, video_path='./test_videos', output_dir='./results', batch_size=1, cfg_scale=6.5, cfg_rescale=0.0, sample_steps=50, seed=666, res_h=1280, res_w=720, sp_size=1): - - def _build_pos_and_neg_prompt(): - # read positive prompt - positive_text = "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \ - hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \ - skin pore detailing, hyper sharpness, perfect without deformations." - # read negative prompt - negative_text = "painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \ - CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \ - signature, jpeg artifacts, deformed, lowres, over-smooth" - return positive_text, negative_text - - def _build_test_prompts(video_path): - positive_text, negative_text = _build_pos_and_neg_prompt() - original_videos = [] - prompts = {} - video_list = os.listdir(video_path) - for f in video_list: - if f.endswith(".mp4"): - original_videos.append(f) - prompts[f] = positive_text - print(f"Total prompts to be generated: {len(original_videos)}") - return original_videos, prompts, negative_text - - def _extract_text_embeds(): - # Text encoder forward. - positive_prompts_embeds = [] - for texts_pos in tqdm(original_videos_local): - text_pos_embeds = torch.load('pos_emb.pt') - text_neg_embeds = torch.load('neg_emb.pt') - - positive_prompts_embeds.append( - {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]} - ) - gc.collect() - torch.cuda.empty_cache() - return positive_prompts_embeds - - def cut_videos(videos, sp_size): - t = videos.size(1) - if t <= 4 * sp_size: - print(f"Cut input video size: {videos.size()}") - padding = [videos[:, -1].unsqueeze(1)] * (4 * sp_size - t + 1) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - return videos - if (t - 1) % (4 * sp_size) == 0: - return videos - else: - padding = [videos[:, -1].unsqueeze(1)] * ( - 4 * sp_size - ((t - 1) % (4 * sp_size)) - ) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - assert (videos.size(1) - 1) % (4 * sp_size) == 0 - return videos - - # classifier-free guidance - runner.config.diffusion.cfg.scale = cfg_scale - runner.config.diffusion.cfg.rescale = cfg_rescale - # sampling steps - runner.config.diffusion.timesteps.sampling.steps = sample_steps - runner.configure_diffusion() - - # set random seed - set_seed(seed, same_across_ranks=True) - os.makedirs(output_dir, exist_ok=True) - tgt_path = output_dir - - # get test prompts - original_videos, _, _ = _build_test_prompts(video_path) - - # divide the prompts into different groups - original_videos_group = partition_by_groups( - original_videos, - get_data_parallel_world_size() // get_sequence_parallel_world_size(), - ) - # store prompt mapping - original_videos_local = original_videos_group[ - get_data_parallel_rank() // get_sequence_parallel_world_size() - ] - original_videos_local = partition_by_size(original_videos_local, batch_size) - - # pre-extract the text embeddings - positive_prompts_embeds = _extract_text_embeds() - - video_transform = Compose( - [ - NaResize( - resolution=( - res_h * res_w - ) - ** 0.5, - mode="area", - # Upsample image, model only trained for high res. - downsample_only=False, - ), - Lambda(lambda x: torch.clamp(x, 0.0, 1.0)), - DivisibleCrop((16, 16)), - Normalize(0.5, 0.5), - Rearrange("t c h w -> c t h w"), - ] - ) - - # generation loop - for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)): - # read condition latents - cond_latents = [] - for video in videos: - video = ( - read_video( - os.path.join(video_path, video), output_format="TCHW" - )[0] - / 255.0 - ) - print(f"Read video size: {video.size()}") - cond_latents.append(video_transform(video.to(get_device()))) - - ori_lengths = [video.size(1) for video in cond_latents] - input_videos = cond_latents - cond_latents = [cut_videos(video, sp_size) for video in cond_latents] - - runner.dit.to("cpu") - print(f"Encoding videos: {list(map(lambda x: x.size(), cond_latents))}") - runner.vae.to(get_device()) - cond_latents = runner.vae_encode(cond_latents) - runner.vae.to("cpu") - runner.dit.to(get_device()) - - for i, emb in enumerate(text_embeds["texts_pos"]): - text_embeds["texts_pos"][i] = emb.to(get_device()) - for i, emb in enumerate(text_embeds["texts_neg"]): - text_embeds["texts_neg"][i] = emb.to(get_device()) - - samples = generation_step(runner, text_embeds, cond_latents=cond_latents) - runner.dit.to("cpu") - del cond_latents - - # dump samples to the output directory - if get_sequence_parallel_rank() == 0: - for path, input, sample, ori_length in zip( - videos, input_videos, samples, ori_lengths - ): - if ori_length < sample.shape[0]: - sample = sample[:ori_length] - filename = os.path.join(tgt_path, os.path.basename(path)) - # color fix - input = ( - rearrange(input[:, None], "c t h w -> t c h w") - if input.ndim == 3 - else rearrange(input, "c t h w -> t c h w") - ) - if use_colorfix: - sample = wavelet_reconstruction( - sample.to("cpu"), input[: sample.size(0)].to("cpu") - ) - else: - sample = sample.to("cpu") - sample = ( - rearrange(sample[:, None], "t c h w -> t h w c") - if sample.ndim == 3 - else rearrange(sample, "t c h w -> t h w c") - ) - sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round() - sample = sample.to(torch.uint8).numpy() - - if sample.shape[0] == 1: - mediapy.write_image(filename, sample.squeeze(0)) - else: - mediapy.write_video( - filename, sample, fps=24 - ) - gc.collect() - torch.cuda.empty_cache() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--video_path", type=str, default="./test_videos") - parser.add_argument("--output_dir", type=str, default="./results") - parser.add_argument("--cfg_scale", type=float, default=6.5) - parser.add_argument("--sample_steps", type=int, default=50) - parser.add_argument("--seed", type=int, default=666) - parser.add_argument("--res_h", type=int, default=720) - parser.add_argument("--res_w", type=int, default=1280) - parser.add_argument("--sp_size", type=int, default=1) - args = parser.parse_args() - - runner = configure_runner(args.sp_size) - generation_loop(runner, **vars(args)) diff --git a/projectsx/inference_seedvr_7b.py b/projectsx/inference_seedvr_7b.py deleted file mode 100644 index 1408c9ca6f1b40ff522611f59937c968e4f15b44..0000000000000000000000000000000000000000 --- a/projectsx/inference_seedvr_7b.py +++ /dev/null @@ -1,324 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -import os -import torch -import mediapy -from einops import rearrange -from omegaconf import OmegaConf -print(os.getcwd()) -import datetime -from tqdm import tqdm -from models.dit import na -import gc - -from data.image.transforms.divisible_crop import DivisibleCrop -from data.image.transforms.na_resize import NaResize -from data.video.transforms.rearrange import Rearrange -if os.path.exists("./projects/video_diffusion_sr/color_fix.py"): - from projects.video_diffusion_sr.color_fix import wavelet_reconstruction - use_colorfix=True -else: - use_colorfix = False - print('Note!!!!!! Color fix is not avaliable!') -from torchvision.transforms import Compose, Lambda, Normalize -from torchvision.io.video import read_video -import argparse - - -from common.distributed import ( - get_device, - init_torch, -) - -from common.distributed.advanced import ( - get_data_parallel_rank, - get_data_parallel_world_size, - get_sequence_parallel_rank, - get_sequence_parallel_world_size, - init_sequence_parallel, -) - -from projects.video_diffusion_sr.infer import VideoDiffusionInfer -from common.config import load_config -from common.distributed.ops import sync_data -from common.seed import set_seed -from common.partition import partition_by_groups, partition_by_size - - -def configure_sequence_parallel(sp_size): - if sp_size > 1: - init_sequence_parallel(sp_size) - -def configure_runner(sp_size): - config_path = os.path.join('./configs_7b', 'main.yaml') - config = load_config(config_path) - runner = VideoDiffusionInfer(config) - OmegaConf.set_readonly(runner.config, False) - - init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600)) - configure_sequence_parallel(sp_size) - runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr_ema_7b.pth') - runner.configure_vae_model() - # Set memory limit. - if hasattr(runner.vae, "set_memory_limit"): - runner.vae.set_memory_limit(**runner.config.vae.memory_limit) - return runner - -def generation_step(runner, text_embeds_dict, cond_latents): - def _move_to_cuda(x): - return [i.to(get_device()) for i in x] - - noises = [torch.randn_like(latent) for latent in cond_latents] - aug_noises = [torch.randn_like(latent) for latent in cond_latents] - print(f"Generating with noise shape: {noises[0].size()}.") - noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0) - noises, aug_noises, cond_latents = list( - map(lambda x: _move_to_cuda(x), (noises, aug_noises, cond_latents)) - ) - cond_noise_scale = 0.1 - - def _add_noise(x, aug_noise): - t = ( - torch.tensor([1000.0], device=get_device()) - * cond_noise_scale - ) - shape = torch.tensor(x.shape[1:], device=get_device())[None] - t = runner.timestep_transform(t, shape) - print( - f"Timestep shifting from" - f" {1000.0 * cond_noise_scale} to {t}." - ) - x = runner.schedule.forward(x, aug_noise, t) - return x - - conditions = [ - runner.get_condition( - noise, - task="sr", - latent_blur=_add_noise(latent_blur, aug_noise), - ) - for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents) - ] - - with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True): - video_tensors = runner.inference( - noises=noises, - conditions=conditions, - dit_offload=True, - **text_embeds_dict, - ) - - samples = [ - ( - rearrange(video[:, None], "c t h w -> t c h w") - if video.ndim == 3 - else rearrange(video, "c t h w -> t c h w") - ) - for video in video_tensors - ] - del video_tensors - - return samples - -def generation_loop(runner, video_path='./test_videos', output_dir='./results', batch_size=1, cfg_scale=6.5, cfg_rescale=0.0, sample_steps=50, seed=666, res_h=1280, res_w=720, sp_size=1): - - def _build_pos_and_neg_prompt(): - # read positive prompt - positive_text = "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \ - hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \ - skin pore detailing, hyper sharpness, perfect without deformations." - # read negative prompt - negative_text = "painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \ - CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \ - signature, jpeg artifacts, deformed, lowres, over-smooth" - return positive_text, negative_text - - def _build_test_prompts(video_path): - positive_text, negative_text = _build_pos_and_neg_prompt() - original_videos = [] - prompts = {} - video_list = os.listdir(video_path) - for f in video_list: - if f.endswith(".mp4"): - original_videos.append(f) - prompts[f] = positive_text - print(f"Total prompts to be generated: {len(original_videos)}") - return original_videos, prompts, negative_text - - def _extract_text_embeds(): - # Text encoder forward. - positive_prompts_embeds = [] - for texts_pos in tqdm(original_videos_local): - text_pos_embeds = torch.load('pos_emb.pt') - text_neg_embeds = torch.load('neg_emb.pt') - - positive_prompts_embeds.append( - {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]} - ) - gc.collect() - torch.cuda.empty_cache() - return positive_prompts_embeds - - def cut_videos(videos, sp_size): - t = videos.size(1) - if t <= 4 * sp_size: - print(f"Cut input video size: {videos.size()}") - padding = [videos[:, -1].unsqueeze(1)] * (4 * sp_size - t + 1) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - return videos - if (t - 1) % (4 * sp_size) == 0: - return videos - else: - padding = [videos[:, -1].unsqueeze(1)] * ( - 4 * sp_size - ((t - 1) % (4 * sp_size)) - ) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - assert (videos.size(1) - 1) % (4 * sp_size) == 0 - return videos - - # classifier-free guidance - runner.config.diffusion.cfg.scale = cfg_scale - runner.config.diffusion.cfg.rescale = cfg_rescale - # sampling steps - runner.config.diffusion.timesteps.sampling.steps = sample_steps - runner.configure_diffusion() - - # set random seed - set_seed(seed, same_across_ranks=True) - os.makedirs(output_dir, exist_ok=True) - tgt_path = output_dir - - # get test prompts - original_videos, _, _ = _build_test_prompts(video_path) - - # divide the prompts into different groups - original_videos_group = partition_by_groups( - original_videos, - get_data_parallel_world_size() // get_sequence_parallel_world_size(), - ) - # store prompt mapping - original_videos_local = original_videos_group[ - get_data_parallel_rank() // get_sequence_parallel_world_size() - ] - original_videos_local = partition_by_size(original_videos_local, batch_size) - - # pre-extract the text embeddings - positive_prompts_embeds = _extract_text_embeds() - - video_transform = Compose( - [ - NaResize( - resolution=( - res_h * res_w - ) - ** 0.5, - mode="area", - # Upsample image, model only trained for high res. - downsample_only=False, - ), - Lambda(lambda x: torch.clamp(x, 0.0, 1.0)), - DivisibleCrop((16, 16)), - Normalize(0.5, 0.5), - Rearrange("t c h w -> c t h w"), - ] - ) - - # generation loop - for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)): - # read condition latents - cond_latents = [] - for video in videos: - video = ( - read_video( - os.path.join(video_path, video), output_format="TCHW" - )[0] - / 255.0 - ) - print(f"Read video size: {video.size()}") - cond_latents.append(video_transform(video.to(get_device()))) - - ori_lengths = [video.size(1) for video in cond_latents] - input_videos = cond_latents - cond_latents = [cut_videos(video, sp_size) for video in cond_latents] - - runner.dit.to("cpu") - print(f"Encoding videos: {list(map(lambda x: x.size(), cond_latents))}") - runner.vae.to(get_device()) - cond_latents = runner.vae_encode(cond_latents) - runner.vae.to("cpu") - runner.dit.to(get_device()) - - for i, emb in enumerate(text_embeds["texts_pos"]): - text_embeds["texts_pos"][i] = emb.to(get_device()) - for i, emb in enumerate(text_embeds["texts_neg"]): - text_embeds["texts_neg"][i] = emb.to(get_device()) - - samples = generation_step(runner, text_embeds, cond_latents=cond_latents) - runner.dit.to("cpu") - del cond_latents - - # dump samples to the output directory - if get_sequence_parallel_rank() == 0: - for path, input, sample, ori_length in zip( - videos, input_videos, samples, ori_lengths - ): - if ori_length < sample.shape[0]: - sample = sample[:ori_length] - filename = os.path.join(tgt_path, os.path.basename(path)) - # color fix - input = ( - rearrange(input[:, None], "c t h w -> t c h w") - if input.ndim == 3 - else rearrange(input, "c t h w -> t c h w") - ) - if use_colorfix: - sample = wavelet_reconstruction( - sample.to("cpu"), input[: sample.size(0)].to("cpu") - ) - else: - sample = sample.to("cpu") - sample = ( - rearrange(sample[:, None], "t c h w -> t h w c") - if sample.ndim == 3 - else rearrange(sample, "t c h w -> t h w c") - ) - sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round() - sample = sample.to(torch.uint8).numpy() - - if sample.shape[0] == 1: - mediapy.write_image(filename, sample.squeeze(0)) - else: - mediapy.write_video( - filename, sample, fps=24 - ) - gc.collect() - torch.cuda.empty_cache() - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--video_path", type=str, default="./test_videos") - parser.add_argument("--output_dir", type=str, default="./results") - parser.add_argument("--cfg_scale", type=float, default=6.5) - parser.add_argument("--sample_steps", type=int, default=50) - parser.add_argument("--seed", type=int, default=666) - parser.add_argument("--res_h", type=int, default=720) - parser.add_argument("--res_w", type=int, default=1280) - parser.add_argument("--sp_size", type=int, default=1) - args = parser.parse_args() - - runner = configure_runner(args.sp_size) - generation_loop(runner, **vars(args)) diff --git a/projectsx/video_diffusion_sr/color_fix.py b/projectsx/video_diffusion_sr/color_fix.py deleted file mode 100644 index efe804519873717eee01468439c416325eb8e192..0000000000000000000000000000000000000000 --- a/projectsx/video_diffusion_sr/color_fix.py +++ /dev/null @@ -1,113 +0,0 @@ -import torch -from PIL import Image -from torch import Tensor -from torch.nn import functional as F - -from torchvision.transforms import ToTensor, ToPILImage - -def adain_color_fix(target: Image, source: Image): - # Convert images to tensors - to_tensor = ToTensor() - target_tensor = to_tensor(target).unsqueeze(0) - source_tensor = to_tensor(source).unsqueeze(0) - - # Apply adaptive instance normalization - result_tensor = adaptive_instance_normalization(target_tensor, source_tensor) - - # Convert tensor back to image - to_image = ToPILImage() - result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) - - return result_image - -def wavelet_color_fix(target: Image, source: Image): - # Convert images to tensors - to_tensor = ToTensor() - target_tensor = to_tensor(target).unsqueeze(0) - source_tensor = to_tensor(source).unsqueeze(0) - - # Apply wavelet reconstruction - result_tensor = wavelet_reconstruction(target_tensor, source_tensor) - - # Convert tensor back to image - to_image = ToPILImage() - result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) - - return result_image - -def calc_mean_std(feat: Tensor, eps=1e-5): - """Calculate mean and std for adaptive_instance_normalization. - Args: - feat (Tensor): 4D tensor. - eps (float): A small value added to the variance to avoid - divide-by-zero. Default: 1e-5. - """ - size = feat.size() - assert len(size) == 4, 'The input feature should be 4D tensor.' - b, c = size[:2] - feat_var = feat.view(b, c, -1).var(dim=2) + eps - feat_std = feat_var.sqrt().view(b, c, 1, 1) - feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) - return feat_mean, feat_std - -def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): - """Adaptive instance normalization. - Adjust the reference features to have the similar color and illuminations - as those in the degradate features. - Args: - content_feat (Tensor): The reference feature. - style_feat (Tensor): The degradate features. - """ - size = content_feat.size() - style_mean, style_std = calc_mean_std(style_feat) - content_mean, content_std = calc_mean_std(content_feat) - normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) - return normalized_feat * style_std.expand(size) + style_mean.expand(size) - -def wavelet_blur(image: Tensor, radius: int): - """ - Apply wavelet blur to the input tensor. - """ - # input shape: (1, 3, H, W) - # convolution kernel - kernel_vals = [ - [0.0625, 0.125, 0.0625], - [0.125, 0.25, 0.125], - [0.0625, 0.125, 0.0625], - ] - kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) - # add channel dimensions to the kernel to make it a 4D tensor - kernel = kernel[None, None] - # repeat the kernel across all input channels - kernel = kernel.repeat(3, 1, 1, 1) - image = F.pad(image, (radius, radius, radius, radius), mode='replicate') - # apply convolution - output = F.conv2d(image, kernel, groups=3, dilation=radius) - return output - -def wavelet_decomposition(image: Tensor, levels=5): - """ - Apply wavelet decomposition to the input tensor. - This function only returns the low frequency & the high frequency. - """ - high_freq = torch.zeros_like(image) - for i in range(levels): - radius = 2 ** i - low_freq = wavelet_blur(image, radius) - high_freq += (image - low_freq) - image = low_freq - - return high_freq, low_freq - -def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): - """ - Apply wavelet decomposition, so that the content will have the same color as the style. - """ - # calculate the wavelet decomposition of the content feature - content_high_freq, content_low_freq = wavelet_decomposition(content_feat) - del content_low_freq - # calculate the wavelet decomposition of the style feature - style_high_freq, style_low_freq = wavelet_decomposition(style_feat) - del style_high_freq - # reconstruct the content feature with the style's high frequency - return content_high_freq + style_low_freq \ No newline at end of file diff --git a/projectsx/video_diffusion_sr/infer.py b/projectsx/video_diffusion_sr/infer.py deleted file mode 100644 index 54bb5fba186f884dd52aed61672b6c675046e42f..0000000000000000000000000000000000000000 --- a/projectsx/video_diffusion_sr/infer.py +++ /dev/null @@ -1,342 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -from typing import List, Optional, Tuple, Union -import torch -from einops import rearrange -from omegaconf import DictConfig, ListConfig -from torch import Tensor - -from common.config import create_object -from common.decorators import log_on_entry, log_runtime -from common.diffusion import ( - classifier_free_guidance_dispatcher, - create_sampler_from_config, - create_sampling_timesteps_from_config, - create_schedule_from_config, -) -from common.distributed import ( - get_device, - get_global_rank, -) - -from common.distributed.meta_init_utils import ( - meta_non_persistent_buffer_init_fn, -) -# from common.fs import download - -from models.dit_v2 import na - -class VideoDiffusionInfer(): - def __init__(self, config: DictConfig): - self.config = config - self.device = "cuda" - - def get_condition(self, latent: Tensor, latent_blur: Tensor, task: str) -> Tensor: - t, h, w, c = latent.shape - cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) - if task == "t2v" or t == 1: - # t2i or t2v generation. - if task == "sr": - cond[:, ..., :-1] = latent_blur[:] - cond[:, ..., -1:] = 1.0 - return cond - if task == "i2v": - # i2v generation. - cond[:1, ..., :-1] = latent[:1] - cond[:1, ..., -1:] = 1.0 - return cond - if task == "v2v": - # v2v frame extension. - cond[:2, ..., :-1] = latent[:2] - cond[:2, ..., -1:] = 1.0 - return cond - if task == "sr": - # sr generation. - cond[:, ..., :-1] = latent_blur[:] - cond[:, ..., -1:] = 1.0 - return cond - raise NotImplementedError - - @log_on_entry - @log_runtime - def configure_dit_model(self, device="cuda", checkpoint=None): - # Load dit checkpoint. - # For fast init & resume, - # when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP. - # otherwise, all ranks init DiT on meta device, then load_state_dict with assign=True. - - # Create dit model. - with torch.device(self.device): - self.dit = create_object(self.config.dit.model) - self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint) - - if checkpoint: - state = torch.load(checkpoint, map_location=self.device, mmap=True) - loading_info = self.dit.load_state_dict(state, strict=True, assign=True) - print(f"Loading pretrained ckpt from {checkpoint}") - print(f"Loading info: {loading_info}") - self.dit = meta_non_persistent_buffer_init_fn(self.dit) - - # Print model size. - num_params = sum(p.numel() for p in self.dit.parameters() if p.requires_grad) - print(f"DiT trainable parameters: {num_params:,}") - - @log_on_entry - @log_runtime - def configure_vae_model(self): - # Create vae model. - dtype = getattr(torch, self.config.vae.dtype) - self.vae = create_object(self.config.vae.model) - self.vae.requires_grad_(False).eval() - self.vae.to(device=get_device(), dtype=dtype) - - # Load vae checkpoint. - state = torch.load( - self.config.vae.checkpoint, map_location=get_device(), mmap=True - ) - self.vae.load_state_dict(state) - - # Set causal slicing. - if hasattr(self.vae, "set_causal_slicing") and hasattr(self.config.vae, "slicing"): - self.vae.set_causal_slicing(**self.config.vae.slicing) - - # ------------------------------ Diffusion ------------------------------ # - - def configure_diffusion(self): - self.schedule = create_schedule_from_config( - config=self.config.diffusion.schedule, - device=get_device(), - ) - self.sampling_timesteps = create_sampling_timesteps_from_config( - config=self.config.diffusion.timesteps.sampling, - schedule=self.schedule, - device=get_device(), - ) - self.sampler = create_sampler_from_config( - config=self.config.diffusion.sampler, - schedule=self.schedule, - timesteps=self.sampling_timesteps, - ) - - # -------------------------------- Helper ------------------------------- # - - @torch.no_grad() - def vae_encode(self, samples: List[Tensor]) -> List[Tensor]: - use_sample = self.config.vae.get("use_sample", True) - latents = [] - if len(samples) > 0: - device = get_device() - dtype = getattr(torch, self.config.vae.dtype) - scale = self.config.vae.scaling_factor - shift = self.config.vae.get("shifting_factor", 0.0) - - if isinstance(scale, ListConfig): - scale = torch.tensor(scale, device=device, dtype=dtype) - if isinstance(shift, ListConfig): - shift = torch.tensor(shift, device=device, dtype=dtype) - - # Group samples of the same shape to batches if enabled. - if self.config.vae.grouping: - batches, indices = na.pack(samples) - else: - batches = [sample.unsqueeze(0) for sample in samples] - - # Vae process by each group. - for sample in batches: - sample = sample.to(device, dtype) - if hasattr(self.vae, "preprocess"): - sample = self.vae.preprocess(sample) - if use_sample: - latent = self.vae.encode(sample).latent - else: - # Deterministic vae encode, only used for i2v inference (optionally) - latent = self.vae.encode(sample).posterior.mode().squeeze(2) - latent = latent.unsqueeze(2) if latent.ndim == 4 else latent - latent = rearrange(latent, "b c ... -> b ... c") - latent = (latent - shift) * scale - latents.append(latent) - - # Ungroup back to individual latent with the original order. - if self.config.vae.grouping: - latents = na.unpack(latents, indices) - else: - latents = [latent.squeeze(0) for latent in latents] - - return latents - - @torch.no_grad() - def vae_decode(self, latents: List[Tensor]) -> List[Tensor]: - samples = [] - if len(latents) > 0: - device = get_device() - dtype = getattr(torch, self.config.vae.dtype) - scale = self.config.vae.scaling_factor - shift = self.config.vae.get("shifting_factor", 0.0) - - if isinstance(scale, ListConfig): - scale = torch.tensor(scale, device=device, dtype=dtype) - if isinstance(shift, ListConfig): - shift = torch.tensor(shift, device=device, dtype=dtype) - - # Group latents of the same shape to batches if enabled. - if self.config.vae.grouping: - latents, indices = na.pack(latents) - else: - latents = [latent.unsqueeze(0) for latent in latents] - - # Vae process by each group. - for latent in latents: - latent = latent.to(device, dtype) - latent = latent / scale + shift - latent = rearrange(latent, "b ... c -> b c ...") - latent = latent.squeeze(2) - sample = self.vae.decode(latent).sample - if hasattr(self.vae, "postprocess"): - sample = self.vae.postprocess(sample) - samples.append(sample) - - # Ungroup back to individual sample with the original order. - if self.config.vae.grouping: - samples = na.unpack(samples, indices) - else: - samples = [sample.squeeze(0) for sample in samples] - - return samples - - def timestep_transform(self, timesteps: Tensor, latents_shapes: Tensor): - # Skip if not needed. - if not self.config.diffusion.timesteps.get("transform", False): - return timesteps - - # Compute resolution. - vt = self.config.vae.model.get("temporal_downsample_factor", 4) - vs = self.config.vae.model.get("spatial_downsample_factor", 8) - frames = (latents_shapes[:, 0] - 1) * vt + 1 - heights = latents_shapes[:, 1] * vs - widths = latents_shapes[:, 2] * vs - - # Compute shift factor. - def get_lin_function(x1, y1, x2, y2): - m = (y2 - y1) / (x2 - x1) - b = y1 - m * x1 - return lambda x: m * x + b - - img_shift_fn = get_lin_function(x1=256 * 256, y1=1.0, x2=1024 * 1024, y2=3.2) - vid_shift_fn = get_lin_function(x1=256 * 256 * 37, y1=1.0, x2=1280 * 720 * 145, y2=5.0) - shift = torch.where( - frames > 1, - vid_shift_fn(heights * widths * frames), - img_shift_fn(heights * widths), - ) - - # Shift timesteps. - timesteps = timesteps / self.schedule.T - timesteps = shift * timesteps / (1 + (shift - 1) * timesteps) - timesteps = timesteps * self.schedule.T - return timesteps - - @torch.no_grad() - def inference( - self, - noises: List[Tensor], - conditions: List[Tensor], - texts_pos: Union[List[str], List[Tensor], List[Tuple[Tensor]]], - texts_neg: Union[List[str], List[Tensor], List[Tuple[Tensor]]], - cfg_scale: Optional[float] = None, - dit_offload: bool = False, - ) -> List[Tensor]: - assert len(noises) == len(conditions) == len(texts_pos) == len(texts_neg) - batch_size = len(noises) - - # Return if empty. - if batch_size == 0: - return [] - - # Set cfg scale - if cfg_scale is None: - cfg_scale = self.config.diffusion.cfg.scale - - # Text embeddings. - assert type(texts_pos[0]) is type(texts_neg[0]) - if isinstance(texts_pos[0], str): - text_pos_embeds, text_pos_shapes = self.text_encode(texts_pos) - text_neg_embeds, text_neg_shapes = self.text_encode(texts_neg) - elif isinstance(texts_pos[0], tuple): - text_pos_embeds, text_pos_shapes = [], [] - text_neg_embeds, text_neg_shapes = [], [] - for pos in zip(*texts_pos): - emb, shape = na.flatten(pos) - text_pos_embeds.append(emb) - text_pos_shapes.append(shape) - for neg in zip(*texts_neg): - emb, shape = na.flatten(neg) - text_neg_embeds.append(emb) - text_neg_shapes.append(shape) - else: - text_pos_embeds, text_pos_shapes = na.flatten(texts_pos) - text_neg_embeds, text_neg_shapes = na.flatten(texts_neg) - - # Flatten. - latents, latents_shapes = na.flatten(noises) - latents_cond, _ = na.flatten(conditions) - - # Enter eval mode. - was_training = self.dit.training - self.dit.eval() - - # Sampling. - latents = self.sampler.sample( - x=latents, - f=lambda args: classifier_free_guidance_dispatcher( - pos=lambda: self.dit( - vid=torch.cat([args.x_t, latents_cond], dim=-1), - txt=text_pos_embeds, - vid_shape=latents_shapes, - txt_shape=text_pos_shapes, - timestep=args.t.repeat(batch_size), - ).vid_sample, - neg=lambda: self.dit( - vid=torch.cat([args.x_t, latents_cond], dim=-1), - txt=text_neg_embeds, - vid_shape=latents_shapes, - txt_shape=text_neg_shapes, - timestep=args.t.repeat(batch_size), - ).vid_sample, - scale=( - cfg_scale - if (args.i + 1) / len(self.sampler.timesteps) - <= self.config.diffusion.cfg.get("partial", 1) - else 1.0 - ), - rescale=self.config.diffusion.cfg.rescale, - ), - ) - - # Exit eval mode. - self.dit.train(was_training) - - # Unflatten. - latents = na.unflatten(latents, latents_shapes) - - if dit_offload: - self.dit.to("cpu") - - # Vae decode. - self.vae.to(get_device()) - samples = self.vae_decode(latents) - - if dit_offload: - self.dit.to(get_device()) - return samples \ No newline at end of file diff --git a/projectsx/video_diffusion_sr/utils.py b/projectsx/video_diffusion_sr/utils.py deleted file mode 100644 index ae48d2662d7ed8630579cb52b97fc0e256335a5a..0000000000000000000000000000000000000000 --- a/projectsx/video_diffusion_sr/utils.py +++ /dev/null @@ -1,368 +0,0 @@ -# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# // -# // Licensed under the Apache License, Version 2.0 (the "License"); -# // you may not use this file except in compliance with the License. -# // You may obtain a copy of the License at -# // -# // http://www.apache.org/licenses/LICENSE-2.0 -# // -# // Unless required by applicable law or agreed to in writing, software -# // distributed under the License is distributed on an "AS IS" BASIS, -# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# // See the License for the specific language governing permissions and -# // limitations under the License. - -import os -import random -import threading -from abc import ABC -from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import dataclass -from functools import partial -from itertools import chain -from typing import Any, Dict, List, Optional, Tuple, Union -import pyarrow as pa -import pyarrow.parquet as pq -from omegaconf import DictConfig - -from common.distributed import get_global_rank, get_world_size -from common.fs import copy, exists, listdir, mkdir, remove -from common.partition import partition_by_groups -from common.persistence.utils import get_local_path -from data.common.parquet_sampler import ( - IdentityParquetSampler, - ParquetSampler, - create_parquet_sampler, -) -from data.common.utils import filter_parquets, get_parquet_metadata - - -# Function to save a Parquet file and copy it to a target path -def save_and_copy( - pa_table, - local_path: str, - target_path: str, - row_group_size: int, - executor: ThreadPoolExecutor, - do_async: bool = False, - futures: List[Tuple[threading.Thread, str]] = [], -): - # Function to handle completion of the future - def _make_on_complete(local_path): - def _on_complete(future): - target_path = future.result() - remove(local_path) - # del future - print(f"Target path saved: {target_path}") - - return _on_complete - - # Function to write Parquet table and copy it - def _fn(pa_table, local_path, target_path, row_group_size): - pq.write_table( - pa_table, - local_path, - row_group_size=row_group_size, - ) - mkdir(os.path.dirname(target_path)) - copy(local_path, target_path) - return target_path - - # Submit the task to the executor - future = executor.submit(_fn, pa_table, local_path, target_path, row_group_size) - future.add_done_callback(_make_on_complete(local_path)) - futures.append(future) - - # If not asynchronous, wait for all futures to complete - if not do_async: - for future in as_completed(futures): - try: - future.result() - except Exception as exc: - print(f"Generated an exception: {exc}") - executor.shutdown(wait=True) - - -@dataclass -class FileListOutput: - existing_files: List[str] - source_files: List[Any] - target_files: List[str] - - -@dataclass -class PersistedParquet: - path: str - - # Method to save the Parquet file - def save( - self, - row_group_size: int, - executor: ThreadPoolExecutor, - pa_table: Optional[pa.Table] = None, - data_dict: Optional[Dict[str, List[Union[str, bytes]]]] = None, - is_last_file=False, - futures: List[threading.Thread] = [], - ): - assert (pa_table is None) != (data_dict is None) - local_path = get_local_path(self.path) - if not pa_table: - schema_dict = self.generate_schema_from_dict(data_dict) - pa_table = pa.Table.from_pydict(data_dict, schema=schema_dict) - save_and_copy( - pa_table, - local_path=local_path, - target_path=self.path, - row_group_size=row_group_size, - executor=executor, - do_async=not is_last_file, - futures=futures, - ) - - # Method to generate schema from a dictionary - def generate_schema_from_dict( - self, - data_dict: Dict[str, List[Union[str, bytes]]], - ): - schema_dict = {} - for key, value in data_dict.items(): - if isinstance(value[0], str): - schema_dict[key] = pa.string() - elif isinstance(value[0], bytes): - schema_dict[key] = pa.binary() - else: - raise ValueError(f"Unsupported data type for key '{key}': {type(value)}") - return pa.schema(schema_dict) - - -# Base class for managing Parquet files -class ParquetManager(ABC): - """ - Base class for the DumpingManager and RepackingManager. - """ - - def __init__( - self, - task: Optional[DictConfig] = None, - target_dir: str = ".", - ): - self.task = task - self.target_dir = target_dir.rstrip("/") - self.executor = ThreadPoolExecutor(max_workers=4) - self.futures = [] - - # Method to get list of Parquet files from source path - def get_parquet_files( - self, - source_path: str, - parquet_sampler: ParquetSampler = IdentityParquetSampler(), - path_mode: str = "dir", - ): - - # Helper function to flatten nested lists - def _flatten(paths): - if isinstance(paths, list): - if any(isinstance(i, list) for i in paths): - return list(chain(*paths)) - else: - return paths - else: - return [paths] - - file_paths = _flatten(source_path) - if path_mode == "dir": - file_paths = map(listdir, file_paths) - if isinstance(parquet_sampler.size, float): - file_paths = map(filter_parquets, file_paths) - file_paths = map(parquet_sampler, file_paths) - file_paths = list(chain(*file_paths)) - else: - file_paths = chain(*file_paths) - file_paths = parquet_sampler(filter_parquets(file_paths)) - - return file_paths - - # Method to save a Parquet file - def save_parquet( - self, - *, - file_name: str, - row_group_size: int, - pa_table: Optional[pa.Table] = None, - data_dict: Optional[Dict[str, List[Union[str, bytes]]]] = None, - override: bool = True, - is_last_file: bool = False, - ): - - persist = self._get_parquet(file_name) - if override or not exists(persist.path): - persist.save( - pa_table=pa_table, - data_dict=data_dict, - executor=self.executor, - row_group_size=row_group_size, - is_last_file=is_last_file, - futures=self.futures, - ) - - # Method to get a PersistedParquet object - def _get_parquet(self, file_name: str) -> PersistedParquet: - return PersistedParquet(file_name) - - -# Class to manage dumping of Parquet files -class DumpingManager(ParquetManager): - """ - Dumping manager handles parquet saving and resuming. - """ - - def __init__( - self, - task: DictConfig, - target_dir: str, - ): - super().__init__(task=task, target_dir=target_dir) - - # Method to generate saving path - def generate_saving_path(self, file_path: str, rsplit: int): - part_list = file_path.rsplit("/", rsplit) - result_folder = "/".join( - [self.target_dir] + [f"epoch_{self.task.epoch}"] + part_list[-rsplit:-1] - ) - result_file = "/".join([result_folder, part_list[-1]]) - return result_folder, result_file - - # Method to configure task paths - def configure_task_path(self, source_path: str, rsplit: int, path_mode: str = "dir"): - - file_paths = self.get_parquet_files( - source_path=source_path, - path_mode=path_mode, - ) - - # Shuffle file paths - random.Random(0).shuffle(file_paths) - - # Partition the file paths based on task configuration - full_source_files = partition_by_groups(file_paths, self.task.total_count)[self.task.index] - full_source_files = partition_by_groups(full_source_files, get_world_size())[ - get_global_rank() - ] - - if not full_source_files: - return FileListOutput([], [], []) - - generate_saving_path = partial(self.generate_saving_path, rsplit=rsplit) - full_paths = map(generate_saving_path, full_source_files) - full_target_folders, full_target_files = map(list, zip(*full_paths)) - full_target_folders = set(full_target_folders) - - existing_file_paths = map( - lambda folder: listdir(folder) if exists(folder) else [], full_target_folders - ) - existing_file_paths = chain(*existing_file_paths) - self.existing_files = list( - filter( - lambda path: path.endswith(".parquet") and path in full_target_files, - existing_file_paths, - ) - ) - - filtered_pairs = list( - filter( - lambda pair: pair[1] not in self.existing_files, - zip(full_source_files, full_target_files), - ) - ) - if filtered_pairs: - filtered_source_files, filtered_target_files = map(list, zip(*filtered_pairs)) - else: - filtered_source_files, filtered_target_files = [], [] - - # Skip existing file paths if specified - skip_exists = self.task.skip_exists - self.source_files = filtered_source_files if skip_exists else full_source_files - self.target_files = filtered_target_files if skip_exists else full_target_files - - return FileListOutput(self.existing_files, self.source_files, self.target_files) - - -class RepackingManager(ParquetManager): - """ - Repacking manager handles parquet spliting and saving. - """ - - def __init__( - self, - task: DictConfig, - target_dir: str, - repackaging: DictConfig, - ): - super().__init__(task=task, target_dir=target_dir) - self.repackaging = repackaging - - # Configure the task paths for repacking - def configure_task_path( - self, - source_path: str, - parquet_sampler: Optional[DictConfig] = None, - path_mode: str = "dir", - ): - - parquet_sampler = create_parquet_sampler(config=parquet_sampler) - file_paths = self.get_parquet_files( - source_path=source_path, - parquet_sampler=parquet_sampler, - path_mode=path_mode, - ) - - random.Random(0).shuffle(file_paths) - target_dir = self.target_dir - size = abs(parquet_sampler.size) - - if self.task: - # Partition the file paths based on task configuration - file_paths = partition_by_groups(file_paths, self.task.total_count)[self.task.index] - target_dir = os.path.join(target_dir, f"{self.task.total_count}_{self.task.index}") - - if size > 1: - size = len( - partition_by_groups(range(size), self.task.total_count)[self.task.index] - ) - - # Get metadata for each Parquet file - metadatas = get_parquet_metadata(file_paths, self.repackaging.num_processes) - - # Create a list of (file_path, row) tuples for each row in the files - target_items = [ - (file_path, row) - for file_path, metadata in zip(file_paths, metadatas) - for row in range(metadata.num_rows) - ] - - # Shuffle the target items - random.Random(0).shuffle(target_items) - - if size > 1: - target_items = target_items[:size] - - # Partition the items into groups for each target file - items_per_file = partition_by_groups(target_items, self.repackaging.num_files) - - # Generate target file paths - target_files = [ - os.path.join(target_dir, f"{str(i).zfill(5)}.parquet") - for i in range(self.repackaging.num_files) - ] - - existing_file_paths = listdir(target_dir) if exists(target_dir) else [] - self.existing_files = list( - filter( - lambda path: path.endswith(".parquet"), - existing_file_paths, - ) - ) - self.source_files = items_per_file - self.target_files = target_files - - return FileListOutput(self.existing_files, self.source_files, self.target_files)