Carlexxx commited on
Commit
206f0da
·
1 Parent(s): 5fad6fa

feat: Implement self-contained specialist managers

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. build_apexx.sh +0 -45
  3. commonx/__init__.py +0 -0
  4. commonx/cache.py +0 -47
  5. commonx/config.py +0 -110
  6. commonx/decorators.py +0 -147
  7. commonx/diffusion/__init__.py +0 -56
  8. commonx/diffusion/config.py +0 -74
  9. commonx/diffusion/samplers/base.py +0 -108
  10. commonx/diffusion/samplers/euler.py +0 -89
  11. commonx/diffusion/schedules/base.py +0 -131
  12. commonx/diffusion/schedules/lerp.py +0 -55
  13. commonx/diffusion/timesteps/base.py +0 -72
  14. commonx/diffusion/timesteps/sampling/trailing.py +0 -49
  15. commonx/diffusion/types.py +0 -59
  16. commonx/diffusion/utils.py +0 -84
  17. commonx/distributed/__init__.py +0 -37
  18. commonx/distributed/advanced.py +0 -208
  19. commonx/distributed/basic.py +0 -84
  20. commonx/distributed/meta_init_utils.py +0 -41
  21. commonx/distributed/ops.py +0 -494
  22. commonx/logger.py +0 -44
  23. commonx/partition.py +0 -59
  24. commonx/seed.py +0 -30
  25. configs_3bx/main.yaml +0 -88
  26. configs_7bx/main.yaml +0 -85
  27. datax/image/transforms/area_resize.py +0 -135
  28. datax/image/transforms/divisible_crop.py +0 -40
  29. datax/image/transforms/na_resize.py +0 -50
  30. datax/image/transforms/side_resize.py +0 -54
  31. datax/video/transforms/rearrange.py +0 -24
  32. environmentx.yml +0 -238
  33. modelsx/dit/attention.py +0 -46
  34. modelsx/dit/blocks/__init__.py +0 -25
  35. modelsx/dit/blocks/mmdit_window_block.py +0 -233
  36. modelsx/dit/embedding.py +0 -62
  37. modelsx/dit/mlp.py +0 -62
  38. modelsx/dit/mm.py +0 -67
  39. modelsx/dit/modulation.py +0 -97
  40. modelsx/dit/na.py +0 -241
  41. modelsx/dit/nablocks/__init__.py +0 -25
  42. modelsx/dit/nablocks/mmsr_block.py +0 -248
  43. modelsx/dit/nadit.py +0 -350
  44. modelsx/dit/normalization.py +0 -63
  45. modelsx/dit/patch.py +0 -112
  46. modelsx/dit/rope.py +0 -101
  47. modelsx/dit/window.py +0 -83
  48. modelsx/dit_v2/attention.py +0 -46
  49. modelsx/dit_v2/embedding.py +0 -62
  50. modelsx/dit_v2/mlp.py +0 -62
.DS_Store DELETED
Binary file (6.15 kB)
 
build_apexx.sh DELETED
@@ -1,45 +0,0 @@
1
- #!/bin/bash
2
- set -e
3
-
4
- # 🧾 Config
5
- # APEX_COMMIT=f3a960f80244cf9e80558ab30f7f7e8cbf03c0a0 # Known stable commit
6
-
7
- echo "🧹 Cleaning any previous apex build..."
8
- rm -rf apex
9
- rm -rf *.egg-info build dist
10
-
11
- echo "📥 Cloning NVIDIA/apex..."
12
- git clone https://github.com/NVIDIA/apex.git
13
- cd apex
14
- # git checkout $APEX_COMMIT
15
-
16
- echo "🛠 Installing build dependencies..."
17
- sudo apt-get update
18
- sudo apt-get install -y \
19
- build-essential \
20
- ninja-build \
21
- python3-dev \
22
- libffi-dev \
23
- libncurses5-dev \
24
- libncursesw5-dev \
25
- libreadline-dev \
26
- libssl-dev \
27
- libsqlite3-dev \
28
- zlib1g-dev \
29
- libbz2-dev \
30
- liblzma-dev \
31
- git
32
-
33
- echo "🐍 Upgrading pip and wheel..."
34
- pip install -U pip setuptools wheel
35
-
36
- echo "🧪 Verifying PyTorch + CUDA availability..."
37
- python -c "import torch; print('PyTorch:', torch.__version__, '| CUDA:', torch.version.cuda)"
38
-
39
- echo "🔧 Building Apex with CUDA and C++ extensions..."
40
- python setup.py bdist_wheel --cuda_ext --cpp_ext
41
-
42
- echo "✅ Done! Built wheel:"
43
- ls -lh dist/*.whl
44
-
45
- cd ..
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/__init__.py DELETED
File without changes
commonx/cache.py DELETED
@@ -1,47 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from typing import Callable
16
-
17
-
18
- class Cache:
19
- """Caching reusable args for faster inference"""
20
-
21
- def __init__(self, disable=False, prefix="", cache=None):
22
- self.cache = cache if cache is not None else {}
23
- self.disable = disable
24
- self.prefix = prefix
25
-
26
- def __call__(self, key: str, fn: Callable):
27
- if self.disable:
28
- return fn()
29
-
30
- key = self.prefix + key
31
- try:
32
- result = self.cache[key]
33
- except KeyError:
34
- result = fn()
35
- self.cache[key] = result
36
- return result
37
-
38
- def namespace(self, namespace: str):
39
- return Cache(
40
- disable=self.disable,
41
- prefix=self.prefix + namespace + ".",
42
- cache=self.cache,
43
- )
44
-
45
- def get(self, key: str):
46
- key = self.prefix + key
47
- return self.cache[key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/config.py DELETED
@@ -1,110 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- """
16
- Configuration utility functions
17
- """
18
-
19
- import importlib
20
- from typing import Any, Callable, List, Union
21
- from omegaconf import DictConfig, ListConfig, OmegaConf
22
-
23
- OmegaConf.register_new_resolver("eval", eval)
24
-
25
-
26
- def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]:
27
- """
28
- Load a configuration. Will resolve inheritance.
29
- """
30
- config = OmegaConf.load(path)
31
- if argv is not None:
32
- config_argv = OmegaConf.from_dotlist(argv)
33
- config = OmegaConf.merge(config, config_argv)
34
- config = resolve_recursive(config, resolve_inheritance)
35
- return config
36
-
37
-
38
- def resolve_recursive(
39
- config: Any,
40
- resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]],
41
- ) -> Any:
42
- config = resolver(config)
43
- if isinstance(config, DictConfig):
44
- for k in config.keys():
45
- v = config.get(k)
46
- if isinstance(v, (DictConfig, ListConfig)):
47
- config[k] = resolve_recursive(v, resolver)
48
- if isinstance(config, ListConfig):
49
- for i in range(len(config)):
50
- v = config.get(i)
51
- if isinstance(v, (DictConfig, ListConfig)):
52
- config[i] = resolve_recursive(v, resolver)
53
- return config
54
-
55
-
56
- def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any:
57
- """
58
- Recursively resolve inheritance if the config contains:
59
- __inherit__: path/to/parent.yaml or a ListConfig of such paths.
60
- """
61
- if isinstance(config, DictConfig):
62
- inherit = config.pop("__inherit__", None)
63
-
64
- if inherit:
65
- inherit_list = inherit if isinstance(inherit, ListConfig) else [inherit]
66
-
67
- parent_config = None
68
- for parent_path in inherit_list:
69
- assert isinstance(parent_path, str)
70
- parent_config = (
71
- load_config(parent_path)
72
- if parent_config is None
73
- else OmegaConf.merge(parent_config, load_config(parent_path))
74
- )
75
-
76
- if len(config.keys()) > 0:
77
- config = OmegaConf.merge(parent_config, config)
78
- else:
79
- config = parent_config
80
- return config
81
-
82
-
83
- def import_item(path: str, name: str) -> Any:
84
- """
85
- Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass
86
- """
87
- return getattr(importlib.import_module(path), name)
88
-
89
-
90
- def create_object(config: DictConfig) -> Any:
91
- """
92
- Create an object from config.
93
- The config is expected to contains the following:
94
- __object__:
95
- path: path.to.module
96
- name: MyClass
97
- args: as_config | as_params (default to as_config)
98
- """
99
- item = import_item(
100
- path=config.__object__.path,
101
- name=config.__object__.name,
102
- )
103
- args = config.__object__.get("args", "as_config")
104
- if args == "as_config":
105
- return item(config)
106
- if args == "as_params":
107
- config = OmegaConf.to_object(config)
108
- config.pop("__object__")
109
- return item(**config)
110
- raise NotImplementedError(f"Unknown args type: {args}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/decorators.py DELETED
@@ -1,147 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- """
16
- Decorators.
17
- """
18
-
19
- import functools
20
- import threading
21
- import time
22
- from typing import Callable
23
- import torch
24
-
25
- from common.distributed import barrier_if_distributed, get_global_rank, get_local_rank
26
- from common.logger import get_logger
27
-
28
- logger = get_logger(__name__)
29
-
30
-
31
- def log_on_entry(func: Callable) -> Callable:
32
- """
33
- Functions with this decorator will log the function name at entry.
34
- When using multiple decorators, this must be applied innermost to properly capture the name.
35
- """
36
-
37
- def log_on_entry_wrapper(*args, **kwargs):
38
- logger.info(f"Entering {func.__name__}")
39
- return func(*args, **kwargs)
40
-
41
- return log_on_entry_wrapper
42
-
43
-
44
- def barrier_on_entry(func: Callable) -> Callable:
45
- """
46
- Functions with this decorator will start executing when all ranks are ready to enter.
47
- """
48
-
49
- def barrier_on_entry_wrapper(*args, **kwargs):
50
- barrier_if_distributed()
51
- return func(*args, **kwargs)
52
-
53
- return barrier_on_entry_wrapper
54
-
55
-
56
- def _conditional_execute_wrapper_factory(execute: bool, func: Callable) -> Callable:
57
- """
58
- Helper function for local_rank_zero_only and global_rank_zero_only.
59
- """
60
-
61
- def conditional_execute_wrapper(*args, **kwargs):
62
- # Only execute if needed.
63
- result = func(*args, **kwargs) if execute else None
64
- # All GPUs must wait.
65
- barrier_if_distributed()
66
- # Return results.
67
- return result
68
-
69
- return conditional_execute_wrapper
70
-
71
-
72
- def _asserted_wrapper_factory(condition: bool, func: Callable, err_msg: str = "") -> Callable:
73
- """
74
- Helper function for some functions with special constraints,
75
- especially functions called by other global_rank_zero_only / local_rank_zero_only ones,
76
- in case they are wrongly invoked in other scenarios.
77
- """
78
-
79
- def asserted_execute_wrapper(*args, **kwargs):
80
- assert condition, err_msg
81
- result = func(*args, **kwargs)
82
- return result
83
-
84
- return asserted_execute_wrapper
85
-
86
-
87
- def local_rank_zero_only(func: Callable) -> Callable:
88
- """
89
- Functions with this decorator will only execute on local rank zero.
90
- """
91
- return _conditional_execute_wrapper_factory(get_local_rank() == 0, func)
92
-
93
-
94
- def global_rank_zero_only(func: Callable) -> Callable:
95
- """
96
- Functions with this decorator will only execute on global rank zero.
97
- """
98
- return _conditional_execute_wrapper_factory(get_global_rank() == 0, func)
99
-
100
-
101
- def assert_only_global_rank_zero(func: Callable) -> Callable:
102
- """
103
- Functions with this decorator are only accessible to processes with global rank zero.
104
- """
105
- return _asserted_wrapper_factory(
106
- get_global_rank() == 0, func, err_msg="Not accessible to processes with global_rank != 0"
107
- )
108
-
109
-
110
- def assert_only_local_rank_zero(func: Callable) -> Callable:
111
- """
112
- Functions with this decorator are only accessible to processes with local rank zero.
113
- """
114
- return _asserted_wrapper_factory(
115
- get_local_rank() == 0, func, err_msg="Not accessible to processes with local_rank != 0"
116
- )
117
-
118
-
119
- def new_thread(func: Callable) -> Callable:
120
- """
121
- Functions with this decorator will run in a new thread.
122
- The function will return the thread, which can be joined to wait for completion.
123
- """
124
-
125
- def new_thread_wrapper(*args, **kwargs):
126
- thread = threading.Thread(target=func, args=args, kwargs=kwargs)
127
- thread.start()
128
- return thread
129
-
130
- return new_thread_wrapper
131
-
132
-
133
- def log_runtime(func: Callable) -> Callable:
134
- """
135
- Functions with this decorator will logging the runtime.
136
- """
137
-
138
- @functools.wraps(func)
139
- def wrapped(*args, **kwargs):
140
- torch.distributed.barrier()
141
- start = time.perf_counter()
142
- result = func(*args, **kwargs)
143
- torch.distributed.barrier()
144
- logger.info(f"Completed {func.__name__} in {time.perf_counter() - start:.3f} seconds.")
145
- return result
146
-
147
- return wrapped
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/diffusion/__init__.py DELETED
@@ -1,56 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- """
16
- Diffusion package.
17
- """
18
-
19
- from .config import (
20
- create_sampler_from_config,
21
- create_sampling_timesteps_from_config,
22
- create_schedule_from_config,
23
- )
24
- from .samplers.base import Sampler
25
- from .samplers.euler import EulerSampler
26
- from .schedules.base import Schedule
27
- from .schedules.lerp import LinearInterpolationSchedule
28
- from .timesteps.base import SamplingTimesteps, Timesteps
29
- from .timesteps.sampling.trailing import UniformTrailingSamplingTimesteps
30
- from .types import PredictionType, SamplingDirection
31
- from .utils import classifier_free_guidance, classifier_free_guidance_dispatcher, expand_dims
32
-
33
- __all__ = [
34
- # Configs
35
- "create_sampler_from_config",
36
- "create_sampling_timesteps_from_config",
37
- "create_schedule_from_config",
38
- # Schedules
39
- "Schedule",
40
- "DiscreteVariancePreservingSchedule",
41
- "LinearInterpolationSchedule",
42
- # Samplers
43
- "Sampler",
44
- "EulerSampler",
45
- # Timesteps
46
- "Timesteps",
47
- "SamplingTimesteps",
48
- # Types
49
- "PredictionType",
50
- "SamplingDirection",
51
- "UniformTrailingSamplingTimesteps",
52
- # Utils
53
- "classifier_free_guidance",
54
- "classifier_free_guidance_dispatcher",
55
- "expand_dims",
56
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/diffusion/config.py DELETED
@@ -1,74 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- """
16
- Utility functions for creating schedules and samplers from config.
17
- """
18
-
19
- import torch
20
- from omegaconf import DictConfig
21
-
22
- from .samplers.base import Sampler
23
- from .samplers.euler import EulerSampler
24
- from .schedules.base import Schedule
25
- from .schedules.lerp import LinearInterpolationSchedule
26
- from .timesteps.base import SamplingTimesteps
27
- from .timesteps.sampling.trailing import UniformTrailingSamplingTimesteps
28
-
29
-
30
- def create_schedule_from_config(
31
- config: DictConfig,
32
- device: torch.device,
33
- dtype: torch.dtype = torch.float32,
34
- ) -> Schedule:
35
- """
36
- Create a schedule from configuration.
37
- """
38
- if config.type == "lerp":
39
- return LinearInterpolationSchedule(T=config.get("T", 1.0))
40
-
41
- raise NotImplementedError
42
-
43
-
44
- def create_sampler_from_config(
45
- config: DictConfig,
46
- schedule: Schedule,
47
- timesteps: SamplingTimesteps,
48
- ) -> Sampler:
49
- """
50
- Create a sampler from configuration.
51
- """
52
- if config.type == "euler":
53
- return EulerSampler(
54
- schedule=schedule,
55
- timesteps=timesteps,
56
- prediction_type=config.prediction_type,
57
- )
58
- raise NotImplementedError
59
-
60
-
61
- def create_sampling_timesteps_from_config(
62
- config: DictConfig,
63
- schedule: Schedule,
64
- device: torch.device,
65
- dtype: torch.dtype = torch.float32,
66
- ) -> SamplingTimesteps:
67
- if config.type == "uniform_trailing":
68
- return UniformTrailingSamplingTimesteps(
69
- T=schedule.T,
70
- steps=config.steps,
71
- shift=config.get("shift", 1.0),
72
- device=device,
73
- )
74
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/diffusion/samplers/base.py DELETED
@@ -1,108 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- """
16
- Sampler base class.
17
- """
18
-
19
- from abc import ABC, abstractmethod
20
- from dataclasses import dataclass
21
- from typing import Callable
22
- import torch
23
- from tqdm import tqdm
24
-
25
- from ..schedules.base import Schedule
26
- from ..timesteps.base import SamplingTimesteps
27
- from ..types import PredictionType, SamplingDirection
28
- from ..utils import assert_schedule_timesteps_compatible
29
-
30
-
31
- @dataclass
32
- class SamplerModelArgs:
33
- x_t: torch.Tensor
34
- t: torch.Tensor
35
- i: int
36
-
37
-
38
- class Sampler(ABC):
39
- """
40
- Samplers are ODE/SDE solvers.
41
- """
42
-
43
- def __init__(
44
- self,
45
- schedule: Schedule,
46
- timesteps: SamplingTimesteps,
47
- prediction_type: PredictionType,
48
- return_endpoint: bool = True,
49
- ):
50
- assert_schedule_timesteps_compatible(
51
- schedule=schedule,
52
- timesteps=timesteps,
53
- )
54
- self.schedule = schedule
55
- self.timesteps = timesteps
56
- self.prediction_type = prediction_type
57
- self.return_endpoint = return_endpoint
58
-
59
- @abstractmethod
60
- def sample(
61
- self,
62
- x: torch.Tensor,
63
- f: Callable[[SamplerModelArgs], torch.Tensor],
64
- ) -> torch.Tensor:
65
- """
66
- Generate a new sample given the the intial sample x and score function f.
67
- """
68
-
69
- def get_next_timestep(
70
- self,
71
- t: torch.Tensor,
72
- ) -> torch.Tensor:
73
- """
74
- Get the next sample timestep.
75
- Support multiple different timesteps t in a batch.
76
- If no more steps, return out of bound value -1 or T+1.
77
- """
78
- T = self.timesteps.T
79
- steps = len(self.timesteps)
80
- curr_idx = self.timesteps.index(t)
81
- next_idx = curr_idx + 1
82
- bound = -1 if self.timesteps.direction == SamplingDirection.backward else T + 1
83
-
84
- s = self.timesteps[next_idx.clamp_max(steps - 1)]
85
- s = s.where(next_idx < steps, bound)
86
- return s
87
-
88
- def get_endpoint(
89
- self,
90
- pred: torch.Tensor,
91
- x_t: torch.Tensor,
92
- t: torch.Tensor,
93
- ) -> torch.Tensor:
94
- """
95
- Get to the endpoint of the probability flow.
96
- """
97
- x_0, x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t)
98
- return x_0 if self.timesteps.direction == SamplingDirection.backward else x_T
99
-
100
- def get_progress_bar(self):
101
- """
102
- Get progress bar for sampling.
103
- """
104
- return tqdm(
105
- iterable=range(len(self.timesteps) - (0 if self.return_endpoint else 1)),
106
- dynamic_ncols=True,
107
- desc=self.__class__.__name__,
108
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/diffusion/samplers/euler.py DELETED
@@ -1,89 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
-
16
- """
17
- Euler ODE solver.
18
- """
19
-
20
- from typing import Callable
21
- import torch
22
- from einops import rearrange
23
- from torch.nn import functional as F
24
-
25
- from models.dit_v2 import na
26
-
27
- from ..types import PredictionType
28
- from ..utils import expand_dims
29
- from .base import Sampler, SamplerModelArgs
30
-
31
-
32
- class EulerSampler(Sampler):
33
- """
34
- The Euler method is the simplest ODE solver.
35
- <https://en.wikipedia.org/wiki/Euler_method>
36
- """
37
-
38
- def sample(
39
- self,
40
- x: torch.Tensor,
41
- f: Callable[[SamplerModelArgs], torch.Tensor],
42
- ) -> torch.Tensor:
43
- timesteps = self.timesteps.timesteps
44
- progress = self.get_progress_bar()
45
- i = 0
46
- for t, s in zip(timesteps[:-1], timesteps[1:]):
47
- pred = f(SamplerModelArgs(x, t, i))
48
- x = self.step_to(pred, x, t, s)
49
- i += 1
50
- progress.update()
51
-
52
- if self.return_endpoint:
53
- t = timesteps[-1]
54
- pred = f(SamplerModelArgs(x, t, i))
55
- x = self.get_endpoint(pred, x, t)
56
- progress.update()
57
- return x
58
-
59
- def step(
60
- self,
61
- pred: torch.Tensor,
62
- x_t: torch.Tensor,
63
- t: torch.Tensor,
64
- ) -> torch.Tensor:
65
- """
66
- Step to the next timestep.
67
- """
68
- return self.step_to(pred, x_t, t, self.get_next_timestep(t))
69
-
70
- def step_to(
71
- self,
72
- pred: torch.Tensor,
73
- x_t: torch.Tensor,
74
- t: torch.Tensor,
75
- s: torch.Tensor,
76
- ) -> torch.Tensor:
77
- """
78
- Steps from x_t at timestep t to x_s at timestep s. Returns x_s.
79
- """
80
- t = expand_dims(t, x_t.ndim)
81
- s = expand_dims(s, x_t.ndim)
82
- T = self.schedule.T
83
- # Step from x_t to x_s.
84
- pred_x_0, pred_x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t)
85
- pred_x_s = self.schedule.forward(pred_x_0, pred_x_T, s.clamp(0, T))
86
- # Clamp x_s to x_0 and x_T if s is out of bound.
87
- pred_x_s = pred_x_s.where(s >= 0, pred_x_0)
88
- pred_x_s = pred_x_s.where(s <= T, pred_x_T)
89
- return pred_x_s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/diffusion/schedules/base.py DELETED
@@ -1,131 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- """
16
- Schedule base class.
17
- """
18
-
19
- from abc import ABC, abstractmethod, abstractproperty
20
- from typing import Tuple, Union
21
- import torch
22
-
23
- from ..types import PredictionType
24
- from ..utils import expand_dims
25
-
26
-
27
- class Schedule(ABC):
28
- """
29
- Diffusion schedules are uniquely defined by T, A, B:
30
-
31
- x_t = A(t) * x_0 + B(t) * x_T, where t in [0, T]
32
-
33
- Schedules can be continuous or discrete.
34
- """
35
-
36
- @abstractproperty
37
- def T(self) -> Union[int, float]:
38
- """
39
- Maximum timestep inclusive.
40
- Schedule is continuous if float, discrete if int.
41
- """
42
-
43
- @abstractmethod
44
- def A(self, t: torch.Tensor) -> torch.Tensor:
45
- """
46
- Interpolation coefficient A.
47
- Returns tensor with the same shape as t.
48
- """
49
-
50
- @abstractmethod
51
- def B(self, t: torch.Tensor) -> torch.Tensor:
52
- """
53
- Interpolation coefficient B.
54
- Returns tensor with the same shape as t.
55
- """
56
-
57
- # ----------------------------------------------------
58
-
59
- def snr(self, t: torch.Tensor) -> torch.Tensor:
60
- """
61
- Signal to noise ratio.
62
- Returns tensor with the same shape as t.
63
- """
64
- return (self.A(t) ** 2) / (self.B(t) ** 2)
65
-
66
- def isnr(self, snr: torch.Tensor) -> torch.Tensor:
67
- """
68
- Inverse signal to noise ratio.
69
- Returns tensor with the same shape as snr.
70
- Subclass may implement.
71
- """
72
- raise NotImplementedError
73
-
74
- # ----------------------------------------------------
75
-
76
- def is_continuous(self) -> bool:
77
- """
78
- Whether the schedule is continuous.
79
- """
80
- return isinstance(self.T, float)
81
-
82
- def forward(self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
83
- """
84
- Diffusion forward function.
85
- """
86
- t = expand_dims(t, x_0.ndim)
87
- return self.A(t) * x_0 + self.B(t) * x_T
88
-
89
- def convert_from_pred(
90
- self, pred: torch.Tensor, pred_type: PredictionType, x_t: torch.Tensor, t: torch.Tensor
91
- ) -> Tuple[torch.Tensor, torch.Tensor]:
92
- """
93
- Convert from prediction. Return predicted x_0 and x_T.
94
- """
95
- t = expand_dims(t, x_t.ndim)
96
- A_t = self.A(t)
97
- B_t = self.B(t)
98
-
99
- if pred_type == PredictionType.x_T:
100
- pred_x_T = pred
101
- pred_x_0 = (x_t - B_t * pred_x_T) / A_t
102
- elif pred_type == PredictionType.x_0:
103
- pred_x_0 = pred
104
- pred_x_T = (x_t - A_t * pred_x_0) / B_t
105
- elif pred_type == PredictionType.v_cos:
106
- pred_x_0 = A_t * x_t - B_t * pred
107
- pred_x_T = A_t * pred + B_t * x_t
108
- elif pred_type == PredictionType.v_lerp:
109
- pred_x_0 = (x_t - B_t * pred) / (A_t + B_t)
110
- pred_x_T = (x_t + A_t * pred) / (A_t + B_t)
111
- else:
112
- raise NotImplementedError
113
-
114
- return pred_x_0, pred_x_T
115
-
116
- def convert_to_pred(
117
- self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor, pred_type: PredictionType
118
- ) -> torch.FloatTensor:
119
- """
120
- Convert to prediction target given x_0 and x_T.
121
- """
122
- if pred_type == PredictionType.x_T:
123
- return x_T
124
- if pred_type == PredictionType.x_0:
125
- return x_0
126
- if pred_type == PredictionType.v_cos:
127
- t = expand_dims(t, x_0.ndim)
128
- return self.A(t) * x_T - self.B(t) * x_0
129
- if pred_type == PredictionType.v_lerp:
130
- return x_T - x_0
131
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/diffusion/schedules/lerp.py DELETED
@@ -1,55 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- """
16
- Linear interpolation schedule (lerp).
17
- """
18
-
19
- from typing import Union
20
- import torch
21
-
22
- from .base import Schedule
23
-
24
-
25
- class LinearInterpolationSchedule(Schedule):
26
- """
27
- Linear interpolation schedule (lerp) is proposed by flow matching and rectified flow.
28
- It leads to straighter probability flow theoretically. It is also used by Stable Diffusion 3.
29
- <https://arxiv.org/abs/2209.03003>
30
- <https://arxiv.org/abs/2210.02747>
31
-
32
- x_t = (1 - t) * x_0 + t * x_T
33
-
34
- Can be either continuous or discrete.
35
- """
36
-
37
- def __init__(self, T: Union[int, float] = 1.0):
38
- self._T = T
39
-
40
- @property
41
- def T(self) -> Union[int, float]:
42
- return self._T
43
-
44
- def A(self, t: torch.Tensor) -> torch.Tensor:
45
- return 1 - (t / self.T)
46
-
47
- def B(self, t: torch.Tensor) -> torch.Tensor:
48
- return t / self.T
49
-
50
- # ----------------------------------------------------
51
-
52
- def isnr(self, snr: torch.Tensor) -> torch.Tensor:
53
- t = self.T / (1 + snr**0.5)
54
- t = t if self.is_continuous() else t.round().int()
55
- return t
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/diffusion/timesteps/base.py DELETED
@@ -1,72 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- from typing import Sequence, Union
3
- import torch
4
-
5
- from ..types import SamplingDirection
6
-
7
-
8
- class Timesteps(ABC):
9
- """
10
- Timesteps base class.
11
- """
12
-
13
- def __init__(self, T: Union[int, float]):
14
- assert T > 0
15
- self._T = T
16
-
17
- @property
18
- def T(self) -> Union[int, float]:
19
- """
20
- Maximum timestep inclusive.
21
- int if discrete, float if continuous.
22
- """
23
- return self._T
24
-
25
- def is_continuous(self) -> bool:
26
- """
27
- Whether the schedule is continuous.
28
- """
29
- return isinstance(self.T, float)
30
-
31
-
32
- class SamplingTimesteps(Timesteps):
33
- """
34
- Sampling timesteps.
35
- It defines the discretization of sampling steps.
36
- """
37
-
38
- def __init__(
39
- self,
40
- T: Union[int, float],
41
- timesteps: torch.Tensor,
42
- direction: SamplingDirection,
43
- ):
44
- assert timesteps.ndim == 1
45
- super().__init__(T)
46
- self.timesteps = timesteps
47
- self.direction = direction
48
-
49
- def __len__(self) -> int:
50
- """
51
- Number of sampling steps.
52
- """
53
- return len(self.timesteps)
54
-
55
- def __getitem__(self, idx: Union[int, torch.IntTensor]) -> torch.Tensor:
56
- """
57
- The timestep at the sampling step.
58
- Returns a scalar tensor if idx is int,
59
- or tensor of the same size if idx is a tensor.
60
- """
61
- return self.timesteps[idx]
62
-
63
- def index(self, t: torch.Tensor) -> torch.Tensor:
64
- """
65
- Find index by t.
66
- Return index of the same shape as t.
67
- Index is -1 if t not found in timesteps.
68
- """
69
- i, j = t.reshape(-1, 1).eq(self.timesteps).nonzero(as_tuple=True)
70
- idx = torch.full_like(t, fill_value=-1, dtype=torch.int)
71
- idx.view(-1)[i] = j.int()
72
- return idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/diffusion/timesteps/sampling/trailing.py DELETED
@@ -1,49 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- import torch
16
-
17
- from ...types import SamplingDirection
18
- from ..base import SamplingTimesteps
19
-
20
-
21
- class UniformTrailingSamplingTimesteps(SamplingTimesteps):
22
- """
23
- Uniform trailing sampling timesteps.
24
- Defined in (https://arxiv.org/abs/2305.08891)
25
-
26
- Shift is proposed in SD3 for RF schedule.
27
- Defined in (https://arxiv.org/pdf/2403.03206) eq.23
28
- """
29
-
30
- def __init__(
31
- self,
32
- T: int,
33
- steps: int,
34
- shift: float = 1.0,
35
- device: torch.device = "cpu",
36
- ):
37
- # Create trailing timesteps.
38
- timesteps = torch.arange(1.0, 0.0, -1.0 / steps, device=device)
39
-
40
- # Shift timesteps.
41
- timesteps = shift * timesteps / (1 + (shift - 1) * timesteps)
42
-
43
- # Scale to T range.
44
- if isinstance(T, float):
45
- timesteps = timesteps * T
46
- else:
47
- timesteps = timesteps.mul(T + 1).sub(1).round().int()
48
-
49
- super().__init__(T=T, timesteps=timesteps, direction=SamplingDirection.backward)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/diffusion/types.py DELETED
@@ -1,59 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- """
16
- Type definitions.
17
- """
18
-
19
- from enum import Enum
20
-
21
-
22
- class PredictionType(str, Enum):
23
- """
24
- x_0:
25
- Predict data sample.
26
- x_T:
27
- Predict noise sample.
28
- Proposed by DDPM (https://arxiv.org/abs/2006.11239)
29
- Proved problematic by zsnr paper (https://arxiv.org/abs/2305.08891)
30
- v_cos:
31
- Predict velocity dx/dt based on the cosine schedule (A_t * x_T - B_t * x_0).
32
- Proposed by progressive distillation (https://arxiv.org/abs/2202.00512)
33
- v_lerp:
34
- Predict velocity dx/dt based on the lerp schedule (x_T - x_0).
35
- Proposed by rectified flow (https://arxiv.org/abs/2209.03003)
36
- """
37
-
38
- x_0 = "x_0"
39
- x_T = "x_T"
40
- v_cos = "v_cos"
41
- v_lerp = "v_lerp"
42
-
43
-
44
- class SamplingDirection(str, Enum):
45
- """
46
- backward: Sample from x_T to x_0 for data generation.
47
- forward: Sample from x_0 to x_T for noise inversion.
48
- """
49
-
50
- backward = "backward"
51
- forward = "forward"
52
-
53
- @staticmethod
54
- def reverse(direction):
55
- if direction == SamplingDirection.backward:
56
- return SamplingDirection.forward
57
- if direction == SamplingDirection.forward:
58
- return SamplingDirection.backward
59
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/diffusion/utils.py DELETED
@@ -1,84 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- """
16
- Utility functions.
17
- """
18
-
19
- from typing import Callable
20
- import torch
21
-
22
-
23
- def expand_dims(tensor: torch.Tensor, ndim: int):
24
- """
25
- Expand tensor to target ndim. New dims are added to the right.
26
- For example, if the tensor shape was (8,), target ndim is 4, return (8, 1, 1, 1).
27
- """
28
- shape = tensor.shape + (1,) * (ndim - tensor.ndim)
29
- return tensor.reshape(shape)
30
-
31
-
32
- def assert_schedule_timesteps_compatible(schedule, timesteps):
33
- """
34
- Check if schedule and timesteps are compatible.
35
- """
36
- if schedule.T != timesteps.T:
37
- raise ValueError("Schedule and timesteps must have the same T.")
38
- if schedule.is_continuous() != timesteps.is_continuous():
39
- raise ValueError("Schedule and timesteps must have the same continuity.")
40
-
41
-
42
- def classifier_free_guidance(
43
- pos: torch.Tensor,
44
- neg: torch.Tensor,
45
- scale: float,
46
- rescale: float = 0.0,
47
- ):
48
- """
49
- Apply classifier-free guidance.
50
- """
51
- # Classifier-free guidance (https://arxiv.org/abs/2207.12598)
52
- cfg = neg + scale * (pos - neg)
53
-
54
- # Classifier-free guidance rescale (https://arxiv.org/pdf/2305.08891.pdf)
55
- if rescale != 0.0:
56
- pos_std = pos.std(dim=list(range(1, pos.ndim)), keepdim=True)
57
- cfg_std = cfg.std(dim=list(range(1, cfg.ndim)), keepdim=True)
58
- factor = pos_std / cfg_std
59
- factor = rescale * factor + (1 - rescale)
60
- cfg *= factor
61
-
62
- return cfg
63
-
64
-
65
- def classifier_free_guidance_dispatcher(
66
- pos: Callable,
67
- neg: Callable,
68
- scale: float,
69
- rescale: float = 0.0,
70
- ):
71
- """
72
- Optionally execute models depending on classifer-free guidance scale.
73
- """
74
- # If scale is 1, no need to execute neg model.
75
- if scale == 1.0:
76
- return pos()
77
-
78
- # Otherwise, execute both pos nad neg models and apply cfg.
79
- return classifier_free_guidance(
80
- pos=pos(),
81
- neg=neg(),
82
- scale=scale,
83
- rescale=rescale,
84
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/distributed/__init__.py DELETED
@@ -1,37 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- """
16
- Distributed package.
17
- """
18
-
19
- from .basic import (
20
- barrier_if_distributed,
21
- convert_to_ddp,
22
- get_device,
23
- get_global_rank,
24
- get_local_rank,
25
- get_world_size,
26
- init_torch,
27
- )
28
-
29
- __all__ = [
30
- "barrier_if_distributed",
31
- "convert_to_ddp",
32
- "get_device",
33
- "get_global_rank",
34
- "get_local_rank",
35
- "get_world_size",
36
- "init_torch",
37
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/distributed/advanced.py DELETED
@@ -1,208 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- """
16
- Advanced distributed functions for sequence parallel.
17
- """
18
-
19
- from typing import Optional, List
20
- import torch
21
- import torch.distributed as dist
22
- from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
23
- from torch.distributed.fsdp import ShardingStrategy
24
-
25
- from .basic import get_global_rank, get_world_size
26
-
27
-
28
- _DATA_PARALLEL_GROUP = None
29
- _SEQUENCE_PARALLEL_GROUP = None
30
- _SEQUENCE_PARALLEL_CPU_GROUP = None
31
- _MODEL_SHARD_CPU_INTER_GROUP = None
32
- _MODEL_SHARD_CPU_INTRA_GROUP = None
33
- _MODEL_SHARD_INTER_GROUP = None
34
- _MODEL_SHARD_INTRA_GROUP = None
35
- _SEQUENCE_PARALLEL_GLOBAL_RANKS = None
36
-
37
-
38
- def get_data_parallel_group() -> Optional[dist.ProcessGroup]:
39
- """
40
- Get data parallel process group.
41
- """
42
- return _DATA_PARALLEL_GROUP
43
-
44
-
45
- def get_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
46
- """
47
- Get sequence parallel process group.
48
- """
49
- return _SEQUENCE_PARALLEL_GROUP
50
-
51
-
52
- def get_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]:
53
- """
54
- Get sequence parallel CPU process group.
55
- """
56
- return _SEQUENCE_PARALLEL_CPU_GROUP
57
-
58
-
59
- def get_data_parallel_rank() -> int:
60
- """
61
- Get data parallel rank.
62
- """
63
- group = get_data_parallel_group()
64
- return dist.get_rank(group) if group else get_global_rank()
65
-
66
-
67
- def get_data_parallel_world_size() -> int:
68
- """
69
- Get data parallel world size.
70
- """
71
- group = get_data_parallel_group()
72
- return dist.get_world_size(group) if group else get_world_size()
73
-
74
-
75
- def get_sequence_parallel_rank() -> int:
76
- """
77
- Get sequence parallel rank.
78
- """
79
- group = get_sequence_parallel_group()
80
- return dist.get_rank(group) if group else 0
81
-
82
-
83
- def get_sequence_parallel_world_size() -> int:
84
- """
85
- Get sequence parallel world size.
86
- """
87
- group = get_sequence_parallel_group()
88
- return dist.get_world_size(group) if group else 1
89
-
90
-
91
- def get_model_shard_cpu_intra_group() -> Optional[dist.ProcessGroup]:
92
- """
93
- Get the CPU intra process group of model sharding.
94
- """
95
- return _MODEL_SHARD_CPU_INTRA_GROUP
96
-
97
-
98
- def get_model_shard_cpu_inter_group() -> Optional[dist.ProcessGroup]:
99
- """
100
- Get the CPU inter process group of model sharding.
101
- """
102
- return _MODEL_SHARD_CPU_INTER_GROUP
103
-
104
-
105
- def get_model_shard_intra_group() -> Optional[dist.ProcessGroup]:
106
- """
107
- Get the GPU intra process group of model sharding.
108
- """
109
- return _MODEL_SHARD_INTRA_GROUP
110
-
111
-
112
- def get_model_shard_inter_group() -> Optional[dist.ProcessGroup]:
113
- """
114
- Get the GPU inter process group of model sharding.
115
- """
116
- return _MODEL_SHARD_INTER_GROUP
117
-
118
-
119
- def init_sequence_parallel(sequence_parallel_size: int):
120
- """
121
- Initialize sequence parallel.
122
- """
123
- global _DATA_PARALLEL_GROUP
124
- global _SEQUENCE_PARALLEL_GROUP
125
- global _SEQUENCE_PARALLEL_CPU_GROUP
126
- global _SEQUENCE_PARALLEL_GLOBAL_RANKS
127
- assert dist.is_initialized()
128
- world_size = dist.get_world_size()
129
- rank = dist.get_rank()
130
- data_parallel_size = world_size // sequence_parallel_size
131
- for i in range(data_parallel_size):
132
- start_rank = i * sequence_parallel_size
133
- end_rank = (i + 1) * sequence_parallel_size
134
- ranks = range(start_rank, end_rank)
135
- group = dist.new_group(ranks)
136
- cpu_group = dist.new_group(ranks, backend="gloo")
137
- if rank in ranks:
138
- _SEQUENCE_PARALLEL_GROUP = group
139
- _SEQUENCE_PARALLEL_CPU_GROUP = cpu_group
140
- _SEQUENCE_PARALLEL_GLOBAL_RANKS = list(ranks)
141
-
142
-
143
- def init_model_shard_group(
144
- *,
145
- sharding_strategy: ShardingStrategy,
146
- device_mesh: Optional[DeviceMesh] = None,
147
- ):
148
- """
149
- Initialize process group of model sharding.
150
- """
151
- global _MODEL_SHARD_INTER_GROUP
152
- global _MODEL_SHARD_INTRA_GROUP
153
- global _MODEL_SHARD_CPU_INTER_GROUP
154
- global _MODEL_SHARD_CPU_INTRA_GROUP
155
- assert dist.is_initialized()
156
- world_size = dist.get_world_size()
157
- if device_mesh is not None:
158
- num_shards_per_group = device_mesh.shape[1]
159
- elif sharding_strategy == ShardingStrategy.NO_SHARD:
160
- num_shards_per_group = 1
161
- elif sharding_strategy in [
162
- ShardingStrategy.HYBRID_SHARD,
163
- ShardingStrategy._HYBRID_SHARD_ZERO2,
164
- ]:
165
- num_shards_per_group = torch.cuda.device_count()
166
- else:
167
- num_shards_per_group = world_size
168
- num_groups = world_size // num_shards_per_group
169
- device_mesh = (num_groups, num_shards_per_group)
170
-
171
- gpu_mesh_2d = init_device_mesh("cuda", device_mesh, mesh_dim_names=("inter", "intra"))
172
- cpu_mesh_2d = init_device_mesh("cpu", device_mesh, mesh_dim_names=("inter", "intra"))
173
-
174
- _MODEL_SHARD_INTER_GROUP = gpu_mesh_2d.get_group("inter")
175
- _MODEL_SHARD_INTRA_GROUP = gpu_mesh_2d.get_group("intra")
176
- _MODEL_SHARD_CPU_INTER_GROUP = cpu_mesh_2d.get_group("inter")
177
- _MODEL_SHARD_CPU_INTRA_GROUP = cpu_mesh_2d.get_group("intra")
178
-
179
- def get_sequence_parallel_global_ranks() -> List[int]:
180
- """
181
- Get all global ranks of the sequence parallel process group
182
- that the caller rank belongs to.
183
- """
184
- if _SEQUENCE_PARALLEL_GLOBAL_RANKS is None:
185
- return [dist.get_rank()]
186
- return _SEQUENCE_PARALLEL_GLOBAL_RANKS
187
-
188
-
189
- def get_next_sequence_parallel_rank() -> int:
190
- """
191
- Get the next global rank of the sequence parallel process group
192
- that the caller rank belongs to.
193
- """
194
- sp_global_ranks = get_sequence_parallel_global_ranks()
195
- sp_rank = get_sequence_parallel_rank()
196
- sp_size = get_sequence_parallel_world_size()
197
- return sp_global_ranks[(sp_rank + 1) % sp_size]
198
-
199
-
200
- def get_prev_sequence_parallel_rank() -> int:
201
- """
202
- Get the previous global rank of the sequence parallel process group
203
- that the caller rank belongs to.
204
- """
205
- sp_global_ranks = get_sequence_parallel_global_ranks()
206
- sp_rank = get_sequence_parallel_rank()
207
- sp_size = get_sequence_parallel_world_size()
208
- return sp_global_ranks[(sp_rank + sp_size - 1) % sp_size]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/distributed/basic.py DELETED
@@ -1,84 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- """
16
- Distributed basic functions.
17
- """
18
-
19
- import os
20
- from datetime import timedelta
21
- import torch
22
- import torch.distributed as dist
23
- from torch.nn.parallel import DistributedDataParallel
24
-
25
-
26
- def get_global_rank() -> int:
27
- """
28
- Get the global rank, the global index of the GPU.
29
- """
30
- return int(os.environ.get("RANK", "0"))
31
-
32
-
33
- def get_local_rank() -> int:
34
- """
35
- Get the local rank, the local index of the GPU.
36
- """
37
- return int(os.environ.get("LOCAL_RANK", "0"))
38
-
39
-
40
- def get_world_size() -> int:
41
- """
42
- Get the world size, the total amount of GPUs.
43
- """
44
- return int(os.environ.get("WORLD_SIZE", "1"))
45
-
46
-
47
- def get_device() -> torch.device:
48
- """
49
- Get current rank device.
50
- """
51
- return torch.device("cuda", get_local_rank())
52
-
53
-
54
- def barrier_if_distributed(*args, **kwargs):
55
- """
56
- Synchronizes all processes if under distributed context.
57
- """
58
- if dist.is_initialized():
59
- return dist.barrier(*args, **kwargs)
60
-
61
-
62
- def init_torch(cudnn_benchmark=True, timeout=timedelta(seconds=600)):
63
- """
64
- Common PyTorch initialization configuration.
65
- """
66
- torch.backends.cuda.matmul.allow_tf32 = True
67
- torch.backends.cudnn.allow_tf32 = True
68
- torch.backends.cudnn.benchmark = cudnn_benchmark
69
- torch.cuda.set_device(get_local_rank())
70
- dist.init_process_group(
71
- backend="nccl",
72
- rank=get_global_rank(),
73
- world_size=get_world_size(),
74
- timeout=timeout,
75
- )
76
-
77
-
78
- def convert_to_ddp(module: torch.nn.Module, **kwargs) -> DistributedDataParallel:
79
- return DistributedDataParallel(
80
- module=module,
81
- device_ids=[get_local_rank()],
82
- output_device=get_local_rank(),
83
- **kwargs,
84
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/distributed/meta_init_utils.py DELETED
@@ -1,41 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- import torch
16
- from rotary_embedding_torch import RotaryEmbedding
17
- from torch import nn
18
- from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
19
-
20
- __all__ = ["meta_non_persistent_buffer_init_fn"]
21
-
22
-
23
- def meta_non_persistent_buffer_init_fn(module: nn.Module) -> nn.Module:
24
- """
25
- Used for materializing `non-persistent tensor buffers` while model resuming.
26
-
27
- Since non-persistent tensor buffers are not saved in state_dict,
28
- when initializing model with meta device, user should materialize those buffers manually.
29
-
30
- Currently, only `rope.dummy` is this special case.
31
- """
32
- with torch.no_grad():
33
- for submodule in module.modules():
34
- if not isinstance(submodule, RotaryEmbedding):
35
- continue
36
- for buffer_name, buffer in submodule.named_buffers(recurse=False):
37
- if buffer.is_meta and "dummy" in buffer_name:
38
- materialized_buffer = torch.zeros_like(buffer, device="cpu")
39
- setattr(submodule, buffer_name, materialized_buffer)
40
- assert not any(b.is_meta for n, b in module.named_buffers())
41
- return module
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/distributed/ops.py DELETED
@@ -1,494 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- """
16
- Distributed ops for supporting sequence parallel.
17
- """
18
-
19
- from collections import defaultdict
20
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
- import torch
22
- import torch.distributed as dist
23
- from torch import Tensor
24
-
25
- from common.cache import Cache
26
- from common.distributed.advanced import (
27
- get_sequence_parallel_group,
28
- get_sequence_parallel_rank,
29
- get_sequence_parallel_world_size,
30
- )
31
-
32
- from .basic import get_device
33
-
34
- _SEQ_DATA_BUF = defaultdict(lambda: [None, None, None])
35
- _SEQ_DATA_META_SHAPES = defaultdict()
36
- _SEQ_DATA_META_DTYPES = defaultdict()
37
- _SEQ_DATA_ASYNC_COMMS = defaultdict(list)
38
- _SYNC_BUFFER = defaultdict(dict)
39
-
40
-
41
- def single_all_to_all(
42
- local_input: Tensor,
43
- scatter_dim: int,
44
- gather_dim: int,
45
- group: dist.ProcessGroup,
46
- async_op: bool = False,
47
- ):
48
- """
49
- A function to do all-to-all on a tensor
50
- """
51
- seq_world_size = dist.get_world_size(group)
52
- prev_scatter_dim = scatter_dim
53
- if scatter_dim != 0:
54
- local_input = local_input.transpose(0, scatter_dim)
55
- if gather_dim == 0:
56
- gather_dim = scatter_dim
57
- scatter_dim = 0
58
-
59
- inp_shape = list(local_input.shape)
60
- inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
61
- input_t = local_input.reshape(
62
- [seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]
63
- ).contiguous()
64
- output = torch.empty_like(input_t)
65
- comm = dist.all_to_all_single(output, input_t, group=group, async_op=async_op)
66
- if async_op:
67
- # let user's code transpose & reshape
68
- return output, comm, prev_scatter_dim
69
-
70
- # first dim is seq_world_size, so we can split it directly
71
- output = torch.cat(output.split(1), dim=gather_dim + 1).squeeze(0)
72
- if prev_scatter_dim:
73
- output = output.transpose(0, prev_scatter_dim).contiguous()
74
- return output
75
-
76
-
77
- def _all_to_all(
78
- local_input: Tensor,
79
- scatter_dim: int,
80
- gather_dim: int,
81
- group: dist.ProcessGroup,
82
- ):
83
- seq_world_size = dist.get_world_size(group)
84
- input_list = [
85
- t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)
86
- ]
87
- output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
88
- dist.all_to_all(output_list, input_list, group=group)
89
- return torch.cat(output_list, dim=gather_dim).contiguous()
90
-
91
-
92
- class SeqAllToAll(torch.autograd.Function):
93
- @staticmethod
94
- def forward(
95
- ctx: Any,
96
- group: dist.ProcessGroup,
97
- local_input: Tensor,
98
- scatter_dim: int,
99
- gather_dim: int,
100
- async_op: bool,
101
- ) -> Tensor:
102
- ctx.group = group
103
- ctx.scatter_dim = scatter_dim
104
- ctx.gather_dim = gather_dim
105
- ctx.async_op = async_op
106
- if async_op:
107
- output, comm, prev_scatter_dim = single_all_to_all(
108
- local_input, scatter_dim, gather_dim, group, async_op=async_op
109
- )
110
- ctx.prev_scatter_dim = prev_scatter_dim
111
- return output, comm
112
-
113
- return _all_to_all(local_input, scatter_dim, gather_dim, group)
114
-
115
- @staticmethod
116
- def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
117
- if ctx.async_op:
118
- input_t = torch.cat(grad_output[0].split(1), dim=ctx.gather_dim + 1).squeeze(0)
119
- if ctx.prev_scatter_dim:
120
- input_t = input_t.transpose(0, ctx.prev_scatter_dim)
121
- else:
122
- input_t = grad_output[0]
123
- return (
124
- None,
125
- _all_to_all(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group),
126
- None,
127
- None,
128
- None,
129
- )
130
-
131
-
132
- class Slice(torch.autograd.Function):
133
- @staticmethod
134
- def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int) -> Tensor:
135
- ctx.group = group
136
- ctx.rank = dist.get_rank(group)
137
- seq_world_size = dist.get_world_size(group)
138
- ctx.seq_world_size = seq_world_size
139
- ctx.dim = dim
140
- dim_size = local_input.shape[dim]
141
- return local_input.split(dim_size // seq_world_size, dim=dim)[ctx.rank].contiguous()
142
-
143
- @staticmethod
144
- def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor, None]:
145
- dim_size = list(grad_output.size())
146
- split_size = dim_size[0]
147
- dim_size[0] = dim_size[0] * ctx.seq_world_size
148
- output = torch.empty(dim_size, dtype=grad_output.dtype, device=torch.cuda.current_device())
149
- dist._all_gather_base(output, grad_output, group=ctx.group)
150
- return (None, torch.cat(output.split(split_size), dim=ctx.dim), None)
151
-
152
-
153
- class Gather(torch.autograd.Function):
154
- @staticmethod
155
- def forward(
156
- ctx: Any,
157
- group: dist.ProcessGroup,
158
- local_input: Tensor,
159
- dim: int,
160
- grad_scale: Optional[bool] = False,
161
- ) -> Tensor:
162
- ctx.group = group
163
- ctx.rank = dist.get_rank(group)
164
- ctx.dim = dim
165
- ctx.grad_scale = grad_scale
166
- seq_world_size = dist.get_world_size(group)
167
- ctx.seq_world_size = seq_world_size
168
- dim_size = list(local_input.size())
169
- split_size = dim_size[0]
170
- ctx.part_size = dim_size[dim]
171
- dim_size[0] = dim_size[0] * seq_world_size
172
- output = torch.empty(dim_size, dtype=local_input.dtype, device=torch.cuda.current_device())
173
- dist._all_gather_base(output, local_input.contiguous(), group=ctx.group)
174
- return torch.cat(output.split(split_size), dim=dim)
175
-
176
- @staticmethod
177
- def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]:
178
- if ctx.grad_scale:
179
- grad_output = grad_output * ctx.seq_world_size
180
- return (
181
- None,
182
- grad_output.split(ctx.part_size, dim=ctx.dim)[ctx.rank].contiguous(),
183
- None,
184
- None,
185
- )
186
-
187
-
188
- def gather_seq_scatter_heads_qkv(
189
- qkv_tensor: Tensor,
190
- *,
191
- seq_dim: int,
192
- qkv_shape: Optional[Tensor] = None,
193
- cache: Cache = Cache(disable=True),
194
- restore_shape: bool = True,
195
- ):
196
- """
197
- A func to sync splited qkv tensor
198
- qkv_tensor: the tensor we want to do alltoall with. The last dim must
199
- be the projection_idx, which we will split into 3 part. After
200
- spliting, the gather idx will be projecttion_idx + 1
201
- seq_dim: gather_dim for all2all comm
202
- restore_shape: if True, output will has the same shape length as input
203
- """
204
- group = get_sequence_parallel_group()
205
- if not group:
206
- return qkv_tensor
207
- world = get_sequence_parallel_world_size()
208
- orig_shape = qkv_tensor.shape
209
- scatter_dim = qkv_tensor.dim()
210
- bef_all2all_shape = list(orig_shape)
211
- qkv_proj_dim = bef_all2all_shape[-1]
212
- bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3]
213
- qkv_tensor = qkv_tensor.view(bef_all2all_shape)
214
- qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, False)
215
- if restore_shape:
216
- out_shape = list(orig_shape)
217
- out_shape[seq_dim] *= world
218
- out_shape[-1] = qkv_proj_dim // world
219
- qkv_tensor = qkv_tensor.view(out_shape)
220
-
221
- # remove padding
222
- if qkv_shape is not None:
223
- unpad_dim_size = cache(
224
- "unpad_dim_size", lambda: torch.sum(torch.prod(qkv_shape, dim=-1)).item()
225
- )
226
- if unpad_dim_size % world != 0:
227
- padding_size = qkv_tensor.size(seq_dim) - unpad_dim_size
228
- qkv_tensor = _unpad_tensor(qkv_tensor, seq_dim, padding_size)
229
- return qkv_tensor
230
-
231
-
232
- def slice_inputs(x: Tensor, dim: int, padding: bool = True):
233
- """
234
- A func to slice the input sequence in sequence parallel
235
- """
236
- group = get_sequence_parallel_group()
237
- if group is None:
238
- return x
239
- sp_rank = get_sequence_parallel_rank()
240
- sp_world = get_sequence_parallel_world_size()
241
- dim_size = x.shape[dim]
242
- unit = (dim_size + sp_world - 1) // sp_world
243
- if padding and dim_size % sp_world:
244
- padding_size = sp_world - (dim_size % sp_world)
245
- x = _pad_tensor(x, dim, padding_size)
246
- slc = [slice(None)] * len(x.shape)
247
- slc[dim] = slice(unit * sp_rank, unit * (sp_rank + 1))
248
- return x[slc]
249
-
250
-
251
- def remove_seqeunce_parallel_padding(x: Tensor, dim: int, unpad_dim_size: int):
252
- """
253
- A func to remove the padding part of the tensor based on its original shape
254
- """
255
- group = get_sequence_parallel_group()
256
- if group is None:
257
- return x
258
- sp_world = get_sequence_parallel_world_size()
259
- if unpad_dim_size % sp_world == 0:
260
- return x
261
- padding_size = sp_world - (unpad_dim_size % sp_world)
262
- assert (padding_size + unpad_dim_size) % sp_world == 0
263
- return _unpad_tensor(x, dim=dim, padding_size=padding_size)
264
-
265
-
266
- def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int) -> Tensor:
267
- """
268
- A func to sync attention result with alltoall in sequence parallel
269
- """
270
- group = get_sequence_parallel_group()
271
- if not group:
272
- return x
273
- dim_size = x.size(seq_dim)
274
- sp_world = get_sequence_parallel_world_size()
275
- if dim_size % sp_world != 0:
276
- padding_size = sp_world - (dim_size % sp_world)
277
- x = _pad_tensor(x, seq_dim, padding_size)
278
- return SeqAllToAll.apply(group, x, seq_dim, head_dim, False)
279
-
280
-
281
- def gather_seq_scatter_heads(x: Tensor, seq_dim: int, head_dim: int) -> Tensor:
282
- """
283
- A func to sync embedding input with alltoall in sequence parallel
284
- """
285
- group = get_sequence_parallel_group()
286
- if not group:
287
- return x
288
- return SeqAllToAll.apply(group, x, head_dim, seq_dim, False)
289
-
290
-
291
- def scatter_heads(x: Tensor, dim: int) -> Tensor:
292
- """
293
- A func to split heads before attention in sequence parallel
294
- """
295
- group = get_sequence_parallel_group()
296
- if not group:
297
- return x
298
- return Slice.apply(group, x, dim)
299
-
300
-
301
- def gather_heads(x: Tensor, dim: int, grad_scale: Optional[bool] = False) -> Tensor:
302
- """
303
- A func to gather heads for the attention result in sequence parallel
304
- """
305
- group = get_sequence_parallel_group()
306
- if not group:
307
- return x
308
- return Gather.apply(group, x, dim, grad_scale)
309
-
310
-
311
- def gather_outputs(
312
- x: Tensor,
313
- *,
314
- gather_dim: int,
315
- padding_dim: Optional[int] = None,
316
- unpad_shape: Optional[Tensor] = None,
317
- cache: Cache = Cache(disable=True),
318
- scale_grad=True,
319
- ):
320
- """
321
- A func to gather the outputs for the model result in sequence parallel
322
- """
323
- group = get_sequence_parallel_group()
324
- if not group:
325
- return x
326
- x = Gather.apply(group, x, gather_dim, scale_grad)
327
- if padding_dim is not None:
328
- unpad_dim_size = cache(
329
- "unpad_dim_size", lambda: torch.sum(torch.prod(unpad_shape, dim=1)).item()
330
- )
331
- x = remove_seqeunce_parallel_padding(x, padding_dim, unpad_dim_size)
332
- return x
333
-
334
-
335
- def _pad_tensor(x: Tensor, dim: int, padding_size: int):
336
- shape = list(x.shape)
337
- shape[dim] = padding_size
338
- pad = torch.zeros(shape, dtype=x.dtype, device=x.device)
339
- return torch.cat([x, pad], dim=dim)
340
-
341
-
342
- def _unpad_tensor(x: Tensor, dim: int, padding_size):
343
- slc = [slice(None)] * len(x.shape)
344
- slc[dim] = slice(0, -padding_size)
345
- return x[slc]
346
-
347
-
348
- def _broadcast_data(data, shape, dtype, src, group, async_op):
349
- comms = []
350
- if isinstance(data, (list, tuple)):
351
- for i, sub_shape in enumerate(shape):
352
- comms += _broadcast_data(data[i], sub_shape, dtype[i], src, group, async_op)
353
- elif isinstance(data, dict):
354
- for key, sub_data in data.items():
355
- comms += _broadcast_data(sub_data, shape[key], dtype[key], src, group, async_op)
356
- elif isinstance(data, Tensor):
357
- comms.append(dist.broadcast(data, src=src, group=group, async_op=async_op))
358
- return comms
359
-
360
-
361
- def _traverse(data: Any, op: Callable) -> Union[None, List, Dict, Any]:
362
- if isinstance(data, (list, tuple)):
363
- return [_traverse(sub_data, op) for sub_data in data]
364
- elif isinstance(data, dict):
365
- return {key: _traverse(sub_data, op) for key, sub_data in data.items()}
366
- elif isinstance(data, Tensor):
367
- return op(data)
368
- else:
369
- return None
370
-
371
-
372
- def _get_shapes(data):
373
- return _traverse(data, op=lambda x: x.shape)
374
-
375
-
376
- def _get_dtypes(data):
377
- return _traverse(data, op=lambda x: x.dtype)
378
-
379
-
380
- def _construct_broadcast_buffer(shapes, dtypes, device):
381
- if isinstance(shapes, torch.Size):
382
- return torch.empty(shapes, dtype=dtypes, device=device)
383
-
384
- if isinstance(shapes, (list, tuple)):
385
- buffer = []
386
- for i, sub_shape in enumerate(shapes):
387
- buffer.append(_construct_broadcast_buffer(sub_shape, dtypes[i], device))
388
- elif isinstance(shapes, dict):
389
- buffer = {}
390
- for key, sub_shape in shapes.items():
391
- buffer[key] = _construct_broadcast_buffer(sub_shape, dtypes[key], device)
392
- else:
393
- return None
394
- return buffer
395
-
396
-
397
- class SPDistForward:
398
- """A forward tool to sync different result across sp group
399
-
400
- Args:
401
- module: a function or module to process users input
402
- sp_step: current training step to judge which rank to broadcast its result to all
403
- name: a distinct str to save meta and async comm
404
- comm_shape: if different ranks have different shape, mark this arg to True
405
- device: the device for current rank, can be empty
406
- """
407
-
408
- def __init__(
409
- self,
410
- name: str,
411
- comm_shape: bool,
412
- device: torch.device = None,
413
- ):
414
- self.name = name
415
- self.comm_shape = comm_shape
416
- if device:
417
- self.device = device
418
- else:
419
- self.device = get_device()
420
-
421
- def __call__(self, inputs) -> Any:
422
- group = get_sequence_parallel_group()
423
- if not group:
424
- yield inputs
425
- else:
426
- device = self.device
427
- sp_world = get_sequence_parallel_world_size()
428
- sp_rank = get_sequence_parallel_rank()
429
- for local_step in range(sp_world):
430
- src_rank = dist.get_global_rank(group, local_step)
431
- is_src = sp_rank == local_step
432
- local_shapes = []
433
- local_dtypes = []
434
- if local_step == 0:
435
- local_result = inputs
436
- _SEQ_DATA_BUF[self.name][-1] = local_result
437
- local_shapes = _get_shapes(local_result)
438
- local_dtypes = _get_dtypes(local_result)
439
- if self.comm_shape:
440
- group_shapes_lists = [None] * sp_world
441
- dist.all_gather_object(group_shapes_lists, local_shapes, group=group)
442
- _SEQ_DATA_META_SHAPES[self.name] = group_shapes_lists
443
- else:
444
- _SEQ_DATA_META_SHAPES[self.name] = [local_shapes] * sp_world
445
- _SEQ_DATA_META_DTYPES[self.name] = local_dtypes
446
- shapes = _SEQ_DATA_META_SHAPES[self.name][local_step]
447
- dtypes = _SEQ_DATA_META_DTYPES[self.name]
448
- buf_id = local_step % 2
449
- if local_step == 0:
450
- sync_data = (
451
- local_result
452
- if is_src
453
- else _construct_broadcast_buffer(shapes, dtypes, device)
454
- )
455
- _broadcast_data(sync_data, shapes, dtypes, src_rank, group, False)
456
- _SEQ_DATA_BUF[self.name][buf_id] = sync_data
457
-
458
- # wait for async comm ops
459
- if _SEQ_DATA_ASYNC_COMMS[self.name]:
460
- for comm in _SEQ_DATA_ASYNC_COMMS[self.name]:
461
- comm.wait()
462
- # before return the sync result, do async broadcast for next batch
463
- if local_step < sp_world - 1:
464
- next_buf_id = 1 - buf_id
465
- shapes = _SEQ_DATA_META_SHAPES[self.name][local_step + 1]
466
- src_rank = dist.get_global_rank(group, local_step + 1)
467
- is_src = sp_rank == local_step + 1
468
- next_sync_data = (
469
- _SEQ_DATA_BUF[self.name][-1]
470
- if is_src
471
- else _construct_broadcast_buffer(shapes, dtypes, device)
472
- )
473
- _SEQ_DATA_ASYNC_COMMS[self.name] = _broadcast_data(
474
- next_sync_data, shapes, dtypes, src_rank, group, True
475
- )
476
- _SEQ_DATA_BUF[self.name][next_buf_id] = next_sync_data
477
- yield _SEQ_DATA_BUF[self.name][buf_id]
478
-
479
-
480
- sync_inputs = SPDistForward(name="bef_fwd", comm_shape=True)
481
-
482
-
483
- def sync_data(data, sp_idx, name="tmp"):
484
- group = get_sequence_parallel_group()
485
- if group is None:
486
- return data
487
- # if sp_idx in _SYNC_BUFFER[name]:
488
- # return _SYNC_BUFFER[name][sp_idx]
489
- sp_rank = get_sequence_parallel_rank()
490
- src_rank = dist.get_global_rank(group, sp_idx)
491
- objects = [data] if sp_rank == sp_idx else [None]
492
- dist.broadcast_object_list(objects, src=src_rank, group=group)
493
- # _SYNC_BUFFER[name] = {sp_idx: objects[0]}
494
- return objects[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/logger.py DELETED
@@ -1,44 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- """
16
- Logging utility functions.
17
- """
18
-
19
- import logging
20
- import sys
21
- from typing import Optional
22
-
23
- from common.distributed import get_global_rank, get_local_rank, get_world_size
24
-
25
- _default_handler = logging.StreamHandler(sys.stdout)
26
- _default_handler.setFormatter(
27
- logging.Formatter(
28
- "%(asctime)s "
29
- + (f"[Rank:{get_global_rank()}]" if get_world_size() > 1 else "")
30
- + (f"[LocalRank:{get_local_rank()}]" if get_world_size() > 1 else "")
31
- + "[%(threadName).12s][%(name)s][%(levelname).5s] "
32
- + "%(message)s"
33
- )
34
- )
35
-
36
-
37
- def get_logger(name: Optional[str] = None) -> logging.Logger:
38
- """
39
- Get a logger.
40
- """
41
- logger = logging.getLogger(name)
42
- logger.addHandler(_default_handler)
43
- logger.setLevel(logging.INFO)
44
- return logger
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/partition.py DELETED
@@ -1,59 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- """
16
- Partition utility functions.
17
- """
18
-
19
- from typing import Any, List
20
-
21
-
22
- def partition_by_size(data: List[Any], size: int) -> List[List[Any]]:
23
- """
24
- Partition a list by size.
25
- When indivisible, the last group contains fewer items than the target size.
26
-
27
- Examples:
28
- - data: [1,2,3,4,5]
29
- - size: 2
30
- - return: [[1,2], [3,4], [5]]
31
- """
32
- assert size > 0
33
- return [data[i : (i + size)] for i in range(0, len(data), size)]
34
-
35
-
36
- def partition_by_groups(data: List[Any], groups: int) -> List[List[Any]]:
37
- """
38
- Partition a list by groups.
39
- When indivisible, some groups may have more items than others.
40
-
41
- Examples:
42
- - data: [1,2,3,4,5]
43
- - groups: 2
44
- - return: [[1,3,5], [2,4]]
45
- """
46
- assert groups > 0
47
- return [data[i::groups] for i in range(groups)]
48
-
49
-
50
- def shift_list(data: List[Any], n: int) -> List[Any]:
51
- """
52
- Rotate a list by n elements.
53
-
54
- Examples:
55
- - data: [1,2,3,4,5]
56
- - n: 3
57
- - return: [4,5,1,2,3]
58
- """
59
- return data[(n % len(data)) :] + data[: (n % len(data))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
commonx/seed.py DELETED
@@ -1,30 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- import random
16
- from typing import Optional
17
- import numpy as np
18
- import torch
19
-
20
- from common.distributed import get_global_rank
21
-
22
-
23
- def set_seed(seed: Optional[int], same_across_ranks: bool = False):
24
- """Function that sets the seed for pseudo-random number generators."""
25
- if seed is not None:
26
- seed += get_global_rank() if not same_across_ranks else 0
27
- random.seed(seed)
28
- np.random.seed(seed)
29
- torch.manual_seed(seed)
30
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs_3bx/main.yaml DELETED
@@ -1,88 +0,0 @@
1
- __object__:
2
- path: projects.video_diffusion_sr.train
3
- name: VideoDiffusionTrainer
4
-
5
- dit:
6
- model:
7
- __object__:
8
- path: models.dit_v2.nadit
9
- name: NaDiT
10
- args: as_params
11
- vid_in_channels: 33
12
- vid_out_channels: 16
13
- vid_dim: 2560
14
- vid_out_norm: fusedrms
15
- txt_in_dim: 5120
16
- txt_in_norm: fusedln
17
- txt_dim: ${.vid_dim}
18
- emb_dim: ${eval:'6 * ${.vid_dim}'}
19
- heads: 20
20
- head_dim: 128 # llm-like
21
- expand_ratio: 4
22
- norm: fusedrms
23
- norm_eps: 1.0e-05
24
- ada: single
25
- qk_bias: False
26
- qk_norm: fusedrms
27
- patch_size: [ 1,2,2 ]
28
- num_layers: 32 # llm-like
29
- mm_layers: 10
30
- mlp_type: swiglu
31
- msa_type: None
32
- block_type: ${eval:'${.num_layers} * ["mmdit_sr"]'} # space-full
33
- window: ${eval:'${.num_layers} * [(4,3,3)]'} # space-full
34
- window_method: ${eval:'${.num_layers} // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]'} # space-full
35
- rope_type: mmrope3d
36
- rope_dim: 128
37
- compile: False
38
- gradient_checkpoint: True
39
- fsdp:
40
- sharding_strategy: _HYBRID_SHARD_ZERO2
41
-
42
- ema:
43
- decay: 0.9998
44
-
45
- vae:
46
- model:
47
- __inherit__: models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml
48
- freeze_encoder: False
49
- # gradient_checkpoint: True
50
- slicing:
51
- split_size: 4
52
- memory_device: same
53
- memory_limit:
54
- conv_max_mem: 0.5
55
- norm_max_mem: 0.5
56
- checkpoint: ./ckpts/ema_vae.pth
57
- scaling_factor: 0.9152
58
- compile: False
59
- grouping: False
60
- dtype: bfloat16
61
-
62
- diffusion:
63
- schedule:
64
- type: lerp
65
- T: 1000.0
66
- sampler:
67
- type: euler
68
- prediction_type: v_lerp
69
- timesteps:
70
- training:
71
- type: logitnormal
72
- loc: 0.0
73
- scale: 1.0
74
- sampling:
75
- type: uniform_trailing
76
- steps: 50
77
- transform: True
78
- loss:
79
- type: v_lerp
80
- cfg:
81
- scale: 7.5
82
- rescale: 0
83
-
84
- condition:
85
- i2v: 0.0
86
- v2v: 0.0
87
- sr: 1.0
88
- noise_scale: 0.25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs_7bx/main.yaml DELETED
@@ -1,85 +0,0 @@
1
- __object__:
2
- path: projects.video_diffusion_sr.train
3
- name: VideoDiffusionTrainer
4
-
5
- dit:
6
- model:
7
- __object__:
8
- path: models.dit.nadit
9
- name: NaDiT
10
- args: as_params
11
- vid_in_channels: 33
12
- vid_out_channels: 16
13
- vid_dim: 3072
14
- txt_in_dim: 5120
15
- txt_dim: ${.vid_dim}
16
- emb_dim: ${eval:'6 * ${.vid_dim}'}
17
- heads: 24
18
- head_dim: 128 # llm-like
19
- expand_ratio: 4
20
- norm: fusedrms
21
- norm_eps: 1e-5
22
- ada: single
23
- qk_bias: False
24
- qk_rope: True
25
- qk_norm: fusedrms
26
- patch_size: [ 1,2,2 ]
27
- num_layers: 36 # llm-like
28
- shared_mlp: False
29
- shared_qkv: False
30
- mlp_type: normal
31
- block_type: ${eval:'${.num_layers} * ["mmdit_sr"]'} # space-full
32
- window: ${eval:'${.num_layers} * [(4,3,3)]'} # space-full
33
- window_method: ${eval:'${.num_layers} // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]'} # space-full
34
- compile: False
35
- gradient_checkpoint: True
36
- fsdp:
37
- sharding_strategy: _HYBRID_SHARD_ZERO2
38
-
39
- ema:
40
- decay: 0.9998
41
-
42
- vae:
43
- model:
44
- __inherit__: models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml
45
- freeze_encoder: False
46
- # gradient_checkpoint: True
47
- slicing:
48
- split_size: 4
49
- memory_device: same
50
- memory_limit:
51
- conv_max_mem: 0.5
52
- norm_max_mem: 0.5
53
- checkpoint: ./ckpts/ema_vae.pth
54
- scaling_factor: 0.9152
55
- compile: False
56
- grouping: False
57
- dtype: bfloat16
58
-
59
- diffusion:
60
- schedule:
61
- type: lerp
62
- T: 1000.0
63
- sampler:
64
- type: euler
65
- prediction_type: v_lerp
66
- timesteps:
67
- training:
68
- type: logitnormal
69
- loc: 0.0
70
- scale: 1.0
71
- sampling:
72
- type: uniform_trailing
73
- steps: 50
74
- transform: True
75
- loss:
76
- type: v_lerp
77
- cfg:
78
- scale: 7.5
79
- rescale: 0
80
-
81
- condition:
82
- i2v: 0.0
83
- v2v: 0.0
84
- sr: 1.0
85
- noise_scale: 0.25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
datax/image/transforms/area_resize.py DELETED
@@ -1,135 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- import math
16
- import random
17
- from typing import Union
18
- import torch
19
- from PIL import Image
20
- from torchvision.transforms import functional as TVF
21
- from torchvision.transforms.functional import InterpolationMode
22
-
23
-
24
- class AreaResize:
25
- def __init__(
26
- self,
27
- max_area: float,
28
- downsample_only: bool = False,
29
- interpolation: InterpolationMode = InterpolationMode.BICUBIC,
30
- ):
31
- self.max_area = max_area
32
- self.downsample_only = downsample_only
33
- self.interpolation = interpolation
34
-
35
- def __call__(self, image: Union[torch.Tensor, Image.Image]):
36
-
37
- if isinstance(image, torch.Tensor):
38
- height, width = image.shape[-2:]
39
- elif isinstance(image, Image.Image):
40
- width, height = image.size
41
- else:
42
- raise NotImplementedError
43
-
44
- scale = math.sqrt(self.max_area / (height * width))
45
-
46
- # keep original height and width for small pictures.
47
- scale = 1 if scale >= 1 and self.downsample_only else scale
48
-
49
- resized_height, resized_width = round(height * scale), round(width * scale)
50
-
51
- return TVF.resize(
52
- image,
53
- size=(resized_height, resized_width),
54
- interpolation=self.interpolation,
55
- )
56
-
57
-
58
- class AreaRandomCrop:
59
- def __init__(
60
- self,
61
- max_area: float,
62
- ):
63
- self.max_area = max_area
64
-
65
- def get_params(self, input_size, output_size):
66
- """Get parameters for ``crop`` for a random crop.
67
-
68
- Args:
69
- img (PIL Image): Image to be cropped.
70
- output_size (tuple): Expected output size of the crop.
71
-
72
- Returns:
73
- tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
74
- """
75
- # w, h = _get_image_size(img)
76
- h, w = input_size
77
- th, tw = output_size
78
- if w <= tw and h <= th:
79
- return 0, 0, h, w
80
-
81
- i = random.randint(0, h - th)
82
- j = random.randint(0, w - tw)
83
- return i, j, th, tw
84
-
85
- def __call__(self, image: Union[torch.Tensor, Image.Image]):
86
- if isinstance(image, torch.Tensor):
87
- height, width = image.shape[-2:]
88
- elif isinstance(image, Image.Image):
89
- width, height = image.size
90
- else:
91
- raise NotImplementedError
92
-
93
- resized_height = math.sqrt(self.max_area / (width / height))
94
- resized_width = (width / height) * resized_height
95
-
96
- # print('>>>>>>>>>>>>>>>>>>>>>')
97
- # print((height, width))
98
- # print( (resized_height, resized_width))
99
-
100
- resized_height, resized_width = round(resized_height), round(resized_width)
101
- i, j, h, w = self.get_params((height, width), (resized_height, resized_width))
102
- image = TVF.crop(image, i, j, h, w)
103
- return image
104
-
105
- class ScaleResize:
106
- def __init__(
107
- self,
108
- scale: float,
109
- ):
110
- self.scale = scale
111
-
112
- def __call__(self, image: Union[torch.Tensor, Image.Image]):
113
- if isinstance(image, torch.Tensor):
114
- height, width = image.shape[-2:]
115
- interpolation_mode = InterpolationMode.BILINEAR
116
- antialias = True if image.ndim == 4 else "warn"
117
- elif isinstance(image, Image.Image):
118
- width, height = image.size
119
- interpolation_mode = InterpolationMode.LANCZOS
120
- antialias = "warn"
121
- else:
122
- raise NotImplementedError
123
-
124
- scale = self.scale
125
-
126
- # keep original height and width for small pictures
127
-
128
- resized_height, resized_width = round(height * scale), round(width * scale)
129
- image = TVF.resize(
130
- image,
131
- size=(resized_height, resized_width),
132
- interpolation=interpolation_mode,
133
- antialias=antialias,
134
- )
135
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
datax/image/transforms/divisible_crop.py DELETED
@@ -1,40 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from typing import Union
16
- import torch
17
- from PIL import Image
18
- from torchvision.transforms import functional as TVF
19
-
20
-
21
- class DivisibleCrop:
22
- def __init__(self, factor):
23
- if not isinstance(factor, tuple):
24
- factor = (factor, factor)
25
-
26
- self.height_factor, self.width_factor = factor[0], factor[1]
27
-
28
- def __call__(self, image: Union[torch.Tensor, Image.Image]):
29
- if isinstance(image, torch.Tensor):
30
- height, width = image.shape[-2:]
31
- elif isinstance(image, Image.Image):
32
- width, height = image.size
33
- else:
34
- raise NotImplementedError
35
-
36
- cropped_height = height - (height % self.height_factor)
37
- cropped_width = width - (width % self.width_factor)
38
-
39
- image = TVF.center_crop(img=image, output_size=(cropped_height, cropped_width))
40
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
datax/image/transforms/na_resize.py DELETED
@@ -1,50 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from typing import Literal
16
- from torchvision.transforms import CenterCrop, Compose, InterpolationMode, Resize
17
-
18
- from .area_resize import AreaResize
19
- from .side_resize import SideResize
20
-
21
-
22
- def NaResize(
23
- resolution: int,
24
- mode: Literal["area", "side"],
25
- downsample_only: bool,
26
- interpolation: InterpolationMode = InterpolationMode.BICUBIC,
27
- ):
28
- if mode == "area":
29
- return AreaResize(
30
- max_area=resolution**2,
31
- downsample_only=downsample_only,
32
- interpolation=interpolation,
33
- )
34
- if mode == "side":
35
- return SideResize(
36
- size=resolution,
37
- downsample_only=downsample_only,
38
- interpolation=interpolation,
39
- )
40
- if mode == "square":
41
- return Compose(
42
- [
43
- Resize(
44
- size=resolution,
45
- interpolation=interpolation,
46
- ),
47
- CenterCrop(resolution),
48
- ]
49
- )
50
- raise ValueError(f"Unknown resize mode: {mode}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
datax/image/transforms/side_resize.py DELETED
@@ -1,54 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from typing import Union
16
- import torch
17
- from PIL import Image
18
- from torchvision.transforms import InterpolationMode
19
- from torchvision.transforms import functional as TVF
20
-
21
-
22
- class SideResize:
23
- def __init__(
24
- self,
25
- size: int,
26
- downsample_only: bool = False,
27
- interpolation: InterpolationMode = InterpolationMode.BICUBIC,
28
- ):
29
- self.size = size
30
- self.downsample_only = downsample_only
31
- self.interpolation = interpolation
32
-
33
- def __call__(self, image: Union[torch.Tensor, Image.Image]):
34
- """
35
- Args:
36
- image (PIL Image or Tensor): Image to be scaled.
37
-
38
- Returns:
39
- PIL Image or Tensor: Rescaled image.
40
- """
41
- if isinstance(image, torch.Tensor):
42
- height, width = image.shape[-2:]
43
- elif isinstance(image, Image.Image):
44
- width, height = image.size
45
- else:
46
- raise NotImplementedError
47
-
48
- if self.downsample_only and min(width, height) < self.size:
49
- # keep original height and width for small pictures.
50
- size = min(width, height)
51
- else:
52
- size = self.size
53
-
54
- return TVF.resize(image, size, self.interpolation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
datax/video/transforms/rearrange.py DELETED
@@ -1,24 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from einops import rearrange
16
-
17
-
18
- class Rearrange:
19
- def __init__(self, pattern: str, **kwargs):
20
- self.pattern = pattern
21
- self.kwargs = kwargs
22
-
23
- def __call__(self, x):
24
- return rearrange(x, self.pattern, **self.kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
environmentx.yml DELETED
@@ -1,238 +0,0 @@
1
- name: seedvr
2
- channels:
3
- - defaults
4
- dependencies:
5
- - _libgcc_mutex=0.1=main
6
- - _openmp_mutex=5.1=1_gnu
7
- - ld_impl_linux-64=2.38=h1181459_1
8
- - libffi=3.4.4=h6a678d5_1
9
- - libgcc-ng=11.2.0=h1234567_1
10
- - libgomp=11.2.0=h1234567_1
11
- - libstdcxx-ng=11.2.0=h1234567_1
12
- - ncurses=6.4=h6a678d5_0
13
- - openssl=3.0.14=h5eee18b_0
14
- - pip=24.0=py39h06a4308_0
15
- - python=3.9.19=h955ad1f_1
16
- - readline=8.2=h5eee18b_0
17
- - setuptools=69.5.1=py39h06a4308_0
18
- - sqlite=3.45.3=h5eee18b_0
19
- - tk=8.6.14=h39e8969_0
20
- - tzdata=2024a=h04d1e81_0
21
- - wheel=0.43.0=py39h06a4308_0
22
- - xz=5.4.6=h5eee18b_1
23
- - zlib=1.2.13=h5eee18b_1
24
- - pip:
25
- - absl-py==2.1.0
26
- - accelerate==0.33.0
27
- - addict==2.4.0
28
- - antlr4-python3-runtime==4.9.3
29
- - anykeystore==0.2
30
- - apex==0.1
31
- - asttokens==2.4.1
32
- - astunparse==1.6.3
33
- - attrs==23.2.0
34
- - av==12.0.0
35
- - basicsr==1.4.2
36
- - beartype==0.18.5
37
- - beautifulsoup4==4.12.3
38
- - bitsandbytes==0.44.1
39
- - black==24.4.2
40
- - bs4==0.0.2
41
- - bson==0.5.10
42
- - certifi==2024.6.2
43
- - cffi==1.16.0
44
- - cfgv==3.4.0
45
- - charset-normalizer==3.3.2
46
- - click==8.1.7
47
- - colorama==0.4.6
48
- - contourpy==1.2.1
49
- - cryptacular==1.6.2
50
- - cryptography==39.0.2
51
- - cycler==0.12.1
52
- - decorator==4.4.2
53
- - deepdiff==7.0.1
54
- - deprecated==1.2.14
55
- - diffusers==0.33.1
56
- - distlib==0.3.8
57
- - dnspython==2.6.1
58
- - docker-pycreds==0.4.0
59
- - docstring-parser==0.16
60
- - einops==0.7.0
61
- - exceptiongroup==1.2.1
62
- - executing==2.0.1
63
- - expecttest==0.2.1
64
- - facexlib==0.3.0
65
- - ffmpeg-python==0.2.0
66
- - filelock==3.15.4
67
- - filterpy==1.4.5
68
- - flake8==7.1.0
69
- - flash-attn==2.5.9.post1
70
- - flatbuffers==24.3.25
71
- - fonttools==4.53.0
72
- - fsspec==2023.6.0
73
- - ftfy==6.2.0
74
- - future==1.0.0
75
- - gast==0.5.4
76
- - gdown==5.2.0
77
- - gitdb==4.0.11
78
- - gitpython==3.1.43
79
- - google-pasta==0.2.0
80
- - greenlet==3.0.3
81
- - grpcio==1.64.1
82
- - h5py==3.11.0
83
- - hf-xet==1.1.2
84
- - huggingface-hub==0.32.2
85
- - hupper==1.12.1
86
- - hypothesis==6.100.1
87
- - icecream==2.1.3
88
- - identify==2.5.36
89
- - idna==3.7
90
- - imageio==2.34.0
91
- - imageio-ffmpeg==0.5.1
92
- - importlib-metadata==7.2.1
93
- - importlib-resources==6.4.0
94
- - iniconfig==2.0.0
95
- - ipaddress==1.0.23
96
- - ipython==8.18.1
97
- - isort==5.13.2
98
- - jedi==0.19.1
99
- - jinja2==3.1.4
100
- - jsonargparse==4.14.1
101
- - keras==3.3.3
102
- - kiwisolver==1.4.5
103
- - lazy-loader==0.4
104
- - libclang==18.1.1
105
- - lightning-utilities==0.11.2
106
- - llvmlite==0.43.0
107
- - lmdb==1.5.1
108
- - lpips==0.1.4
109
- - markdown==3.6
110
- - markdown-it-py==3.0.0
111
- - markupsafe==2.1.5
112
- - matplotlib==3.9.0
113
- - matplotlib-inline==0.1.7
114
- - mccabe==0.7.0
115
- - mdurl==0.1.2
116
- - mediapy==1.2.0
117
- - ml-dtypes==0.3.2
118
- - moviepy==1.0.3
119
- - mpmath==1.3.0
120
- - msgpack==1.0.8
121
- - mypy-extensions==1.0.0
122
- - namex==0.0.8
123
- - networkx==3.2.1
124
- - nodeenv==1.9.1
125
- - numba==0.60.0
126
- - numpy==1.24.4
127
- - oauthlib==3.2.2
128
- - omegaconf==2.3.0
129
- - openai-clip==1.0.1
130
- - opencv-python==4.9.0.80
131
- - opencv-python-headless==4.10.0.84
132
- - opt-einsum==3.3.0
133
- - optree==0.11.0
134
- - ordered-set==4.1.0
135
- - packaging==22.0
136
- - pandas==1.5.3
137
- - parameterized==0.9.0
138
- - parso==0.8.4
139
- - pastedeploy==3.1.0
140
- - pathspec==0.12.1
141
- - pathtools==0.1.2
142
- - pbkdf2==1.3
143
- - pexpect==4.9.0
144
- - pillow==10.3.0
145
- - plaster==1.1.2
146
- - plaster-pastedeploy==1.0.1
147
- - platformdirs==4.2.2
148
- - pluggy==1.5.0
149
- - proglog==0.1.10
150
- - promise==2.3
151
- - prompt-toolkit==3.0.47
152
- - protobuf==3.20.3
153
- - psutil==6.0.0
154
- - ptyprocess==0.7.0
155
- - pure-eval==0.2.2
156
- - pyarrow==11.0.0
157
- - pycocotools==2.0.7
158
- - pycodestyle==2.12.0
159
- - pycparser==2.22
160
- - pydantic==1.10.17
161
- - pyflakes==3.2.0
162
- - pygments==2.18.0
163
- - pyiqa==0.1.13
164
- - pyjwt==2.8.0
165
- - pyopenssl==23.2.0
166
- - pyparsing==3.1.2
167
- - pyramid==2.0.2
168
- - pyramid-mailer==0.15.1
169
- - pysocks==1.7.1
170
- - pytest==8.3.3
171
- - python-dateutil==2.9.0.post0
172
- - python-etcd==0.4.5
173
- - python3-openid==3.2.0
174
- - pytz==2024.1
175
- - pyyaml==6.0.1
176
- - regex==2024.5.15
177
- - repoze-sendmail==4.4.1
178
- - requests==2.32.3
179
- - requests-oauthlib==2.0.0
180
- - rich==13.7.1
181
- - rotary-embedding-torch==0.5.3
182
- - safetensors==0.4.3
183
- - scenedetect==0.6.4
184
- - schedule==1.2.2
185
- - scikit-image==0.24.0
186
- - scipy==1.13.1
187
- - sentencepiece==0.2.0
188
- - sentry-sdk==2.6.0
189
- - setproctitle==1.3.3
190
- - shortuuid==1.0.13
191
- - six==1.16.0
192
- - smmap==5.0.1
193
- - sortedcontainers==2.4.0
194
- - soupsieve==2.5
195
- - sqlalchemy==2.0.31
196
- - stack-data==0.6.3
197
- - sympy==1.12.1
198
- - tabulate==0.9.0
199
- - tb-nightly==2.20.0a20250528
200
- - tenacity==8.4.1
201
- - tensorboard==2.16.2
202
- - tensorboard-data-server==0.7.2
203
- - tensorflow==2.16.1
204
- - tensorflow-io-gcs-filesystem==0.37.0
205
- - termcolor==2.4.0
206
- - tifffile==2024.8.30
207
- - tiktoken==0.9.0
208
- - timm==1.0.11
209
- - tokenizers==0.20.3
210
- - tomli==2.0.1
211
- - torch==2.4.0+cu121
212
- - torch-fidelity==0.3.0
213
- - torchaudio==2.4.0+cu121
214
- - torchmetrics==1.3.2
215
- - torchvision==0.19.0+cu121
216
- - tqdm==4.66.4
217
- - traitlets==5.14.3
218
- - transaction==4.0
219
- - transformers==4.46.2
220
- - transformers-stream-generator==0.0.5
221
- - translationstring==1.4
222
- - triton==3.0.0
223
- - typing-extensions==4.12.2
224
- - urllib3==1.26.19
225
- - velruse==1.1.1
226
- - venusian==3.1.0
227
- - virtualenv==20.26.3
228
- - wcwidth==0.2.13
229
- - webob==1.8.7
230
- - werkzeug==3.0.3
231
- - wrapt==1.16.0
232
- - wtforms==3.1.2
233
- - wtforms-recaptcha==0.3.2
234
- - yapf==0.43.0
235
- - zipp==3.19.2
236
- - zope-deprecation==5.0
237
- - zope-interface==6.4.post2
238
- - zope-sqlalchemy==3.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit/attention.py DELETED
@@ -1,46 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- import torch
16
- import torch.nn.functional as F
17
-
18
- from flash_attn import flash_attn_varlen_func
19
-
20
- from torch import nn
21
-
22
- class TorchAttention(nn.Module):
23
- def tflops(self, args, kwargs, output) -> float:
24
- assert len(args) == 0 or len(args) > 2, "query, key should both provided by args / kwargs"
25
- q = kwargs.get("query") or args[0]
26
- k = kwargs.get("key") or args[1]
27
- b, h, sq, d = q.shape
28
- b, h, sk, d = k.shape
29
- return b * h * (4 * d * (sq / 1e6) * (sk / 1e6))
30
-
31
- def forward(self, *args, **kwargs):
32
- return F.scaled_dot_product_attention(*args, **kwargs)
33
-
34
-
35
- class FlashAttentionVarlen(nn.Module):
36
- def tflops(self, args, kwargs, output) -> float:
37
- cu_seqlens_q = kwargs["cu_seqlens_q"]
38
- cu_seqlens_k = kwargs["cu_seqlens_k"]
39
- _, h, d = output.shape
40
- seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]) / 1e6
41
- seqlens_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]) / 1e6
42
- return h * (4 * d * (seqlens_q * seqlens_k).sum())
43
-
44
- def forward(self, *args, **kwargs):
45
- kwargs["deterministic"] = torch.are_deterministic_algorithms_enabled()
46
- return flash_attn_varlen_func(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit/blocks/__init__.py DELETED
@@ -1,25 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from .mmdit_window_block import MMWindowTransformerBlock
16
-
17
- dit_blocks = {
18
- "mmdit_window": MMWindowTransformerBlock,
19
- }
20
-
21
-
22
- def get_block(block_type: str):
23
- if block_type in dit_blocks:
24
- return dit_blocks[block_type]
25
- raise NotImplementedError(f"{block_type} is not supported")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit/blocks/mmdit_window_block.py DELETED
@@ -1,233 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from typing import Tuple, Union
16
- import torch
17
- from einops import rearrange
18
- from torch import nn
19
- from torch.nn import functional as F
20
- from torch.nn.modules.utils import _triple
21
-
22
- from common.distributed.ops import (
23
- gather_heads,
24
- gather_heads_scatter_seq,
25
- gather_seq_scatter_heads_qkv,
26
- scatter_heads,
27
- )
28
-
29
- from ..attention import TorchAttention
30
- from ..mlp import get_mlp
31
- from ..mm import MMArg, MMModule
32
- from ..modulation import ada_layer_type
33
- from ..normalization import norm_layer_type
34
- from ..rope import RotaryEmbedding3d
35
-
36
-
37
- class MMWindowAttention(nn.Module):
38
- def __init__(
39
- self,
40
- vid_dim: int,
41
- txt_dim: int,
42
- heads: int,
43
- head_dim: int,
44
- qk_bias: bool,
45
- qk_rope: bool,
46
- qk_norm: norm_layer_type,
47
- qk_norm_eps: float,
48
- window: Union[int, Tuple[int, int, int]],
49
- window_method: str,
50
- shared_qkv: bool,
51
- ):
52
- super().__init__()
53
- dim = MMArg(vid_dim, txt_dim)
54
- inner_dim = heads * head_dim
55
- qkv_dim = inner_dim * 3
56
-
57
- self.window = _triple(window)
58
- self.window_method = window_method
59
- assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window))
60
-
61
- self.head_dim = head_dim
62
- self.proj_qkv = MMModule(nn.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_qkv)
63
- self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_qkv)
64
- self.norm_q = MMModule(qk_norm, dim=head_dim, eps=qk_norm_eps, elementwise_affine=True)
65
- self.norm_k = MMModule(qk_norm, dim=head_dim, eps=qk_norm_eps, elementwise_affine=True)
66
- self.rope = RotaryEmbedding3d(dim=head_dim // 2) if qk_rope else None
67
- self.attn = TorchAttention()
68
-
69
- def forward(
70
- self,
71
- vid: torch.FloatTensor, # b T H W c
72
- txt: torch.FloatTensor, # b L c
73
- txt_mask: torch.BoolTensor, # b L
74
- ) -> Tuple[
75
- torch.FloatTensor,
76
- torch.FloatTensor,
77
- ]:
78
- # Project q, k, v.
79
- vid_qkv, txt_qkv = self.proj_qkv(vid, txt)
80
- vid_qkv = gather_seq_scatter_heads_qkv(vid_qkv, seq_dim=2)
81
- _, T, H, W, _ = vid_qkv.shape
82
- _, L, _ = txt.shape
83
-
84
- if self.window_method == "win":
85
- nt, nh, nw = self.window
86
- tt, hh, ww = T // nt, H // nh, W // nw
87
- elif self.window_method == "win_by_size":
88
- tt, hh, ww = self.window
89
- tt, hh, ww = (
90
- tt if tt > 0 else T,
91
- hh if hh > 0 else H,
92
- ww if ww > 0 else W,
93
- )
94
- nt, nh, nw = T // tt, H // hh, W // ww
95
- else:
96
- raise NotImplementedError
97
-
98
- vid_qkv = rearrange(vid_qkv, "b T H W (o h d) -> o b h (T H W) d", o=3, d=self.head_dim)
99
- txt_qkv = rearrange(txt_qkv, "b L (o h d) -> o b h L d", o=3, d=self.head_dim)
100
- txt_qkv = scatter_heads(txt_qkv, dim=2)
101
-
102
- vid_q, vid_k, vid_v = vid_qkv.unbind()
103
- txt_q, txt_k, txt_v = txt_qkv.unbind()
104
-
105
- vid_q, txt_q = self.norm_q(vid_q, txt_q)
106
- vid_k, txt_k = self.norm_k(vid_k, txt_k)
107
-
108
- if self.rope:
109
- vid_q, vid_k = self.rope(vid_q, vid_k, (T, H, W))
110
-
111
- def vid_window(v):
112
- return rearrange(
113
- v,
114
- "b h (nt tt nh hh nw ww) d -> b h (nt nh nw) (tt hh ww) d",
115
- hh=hh,
116
- ww=ww,
117
- tt=tt,
118
- nh=nh,
119
- nw=nw,
120
- nt=nt,
121
- )
122
-
123
- def txt_window(t):
124
- return rearrange(t, "b h L d -> b h 1 L d").expand(-1, -1, nt * nh * nw, -1, -1)
125
-
126
- # Process video attention.
127
- vid_msk = F.pad(txt_mask, (tt * hh * ww, 0), value=True)
128
- vid_msk = rearrange(vid_msk, "b l -> b 1 1 1 l").expand(-1, 1, 1, tt * hh * ww, -1)
129
- vid_out = self.attn(
130
- vid_window(vid_q),
131
- torch.cat([vid_window(vid_k), txt_window(txt_k)], dim=-2),
132
- torch.cat([vid_window(vid_v), txt_window(txt_v)], dim=-2),
133
- vid_msk,
134
- )
135
- vid_out = rearrange(
136
- vid_out,
137
- "b h (nt nh nw) (tt hh ww) d -> b (nt tt) (nh hh) (nw ww) (h d)",
138
- hh=hh,
139
- ww=ww,
140
- tt=tt,
141
- nh=nh,
142
- nw=nw,
143
- )
144
- vid_out = gather_heads_scatter_seq(vid_out, head_dim=4, seq_dim=2)
145
-
146
- # Process text attention.
147
- txt_msk = F.pad(txt_mask, (T * H * W, 0), value=True)
148
- txt_msk = rearrange(txt_msk, "b l -> b 1 1 l").expand(-1, 1, L, -1)
149
- txt_out = self.attn(
150
- txt_q,
151
- torch.cat([vid_k, txt_k], dim=-2),
152
- torch.cat([vid_v, txt_v], dim=-2),
153
- txt_msk,
154
- )
155
- txt_out = rearrange(txt_out, "b h L d -> b L (h d)")
156
- txt_out = gather_heads(txt_out, dim=2)
157
-
158
- # Project output.
159
- vid_out, txt_out = self.proj_out(vid_out, txt_out)
160
- return vid_out, txt_out
161
-
162
-
163
- class MMWindowTransformerBlock(nn.Module):
164
- def __init__(
165
- self,
166
- *,
167
- vid_dim: int,
168
- txt_dim: int,
169
- emb_dim: int,
170
- heads: int,
171
- head_dim: int,
172
- expand_ratio: int,
173
- norm: norm_layer_type,
174
- norm_eps: float,
175
- ada: ada_layer_type,
176
- qk_bias: bool,
177
- qk_rope: bool,
178
- qk_norm: norm_layer_type,
179
- window: Union[int, Tuple[int, int, int]],
180
- window_method: str,
181
- shared_qkv: bool,
182
- shared_mlp: bool,
183
- mlp_type: str,
184
- **kwargs,
185
- ):
186
- super().__init__()
187
- dim = MMArg(vid_dim, txt_dim)
188
- self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False)
189
- self.attn = MMWindowAttention(
190
- vid_dim=vid_dim,
191
- txt_dim=txt_dim,
192
- heads=heads,
193
- head_dim=head_dim,
194
- qk_bias=qk_bias,
195
- qk_rope=qk_rope,
196
- qk_norm=qk_norm,
197
- qk_norm_eps=norm_eps,
198
- window=window,
199
- window_method=window_method,
200
- shared_qkv=shared_qkv,
201
- )
202
- self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False)
203
- self.mlp = MMModule(
204
- get_mlp(mlp_type),
205
- dim=dim,
206
- expand_ratio=expand_ratio,
207
- shared_weights=shared_mlp,
208
- )
209
- self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"])
210
-
211
- def forward(
212
- self,
213
- vid: torch.FloatTensor,
214
- txt: torch.FloatTensor,
215
- txt_mask: torch.BoolTensor,
216
- emb: torch.FloatTensor,
217
- ) -> Tuple[
218
- torch.FloatTensor,
219
- torch.FloatTensor,
220
- ]:
221
- vid_attn, txt_attn = self.attn_norm(vid, txt)
222
- vid_attn, txt_attn = self.ada(vid_attn, txt_attn, emb=emb, layer="attn", mode="in")
223
- vid_attn, txt_attn = self.attn(vid_attn, txt_attn, txt_mask=txt_mask)
224
- vid_attn, txt_attn = self.ada(vid_attn, txt_attn, emb=emb, layer="attn", mode="out")
225
- vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt)
226
-
227
- vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn)
228
- vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, emb=emb, layer="mlp", mode="in")
229
- vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp)
230
- vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, emb=emb, layer="mlp", mode="out")
231
- vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn)
232
-
233
- return vid_mlp, txt_mlp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit/embedding.py DELETED
@@ -1,62 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from typing import Optional, Union
16
- import torch
17
- from diffusers.models.embeddings import get_timestep_embedding
18
- from torch import nn
19
-
20
-
21
- def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]):
22
- return emb1 if emb2 is None else emb1 + emb2
23
-
24
-
25
- class TimeEmbedding(nn.Module):
26
- def __init__(
27
- self,
28
- sinusoidal_dim: int,
29
- hidden_dim: int,
30
- output_dim: int,
31
- ):
32
- super().__init__()
33
- self.sinusoidal_dim = sinusoidal_dim
34
- self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim)
35
- self.proj_hid = nn.Linear(hidden_dim, hidden_dim)
36
- self.proj_out = nn.Linear(hidden_dim, output_dim)
37
- self.act = nn.SiLU()
38
-
39
- def forward(
40
- self,
41
- timestep: Union[int, float, torch.IntTensor, torch.FloatTensor],
42
- device: torch.device,
43
- dtype: torch.dtype,
44
- ) -> torch.FloatTensor:
45
- if not torch.is_tensor(timestep):
46
- timestep = torch.tensor([timestep], device=device, dtype=dtype)
47
- if timestep.ndim == 0:
48
- timestep = timestep[None]
49
-
50
- emb = get_timestep_embedding(
51
- timesteps=timestep,
52
- embedding_dim=self.sinusoidal_dim,
53
- flip_sin_to_cos=False,
54
- downscale_freq_shift=0,
55
- )
56
- emb = emb.to(dtype)
57
- emb = self.proj_in(emb)
58
- emb = self.act(emb)
59
- emb = self.proj_hid(emb)
60
- emb = self.act(emb)
61
- emb = self.proj_out(emb)
62
- return emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit/mlp.py DELETED
@@ -1,62 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from typing import Optional
16
- import torch
17
- import torch.nn.functional as F
18
- from torch import nn
19
-
20
-
21
- def get_mlp(mlp_type: Optional[str] = "normal"):
22
- if mlp_type == "normal":
23
- return MLP
24
- elif mlp_type == "swiglu":
25
- return SwiGLUMLP
26
-
27
-
28
- class MLP(nn.Module):
29
- def __init__(
30
- self,
31
- dim: int,
32
- expand_ratio: int,
33
- ):
34
- super().__init__()
35
- self.proj_in = nn.Linear(dim, dim * expand_ratio)
36
- self.act = nn.GELU("tanh")
37
- self.proj_out = nn.Linear(dim * expand_ratio, dim)
38
-
39
- def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
40
- x = self.proj_in(x)
41
- x = self.act(x)
42
- x = self.proj_out(x)
43
- return x
44
-
45
-
46
- class SwiGLUMLP(nn.Module):
47
- def __init__(
48
- self,
49
- dim: int,
50
- expand_ratio: int,
51
- multiple_of: int = 256,
52
- ):
53
- super().__init__()
54
- hidden_dim = int(2 * dim * expand_ratio / 3)
55
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
56
- self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False)
57
- self.proj_out = nn.Linear(hidden_dim, dim, bias=False)
58
- self.proj_in = nn.Linear(dim, hidden_dim, bias=False)
59
-
60
- def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
61
- x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x))
62
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit/mm.py DELETED
@@ -1,67 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from dataclasses import dataclass
16
- from typing import Any, Callable, Dict, List, Tuple
17
- import torch
18
- from torch import nn
19
-
20
-
21
- @dataclass
22
- class MMArg:
23
- vid: Any
24
- txt: Any
25
-
26
-
27
- def get_args(key: str, args: List[Any]) -> List[Any]:
28
- return [getattr(v, key) if isinstance(v, MMArg) else v for v in args]
29
-
30
-
31
- def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
32
- return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()}
33
-
34
-
35
- class MMModule(nn.Module):
36
- def __init__(
37
- self,
38
- module: Callable[..., nn.Module],
39
- *args,
40
- shared_weights: bool = False,
41
- **kwargs,
42
- ):
43
- super().__init__()
44
- self.shared_weights = shared_weights
45
- if self.shared_weights:
46
- assert get_args("vid", args) == get_args("txt", args)
47
- assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs)
48
- self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs))
49
- else:
50
- self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs))
51
- self.txt = module(*get_args("txt", args), **get_kwargs("txt", kwargs))
52
-
53
- def forward(
54
- self,
55
- vid: torch.FloatTensor,
56
- txt: torch.FloatTensor,
57
- *args,
58
- **kwargs,
59
- ) -> Tuple[
60
- torch.FloatTensor,
61
- torch.FloatTensor,
62
- ]:
63
- vid_module = self.vid if not self.shared_weights else self.all
64
- txt_module = self.txt if not self.shared_weights else self.all
65
- vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs))
66
- txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs))
67
- return vid, txt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit/modulation.py DELETED
@@ -1,97 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from typing import Callable, List, Optional
16
- import torch
17
- from einops import rearrange
18
- from torch import nn
19
-
20
- from common.cache import Cache
21
- from common.distributed.ops import slice_inputs
22
-
23
- # (dim: int, emb_dim: int)
24
- ada_layer_type = Callable[[int, int], nn.Module]
25
-
26
-
27
- def get_ada_layer(ada_layer: str) -> ada_layer_type:
28
- if ada_layer == "single":
29
- return AdaSingle
30
- raise NotImplementedError(f"{ada_layer} is not supported")
31
-
32
-
33
- def expand_dims(x: torch.Tensor, dim: int, ndim: int):
34
- """
35
- Expand tensor "x" to "ndim" by adding empty dims at "dim".
36
- Example: x is (b d), target ndim is 5, add dim at 1, return (b 1 1 1 d).
37
- """
38
- shape = x.shape
39
- shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:]
40
- return x.reshape(shape)
41
-
42
-
43
- class AdaSingle(nn.Module):
44
- def __init__(
45
- self,
46
- dim: int,
47
- emb_dim: int,
48
- layers: List[str],
49
- ):
50
- assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim"
51
- super().__init__()
52
- self.dim = dim
53
- self.emb_dim = emb_dim
54
- self.layers = layers
55
- for l in layers:
56
- self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5))
57
- self.register_parameter(f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1))
58
- self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5))
59
-
60
- def forward(
61
- self,
62
- hid: torch.FloatTensor, # b ... c
63
- emb: torch.FloatTensor, # b d
64
- layer: str,
65
- mode: str,
66
- cache: Cache = Cache(disable=True),
67
- branch_tag: str = "",
68
- hid_len: Optional[torch.LongTensor] = None, # b
69
- ) -> torch.FloatTensor:
70
- idx = self.layers.index(layer)
71
- emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :]
72
- emb = expand_dims(emb, 1, hid.ndim + 1)
73
-
74
- if hid_len is not None:
75
- emb = cache(
76
- f"emb_repeat_{idx}_{branch_tag}",
77
- lambda: slice_inputs(
78
- torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]),
79
- dim=0,
80
- ),
81
- )
82
-
83
- shiftA, scaleA, gateA = emb.unbind(-1)
84
- shiftB, scaleB, gateB = (
85
- getattr(self, f"{layer}_shift"),
86
- getattr(self, f"{layer}_scale"),
87
- getattr(self, f"{layer}_gate"),
88
- )
89
-
90
- if mode == "in":
91
- return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB)
92
- if mode == "out":
93
- return hid.mul_(gateA + gateB)
94
- raise NotImplementedError
95
-
96
- def extra_repr(self) -> str:
97
- return f"dim={self.dim}, emb_dim={self.emb_dim}, layers={self.layers}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit/na.py DELETED
@@ -1,241 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from itertools import chain
16
- from typing import Callable, Dict, List, Tuple
17
- import einops
18
- import torch
19
-
20
-
21
- def flatten(
22
- hid: List[torch.FloatTensor], # List of (*** c)
23
- ) -> Tuple[
24
- torch.FloatTensor, # (L c)
25
- torch.LongTensor, # (b n)
26
- ]:
27
- assert len(hid) > 0
28
- shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid])
29
- hid = torch.cat([x.flatten(0, -2) for x in hid])
30
- return hid, shape
31
-
32
-
33
- def unflatten(
34
- hid: torch.FloatTensor, # (L c) or (L ... c)
35
- hid_shape: torch.LongTensor, # (b n)
36
- ) -> List[torch.Tensor]: # List of (*** c) or (*** ... c)
37
- hid_len = hid_shape.prod(-1)
38
- hid = hid.split(hid_len.tolist())
39
- hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)]
40
- return hid
41
-
42
-
43
- def concat(
44
- vid: torch.FloatTensor, # (VL ... c)
45
- txt: torch.FloatTensor, # (TL ... c)
46
- vid_len: torch.LongTensor, # (b)
47
- txt_len: torch.LongTensor, # (b)
48
- ) -> torch.FloatTensor: # (L ... c)
49
- vid = torch.split(vid, vid_len.tolist())
50
- txt = torch.split(txt, txt_len.tolist())
51
- return torch.cat(list(chain(*zip(vid, txt))))
52
-
53
-
54
- def concat_idx(
55
- vid_len: torch.LongTensor, # (b)
56
- txt_len: torch.LongTensor, # (b)
57
- ) -> Tuple[
58
- Callable,
59
- Callable,
60
- ]:
61
- device = vid_len.device
62
- vid_idx = torch.arange(vid_len.sum(), device=device)
63
- txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device)
64
- tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len)
65
- src_idx = torch.argsort(tgt_idx)
66
- return (
67
- lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx),
68
- lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]),
69
- )
70
-
71
-
72
- def unconcat(
73
- all: torch.FloatTensor, # (L ... c)
74
- vid_len: torch.LongTensor, # (b)
75
- txt_len: torch.LongTensor, # (b)
76
- ) -> Tuple[
77
- torch.FloatTensor, # (VL ... c)
78
- torch.FloatTensor, # (TL ... c)
79
- ]:
80
- interleave_len = list(chain(*zip(vid_len.tolist(), txt_len.tolist())))
81
- all = all.split(interleave_len)
82
- vid = torch.cat(all[0::2])
83
- txt = torch.cat(all[1::2])
84
- return vid, txt
85
-
86
-
87
- def repeat_concat(
88
- vid: torch.FloatTensor, # (VL ... c)
89
- txt: torch.FloatTensor, # (TL ... c)
90
- vid_len: torch.LongTensor, # (n*b)
91
- txt_len: torch.LongTensor, # (b)
92
- txt_repeat: List, # (n)
93
- ) -> torch.FloatTensor: # (L ... c)
94
- vid = torch.split(vid, vid_len.tolist())
95
- txt = torch.split(txt, txt_len.tolist())
96
- txt = [[x] * n for x, n in zip(txt, txt_repeat)]
97
- txt = list(chain(*txt))
98
- return torch.cat(list(chain(*zip(vid, txt))))
99
-
100
-
101
- def repeat_concat_idx(
102
- vid_len: torch.LongTensor, # (n*b)
103
- txt_len: torch.LongTensor, # (b)
104
- txt_repeat: torch.LongTensor, # (n)
105
- ) -> Tuple[
106
- Callable,
107
- Callable,
108
- ]:
109
- device = vid_len.device
110
- vid_idx = torch.arange(vid_len.sum(), device=device)
111
- txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device)
112
- txt_repeat_list = txt_repeat.tolist()
113
- tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat)
114
- src_idx = torch.argsort(tgt_idx)
115
- txt_idx_len = len(tgt_idx) - len(vid_idx)
116
- repeat_txt_len = (txt_len * txt_repeat).tolist()
117
-
118
- def unconcat_coalesce(all):
119
- """
120
- Un-concat vid & txt, and coalesce the repeated txt.
121
- e.g. vid [0 1 2 3 4 5 6 7 8] -> 3 splits -> [0 1 2] [3 4 5] [6 7 8]
122
- txt [9 10]
123
- repeat_concat ==> [0 1 2 9 10 3 4 5 9 10 6 7 8 9 10]
124
- 1. argsort re-index ==> [0 1 2 3 4 5 6 7 8 9 9 9 10 10 10]
125
- split ==> vid_out [0 1 2 3 4 5 6 7 8] txt_out [9 9 9 10 10 10]
126
- 2. reshape & mean for each sample to coalesce the repeated txt.
127
- """
128
- vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len])
129
- txt_out_coalesced = []
130
- for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list):
131
- txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1)
132
- txt_out_coalesced.append(txt)
133
- return vid_out, torch.cat(txt_out_coalesced)
134
-
135
- # Note: Backward of torch.index_select is non-deterministic when existing repeated index,
136
- # the difference may cumulative like torch.repeat_interleave, so we use vanilla index here.
137
- return (
138
- lambda vid, txt: torch.cat([vid, txt])[tgt_idx],
139
- lambda all: unconcat_coalesce(all),
140
- )
141
-
142
-
143
- def rearrange(
144
- hid: torch.FloatTensor, # (L c)
145
- hid_shape: torch.LongTensor, # (b n)
146
- pattern: str,
147
- **kwargs: Dict[str, int],
148
- ) -> Tuple[
149
- torch.FloatTensor,
150
- torch.LongTensor,
151
- ]:
152
- return flatten([einops.rearrange(h, pattern, **kwargs) for h in unflatten(hid, hid_shape)])
153
-
154
-
155
- def rearrange_idx(
156
- hid_shape: torch.LongTensor, # (b n)
157
- pattern: str,
158
- **kwargs: Dict[str, int],
159
- ) -> Tuple[Callable, Callable, torch.LongTensor]:
160
- hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1)
161
- tgt_idx, tgt_shape = rearrange(hid_idx, hid_shape, pattern, **kwargs)
162
- tgt_idx = tgt_idx.squeeze(-1)
163
- src_idx = torch.argsort(tgt_idx)
164
- return (
165
- lambda hid: torch.index_select(hid, 0, tgt_idx),
166
- lambda hid: torch.index_select(hid, 0, src_idx),
167
- tgt_shape,
168
- )
169
-
170
-
171
- def repeat(
172
- hid: torch.FloatTensor, # (L c)
173
- hid_shape: torch.LongTensor, # (b n)
174
- pattern: str,
175
- **kwargs: Dict[str, torch.LongTensor], # (b)
176
- ) -> Tuple[
177
- torch.FloatTensor,
178
- torch.LongTensor,
179
- ]:
180
- hid = unflatten(hid, hid_shape)
181
- kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))]
182
- return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)])
183
-
184
-
185
- def pack(
186
- samples: List[torch.Tensor], # List of (h w c).
187
- ) -> Tuple[
188
- List[torch.Tensor], # groups [(b1 h1 w1 c1), (b2 h2 w2 c2)]
189
- List[List[int]], # reversal indices.
190
- ]:
191
- batches = {}
192
- indices = {}
193
- for i, sample in enumerate(samples):
194
- shape = sample.shape
195
- batches[shape] = batches.get(shape, [])
196
- indices[shape] = indices.get(shape, [])
197
- batches[shape].append(sample)
198
- indices[shape].append(i)
199
-
200
- batches = list(map(torch.stack, batches.values()))
201
- indices = list(indices.values())
202
- return batches, indices
203
-
204
-
205
- def unpack(
206
- batches: List[torch.Tensor],
207
- indices: List[List[int]],
208
- ) -> List[torch.Tensor]:
209
- samples = [None] * (max(chain(*indices)) + 1)
210
- for batch, index in zip(batches, indices):
211
- for sample, i in zip(batch.unbind(), index):
212
- samples[i] = sample
213
- return samples
214
-
215
-
216
- def window(
217
- hid: torch.FloatTensor, # (L c)
218
- hid_shape: torch.LongTensor, # (b n)
219
- window_fn: Callable[[torch.Tensor], List[torch.Tensor]],
220
- ):
221
- hid = unflatten(hid, hid_shape)
222
- hid = list(map(window_fn, hid))
223
- hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device)
224
- hid, hid_shape = flatten(list(chain(*hid)))
225
- return hid, hid_shape, hid_windows
226
-
227
-
228
- def window_idx(
229
- hid_shape: torch.LongTensor, # (b n)
230
- window_fn: Callable[[torch.Tensor], List[torch.Tensor]],
231
- ):
232
- hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1)
233
- tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn)
234
- tgt_idx = tgt_idx.squeeze(-1)
235
- src_idx = torch.argsort(tgt_idx)
236
- return (
237
- lambda hid: torch.index_select(hid, 0, tgt_idx),
238
- lambda hid: torch.index_select(hid, 0, src_idx),
239
- tgt_shape,
240
- tgt_windows,
241
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit/nablocks/__init__.py DELETED
@@ -1,25 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from .mmsr_block import NaMMSRTransformerBlock
16
-
17
- nadit_blocks = {
18
- "mmdit_sr": NaMMSRTransformerBlock,
19
- }
20
-
21
-
22
- def get_nablock(block_type: str):
23
- if block_type in nadit_blocks:
24
- return nadit_blocks[block_type]
25
- raise NotImplementedError(f"{block_type} is not supported")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit/nablocks/mmsr_block.py DELETED
@@ -1,248 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from typing import Tuple, Union
16
- import torch
17
- from einops import rearrange
18
- from torch.nn import functional as F
19
-
20
- # from ..cache import Cache
21
- from common.cache import Cache
22
- from common.distributed.ops import gather_heads_scatter_seq, gather_seq_scatter_heads_qkv
23
-
24
- from .. import na
25
- from ..attention import FlashAttentionVarlen
26
- from ..blocks.mmdit_window_block import MMWindowAttention, MMWindowTransformerBlock
27
- from ..mm import MMArg
28
- from ..modulation import ada_layer_type
29
- from ..normalization import norm_layer_type
30
- from ..rope import NaRotaryEmbedding3d
31
- from ..window import get_window_op
32
-
33
-
34
- class NaSwinAttention(MMWindowAttention):
35
- def __init__(
36
- self,
37
- vid_dim: int,
38
- txt_dim: int,
39
- heads: int,
40
- head_dim: int,
41
- qk_bias: bool,
42
- qk_rope: bool,
43
- qk_norm: norm_layer_type,
44
- qk_norm_eps: float,
45
- window: Union[int, Tuple[int, int, int]],
46
- window_method: str,
47
- shared_qkv: bool,
48
- **kwargs,
49
- ):
50
- super().__init__(
51
- vid_dim=vid_dim,
52
- txt_dim=txt_dim,
53
- heads=heads,
54
- head_dim=head_dim,
55
- qk_bias=qk_bias,
56
- qk_rope=qk_rope,
57
- qk_norm=qk_norm,
58
- qk_norm_eps=qk_norm_eps,
59
- window=window,
60
- window_method=window_method,
61
- shared_qkv=shared_qkv,
62
- )
63
- self.rope = NaRotaryEmbedding3d(dim=head_dim // 2) if qk_rope else None
64
- self.attn = FlashAttentionVarlen()
65
- self.window_op = get_window_op(window_method)
66
-
67
- def forward(
68
- self,
69
- vid: torch.FloatTensor, # l c
70
- txt: torch.FloatTensor, # l c
71
- vid_shape: torch.LongTensor, # b 3
72
- txt_shape: torch.LongTensor, # b 1
73
- cache: Cache,
74
- ) -> Tuple[
75
- torch.FloatTensor,
76
- torch.FloatTensor,
77
- ]:
78
-
79
- vid_qkv, txt_qkv = self.proj_qkv(vid, txt)
80
- vid_qkv = gather_seq_scatter_heads_qkv(
81
- vid_qkv,
82
- seq_dim=0,
83
- qkv_shape=vid_shape,
84
- cache=cache.namespace("vid"),
85
- )
86
- txt_qkv = gather_seq_scatter_heads_qkv(
87
- txt_qkv,
88
- seq_dim=0,
89
- qkv_shape=txt_shape,
90
- cache=cache.namespace("txt"),
91
- )
92
-
93
- # re-org the input seq for window attn
94
- cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3")
95
-
96
- def make_window(x: torch.Tensor):
97
- t, h, w, _ = x.shape
98
- window_slices = self.window_op((t, h, w), self.window)
99
- return [x[st, sh, sw] for (st, sh, sw) in window_slices]
100
-
101
- window_partition, window_reverse, window_shape, window_count = cache_win(
102
- "win_transform",
103
- lambda: na.window_idx(vid_shape, make_window),
104
- )
105
- vid_qkv_win = window_partition(vid_qkv)
106
-
107
- vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim)
108
- txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim)
109
-
110
- vid_q, vid_k, vid_v = vid_qkv_win.unbind(1)
111
- txt_q, txt_k, txt_v = txt_qkv.unbind(1)
112
-
113
- vid_q, txt_q = self.norm_q(vid_q, txt_q)
114
- vid_k, txt_k = self.norm_k(vid_k, txt_k)
115
-
116
- txt_len = cache("txt_len", lambda: txt_shape.prod(-1))
117
-
118
- vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1))
119
- txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count))
120
- all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win)
121
- concat_win, unconcat_win = cache_win(
122
- "mm_pnp", lambda: na.repeat_concat_idx(vid_len_win, txt_len, window_count)
123
- )
124
-
125
- # window rope
126
- if self.rope:
127
- vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
128
-
129
- out = self.attn(
130
- q=concat_win(vid_q, txt_q).bfloat16(),
131
- k=concat_win(vid_k, txt_k).bfloat16(),
132
- v=concat_win(vid_v, txt_v).bfloat16(),
133
- cu_seqlens_q=cache_win(
134
- "vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int()
135
- ),
136
- cu_seqlens_k=cache_win(
137
- "vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int()
138
- ),
139
- max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()),
140
- max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()),
141
- ).type_as(vid_q)
142
-
143
- # text pooling
144
- vid_out, txt_out = unconcat_win(out)
145
-
146
- vid_out = rearrange(vid_out, "l h d -> l (h d)")
147
- txt_out = rearrange(txt_out, "l h d -> l (h d)")
148
- vid_out = window_reverse(vid_out)
149
-
150
- vid_out = gather_heads_scatter_seq(vid_out, head_dim=1, seq_dim=0)
151
- txt_out = gather_heads_scatter_seq(txt_out, head_dim=1, seq_dim=0)
152
-
153
- vid_out, txt_out = self.proj_out(vid_out, txt_out)
154
-
155
- return vid_out, txt_out
156
-
157
-
158
- class NaMMSRTransformerBlock(MMWindowTransformerBlock):
159
- def __init__(
160
- self,
161
- *,
162
- vid_dim: int,
163
- txt_dim: int,
164
- emb_dim: int,
165
- heads: int,
166
- head_dim: int,
167
- expand_ratio: int,
168
- norm: norm_layer_type,
169
- norm_eps: float,
170
- ada: ada_layer_type,
171
- qk_bias: bool,
172
- qk_rope: bool,
173
- qk_norm: norm_layer_type,
174
- shared_qkv: bool,
175
- shared_mlp: bool,
176
- mlp_type: str,
177
- **kwargs,
178
- ):
179
- super().__init__(
180
- vid_dim=vid_dim,
181
- txt_dim=txt_dim,
182
- emb_dim=emb_dim,
183
- heads=heads,
184
- head_dim=head_dim,
185
- expand_ratio=expand_ratio,
186
- norm=norm,
187
- norm_eps=norm_eps,
188
- ada=ada,
189
- qk_bias=qk_bias,
190
- qk_rope=qk_rope,
191
- qk_norm=qk_norm,
192
- shared_qkv=shared_qkv,
193
- shared_mlp=shared_mlp,
194
- mlp_type=mlp_type,
195
- **kwargs,
196
- )
197
-
198
- self.attn = NaSwinAttention(
199
- vid_dim=vid_dim,
200
- txt_dim=txt_dim,
201
- heads=heads,
202
- head_dim=head_dim,
203
- qk_bias=qk_bias,
204
- qk_rope=qk_rope,
205
- qk_norm=qk_norm,
206
- qk_norm_eps=norm_eps,
207
- shared_qkv=shared_qkv,
208
- **kwargs,
209
- )
210
-
211
- def forward(
212
- self,
213
- vid: torch.FloatTensor, # l c
214
- txt: torch.FloatTensor, # l c
215
- vid_shape: torch.LongTensor, # b 3
216
- txt_shape: torch.LongTensor, # b 1
217
- emb: torch.FloatTensor,
218
- cache: Cache,
219
- ) -> Tuple[
220
- torch.FloatTensor,
221
- torch.FloatTensor,
222
- torch.LongTensor,
223
- torch.LongTensor,
224
- ]:
225
- hid_len = MMArg(
226
- cache("vid_len", lambda: vid_shape.prod(-1)),
227
- cache("txt_len", lambda: txt_shape.prod(-1)),
228
- )
229
- ada_kwargs = {
230
- "emb": emb,
231
- "hid_len": hid_len,
232
- "cache": cache,
233
- "branch_tag": MMArg("vid", "txt"),
234
- }
235
-
236
- vid_attn, txt_attn = self.attn_norm(vid, txt)
237
- vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs)
238
- vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache)
239
- vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs)
240
- vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt)
241
-
242
- vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn)
243
- vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs)
244
- vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp)
245
- vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs)
246
- vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn)
247
-
248
- return vid_mlp, txt_mlp, vid_shape, txt_shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit/nadit.py DELETED
@@ -1,350 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from dataclasses import dataclass
16
- from typing import Optional, Tuple, Union, Callable
17
- import torch
18
- from torch import nn
19
-
20
- from common.cache import Cache
21
- from common.distributed.ops import slice_inputs
22
-
23
- from . import na
24
- from .embedding import TimeEmbedding
25
- from .modulation import get_ada_layer
26
- from .nablocks import get_nablock
27
- from .normalization import get_norm_layer
28
- from .patch import NaPatchIn, NaPatchOut
29
-
30
- # Fake func, no checkpointing is required for inference
31
- def gradient_checkpointing(module: Union[Callable, nn.Module], *args, enabled: bool, **kwargs):
32
- return module(*args, **kwargs)
33
-
34
- @dataclass
35
- class NaDiTOutput:
36
- vid_sample: torch.Tensor
37
-
38
-
39
- class NaDiT(nn.Module):
40
- """
41
- Native Resolution Diffusion Transformer (NaDiT)
42
- """
43
-
44
- gradient_checkpointing = False
45
-
46
- def __init__(
47
- self,
48
- vid_in_channels: int,
49
- vid_out_channels: int,
50
- vid_dim: int,
51
- txt_in_dim: Optional[int],
52
- txt_dim: Optional[int],
53
- emb_dim: int,
54
- heads: int,
55
- head_dim: int,
56
- expand_ratio: int,
57
- norm: Optional[str],
58
- norm_eps: float,
59
- ada: str,
60
- qk_bias: bool,
61
- qk_rope: bool,
62
- qk_norm: Optional[str],
63
- patch_size: Union[int, Tuple[int, int, int]],
64
- num_layers: int,
65
- block_type: Union[str, Tuple[str]],
66
- shared_qkv: bool = False,
67
- shared_mlp: bool = False,
68
- mlp_type: str = "normal",
69
- window: Optional[Tuple] = None,
70
- window_method: Optional[Tuple[str]] = None,
71
- temporal_window_size: int = None,
72
- temporal_shifted: bool = False,
73
- **kwargs,
74
- ):
75
- ada = get_ada_layer(ada)
76
- norm = get_norm_layer(norm)
77
- qk_norm = get_norm_layer(qk_norm)
78
- if isinstance(block_type, str):
79
- block_type = [block_type] * num_layers
80
- elif len(block_type) != num_layers:
81
- raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
82
- super().__init__()
83
- self.vid_in = NaPatchIn(
84
- in_channels=vid_in_channels,
85
- patch_size=patch_size,
86
- dim=vid_dim,
87
- )
88
- self.txt_in = (
89
- nn.Linear(txt_in_dim, txt_dim)
90
- if txt_in_dim and txt_in_dim != txt_dim
91
- else nn.Identity()
92
- )
93
- self.emb_in = TimeEmbedding(
94
- sinusoidal_dim=256,
95
- hidden_dim=max(vid_dim, txt_dim),
96
- output_dim=emb_dim,
97
- )
98
-
99
- if window is None or isinstance(window[0], int):
100
- window = [window] * num_layers
101
- if window_method is None or isinstance(window_method, str):
102
- window_method = [window_method] * num_layers
103
- if temporal_window_size is None or isinstance(temporal_window_size, int):
104
- temporal_window_size = [temporal_window_size] * num_layers
105
- if temporal_shifted is None or isinstance(temporal_shifted, bool):
106
- temporal_shifted = [temporal_shifted] * num_layers
107
-
108
- self.blocks = nn.ModuleList(
109
- [
110
- get_nablock(block_type[i])(
111
- vid_dim=vid_dim,
112
- txt_dim=txt_dim,
113
- emb_dim=emb_dim,
114
- heads=heads,
115
- head_dim=head_dim,
116
- expand_ratio=expand_ratio,
117
- norm=norm,
118
- norm_eps=norm_eps,
119
- ada=ada,
120
- qk_bias=qk_bias,
121
- qk_rope=qk_rope,
122
- qk_norm=qk_norm,
123
- shared_qkv=shared_qkv,
124
- shared_mlp=shared_mlp,
125
- mlp_type=mlp_type,
126
- window=window[i],
127
- window_method=window_method[i],
128
- temporal_window_size=temporal_window_size[i],
129
- temporal_shifted=temporal_shifted[i],
130
- **kwargs,
131
- )
132
- for i in range(num_layers)
133
- ]
134
- )
135
- self.vid_out = NaPatchOut(
136
- out_channels=vid_out_channels,
137
- patch_size=patch_size,
138
- dim=vid_dim,
139
- )
140
-
141
- self.need_txt_repeat = block_type[0] in [
142
- "mmdit_stwin",
143
- "mmdit_stwin_spatial",
144
- "mmdit_stwin_3d_spatial",
145
- ]
146
-
147
- def set_gradient_checkpointing(self, enable: bool):
148
- self.gradient_checkpointing = enable
149
-
150
- def forward(
151
- self,
152
- vid: torch.FloatTensor, # l c
153
- txt: torch.FloatTensor, # l c
154
- vid_shape: torch.LongTensor, # b 3
155
- txt_shape: torch.LongTensor, # b 1
156
- timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b
157
- disable_cache: bool = True, # for test
158
- ):
159
- # Text input.
160
- if txt_shape.size(-1) == 1 and self.need_txt_repeat:
161
- txt, txt_shape = na.repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
162
- # slice vid after patching in when using sequence parallelism
163
- txt = slice_inputs(txt, dim=0)
164
- txt = self.txt_in(txt)
165
-
166
- # Video input.
167
- # Sequence parallel slicing is done inside patching class.
168
- vid, vid_shape = self.vid_in(vid, vid_shape)
169
-
170
- # Embedding input.
171
- emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
172
-
173
- # Body
174
- cache = Cache(disable=disable_cache)
175
- for i, block in enumerate(self.blocks):
176
- vid, txt, vid_shape, txt_shape = gradient_checkpointing(
177
- enabled=(self.gradient_checkpointing and self.training),
178
- module=block,
179
- vid=vid,
180
- txt=txt,
181
- vid_shape=vid_shape,
182
- txt_shape=txt_shape,
183
- emb=emb,
184
- cache=cache,
185
- )
186
-
187
- vid, vid_shape = self.vid_out(vid, vid_shape, cache)
188
- return NaDiTOutput(vid_sample=vid)
189
-
190
-
191
- class NaDiTUpscaler(nn.Module):
192
- """
193
- Native Resolution Diffusion Transformer (NaDiT)
194
- """
195
-
196
- gradient_checkpointing = False
197
-
198
- def __init__(
199
- self,
200
- vid_in_channels: int,
201
- vid_out_channels: int,
202
- vid_dim: int,
203
- txt_in_dim: Optional[int],
204
- txt_dim: Optional[int],
205
- emb_dim: int,
206
- heads: int,
207
- head_dim: int,
208
- expand_ratio: int,
209
- norm: Optional[str],
210
- norm_eps: float,
211
- ada: str,
212
- qk_bias: bool,
213
- qk_rope: bool,
214
- qk_norm: Optional[str],
215
- patch_size: Union[int, Tuple[int, int, int]],
216
- num_layers: int,
217
- block_type: Union[str, Tuple[str]],
218
- shared_qkv: bool = False,
219
- shared_mlp: bool = False,
220
- mlp_type: str = "normal",
221
- window: Optional[Tuple] = None,
222
- window_method: Optional[Tuple[str]] = None,
223
- temporal_window_size: int = None,
224
- temporal_shifted: bool = False,
225
- **kwargs,
226
- ):
227
- ada = get_ada_layer(ada)
228
- norm = get_norm_layer(norm)
229
- qk_norm = get_norm_layer(qk_norm)
230
- if isinstance(block_type, str):
231
- block_type = [block_type] * num_layers
232
- elif len(block_type) != num_layers:
233
- raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
234
- super().__init__()
235
- self.vid_in = NaPatchIn(
236
- in_channels=vid_in_channels,
237
- patch_size=patch_size,
238
- dim=vid_dim,
239
- )
240
- self.txt_in = (
241
- nn.Linear(txt_in_dim, txt_dim)
242
- if txt_in_dim and txt_in_dim != txt_dim
243
- else nn.Identity()
244
- )
245
- self.emb_in = TimeEmbedding(
246
- sinusoidal_dim=256,
247
- hidden_dim=max(vid_dim, txt_dim),
248
- output_dim=emb_dim,
249
- )
250
-
251
- self.emb_scale = TimeEmbedding(
252
- sinusoidal_dim=256,
253
- hidden_dim=max(vid_dim, txt_dim),
254
- output_dim=emb_dim,
255
- )
256
-
257
- if window is None or isinstance(window[0], int):
258
- window = [window] * num_layers
259
- if window_method is None or isinstance(window_method, str):
260
- window_method = [window_method] * num_layers
261
- if temporal_window_size is None or isinstance(temporal_window_size, int):
262
- temporal_window_size = [temporal_window_size] * num_layers
263
- if temporal_shifted is None or isinstance(temporal_shifted, bool):
264
- temporal_shifted = [temporal_shifted] * num_layers
265
-
266
- self.blocks = nn.ModuleList(
267
- [
268
- get_nablock(block_type[i])(
269
- vid_dim=vid_dim,
270
- txt_dim=txt_dim,
271
- emb_dim=emb_dim,
272
- heads=heads,
273
- head_dim=head_dim,
274
- expand_ratio=expand_ratio,
275
- norm=norm,
276
- norm_eps=norm_eps,
277
- ada=ada,
278
- qk_bias=qk_bias,
279
- qk_rope=qk_rope,
280
- qk_norm=qk_norm,
281
- shared_qkv=shared_qkv,
282
- shared_mlp=shared_mlp,
283
- mlp_type=mlp_type,
284
- window=window[i],
285
- window_method=window_method[i],
286
- temporal_window_size=temporal_window_size[i],
287
- temporal_shifted=temporal_shifted[i],
288
- **kwargs,
289
- )
290
- for i in range(num_layers)
291
- ]
292
- )
293
- self.vid_out = NaPatchOut(
294
- out_channels=vid_out_channels,
295
- patch_size=patch_size,
296
- dim=vid_dim,
297
- )
298
-
299
- self.need_txt_repeat = block_type[0] in [
300
- "mmdit_stwin",
301
- "mmdit_stwin_spatial",
302
- "mmdit_stwin_3d_spatial",
303
- ]
304
-
305
- def set_gradient_checkpointing(self, enable: bool):
306
- self.gradient_checkpointing = enable
307
-
308
- def forward(
309
- self,
310
- vid: torch.FloatTensor, # l c
311
- txt: torch.FloatTensor, # l c
312
- vid_shape: torch.LongTensor, # b 3
313
- txt_shape: torch.LongTensor, # b 1
314
- timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b
315
- downscale: Union[int, float, torch.IntTensor, torch.FloatTensor], # b
316
- disable_cache: bool = False, # for test
317
- ):
318
-
319
- # Text input.
320
- if txt_shape.size(-1) == 1 and self.need_txt_repeat:
321
- txt, txt_shape = na.repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
322
- # slice vid after patching in when using sequence parallelism
323
- txt = slice_inputs(txt, dim=0)
324
- txt = self.txt_in(txt)
325
-
326
- # Video input.
327
- # Sequence parallel slicing is done inside patching class.
328
- vid, vid_shape = self.vid_in(vid, vid_shape)
329
-
330
- # Embedding input.
331
- emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
332
- emb_scale = self.emb_scale(downscale, device=vid.device, dtype=vid.dtype)
333
- emb = emb + emb_scale
334
-
335
- # Body
336
- cache = Cache(disable=disable_cache)
337
- for i, block in enumerate(self.blocks):
338
- vid, txt, vid_shape, txt_shape = gradient_checkpointing(
339
- enabled=(self.gradient_checkpointing and self.training),
340
- module=block,
341
- vid=vid,
342
- txt=txt,
343
- vid_shape=vid_shape,
344
- txt_shape=txt_shape,
345
- emb=emb,
346
- cache=cache,
347
- )
348
-
349
- vid, vid_shape = self.vid_out(vid, vid_shape, cache)
350
- return NaDiTOutput(vid_sample=vid)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit/normalization.py DELETED
@@ -1,63 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from typing import Callable, Optional
16
- from diffusers.models.normalization import RMSNorm
17
- from torch import nn
18
-
19
- # (dim: int, eps: float, elementwise_affine: bool)
20
- norm_layer_type = Callable[[int, float, bool], nn.Module]
21
-
22
-
23
- def get_norm_layer(norm_type: Optional[str]) -> norm_layer_type:
24
-
25
- def _norm_layer(dim: int, eps: float, elementwise_affine: bool):
26
- if norm_type is None:
27
- return nn.Identity()
28
-
29
- if norm_type == "layer":
30
- return nn.LayerNorm(
31
- normalized_shape=dim,
32
- eps=eps,
33
- elementwise_affine=elementwise_affine,
34
- )
35
-
36
- if norm_type == "rms":
37
- return RMSNorm(
38
- dim=dim,
39
- eps=eps,
40
- elementwise_affine=elementwise_affine,
41
- )
42
-
43
- if norm_type == "fusedln":
44
- from apex.normalization import FusedLayerNorm
45
-
46
- return FusedLayerNorm(
47
- normalized_shape=dim,
48
- elementwise_affine=elementwise_affine,
49
- eps=eps,
50
- )
51
-
52
- if norm_type == "fusedrms":
53
- from apex.normalization import FusedRMSNorm
54
-
55
- return FusedRMSNorm(
56
- normalized_shape=dim,
57
- elementwise_affine=elementwise_affine,
58
- eps=eps,
59
- )
60
-
61
- raise NotImplementedError(f"{norm_type} is not supported")
62
-
63
- return _norm_layer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit/patch.py DELETED
@@ -1,112 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from typing import Tuple, Union
16
- import torch
17
- from einops import rearrange
18
- from torch import nn
19
- from torch.nn.modules.utils import _triple
20
-
21
- from common.cache import Cache
22
- from common.distributed.ops import gather_outputs, slice_inputs
23
-
24
- from . import na
25
-
26
-
27
- class PatchIn(nn.Module):
28
- def __init__(
29
- self,
30
- in_channels: int,
31
- patch_size: Union[int, Tuple[int, int, int]],
32
- dim: int,
33
- ):
34
- super().__init__()
35
- t, h, w = _triple(patch_size)
36
- self.patch_size = t, h, w
37
- self.proj = nn.Linear(in_channels * t * h * w, dim)
38
-
39
- def forward(
40
- self,
41
- vid: torch.Tensor,
42
- ) -> torch.Tensor:
43
- t, h, w = self.patch_size
44
- vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w)
45
- vid = self.proj(vid)
46
- return vid
47
-
48
-
49
- class PatchOut(nn.Module):
50
- def __init__(
51
- self,
52
- out_channels: int,
53
- patch_size: Union[int, Tuple[int, int, int]],
54
- dim: int,
55
- ):
56
- super().__init__()
57
- t, h, w = _triple(patch_size)
58
- self.patch_size = t, h, w
59
- self.proj = nn.Linear(dim, out_channels * t * h * w)
60
-
61
- def forward(
62
- self,
63
- vid: torch.Tensor,
64
- ) -> torch.Tensor:
65
- t, h, w = self.patch_size
66
- vid = self.proj(vid)
67
- vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w)
68
- return vid
69
-
70
-
71
- class NaPatchIn(PatchIn):
72
- def forward(
73
- self,
74
- vid: torch.Tensor, # l c
75
- vid_shape: torch.LongTensor,
76
- ) -> torch.Tensor:
77
- t, h, w = self.patch_size
78
- if not (t == h == w == 1):
79
- vid, vid_shape = na.rearrange(
80
- vid, vid_shape, "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w
81
- )
82
- # slice vid after patching in when using sequence parallelism
83
- vid = slice_inputs(vid, dim=0)
84
- vid = self.proj(vid)
85
- return vid, vid_shape
86
-
87
-
88
- class NaPatchOut(PatchOut):
89
- def forward(
90
- self,
91
- vid: torch.FloatTensor, # l c
92
- vid_shape: torch.LongTensor,
93
- cache: Cache = Cache(disable=True),
94
- ) -> Tuple[
95
- torch.FloatTensor,
96
- torch.LongTensor,
97
- ]:
98
- t, h, w = self.patch_size
99
- vid = self.proj(vid)
100
- # gather vid before patching out when enabling sequence parallelism
101
- vid = gather_outputs(
102
- vid,
103
- gather_dim=0,
104
- padding_dim=0,
105
- unpad_shape=vid_shape,
106
- cache=cache.namespace("vid"),
107
- )
108
- if not (t == h == w == 1):
109
- vid, vid_shape = na.rearrange(
110
- vid, vid_shape, "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w
111
- )
112
- return vid, vid_shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit/rope.py DELETED
@@ -1,101 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from functools import lru_cache
16
- from typing import Tuple
17
- import torch
18
- from einops import rearrange
19
- from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
20
- from torch import nn
21
-
22
- from common.cache import Cache
23
-
24
-
25
- class RotaryEmbeddingBase(nn.Module):
26
- def __init__(self, dim: int, rope_dim: int):
27
- super().__init__()
28
- self.rope = RotaryEmbedding(
29
- dim=dim // rope_dim,
30
- freqs_for="pixel",
31
- max_freq=256,
32
- )
33
- # 1. Set model.requires_grad_(True) after model creation will make
34
- # the `requires_grad=False` for rope freqs no longer hold.
35
- # 2. Even if we don't set requires_grad_(True) explicitly,
36
- # FSDP is not memory efficient when handling fsdp_wrap
37
- # with mixed requires_grad=True/False.
38
- # With above consideration, it is easier just remove the freqs
39
- # out of nn.Parameters when `learned_freq=False`
40
- freqs = self.rope.freqs
41
- del self.rope.freqs
42
- self.rope.register_buffer("freqs", freqs.data)
43
-
44
- @lru_cache(maxsize=128)
45
- def get_axial_freqs(self, *dims):
46
- return self.rope.get_axial_freqs(*dims)
47
-
48
-
49
- class RotaryEmbedding3d(RotaryEmbeddingBase):
50
- def __init__(self, dim: int):
51
- super().__init__(dim, rope_dim=3)
52
-
53
- def forward(
54
- self,
55
- q: torch.FloatTensor, # b h l d
56
- k: torch.FloatTensor, # b h l d
57
- size: Tuple[int, int, int],
58
- ) -> Tuple[
59
- torch.FloatTensor,
60
- torch.FloatTensor,
61
- ]:
62
- T, H, W = size
63
- freqs = self.get_axial_freqs(T, H, W)
64
- q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W)
65
- k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W)
66
- q = apply_rotary_emb(freqs, q)
67
- k = apply_rotary_emb(freqs, k)
68
- q = rearrange(q, "b h T H W d -> b h (T H W) d")
69
- k = rearrange(k, "b h T H W d -> b h (T H W) d")
70
- return q, k
71
-
72
-
73
- class NaRotaryEmbedding3d(RotaryEmbedding3d):
74
- def forward(
75
- self,
76
- q: torch.FloatTensor, # L h d
77
- k: torch.FloatTensor, # L h d
78
- shape: torch.LongTensor,
79
- cache: Cache,
80
- ) -> Tuple[
81
- torch.FloatTensor,
82
- torch.FloatTensor,
83
- ]:
84
- freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape))
85
- q = rearrange(q, "L h d -> h L d")
86
- k = rearrange(k, "L h d -> h L d")
87
- q = apply_rotary_emb(freqs, q.float()).to(q.dtype)
88
- k = apply_rotary_emb(freqs, k.float()).to(k.dtype)
89
- q = rearrange(q, "h L d -> L h d")
90
- k = rearrange(k, "h L d -> L h d")
91
- return q, k
92
-
93
- def get_freqs(
94
- self,
95
- shape: torch.LongTensor,
96
- ) -> torch.Tensor:
97
- freq_list = []
98
- for f, h, w in shape.tolist():
99
- freqs = self.get_axial_freqs(f, h, w)
100
- freq_list.append(freqs.view(-1, freqs.size(-1)))
101
- return torch.cat(freq_list, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit/window.py DELETED
@@ -1,83 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from math import ceil
16
- from typing import Tuple
17
- import math
18
-
19
- def get_window_op(name: str):
20
- if name == "720pwin_by_size_bysize":
21
- return make_720Pwindows_bysize
22
- if name == "720pswin_by_size_bysize":
23
- return make_shifted_720Pwindows_bysize
24
- raise ValueError(f"Unknown windowing method: {name}")
25
-
26
-
27
- # -------------------------------- Windowing -------------------------------- #
28
- def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]):
29
- t, h, w = size
30
- resized_nt, resized_nh, resized_nw = num_windows
31
- #cal windows under 720p
32
- scale = math.sqrt((45 * 80) / (h * w))
33
- resized_h, resized_w = round(h * scale), round(w * scale)
34
- wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
35
- wt = ceil(min(t, 30) / resized_nt) # window size.
36
- nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size.
37
- return [
38
- (
39
- slice(it * wt, min((it + 1) * wt, t)),
40
- slice(ih * wh, min((ih + 1) * wh, h)),
41
- slice(iw * ww, min((iw + 1) * ww, w)),
42
- )
43
- for iw in range(nw)
44
- if min((iw + 1) * ww, w) > iw * ww
45
- for ih in range(nh)
46
- if min((ih + 1) * wh, h) > ih * wh
47
- for it in range(nt)
48
- if min((it + 1) * wt, t) > it * wt
49
- ]
50
-
51
- def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]):
52
- t, h, w = size
53
- resized_nt, resized_nh, resized_nw = num_windows
54
- #cal windows under 720p
55
- scale = math.sqrt((45 * 80) / (h * w))
56
- resized_h, resized_w = round(h * scale), round(w * scale)
57
- wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
58
- wt = ceil(min(t, 30) / resized_nt) # window size.
59
-
60
- st, sh, sw = ( # shift size.
61
- 0.5 if wt < t else 0,
62
- 0.5 if wh < h else 0,
63
- 0.5 if ww < w else 0,
64
- )
65
- nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size.
66
- nt, nh, nw = ( # number of window.
67
- nt + 1 if st > 0 else 1,
68
- nh + 1 if sh > 0 else 1,
69
- nw + 1 if sw > 0 else 1,
70
- )
71
- return [
72
- (
73
- slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)),
74
- slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)),
75
- slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)),
76
- )
77
- for iw in range(nw)
78
- if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0)
79
- for ih in range(nh)
80
- if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0)
81
- for it in range(nt)
82
- if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0)
83
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit_v2/attention.py DELETED
@@ -1,46 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- import torch
16
- import torch.nn.functional as F
17
-
18
- from flash_attn import flash_attn_varlen_func
19
-
20
- from torch import nn
21
-
22
- class TorchAttention(nn.Module):
23
- def tflops(self, args, kwargs, output) -> float:
24
- assert len(args) == 0 or len(args) > 2, "query, key should both provided by args / kwargs"
25
- q = kwargs.get("query") or args[0]
26
- k = kwargs.get("key") or args[1]
27
- b, h, sq, d = q.shape
28
- b, h, sk, d = k.shape
29
- return b * h * (4 * d * (sq / 1e6) * (sk / 1e6))
30
-
31
- def forward(self, *args, **kwargs):
32
- return F.scaled_dot_product_attention(*args, **kwargs)
33
-
34
-
35
- class FlashAttentionVarlen(nn.Module):
36
- def tflops(self, args, kwargs, output) -> float:
37
- cu_seqlens_q = kwargs["cu_seqlens_q"]
38
- cu_seqlens_k = kwargs["cu_seqlens_k"]
39
- _, h, d = output.shape
40
- seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]) / 1e6
41
- seqlens_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]) / 1e6
42
- return h * (4 * d * (seqlens_q * seqlens_k).sum())
43
-
44
- def forward(self, *args, **kwargs):
45
- kwargs["deterministic"] = torch.are_deterministic_algorithms_enabled()
46
- return flash_attn_varlen_func(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit_v2/embedding.py DELETED
@@ -1,62 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from typing import Optional, Union
16
- import torch
17
- from diffusers.models.embeddings import get_timestep_embedding
18
- from torch import nn
19
-
20
-
21
- def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]):
22
- return emb1 if emb2 is None else emb1 + emb2
23
-
24
-
25
- class TimeEmbedding(nn.Module):
26
- def __init__(
27
- self,
28
- sinusoidal_dim: int,
29
- hidden_dim: int,
30
- output_dim: int,
31
- ):
32
- super().__init__()
33
- self.sinusoidal_dim = sinusoidal_dim
34
- self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim)
35
- self.proj_hid = nn.Linear(hidden_dim, hidden_dim)
36
- self.proj_out = nn.Linear(hidden_dim, output_dim)
37
- self.act = nn.SiLU()
38
-
39
- def forward(
40
- self,
41
- timestep: Union[int, float, torch.IntTensor, torch.FloatTensor],
42
- device: torch.device,
43
- dtype: torch.dtype,
44
- ) -> torch.FloatTensor:
45
- if not torch.is_tensor(timestep):
46
- timestep = torch.tensor([timestep], device=device, dtype=dtype)
47
- if timestep.ndim == 0:
48
- timestep = timestep[None]
49
-
50
- emb = get_timestep_embedding(
51
- timesteps=timestep,
52
- embedding_dim=self.sinusoidal_dim,
53
- flip_sin_to_cos=False,
54
- downscale_freq_shift=0,
55
- )
56
- emb = emb.to(dtype)
57
- emb = self.proj_in(emb)
58
- emb = self.act(emb)
59
- emb = self.proj_hid(emb)
60
- emb = self.act(emb)
61
- emb = self.proj_out(emb)
62
- return emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelsx/dit_v2/mlp.py DELETED
@@ -1,62 +0,0 @@
1
- # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # //
3
- # // Licensed under the Apache License, Version 2.0 (the "License");
4
- # // you may not use this file except in compliance with the License.
5
- # // You may obtain a copy of the License at
6
- # //
7
- # // http://www.apache.org/licenses/LICENSE-2.0
8
- # //
9
- # // Unless required by applicable law or agreed to in writing, software
10
- # // distributed under the License is distributed on an "AS IS" BASIS,
11
- # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # // See the License for the specific language governing permissions and
13
- # // limitations under the License.
14
-
15
- from typing import Optional
16
- import torch
17
- import torch.nn.functional as F
18
- from torch import nn
19
-
20
-
21
- def get_mlp(mlp_type: Optional[str] = "normal"):
22
- if mlp_type == "normal":
23
- return MLP
24
- elif mlp_type == "swiglu":
25
- return SwiGLUMLP
26
-
27
-
28
- class MLP(nn.Module):
29
- def __init__(
30
- self,
31
- dim: int,
32
- expand_ratio: int,
33
- ):
34
- super().__init__()
35
- self.proj_in = nn.Linear(dim, dim * expand_ratio)
36
- self.act = nn.GELU("tanh")
37
- self.proj_out = nn.Linear(dim * expand_ratio, dim)
38
-
39
- def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
40
- x = self.proj_in(x)
41
- x = self.act(x)
42
- x = self.proj_out(x)
43
- return x
44
-
45
-
46
- class SwiGLUMLP(nn.Module):
47
- def __init__(
48
- self,
49
- dim: int,
50
- expand_ratio: int,
51
- multiple_of: int = 256,
52
- ):
53
- super().__init__()
54
- hidden_dim = int(2 * dim * expand_ratio / 3)
55
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
56
- self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False)
57
- self.proj_out = nn.Linear(hidden_dim, dim, bias=False)
58
- self.proj_in = nn.Linear(dim, hidden_dim, bias=False)
59
-
60
- def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
61
- x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x))
62
- return x