Shaoan commited on
Commit
c867393
·
verified ·
1 Parent(s): e476a4b

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. layers.py +734 -0
layers.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Flux
2
+ #
3
+ # Copyright 2024 Black Forest Labs
4
+
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ # This source code is licensed under the license found in the
18
+ # LICENSE file in the root directory of this source tree.
19
+
20
+ import math # noqa: I001
21
+ from dataclasses import dataclass
22
+ from functools import partial
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from einops import rearrange
27
+ from torch import Tensor, nn
28
+ from torch.utils.checkpoint import checkpoint
29
+
30
+
31
+
32
+ def to_cuda(x):
33
+ if isinstance(x, torch.Tensor):
34
+ return x.cuda()
35
+ elif isinstance(x, (list, tuple)):
36
+ return [to_cuda(elem) for elem in x]
37
+ elif isinstance(x, dict):
38
+ return {k: to_cuda(v) for k, v in x.items()}
39
+ else:
40
+ return x
41
+
42
+
43
+ def to_cpu(x):
44
+ if isinstance(x, torch.Tensor):
45
+ return x.cpu()
46
+ elif isinstance(x, (list, tuple)):
47
+ return [to_cpu(elem) for elem in x]
48
+ elif isinstance(x, dict):
49
+ return {k: to_cpu(v) for k, v in x.items()}
50
+ else:
51
+ return x
52
+
53
+
54
+ MEMORY_LAYOUT = {
55
+ "flash": (
56
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
57
+ lambda x: x,
58
+ ),
59
+ "torch": (
60
+ lambda x: x.transpose(1, 2),
61
+ lambda x: x.transpose(1, 2),
62
+ ),
63
+ "vanilla": (
64
+ lambda x: x.transpose(1, 2),
65
+ lambda x: x.transpose(1, 2),
66
+ ),
67
+ }
68
+
69
+
70
+ def attention(
71
+ q,
72
+ k,
73
+ v,
74
+ mode="flash",
75
+ drop_rate=0,
76
+ attn_mask=None,
77
+ causal=False,
78
+ cu_seqlens_q=None,
79
+ cu_seqlens_kv=None,
80
+ max_seqlen_q=None,
81
+ max_seqlen_kv=None,
82
+ batch_size=1,
83
+ ):
84
+ """
85
+ Perform QKV self attention.
86
+
87
+ Args:
88
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
89
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
90
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
91
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
92
+ drop_rate (float): Dropout rate in attention map. (default: 0)
93
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
94
+ (default: None)
95
+ causal (bool): Whether to use causal attention. (default: False)
96
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
97
+ used to index into q.
98
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
99
+ used to index into kv.
100
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
101
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
102
+
103
+ Returns:
104
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
105
+ """
106
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
107
+ q = pre_attn_layout(q)
108
+ k = pre_attn_layout(k)
109
+ v = pre_attn_layout(v)
110
+
111
+ if mode == "torch":
112
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
113
+ attn_mask = attn_mask.to(q.dtype)
114
+ x = F.scaled_dot_product_attention(
115
+ q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
116
+ )
117
+ elif mode == "flash":
118
+ assert flash_attn_varlen_func is not None
119
+ x: torch.Tensor = flash_attn_varlen_func(
120
+ q,
121
+ k,
122
+ v,
123
+ cu_seqlens_q,
124
+ cu_seqlens_kv,
125
+ max_seqlen_q,
126
+ max_seqlen_kv,
127
+ ) # type: ignore
128
+ # x with shape [(bxs), a, d]
129
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # type: ignore # reshape x to [b, s, a, d]
130
+ elif mode == "vanilla":
131
+ scale_factor = 1 / math.sqrt(q.size(-1))
132
+
133
+ b, a, s, _ = q.shape
134
+ s1 = k.size(2)
135
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
136
+ if causal:
137
+ # Only applied to self attention
138
+ assert attn_mask is None, (
139
+ "Causal mask and attn_mask cannot be used together"
140
+ )
141
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
142
+ diagonal=0
143
+ )
144
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
145
+ attn_bias.to(q.dtype)
146
+
147
+ if attn_mask is not None:
148
+ if attn_mask.dtype == torch.bool:
149
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
150
+ else:
151
+ attn_bias += attn_mask
152
+
153
+ # TODO: Maybe force q and k to be float32 to avoid numerical overflow
154
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
155
+ attn += attn_bias
156
+ attn = attn.softmax(dim=-1)
157
+ attn = torch.dropout(attn, p=drop_rate, train=True)
158
+ x = attn @ v
159
+ else:
160
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
161
+
162
+ x = post_attn_layout(x)
163
+ b, s, a, d = x.shape
164
+ out = x.reshape(b, s, -1)
165
+ return out
166
+
167
+
168
+ def apply_gate(x, gate=None, tanh=False):
169
+ """AI is creating summary for apply_gate
170
+
171
+ Args:
172
+ x (torch.Tensor): input tensor.
173
+ gate (torch.Tensor, optional): gate tensor. Defaults to None.
174
+ tanh (bool, optional): whether to use tanh function. Defaults to False.
175
+
176
+ Returns:
177
+ torch.Tensor: the output tensor after apply gate.
178
+ """
179
+ if gate is None:
180
+ return x
181
+ if tanh:
182
+ return x * gate.unsqueeze(1).tanh()
183
+ else:
184
+ return x * gate.unsqueeze(1)
185
+
186
+
187
+ class MLP(nn.Module):
188
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
189
+
190
+ def __init__(
191
+ self,
192
+ in_channels,
193
+ hidden_channels=None,
194
+ out_features=None,
195
+ act_layer=nn.GELU,
196
+ norm_layer=None,
197
+ bias=True,
198
+ drop=0.0,
199
+ use_conv=False,
200
+ device=None,
201
+ dtype=None,
202
+ ):
203
+ super().__init__()
204
+ out_features = out_features or in_channels
205
+ hidden_channels = hidden_channels or in_channels
206
+ bias = (bias, bias)
207
+ drop_probs = (drop, drop)
208
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
209
+
210
+ self.fc1 = linear_layer(
211
+ in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype
212
+ )
213
+ self.act = act_layer()
214
+ self.drop1 = nn.Dropout(drop_probs[0])
215
+ self.norm = (
216
+ norm_layer(hidden_channels, device=device, dtype=dtype)
217
+ if norm_layer is not None
218
+ else nn.Identity()
219
+ )
220
+ self.fc2 = linear_layer(
221
+ hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype
222
+ )
223
+ self.drop2 = nn.Dropout(drop_probs[1])
224
+
225
+ def forward(self, x):
226
+ x = self.fc1(x)
227
+ x = self.act(x)
228
+ x = self.drop1(x)
229
+ x = self.norm(x)
230
+ x = self.fc2(x)
231
+ x = self.drop2(x)
232
+ return x
233
+
234
+
235
+ class TextProjection(nn.Module):
236
+ """
237
+ Projects text embeddings. Also handles dropout for classifier-free guidance.
238
+
239
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
240
+ """
241
+
242
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
243
+ factory_kwargs = {"dtype": dtype, "device": device}
244
+ super().__init__()
245
+ self.linear_1 = nn.Linear(
246
+ in_features=in_channels,
247
+ out_features=hidden_size,
248
+ bias=True,
249
+ **factory_kwargs,
250
+ )
251
+ self.act_1 = act_layer()
252
+ self.linear_2 = nn.Linear(
253
+ in_features=hidden_size,
254
+ out_features=hidden_size,
255
+ bias=True,
256
+ **factory_kwargs,
257
+ )
258
+
259
+ def forward(self, caption):
260
+ hidden_states = self.linear_1(caption)
261
+ hidden_states = self.act_1(hidden_states)
262
+ hidden_states = self.linear_2(hidden_states)
263
+ return hidden_states
264
+
265
+
266
+ class TimestepEmbedder(nn.Module):
267
+ """
268
+ Embeds scalar timesteps into vector representations.
269
+ """
270
+
271
+ def __init__(
272
+ self,
273
+ hidden_size,
274
+ act_layer,
275
+ frequency_embedding_size=256,
276
+ max_period=10000,
277
+ out_size=None,
278
+ dtype=None,
279
+ device=None,
280
+ ):
281
+ factory_kwargs = {"dtype": dtype, "device": device}
282
+ super().__init__()
283
+ self.frequency_embedding_size = frequency_embedding_size
284
+ self.max_period = max_period
285
+ if out_size is None:
286
+ out_size = hidden_size
287
+
288
+ self.mlp = nn.Sequential(
289
+ nn.Linear(
290
+ frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
291
+ ),
292
+ act_layer(),
293
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
294
+ )
295
+ nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore
296
+ nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore
297
+
298
+ @staticmethod
299
+ def timestep_embedding(t, dim, max_period=10000):
300
+ """
301
+ Create sinusoidal timestep embeddings.
302
+
303
+ Args:
304
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
305
+ dim (int): the dimension of the output.
306
+ max_period (int): controls the minimum frequency of the embeddings.
307
+
308
+ Returns:
309
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
310
+
311
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
312
+ """
313
+ half = dim // 2
314
+ freqs = torch.exp(
315
+ -math.log(max_period)
316
+ * torch.arange(start=0, end=half, dtype=torch.float32)
317
+ / half
318
+ ).to(device=t.device)
319
+ args = t[:, None].float() * freqs[None]
320
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
321
+ if dim % 2:
322
+ embedding = torch.cat(
323
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
324
+ )
325
+ return embedding
326
+
327
+ def forward(self, t):
328
+ t_freq = self.timestep_embedding(
329
+ t, self.frequency_embedding_size, self.max_period
330
+ ).type(self.mlp[0].weight.dtype) # type: ignore
331
+ t_emb = self.mlp(t_freq)
332
+ return t_emb
333
+
334
+
335
+ class EmbedND(nn.Module):
336
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
337
+ super().__init__()
338
+ self.dim = dim
339
+ self.theta = theta
340
+ self.axes_dim = axes_dim
341
+
342
+ def forward(self, ids: Tensor) -> Tensor:
343
+ n_axes = ids.shape[-1]
344
+ emb = torch.cat(
345
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
346
+ dim=-3,
347
+ )
348
+
349
+ return emb.unsqueeze(1)
350
+
351
+
352
+ class MLPEmbedder(nn.Module):
353
+ def __init__(self, in_dim: int, hidden_dim: int):
354
+ super().__init__()
355
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
356
+ self.silu = nn.SiLU()
357
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
358
+
359
+ self.gradient_checkpointing = False
360
+
361
+ def enable_gradient_checkpointing(self):
362
+ self.gradient_checkpointing = True
363
+
364
+ def disable_gradient_checkpointing(self):
365
+ self.gradient_checkpointing = False
366
+
367
+ def _forward(self, x: Tensor) -> Tensor:
368
+ return self.out_layer(self.silu(self.in_layer(x)))
369
+
370
+ def forward(self, *args, **kwargs):
371
+ if self.training and self.gradient_checkpointing:
372
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
373
+ else:
374
+ return self._forward(*args, **kwargs)
375
+
376
+
377
+ def rope(pos, dim: int, theta: int):
378
+ assert dim % 2 == 0
379
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
380
+ omega = 1.0 / (theta ** scale)
381
+ out = torch.einsum("...n,d->...nd", pos, omega)
382
+ out = torch.stack(
383
+ [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1
384
+ )
385
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
386
+ return out.float()
387
+
388
+
389
+ def attention_after_rope(q, k, v, pe, mode):
390
+ q, k = apply_rope(q, k, pe)
391
+
392
+ from .attention import attention
393
+
394
+ x = attention(q, k, v, mode)
395
+ return x
396
+
397
+
398
+ @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
399
+ def apply_rope(xq, xk, freqs_cis):
400
+ # 将 num_heads 和 seq_len 的维度交换回原函数的处理顺序
401
+ xq = xq.transpose(1, 2) # [batch, num_heads, seq_len, head_dim]
402
+ xk = xk.transpose(1, 2)
403
+
404
+ # 将 head_dim 拆分为复数部分(实部和虚部)
405
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
406
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
407
+
408
+ # 应用旋转位置编码(复数乘法)
409
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
410
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
411
+
412
+ # 恢复张量形状并转置回目标维度顺序
413
+ xq_out = xq_out.reshape(*xq.shape).type_as(xq).transpose(1, 2)
414
+ xk_out = xk_out.reshape(*xk.shape).type_as(xk).transpose(1, 2)
415
+
416
+ return xq_out, xk_out
417
+
418
+
419
+ @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
420
+ def scale_add_residual(
421
+ x: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor
422
+ ) -> torch.Tensor:
423
+ return x * scale + residual
424
+
425
+
426
+ @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
427
+ def layernorm_and_scale_shift(
428
+ x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor
429
+ ) -> torch.Tensor:
430
+ return torch.nn.functional.layer_norm(x, (x.size(-1),)) * (scale + 1) + shift
431
+
432
+
433
+ class SelfAttention(nn.Module):
434
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
435
+ super().__init__()
436
+ self.num_heads = num_heads
437
+ head_dim = dim // num_heads
438
+
439
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
440
+ self.norm = QKNorm(head_dim)
441
+ self.proj = nn.Linear(dim, dim)
442
+
443
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
444
+ qkv = self.qkv(x)
445
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
446
+ q, k = self.norm(q, k, v)
447
+ x = attention_after_rope(q, k, v, pe=pe)
448
+ x = self.proj(x)
449
+ return x
450
+
451
+
452
+ @dataclass
453
+ class ModulationOut:
454
+ shift: Tensor
455
+ scale: Tensor
456
+ gate: Tensor
457
+
458
+
459
+ class RMSNorm(torch.nn.Module):
460
+ def __init__(self, dim: int):
461
+ super().__init__()
462
+ self.scale = nn.Parameter(torch.ones(dim))
463
+
464
+ @staticmethod
465
+ def rms_norm_fast(x, weight, eps):
466
+ return LigerRMSNormFunction.apply(
467
+ x,
468
+ weight,
469
+ eps,
470
+ 0.0,
471
+ "gemma",
472
+ True,
473
+ )
474
+
475
+ @staticmethod
476
+ def rms_norm(x, weight, eps):
477
+ x_dtype = x.dtype
478
+ x = x.float()
479
+ rrms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps)
480
+ return (x * rrms).to(dtype=x_dtype) * weight
481
+
482
+ def forward(self, x: Tensor):
483
+ return self.rms_norm_fast(x, self.scale.to(x.dtype), 1e-6)
484
+
485
+
486
+ class QKNorm(torch.nn.Module):
487
+ def __init__(self, dim: int):
488
+ super().__init__()
489
+ self.query_norm = RMSNorm(dim)
490
+ self.key_norm = RMSNorm(dim)
491
+
492
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
493
+ q = self.query_norm(q)
494
+ k = self.key_norm(k)
495
+ return q.to(v), k.to(v)
496
+
497
+
498
+ class Modulation(nn.Module):
499
+ def __init__(self, dim: int, double: bool):
500
+ super().__init__()
501
+ self.is_double = double
502
+ self.multiplier = 6 if double else 3
503
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
504
+
505
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
506
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(
507
+ self.multiplier, dim=-1
508
+ )
509
+
510
+ return (
511
+ ModulationOut(*out[:3]),
512
+ ModulationOut(*out[3:]) if self.is_double else None,
513
+ )
514
+
515
+
516
+ class DoubleStreamBlock(nn.Module):
517
+ def __init__(
518
+ self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, mode: str = "flash"
519
+ ):
520
+ super().__init__()
521
+
522
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
523
+ self.num_heads = num_heads
524
+ self.hidden_size = hidden_size
525
+ self.img_mod = Modulation(hidden_size, double=True)
526
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
527
+ self.img_attn = SelfAttention(
528
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
529
+ )
530
+
531
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
532
+ self.img_mlp = nn.Sequential(
533
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
534
+ nn.GELU(approximate="tanh"),
535
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
536
+ )
537
+
538
+ self.txt_mod = Modulation(hidden_size, double=True)
539
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
540
+ self.txt_attn = SelfAttention(
541
+ dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
542
+ )
543
+
544
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
545
+ self.txt_mlp = nn.Sequential(
546
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
547
+ nn.GELU(approximate="tanh"),
548
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
549
+ )
550
+
551
+ self.mode = mode
552
+ self.gradient_checkpointing = False
553
+ self.cpu_offload_checkpointing = False
554
+
555
+ def enable_gradient_checkpointing(self, cpu_offload: bool = False):
556
+ self.gradient_checkpointing = True
557
+ self.cpu_offload_checkpointing = cpu_offload
558
+
559
+ def disable_gradient_checkpointing(self):
560
+ self.gradient_checkpointing = False
561
+ self.cpu_offload_checkpointing = False
562
+
563
+ def _forward(
564
+ self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
565
+ ) -> tuple[Tensor, Tensor]:
566
+ img_mod1, img_mod2 = self.img_mod(vec)
567
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
568
+
569
+ # prepare image for attention
570
+ img_modulated = self.img_norm1(img)
571
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
572
+ img_qkv = self.img_attn.qkv(img_modulated)
573
+ img_q, img_k, img_v = rearrange(
574
+ img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
575
+ )
576
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
577
+
578
+ # prepare txt for attention
579
+ txt_modulated = self.txt_norm1(txt)
580
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
581
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
582
+ txt_q, txt_k, txt_v = rearrange(
583
+ txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
584
+ )
585
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
586
+
587
+ # run actual attention
588
+ q = torch.cat((txt_q, img_q), dim=1)
589
+ k = torch.cat((txt_k, img_k), dim=1)
590
+ v = torch.cat((txt_v, img_v), dim=1)
591
+
592
+ attn = attention_after_rope(q, k, v, pe, self.mode)
593
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
594
+
595
+ # calculate the img bloks
596
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
597
+ img_mlp = self.img_mlp(
598
+ (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
599
+ )
600
+ img = scale_add_residual(img_mlp, img_mod2.gate, img)
601
+
602
+ # calculate the txt bloks
603
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
604
+ txt_mlp = self.txt_mlp(
605
+ (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
606
+ )
607
+ txt = scale_add_residual(txt_mlp, txt_mod2.gate, txt)
608
+ return img, txt
609
+
610
+ def forward(
611
+ self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
612
+ ) -> tuple[Tensor, Tensor]:
613
+ if self.training and self.gradient_checkpointing:
614
+ if not self.cpu_offload_checkpointing:
615
+ return checkpoint(self._forward, img, txt, vec, pe, use_reentrant=False)
616
+
617
+ # cpu offload checkpointing
618
+
619
+ def create_custom_forward(func):
620
+ def custom_forward(*inputs):
621
+ cuda_inputs = to_cuda(inputs)
622
+ outputs = func(*cuda_inputs)
623
+ return to_cpu(outputs)
624
+
625
+ return custom_forward
626
+
627
+ return torch.utils.checkpoint.checkpoint(
628
+ create_custom_forward(self._forward), img, txt, vec, pe, use_reentrant=False
629
+ )
630
+
631
+ else:
632
+ return self._forward(img, txt, vec, pe)
633
+
634
+
635
+ class SingleStreamBlock(nn.Module):
636
+ """
637
+ A DiT block with parallel linear layers as described in
638
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
639
+ """
640
+
641
+ def __init__(
642
+ self,
643
+ hidden_size: int,
644
+ num_heads: int,
645
+ mlp_ratio: float = 4.0,
646
+ qk_scale: float | None = None,
647
+ mode: str = "flash"
648
+ ):
649
+ super().__init__()
650
+ self.hidden_dim = hidden_size
651
+ self.num_heads = num_heads
652
+ head_dim = hidden_size // num_heads
653
+ self.scale = qk_scale or head_dim ** -0.5
654
+
655
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
656
+ # qkv and mlp_in
657
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
658
+ # proj and mlp_out
659
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
660
+
661
+ self.norm = QKNorm(head_dim)
662
+
663
+ self.hidden_size = hidden_size
664
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
665
+
666
+ self.mlp_act = nn.GELU(approximate="tanh")
667
+ self.modulation = Modulation(hidden_size, double=False)
668
+
669
+ self.mode = mode
670
+ self.gradient_checkpointing = False
671
+ self.cpu_offload_checkpointing = False
672
+
673
+ def enable_gradient_checkpointing(self, cpu_offload: bool = False):
674
+ self.gradient_checkpointing = True
675
+ self.cpu_offload_checkpointing = cpu_offload
676
+
677
+ def disable_gradient_checkpointing(self):
678
+ self.gradient_checkpointing = False
679
+ self.cpu_offload_checkpointing = False
680
+
681
+ def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
682
+ mod, _ = self.modulation(vec)
683
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
684
+ qkv, mlp = torch.split(
685
+ self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
686
+ )
687
+
688
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
689
+ q, k = self.norm(q, k, v)
690
+
691
+ # compute attention
692
+ attn = attention_after_rope(q, k, v, pe, self.mode)
693
+ # compute activation in mlp stream, cat again and run second linear layer
694
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
695
+ return scale_add_residual(output, mod.gate, x)
696
+
697
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
698
+ if self.training and self.gradient_checkpointing:
699
+ if not self.cpu_offload_checkpointing:
700
+ return checkpoint(self._forward, x, vec, pe, use_reentrant=False)
701
+
702
+ # cpu offload checkpointing
703
+
704
+ def create_custom_forward(func):
705
+ def custom_forward(*inputs):
706
+ cuda_inputs = to_cuda(inputs)
707
+ outputs = func(*cuda_inputs)
708
+ return to_cpu(outputs)
709
+
710
+ return custom_forward
711
+
712
+ return torch.utils.checkpoint.checkpoint(
713
+ create_custom_forward(self._forward), x, vec, pe, use_reentrant=False
714
+ )
715
+ else:
716
+ return self._forward(x, vec, pe)
717
+
718
+
719
+ class LastLayer(nn.Module):
720
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
721
+ super().__init__()
722
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
723
+ self.linear = nn.Linear(
724
+ hidden_size, patch_size * patch_size * out_channels, bias=True
725
+ )
726
+ self.adaLN_modulation = nn.Sequential(
727
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
728
+ )
729
+
730
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
731
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
732
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
733
+ x = self.linear(x)
734
+ return x