orrzohar commited on
Commit
dc15cdf
·
verified ·
1 Parent(s): 110fbe9

Upload lumina_nextdit2d.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. lumina_nextdit2d.py +365 -0
lumina_nextdit2d.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.models.attention import LuminaFeedForward
21
+ from diffusers.models.attention_processor import Attention, LuminaAttnProcessor2_0
22
+ from diffusers.models.embeddings import LuminaCombinedTimestepCaptionEmbedding, LuminaPatchEmbed, PixArtAlphaTextProjection
23
+
24
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
27
+ from diffusers.utils import is_torch_version, logging
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ class LuminaNextDiTBlock(nn.Module):
33
+ """
34
+ A LuminaNextDiTBlock for LuminaNextDiT2DModel.
35
+
36
+ Parameters:
37
+ dim (`int`): Embedding dimension of the input features.
38
+ num_attention_heads (`int`): Number of attention heads.
39
+ num_kv_heads (`int`):
40
+ Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
41
+ multiple_of (`int`): The number of multiple of ffn layer.
42
+ ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension.
43
+ norm_eps (`float`): The eps for norm layer.
44
+ qk_norm (`bool`): normalization for query and key.
45
+ cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states.
46
+ norm_elementwise_affine (`bool`, *optional*, defaults to True),
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ dim: int,
52
+ num_attention_heads: int,
53
+ num_kv_heads: int,
54
+ multiple_of: int,
55
+ ffn_dim_multiplier: float,
56
+ norm_eps: float,
57
+ qk_norm: bool,
58
+ cross_attention_dim: int,
59
+ norm_elementwise_affine: bool = True,
60
+ ) -> None:
61
+ super().__init__()
62
+ self.head_dim = dim // num_attention_heads
63
+
64
+ self.gate = nn.Parameter(torch.zeros([num_attention_heads]))
65
+
66
+ # Self-attention
67
+ self.attn1 = Attention(
68
+ query_dim=dim,
69
+ cross_attention_dim=None,
70
+ dim_head=dim // num_attention_heads,
71
+ qk_norm="layer_norm_across_heads" if qk_norm else None,
72
+ heads=num_attention_heads,
73
+ kv_heads=num_kv_heads,
74
+ eps=1e-5,
75
+ bias=False,
76
+ out_bias=False,
77
+ processor=LuminaAttnProcessor2_0(),
78
+ )
79
+ self.attn1.to_out = nn.Identity()
80
+
81
+ # Cross-attention
82
+ self.attn2 = Attention(
83
+ query_dim=dim,
84
+ cross_attention_dim=cross_attention_dim,
85
+ dim_head=dim // num_attention_heads,
86
+ qk_norm="layer_norm_across_heads" if qk_norm else None,
87
+ heads=num_attention_heads,
88
+ kv_heads=num_kv_heads,
89
+ eps=1e-5,
90
+ bias=False,
91
+ out_bias=False,
92
+ processor=LuminaAttnProcessor2_0(),
93
+ )
94
+
95
+ self.feed_forward = LuminaFeedForward(
96
+ dim=dim,
97
+ inner_dim=4 * dim,
98
+ multiple_of=multiple_of,
99
+ ffn_dim_multiplier=ffn_dim_multiplier,
100
+ )
101
+
102
+ self.norm1 = LuminaRMSNormZero(
103
+ embedding_dim=dim,
104
+ norm_eps=norm_eps,
105
+ norm_elementwise_affine=norm_elementwise_affine,
106
+ )
107
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
108
+
109
+ self.norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
110
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
111
+
112
+ self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
113
+
114
+ def forward(
115
+ self,
116
+ hidden_states: torch.Tensor,
117
+ attention_mask: torch.Tensor,
118
+ image_rotary_emb: torch.Tensor,
119
+ encoder_hidden_states: torch.Tensor,
120
+ encoder_mask: torch.Tensor,
121
+ temb: torch.Tensor,
122
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
123
+ ):
124
+ """
125
+ Perform a forward pass through the LuminaNextDiTBlock.
126
+
127
+ Parameters:
128
+ hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock.
129
+ attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask.
130
+ image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies.
131
+ encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder.
132
+ encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask.
133
+ temb (`torch.Tensor`): Timestep embedding with text prompt embedding.
134
+ cross_attention_kwargs (`Dict[str, Any]`): kwargs for cross attention.
135
+ """
136
+ residual = hidden_states
137
+
138
+ # Self-attention
139
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
140
+ self_attn_output = self.attn1(
141
+ hidden_states=norm_hidden_states,
142
+ encoder_hidden_states=norm_hidden_states,
143
+ attention_mask=attention_mask,
144
+ query_rotary_emb=image_rotary_emb,
145
+ key_rotary_emb=image_rotary_emb,
146
+ **cross_attention_kwargs,
147
+ )
148
+
149
+ # Cross-attention
150
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
151
+ cross_attn_output = self.attn2(
152
+ hidden_states=norm_hidden_states,
153
+ encoder_hidden_states=norm_encoder_hidden_states,
154
+ attention_mask=encoder_mask,
155
+ query_rotary_emb=image_rotary_emb,
156
+ key_rotary_emb=None,
157
+ **cross_attention_kwargs,
158
+ )
159
+ cross_attn_output = cross_attn_output * self.gate.tanh().view(1, 1, -1, 1)
160
+ mixed_attn_output = self_attn_output + cross_attn_output
161
+ mixed_attn_output = mixed_attn_output.flatten(-2)
162
+ # linear proj
163
+ hidden_states = self.attn2.to_out[0](mixed_attn_output)
164
+
165
+ hidden_states = residual + gate_msa.unsqueeze(1).tanh() * self.norm2(hidden_states)
166
+
167
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
168
+
169
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
170
+
171
+ return hidden_states
172
+
173
+
174
+ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
175
+ """
176
+ LuminaNextDiT: Diffusion model with a Transformer backbone.
177
+
178
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
179
+
180
+ Parameters:
181
+ sample_size (`int`): The width of the latent images. This is fixed during training since
182
+ it is used to learn a number of position embeddings.
183
+ patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
184
+ The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
185
+ in_channels (`int`, *optional*, defaults to 4):
186
+ The number of input channels for the model. Typically, this matches the number of channels in the input
187
+ images.
188
+ hidden_size (`int`, *optional*, defaults to 4096):
189
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
190
+ hidden representations.
191
+ num_layers (`int`, *optional*, default to 32):
192
+ The number of layers in the model. This defines the depth of the neural network.
193
+ num_attention_heads (`int`, *optional*, defaults to 32):
194
+ The number of attention heads in each attention layer. This parameter specifies how many separate attention
195
+ mechanisms are used.
196
+ num_kv_heads (`int`, *optional*, defaults to 8):
197
+ The number of key-value heads in the attention mechanism, if different from the number of attention heads.
198
+ If None, it defaults to num_attention_heads.
199
+ multiple_of (`int`, *optional*, defaults to 256):
200
+ A factor that the hidden size should be a multiple of. This can help optimize certain hardware
201
+ configurations.
202
+ ffn_dim_multiplier (`float`, *optional*):
203
+ A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
204
+ the model configuration.
205
+ norm_eps (`float`, *optional*, defaults to 1e-5):
206
+ A small value added to the denominator for numerical stability in normalization layers.
207
+ learn_sigma (`bool`, *optional*, defaults to True):
208
+ Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in
209
+ predictions.
210
+ qk_norm (`bool`, *optional*, defaults to True):
211
+ Indicates if the queries and keys in the attention mechanism should be normalized.
212
+ cross_attention_dim (`int`, *optional*, defaults to 2048):
213
+ The dimensionality of the text embeddings. This parameter defines the size of the text representations used
214
+ in the model.
215
+ scaling_factor (`float`, *optional*, defaults to 1.0):
216
+ A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
217
+ overall scale of the model's operations.
218
+ """
219
+
220
+ _supports_gradient_checkpointing = True
221
+ _no_split_modules = ["LuminaNextDiTBlock"]
222
+
223
+ @register_to_config
224
+ def __init__(
225
+ self,
226
+ sample_size: int = 128,
227
+ patch_size: Optional[int] = 2,
228
+ in_channels: Optional[int] = 4,
229
+ hidden_size: Optional[int] = 2304,
230
+ num_layers: Optional[int] = 32, # 32
231
+ num_attention_heads: Optional[int] = 32, # 32
232
+ num_kv_heads: Optional[int] = None,
233
+ multiple_of: Optional[int] = 256,
234
+ ffn_dim_multiplier: Optional[float] = None,
235
+ norm_eps: Optional[float] = 1e-5,
236
+ learn_sigma: Optional[bool] = True,
237
+ qk_norm: Optional[bool] = True,
238
+ cross_attention_dim: Optional[int] = 2048,
239
+ scaling_factor: Optional[float] = 1.0,
240
+ ) -> None:
241
+ super().__init__()
242
+ self.sample_size = sample_size
243
+ self.patch_size = patch_size
244
+ self.in_channels = in_channels
245
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
246
+ self.hidden_size = hidden_size
247
+ self.num_attention_heads = num_attention_heads
248
+ self.head_dim = hidden_size // num_attention_heads
249
+ self.scaling_factor = scaling_factor
250
+ self.gradient_checkpointing = False
251
+
252
+ self.caption_projection = PixArtAlphaTextProjection(in_features=cross_attention_dim, hidden_size=hidden_size)
253
+ self.patch_embedder = LuminaPatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True)
254
+
255
+ self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding(hidden_size=min(hidden_size, 1024), cross_attention_dim=hidden_size)
256
+
257
+ self.layers = nn.ModuleList(
258
+ [
259
+ LuminaNextDiTBlock(
260
+ hidden_size,
261
+ num_attention_heads,
262
+ num_kv_heads,
263
+ multiple_of,
264
+ ffn_dim_multiplier,
265
+ norm_eps,
266
+ qk_norm,
267
+ hidden_size,
268
+ )
269
+ for _ in range(num_layers)
270
+ ]
271
+ )
272
+ self.norm_out = LuminaLayerNormContinuous(
273
+ embedding_dim=hidden_size,
274
+ conditioning_embedding_dim=min(hidden_size, 1024),
275
+ elementwise_affine=False,
276
+ eps=1e-6,
277
+ bias=True,
278
+ out_dim=patch_size * patch_size * self.out_channels,
279
+ )
280
+ # self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels)
281
+
282
+ assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
283
+
284
+ def _set_gradient_checkpointing(self, module, value=False):
285
+ if hasattr(module, "gradient_checkpointing"):
286
+ module.gradient_checkpointing = value
287
+
288
+ def forward(
289
+ self,
290
+ hidden_states: torch.Tensor,
291
+ timestep: torch.Tensor,
292
+ encoder_hidden_states: torch.Tensor,
293
+ encoder_mask: torch.Tensor,
294
+ image_rotary_emb: torch.Tensor,
295
+ cross_attention_kwargs: Dict[str, Any] = None,
296
+ return_dict=True,
297
+ ) -> torch.Tensor:
298
+ """
299
+ Forward pass of LuminaNextDiT.
300
+
301
+ Parameters:
302
+ hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W).
303
+ timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,).
304
+ encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D).
305
+ encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L).
306
+ """
307
+ hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)
308
+ image_rotary_emb = image_rotary_emb.to(hidden_states.device)
309
+ # breakpoint()
310
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
311
+ temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)
312
+
313
+ encoder_mask = encoder_mask.bool()
314
+
315
+ for layer in self.layers:
316
+ if self.training and self.gradient_checkpointing:
317
+
318
+ def create_custom_forward(module, return_dict=None):
319
+ def custom_forward(*inputs):
320
+ if return_dict is not None:
321
+ return module(*inputs, return_dict=return_dict)
322
+ else:
323
+ return module(*inputs)
324
+
325
+ return custom_forward
326
+
327
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
328
+ hidden_states = torch.utils.checkpoint.checkpoint(
329
+ create_custom_forward(layer),
330
+ hidden_states,
331
+ mask,
332
+ image_rotary_emb,
333
+ encoder_hidden_states,
334
+ encoder_mask,
335
+ temb,
336
+ cross_attention_kwargs,
337
+ **ckpt_kwargs,
338
+ )
339
+ else:
340
+ hidden_states = layer(
341
+ hidden_states,
342
+ mask,
343
+ image_rotary_emb,
344
+ encoder_hidden_states,
345
+ encoder_mask,
346
+ temb=temb,
347
+ cross_attention_kwargs=cross_attention_kwargs,
348
+ )
349
+
350
+ hidden_states = self.norm_out(hidden_states, temb)
351
+
352
+ # unpatchify
353
+ height_tokens = width_tokens = self.patch_size
354
+ height, width = img_size[0]
355
+ batch_size = hidden_states.size(0)
356
+ sequence_length = (height // height_tokens) * (width // width_tokens)
357
+ hidden_states = hidden_states[:, :sequence_length].view(
358
+ batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
359
+ )
360
+ output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
361
+
362
+ if not return_dict:
363
+ return (output,)
364
+
365
+ return Transformer2DModelOutput(sample=output)