AlexGraikos commited on
Commit
0691f24
·
verified ·
1 Parent(s): 9aa4d51

Upload pixcell_transformer_2d.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pixcell_transformer_2d.py +467 -0
pixcell_transformer_2d.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Union
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.utils import is_torch_version, logging
21
+ from diffusers.models.attention import BasicTransformerBlock
22
+ from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
23
+ from diffusers.models.embeddings import PatchEmbed
24
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.normalization import AdaLayerNormSingle
27
+
28
+ from .embeddings_pixcell import PixcellUNIProjection, UNIPosEmbed
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ class PixCellTransformer2DModel(ModelMixin, ConfigMixin):
34
+ r"""
35
+ A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
36
+ https://arxiv.org/abs/2403.04692). Modified for the pathology domain.
37
+
38
+ Parameters:
39
+ num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
40
+ attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
41
+ in_channels (int, defaults to 4): The number of channels in the input.
42
+ out_channels (int, optional):
43
+ The number of channels in the output. Specify this parameter if the output channel number differs from the
44
+ input.
45
+ num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
46
+ dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
47
+ norm_num_groups (int, optional, defaults to 32):
48
+ Number of groups for group normalization within Transformer blocks.
49
+ cross_attention_dim (int, optional):
50
+ The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension.
51
+ attention_bias (bool, optional, defaults to True):
52
+ Configure if the Transformer blocks' attention should contain a bias parameter.
53
+ sample_size (int, defaults to 128):
54
+ The width of the latent images. This parameter is fixed during training.
55
+ patch_size (int, defaults to 2):
56
+ Size of the patches the model processes, relevant for architectures working on non-sequential data.
57
+ activation_fn (str, optional, defaults to "gelu-approximate"):
58
+ Activation function to use in feed-forward networks within Transformer blocks.
59
+ num_embeds_ada_norm (int, optional, defaults to 1000):
60
+ Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
61
+ inference.
62
+ upcast_attention (bool, optional, defaults to False):
63
+ If true, upcasts the attention mechanism dimensions for potentially improved performance.
64
+ norm_type (str, optional, defaults to "ada_norm_zero"):
65
+ Specifies the type of normalization used, can be 'ada_norm_zero'.
66
+ norm_elementwise_affine (bool, optional, defaults to False):
67
+ If true, enables element-wise affine parameters in the normalization layers.
68
+ norm_eps (float, optional, defaults to 1e-6):
69
+ A small constant added to the denominator in normalization layers to prevent division by zero.
70
+ interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings.
71
+ use_additional_conditions (bool, optional): If we're using additional conditions as inputs.
72
+ attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used.
73
+ caption_channels (int, optional, defaults to None):
74
+ Number of channels to use for projecting the caption embeddings.
75
+ use_linear_projection (bool, optional, defaults to False):
76
+ Deprecated argument. Will be removed in a future version.
77
+ num_vector_embeds (bool, optional, defaults to False):
78
+ Deprecated argument. Will be removed in a future version.
79
+ """
80
+
81
+ _supports_gradient_checkpointing = True
82
+ _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
83
+
84
+ @register_to_config
85
+ def __init__(
86
+ self,
87
+ num_attention_heads: int = 16,
88
+ attention_head_dim: int = 72,
89
+ in_channels: int = 4,
90
+ out_channels: Optional[int] = 8,
91
+ num_layers: int = 28,
92
+ dropout: float = 0.0,
93
+ norm_num_groups: int = 32,
94
+ cross_attention_dim: Optional[int] = 1152,
95
+ attention_bias: bool = True,
96
+ sample_size: int = 128,
97
+ patch_size: int = 2,
98
+ activation_fn: str = "gelu-approximate",
99
+ num_embeds_ada_norm: Optional[int] = 1000,
100
+ upcast_attention: bool = False,
101
+ norm_type: str = "ada_norm_single",
102
+ norm_elementwise_affine: bool = False,
103
+ norm_eps: float = 1e-6,
104
+ interpolation_scale: Optional[int] = None,
105
+ use_additional_conditions: Optional[bool] = None,
106
+ caption_channels: Optional[int] = None,
107
+ caption_num_tokens: int = 1,
108
+ attention_type: Optional[str] = "default",
109
+ ):
110
+ super().__init__()
111
+
112
+ # Validate inputs.
113
+ if norm_type != "ada_norm_single":
114
+ raise NotImplementedError(
115
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
116
+ )
117
+ elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None:
118
+ raise ValueError(
119
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
120
+ )
121
+
122
+ # Set some common variables used across the board.
123
+ self.attention_head_dim = attention_head_dim
124
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
125
+ self.out_channels = in_channels if out_channels is None else out_channels
126
+ if use_additional_conditions is None:
127
+ if sample_size == 128:
128
+ use_additional_conditions = True
129
+ else:
130
+ use_additional_conditions = False
131
+ self.use_additional_conditions = use_additional_conditions
132
+
133
+ self.gradient_checkpointing = False
134
+
135
+ # 2. Initialize the position embedding and transformer blocks.
136
+ self.height = self.config.sample_size
137
+ self.width = self.config.sample_size
138
+
139
+ interpolation_scale = (
140
+ self.config.interpolation_scale
141
+ if self.config.interpolation_scale is not None
142
+ else max(self.config.sample_size // 64, 1)
143
+ )
144
+ self.pos_embed = PatchEmbed(
145
+ height=self.config.sample_size,
146
+ width=self.config.sample_size,
147
+ patch_size=self.config.patch_size,
148
+ in_channels=self.config.in_channels,
149
+ embed_dim=self.inner_dim,
150
+ interpolation_scale=interpolation_scale,
151
+ )
152
+
153
+ self.transformer_blocks = nn.ModuleList(
154
+ [
155
+ BasicTransformerBlock(
156
+ self.inner_dim,
157
+ self.config.num_attention_heads,
158
+ self.config.attention_head_dim,
159
+ dropout=self.config.dropout,
160
+ cross_attention_dim=self.config.cross_attention_dim,
161
+ activation_fn=self.config.activation_fn,
162
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
163
+ attention_bias=self.config.attention_bias,
164
+ upcast_attention=self.config.upcast_attention,
165
+ norm_type=norm_type,
166
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
167
+ norm_eps=self.config.norm_eps,
168
+ attention_type=self.config.attention_type,
169
+ )
170
+ for _ in range(self.config.num_layers)
171
+ ]
172
+ )
173
+
174
+ # Initialize the positional embedding for the conditions for >1 UNI embeddings
175
+ if self.config.caption_num_tokens == 1:
176
+ self.y_pos_embed = None
177
+ else:
178
+ # 1:1 aspect ratio
179
+ self.uni_height = int(self.config.caption_num_tokens ** 0.5)
180
+ self.uni_width = int(self.config.caption_num_tokens ** 0.5)
181
+
182
+ self.y_pos_embed = UNIPosEmbed(
183
+ height=self.uni_height,
184
+ width=self.uni_width,
185
+ base_size=self.config.sample_size // self.config.patch_size,
186
+ embed_dim=self.config.caption_channels,
187
+ interpolation_scale=2, # Should this be fixed?
188
+ pos_embed_type="sincos", # This is fixed
189
+ )
190
+
191
+ # 3. Output blocks.
192
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
193
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
194
+ self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
195
+
196
+ self.adaln_single = AdaLayerNormSingle(
197
+ self.inner_dim, use_additional_conditions=self.use_additional_conditions
198
+ )
199
+ self.caption_projection = None
200
+ if self.config.caption_channels is not None:
201
+ self.caption_projection = PixcellUNIProjection(
202
+ in_features=self.config.caption_channels, hidden_size=self.inner_dim, num_tokens=self.config.caption_num_tokens,
203
+ )
204
+
205
+ def _set_gradient_checkpointing(self, module, value=False):
206
+ if hasattr(module, "gradient_checkpointing"):
207
+ module.gradient_checkpointing = value
208
+
209
+ @property
210
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
211
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
212
+ r"""
213
+ Returns:
214
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
215
+ indexed by its weight name.
216
+ """
217
+ # set recursively
218
+ processors = {}
219
+
220
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
221
+ if hasattr(module, "get_processor"):
222
+ processors[f"{name}.processor"] = module.get_processor()
223
+
224
+ for sub_name, child in module.named_children():
225
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
226
+
227
+ return processors
228
+
229
+ for name, module in self.named_children():
230
+ fn_recursive_add_processors(name, module, processors)
231
+
232
+ return processors
233
+
234
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
235
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
236
+ r"""
237
+ Sets the attention processor to use to compute attention.
238
+
239
+ Parameters:
240
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
241
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
242
+ for **all** `Attention` layers.
243
+
244
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
245
+ processor. This is strongly recommended when setting trainable attention processors.
246
+
247
+ """
248
+ count = len(self.attn_processors.keys())
249
+
250
+ if isinstance(processor, dict) and len(processor) != count:
251
+ raise ValueError(
252
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
253
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
254
+ )
255
+
256
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
257
+ if hasattr(module, "set_processor"):
258
+ if not isinstance(processor, dict):
259
+ module.set_processor(processor)
260
+ else:
261
+ module.set_processor(processor.pop(f"{name}.processor"))
262
+
263
+ for sub_name, child in module.named_children():
264
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
265
+
266
+ for name, module in self.named_children():
267
+ fn_recursive_attn_processor(name, module, processor)
268
+
269
+ def set_default_attn_processor(self):
270
+ """
271
+ Disables custom attention processors and sets the default attention implementation.
272
+
273
+ Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
274
+ """
275
+ self.set_attn_processor(AttnProcessor())
276
+
277
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
278
+ def fuse_qkv_projections(self):
279
+ """
280
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
281
+ are fused. For cross-attention modules, key and value projection matrices are fused.
282
+
283
+ <Tip warning={true}>
284
+
285
+ This API is 🧪 experimental.
286
+
287
+ </Tip>
288
+ """
289
+ self.original_attn_processors = None
290
+
291
+ for _, attn_processor in self.attn_processors.items():
292
+ if "Added" in str(attn_processor.__class__.__name__):
293
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
294
+
295
+ self.original_attn_processors = self.attn_processors
296
+
297
+ for module in self.modules():
298
+ if isinstance(module, Attention):
299
+ module.fuse_projections(fuse=True)
300
+
301
+ self.set_attn_processor(FusedAttnProcessor2_0())
302
+
303
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
304
+ def unfuse_qkv_projections(self):
305
+ """Disables the fused QKV projection if enabled.
306
+
307
+ <Tip warning={true}>
308
+
309
+ This API is 🧪 experimental.
310
+
311
+ </Tip>
312
+
313
+ """
314
+ if self.original_attn_processors is not None:
315
+ self.set_attn_processor(self.original_attn_processors)
316
+
317
+ def forward(
318
+ self,
319
+ hidden_states: torch.Tensor,
320
+ encoder_hidden_states: Optional[torch.Tensor] = None,
321
+ timestep: Optional[torch.LongTensor] = None,
322
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
323
+ cross_attention_kwargs: Dict[str, Any] = None,
324
+ attention_mask: Optional[torch.Tensor] = None,
325
+ encoder_attention_mask: Optional[torch.Tensor] = None,
326
+ return_dict: bool = True,
327
+ ):
328
+ """
329
+ The [`PixCellTransformer2DModel`] forward method.
330
+
331
+ Args:
332
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
333
+ Input `hidden_states`.
334
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
335
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
336
+ self-attention.
337
+ timestep (`torch.LongTensor`, *optional*):
338
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
339
+ added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs.
340
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
341
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
342
+ `self.processor` in
343
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
344
+ attention_mask ( `torch.Tensor`, *optional*):
345
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
346
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
347
+ negative values to the attention scores corresponding to "discard" tokens.
348
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
349
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
350
+
351
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
352
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
353
+
354
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
355
+ above. This bias will be added to the cross-attention scores.
356
+ return_dict (`bool`, *optional*, defaults to `True`):
357
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
358
+ tuple.
359
+
360
+ Returns:
361
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
362
+ `tuple` where the first element is the sample tensor.
363
+ """
364
+ if self.use_additional_conditions and added_cond_kwargs is None:
365
+ raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
366
+
367
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
368
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
369
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
370
+ # expects mask of shape:
371
+ # [batch, key_tokens]
372
+ # adds singleton query_tokens dimension:
373
+ # [batch, 1, key_tokens]
374
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
375
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
376
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
377
+ if attention_mask is not None and attention_mask.ndim == 2:
378
+ # assume that mask is expressed as:
379
+ # (1 = keep, 0 = discard)
380
+ # convert mask into a bias that can be added to attention scores:
381
+ # (keep = +0, discard = -10000.0)
382
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
383
+ attention_mask = attention_mask.unsqueeze(1)
384
+
385
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
386
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
387
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
388
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
389
+
390
+ # 1. Input
391
+ batch_size = hidden_states.shape[0]
392
+ height, width = (
393
+ hidden_states.shape[-2] // self.config.patch_size,
394
+ hidden_states.shape[-1] // self.config.patch_size,
395
+ )
396
+ hidden_states = self.pos_embed(hidden_states)
397
+
398
+ timestep, embedded_timestep = self.adaln_single(
399
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
400
+ )
401
+
402
+ if self.caption_projection is not None:
403
+ # Add positional embeddings to conditions if >1 UNI are given
404
+ if self.y_pos_embed is not None:
405
+ encoder_hidden_states = self.y_pos_embed(encoder_hidden_states)
406
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
407
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
408
+
409
+ # 2. Blocks
410
+ for block in self.transformer_blocks:
411
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
412
+
413
+ def create_custom_forward(module, return_dict=None):
414
+ def custom_forward(*inputs):
415
+ if return_dict is not None:
416
+ return module(*inputs, return_dict=return_dict)
417
+ else:
418
+ return module(*inputs)
419
+
420
+ return custom_forward
421
+
422
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
423
+ hidden_states = torch.utils.checkpoint.checkpoint(
424
+ create_custom_forward(block),
425
+ hidden_states,
426
+ attention_mask,
427
+ encoder_hidden_states,
428
+ encoder_attention_mask,
429
+ timestep,
430
+ cross_attention_kwargs,
431
+ None,
432
+ **ckpt_kwargs,
433
+ )
434
+ else:
435
+ hidden_states = block(
436
+ hidden_states,
437
+ attention_mask=attention_mask,
438
+ encoder_hidden_states=encoder_hidden_states,
439
+ encoder_attention_mask=encoder_attention_mask,
440
+ timestep=timestep,
441
+ cross_attention_kwargs=cross_attention_kwargs,
442
+ class_labels=None,
443
+ )
444
+
445
+ # 3. Output
446
+ shift, scale = (
447
+ self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
448
+ ).chunk(2, dim=1)
449
+ hidden_states = self.norm_out(hidden_states)
450
+ # Modulation
451
+ hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
452
+ hidden_states = self.proj_out(hidden_states)
453
+ hidden_states = hidden_states.squeeze(1)
454
+
455
+ # unpatchify
456
+ hidden_states = hidden_states.reshape(
457
+ shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
458
+ )
459
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
460
+ output = hidden_states.reshape(
461
+ shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
462
+ )
463
+
464
+ if not return_dict:
465
+ return (output,)
466
+
467
+ return Transformer2DModelOutput(sample=output)