jiaosiyu.111 commited on
Commit
5e154e3
·
1 Parent(s): 963f7cd

init commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## 🚀 Quick Start
3
+
4
+ ### 🛠️ Environment Setup
5
+
6
+ #### ✅ Recommended Setup
7
+
8
+ ```bash
9
+ # 1. Clone the repo
10
+ git clone https://github.com/jiaosiyuu/ThinkGen.git
11
+ cd OmniGen2
12
+
13
+ # 2. (Optional) Create a clean Python environment
14
+ conda create -n thinkgen python=3.11
15
+ conda activate thinkgen
16
+
17
+ # 3. Install dependencies
18
+ # 3.1 Install PyTorch (choose correct CUDA version)
19
+ pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
20
+
21
+ # 3.2 Install other required packages
22
+ pip install -r req.txt
23
+
24
+ # ThinkGen runs even without flash-attn, though we recommend install it for best performance.
25
+ pip install --no-cache-dir flash-attn==2.7.4.post1 --no-build-isolation
26
+ ```
27
+
28
+ #### 🌏 For users in Mainland China
29
+
30
+ ```bash
31
+ pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://mirror.sjtu.edu.cn/pytorch-wheels/cu124
32
+ pip install -r req.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
33
+ pip install --no-cache-dir flash-attn==2.7.4.post1 --no-build-isolation -i https://pypi.tuna.tsinghua.edu.cn/simple
34
+ ```
35
+
36
+ ---
37
+
38
+
39
+ * **Run Locally**:
40
+ ```bash
41
+ from ThinkGen.model import ThinkGen_Chat
42
+ import os
43
+
44
+ model_path = "/home/tiger/ThinkGen"
45
+
46
+ chat_model = ThinkGen_Chat(
47
+ model_path=model_path,
48
+ dtype='bf16',
49
+ height=1024,
50
+ width=1024
51
+ )
52
+
53
+
54
+ # Generation
55
+ messages = [
56
+ {"type": "text", "value": '''A close-up image of a red apple with the words 'Tart & Sweet' in white, cursive font on its surface, forming a spiral pattern. The apple is centered in the frame, and the background is a green surface labeled 'Organic Produce' in black, bold letters. The apple has a visible stem and a small bite mark on its side with the word 'Juicy' written in a small, handwritten style near the bite.'''}
57
+ ]
58
+ results = chat_model.generate_image(messages)
59
+ output_dir = "vis/chat"
60
+ os.makedirs(output_dir, exist_ok=True)
61
+
62
+ for i, img in enumerate(results.images):
63
+ save_path = os.path.join(output_dir, f"result_{i}.png")
64
+ img.save(save_path)
65
+ print(f"Saved to {save_path}")
66
+
67
+
68
+
69
+ # Understanding
70
+ messages = [
71
+ {"type": "image", "value": "images/teaser.png"},
72
+ {"type": "text", "value": "Describe this image"}
73
+ ]
74
+
75
+ response = chat_model.generate_text(messages)
76
+ print(response)
77
+
78
+
79
+ ```
80
+
81
+
82
+
83
+ ## License
84
+ This work is licensed under Apache 2.0 license.
mllm/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd6ddc642758556cbfe31c59342b6e7b4ddcf7c62c0534723e53452ceb73abb4
3
+ size 1566
mllm/generation_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e577a4a2ea83445cbb1b79f73a6ce55fde9b8c60aff7cd0bb8752a3449919fd3
3
+ size 147
mllm/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
mllm/model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e81d1972134a60c6bef2a0deaa41fbfef563fe0abf53ec1d7f022a2721895b90
3
+ size 4940477760
mllm/model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca92aff29a9a5e533d427ce34d53d14352ea42ce92030940817253775c3fc82d
3
+ size 4954046904
mllm/model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d31e957cea4b9230e31034a5dec0c131c16b741846e205f768f7a1b23be0dddb
3
+ size 4997839528
mllm/model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:570b3975d4a9ba00999aa26c6a4aa53006eb680da0bd3664d83a1a45c2f73eaa
3
+ size 2641975280
mllm/model.safetensors.index.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53f1e0c5682443e902a41f329b300fcbaab3116499d465086e77233c78439674
3
+ size 67795
model_index.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e99a472180a7cea970e085ee56020eab27f85dd6a813a26db8e4eccc6957ba2
3
+ size 456
processor/added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0284b582e14987fbd3d5a2cb2bd139084371ed9acbae488829a1c900833c680
3
+ size 707
processor/chat_template.jinja ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- set image_count = namespace(value=0) %}
2
+ {%- set video_count = namespace(value=0) %}
3
+ {%- macro render_content(content, do_vision_count) %}
4
+ {%- if content is string %}
5
+ {{- content }}
6
+ {%- else %}
7
+ {%- for item in content %}
8
+ {%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
9
+ {%- if do_vision_count %}
10
+ {%- set image_count.value = image_count.value + 1 %}
11
+ {%- endif %}
12
+ {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}
13
+ <|vision_start|><|image_pad|><|vision_end|>
14
+ {%- elif 'video' in item or item.type == 'video' %}
15
+ {%- if do_vision_count %}
16
+ {%- set video_count.value = video_count.value + 1 %}
17
+ {%- endif %}
18
+ {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}
19
+ <|vision_start|><|video_pad|><|vision_end|>
20
+ {%- elif 'text' in item %}
21
+ {{- item.text }}
22
+ {%- endif %}
23
+ {%- endfor %}
24
+ {%- endif %}
25
+ {%- endmacro %}
26
+ {%- if tools %}
27
+ {{- '<|im_start|>system\n' }}
28
+ {%- if messages[0].role == 'system' %}
29
+ {{- render_content(messages[0].content, false) + '\n\n' }}
30
+ {%- endif %}
31
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
32
+ {%- for tool in tools %}
33
+ {{- "\n" }}
34
+ {{- tool | tojson }}
35
+ {%- endfor %}
36
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
37
+ {%- else %}
38
+ {%- if messages[0].role == 'system' %}
39
+ {{- '<|im_start|>system\n' + render_content(messages[0].content, false) + '<|im_end|>\n' }}
40
+ {%- endif %}
41
+ {%- endif %}
42
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
43
+ {%- for message in messages[::-1] %}
44
+ {%- set index = (messages|length - 1) - loop.index0 %}
45
+ {%- if ns.multi_step_tool and message.role == "user" %}
46
+ {%- set content = render_content(message.content, false) %}
47
+ {%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %}
48
+ {%- set ns.multi_step_tool = false %}
49
+ {%- set ns.last_query_index = index %}
50
+ {%- endif %}
51
+ {%- endif %}
52
+ {%- endfor %}
53
+ {%- for message in messages %}
54
+ {%- set content = render_content(message.content, True) %}
55
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
56
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
57
+ {%- elif message.role == "assistant" %}
58
+ {%- set reasoning_content = '' %}
59
+ {%- if message.reasoning_content is string %}
60
+ {%- set reasoning_content = message.reasoning_content %}
61
+ {%- else %}
62
+ {%- if '</think>' in content %}
63
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
64
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
65
+ {%- endif %}
66
+ {%- endif %}
67
+ {%- if loop.index0 > ns.last_query_index %}
68
+ {%- if loop.last or (not loop.last and reasoning_content) %}
69
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
70
+ {%- else %}
71
+ {{- '<|im_start|>' + message.role + '\n' + content }}
72
+ {%- endif %}
73
+ {%- else %}
74
+ {{- '<|im_start|>' + message.role + '\n' + content }}
75
+ {%- endif %}
76
+ {%- if message.tool_calls %}
77
+ {%- for tool_call in message.tool_calls %}
78
+ {%- if (loop.first and content) or (not loop.first) %}
79
+ {{- '\n' }}
80
+ {%- endif %}
81
+ {%- if tool_call.function %}
82
+ {%- set tool_call = tool_call.function %}
83
+ {%- endif %}
84
+ {{- '<tool_call>\n{"name": "' }}
85
+ {{- tool_call.name }}
86
+ {{- '", "arguments": ' }}
87
+ {%- if tool_call.arguments is string %}
88
+ {{- tool_call.arguments }}
89
+ {%- else %}
90
+ {{- tool_call.arguments | tojson }}
91
+ {%- endif %}
92
+ {{- '}\n</tool_call>' }}
93
+ {%- endfor %}
94
+ {%- endif %}
95
+ {{- '<|im_end|>\n' }}
96
+ {%- elif message.role == "tool" %}
97
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
98
+ {{- '<|im_start|>user' }}
99
+ {%- endif %}
100
+ {{- '\n<tool_response>\n' }}
101
+ {{- content }}
102
+ {{- '\n</tool_response>' }}
103
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
104
+ {{- '<|im_end|>\n' }}
105
+ {%- endif %}
106
+ {%- endif %}
107
+ {%- endfor %}
108
+ {%- if add_generation_prompt %}
109
+ {{- '<|im_start|>assistant\n<think>\n' }}
110
+ {%- endif %}
processor/chat_template.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4757062314f864cf47b9ce6ea4bd921590611c5b90f0860c523831756edc4fa1
3
+ size 1072
processor/preprocessor_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93585062a80db5e8ca038efc7726a3e6411d9db948472d81d63c6303993be8c5
3
+ size 782
processor/special_tokens_map.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76862e765266b85aa9459767e33cbaf13970f327a0e88d1c65846c2ddd3a1ecd
3
+ size 613
processor/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
3
+ size 11422654
processor/tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59b3e9b9a46fd9e8447842ca20aba3fc4eb9c22cd10969ae28812f1bc7c3fa22
3
+ size 5465
processor/video_preprocessor_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59c5c9eb52182eb14c06ffb10ca9effd29adce5f238a95de23ca14a38dbd2cb1
3
+ size 817
processor/vocab.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca10d7e9fb3ed18575dd1e277a2579c16d108e32f27439684afa0e10b1440910
3
+ size 2776833
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0808a703cdff03e47929de027fa90cd909b44011d7c83c27d2db17a2332e5fa1
3
+ size 150
scheduler/scheduling_flow_match_euler_discrete.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
32
+ """
33
+ Output class for the scheduler's `step` function output.
34
+
35
+ Args:
36
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
+ denoising loop.
39
+ """
40
+
41
+ prev_sample: torch.FloatTensor
42
+
43
+
44
+ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
45
+ """
46
+ Euler scheduler.
47
+
48
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
49
+ methods the library implements for all schedulers such as loading and saving.
50
+
51
+ Args:
52
+ num_train_timesteps (`int`, defaults to 1000):
53
+ The number of diffusion steps to train the model.
54
+ timestep_spacing (`str`, defaults to `"linspace"`):
55
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
56
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
57
+ shift (`float`, defaults to 1.0):
58
+ The shift value for the timestep schedule.
59
+ """
60
+
61
+ _compatibles = []
62
+ order = 1
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ num_train_timesteps: int = 1000,
68
+ dynamic_time_shift: bool = False
69
+ ):
70
+ timesteps = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float32)[:-1]
71
+
72
+ self.timesteps = timesteps
73
+
74
+ self._step_index = None
75
+ self._begin_index = None
76
+
77
+ @property
78
+ def step_index(self):
79
+ """
80
+ The index counter for current timestep. It will increase 1 after each scheduler step.
81
+ """
82
+ return self._step_index
83
+
84
+ @property
85
+ def begin_index(self):
86
+ """
87
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
88
+ """
89
+ return self._begin_index
90
+
91
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
92
+ def set_begin_index(self, begin_index: int = 0):
93
+ """
94
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
95
+
96
+ Args:
97
+ begin_index (`int`):
98
+ The begin index for the scheduler.
99
+ """
100
+ self._begin_index = begin_index
101
+
102
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
103
+ if schedule_timesteps is None:
104
+ schedule_timesteps = self._timesteps
105
+
106
+ indices = (schedule_timesteps == timestep).nonzero()
107
+
108
+ # The sigma index that is taken for the **very** first `step`
109
+ # is always the second index (or the last index if there is only 1)
110
+ # This way we can ensure we don't accidentally skip a sigma in
111
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
112
+ pos = 1 if len(indices) > 1 else 0
113
+
114
+ return indices[pos].item()
115
+
116
+ # def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
117
+ # return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
118
+
119
+ def set_timesteps(
120
+ self,
121
+ num_inference_steps: int = None,
122
+ device: Union[str, torch.device] = None,
123
+ timesteps: Optional[List[float]] = None,
124
+ num_tokens: Optional[int] = None
125
+ ):
126
+ """
127
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
128
+
129
+ Args:
130
+ num_inference_steps (`int`):
131
+ The number of diffusion steps used when generating samples with a pre-trained model.
132
+ device (`str` or `torch.device`, *optional*):
133
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
134
+ """
135
+
136
+ if timesteps is None:
137
+ self.num_inference_steps = num_inference_steps
138
+ timesteps = np.linspace(0, 1, num_inference_steps + 1, dtype=np.float32)[:-1]
139
+ if self.config.dynamic_time_shift and num_tokens is not None:
140
+ m = np.sqrt(num_tokens) / 40 # when input resolution is 320 * 320, m = 1, when input resolution is 1024 * 1024, m = 3.2
141
+ timesteps = timesteps / (m - m * timesteps + timesteps)
142
+
143
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
144
+ _timesteps = torch.cat([timesteps, torch.ones(1, device=timesteps.device)])
145
+
146
+ self.timesteps = timesteps
147
+ self._timesteps = _timesteps
148
+ self._step_index = None
149
+ self._begin_index = None
150
+
151
+ def _init_step_index(self, timestep):
152
+ if self.begin_index is None:
153
+ if isinstance(timestep, torch.Tensor):
154
+ timestep = timestep.to(self.timesteps.device)
155
+ self._step_index = self.index_for_timestep(timestep)
156
+ else:
157
+ self._step_index = self._begin_index
158
+
159
+ def step(
160
+ self,
161
+ model_output: torch.FloatTensor,
162
+ timestep: Union[float, torch.FloatTensor],
163
+ sample: torch.FloatTensor,
164
+ generator: Optional[torch.Generator] = None,
165
+ return_dict: bool = True,
166
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
167
+ """
168
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
169
+ process from the learned model outputs (most often the predicted noise).
170
+
171
+ Args:
172
+ model_output (`torch.FloatTensor`):
173
+ The direct output from learned diffusion model.
174
+ timestep (`float`):
175
+ The current discrete timestep in the diffusion chain.
176
+ sample (`torch.FloatTensor`):
177
+ A current instance of a sample created by the diffusion process.
178
+ s_churn (`float`):
179
+ s_tmin (`float`):
180
+ s_tmax (`float`):
181
+ s_noise (`float`, defaults to 1.0):
182
+ Scaling factor for noise added to the sample.
183
+ generator (`torch.Generator`, *optional*):
184
+ A random number generator.
185
+ return_dict (`bool`):
186
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
187
+ tuple.
188
+
189
+ Returns:
190
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
191
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
192
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
193
+ """
194
+
195
+ if (
196
+ isinstance(timestep, int)
197
+ or isinstance(timestep, torch.IntTensor)
198
+ or isinstance(timestep, torch.LongTensor)
199
+ ):
200
+ raise ValueError(
201
+ (
202
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
203
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
204
+ " one of the `scheduler.timesteps` as a timestep."
205
+ ),
206
+ )
207
+
208
+ if self.step_index is None:
209
+ self._init_step_index(timestep)
210
+ # Upcast to avoid precision issues when computing prev_sample
211
+ sample = sample.to(torch.float32)
212
+ t = self._timesteps[self.step_index]
213
+ t_next = self._timesteps[self.step_index + 1]
214
+
215
+ prev_sample = sample + (t_next - t) * model_output
216
+
217
+ # Cast sample back to model compatible dtype
218
+ prev_sample = prev_sample.to(model_output.dtype)
219
+
220
+ # upon completion increase step index by one
221
+ self._step_index += 1
222
+
223
+ if not return_dict:
224
+ return (prev_sample,)
225
+
226
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
227
+
228
+ def __len__(self):
229
+ return self.config.num_train_timesteps
transformer/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dedfbf53b121b97bee777508dd91de0f31cdace62d27aa54c19830acf021382e
3
+ size 497
transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:803e214724ea75baa6f0334d054729af5ede84cfb9c4dfcccca60c3098585008
3
+ size 7965739976
transformer/transformer_thinkgen.py ADDED
@@ -0,0 +1,2457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import itertools
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+ from dataclasses import dataclass
5
+ import math
6
+ import numpy as np
7
+
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from einops import rearrange, repeat
14
+
15
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
16
+ from diffusers.loaders import PeftAdapterMixin
17
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
18
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
19
+ from diffusers.models.attention_processor import Attention
20
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
21
+ from diffusers.models.modeling_utils import ModelMixin
22
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
23
+ from diffusers.models.activations import get_activation
24
+ from diffusers.models.embeddings import Timesteps
25
+
26
+
27
+ import importlib.util
28
+ import sys
29
+
30
+ # The package importlib_metadata is in a different place, depending on the python version.
31
+ if sys.version_info < (3, 8):
32
+ import importlib_metadata
33
+ else:
34
+ import importlib.metadata as importlib_metadata
35
+
36
+ def _is_package_available(pkg_name: str):
37
+ pkg_exists = importlib.util.find_spec(pkg_name) is not None
38
+ pkg_version = "N/A"
39
+
40
+ if pkg_exists:
41
+ try:
42
+ pkg_version = importlib_metadata.version(pkg_name)
43
+ except (ImportError, importlib_metadata.PackageNotFoundError):
44
+ pkg_exists = False
45
+
46
+ return pkg_exists, pkg_version
47
+
48
+ _triton_available, _triton_version = _is_package_available("triton")
49
+ _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
50
+
51
+ def is_triton_available():
52
+ return _triton_available
53
+
54
+ def is_flash_attn_available():
55
+ return _flash_attn_available
56
+
57
+ if is_flash_attn_available():
58
+ from flash_attn import flash_attn_varlen_func
59
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
60
+ else:
61
+ warnings.warn("Cannot import flash_attn, install flash_attn to use Flash2Varlen attention for better performance")
62
+
63
+
64
+ if is_triton_available():
65
+ # from ...ops.triton.layer_norm import RMSNorm
66
+ import triton
67
+ import triton.language as tl
68
+
69
+
70
+ from typing import Callable
71
+
72
+
73
+ def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
74
+ def decorator(*args, **kwargs):
75
+ if cuda_amp_deprecated:
76
+ kwargs["device_type"] = "cuda"
77
+ return dec(*args, **kwargs)
78
+ return decorator
79
+
80
+
81
+ if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
82
+ deprecated = True
83
+ from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
84
+ else:
85
+ deprecated = False
86
+ from torch.cuda.amp import custom_fwd, custom_bwd
87
+
88
+ custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
89
+ custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
90
+
91
+
92
+ def triton_autotune_configs():
93
+ # Return configs with a valid warp count for the current device
94
+ configs=[]
95
+ # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
96
+ max_threads_per_block=1024
97
+ # Default to warp size 32 if not defined by device
98
+ warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
99
+ # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
100
+ warp_count=1
101
+ while warp_count*warp_size <= max_threads_per_block:
102
+ configs.append(triton.Config({}, num_warps=warp_count))
103
+ warp_count*=2
104
+ return configs
105
+
106
+ @triton.autotune(
107
+ configs=triton_autotune_configs(),
108
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
109
+ )
110
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
111
+ # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
112
+ @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
113
+ @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
114
+ @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
115
+ @triton.jit
116
+ def _layer_norm_fwd_1pass_kernel(
117
+ X, # pointer to the input
118
+ Y, # pointer to the output
119
+ W, # pointer to the weights
120
+ B, # pointer to the biases
121
+ RESIDUAL, # pointer to the residual
122
+ X1,
123
+ W1,
124
+ B1,
125
+ Y1,
126
+ RESIDUAL_OUT, # pointer to the residual
127
+ ROWSCALE,
128
+ SEEDS, # Dropout seeds for each row
129
+ DROPOUT_MASK,
130
+ Mean, # pointer to the mean
131
+ Rstd, # pointer to the 1/std
132
+ stride_x_row, # how much to increase the pointer when moving by 1 row
133
+ stride_y_row,
134
+ stride_res_row,
135
+ stride_res_out_row,
136
+ stride_x1_row,
137
+ stride_y1_row,
138
+ M, # number of rows in X
139
+ N, # number of columns in X
140
+ eps, # epsilon to avoid division by zero
141
+ dropout_p, # Dropout probability
142
+ zero_centered_weight, # If true, add 1.0 to the weight
143
+ IS_RMS_NORM: tl.constexpr,
144
+ BLOCK_N: tl.constexpr,
145
+ HAS_RESIDUAL: tl.constexpr,
146
+ STORE_RESIDUAL_OUT: tl.constexpr,
147
+ HAS_BIAS: tl.constexpr,
148
+ HAS_DROPOUT: tl.constexpr,
149
+ STORE_DROPOUT_MASK: tl.constexpr,
150
+ HAS_ROWSCALE: tl.constexpr,
151
+ HAS_X1: tl.constexpr,
152
+ HAS_W1: tl.constexpr,
153
+ HAS_B1: tl.constexpr,
154
+ ):
155
+ # Map the program id to the row of X and Y it should compute.
156
+ row = tl.program_id(0)
157
+ X += row * stride_x_row
158
+ Y += row * stride_y_row
159
+ if HAS_RESIDUAL:
160
+ RESIDUAL += row * stride_res_row
161
+ if STORE_RESIDUAL_OUT:
162
+ RESIDUAL_OUT += row * stride_res_out_row
163
+ if HAS_X1:
164
+ X1 += row * stride_x1_row
165
+ if HAS_W1:
166
+ Y1 += row * stride_y1_row
167
+ # Compute mean and variance
168
+ cols = tl.arange(0, BLOCK_N)
169
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
170
+ if HAS_ROWSCALE:
171
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
172
+ x *= rowscale
173
+ if HAS_DROPOUT:
174
+ # Compute dropout mask
175
+ # 7 rounds is good enough, and reduces register pressure
176
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
177
+ x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
178
+ if STORE_DROPOUT_MASK:
179
+ tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
180
+ if HAS_X1:
181
+ x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
182
+ if HAS_ROWSCALE:
183
+ rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
184
+ x1 *= rowscale
185
+ if HAS_DROPOUT:
186
+ # Compute dropout mask
187
+ # 7 rounds is good enough, and reduces register pressure
188
+ keep_mask = (
189
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
190
+ )
191
+ x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
192
+ if STORE_DROPOUT_MASK:
193
+ tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
194
+ x += x1
195
+ if HAS_RESIDUAL:
196
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
197
+ x += residual
198
+ if STORE_RESIDUAL_OUT:
199
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
200
+ if not IS_RMS_NORM:
201
+ mean = tl.sum(x, axis=0) / N
202
+ tl.store(Mean + row, mean)
203
+ xbar = tl.where(cols < N, x - mean, 0.0)
204
+ var = tl.sum(xbar * xbar, axis=0) / N
205
+ else:
206
+ xbar = tl.where(cols < N, x, 0.0)
207
+ var = tl.sum(xbar * xbar, axis=0) / N
208
+ rstd = 1 / tl.sqrt(var + eps)
209
+ tl.store(Rstd + row, rstd)
210
+ # Normalize and apply linear transformation
211
+ mask = cols < N
212
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
213
+ if zero_centered_weight:
214
+ w += 1.0
215
+ if HAS_BIAS:
216
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
217
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
218
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
219
+ # Write output
220
+ tl.store(Y + cols, y, mask=mask)
221
+ if HAS_W1:
222
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
223
+ if zero_centered_weight:
224
+ w1 += 1.0
225
+ if HAS_B1:
226
+ b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
227
+ y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
228
+ tl.store(Y1 + cols, y1, mask=mask)
229
+
230
+
231
+ def _layer_norm_fwd(
232
+ x,
233
+ weight,
234
+ bias,
235
+ eps,
236
+ residual=None,
237
+ x1=None,
238
+ weight1=None,
239
+ bias1=None,
240
+ dropout_p=0.0,
241
+ rowscale=None,
242
+ out_dtype=None,
243
+ residual_dtype=None,
244
+ zero_centered_weight=False,
245
+ is_rms_norm=False,
246
+ return_dropout_mask=False,
247
+ out=None,
248
+ residual_out=None
249
+ ):
250
+ if residual is not None:
251
+ residual_dtype = residual.dtype
252
+ M, N = x.shape
253
+ assert x.stride(-1) == 1
254
+ if residual is not None:
255
+ assert residual.stride(-1) == 1
256
+ assert residual.shape == (M, N)
257
+ assert weight.shape == (N,)
258
+ assert weight.stride(-1) == 1
259
+ if bias is not None:
260
+ assert bias.stride(-1) == 1
261
+ assert bias.shape == (N,)
262
+ if x1 is not None:
263
+ assert x1.shape == x.shape
264
+ assert rowscale is None
265
+ assert x1.stride(-1) == 1
266
+ if weight1 is not None:
267
+ assert weight1.shape == (N,)
268
+ assert weight1.stride(-1) == 1
269
+ if bias1 is not None:
270
+ assert bias1.shape == (N,)
271
+ assert bias1.stride(-1) == 1
272
+ if rowscale is not None:
273
+ assert rowscale.is_contiguous()
274
+ assert rowscale.shape == (M,)
275
+ # allocate output
276
+ if out is None:
277
+ out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
278
+ else:
279
+ assert out.shape == x.shape
280
+ assert out.stride(-1) == 1
281
+ if weight1 is not None:
282
+ y1 = torch.empty_like(out)
283
+ assert y1.stride(-1) == 1
284
+ else:
285
+ y1 = None
286
+ if (
287
+ residual is not None
288
+ or (residual_dtype is not None and residual_dtype != x.dtype)
289
+ or dropout_p > 0.0
290
+ or rowscale is not None
291
+ or x1 is not None
292
+ ):
293
+ if residual_out is None:
294
+ residual_out = torch.empty(
295
+ M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
296
+ )
297
+ else:
298
+ assert residual_out.shape == x.shape
299
+ assert residual_out.stride(-1) == 1
300
+ else:
301
+ residual_out = None
302
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
303
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
304
+ if dropout_p > 0.0:
305
+ seeds = torch.randint(
306
+ 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
307
+ )
308
+ else:
309
+ seeds = None
310
+ if return_dropout_mask and dropout_p > 0.0:
311
+ dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
312
+ else:
313
+ dropout_mask = None
314
+ # Less than 64KB per feature: enqueue fused kernel
315
+ MAX_FUSED_SIZE = 65536 // x.element_size()
316
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
317
+ if N > BLOCK_N:
318
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
319
+ with torch.cuda.device(x.device.index):
320
+ _layer_norm_fwd_1pass_kernel[(M,)](
321
+ x,
322
+ out,
323
+ weight,
324
+ bias,
325
+ residual,
326
+ x1,
327
+ weight1,
328
+ bias1,
329
+ y1,
330
+ residual_out,
331
+ rowscale,
332
+ seeds,
333
+ dropout_mask,
334
+ mean,
335
+ rstd,
336
+ x.stride(0),
337
+ out.stride(0),
338
+ residual.stride(0) if residual is not None else 0,
339
+ residual_out.stride(0) if residual_out is not None else 0,
340
+ x1.stride(0) if x1 is not None else 0,
341
+ y1.stride(0) if y1 is not None else 0,
342
+ M,
343
+ N,
344
+ eps,
345
+ dropout_p,
346
+ zero_centered_weight,
347
+ is_rms_norm,
348
+ BLOCK_N,
349
+ residual is not None,
350
+ residual_out is not None,
351
+ bias is not None,
352
+ dropout_p > 0.0,
353
+ dropout_mask is not None,
354
+ rowscale is not None,
355
+ )
356
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
357
+ if dropout_mask is not None and x1 is not None:
358
+ dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
359
+ else:
360
+ dropout_mask1 = None
361
+ return (
362
+ out,
363
+ y1,
364
+ mean,
365
+ rstd,
366
+ residual_out if residual_out is not None else x,
367
+ seeds,
368
+ dropout_mask,
369
+ dropout_mask1,
370
+ )
371
+
372
+ @triton.autotune(
373
+ configs=triton_autotune_configs(),
374
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
375
+ )
376
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
377
+ # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
378
+ # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
379
+ @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
380
+ @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
381
+ @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
382
+ @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
383
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
384
+ @triton.jit
385
+ def _layer_norm_bwd_kernel(
386
+ X, # pointer to the input
387
+ W, # pointer to the weights
388
+ B, # pointer to the biases
389
+ Y, # pointer to the output to be recomputed
390
+ DY, # pointer to the output gradient
391
+ DX, # pointer to the input gradient
392
+ DW, # pointer to the partial sum of weights gradient
393
+ DB, # pointer to the partial sum of biases gradient
394
+ DRESIDUAL,
395
+ W1,
396
+ DY1,
397
+ DX1,
398
+ DW1,
399
+ DB1,
400
+ DRESIDUAL_IN,
401
+ ROWSCALE,
402
+ SEEDS,
403
+ Mean, # pointer to the mean
404
+ Rstd, # pointer to the 1/std
405
+ stride_x_row, # how much to increase the pointer when moving by 1 row
406
+ stride_y_row,
407
+ stride_dy_row,
408
+ stride_dx_row,
409
+ stride_dres_row,
410
+ stride_dy1_row,
411
+ stride_dx1_row,
412
+ stride_dres_in_row,
413
+ M, # number of rows in X
414
+ N, # number of columns in X
415
+ eps, # epsilon to avoid division by zero
416
+ dropout_p,
417
+ zero_centered_weight,
418
+ rows_per_program,
419
+ IS_RMS_NORM: tl.constexpr,
420
+ BLOCK_N: tl.constexpr,
421
+ HAS_DRESIDUAL: tl.constexpr,
422
+ STORE_DRESIDUAL: tl.constexpr,
423
+ HAS_BIAS: tl.constexpr,
424
+ HAS_DROPOUT: tl.constexpr,
425
+ HAS_ROWSCALE: tl.constexpr,
426
+ HAS_DY1: tl.constexpr,
427
+ HAS_DX1: tl.constexpr,
428
+ HAS_B1: tl.constexpr,
429
+ RECOMPUTE_OUTPUT: tl.constexpr,
430
+ ):
431
+ # Map the program id to the elements of X, DX, and DY it should compute.
432
+ row_block_id = tl.program_id(0)
433
+ row_start = row_block_id * rows_per_program
434
+ # Do not early exit if row_start >= M, because we need to write DW and DB
435
+ cols = tl.arange(0, BLOCK_N)
436
+ mask = cols < N
437
+ X += row_start * stride_x_row
438
+ if HAS_DRESIDUAL:
439
+ DRESIDUAL += row_start * stride_dres_row
440
+ if STORE_DRESIDUAL:
441
+ DRESIDUAL_IN += row_start * stride_dres_in_row
442
+ DY += row_start * stride_dy_row
443
+ DX += row_start * stride_dx_row
444
+ if HAS_DY1:
445
+ DY1 += row_start * stride_dy1_row
446
+ if HAS_DX1:
447
+ DX1 += row_start * stride_dx1_row
448
+ if RECOMPUTE_OUTPUT:
449
+ Y += row_start * stride_y_row
450
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
451
+ if zero_centered_weight:
452
+ w += 1.0
453
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
454
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
455
+ if HAS_DY1:
456
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
457
+ if zero_centered_weight:
458
+ w1 += 1.0
459
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
460
+ if HAS_BIAS:
461
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
462
+ if HAS_DY1:
463
+ dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
464
+ if HAS_B1:
465
+ db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
466
+ row_end = min((row_block_id + 1) * rows_per_program, M)
467
+ for row in range(row_start, row_end):
468
+ # Load data to SRAM
469
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
470
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
471
+ if HAS_DY1:
472
+ dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
473
+ if not IS_RMS_NORM:
474
+ mean = tl.load(Mean + row)
475
+ rstd = tl.load(Rstd + row)
476
+ # Compute dx
477
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
478
+ xhat = tl.where(mask, xhat, 0.0)
479
+ if RECOMPUTE_OUTPUT:
480
+ y = xhat * w + b if HAS_BIAS else xhat * w
481
+ tl.store(Y + cols, y, mask=mask)
482
+ wdy = w * dy
483
+ dw += dy * xhat
484
+ if HAS_BIAS:
485
+ db += dy
486
+ if HAS_DY1:
487
+ wdy += w1 * dy1
488
+ dw1 += dy1 * xhat
489
+ if HAS_B1:
490
+ db1 += dy1
491
+ if not IS_RMS_NORM:
492
+ c1 = tl.sum(xhat * wdy, axis=0) / N
493
+ c2 = tl.sum(wdy, axis=0) / N
494
+ dx = (wdy - (xhat * c1 + c2)) * rstd
495
+ else:
496
+ c1 = tl.sum(xhat * wdy, axis=0) / N
497
+ dx = (wdy - xhat * c1) * rstd
498
+ if HAS_DRESIDUAL:
499
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
500
+ dx += dres
501
+ # Write dx
502
+ if STORE_DRESIDUAL:
503
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
504
+ if HAS_DX1:
505
+ if HAS_DROPOUT:
506
+ keep_mask = (
507
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
508
+ )
509
+ dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
510
+ else:
511
+ dx1 = dx
512
+ tl.store(DX1 + cols, dx1, mask=mask)
513
+ if HAS_DROPOUT:
514
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
515
+ dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
516
+ if HAS_ROWSCALE:
517
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
518
+ dx *= rowscale
519
+ tl.store(DX + cols, dx, mask=mask)
520
+
521
+ X += stride_x_row
522
+ if HAS_DRESIDUAL:
523
+ DRESIDUAL += stride_dres_row
524
+ if STORE_DRESIDUAL:
525
+ DRESIDUAL_IN += stride_dres_in_row
526
+ if RECOMPUTE_OUTPUT:
527
+ Y += stride_y_row
528
+ DY += stride_dy_row
529
+ DX += stride_dx_row
530
+ if HAS_DY1:
531
+ DY1 += stride_dy1_row
532
+ if HAS_DX1:
533
+ DX1 += stride_dx1_row
534
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
535
+ if HAS_BIAS:
536
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
537
+ if HAS_DY1:
538
+ tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
539
+ if HAS_B1:
540
+ tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
541
+
542
+
543
+ def _layer_norm_bwd(
544
+ dy,
545
+ x,
546
+ weight,
547
+ bias,
548
+ eps,
549
+ mean,
550
+ rstd,
551
+ dresidual=None,
552
+ dy1=None,
553
+ weight1=None,
554
+ bias1=None,
555
+ seeds=None,
556
+ dropout_p=0.0,
557
+ rowscale=None,
558
+ has_residual=False,
559
+ has_x1=False,
560
+ zero_centered_weight=False,
561
+ is_rms_norm=False,
562
+ x_dtype=None,
563
+ recompute_output=False,
564
+ ):
565
+ M, N = x.shape
566
+ assert x.stride(-1) == 1
567
+ assert dy.stride(-1) == 1
568
+ assert dy.shape == (M, N)
569
+ if dresidual is not None:
570
+ assert dresidual.stride(-1) == 1
571
+ assert dresidual.shape == (M, N)
572
+ assert weight.shape == (N,)
573
+ assert weight.stride(-1) == 1
574
+ if bias is not None:
575
+ assert bias.stride(-1) == 1
576
+ assert bias.shape == (N,)
577
+ if dy1 is not None:
578
+ assert weight1 is not None
579
+ assert dy1.shape == dy.shape
580
+ assert dy1.stride(-1) == 1
581
+ if weight1 is not None:
582
+ assert weight1.shape == (N,)
583
+ assert weight1.stride(-1) == 1
584
+ if bias1 is not None:
585
+ assert bias1.shape == (N,)
586
+ assert bias1.stride(-1) == 1
587
+ if seeds is not None:
588
+ assert seeds.is_contiguous()
589
+ assert seeds.shape == (M if not has_x1 else M * 2,)
590
+ if rowscale is not None:
591
+ assert rowscale.is_contiguous()
592
+ assert rowscale.shape == (M,)
593
+ # allocate output
594
+ dx = (
595
+ torch.empty_like(x)
596
+ if x_dtype is None
597
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
598
+ )
599
+ dresidual_in = (
600
+ torch.empty_like(x)
601
+ if has_residual
602
+ and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
603
+ else None
604
+ )
605
+ dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
606
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
607
+ if recompute_output:
608
+ assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
609
+
610
+ # Less than 64KB per feature: enqueue fused kernel
611
+ MAX_FUSED_SIZE = 65536 // x.element_size()
612
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
613
+ if N > BLOCK_N:
614
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
615
+ # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the
616
+ # latency of the gmem reads/writes, but will increase the time of summing up dw / db.
617
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8
618
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
619
+ _db = (
620
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
621
+ if bias is not None
622
+ else None
623
+ )
624
+ _dw1 = torch.empty_like(_dw) if weight1 is not None else None
625
+ _db1 = torch.empty_like(_db) if bias1 is not None else None
626
+ rows_per_program = math.ceil(M / sm_count)
627
+ grid = (sm_count,)
628
+ with torch.cuda.device(x.device.index):
629
+ _layer_norm_bwd_kernel[grid](
630
+ x,
631
+ weight,
632
+ bias,
633
+ y,
634
+ dy,
635
+ dx,
636
+ _dw,
637
+ _db,
638
+ dresidual,
639
+ weight1,
640
+ dy1,
641
+ dx1,
642
+ _dw1,
643
+ _db1,
644
+ dresidual_in,
645
+ rowscale,
646
+ seeds,
647
+ mean,
648
+ rstd,
649
+ x.stride(0),
650
+ 0 if not recompute_output else y.stride(0),
651
+ dy.stride(0),
652
+ dx.stride(0),
653
+ dresidual.stride(0) if dresidual is not None else 0,
654
+ dy1.stride(0) if dy1 is not None else 0,
655
+ dx1.stride(0) if dx1 is not None else 0,
656
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
657
+ M,
658
+ N,
659
+ eps,
660
+ dropout_p,
661
+ zero_centered_weight,
662
+ rows_per_program,
663
+ is_rms_norm,
664
+ BLOCK_N,
665
+ dresidual is not None,
666
+ dresidual_in is not None,
667
+ bias is not None,
668
+ dropout_p > 0.0,
669
+ )
670
+ dw = _dw.sum(0).to(weight.dtype)
671
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
672
+ dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
673
+ db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
674
+ # Don't need to compute dresidual_in separately in this case
675
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
676
+ dresidual_in = dx
677
+ if has_x1 and dropout_p == 0.0:
678
+ dx1 = dx
679
+ return (
680
+ (dx, dw, db, dresidual_in, dx1, dw1, db1)
681
+ if not recompute_output
682
+ else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
683
+ )
684
+
685
+ class LayerNormFn(torch.autograd.Function):
686
+ @staticmethod
687
+ def forward(
688
+ ctx,
689
+ x,
690
+ weight,
691
+ bias,
692
+ residual=None,
693
+ x1=None,
694
+ weight1=None,
695
+ bias1=None,
696
+ eps=1e-6,
697
+ dropout_p=0.0,
698
+ rowscale=None,
699
+ prenorm=False,
700
+ residual_in_fp32=False,
701
+ zero_centered_weight=False,
702
+ is_rms_norm=False,
703
+ return_dropout_mask=False,
704
+ out=None,
705
+ residual_out=None
706
+ ):
707
+ x_shape_og = x.shape
708
+ # Check for zero sequence length
709
+ if x.numel() == 0:
710
+ ctx.zero_seq_length = True
711
+ # Only save minimal required tensors for backward
712
+ # ctx.save_for_backward(weight, bias, weight1, bias1)
713
+ ctx.x_shape_og = x_shape_og
714
+ ctx.weight_shape = weight.shape
715
+ ctx.weight_dtype = weight.dtype
716
+ ctx.weight_device = weight.device
717
+
718
+ ctx.has_bias = bias is not None
719
+ ctx.bias_shape = bias.shape if bias is not None else None
720
+ ctx.bias_dtype = bias.dtype if bias is not None else None
721
+ ctx.bias_device = bias.device if bias is not None else None
722
+
723
+ ctx.has_weight1 = weight1 is not None
724
+ ctx.weight1_shape = weight1.shape if weight1 is not None else None
725
+ ctx.weight1_dtype = weight1.dtype if weight1 is not None else None
726
+ ctx.weight1_device = weight1.device if weight1 is not None else None
727
+
728
+ ctx.has_bias1 = bias1 is not None
729
+ ctx.bias1_shape = bias1.shape if bias1 is not None else None
730
+ ctx.bias1_dtype = bias1.dtype if bias1 is not None else None
731
+ ctx.bias1_device = bias1.device if bias1 is not None else None
732
+
733
+ ctx.has_residual = residual is not None
734
+ ctx.has_x1 = x1 is not None
735
+ ctx.dropout_p = dropout_p
736
+
737
+ # Handle output tensors with correct dtype
738
+ y = x # Preserve input tensor properties
739
+ y1 = torch.empty_like(x) if x1 is not None else None
740
+
741
+ # Only create residual_out if prenorm is True
742
+ residual_out = torch.empty(x.shape,
743
+ dtype=torch.float32 if residual_in_fp32 else x.dtype,
744
+ device=x.device) if prenorm else None
745
+
746
+ # Handle dropout masks
747
+ dropout_mask = None
748
+ dropout_mask1 = None
749
+ if return_dropout_mask:
750
+ dropout_mask = torch.empty_like(x, dtype=torch.uint8)
751
+ if x1 is not None:
752
+ dropout_mask1 = torch.empty_like(x, dtype=torch.uint8)
753
+
754
+ # Return based on configuration
755
+ if not return_dropout_mask:
756
+ if weight1 is None:
757
+ return y if not prenorm else (y, residual_out)
758
+ else:
759
+ return (y, y1) if not prenorm else (y, y1, residual_out)
760
+ else:
761
+ if weight1 is None:
762
+ return ((y, dropout_mask, dropout_mask1) if not prenorm
763
+ else (y, residual_out, dropout_mask, dropout_mask1))
764
+ else:
765
+ return ((y, y1, dropout_mask, dropout_mask1) if not prenorm
766
+ else (y, y1, residual_out, dropout_mask, dropout_mask1))
767
+
768
+ ctx.zero_seq_length = False
769
+ # reshape input data into 2D tensor
770
+ x = x.reshape(-1, x.shape[-1])
771
+ if x.stride(-1) != 1:
772
+ x = x.contiguous()
773
+ if residual is not None:
774
+ assert residual.shape == x_shape_og
775
+ residual = residual.reshape(-1, residual.shape[-1])
776
+ if residual.stride(-1) != 1:
777
+ residual = residual.contiguous()
778
+ if x1 is not None:
779
+ assert x1.shape == x_shape_og
780
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
781
+ x1 = x1.reshape(-1, x1.shape[-1])
782
+ if x1.stride(-1) != 1:
783
+ x1 = x1.contiguous()
784
+ weight = weight.contiguous()
785
+ if bias is not None:
786
+ bias = bias.contiguous()
787
+ if weight1 is not None:
788
+ weight1 = weight1.contiguous()
789
+ if bias1 is not None:
790
+ bias1 = bias1.contiguous()
791
+ if rowscale is not None:
792
+ rowscale = rowscale.reshape(-1).contiguous()
793
+ residual_dtype = (
794
+ residual.dtype
795
+ if residual is not None
796
+ else (torch.float32 if residual_in_fp32 else None)
797
+ )
798
+ if out is not None:
799
+ out = out.reshape(-1, out.shape[-1])
800
+ if residual_out is not None:
801
+ residual_out = residual_out.reshape(-1, residual_out.shape[-1])
802
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
803
+ x,
804
+ weight,
805
+ bias,
806
+ eps,
807
+ residual,
808
+ x1,
809
+ weight1,
810
+ bias1,
811
+ dropout_p=dropout_p,
812
+ rowscale=rowscale,
813
+ residual_dtype=residual_dtype,
814
+ zero_centered_weight=zero_centered_weight,
815
+ is_rms_norm=is_rms_norm,
816
+ return_dropout_mask=return_dropout_mask,
817
+ out=out,
818
+ residual_out=residual_out
819
+ )
820
+ ctx.save_for_backward(
821
+ residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
822
+ )
823
+ ctx.x_shape_og = x_shape_og
824
+ ctx.eps = eps
825
+ ctx.dropout_p = dropout_p
826
+ ctx.is_rms_norm = is_rms_norm
827
+ ctx.has_residual = residual is not None
828
+ ctx.has_x1 = x1 is not None
829
+ ctx.prenorm = prenorm
830
+ ctx.x_dtype = x.dtype
831
+ ctx.zero_centered_weight = zero_centered_weight
832
+ y = y.reshape(x_shape_og)
833
+ y1 = y1.reshape(x_shape_og) if y1 is not None else None
834
+ residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
835
+ dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
836
+ dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
837
+ if not return_dropout_mask:
838
+ if weight1 is None:
839
+ return y if not prenorm else (y, residual_out)
840
+ else:
841
+ return (y, y1) if not prenorm else (y, y1, residual_out)
842
+ else:
843
+ if weight1 is None:
844
+ return (
845
+ (y, dropout_mask, dropout_mask1)
846
+ if not prenorm
847
+ else (y, residual_out, dropout_mask, dropout_mask1)
848
+ )
849
+ else:
850
+ return (
851
+ (y, y1, dropout_mask, dropout_mask1)
852
+ if not prenorm
853
+ else (y, y1, residual_out, dropout_mask, dropout_mask1)
854
+ )
855
+
856
+ @staticmethod
857
+ def backward(ctx, dy, *args):
858
+ if ctx.zero_seq_length:
859
+ return (
860
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device),
861
+ torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device),
862
+ torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None,
863
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None,
864
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None,
865
+ torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None,
866
+ torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None,
867
+ None,
868
+ None,
869
+ None,
870
+ None,
871
+ None,
872
+ None,
873
+ None,
874
+ None,
875
+ None,
876
+ None,
877
+ )
878
+
879
+ x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
880
+ dy = dy.reshape(-1, dy.shape[-1])
881
+ if dy.stride(-1) != 1:
882
+ dy = dy.contiguous()
883
+ assert dy.shape == x.shape
884
+ if weight1 is not None:
885
+ dy1, args = args[0], args[1:]
886
+ dy1 = dy1.reshape(-1, dy1.shape[-1])
887
+ if dy1.stride(-1) != 1:
888
+ dy1 = dy1.contiguous()
889
+ assert dy1.shape == x.shape
890
+ else:
891
+ dy1 = None
892
+ if ctx.prenorm:
893
+ dresidual = args[0]
894
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
895
+ if dresidual.stride(-1) != 1:
896
+ dresidual = dresidual.contiguous()
897
+ assert dresidual.shape == x.shape
898
+ else:
899
+ dresidual = None
900
+
901
+ dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
902
+ dy,
903
+ x,
904
+ weight,
905
+ bias,
906
+ ctx.eps,
907
+ mean,
908
+ rstd,
909
+ dresidual,
910
+ dy1,
911
+ weight1,
912
+ bias1,
913
+ seeds,
914
+ ctx.dropout_p,
915
+ rowscale,
916
+ ctx.has_residual,
917
+ ctx.has_x1,
918
+ ctx.zero_centered_weight,
919
+ ctx.is_rms_norm,
920
+ x_dtype=ctx.x_dtype,
921
+ )
922
+ return (
923
+ dx.reshape(ctx.x_shape_og),
924
+ dw,
925
+ db,
926
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
927
+ dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
928
+ dw1,
929
+ db1,
930
+ None,
931
+ None,
932
+ None,
933
+ None,
934
+ None,
935
+ None,
936
+ None,
937
+ None,
938
+ None,
939
+ None,
940
+ )
941
+
942
+ def rms_norm_fn(
943
+ x,
944
+ weight,
945
+ bias,
946
+ residual=None,
947
+ x1=None,
948
+ weight1=None,
949
+ bias1=None,
950
+ eps=1e-6,
951
+ dropout_p=0.0,
952
+ rowscale=None,
953
+ prenorm=False,
954
+ residual_in_fp32=False,
955
+ zero_centered_weight=False,
956
+ return_dropout_mask=False,
957
+ out=None,
958
+ residual_out=None
959
+ ):
960
+ return LayerNormFn.apply(
961
+ x,
962
+ weight,
963
+ bias,
964
+ residual,
965
+ x1,
966
+ weight1,
967
+ bias1,
968
+ eps,
969
+ dropout_p,
970
+ rowscale,
971
+ prenorm,
972
+ residual_in_fp32,
973
+ zero_centered_weight,
974
+ True,
975
+ return_dropout_mask,
976
+ out,
977
+ residual_out
978
+ )
979
+
980
+ class RMSNorm(torch.nn.Module):
981
+ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False,
982
+ device=None, dtype=None):
983
+ factory_kwargs = {"device": device, "dtype": dtype}
984
+ super().__init__()
985
+ self.eps = eps
986
+ if dropout_p > 0.0:
987
+ self.drop = torch.nn.Dropout(dropout_p)
988
+ else:
989
+ self.drop = None
990
+ self.zero_centered_weight = zero_centered_weight
991
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
992
+ self.register_parameter("bias", None)
993
+ self.reset_parameters()
994
+
995
+ def reset_parameters(self):
996
+ if not self.zero_centered_weight:
997
+ torch.nn.init.ones_(self.weight)
998
+ else:
999
+ torch.nn.init.zeros_(self.weight)
1000
+
1001
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
1002
+ return rms_norm_fn(
1003
+ x,
1004
+ self.weight,
1005
+ self.bias,
1006
+ residual=residual,
1007
+ eps=self.eps,
1008
+ dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
1009
+ prenorm=prenorm,
1010
+ residual_in_fp32=residual_in_fp32,
1011
+ zero_centered_weight=self.zero_centered_weight,
1012
+ )
1013
+ else:
1014
+ from torch.nn import RMSNorm
1015
+ warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance")
1016
+
1017
+ def swiglu(x, y):
1018
+ return F.silu(x.float(), inplace=False).to(x.dtype) * y
1019
+
1020
+ logger = logging.get_logger(__name__)
1021
+
1022
+ @dataclass
1023
+ class TeaCacheParams:
1024
+ previous_residual: Optional[torch.Tensor] = None
1025
+ previous_modulated_inp: Optional[torch.Tensor] = None
1026
+ accumulated_rel_l1_distance: float = 0
1027
+ is_first_or_last_step: bool = False
1028
+
1029
+
1030
+ class TimestepEmbedding(nn.Module):
1031
+ def __init__(
1032
+ self,
1033
+ in_channels: int,
1034
+ time_embed_dim: int,
1035
+ act_fn: str = "silu",
1036
+ out_dim: int = None,
1037
+ post_act_fn: Optional[str] = None,
1038
+ cond_proj_dim=None,
1039
+ sample_proj_bias=True,
1040
+ ):
1041
+ super().__init__()
1042
+
1043
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
1044
+
1045
+ if cond_proj_dim is not None:
1046
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
1047
+ else:
1048
+ self.cond_proj = None
1049
+
1050
+ self.act = get_activation(act_fn)
1051
+
1052
+ if out_dim is not None:
1053
+ time_embed_dim_out = out_dim
1054
+ else:
1055
+ time_embed_dim_out = time_embed_dim
1056
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
1057
+
1058
+ if post_act_fn is None:
1059
+ self.post_act = None
1060
+ else:
1061
+ self.post_act = get_activation(post_act_fn)
1062
+
1063
+ self.initialize_weights()
1064
+
1065
+ def initialize_weights(self):
1066
+ nn.init.normal_(self.linear_1.weight, std=0.02)
1067
+ nn.init.zeros_(self.linear_1.bias)
1068
+ nn.init.normal_(self.linear_2.weight, std=0.02)
1069
+ nn.init.zeros_(self.linear_2.bias)
1070
+
1071
+ def forward(self, sample, condition=None):
1072
+ if condition is not None:
1073
+ sample = sample + self.cond_proj(condition)
1074
+ sample = self.linear_1(sample)
1075
+
1076
+ if self.act is not None:
1077
+ sample = self.act(sample)
1078
+
1079
+ sample = self.linear_2(sample)
1080
+
1081
+ if self.post_act is not None:
1082
+ sample = self.post_act(sample)
1083
+ return sample
1084
+
1085
+ def apply_rotary_emb(
1086
+ x: torch.Tensor,
1087
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
1088
+ use_real: bool = True,
1089
+ use_real_unbind_dim: int = -1,
1090
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1091
+ """
1092
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
1093
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
1094
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
1095
+ tensors contain rotary embeddings and are returned as real tensors.
1096
+
1097
+ Args:
1098
+ x (`torch.Tensor`):
1099
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
1100
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
1101
+
1102
+ Returns:
1103
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
1104
+ """
1105
+ if use_real:
1106
+ cos, sin = freqs_cis # [S, D]
1107
+ cos = cos[None, None]
1108
+ sin = sin[None, None]
1109
+ cos, sin = cos.to(x.device), sin.to(x.device)
1110
+
1111
+ if use_real_unbind_dim == -1:
1112
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
1113
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
1114
+ elif use_real_unbind_dim == -2:
1115
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
1116
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
1117
+ else:
1118
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
1119
+
1120
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
1121
+
1122
+ return out
1123
+ else:
1124
+ # used for lumina
1125
+ # x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
1126
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2))
1127
+ freqs_cis = freqs_cis.unsqueeze(2)
1128
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
1129
+
1130
+ return x_out.type_as(x)
1131
+
1132
+ class ThinkGenRotaryPosEmbed(nn.Module):
1133
+ def __init__(self, theta: int,
1134
+ axes_dim: Tuple[int, int, int],
1135
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
1136
+ patch_size: int = 2):
1137
+ super().__init__()
1138
+ self.theta = theta
1139
+ self.axes_dim = axes_dim
1140
+ self.axes_lens = axes_lens
1141
+ self.patch_size = patch_size
1142
+
1143
+ @staticmethod
1144
+ def get_freqs_cis(axes_dim: Tuple[int, int, int],
1145
+ axes_lens: Tuple[int, int, int],
1146
+ theta: int) -> List[torch.Tensor]:
1147
+ freqs_cis = []
1148
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
1149
+ for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
1150
+ emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
1151
+ freqs_cis.append(emb)
1152
+ return freqs_cis
1153
+
1154
+ def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor:
1155
+ device = ids.device
1156
+ if ids.device.type == "mps":
1157
+ ids = ids.to("cpu")
1158
+
1159
+ result = []
1160
+ for i in range(len(self.axes_dim)):
1161
+ freqs = freqs_cis[i].to(ids.device)
1162
+ index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
1163
+ result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
1164
+ return torch.cat(result, dim=-1).to(device)
1165
+
1166
+ def forward(
1167
+ self,
1168
+ freqs_cis,
1169
+ attention_mask,
1170
+ l_effective_ref_img_len,
1171
+ l_effective_img_len,
1172
+ ref_img_sizes,
1173
+ img_sizes,
1174
+ device
1175
+ ):
1176
+ batch_size = len(attention_mask)
1177
+ p = self.patch_size
1178
+
1179
+ encoder_seq_len = attention_mask.shape[1]
1180
+ l_effective_cap_len = attention_mask.sum(dim=1).tolist()
1181
+
1182
+ seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
1183
+
1184
+ max_seq_len = max(seq_lengths)
1185
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
1186
+ max_img_len = max(l_effective_img_len)
1187
+
1188
+ # Create position IDs
1189
+ position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
1190
+
1191
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
1192
+ # add text position ids
1193
+ position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
1194
+
1195
+ pe_shift = cap_seq_len
1196
+ pe_shift_len = cap_seq_len
1197
+
1198
+ if ref_img_sizes[i] is not None:
1199
+ for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
1200
+ H, W = ref_img_size
1201
+ ref_H_tokens, ref_W_tokens = H // p, W // p
1202
+ assert ref_H_tokens * ref_W_tokens == ref_img_len
1203
+ # add image position ids
1204
+
1205
+ row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
1206
+ col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
1207
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
1208
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
1209
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
1210
+
1211
+ pe_shift += max(ref_H_tokens, ref_W_tokens)
1212
+ pe_shift_len += ref_img_len
1213
+
1214
+ H, W = img_sizes[i]
1215
+ H_tokens, W_tokens = H // p, W // p
1216
+ assert H_tokens * W_tokens == l_effective_img_len[i]
1217
+
1218
+ row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
1219
+ col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
1220
+
1221
+ assert pe_shift_len + l_effective_img_len[i] == seq_len
1222
+ position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
1223
+ position_ids[i, pe_shift_len: seq_len, 1] = row_ids
1224
+ position_ids[i, pe_shift_len: seq_len, 2] = col_ids
1225
+
1226
+ # Get combined rotary embeddings
1227
+ freqs_cis = self._get_freqs_cis(freqs_cis, position_ids)
1228
+
1229
+ # create separate rotary embeddings for captions and images
1230
+ cap_freqs_cis = torch.zeros(
1231
+ batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
1232
+ )
1233
+ ref_img_freqs_cis = torch.zeros(
1234
+ batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
1235
+ )
1236
+ img_freqs_cis = torch.zeros(
1237
+ batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
1238
+ )
1239
+
1240
+ for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
1241
+ cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
1242
+ ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
1243
+ img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
1244
+
1245
+ return (
1246
+ cap_freqs_cis,
1247
+ ref_img_freqs_cis,
1248
+ img_freqs_cis,
1249
+ freqs_cis,
1250
+ l_effective_cap_len,
1251
+ seq_lengths,
1252
+ )
1253
+
1254
+
1255
+ class LuminaRMSNormZero(nn.Module):
1256
+ """
1257
+ Norm layer adaptive RMS normalization zero.
1258
+
1259
+ Parameters:
1260
+ embedding_dim (`int`): The size of each embedding vector.
1261
+ """
1262
+
1263
+ def __init__(
1264
+ self,
1265
+ embedding_dim: int,
1266
+ norm_eps: float,
1267
+ norm_elementwise_affine: bool,
1268
+ ):
1269
+ super().__init__()
1270
+ self.silu = nn.SiLU()
1271
+ self.linear = nn.Linear(
1272
+ min(embedding_dim, 1024),
1273
+ 4 * embedding_dim,
1274
+ bias=True,
1275
+ )
1276
+ self.norm = RMSNorm(embedding_dim, eps=norm_eps)
1277
+
1278
+ def forward(
1279
+ self,
1280
+ x: torch.Tensor,
1281
+ emb: Optional[torch.Tensor] = None,
1282
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1283
+ emb = self.linear(self.silu(emb))
1284
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
1285
+ x = self.norm(x) * (1 + scale_msa[:, None])
1286
+ return x, gate_msa, scale_mlp, gate_mlp
1287
+
1288
+
1289
+ class LuminaLayerNormContinuous(nn.Module):
1290
+ def __init__(
1291
+ self,
1292
+ embedding_dim: int,
1293
+ conditioning_embedding_dim: int,
1294
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
1295
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
1296
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
1297
+ # However, this is how it was implemented in the original code, and it's rather likely you should
1298
+ # set `elementwise_affine` to False.
1299
+ elementwise_affine=True,
1300
+ eps=1e-5,
1301
+ bias=True,
1302
+ norm_type="layer_norm",
1303
+ out_dim: Optional[int] = None,
1304
+ ):
1305
+ super().__init__()
1306
+
1307
+ # AdaLN
1308
+ self.silu = nn.SiLU()
1309
+ self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
1310
+
1311
+ if norm_type == "layer_norm":
1312
+ self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
1313
+ elif norm_type == "rms_norm":
1314
+ self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
1315
+ else:
1316
+ raise ValueError(f"unknown norm_type {norm_type}")
1317
+
1318
+ self.linear_2 = None
1319
+ if out_dim is not None:
1320
+ self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
1321
+
1322
+ def forward(
1323
+ self,
1324
+ x: torch.Tensor,
1325
+ conditioning_embedding: torch.Tensor,
1326
+ ) -> torch.Tensor:
1327
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
1328
+ emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
1329
+ scale = emb
1330
+ x = self.norm(x) * (1 + scale)[:, None, :]
1331
+
1332
+ if self.linear_2 is not None:
1333
+ x = self.linear_2(x)
1334
+
1335
+ return x
1336
+
1337
+
1338
+ class LuminaFeedForward(nn.Module):
1339
+ r"""
1340
+ A feed-forward layer.
1341
+
1342
+ Parameters:
1343
+ hidden_size (`int`):
1344
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
1345
+ hidden representations.
1346
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
1347
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
1348
+ of this value.
1349
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
1350
+ dimension. Defaults to None.
1351
+ """
1352
+
1353
+ def __init__(
1354
+ self,
1355
+ dim: int,
1356
+ inner_dim: int,
1357
+ multiple_of: Optional[int] = 256,
1358
+ ffn_dim_multiplier: Optional[float] = None,
1359
+ ):
1360
+ super().__init__()
1361
+
1362
+ self.swiglu = swiglu
1363
+
1364
+ # custom hidden_size factor multiplier
1365
+ if ffn_dim_multiplier is not None:
1366
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
1367
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
1368
+
1369
+ self.linear_1 = nn.Linear(
1370
+ dim,
1371
+ inner_dim,
1372
+ bias=False,
1373
+ )
1374
+ self.linear_2 = nn.Linear(
1375
+ inner_dim,
1376
+ dim,
1377
+ bias=False,
1378
+ )
1379
+ self.linear_3 = nn.Linear(
1380
+ dim,
1381
+ inner_dim,
1382
+ bias=False,
1383
+ )
1384
+
1385
+ def forward(self, x):
1386
+ h1, h2 = self.linear_1(x), self.linear_3(x)
1387
+ return self.linear_2(self.swiglu(h1, h2))
1388
+
1389
+
1390
+ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
1391
+ def __init__(
1392
+ self,
1393
+ hidden_size: int = 4096,
1394
+ text_feat_dim: int = 204800, # 2048
1395
+ frequency_embedding_size: int = 256,
1396
+ norm_eps: float = 1e-5,
1397
+ timestep_scale: float = 1.0,
1398
+ ) -> None:
1399
+ super().__init__()
1400
+
1401
+ self.time_proj = Timesteps(
1402
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale
1403
+ )
1404
+
1405
+ self.timestep_embedder = TimestepEmbedding(
1406
+ in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
1407
+ )
1408
+
1409
+ self.caption_embedder = nn.Sequential(
1410
+ RMSNorm(text_feat_dim*2, eps=norm_eps),
1411
+ nn.Linear(text_feat_dim*2, hidden_size, bias=True),
1412
+ )
1413
+
1414
+ self._initialize_weights()
1415
+
1416
+ def _initialize_weights(self):
1417
+ for name, module in self.caption_embedder.named_modules():
1418
+ if hasattr(module, 'weight') and module.weight is not None:
1419
+ nn.init.trunc_normal_(module.weight, std=0.02)
1420
+ print(name, "a")
1421
+ if hasattr(module, 'bias') and module.bias is not None:
1422
+ nn.init.zeros_(module.bias)
1423
+ print(name, "b")
1424
+
1425
+ print("init caption_embedder done")
1426
+
1427
+
1428
+ def forward(
1429
+ self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
1430
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1431
+ timestep_proj = self.time_proj(timestep).to(dtype=dtype)
1432
+ time_embed = self.timestep_embedder(timestep_proj)
1433
+ caption_embed = self.caption_embedder(text_hidden_states)
1434
+ return time_embed, caption_embed
1435
+
1436
+
1437
+ class ThinkGenAttnProcessor:
1438
+ """
1439
+ Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
1440
+
1441
+ This processor is optimized for PyTorch 2.0 and implements:
1442
+ - Flash attention with variable length sequences
1443
+ - Rotary position embeddings (RoPE)
1444
+ - Query-Key normalization
1445
+ - Proportional attention scaling
1446
+
1447
+ Args:
1448
+ None
1449
+
1450
+ Raises:
1451
+ ImportError: If PyTorch version is less than 2.0
1452
+ """
1453
+
1454
+ def __init__(self) -> None:
1455
+ """Initialize the attention processor."""
1456
+ if not hasattr(F, "scaled_dot_product_attention"):
1457
+ raise ImportError(
1458
+ "ThinkGenAttnProcessorFlash2Varlen requires PyTorch 2.0. "
1459
+ "Please upgrade PyTorch to version 2.0 or later."
1460
+ )
1461
+
1462
+ def __call__(
1463
+ self,
1464
+ attn: Attention,
1465
+ hidden_states: torch.Tensor,
1466
+ encoder_hidden_states: torch.Tensor,
1467
+ attention_mask: Optional[torch.Tensor] = None,
1468
+ image_rotary_emb: Optional[torch.Tensor] = None,
1469
+ base_sequence_length: Optional[int] = None,
1470
+ ) -> torch.Tensor:
1471
+ """
1472
+ Process attention computation with flash attention.
1473
+
1474
+ Args:
1475
+ attn: Attention module
1476
+ hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
1477
+ encoder_hidden_states: Encoder hidden states tensor
1478
+ attention_mask: Optional attention mask tensor
1479
+ image_rotary_emb: Optional rotary embeddings for image tokens
1480
+ base_sequence_length: Optional base sequence length for proportional attention
1481
+
1482
+ Returns:
1483
+ torch.Tensor: Processed hidden states after attention computation
1484
+ """
1485
+ batch_size, sequence_length, _ = hidden_states.shape
1486
+
1487
+ # Get Query-Key-Value Pair
1488
+ query = attn.to_q(hidden_states)
1489
+ key = attn.to_k(encoder_hidden_states)
1490
+ value = attn.to_v(encoder_hidden_states)
1491
+
1492
+ query_dim = query.shape[-1]
1493
+ inner_dim = key.shape[-1]
1494
+ head_dim = query_dim // attn.heads
1495
+ dtype = query.dtype
1496
+
1497
+ # Get key-value heads
1498
+ kv_heads = inner_dim // head_dim
1499
+
1500
+ # Reshape tensors for attention computation
1501
+ query = query.view(batch_size, -1, attn.heads, head_dim)
1502
+ key = key.view(batch_size, -1, kv_heads, head_dim)
1503
+ value = value.view(batch_size, -1, kv_heads, head_dim)
1504
+
1505
+ # Apply Query-Key normalization
1506
+ if attn.norm_q is not None:
1507
+ query = attn.norm_q(query)
1508
+ if attn.norm_k is not None:
1509
+ key = attn.norm_k(key)
1510
+
1511
+ # Apply Rotary Position Embeddings
1512
+ if image_rotary_emb is not None:
1513
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
1514
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
1515
+
1516
+ query, key = query.to(dtype), key.to(dtype)
1517
+
1518
+ # Calculate attention scale
1519
+ if base_sequence_length is not None:
1520
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
1521
+ else:
1522
+ softmax_scale = attn.scale
1523
+
1524
+ # scaled_dot_product_attention expects attention_mask shape to be
1525
+ # (batch, heads, source_length, target_length)
1526
+ if attention_mask is not None:
1527
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
1528
+
1529
+ query = query.transpose(1, 2)
1530
+ key = key.transpose(1, 2)
1531
+ value = value.transpose(1, 2)
1532
+
1533
+ # explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6
1534
+ key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
1535
+ value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
1536
+
1537
+ hidden_states = F.scaled_dot_product_attention(
1538
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
1539
+ )
1540
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1541
+ hidden_states = hidden_states.type_as(query)
1542
+
1543
+ # Apply output projection
1544
+ hidden_states = attn.to_out[0](hidden_states)
1545
+ hidden_states = attn.to_out[1](hidden_states)
1546
+
1547
+ return hidden_states
1548
+
1549
+
1550
+
1551
+ class ThinkGenAttnProcessorFlash2Varlen:
1552
+ """
1553
+ Processor for implementing scaled dot-product attention with flash attention and variable length sequences.
1554
+
1555
+ This processor implements:
1556
+ - Flash attention with variable length sequences
1557
+ - Rotary position embeddings (RoPE)
1558
+ - Query-Key normalization
1559
+ - Proportional attention scaling
1560
+
1561
+ Args:
1562
+ None
1563
+ """
1564
+
1565
+ def __init__(self) -> None:
1566
+ """Initialize the attention processor."""
1567
+ if not is_flash_attn_available():
1568
+ raise ImportError(
1569
+ "ThinkGenAttnProcessorFlash2Varlen requires flash_attn. "
1570
+ "Please install flash_attn."
1571
+ )
1572
+
1573
+ def _upad_input(
1574
+ self,
1575
+ query_layer: torch.Tensor,
1576
+ key_layer: torch.Tensor,
1577
+ value_layer: torch.Tensor,
1578
+ attention_mask: torch.Tensor,
1579
+ query_length: int,
1580
+ num_heads: int,
1581
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
1582
+ """
1583
+ Unpad the input tensors for flash attention.
1584
+
1585
+ Args:
1586
+ query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim)
1587
+ key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
1588
+ value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim)
1589
+ attention_mask: Attention mask tensor of shape (batch_size, seq_len)
1590
+ query_length: Length of the query sequence
1591
+ num_heads: Number of attention heads
1592
+
1593
+ Returns:
1594
+ Tuple containing:
1595
+ - Unpadded query tensor
1596
+ - Unpadded key tensor
1597
+ - Unpadded value tensor
1598
+ - Query indices
1599
+ - Tuple of cumulative sequence lengths for query and key
1600
+ - Tuple of maximum sequence lengths for query and key
1601
+ """
1602
+ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
1603
+ """Helper function to get unpadding data from attention mask."""
1604
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
1605
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
1606
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
1607
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
1608
+ return indices, cu_seqlens, max_seqlen_in_batch
1609
+
1610
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
1611
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1612
+
1613
+ # Unpad key and value layers
1614
+ key_layer = index_first_axis(
1615
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1616
+ indices_k,
1617
+ )
1618
+ value_layer = index_first_axis(
1619
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1620
+ indices_k,
1621
+ )
1622
+
1623
+ # Handle different query length cases
1624
+ if query_length == kv_seq_len:
1625
+ query_layer = index_first_axis(
1626
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
1627
+ indices_k,
1628
+ )
1629
+ cu_seqlens_q = cu_seqlens_k
1630
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
1631
+ indices_q = indices_k
1632
+ elif query_length == 1:
1633
+ max_seqlen_in_batch_q = 1
1634
+ cu_seqlens_q = torch.arange(
1635
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
1636
+ )
1637
+ indices_q = cu_seqlens_q[:-1]
1638
+ query_layer = query_layer.squeeze(1)
1639
+ else:
1640
+ attention_mask = attention_mask[:, -query_length:]
1641
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
1642
+
1643
+ return (
1644
+ query_layer,
1645
+ key_layer,
1646
+ value_layer,
1647
+ indices_q,
1648
+ (cu_seqlens_q, cu_seqlens_k),
1649
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
1650
+ )
1651
+
1652
+ def __call__(
1653
+ self,
1654
+ attn: Attention,
1655
+ hidden_states: torch.Tensor,
1656
+ encoder_hidden_states: torch.Tensor,
1657
+ attention_mask: Optional[torch.Tensor] = None,
1658
+ image_rotary_emb: Optional[torch.Tensor] = None,
1659
+ base_sequence_length: Optional[int] = None,
1660
+ ) -> torch.Tensor:
1661
+ """
1662
+ Process attention computation with flash attention.
1663
+
1664
+ Args:
1665
+ attn: Attention module
1666
+ hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
1667
+ encoder_hidden_states: Encoder hidden states tensor
1668
+ attention_mask: Optional attention mask tensor
1669
+ image_rotary_emb: Optional rotary embeddings for image tokens
1670
+ base_sequence_length: Optional base sequence length for proportional attention
1671
+
1672
+ Returns:
1673
+ torch.Tensor: Processed hidden states after attention computation
1674
+ """
1675
+ batch_size, sequence_length, _ = hidden_states.shape
1676
+
1677
+ # Get Query-Key-Value Pair
1678
+ query = attn.to_q(hidden_states)
1679
+ key = attn.to_k(encoder_hidden_states)
1680
+ value = attn.to_v(encoder_hidden_states)
1681
+
1682
+ query_dim = query.shape[-1]
1683
+ inner_dim = key.shape[-1]
1684
+ head_dim = query_dim // attn.heads
1685
+ dtype = query.dtype
1686
+
1687
+ # Get key-value heads
1688
+ kv_heads = inner_dim // head_dim
1689
+
1690
+ # Reshape tensors for attention computation
1691
+ query = query.view(batch_size, -1, attn.heads, head_dim)
1692
+ key = key.view(batch_size, -1, kv_heads, head_dim)
1693
+ value = value.view(batch_size, -1, kv_heads, head_dim)
1694
+
1695
+ # Apply Query-Key normalization
1696
+ if attn.norm_q is not None:
1697
+ query = attn.norm_q(query)
1698
+ if attn.norm_k is not None:
1699
+ key = attn.norm_k(key)
1700
+
1701
+ # Apply Rotary Position Embeddings
1702
+ if image_rotary_emb is not None:
1703
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
1704
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
1705
+
1706
+ query, key = query.to(dtype), key.to(dtype)
1707
+
1708
+ # Calculate attention scale
1709
+ if base_sequence_length is not None:
1710
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
1711
+ else:
1712
+ softmax_scale = attn.scale
1713
+
1714
+ # Unpad input for flash attention
1715
+ (
1716
+ query_states,
1717
+ key_states,
1718
+ value_states,
1719
+ indices_q,
1720
+ cu_seq_lens,
1721
+ max_seq_lens,
1722
+ ) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads)
1723
+
1724
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1725
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1726
+
1727
+ # Handle different number of heads
1728
+ if kv_heads < attn.heads:
1729
+ key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
1730
+ value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads)
1731
+
1732
+ # Apply flash attention
1733
+ attn_output_unpad = flash_attn_varlen_func(
1734
+ query_states,
1735
+ key_states,
1736
+ value_states,
1737
+ cu_seqlens_q=cu_seqlens_q,
1738
+ cu_seqlens_k=cu_seqlens_k,
1739
+ max_seqlen_q=max_seqlen_in_batch_q,
1740
+ max_seqlen_k=max_seqlen_in_batch_k,
1741
+ dropout_p=0.0,
1742
+ causal=False,
1743
+ softmax_scale=softmax_scale,
1744
+ )
1745
+
1746
+ # Pad output and apply final transformations
1747
+ hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length)
1748
+ hidden_states = hidden_states.flatten(-2)
1749
+ hidden_states = hidden_states.type_as(query)
1750
+
1751
+ # Apply output projection
1752
+ hidden_states = attn.to_out[0](hidden_states)
1753
+ hidden_states = attn.to_out[1](hidden_states)
1754
+
1755
+ return hidden_states
1756
+
1757
+
1758
+ class ThinkGenTransformerBlock(nn.Module):
1759
+ """
1760
+ Transformer block for ThinkGen model.
1761
+
1762
+ This block implements a transformer layer with:
1763
+ - Multi-head attention with flash attention
1764
+ - Feed-forward network with SwiGLU activation
1765
+ - RMS normalization
1766
+ - Optional modulation for conditional generation
1767
+
1768
+ Args:
1769
+ dim: Dimension of the input and output tensors
1770
+ num_attention_heads: Number of attention heads
1771
+ num_kv_heads: Number of key-value heads
1772
+ multiple_of: Multiple of which the hidden dimension should be
1773
+ ffn_dim_multiplier: Multiplier for the feed-forward network dimension
1774
+ norm_eps: Epsilon value for normalization layers
1775
+ modulation: Whether to use modulation for conditional generation
1776
+ use_fused_rms_norm: Whether to use fused RMS normalization
1777
+ use_fused_swiglu: Whether to use fused SwiGLU activation
1778
+ """
1779
+
1780
+ def __init__(
1781
+ self,
1782
+ dim: int,
1783
+ num_attention_heads: int,
1784
+ num_kv_heads: int,
1785
+ multiple_of: int,
1786
+ ffn_dim_multiplier: float,
1787
+ norm_eps: float,
1788
+ modulation: bool = True,
1789
+ ) -> None:
1790
+ """Initialize the transformer block."""
1791
+ super().__init__()
1792
+ self.head_dim = dim // num_attention_heads
1793
+ self.modulation = modulation
1794
+
1795
+ try:
1796
+ processor = ThinkGenAttnProcessorFlash2Varlen()
1797
+ except ImportError:
1798
+ processor = ThinkGenAttnProcessor()
1799
+
1800
+ # Initialize attention layer
1801
+ self.attn = Attention(
1802
+ query_dim=dim,
1803
+ cross_attention_dim=None,
1804
+ dim_head=dim // num_attention_heads,
1805
+ qk_norm="rms_norm",
1806
+ heads=num_attention_heads,
1807
+ kv_heads=num_kv_heads,
1808
+ eps=1e-5,
1809
+ bias=False,
1810
+ out_bias=False,
1811
+ processor=processor,
1812
+ )
1813
+
1814
+ # Initialize feed-forward network
1815
+ self.feed_forward = LuminaFeedForward(
1816
+ dim=dim,
1817
+ inner_dim=4 * dim,
1818
+ multiple_of=multiple_of,
1819
+ ffn_dim_multiplier=ffn_dim_multiplier
1820
+ )
1821
+
1822
+ # Initialize normalization layers
1823
+ if modulation:
1824
+ self.norm1 = LuminaRMSNormZero(
1825
+ embedding_dim=dim,
1826
+ norm_eps=norm_eps,
1827
+ norm_elementwise_affine=True
1828
+ )
1829
+ else:
1830
+ self.norm1 = RMSNorm(dim, eps=norm_eps)
1831
+
1832
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
1833
+ self.norm2 = RMSNorm(dim, eps=norm_eps)
1834
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
1835
+
1836
+ self.initialize_weights()
1837
+
1838
+ def initialize_weights(self) -> None:
1839
+ """
1840
+ Initialize the weights of the transformer block.
1841
+
1842
+ Uses Xavier uniform initialization for linear layers and zero initialization for biases.
1843
+ """
1844
+ nn.init.xavier_uniform_(self.attn.to_q.weight)
1845
+ nn.init.xavier_uniform_(self.attn.to_k.weight)
1846
+ nn.init.xavier_uniform_(self.attn.to_v.weight)
1847
+ nn.init.xavier_uniform_(self.attn.to_out[0].weight)
1848
+
1849
+ nn.init.xavier_uniform_(self.feed_forward.linear_1.weight)
1850
+ nn.init.xavier_uniform_(self.feed_forward.linear_2.weight)
1851
+ nn.init.xavier_uniform_(self.feed_forward.linear_3.weight)
1852
+
1853
+ if self.modulation:
1854
+ nn.init.zeros_(self.norm1.linear.weight)
1855
+ nn.init.zeros_(self.norm1.linear.bias)
1856
+
1857
+ def forward(
1858
+ self,
1859
+ hidden_states: torch.Tensor,
1860
+ attention_mask: torch.Tensor,
1861
+ image_rotary_emb: torch.Tensor,
1862
+ temb: Optional[torch.Tensor] = None,
1863
+ ) -> torch.Tensor:
1864
+ """
1865
+ Forward pass of the transformer block.
1866
+
1867
+ Args:
1868
+ hidden_states: Input hidden states tensor
1869
+ attention_mask: Attention mask tensor
1870
+ image_rotary_emb: Rotary embeddings for image tokens
1871
+ temb: Optional timestep embedding tensor
1872
+
1873
+ Returns:
1874
+ torch.Tensor: Output hidden states after transformer block processing
1875
+ """
1876
+ enable_taylorseer = getattr(self, 'enable_taylorseer', False)
1877
+ if enable_taylorseer:
1878
+ if self.modulation:
1879
+ if temb is None:
1880
+ raise ValueError("temb must be provided when modulation is enabled")
1881
+
1882
+ if self.current['type'] == 'full':
1883
+ self.current['module'] = 'total'
1884
+ taylor_cache_init(cache_dic=self.cache_dic, current=self.current)
1885
+
1886
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
1887
+ attn_output = self.attn(
1888
+ hidden_states=norm_hidden_states,
1889
+ encoder_hidden_states=norm_hidden_states,
1890
+ attention_mask=attention_mask,
1891
+ image_rotary_emb=image_rotary_emb,
1892
+ )
1893
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
1894
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
1895
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
1896
+
1897
+ derivative_approximation(cache_dic=self.cache_dic, current=self.current, feature=hidden_states)
1898
+
1899
+ elif self.current['type'] == 'Taylor':
1900
+ self.current['module'] = 'total'
1901
+ hidden_states = taylor_formula(cache_dic=self.cache_dic, current=self.current)
1902
+ else:
1903
+ norm_hidden_states = self.norm1(hidden_states)
1904
+ attn_output = self.attn(
1905
+ hidden_states=norm_hidden_states,
1906
+ encoder_hidden_states=norm_hidden_states,
1907
+ attention_mask=attention_mask,
1908
+ image_rotary_emb=image_rotary_emb,
1909
+ )
1910
+ hidden_states = hidden_states + self.norm2(attn_output)
1911
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
1912
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
1913
+ else:
1914
+ if self.modulation:
1915
+ if temb is None:
1916
+ raise ValueError("temb must be provided when modulation is enabled")
1917
+
1918
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
1919
+ attn_output = self.attn(
1920
+ hidden_states=norm_hidden_states,
1921
+ encoder_hidden_states=norm_hidden_states,
1922
+ attention_mask=attention_mask,
1923
+ image_rotary_emb=image_rotary_emb,
1924
+ )
1925
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
1926
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
1927
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
1928
+ else:
1929
+ norm_hidden_states = self.norm1(hidden_states)
1930
+ attn_output = self.attn(
1931
+ hidden_states=norm_hidden_states,
1932
+ encoder_hidden_states=norm_hidden_states,
1933
+ attention_mask=attention_mask,
1934
+ image_rotary_emb=image_rotary_emb,
1935
+ )
1936
+ hidden_states = hidden_states + self.norm2(attn_output)
1937
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
1938
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
1939
+
1940
+ return hidden_states
1941
+
1942
+
1943
+ class ThinkGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
1944
+ """
1945
+ ThinkGen Transformer 2D Model.
1946
+
1947
+ A transformer-based diffusion model for image generation with:
1948
+ - Patch-based image processing
1949
+ - Rotary position embeddings
1950
+ - Multi-head attention
1951
+ - Conditional generation support
1952
+
1953
+ Args:
1954
+ patch_size: Size of image patches
1955
+ in_channels: Number of input channels
1956
+ out_channels: Number of output channels (defaults to in_channels)
1957
+ hidden_size: Size of hidden layers
1958
+ num_layers: Number of transformer layers
1959
+ num_refiner_layers: Number of refiner layers
1960
+ num_attention_heads: Number of attention heads
1961
+ num_kv_heads: Number of key-value heads
1962
+ multiple_of: Multiple of which the hidden dimension should be
1963
+ ffn_dim_multiplier: Multiplier for feed-forward network dimension
1964
+ norm_eps: Epsilon value for normalization layers
1965
+ axes_dim_rope: Dimensions for rotary position embeddings
1966
+ axes_lens: Lengths for rotary position embeddings
1967
+ text_feat_dim: Dimension of text features
1968
+ timestep_scale: Scale factor for timestep embeddings
1969
+ use_fused_rms_norm: Whether to use fused RMS normalization
1970
+ use_fused_swiglu: Whether to use fused SwiGLU activation
1971
+ """
1972
+
1973
+ _supports_gradient_checkpointing = True
1974
+ _no_split_modules = ["ThinkGenTransformerBlock"]
1975
+ _skip_layerwise_casting_patterns = ["x_embedder", "norm"]
1976
+
1977
+ @register_to_config
1978
+ def __init__(
1979
+ self,
1980
+ patch_size: int = 2,
1981
+ in_channels: int = 16,
1982
+ out_channels: Optional[int] = None,
1983
+ hidden_size: int = 2304,
1984
+ num_layers: int = 26,
1985
+ num_refiner_layers: int = 2,
1986
+ num_attention_heads: int = 24,
1987
+ num_kv_heads: int = 8,
1988
+ multiple_of: int = 256,
1989
+ ffn_dim_multiplier: Optional[float] = None,
1990
+ norm_eps: float = 1e-5,
1991
+ axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
1992
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
1993
+ text_feat_dim: int = 1024,
1994
+ timestep_scale: float = 1.0
1995
+ ) -> None:
1996
+ """Initialize the ThinkGen transformer model."""
1997
+ super().__init__()
1998
+
1999
+ # Validate configuration
2000
+ if (hidden_size // num_attention_heads) != sum(axes_dim_rope):
2001
+ raise ValueError(
2002
+ f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) "
2003
+ f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})"
2004
+ )
2005
+
2006
+ self.out_channels = out_channels or in_channels
2007
+
2008
+ # Initialize embeddings
2009
+ self.rope_embedder = ThinkGenRotaryPosEmbed(
2010
+ theta=10000,
2011
+ axes_dim=axes_dim_rope,
2012
+ axes_lens=axes_lens,
2013
+ patch_size=patch_size,
2014
+ )
2015
+
2016
+ self.x_embedder = nn.Linear(
2017
+ in_features=patch_size * patch_size * in_channels,
2018
+ out_features=hidden_size,
2019
+ )
2020
+
2021
+ self.ref_image_patch_embedder = nn.Linear(
2022
+ in_features=patch_size * patch_size * in_channels,
2023
+ out_features=hidden_size,
2024
+ )
2025
+
2026
+ self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
2027
+ hidden_size=hidden_size,
2028
+ text_feat_dim=text_feat_dim,
2029
+ norm_eps=norm_eps,
2030
+ timestep_scale=timestep_scale
2031
+ )
2032
+
2033
+ # Initialize transformer blocks
2034
+ self.noise_refiner = nn.ModuleList([
2035
+ ThinkGenTransformerBlock(
2036
+ hidden_size,
2037
+ num_attention_heads,
2038
+ num_kv_heads,
2039
+ multiple_of,
2040
+ ffn_dim_multiplier,
2041
+ norm_eps,
2042
+ modulation=True
2043
+ )
2044
+ for _ in range(num_refiner_layers)
2045
+ ])
2046
+
2047
+ self.ref_image_refiner = nn.ModuleList([
2048
+ ThinkGenTransformerBlock(
2049
+ hidden_size,
2050
+ num_attention_heads,
2051
+ num_kv_heads,
2052
+ multiple_of,
2053
+ ffn_dim_multiplier,
2054
+ norm_eps,
2055
+ modulation=True
2056
+ )
2057
+ for _ in range(num_refiner_layers)
2058
+ ])
2059
+
2060
+ self.context_refiner = nn.ModuleList(
2061
+ [
2062
+ ThinkGenTransformerBlock(
2063
+ hidden_size,
2064
+ num_attention_heads,
2065
+ num_kv_heads,
2066
+ multiple_of,
2067
+ ffn_dim_multiplier,
2068
+ norm_eps,
2069
+ modulation=False
2070
+ )
2071
+ for _ in range(num_refiner_layers)
2072
+ ]
2073
+ )
2074
+
2075
+ # 3. Transformer blocks
2076
+ self.layers = nn.ModuleList(
2077
+ [
2078
+ ThinkGenTransformerBlock(
2079
+ hidden_size,
2080
+ num_attention_heads,
2081
+ num_kv_heads,
2082
+ multiple_of,
2083
+ ffn_dim_multiplier,
2084
+ norm_eps,
2085
+ modulation=True
2086
+ )
2087
+ for _ in range(num_layers)
2088
+ ]
2089
+ )
2090
+
2091
+ # 4. Output norm & projection
2092
+ self.norm_out = LuminaLayerNormContinuous(
2093
+ embedding_dim=hidden_size,
2094
+ conditioning_embedding_dim=min(hidden_size, 1024),
2095
+ elementwise_affine=False,
2096
+ eps=1e-6,
2097
+ bias=True,
2098
+ out_dim=patch_size * patch_size * self.out_channels
2099
+ )
2100
+
2101
+ # Add learnable embeddings to distinguish different images
2102
+ self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images
2103
+
2104
+ self.gradient_checkpointing = False
2105
+
2106
+ self.initialize_weights()
2107
+
2108
+ # TeaCache settings
2109
+ self.enable_teacache = False
2110
+ self.teacache_rel_l1_thresh = 0.05
2111
+ self.teacache_params = TeaCacheParams()
2112
+
2113
+ coefficients = [-5.48259225, 11.48772289, -4.47407401, 2.47730926, -0.03316487]
2114
+ self.rescale_func = np.poly1d(coefficients)
2115
+
2116
+ self.prepad_embed = nn.Parameter(torch.randn(1, 23, 8192))
2117
+ print("add prepad_embed parameter ! ")
2118
+
2119
+ self.register_buffer('prepad_mask', torch.ones(1, 23).to(torch.int64))
2120
+
2121
+
2122
+ def initialize_weights(self) -> None:
2123
+ """
2124
+ Initialize the weights of the model.
2125
+
2126
+ Uses Xavier uniform initialization for linear layers.
2127
+ """
2128
+ nn.init.xavier_uniform_(self.x_embedder.weight)
2129
+ nn.init.constant_(self.x_embedder.bias, 0.0)
2130
+
2131
+ nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
2132
+ nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
2133
+
2134
+ nn.init.zeros_(self.norm_out.linear_1.weight)
2135
+ nn.init.zeros_(self.norm_out.linear_1.bias)
2136
+ nn.init.zeros_(self.norm_out.linear_2.weight)
2137
+ nn.init.zeros_(self.norm_out.linear_2.bias)
2138
+
2139
+ nn.init.normal_(self.image_index_embedding, std=0.02)
2140
+
2141
+ def img_patch_embed_and_refine(
2142
+ self,
2143
+ hidden_states,
2144
+ ref_image_hidden_states,
2145
+ padded_img_mask,
2146
+ padded_ref_img_mask,
2147
+ noise_rotary_emb,
2148
+ ref_img_rotary_emb,
2149
+ l_effective_ref_img_len,
2150
+ l_effective_img_len,
2151
+ temb
2152
+ ):
2153
+ batch_size = len(hidden_states)
2154
+ max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)])
2155
+
2156
+ hidden_states = self.x_embedder(hidden_states)
2157
+ ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
2158
+
2159
+ # 添加image_index_embedding
2160
+ for i in range(batch_size):
2161
+ shift = 0
2162
+ for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
2163
+ ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j]
2164
+ shift += ref_img_len
2165
+
2166
+ for layer in self.noise_refiner:
2167
+ hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
2168
+
2169
+ flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
2170
+ num_ref_images = len(flat_l_effective_ref_img_len)
2171
+ max_ref_img_len = max(flat_l_effective_ref_img_len)
2172
+
2173
+ batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
2174
+ batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size)
2175
+ batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype)
2176
+ batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
2177
+
2178
+ # sequence of ref imgs to batch
2179
+ idx = 0
2180
+ for i in range(batch_size):
2181
+ shift = 0
2182
+ for ref_img_len in l_effective_ref_img_len[i]:
2183
+ batch_ref_img_mask[idx, :ref_img_len] = True
2184
+ batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
2185
+ batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
2186
+ batch_temb[idx] = temb[i]
2187
+ shift += ref_img_len
2188
+ idx += 1
2189
+
2190
+ # refine ref imgs separately
2191
+ for layer in self.ref_image_refiner:
2192
+ batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb)
2193
+
2194
+ # batch of ref imgs to sequence
2195
+ idx = 0
2196
+ for i in range(batch_size):
2197
+ shift = 0
2198
+ for ref_img_len in l_effective_ref_img_len[i]:
2199
+ ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
2200
+ shift += ref_img_len
2201
+ idx += 1
2202
+
2203
+ combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size)
2204
+ for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
2205
+ combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
2206
+ combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
2207
+
2208
+ return combined_img_hidden_states
2209
+
2210
+ def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
2211
+ batch_size = len(hidden_states)
2212
+ p = self.config.patch_size
2213
+ device = hidden_states[0].device
2214
+
2215
+ img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
2216
+ l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
2217
+
2218
+ if ref_image_hidden_states is not None:
2219
+ ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states]
2220
+ l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
2221
+ else:
2222
+ ref_img_sizes = [None for _ in range(batch_size)]
2223
+ l_effective_ref_img_len = [[0] for _ in range(batch_size)]
2224
+
2225
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
2226
+ max_img_len = max(l_effective_img_len)
2227
+
2228
+ # ref image patch embeddings
2229
+ flat_ref_img_hidden_states = []
2230
+ for i in range(batch_size):
2231
+ if ref_img_sizes[i] is not None:
2232
+ imgs = []
2233
+ for ref_img in ref_image_hidden_states[i]:
2234
+ C, H, W = ref_img.size()
2235
+ ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
2236
+ imgs.append(ref_img)
2237
+
2238
+ img = torch.cat(imgs, dim=0)
2239
+ flat_ref_img_hidden_states.append(img)
2240
+ else:
2241
+ flat_ref_img_hidden_states.append(None)
2242
+
2243
+ # image patch embeddings
2244
+ flat_hidden_states = []
2245
+ for i in range(batch_size):
2246
+ img = hidden_states[i]
2247
+ C, H, W = img.size()
2248
+
2249
+ img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
2250
+ flat_hidden_states.append(img)
2251
+
2252
+ padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
2253
+ padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
2254
+ for i in range(batch_size):
2255
+ if ref_img_sizes[i] is not None:
2256
+ padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
2257
+ padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
2258
+
2259
+ padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
2260
+ padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
2261
+ for i in range(batch_size):
2262
+ padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
2263
+ padded_img_mask[i, :l_effective_img_len[i]] = True
2264
+
2265
+ return (
2266
+ padded_hidden_states,
2267
+ padded_ref_img_hidden_states,
2268
+ padded_img_mask,
2269
+ padded_ref_img_mask,
2270
+ l_effective_ref_img_len,
2271
+ l_effective_img_len,
2272
+ ref_img_sizes,
2273
+ img_sizes,
2274
+ )
2275
+
2276
+ def forward(
2277
+ self,
2278
+ hidden_states: Union[torch.Tensor, List[torch.Tensor]],
2279
+ timestep: torch.Tensor,
2280
+ text_hidden_states: torch.Tensor,
2281
+ freqs_cis: torch.Tensor,
2282
+ text_attention_mask: torch.Tensor,
2283
+ ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
2284
+ attention_kwargs: Optional[Dict[str, Any]] = None,
2285
+ return_dict: bool = False,
2286
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
2287
+ enable_taylorseer = getattr(self, 'enable_taylorseer', False)
2288
+
2289
+ # if self.prepad_embed.dtype != text_hidden_states.dtype:
2290
+ # self.prepad_embed = self.prepad_embed.to(text_hidden_states.dtype)
2291
+ # if self.prepad_mask.device != text_attention_mask.device:
2292
+ # self.prepad_mask = self.prepad_mask.to(text_attention_mask.device)
2293
+
2294
+ bs = text_hidden_states.shape[0]
2295
+ prepad_embed = self.prepad_embed.repeat(bs, 1, 1)
2296
+ prepad_mask = self.prepad_mask.repeat(bs, 1)
2297
+ text_hidden_states = torch.cat([prepad_embed, text_hidden_states], dim = 1)
2298
+ text_attention_mask = torch.cat([prepad_mask, text_attention_mask], dim = 1)
2299
+
2300
+
2301
+ if enable_taylorseer:
2302
+ cal_type(self.cache_dic, self.current)
2303
+
2304
+ if attention_kwargs is not None:
2305
+ attention_kwargs = attention_kwargs.copy()
2306
+ lora_scale = attention_kwargs.pop("scale", 1.0)
2307
+ else:
2308
+ lora_scale = 1.0
2309
+
2310
+ if USE_PEFT_BACKEND:
2311
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
2312
+ scale_lora_layers(self, lora_scale)
2313
+ else:
2314
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
2315
+ logger.warning(
2316
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
2317
+ )
2318
+
2319
+ # 1. Condition, positional & patch embedding
2320
+ batch_size = len(hidden_states)
2321
+ is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
2322
+
2323
+ if is_hidden_states_tensor:
2324
+ assert hidden_states.ndim == 4
2325
+ hidden_states = [_hidden_states for _hidden_states in hidden_states]
2326
+
2327
+ device = hidden_states[0].device
2328
+
2329
+ temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
2330
+
2331
+ (
2332
+ hidden_states,
2333
+ ref_image_hidden_states,
2334
+ img_mask,
2335
+ ref_img_mask,
2336
+ l_effective_ref_img_len,
2337
+ l_effective_img_len,
2338
+ ref_img_sizes,
2339
+ img_sizes,
2340
+ ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
2341
+
2342
+ (
2343
+ context_rotary_emb,
2344
+ ref_img_rotary_emb,
2345
+ noise_rotary_emb,
2346
+ rotary_emb,
2347
+ encoder_seq_lengths,
2348
+ seq_lengths,
2349
+ ) = self.rope_embedder(
2350
+ freqs_cis,
2351
+ text_attention_mask,
2352
+ l_effective_ref_img_len,
2353
+ l_effective_img_len,
2354
+ ref_img_sizes,
2355
+ img_sizes,
2356
+ device,
2357
+ )
2358
+
2359
+ # 2. Context refinement
2360
+ for layer in self.context_refiner:
2361
+ text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
2362
+
2363
+ combined_img_hidden_states = self.img_patch_embed_and_refine(
2364
+ hidden_states,
2365
+ ref_image_hidden_states,
2366
+ img_mask,
2367
+ ref_img_mask,
2368
+ noise_rotary_emb,
2369
+ ref_img_rotary_emb,
2370
+ l_effective_ref_img_len,
2371
+ l_effective_img_len,
2372
+ temb,
2373
+ )
2374
+
2375
+ # 3. Joint Transformer blocks (joint text embed 和 image embed)
2376
+ max_seq_len = max(seq_lengths)
2377
+
2378
+ attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
2379
+ joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
2380
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
2381
+ attention_mask[i, :seq_len] = True
2382
+ joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len]
2383
+ joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
2384
+
2385
+ hidden_states = joint_hidden_states
2386
+
2387
+ if self.enable_teacache:
2388
+ teacache_hidden_states = hidden_states.clone()
2389
+ teacache_temb = temb.clone()
2390
+ modulated_inp, _, _, _ = self.layers[0].norm1(teacache_hidden_states, teacache_temb)
2391
+ if self.teacache_params.is_first_or_last_step:
2392
+ should_calc = True
2393
+ self.teacache_params.accumulated_rel_l1_distance = 0
2394
+ else:
2395
+ self.teacache_params.accumulated_rel_l1_distance += self.rescale_func(
2396
+ ((modulated_inp - self.teacache_params.previous_modulated_inp).abs().mean() \
2397
+ / self.teacache_params.previous_modulated_inp.abs().mean()).cpu().item()
2398
+ )
2399
+ if self.teacache_params.accumulated_rel_l1_distance < self.teacache_rel_l1_thresh:
2400
+ should_calc = False
2401
+ else:
2402
+ should_calc = True
2403
+ self.teacache_params.accumulated_rel_l1_distance = 0
2404
+ self.teacache_params.previous_modulated_inp = modulated_inp
2405
+
2406
+ if self.enable_teacache:
2407
+ if not should_calc:
2408
+ hidden_states += self.teacache_params.previous_residual
2409
+ else:
2410
+ ori_hidden_states = hidden_states.clone()
2411
+ for layer_idx, layer in enumerate(self.layers):
2412
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2413
+ hidden_states = self._gradient_checkpointing_func(
2414
+ layer, hidden_states, attention_mask, rotary_emb, temb
2415
+ )
2416
+ else:
2417
+ hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
2418
+ self.teacache_params.previous_residual = hidden_states - ori_hidden_states
2419
+ else:
2420
+ if enable_taylorseer:
2421
+ self.current['stream'] = 'layers_stream'
2422
+
2423
+ for layer_idx, layer in enumerate(self.layers):
2424
+ if enable_taylorseer:
2425
+ layer.current = self.current
2426
+ layer.cache_dic = self.cache_dic
2427
+ layer.enable_taylorseer = True
2428
+ self.current['layer'] = layer_idx
2429
+
2430
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2431
+ hidden_states = self._gradient_checkpointing_func(
2432
+ layer, hidden_states, attention_mask, rotary_emb, temb
2433
+ )
2434
+ else:
2435
+ hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
2436
+
2437
+ # 4. Output norm & projection
2438
+ hidden_states = self.norm_out(hidden_states, temb)
2439
+
2440
+ p = self.config.patch_size
2441
+ output = []
2442
+ for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
2443
+ height, width = img_size
2444
+ output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p))
2445
+ if is_hidden_states_tensor:
2446
+ output = torch.stack(output, dim=0)
2447
+
2448
+ if USE_PEFT_BACKEND:
2449
+ # remove `lora_scale` from each PEFT layer
2450
+ unscale_lora_layers(self, lora_scale)
2451
+
2452
+ if enable_taylorseer:
2453
+ self.current['step'] += 1
2454
+
2455
+ if not return_dict:
2456
+ return output
2457
+ return Transformer2DModelOutput(sample=output)
vae/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bf4a6f861189e5647b58c1b532fee4a4ce602fda9ff2a744931d72c2f6c2fc3
3
+ size 840
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c717328c8ad41faab2ccfd52ae17332505c6833cf176aad56e7b58f2c4d4c94
3
+ size 335306212