vulus98 commited on
Commit
912db97
·
1 Parent(s): 0c05ca5

marigold folder added

Browse files
Marigold/LICENSE.txt ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
Marigold/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Code is copied from https://github.com/prs-eth/Marigold.
Marigold/resnet/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .resnet import ResnetBlock2D, ResnetBlockCondNorm2D
2
+ from .downsampling import Downsample2D, FirDownsample2D, KDownsample2D
3
+ from .upsampling import Upsample2D, FirUpsample2D, KUpsample2D
4
+
5
+ __all__ = [
6
+ "ResnetBlock2D",
7
+ "ResnetBlockCondNorm2D",
8
+ "Downsample2D",
9
+ "FirDownsample2D",
10
+ "KDownsample2D",
11
+ "Upsample2D",
12
+ "FirUpsample2D",
13
+ "KUpsample2D"
14
+ ]
Marigold/resnet/downsampling.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ from diffusers.utils import deprecate
22
+ from diffusers.models.normalization import RMSNorm
23
+ from Marigold.resnet.upsampling import upfirdn2d_native
24
+ from einops import rearrange
25
+
26
+
27
+ class Downsample1D(nn.Module):
28
+ """A 1D downsampling layer with an optional convolution.
29
+
30
+ Parameters:
31
+ channels (`int`):
32
+ number of channels in the inputs and outputs.
33
+ use_conv (`bool`, default `False`):
34
+ option to use a convolution.
35
+ out_channels (`int`, optional):
36
+ number of output channels. Defaults to `channels`.
37
+ padding (`int`, default `1`):
38
+ padding for the convolution.
39
+ name (`str`, default `conv`):
40
+ name of the downsampling 1D layer.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ channels: int,
46
+ use_conv: bool = False,
47
+ out_channels: Optional[int] = None,
48
+ padding: int = 1,
49
+ name: str = "conv",
50
+ ):
51
+ super().__init__()
52
+ self.channels = channels
53
+ self.out_channels = out_channels or channels
54
+ self.use_conv = use_conv
55
+ self.padding = padding
56
+ stride = 2
57
+ self.name = name
58
+
59
+ if use_conv:
60
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
61
+ else:
62
+ assert self.channels == self.out_channels
63
+ self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
64
+
65
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
66
+ assert inputs.shape[1] == self.channels
67
+ return self.conv(inputs)
68
+
69
+
70
+ class Downsample2D(nn.Module):
71
+ """A 2D downsampling layer with an optional convolution.
72
+
73
+ Parameters:
74
+ channels (`int`):
75
+ number of channels in the inputs and outputs.
76
+ use_conv (`bool`, default `False`):
77
+ option to use a convolution.
78
+ out_channels (`int`, optional):
79
+ number of output channels. Defaults to `channels`.
80
+ padding (`int`, default `1`):
81
+ padding for the convolution.
82
+ name (`str`, default `conv`):
83
+ name of the downsampling 2D layer.
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ channels: int,
89
+ use_conv: bool = False,
90
+ out_channels: Optional[int] = None,
91
+ padding: int = 1,
92
+ name: str = "conv",
93
+ kernel_size=3,
94
+ norm_type=None,
95
+ eps=None,
96
+ elementwise_affine=None,
97
+ bias=True,
98
+ ):
99
+ super().__init__()
100
+ self.channels = channels
101
+ self.out_channels = out_channels or channels
102
+ self.use_conv = use_conv
103
+ self.padding = padding
104
+ stride = 2
105
+ self.name = name
106
+
107
+ if norm_type == "ln_norm":
108
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
109
+ elif norm_type == "rms_norm":
110
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
111
+ elif norm_type is None:
112
+ self.norm = None
113
+ else:
114
+ raise ValueError(f"unknown norm_type: {norm_type}")
115
+
116
+ if use_conv:
117
+ conv = nn.Conv2d(
118
+ self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias,
119
+ )
120
+ else:
121
+ assert self.channels == self.out_channels
122
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
123
+
124
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
125
+ if name == "conv":
126
+ self.Conv2d_0 = conv
127
+ self.conv = conv
128
+ elif name == "Conv2d_0":
129
+ self.conv = conv
130
+ else:
131
+ self.conv = conv
132
+
133
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
134
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
135
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
136
+ deprecate("scale", "1.0.0", deprecation_message)
137
+ assert hidden_states.shape[1] == self.channels
138
+
139
+ if self.norm is not None:
140
+ hidden_states_permuted = hidden_states.permute(0, 2, 3, 1) # [N, C, H, W] -> [N, H, W, C]
141
+ b, c, h, w = hidden_states_permuted.shape
142
+ hidden_states_permuted = rearrange(hidden_states_permuted, "(b t) c h w -> b c (h w t)", b=1, h=h, w=w)
143
+ hidden_states = self.norm(hidden_states_permuted)
144
+ hidden_states = rearrange(hidden_states, "b c (h w t) -> (b t) c h w", b=1, h=h, w=w)
145
+ hidden_states = hidden_states.permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W]
146
+ if self.use_conv and self.padding == 0:
147
+ pad = (0, 1, 0, 1)
148
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
149
+
150
+ assert hidden_states.shape[1] == self.channels
151
+
152
+ hidden_states = self.conv(hidden_states)
153
+
154
+ return hidden_states
155
+
156
+
157
+ class FirDownsample2D(nn.Module):
158
+ """A 2D FIR downsampling layer with an optional convolution.
159
+
160
+ Parameters:
161
+ channels (`int`):
162
+ number of channels in the inputs and outputs.
163
+ use_conv (`bool`, default `False`):
164
+ option to use a convolution.
165
+ out_channels (`int`, optional):
166
+ number of output channels. Defaults to `channels`.
167
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
168
+ kernel for the FIR filter.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ channels: Optional[int] = None,
174
+ out_channels: Optional[int] = None,
175
+ use_conv: bool = False,
176
+ fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
177
+ ):
178
+ super().__init__()
179
+ out_channels = out_channels if out_channels else channels
180
+ if use_conv:
181
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
182
+ self.fir_kernel = fir_kernel
183
+ self.use_conv = use_conv
184
+ self.out_channels = out_channels
185
+
186
+ def _downsample_2d(
187
+ self,
188
+ hidden_states: torch.Tensor,
189
+ weight: Optional[torch.Tensor] = None,
190
+ kernel: Optional[torch.Tensor] = None,
191
+ factor: int = 2,
192
+ gain: float = 1,
193
+ ) -> torch.Tensor:
194
+ """Fused `Conv2d()` followed by `downsample_2d()`.
195
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
196
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
197
+ arbitrary order.
198
+
199
+ Args:
200
+ hidden_states (`torch.Tensor`):
201
+ Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
202
+ weight (`torch.Tensor`, *optional*):
203
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
204
+ performed by `inChannels = x.shape[0] // numGroups`.
205
+ kernel (`torch.Tensor`, *optional*):
206
+ FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
207
+ corresponds to average pooling.
208
+ factor (`int`, *optional*, default to `2`):
209
+ Integer downsampling factor.
210
+ gain (`float`, *optional*, default to `1.0`):
211
+ Scaling factor for signal magnitude.
212
+
213
+ Returns:
214
+ output (`torch.Tensor`):
215
+ Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
216
+ datatype as `x`.
217
+ """
218
+
219
+ assert isinstance(factor, int) and factor >= 1
220
+ if kernel is None:
221
+ kernel = [1] * factor
222
+
223
+ # setup kernel
224
+ kernel = torch.tensor(kernel, dtype=torch.float32)
225
+ if kernel.ndim == 1:
226
+ kernel = torch.outer(kernel, kernel)
227
+ kernel /= torch.sum(kernel)
228
+
229
+ kernel = kernel * gain
230
+
231
+ if self.use_conv:
232
+ _, _, convH, convW = weight.shape
233
+ pad_value = (kernel.shape[0] - factor) + (convW - 1)
234
+ stride_value = [factor, factor]
235
+ upfirdn_input = upfirdn2d_native(
236
+ hidden_states,
237
+ torch.tensor(kernel, device=hidden_states.device),
238
+ pad=((pad_value + 1) // 2, pad_value // 2),
239
+ )
240
+ output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
241
+ else:
242
+ pad_value = kernel.shape[0] - factor
243
+ output = upfirdn2d_native(
244
+ hidden_states,
245
+ torch.tensor(kernel, device=hidden_states.device),
246
+ down=factor,
247
+ pad=((pad_value + 1) // 2, pad_value // 2),
248
+ )
249
+
250
+ return output
251
+
252
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
253
+ if self.use_conv:
254
+ downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
255
+ hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
256
+ else:
257
+ hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
258
+
259
+ return hidden_states
260
+
261
+
262
+ # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
263
+ class KDownsample2D(nn.Module):
264
+ r"""A 2D K-downsampling layer.
265
+
266
+ Parameters:
267
+ pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
268
+ """
269
+
270
+ def __init__(self, pad_mode: str = "reflect"):
271
+ super().__init__()
272
+ self.pad_mode = pad_mode
273
+ kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
274
+ self.pad = kernel_1d.shape[1] // 2 - 1
275
+ self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
276
+
277
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
278
+ inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
279
+ weight = inputs.new_zeros(
280
+ [
281
+ inputs.shape[1],
282
+ inputs.shape[1],
283
+ self.kernel.shape[0],
284
+ self.kernel.shape[1],
285
+ ]
286
+ )
287
+ indices = torch.arange(inputs.shape[1], device=inputs.device)
288
+ kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
289
+ weight[indices, indices] = kernel
290
+ return F.conv2d(inputs, weight, stride=2)
291
+
292
+
293
+ class CogVideoXDownsample3D(nn.Module):
294
+ # Todo: Wait for paper relase.
295
+ r"""
296
+ A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
297
+
298
+ Args:
299
+ in_channels (`int`):
300
+ Number of channels in the input image.
301
+ out_channels (`int`):
302
+ Number of channels produced by the convolution.
303
+ kernel_size (`int`, defaults to `3`):
304
+ Size of the convolving kernel.
305
+ stride (`int`, defaults to `2`):
306
+ Stride of the convolution.
307
+ padding (`int`, defaults to `0`):
308
+ Padding added to all four sides of the input.
309
+ compress_time (`bool`, defaults to `False`):
310
+ Whether or not to compress the time dimension.
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ in_channels: int,
316
+ out_channels: int,
317
+ kernel_size: int = 3,
318
+ stride: int = 2,
319
+ padding: int = 0,
320
+ compress_time: bool = False,
321
+ ):
322
+ super().__init__()
323
+
324
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
325
+ self.compress_time = compress_time
326
+
327
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
328
+ if self.compress_time:
329
+ batch_size, channels, frames, height, width = x.shape
330
+
331
+ # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
332
+ x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
333
+
334
+ if x.shape[-1] % 2 == 1:
335
+ x_first, x_rest = x[..., 0], x[..., 1:]
336
+ if x_rest.shape[-1] > 0:
337
+ # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
338
+ x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
339
+
340
+ x = torch.cat([x_first[..., None], x_rest], dim=-1)
341
+ # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
342
+ x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
343
+ else:
344
+ # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
345
+ x = F.avg_pool1d(x, kernel_size=2, stride=2)
346
+ # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
347
+ x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
348
+
349
+ # Pad the tensor
350
+ pad = (0, 1, 0, 1)
351
+ x = F.pad(x, pad, mode="constant", value=0)
352
+ batch_size, channels, frames, height, width = x.shape
353
+ # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
354
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
355
+ x = self.conv(x)
356
+ # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
357
+ x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
358
+ return x
359
+
360
+
361
+ def downsample_2d(
362
+ hidden_states: torch.Tensor,
363
+ kernel: Optional[torch.Tensor] = None,
364
+ factor: int = 2,
365
+ gain: float = 1,
366
+ ) -> torch.Tensor:
367
+ r"""Downsample2D a batch of 2D images with the given filter.
368
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
369
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
370
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
371
+ shape is a multiple of the downsampling factor.
372
+
373
+ Args:
374
+ hidden_states (`torch.Tensor`)
375
+ Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
376
+ kernel (`torch.Tensor`, *optional*):
377
+ FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
378
+ corresponds to average pooling.
379
+ factor (`int`, *optional*, default to `2`):
380
+ Integer downsampling factor.
381
+ gain (`float`, *optional*, default to `1.0`):
382
+ Scaling factor for signal magnitude.
383
+
384
+ Returns:
385
+ output (`torch.Tensor`):
386
+ Tensor of the shape `[N, C, H // factor, W // factor]`
387
+ """
388
+
389
+ assert isinstance(factor, int) and factor >= 1
390
+ if kernel is None:
391
+ kernel = [1] * factor
392
+
393
+ kernel = torch.tensor(kernel, dtype=torch.float32)
394
+ if kernel.ndim == 1:
395
+ kernel = torch.outer(kernel, kernel)
396
+ kernel /= torch.sum(kernel)
397
+
398
+ kernel = kernel * gain
399
+ pad_value = kernel.shape[0] - factor
400
+ output = upfirdn2d_native(
401
+ hidden_states,
402
+ kernel.to(device=hidden_states.device),
403
+ down=factor,
404
+ pad=((pad_value + 1) // 2, pad_value // 2),
405
+ )
406
+ return output
Marigold/resnet/resnet.py ADDED
@@ -0,0 +1,802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ # `TemporalConvLayer` Copyright 2024 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusers.utils import deprecate
24
+ from diffusers.models.activations import get_activation
25
+ from diffusers.models.attention_processor import SpatialNorm
26
+ from Marigold.resnet.downsampling import ( # noqa
27
+ Downsample2D,
28
+ downsample_2d,
29
+ )
30
+ from diffusers.models.normalization import AdaGroupNorm
31
+ from Marigold.resnet.upsampling import ( # noqa
32
+ Upsample2D,
33
+ upsample_2d,
34
+ )
35
+ from einops import rearrange
36
+
37
+ class ResnetBlockCondNorm2D(nn.Module):
38
+ r"""
39
+ A Resnet block that use normalization layer that incorporate conditioning information.
40
+
41
+ Parameters:
42
+ in_channels (`int`): The number of channels in the input.
43
+ out_channels (`int`, *optional*, default to be `None`):
44
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
45
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
46
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
47
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
48
+ groups_out (`int`, *optional*, default to None):
49
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
50
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
51
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
52
+ time_embedding_norm (`str`, *optional*, default to `"ada_group"` ):
53
+ The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial".
54
+ kernel (`torch.Tensor`, optional, default to None): FIR filter, see
55
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
56
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
57
+ use_in_shortcut (`bool`, *optional*, default to `True`):
58
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
59
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
60
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
61
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
62
+ `conv_shortcut` output.
63
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
64
+ If None, same as `out_channels`.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ *,
70
+ in_channels: int,
71
+ out_channels: Optional[int] = None,
72
+ conv_shortcut: bool = False,
73
+ dropout: float = 0.0,
74
+ temb_channels: int = 512,
75
+ groups: int = 32,
76
+ groups_out: Optional[int] = None,
77
+ eps: float = 1e-6,
78
+ non_linearity: str = "swish",
79
+ time_embedding_norm: str = "ada_group", # ada_group, spatial
80
+ output_scale_factor: float = 1.0,
81
+ use_in_shortcut: Optional[bool] = None,
82
+ up: bool = False,
83
+ down: bool = False,
84
+ conv_shortcut_bias: bool = True,
85
+ conv_2d_out_channels: Optional[int] = None,
86
+ ):
87
+ super().__init__()
88
+ self.in_channels = in_channels
89
+ out_channels = in_channels if out_channels is None else out_channels
90
+ self.out_channels = out_channels
91
+ self.use_conv_shortcut = conv_shortcut
92
+ self.up = up
93
+ self.down = down
94
+ self.output_scale_factor = output_scale_factor
95
+ self.time_embedding_norm = time_embedding_norm
96
+
97
+ if groups_out is None:
98
+ groups_out = groups
99
+
100
+ if self.time_embedding_norm == "ada_group": # ada_group
101
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
102
+ elif self.time_embedding_norm == "spatial":
103
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
104
+ else:
105
+ raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")
106
+
107
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
108
+
109
+ if self.time_embedding_norm == "ada_group": # ada_group
110
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
111
+ elif self.time_embedding_norm == "spatial": # spatial
112
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
113
+ else:
114
+ raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")
115
+
116
+ self.dropout = torch.nn.Dropout(dropout)
117
+
118
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
119
+ self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
120
+
121
+ self.nonlinearity = get_activation(non_linearity)
122
+
123
+ self.upsample = self.downsample = None
124
+ if self.up:
125
+ self.upsample = Upsample2D(in_channels, use_conv=False)
126
+ elif self.down:
127
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
128
+
129
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
130
+
131
+ self.conv_shortcut = None
132
+ if self.use_in_shortcut:
133
+ self.conv_shortcut = nn.Conv2d(
134
+ in_channels,
135
+ conv_2d_out_channels,
136
+ kernel_size=1,
137
+ stride=1,
138
+ padding=0,
139
+ bias=conv_shortcut_bias,
140
+ )
141
+
142
+ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
143
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
144
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
145
+ deprecate("scale", "1.0.0", deprecation_message)
146
+
147
+ hidden_states = input_tensor
148
+
149
+ hidden_states = self.norm1(hidden_states, temb)
150
+
151
+ hidden_states = self.nonlinearity(hidden_states)
152
+
153
+ if self.upsample is not None:
154
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
155
+ if hidden_states.shape[0] >= 64:
156
+ input_tensor = input_tensor.contiguous()
157
+ hidden_states = hidden_states.contiguous()
158
+ input_tensor = self.upsample(input_tensor)
159
+ hidden_states = self.upsample(hidden_states)
160
+
161
+ elif self.downsample is not None:
162
+ input_tensor = self.downsample(input_tensor)
163
+ hidden_states = self.downsample(hidden_states)
164
+
165
+ hidden_states = self.conv1(hidden_states)
166
+
167
+ hidden_states = self.norm2(hidden_states, temb)
168
+
169
+ hidden_states = self.nonlinearity(hidden_states)
170
+
171
+ hidden_states = self.dropout(hidden_states)
172
+ hidden_states = self.conv2(hidden_states)
173
+
174
+ if self.conv_shortcut is not None:
175
+ input_tensor = self.conv_shortcut(input_tensor)
176
+
177
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
178
+
179
+ return output_tensor
180
+
181
+
182
+ class ResnetBlock2D(nn.Module):
183
+ r"""
184
+ A Resnet block.
185
+
186
+ Parameters:
187
+ in_channels (`int`): The number of channels in the input.
188
+ out_channels (`int`, *optional*, default to be `None`):
189
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
190
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
191
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
192
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
193
+ groups_out (`int`, *optional*, default to None):
194
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
195
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
196
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
197
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
198
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a
199
+ stronger conditioning with scale and shift.
200
+ kernel (`torch.Tensor`, optional, default to None): FIR filter, see
201
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
202
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
203
+ use_in_shortcut (`bool`, *optional*, default to `True`):
204
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
205
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
206
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
207
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
208
+ `conv_shortcut` output.
209
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
210
+ If None, same as `out_channels`.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ *,
216
+ in_channels: int,
217
+ out_channels: Optional[int] = None,
218
+ conv_shortcut: bool = False,
219
+ dropout: float = 0.0,
220
+ temb_channels: int = 512,
221
+ groups: int = 32,
222
+ groups_out: Optional[int] = None,
223
+ pre_norm: bool = True,
224
+ eps: float = 1e-6,
225
+ non_linearity: str = "swish",
226
+ skip_time_act: bool = False,
227
+ time_embedding_norm: str = "default", # default, scale_shift,
228
+ kernel: Optional[torch.Tensor] = None,
229
+ output_scale_factor: float = 1.0,
230
+ use_in_shortcut: Optional[bool] = None,
231
+ up: bool = False,
232
+ down: bool = False,
233
+ conv_shortcut_bias: bool = True,
234
+ conv_2d_out_channels: Optional[int] = None,
235
+ ):
236
+ super().__init__()
237
+ if time_embedding_norm == "ada_group":
238
+ raise ValueError(
239
+ "This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead",
240
+ )
241
+ if time_embedding_norm == "spatial":
242
+ raise ValueError(
243
+ "This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead",
244
+ )
245
+
246
+ self.pre_norm = True
247
+ self.in_channels = in_channels
248
+ out_channels = in_channels if out_channels is None else out_channels
249
+ self.out_channels = out_channels
250
+ self.use_conv_shortcut = conv_shortcut
251
+ self.up = up
252
+ self.down = down
253
+ self.output_scale_factor = output_scale_factor
254
+ self.time_embedding_norm = time_embedding_norm
255
+ self.skip_time_act = skip_time_act
256
+
257
+ if groups_out is None:
258
+ groups_out = groups
259
+
260
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
261
+
262
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
263
+
264
+ if temb_channels is not None:
265
+ if self.time_embedding_norm == "default":
266
+ self.time_emb_proj = nn.Linear(temb_channels, out_channels)
267
+ elif self.time_embedding_norm == "scale_shift":
268
+ self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels)
269
+ else:
270
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
271
+ else:
272
+ self.time_emb_proj = None
273
+
274
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
275
+
276
+ self.dropout = torch.nn.Dropout(dropout)
277
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
278
+ self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
279
+
280
+ self.nonlinearity = get_activation(non_linearity)
281
+
282
+ self.upsample = self.downsample = None
283
+ if self.up:
284
+ if kernel == "fir":
285
+ fir_kernel = (1, 3, 3, 1)
286
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
287
+ elif kernel == "sde_vp":
288
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
289
+ else:
290
+ self.upsample = Upsample2D(in_channels, use_conv=False)
291
+ elif self.down:
292
+ if kernel == "fir":
293
+ fir_kernel = (1, 3, 3, 1)
294
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
295
+ elif kernel == "sde_vp":
296
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
297
+ else:
298
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
299
+
300
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
301
+
302
+ self.conv_shortcut = None
303
+ if self.use_in_shortcut:
304
+ self.conv_shortcut = nn.Conv2d(
305
+ in_channels,
306
+ conv_2d_out_channels,
307
+ kernel_size=1,
308
+ stride=1,
309
+ padding=0,
310
+ bias=conv_shortcut_bias,
311
+ )
312
+
313
+ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
314
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
315
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
316
+ deprecate("scale", "1.0.0", deprecation_message)
317
+
318
+ hidden_states = input_tensor
319
+
320
+ b, c, h, w = hidden_states.shape
321
+ hidden_states = rearrange(hidden_states, "(b t) c h w -> b c (h w t)", t=6, h=h, w=w)
322
+ hidden_states = self.norm1(hidden_states)
323
+ hidden_states = rearrange(hidden_states, "b c (h w t) -> (b t) c h w", t=6, h=h, w=w)
324
+ hidden_states = self.nonlinearity(hidden_states)
325
+
326
+ if self.upsample is not None:
327
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
328
+ if hidden_states.shape[0] >= 64:
329
+ input_tensor = input_tensor.contiguous()
330
+ hidden_states = hidden_states.contiguous()
331
+ input_tensor = self.upsample(input_tensor)
332
+ hidden_states = self.upsample(hidden_states)
333
+ elif self.downsample is not None:
334
+ input_tensor = self.downsample(input_tensor)
335
+ hidden_states = self.downsample(hidden_states)
336
+
337
+ hidden_states = self.conv1(hidden_states)
338
+
339
+ if self.time_emb_proj is not None:
340
+ if not self.skip_time_act:
341
+ temb = self.nonlinearity(temb)
342
+ temb = self.time_emb_proj(temb)[:, :, None, None]
343
+
344
+ if self.time_embedding_norm == "default":
345
+ if temb is not None:
346
+ hidden_states = hidden_states + temb
347
+ b, c, h, w = hidden_states.shape
348
+ hidden_states = rearrange(hidden_states, "(b t) c h w -> b c (h w t)", t=6, h=h, w=w)
349
+ hidden_states = self.norm2(hidden_states)
350
+ hidden_states = rearrange(hidden_states, "b c (h w t) -> (b t) c h w", t=6, h=h, w=w)
351
+ elif self.time_embedding_norm == "scale_shift":
352
+ if temb is None:
353
+ raise ValueError(
354
+ f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}"
355
+ )
356
+ time_scale, time_shift = torch.chunk(temb, 2, dim=1)
357
+ b, c, h, w = hidden_states.shape
358
+ hidden_states = rearrange(hidden_states, "(b t) c h w -> b c (h w t)", t=6, h=h, w=w)
359
+ hidden_states = self.norm2(hidden_states)
360
+ hidden_states = rearrange(hidden_states, "b c (h w t) -> (b t) c h w", t=6, h=h, w=w)
361
+ hidden_states = hidden_states * (1 + time_scale) + time_shift
362
+ else:
363
+ b, c, h, w = hidden_states.shape
364
+ hidden_states = rearrange(hidden_states, "(b t) c h w -> b c (h w t)", t=6, h=h, w=w)
365
+ hidden_states = self.norm2(hidden_states)
366
+ hidden_states = rearrange(hidden_states, "b c (h w t) -> (b t) c h w", t=6, h=h, w=w)
367
+
368
+ hidden_states = self.nonlinearity(hidden_states)
369
+
370
+ hidden_states = self.dropout(hidden_states)
371
+ hidden_states = self.conv2(hidden_states)
372
+
373
+ if self.conv_shortcut is not None:
374
+ input_tensor = self.conv_shortcut(input_tensor)
375
+
376
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
377
+
378
+ return output_tensor
379
+
380
+
381
+ # unet_rl.py
382
+ def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
383
+ if len(tensor.shape) == 2:
384
+ return tensor[:, :, None]
385
+ if len(tensor.shape) == 3:
386
+ return tensor[:, :, None, :]
387
+ elif len(tensor.shape) == 4:
388
+ return tensor[:, :, 0, :]
389
+ else:
390
+ raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
391
+
392
+
393
+ class Conv1dBlock(nn.Module):
394
+ """
395
+ Conv1d --> GroupNorm --> Mish
396
+
397
+ Parameters:
398
+ inp_channels (`int`): Number of input channels.
399
+ out_channels (`int`): Number of output channels.
400
+ kernel_size (`int` or `tuple`): Size of the convolving kernel.
401
+ n_groups (`int`, default `8`): Number of groups to separate the channels into.
402
+ activation (`str`, defaults to `mish`): Name of the activation function.
403
+ """
404
+
405
+ def __init__(
406
+ self,
407
+ inp_channels: int,
408
+ out_channels: int,
409
+ kernel_size: Union[int, Tuple[int, int]],
410
+ n_groups: int = 8,
411
+ activation: str = "mish",
412
+ ):
413
+ super().__init__()
414
+
415
+ self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
416
+ self.group_norm = nn.GroupNorm(n_groups, out_channels)
417
+ self.mish = get_activation(activation)
418
+
419
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
420
+ intermediate_repr = self.conv1d(inputs)
421
+ intermediate_repr = rearrange_dims(intermediate_repr)
422
+ intermediate_repr = self.group_norm(intermediate_repr)
423
+ intermediate_repr = rearrange_dims(intermediate_repr)
424
+ output = self.mish(intermediate_repr)
425
+ return output
426
+
427
+
428
+ # unet_rl.py
429
+ class ResidualTemporalBlock1D(nn.Module):
430
+ """
431
+ Residual 1D block with temporal convolutions.
432
+
433
+ Parameters:
434
+ inp_channels (`int`): Number of input channels.
435
+ out_channels (`int`): Number of output channels.
436
+ embed_dim (`int`): Embedding dimension.
437
+ kernel_size (`int` or `tuple`): Size of the convolving kernel.
438
+ activation (`str`, defaults `mish`): It is possible to choose the right activation function.
439
+ """
440
+
441
+ def __init__(
442
+ self,
443
+ inp_channels: int,
444
+ out_channels: int,
445
+ embed_dim: int,
446
+ kernel_size: Union[int, Tuple[int, int]] = 5,
447
+ activation: str = "mish",
448
+ ):
449
+ super().__init__()
450
+ self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
451
+ self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
452
+
453
+ self.time_emb_act = get_activation(activation)
454
+ self.time_emb = nn.Linear(embed_dim, out_channels)
455
+
456
+ self.residual_conv = (
457
+ nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
458
+ )
459
+
460
+ def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
461
+ """
462
+ Args:
463
+ inputs : [ batch_size x inp_channels x horizon ]
464
+ t : [ batch_size x embed_dim ]
465
+
466
+ returns:
467
+ out : [ batch_size x out_channels x horizon ]
468
+ """
469
+ t = self.time_emb_act(t)
470
+ t = self.time_emb(t)
471
+ out = self.conv_in(inputs) + rearrange_dims(t)
472
+ out = self.conv_out(out)
473
+ return out + self.residual_conv(inputs)
474
+
475
+
476
+ class TemporalConvLayer(nn.Module):
477
+ """
478
+ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
479
+ https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
480
+
481
+ Parameters:
482
+ in_dim (`int`): Number of input channels.
483
+ out_dim (`int`): Number of output channels.
484
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
485
+ """
486
+
487
+ def __init__(
488
+ self,
489
+ in_dim: int,
490
+ out_dim: Optional[int] = None,
491
+ dropout: float = 0.0,
492
+ norm_num_groups: int = 32,
493
+ ):
494
+ super().__init__()
495
+ out_dim = out_dim or in_dim
496
+ self.in_dim = in_dim
497
+ self.out_dim = out_dim
498
+
499
+ # conv layers
500
+ self.conv1 = nn.Sequential(
501
+ nn.GroupNorm(norm_num_groups, in_dim),
502
+ nn.SiLU(),
503
+ nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
504
+ )
505
+ self.conv2 = nn.Sequential(
506
+ nn.GroupNorm(norm_num_groups, out_dim),
507
+ nn.SiLU(),
508
+ nn.Dropout(dropout),
509
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
510
+ )
511
+ self.conv3 = nn.Sequential(
512
+ nn.GroupNorm(norm_num_groups, out_dim),
513
+ nn.SiLU(),
514
+ nn.Dropout(dropout),
515
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
516
+ )
517
+ self.conv4 = nn.Sequential(
518
+ nn.GroupNorm(norm_num_groups, out_dim),
519
+ nn.SiLU(),
520
+ nn.Dropout(dropout),
521
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
522
+ )
523
+
524
+ # zero out the last layer params,so the conv block is identity
525
+ nn.init.zeros_(self.conv4[-1].weight)
526
+ nn.init.zeros_(self.conv4[-1].bias)
527
+
528
+ def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:
529
+ hidden_states = (
530
+ hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
531
+ )
532
+
533
+ identity = hidden_states
534
+ hidden_states = self.conv1(hidden_states)
535
+ hidden_states = self.conv2(hidden_states)
536
+ hidden_states = self.conv3(hidden_states)
537
+ hidden_states = self.conv4(hidden_states)
538
+
539
+ hidden_states = identity + hidden_states
540
+
541
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
542
+ (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
543
+ )
544
+ return hidden_states
545
+
546
+
547
+ class TemporalResnetBlock(nn.Module):
548
+ r"""
549
+ A Resnet block.
550
+
551
+ Parameters:
552
+ in_channels (`int`): The number of channels in the input.
553
+ out_channels (`int`, *optional*, default to be `None`):
554
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
555
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
556
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
557
+ """
558
+
559
+ def __init__(
560
+ self,
561
+ in_channels: int,
562
+ out_channels: Optional[int] = None,
563
+ temb_channels: int = 512,
564
+ eps: float = 1e-6,
565
+ ):
566
+ super().__init__()
567
+ self.in_channels = in_channels
568
+ out_channels = in_channels if out_channels is None else out_channels
569
+ self.out_channels = out_channels
570
+
571
+ kernel_size = (3, 1, 1)
572
+ padding = [k // 2 for k in kernel_size]
573
+
574
+ self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True)
575
+ self.conv1 = nn.Conv3d(
576
+ in_channels,
577
+ out_channels,
578
+ kernel_size=kernel_size,
579
+ stride=1,
580
+ padding=padding,
581
+ )
582
+
583
+ if temb_channels is not None:
584
+ self.time_emb_proj = nn.Linear(temb_channels, out_channels)
585
+ else:
586
+ self.time_emb_proj = None
587
+
588
+ self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True)
589
+
590
+ self.dropout = torch.nn.Dropout(0.0)
591
+ self.conv2 = nn.Conv3d(
592
+ out_channels,
593
+ out_channels,
594
+ kernel_size=kernel_size,
595
+ stride=1,
596
+ padding=padding,
597
+ )
598
+
599
+ self.nonlinearity = get_activation("silu")
600
+
601
+ self.use_in_shortcut = self.in_channels != out_channels
602
+
603
+ self.conv_shortcut = None
604
+ if self.use_in_shortcut:
605
+ self.conv_shortcut = nn.Conv3d(
606
+ in_channels,
607
+ out_channels,
608
+ kernel_size=1,
609
+ stride=1,
610
+ padding=0,
611
+ )
612
+
613
+ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
614
+ hidden_states = input_tensor
615
+
616
+ hidden_states = self.norm1(hidden_states)
617
+ hidden_states = self.nonlinearity(hidden_states)
618
+ hidden_states = self.conv1(hidden_states)
619
+
620
+ if self.time_emb_proj is not None:
621
+ temb = self.nonlinearity(temb)
622
+ temb = self.time_emb_proj(temb)[:, :, :, None, None]
623
+ temb = temb.permute(0, 2, 1, 3, 4)
624
+ hidden_states = hidden_states + temb
625
+
626
+ hidden_states = self.norm2(hidden_states)
627
+ hidden_states = self.nonlinearity(hidden_states)
628
+ hidden_states = self.dropout(hidden_states)
629
+ hidden_states = self.conv2(hidden_states)
630
+
631
+ if self.conv_shortcut is not None:
632
+ input_tensor = self.conv_shortcut(input_tensor)
633
+
634
+ output_tensor = input_tensor + hidden_states
635
+
636
+ return output_tensor
637
+
638
+
639
+ # VideoResBlock
640
+ class SpatioTemporalResBlock(nn.Module):
641
+ r"""
642
+ A SpatioTemporal Resnet block.
643
+
644
+ Parameters:
645
+ in_channels (`int`): The number of channels in the input.
646
+ out_channels (`int`, *optional*, default to be `None`):
647
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
648
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
649
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet.
650
+ temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet.
651
+ merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing.
652
+ merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
653
+ The merge strategy to use for the temporal mixing.
654
+ switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
655
+ If `True`, switch the spatial and temporal mixing.
656
+ """
657
+
658
+ def __init__(
659
+ self,
660
+ in_channels: int,
661
+ out_channels: Optional[int] = None,
662
+ temb_channels: int = 512,
663
+ eps: float = 1e-6,
664
+ temporal_eps: Optional[float] = None,
665
+ merge_factor: float = 0.5,
666
+ merge_strategy="learned_with_images",
667
+ switch_spatial_to_temporal_mix: bool = False,
668
+ ):
669
+ super().__init__()
670
+
671
+ self.spatial_res_block = ResnetBlock2D(
672
+ in_channels=in_channels,
673
+ out_channels=out_channels,
674
+ temb_channels=temb_channels,
675
+ eps=eps,
676
+ )
677
+
678
+ self.temporal_res_block = TemporalResnetBlock(
679
+ in_channels=out_channels if out_channels is not None else in_channels,
680
+ out_channels=out_channels if out_channels is not None else in_channels,
681
+ temb_channels=temb_channels,
682
+ eps=temporal_eps if temporal_eps is not None else eps,
683
+ )
684
+
685
+ self.time_mixer = AlphaBlender(
686
+ alpha=merge_factor,
687
+ merge_strategy=merge_strategy,
688
+ switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
689
+ )
690
+
691
+ def forward(
692
+ self,
693
+ hidden_states: torch.Tensor,
694
+ temb: Optional[torch.Tensor] = None,
695
+ image_only_indicator: Optional[torch.Tensor] = None,
696
+ ):
697
+ num_frames = image_only_indicator.shape[-1]
698
+ hidden_states = self.spatial_res_block(hidden_states, temb)
699
+
700
+ batch_frames, channels, height, width = hidden_states.shape
701
+ batch_size = batch_frames // num_frames
702
+
703
+ hidden_states_mix = (
704
+ hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
705
+ )
706
+ hidden_states = (
707
+ hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
708
+ )
709
+
710
+ if temb is not None:
711
+ temb = temb.reshape(batch_size, num_frames, -1)
712
+
713
+ hidden_states = self.temporal_res_block(hidden_states, temb)
714
+ hidden_states = self.time_mixer(
715
+ x_spatial=hidden_states_mix,
716
+ x_temporal=hidden_states,
717
+ image_only_indicator=image_only_indicator,
718
+ )
719
+
720
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
721
+ return hidden_states
722
+
723
+
724
+ class AlphaBlender(nn.Module):
725
+ r"""
726
+ A module to blend spatial and temporal features.
727
+
728
+ Parameters:
729
+ alpha (`float`): The initial value of the blending factor.
730
+ merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
731
+ The merge strategy to use for the temporal mixing.
732
+ switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
733
+ If `True`, switch the spatial and temporal mixing.
734
+ """
735
+
736
+ strategies = ["learned", "fixed", "learned_with_images"]
737
+
738
+ def __init__(
739
+ self,
740
+ alpha: float,
741
+ merge_strategy: str = "learned_with_images",
742
+ switch_spatial_to_temporal_mix: bool = False,
743
+ ):
744
+ super().__init__()
745
+ self.merge_strategy = merge_strategy
746
+ self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix # For TemporalVAE
747
+
748
+ if merge_strategy not in self.strategies:
749
+ raise ValueError(f"merge_strategy needs to be in {self.strategies}")
750
+
751
+ if self.merge_strategy == "fixed":
752
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
753
+ elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
754
+ self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
755
+ else:
756
+ raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
757
+
758
+ def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
759
+ if self.merge_strategy == "fixed":
760
+ alpha = self.mix_factor
761
+
762
+ elif self.merge_strategy == "learned":
763
+ alpha = torch.sigmoid(self.mix_factor)
764
+
765
+ elif self.merge_strategy == "learned_with_images":
766
+ if image_only_indicator is None:
767
+ raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy")
768
+
769
+ alpha = torch.where(
770
+ image_only_indicator.bool(),
771
+ torch.ones(1, 1, device=image_only_indicator.device),
772
+ torch.sigmoid(self.mix_factor)[..., None],
773
+ )
774
+
775
+ # (batch, channel, frames, height, width)
776
+ if ndims == 5:
777
+ alpha = alpha[:, None, :, None, None]
778
+ # (batch*frames, height*width, channels)
779
+ elif ndims == 3:
780
+ alpha = alpha.reshape(-1)[:, None, None]
781
+ else:
782
+ raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")
783
+
784
+ else:
785
+ raise NotImplementedError
786
+
787
+ return alpha
788
+
789
+ def forward(
790
+ self,
791
+ x_spatial: torch.Tensor,
792
+ x_temporal: torch.Tensor,
793
+ image_only_indicator: Optional[torch.Tensor] = None,
794
+ ) -> torch.Tensor:
795
+ alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
796
+ alpha = alpha.to(x_spatial.dtype)
797
+
798
+ if self.switch_spatial_to_temporal_mix:
799
+ alpha = 1.0 - alpha
800
+
801
+ x = alpha * x_spatial + (1.0 - alpha) * x_temporal
802
+ return x
Marigold/resnet/upsampling.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ from diffusers.utils import deprecate
22
+ from diffusers.models.normalization import RMSNorm
23
+ from einops import rearrange
24
+
25
+
26
+ class Upsample1D(nn.Module):
27
+ """A 1D upsampling layer with an optional convolution.
28
+
29
+ Parameters:
30
+ channels (`int`):
31
+ number of channels in the inputs and outputs.
32
+ use_conv (`bool`, default `False`):
33
+ option to use a convolution.
34
+ use_conv_transpose (`bool`, default `False`):
35
+ option to use a convolution transpose.
36
+ out_channels (`int`, optional):
37
+ number of output channels. Defaults to `channels`.
38
+ name (`str`, default `conv`):
39
+ name of the upsampling 1D layer.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ channels: int,
45
+ use_conv: bool = False,
46
+ use_conv_transpose: bool = False,
47
+ out_channels: Optional[int] = None,
48
+ name: str = "conv",
49
+ ):
50
+ super().__init__()
51
+ self.channels = channels
52
+ self.out_channels = out_channels or channels
53
+ self.use_conv = use_conv
54
+ self.use_conv_transpose = use_conv_transpose
55
+ self.name = name
56
+
57
+ self.conv = None
58
+ if use_conv_transpose:
59
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
60
+ elif use_conv:
61
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
62
+
63
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
64
+ assert inputs.shape[1] == self.channels
65
+ if self.use_conv_transpose:
66
+ return self.conv(inputs)
67
+
68
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
69
+
70
+ if self.use_conv:
71
+ outputs = self.conv(outputs)
72
+
73
+ return outputs
74
+
75
+
76
+ class Upsample2D(nn.Module):
77
+ """A 2D upsampling layer with an optional convolution.
78
+
79
+ Parameters:
80
+ channels (`int`):
81
+ number of channels in the inputs and outputs.
82
+ use_conv (`bool`, default `False`):
83
+ option to use a convolution.
84
+ use_conv_transpose (`bool`, default `False`):
85
+ option to use a convolution transpose.
86
+ out_channels (`int`, optional):
87
+ number of output channels. Defaults to `channels`.
88
+ name (`str`, default `conv`):
89
+ name of the upsampling 2D layer.
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ channels: int,
95
+ use_conv: bool = False,
96
+ use_conv_transpose: bool = False,
97
+ out_channels: Optional[int] = None,
98
+ name: str = "conv",
99
+ kernel_size: Optional[int] = None,
100
+ padding=1,
101
+ norm_type=None,
102
+ eps=None,
103
+ elementwise_affine=None,
104
+ bias=True,
105
+ interpolate=True,
106
+ ):
107
+ super().__init__()
108
+ self.channels = channels
109
+ self.out_channels = out_channels or channels
110
+ self.use_conv = use_conv
111
+ self.use_conv_transpose = use_conv_transpose
112
+ self.name = name
113
+ self.interpolate = interpolate
114
+
115
+ if norm_type == "ln_norm":
116
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
117
+ elif norm_type == "rms_norm":
118
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
119
+ elif norm_type is None:
120
+ self.norm = None
121
+ else:
122
+ raise ValueError(f"unknown norm_type: {norm_type}")
123
+
124
+ conv = None
125
+ if use_conv_transpose:
126
+ if kernel_size is None:
127
+ kernel_size = 4
128
+ conv = nn.ConvTranspose2d(
129
+ channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
130
+ )
131
+ elif use_conv:
132
+ if kernel_size is None:
133
+ kernel_size = 3
134
+ conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
135
+
136
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
137
+ if name == "conv":
138
+ self.conv = conv
139
+ else:
140
+ self.Conv2d_0 = conv
141
+
142
+ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, *args, **kwargs) -> torch.Tensor:
143
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
144
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
145
+ deprecate("scale", "1.0.0", deprecation_message)
146
+
147
+ assert hidden_states.shape[1] == self.channels
148
+
149
+ if self.norm is not None:
150
+ hidden_states_permuted = hidden_states.permute(0, 2, 3, 1) # [N, C, H, W] -> [N, H, W, C]
151
+ b, c, h, w = hidden_states_permuted.shape
152
+ hidden_states_permuted = rearrange(hidden_states_permuted, "(b t) c h w -> b c (h w t)", b=1, h=h, w=w)
153
+ hidden_states = self.norm(hidden_states_permuted)
154
+ hidden_states = rearrange(hidden_states, "b c (h w t) -> (b t) c h w", b=1, h=h, w=w)
155
+ hidden_states = hidden_states.permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W]
156
+
157
+ if self.use_conv_transpose:
158
+ return self.conv(hidden_states)
159
+
160
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
161
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
162
+ # https://github.com/pytorch/pytorch/issues/86679
163
+ dtype = hidden_states.dtype
164
+ if dtype == torch.bfloat16:
165
+ hidden_states = hidden_states.to(torch.float32)
166
+
167
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
168
+ if hidden_states.shape[0] >= 64:
169
+ hidden_states = hidden_states.contiguous()
170
+
171
+ # if `output_size` is passed we force the interpolation output
172
+ # size and do not make use of `scale_factor=2`
173
+ if self.interpolate:
174
+ if output_size is None:
175
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
176
+ else:
177
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
178
+
179
+ # If the input is bfloat16, we cast back to bfloat16
180
+ if dtype == torch.bfloat16:
181
+ hidden_states = hidden_states.to(dtype)
182
+
183
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
184
+ if self.use_conv:
185
+ if self.name == "conv":
186
+ hidden_states = self.conv(hidden_states)
187
+ else:
188
+ hidden_states = self.Conv2d_0(hidden_states)
189
+
190
+ return hidden_states
191
+
192
+
193
+ class FirUpsample2D(nn.Module):
194
+ """A 2D FIR upsampling layer with an optional convolution.
195
+
196
+ Parameters:
197
+ channels (`int`, optional):
198
+ number of channels in the inputs and outputs.
199
+ use_conv (`bool`, default `False`):
200
+ option to use a convolution.
201
+ out_channels (`int`, optional):
202
+ number of output channels. Defaults to `channels`.
203
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
204
+ kernel for the FIR filter.
205
+ """
206
+
207
+ def __init__(
208
+ self,
209
+ channels: Optional[int] = None,
210
+ out_channels: Optional[int] = None,
211
+ use_conv: bool = False,
212
+ fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
213
+ ):
214
+ super().__init__()
215
+ out_channels = out_channels if out_channels else channels
216
+ if use_conv:
217
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
218
+ self.use_conv = use_conv
219
+ self.fir_kernel = fir_kernel
220
+ self.out_channels = out_channels
221
+
222
+ def _upsample_2d(
223
+ self,
224
+ hidden_states: torch.Tensor,
225
+ weight: Optional[torch.Tensor] = None,
226
+ kernel: Optional[torch.Tensor] = None,
227
+ factor: int = 2,
228
+ gain: float = 1,
229
+ ) -> torch.Tensor:
230
+ """Fused `upsample_2d()` followed by `Conv2d()`.
231
+
232
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
233
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
234
+ arbitrary order.
235
+
236
+ Args:
237
+ hidden_states (`torch.Tensor`):
238
+ Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
239
+ weight (`torch.Tensor`, *optional*):
240
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
241
+ performed by `inChannels = x.shape[0] // numGroups`.
242
+ kernel (`torch.Tensor`, *optional*):
243
+ FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
244
+ corresponds to nearest-neighbor upsampling.
245
+ factor (`int`, *optional*): Integer upsampling factor (default: 2).
246
+ gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
247
+
248
+ Returns:
249
+ output (`torch.Tensor`):
250
+ Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
251
+ datatype as `hidden_states`.
252
+ """
253
+
254
+ assert isinstance(factor, int) and factor >= 1
255
+
256
+ # Setup filter kernel.
257
+ if kernel is None:
258
+ kernel = [1] * factor
259
+
260
+ # setup kernel
261
+ kernel = torch.tensor(kernel, dtype=torch.float32)
262
+ if kernel.ndim == 1:
263
+ kernel = torch.outer(kernel, kernel)
264
+ kernel /= torch.sum(kernel)
265
+
266
+ kernel = kernel * (gain * (factor**2))
267
+
268
+ if self.use_conv:
269
+ convH = weight.shape[2]
270
+ convW = weight.shape[3]
271
+ inC = weight.shape[1]
272
+
273
+ pad_value = (kernel.shape[0] - factor) - (convW - 1)
274
+
275
+ stride = (factor, factor)
276
+ # Determine data dimensions.
277
+ output_shape = (
278
+ (hidden_states.shape[2] - 1) * factor + convH,
279
+ (hidden_states.shape[3] - 1) * factor + convW,
280
+ )
281
+ output_padding = (
282
+ output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
283
+ output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
284
+ )
285
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
286
+ num_groups = hidden_states.shape[1] // inC
287
+
288
+ # Transpose weights.
289
+ weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
290
+ weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
291
+ weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
292
+
293
+ inverse_conv = F.conv_transpose2d(
294
+ hidden_states,
295
+ weight,
296
+ stride=stride,
297
+ output_padding=output_padding,
298
+ padding=0,
299
+ )
300
+
301
+ output = upfirdn2d_native(
302
+ inverse_conv,
303
+ torch.tensor(kernel, device=inverse_conv.device),
304
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
305
+ )
306
+ else:
307
+ pad_value = kernel.shape[0] - factor
308
+ output = upfirdn2d_native(
309
+ hidden_states,
310
+ torch.tensor(kernel, device=hidden_states.device),
311
+ up=factor,
312
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
313
+ )
314
+
315
+ return output
316
+
317
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
318
+ if self.use_conv:
319
+ height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
320
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
321
+ else:
322
+ height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
323
+
324
+ return height
325
+
326
+
327
+ class KUpsample2D(nn.Module):
328
+ r"""A 2D K-upsampling layer.
329
+
330
+ Parameters:
331
+ pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
332
+ """
333
+
334
+ def __init__(self, pad_mode: str = "reflect"):
335
+ super().__init__()
336
+ self.pad_mode = pad_mode
337
+ kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
338
+ self.pad = kernel_1d.shape[1] // 2 - 1
339
+ self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
340
+
341
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
342
+ inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
343
+ weight = inputs.new_zeros(
344
+ [
345
+ inputs.shape[1],
346
+ inputs.shape[1],
347
+ self.kernel.shape[0],
348
+ self.kernel.shape[1],
349
+ ]
350
+ )
351
+ indices = torch.arange(inputs.shape[1], device=inputs.device)
352
+ kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
353
+ weight[indices, indices] = kernel
354
+ return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
355
+
356
+
357
+ class CogVideoXUpsample3D(nn.Module):
358
+ r"""
359
+ A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
360
+
361
+ Args:
362
+ in_channels (`int`):
363
+ Number of channels in the input image.
364
+ out_channels (`int`):
365
+ Number of channels produced by the convolution.
366
+ kernel_size (`int`, defaults to `3`):
367
+ Size of the convolving kernel.
368
+ stride (`int`, defaults to `1`):
369
+ Stride of the convolution.
370
+ padding (`int`, defaults to `1`):
371
+ Padding added to all four sides of the input.
372
+ compress_time (`bool`, defaults to `False`):
373
+ Whether or not to compress the time dimension.
374
+ """
375
+
376
+ def __init__(
377
+ self,
378
+ in_channels: int,
379
+ out_channels: int,
380
+ kernel_size: int = 3,
381
+ stride: int = 1,
382
+ padding: int = 1,
383
+ compress_time: bool = False,
384
+ ) -> None:
385
+ super().__init__()
386
+
387
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
388
+ self.compress_time = compress_time
389
+
390
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
391
+ if self.compress_time:
392
+ if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
393
+ # split first frame
394
+ x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
395
+
396
+ x_first = F.interpolate(x_first, scale_factor=2.0)
397
+ x_rest = F.interpolate(x_rest, scale_factor=2.0)
398
+ x_first = x_first[:, :, None, :, :]
399
+ inputs = torch.cat([x_first, x_rest], dim=2)
400
+ elif inputs.shape[2] > 1:
401
+ inputs = F.interpolate(inputs, scale_factor=2.0)
402
+ else:
403
+ inputs = inputs.squeeze(2)
404
+ inputs = F.interpolate(inputs, scale_factor=2.0)
405
+ inputs = inputs[:, :, None, :, :]
406
+ else:
407
+ # only interpolate 2D
408
+ b, c, t, h, w = inputs.shape
409
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
410
+ inputs = F.interpolate(inputs, scale_factor=2.0)
411
+ inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
412
+
413
+ b, c, t, h, w = inputs.shape
414
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
415
+ inputs = self.conv(inputs)
416
+ inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
417
+
418
+ return inputs
419
+
420
+
421
+ def upfirdn2d_native(
422
+ tensor: torch.Tensor,
423
+ kernel: torch.Tensor,
424
+ up: int = 1,
425
+ down: int = 1,
426
+ pad: Tuple[int, int] = (0, 0),
427
+ ) -> torch.Tensor:
428
+ up_x = up_y = up
429
+ down_x = down_y = down
430
+ pad_x0 = pad_y0 = pad[0]
431
+ pad_x1 = pad_y1 = pad[1]
432
+
433
+ _, channel, in_h, in_w = tensor.shape
434
+ tensor = tensor.reshape(-1, in_h, in_w, 1)
435
+
436
+ _, in_h, in_w, minor = tensor.shape
437
+ kernel_h, kernel_w = kernel.shape
438
+
439
+ out = tensor.view(-1, in_h, 1, in_w, 1, minor)
440
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
441
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
442
+
443
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
444
+ out = out.to(tensor.device) # Move back to mps if necessary
445
+ out = out[
446
+ :,
447
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
448
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
449
+ :,
450
+ ]
451
+
452
+ out = out.permute(0, 3, 1, 2)
453
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
454
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
455
+ out = F.conv2d(out, w)
456
+ out = out.reshape(
457
+ -1,
458
+ minor,
459
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
460
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
461
+ )
462
+ out = out.permute(0, 2, 3, 1)
463
+ out = out[:, ::down_y, ::down_x, :]
464
+
465
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
466
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
467
+
468
+ return out.view(-1, channel, out_h, out_w)
469
+
470
+
471
+ def upsample_2d(
472
+ hidden_states: torch.Tensor,
473
+ kernel: Optional[torch.Tensor] = None,
474
+ factor: int = 2,
475
+ gain: float = 1,
476
+ ) -> torch.Tensor:
477
+ r"""Upsample2D a batch of 2D images with the given filter.
478
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
479
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
480
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
481
+ a: multiple of the upsampling factor.
482
+
483
+ Args:
484
+ hidden_states (`torch.Tensor`):
485
+ Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
486
+ kernel (`torch.Tensor`, *optional*):
487
+ FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
488
+ corresponds to nearest-neighbor upsampling.
489
+ factor (`int`, *optional*, default to `2`):
490
+ Integer upsampling factor.
491
+ gain (`float`, *optional*, default to `1.0`):
492
+ Scaling factor for signal magnitude (default: 1.0).
493
+
494
+ Returns:
495
+ output (`torch.Tensor`):
496
+ Tensor of the shape `[N, C, H * factor, W * factor]`
497
+ """
498
+ assert isinstance(factor, int) and factor >= 1
499
+ if kernel is None:
500
+ kernel = [1] * factor
501
+
502
+ kernel = torch.tensor(kernel, dtype=torch.float32)
503
+ if kernel.ndim == 1:
504
+ kernel = torch.outer(kernel, kernel)
505
+ kernel /= torch.sum(kernel)
506
+
507
+ kernel = kernel * (gain * (factor**2))
508
+ pad_value = kernel.shape[0] - factor
509
+ output = upfirdn2d_native(
510
+ hidden_states,
511
+ kernel.to(device=hidden_states.device),
512
+ up=factor,
513
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
514
+ )
515
+ return output
Marigold/unet/__init__.py ADDED
File without changes
Marigold/unet/attention.py ADDED
@@ -0,0 +1,1342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import deprecate, logging
21
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
22
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
23
+ from diffusers.models.attention_processor import JointAttnProcessor2_0
24
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
25
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
26
+
27
+ from Marigold.unet.attention_processor import Attention
28
+ from einops import rearrange
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
34
+ # "feed_forward_chunk_size" can be used to save memory
35
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
36
+ raise ValueError(
37
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
38
+ )
39
+
40
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
41
+ ff_output = torch.cat(
42
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
43
+ dim=chunk_dim,
44
+ )
45
+ return ff_output
46
+
47
+
48
+ @maybe_allow_in_graph
49
+ class GatedSelfAttentionDense(nn.Module):
50
+ r"""
51
+ A gated self-attention dense layer that combines visual features and object features.
52
+
53
+ Parameters:
54
+ query_dim (`int`): The number of channels in the query.
55
+ context_dim (`int`): The number of channels in the context.
56
+ n_heads (`int`): The number of heads to use for attention.
57
+ d_head (`int`): The number of channels in each head.
58
+ """
59
+
60
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
61
+ super().__init__()
62
+
63
+ # we need a linear projection since we need cat visual feature and obj feature
64
+ self.linear = nn.Linear(context_dim, query_dim)
65
+
66
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
67
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
68
+
69
+ self.norm1 = nn.LayerNorm(query_dim)
70
+ self.norm2 = nn.LayerNorm(query_dim)
71
+
72
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
73
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
74
+
75
+ self.enabled = True
76
+
77
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
78
+ if not self.enabled:
79
+ return x
80
+
81
+ n_visual = x.shape[1]
82
+ objs = self.linear(objs)
83
+
84
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
85
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
86
+
87
+ return x
88
+
89
+
90
+ @maybe_allow_in_graph
91
+ class JointTransformerBlock(nn.Module):
92
+ r"""
93
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
94
+
95
+ Reference: https://arxiv.org/abs/2403.03206
96
+
97
+ Parameters:
98
+ dim (`int`): The number of channels in the input and output.
99
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
100
+ attention_head_dim (`int`): The number of channels in each head.
101
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
102
+ processing of `context` conditions.
103
+ """
104
+
105
+ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
106
+ super().__init__()
107
+
108
+ self.context_pre_only = context_pre_only
109
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
110
+
111
+ self.norm1 = AdaLayerNormZero(dim)
112
+
113
+ if context_norm_type == "ada_norm_continous":
114
+ self.norm1_context = AdaLayerNormContinuous(
115
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
116
+ )
117
+ elif context_norm_type == "ada_norm_zero":
118
+ self.norm1_context = AdaLayerNormZero(dim)
119
+ else:
120
+ raise ValueError(
121
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
122
+ )
123
+ if hasattr(F, "scaled_dot_product_attention"):
124
+ processor = JointAttnProcessor2_0()
125
+ else:
126
+ raise ValueError(
127
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
128
+ )
129
+ self.attn = Attention(
130
+ query_dim=dim,
131
+ cross_attention_dim=None,
132
+ added_kv_proj_dim=dim,
133
+ dim_head=attention_head_dim,
134
+ heads=num_attention_heads,
135
+ out_dim=dim,
136
+ context_pre_only=context_pre_only,
137
+ bias=True,
138
+ processor=processor,
139
+ )
140
+
141
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
142
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
143
+
144
+ if not context_pre_only:
145
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
146
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
147
+ else:
148
+ self.norm2_context = None
149
+ self.ff_context = None
150
+
151
+ # let chunk size default to None
152
+ self._chunk_size = None
153
+ self._chunk_dim = 0
154
+
155
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
156
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
157
+ # Sets chunk feed-forward
158
+ self._chunk_size = chunk_size
159
+ self._chunk_dim = dim
160
+
161
+ def forward(
162
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
163
+ ):
164
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
165
+
166
+ if self.context_pre_only:
167
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
168
+ else:
169
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
170
+ encoder_hidden_states, emb=temb
171
+ )
172
+
173
+ # Attention.
174
+ attn_output, context_attn_output = self.attn(
175
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
176
+ )
177
+
178
+ # Process attention outputs for the `hidden_states`.
179
+ attn_output = gate_msa.unsqueeze(1) * attn_output
180
+ hidden_states = hidden_states + attn_output
181
+
182
+ norm_hidden_states = self.norm2(hidden_states)
183
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
184
+ if self._chunk_size is not None:
185
+ # "feed_forward_chunk_size" can be used to save memory
186
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
187
+ else:
188
+ ff_output = self.ff(norm_hidden_states)
189
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
190
+
191
+ hidden_states = hidden_states + ff_output
192
+
193
+ # Process attention outputs for the `encoder_hidden_states`.
194
+ if self.context_pre_only:
195
+ encoder_hidden_states = None
196
+ else:
197
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
198
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
199
+
200
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
201
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
202
+ if self._chunk_size is not None:
203
+ # "feed_forward_chunk_size" can be used to save memory
204
+ context_ff_output = _chunked_feed_forward(
205
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
206
+ )
207
+ else:
208
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
209
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
210
+
211
+ return encoder_hidden_states, hidden_states
212
+
213
+
214
+
215
+ @maybe_allow_in_graph
216
+ class BasicTransformerBlock(nn.Module):
217
+ r"""
218
+ A basic Transformer block.
219
+
220
+ Parameters:
221
+ dim (`int`): The number of channels in the input and output.
222
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
223
+ attention_head_dim (`int`): The number of channels in each head.
224
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
225
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
226
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
227
+ num_embeds_ada_norm (:
228
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
229
+ attention_bias (:
230
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
231
+ only_cross_attention (`bool`, *optional*):
232
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
233
+ double_self_attention (`bool`, *optional*):
234
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
235
+ upcast_attention (`bool`, *optional*):
236
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
237
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
238
+ Whether to use learnable elementwise affine parameters for normalization.
239
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
240
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
241
+ final_dropout (`bool` *optional*, defaults to False):
242
+ Whether to apply a final dropout after the last feed-forward layer.
243
+ attention_type (`str`, *optional*, defaults to `"default"`):
244
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
245
+ positional_embeddings (`str`, *optional*, defaults to `None`):
246
+ The type of positional embeddings to apply to.
247
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
248
+ The maximum number of positional embeddings to apply.
249
+ """
250
+
251
+ def __init__(
252
+ self,
253
+ dim: int,
254
+ num_attention_heads: int,
255
+ attention_head_dim: int,
256
+ dropout=0.0,
257
+ cross_attention_dim: Optional[int] = None,
258
+ activation_fn: str = "geglu",
259
+ num_embeds_ada_norm: Optional[int] = None,
260
+ attention_bias: bool = False,
261
+ only_cross_attention: bool = False,
262
+ double_self_attention: bool = False,
263
+ upcast_attention: bool = False,
264
+ norm_elementwise_affine: bool = True,
265
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
266
+ norm_eps: float = 1e-5,
267
+ final_dropout: bool = False,
268
+ attention_type: str = "default",
269
+ positional_embeddings: Optional[str] = None,
270
+ num_positional_embeddings: Optional[int] = None,
271
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
272
+ ada_norm_bias: Optional[int] = None,
273
+ ff_inner_dim: Optional[int] = None,
274
+ ff_bias: bool = True,
275
+ attention_out_bias: bool = True,
276
+ use_RoPE: bool = False,
277
+ ):
278
+ super().__init__()
279
+ self.dim = dim
280
+ self.num_attention_heads = num_attention_heads
281
+ self.attention_head_dim = attention_head_dim
282
+ self.dropout = dropout
283
+ self.cross_attention_dim = cross_attention_dim
284
+ self.activation_fn = activation_fn
285
+ self.attention_bias = attention_bias
286
+ self.double_self_attention = double_self_attention
287
+ self.norm_elementwise_affine = norm_elementwise_affine
288
+ self.positional_embeddings = positional_embeddings
289
+ self.num_positional_embeddings = num_positional_embeddings
290
+ self.only_cross_attention = only_cross_attention
291
+ self.use_RoPE = use_RoPE
292
+
293
+ # We keep these boolean flags for backward-compatibility.
294
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
295
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
296
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
297
+ self.use_layer_norm = norm_type == "layer_norm"
298
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
299
+
300
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
301
+ raise ValueError(
302
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
303
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
304
+ )
305
+
306
+ self.norm_type = norm_type
307
+ self.num_embeds_ada_norm = num_embeds_ada_norm
308
+
309
+ if positional_embeddings and (num_positional_embeddings is None):
310
+ raise ValueError(
311
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
312
+ )
313
+
314
+ if positional_embeddings == "sinusoidal":
315
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
316
+ else:
317
+ self.pos_embed = None
318
+
319
+ # Define 3 blocks. Each block has its own normalization layer.
320
+ # 1. Self-Attn
321
+ if norm_type == "ada_norm":
322
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
323
+ elif norm_type == "ada_norm_zero":
324
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
325
+ elif norm_type == "ada_norm_continuous":
326
+ self.norm1 = AdaLayerNormContinuous(
327
+ dim,
328
+ ada_norm_continous_conditioning_embedding_dim,
329
+ norm_elementwise_affine,
330
+ norm_eps,
331
+ ada_norm_bias,
332
+ "rms_norm",
333
+ )
334
+ else:
335
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
336
+
337
+ self.attn1 = Attention(
338
+ query_dim=dim,
339
+ heads=num_attention_heads,
340
+ dim_head=attention_head_dim,
341
+ dropout=dropout,
342
+ bias=attention_bias,
343
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
344
+ upcast_attention=upcast_attention,
345
+ out_bias=attention_out_bias,
346
+ use_RoPE=use_RoPE,
347
+ )
348
+
349
+ # 2. Cross-Attn
350
+ if cross_attention_dim is not None or double_self_attention:
351
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
352
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
353
+ # the second cross attention block.
354
+ if norm_type == "ada_norm":
355
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
356
+ elif norm_type == "ada_norm_continuous":
357
+ self.norm2 = AdaLayerNormContinuous(
358
+ dim,
359
+ ada_norm_continous_conditioning_embedding_dim,
360
+ norm_elementwise_affine,
361
+ norm_eps,
362
+ ada_norm_bias,
363
+ "rms_norm",
364
+ )
365
+ else:
366
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
367
+
368
+ self.attn2 = Attention(
369
+ query_dim=dim,
370
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
371
+ heads=num_attention_heads,
372
+ dim_head=attention_head_dim,
373
+ dropout=dropout,
374
+ bias=attention_bias,
375
+ upcast_attention=upcast_attention,
376
+ out_bias=attention_out_bias,
377
+ use_RoPE=use_RoPE,
378
+ ) # is self-attn if encoder_hidden_states is none
379
+ else:
380
+ if norm_type == "ada_norm_single": # For Latte
381
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
382
+ else:
383
+ self.norm2 = None
384
+ self.attn2 = None
385
+
386
+ # 3. Feed-forward
387
+ if norm_type == "ada_norm_continuous":
388
+ self.norm3 = AdaLayerNormContinuous(
389
+ dim,
390
+ ada_norm_continous_conditioning_embedding_dim,
391
+ norm_elementwise_affine,
392
+ norm_eps,
393
+ ada_norm_bias,
394
+ "layer_norm",
395
+ )
396
+
397
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
398
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
399
+ elif norm_type == "layer_norm_i2vgen":
400
+ self.norm3 = None
401
+
402
+ self.ff = FeedForward(
403
+ dim,
404
+ dropout=dropout,
405
+ activation_fn=activation_fn,
406
+ final_dropout=final_dropout,
407
+ inner_dim=ff_inner_dim,
408
+ bias=ff_bias,
409
+ )
410
+
411
+ # 4. Fuser
412
+ if attention_type == "gated" or attention_type == "gated-text-image":
413
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
414
+
415
+ # 5. Scale-shift for PixArt-Alpha.
416
+ if norm_type == "ada_norm_single":
417
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
418
+
419
+ # let chunk size default to None
420
+ self._chunk_size = None
421
+ self._chunk_dim = 0
422
+
423
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
424
+ # Sets chunk feed-forward
425
+ self._chunk_size = chunk_size
426
+ self._chunk_dim = dim
427
+
428
+ def forward(
429
+ self,
430
+ hidden_states: torch.Tensor,
431
+ attention_mask: Optional[torch.Tensor] = None,
432
+ encoder_hidden_states: Optional[torch.Tensor] = None,
433
+ encoder_attention_mask: Optional[torch.Tensor] = None,
434
+ timestep: Optional[torch.LongTensor] = None,
435
+ cross_attention_kwargs: Dict[str, Any] = None,
436
+ class_labels: Optional[torch.LongTensor] = None,
437
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
438
+ ) -> torch.Tensor:
439
+ if cross_attention_kwargs is not None:
440
+ if cross_attention_kwargs.get("scale", None) is not None:
441
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
442
+
443
+ # Notice that normalization is always applied before the real computation in the following blocks.
444
+ # 0. Self-Attention
445
+ batch_size = hidden_states.shape[0]
446
+ if self.norm_type == "ada_norm":
447
+ norm_hidden_states = self.norm1(hidden_states, timestep)
448
+ elif self.norm_type == "ada_norm_zero":
449
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
450
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
451
+ )
452
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
453
+ norm_hidden_states = self.norm1(hidden_states)
454
+ elif self.norm_type == "ada_norm_continuous":
455
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
456
+ elif self.norm_type == "ada_norm_single":
457
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
458
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
459
+ ).chunk(6, dim=1)
460
+ norm_hidden_states = self.norm1(hidden_states)
461
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
462
+ else:
463
+ raise ValueError("Incorrect norm used")
464
+ if self.pos_embed is not None:
465
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
466
+
467
+ # 1. Prepare GLIGEN inputs
468
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
469
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
470
+
471
+ attn_output = self.attn1(
472
+ norm_hidden_states,
473
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
474
+ attention_mask=attention_mask,
475
+ **cross_attention_kwargs,
476
+ )
477
+
478
+ if self.norm_type == "ada_norm_zero":
479
+ attn_output = gate_msa.unsqueeze(1) * attn_output
480
+ elif self.norm_type == "ada_norm_single":
481
+ attn_output = gate_msa * attn_output
482
+
483
+
484
+ hidden_states = attn_output + hidden_states
485
+ if hidden_states.ndim == 4:
486
+ hidden_states = hidden_states.squeeze(1)
487
+
488
+ # 1.2 GLIGEN Control
489
+ if gligen_kwargs is not None:
490
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
491
+
492
+ # 3. Cross-Attention
493
+ if self.attn2 is not None:
494
+ if self.norm_type == "ada_norm":
495
+ norm_hidden_states = self.norm2(hidden_states, timestep)
496
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
497
+ norm_hidden_states = self.norm2(hidden_states)
498
+ elif self.norm_type == "ada_norm_single":
499
+ # For PixArt norm2 isn't applied here:
500
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
501
+ norm_hidden_states = hidden_states
502
+ elif self.norm_type == "ada_norm_continuous":
503
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
504
+ else:
505
+ raise ValueError("Incorrect norm")
506
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
507
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
508
+
509
+ attn_output = self.attn2(
510
+ norm_hidden_states,
511
+ encoder_hidden_states=encoder_hidden_states,
512
+ attention_mask=encoder_attention_mask,
513
+ **cross_attention_kwargs,
514
+ )
515
+ hidden_states = attn_output + hidden_states
516
+
517
+ # 4. Feed-forward
518
+ # i2vgen doesn't have this norm 🤷‍♂️
519
+ if self.norm_type == "ada_norm_continuous":
520
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
521
+ elif not self.norm_type == "ada_norm_single":
522
+ norm_hidden_states = self.norm3(hidden_states)
523
+
524
+ if self.norm_type == "ada_norm_zero":
525
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
526
+
527
+ if self.norm_type == "ada_norm_single":
528
+ norm_hidden_states = self.norm2(hidden_states)
529
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
530
+
531
+ if self._chunk_size is not None:
532
+ # "feed_forward_chunk_size" can be used to save memory
533
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
534
+ else:
535
+ ff_output = self.ff(norm_hidden_states)
536
+
537
+ if self.norm_type == "ada_norm_zero":
538
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
539
+ elif self.norm_type == "ada_norm_single":
540
+ ff_output = gate_mlp * ff_output
541
+
542
+ hidden_states = ff_output + hidden_states
543
+ if hidden_states.ndim == 4:
544
+ hidden_states = hidden_states.squeeze(1)
545
+
546
+ return hidden_states
547
+
548
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
549
+ # Sets chunk feed-forward
550
+ self._chunk_size = chunk_size
551
+ self._chunk_dim = dim
552
+
553
+ def forward(
554
+ self,
555
+ hidden_states: torch.Tensor,
556
+ attention_mask: Optional[torch.Tensor] = None,
557
+ encoder_hidden_states: Optional[torch.Tensor] = None,
558
+ encoder_attention_mask: Optional[torch.Tensor] = None,
559
+ timestep: Optional[torch.LongTensor] = None,
560
+ cross_attention_kwargs: Dict[str, Any] = None,
561
+ class_labels: Optional[torch.LongTensor] = None,
562
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
563
+ num_cube_faces = 6,
564
+ ) -> torch.Tensor:
565
+ if cross_attention_kwargs is not None:
566
+ if cross_attention_kwargs.get("scale", None) is not None:
567
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
568
+
569
+ # Notice that normalization is always applied before the real computation in the following blocks.
570
+ # 0. Self-Attention
571
+ batch_size = hidden_states.shape[0]
572
+
573
+ if self.norm_type == "ada_norm":
574
+ # sync_norm:
575
+ hidden_states = rearrange(hidden_states, '(b t) l c -> b (t l) c', t=6)
576
+ norm_hidden_states = self.norm1(hidden_states, timestep)
577
+ norm_hidden_states = rearrange(norm_hidden_states, 'b (t l) c -> (b t) l c', t=6)
578
+ hidden_states = rearrange(hidden_states, 'b (t l) c -> (b t) l c', t=6)
579
+ elif self.norm_type == "ada_norm_zero":
580
+ # sync_norm:
581
+ hidden_states = rearrange(hidden_states, '(b t) l c -> b (t l) c', t=6)
582
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
583
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
584
+ )
585
+ norm_hidden_states = rearrange(norm_hidden_states, 'b (t l) c -> (b t) l c', t=6)
586
+ hidden_states = rearrange(hidden_states, 'b (t l) c -> (b t) l c', t=6)
587
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
588
+ # sync_norm:
589
+ hidden_states = rearrange(hidden_states, '(b t) l c -> b (t l) c', t=6)
590
+ norm_hidden_states = self.norm1(hidden_states)
591
+ norm_hidden_states = rearrange(norm_hidden_states, 'b (t l) c -> (b t) l c', t=6)
592
+ hidden_states = rearrange(hidden_states, 'b (t l) c -> (b t) l c', t=6)
593
+ elif self.norm_type == "ada_norm_continuous":
594
+ # sync_norm:
595
+ hidden_states = rearrange(hidden_states, '(b t) l c -> b (t l) c', t=6)
596
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
597
+ norm_hidden_states = rearrange(norm_hidden_states, 'b (t l) c -> (b t) l c', t=6)
598
+ hidden_states = rearrange(hidden_states, 'b (t l) c -> (b t) l c', t=6)
599
+ elif self.norm_type == "ada_norm_single":
600
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
601
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
602
+ ).chunk(6, dim=1)
603
+ # sync_norm:
604
+ hidden_states = rearrange(hidden_states, '(b t) l c -> b (t l) c', t=6)
605
+ norm_hidden_states = self.norm1(hidden_states)
606
+ norm_hidden_states = rearrange(norm_hidden_states, 'b (t l) c -> (b t) l c', t=6)
607
+ hidden_states = rearrange(hidden_states, 'b (t l) c -> (b t) l c', t=6)
608
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
609
+ else:
610
+ raise ValueError("Incorrect norm used")
611
+
612
+ if self.pos_embed is not None:
613
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
614
+
615
+ # 1. Prepare GLIGEN inputs
616
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
617
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
618
+
619
+ attn_output = self.attn1(
620
+ norm_hidden_states,
621
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
622
+ attention_mask=attention_mask,
623
+ **cross_attention_kwargs,
624
+ )
625
+
626
+ if self.norm_type == "ada_norm_zero":
627
+ attn_output = gate_msa.unsqueeze(1) * attn_output
628
+ elif self.norm_type == "ada_norm_single":
629
+ attn_output = gate_msa * attn_output
630
+
631
+ hidden_states = attn_output + hidden_states
632
+
633
+ if hidden_states.ndim == 4:
634
+ hidden_states = hidden_states.squeeze(1)
635
+
636
+ # 1.2 GLIGEN Control
637
+ if gligen_kwargs is not None:
638
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
639
+
640
+ # 3. Cross-Attention
641
+ if self.attn2 is not None:
642
+ if self.norm_type == "ada_norm":
643
+ # sync_norm:
644
+ hidden_states = rearrange(hidden_states, '(b t) l c -> b (t l) c', t=6)
645
+ norm_hidden_states = self.norm2(hidden_states, timestep)
646
+ norm_hidden_states = rearrange(norm_hidden_states, 'b (t l) c -> (b t) l c', t=6)
647
+ hidden_states = rearrange(hidden_states, 'b (t l) c -> (b t) l c', t=6)
648
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
649
+ # sync_norm:
650
+ hidden_states = rearrange(hidden_states, '(b t) l c -> b (t l) c', t=6)
651
+ norm_hidden_states = self.norm2(hidden_states)
652
+ norm_hidden_states = rearrange(norm_hidden_states, 'b (t l) c -> (b t) l c', t=6)
653
+ hidden_states = rearrange(hidden_states, 'b (t l) c -> (b t) l c', t=6)
654
+ elif self.norm_type == "ada_norm_single":
655
+ # For PixArt norm2 isn't applied here:
656
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
657
+ norm_hidden_states = hidden_states
658
+ elif self.norm_type == "ada_norm_continuous":
659
+ # sync_norm:
660
+ hidden_states = rearrange(hidden_states, '(b t) l c -> b (t l) c', t=6)
661
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
662
+ norm_hidden_states = rearrange(norm_hidden_states, 'b (t l) c -> (b t) l c', t=6)
663
+ hidden_states = rearrange(hidden_states, 'b (t l) c -> (b t) l c', t=6)
664
+ else:
665
+ raise ValueError("Incorrect norm")
666
+
667
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
668
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
669
+
670
+ attn_output = self.attn2(
671
+ norm_hidden_states,
672
+ encoder_hidden_states=encoder_hidden_states,
673
+ attention_mask=encoder_attention_mask,
674
+ **cross_attention_kwargs,
675
+ )
676
+ hidden_states = attn_output + hidden_states
677
+
678
+ # 4. Feed-forward
679
+ # i2vgen doesn't have this norm 🤷‍♂️
680
+ if self.norm_type == "ada_norm_continuous":
681
+ # sync_norm:
682
+ hidden_states = rearrange(hidden_states, '(b t) l c -> b (t l) c', t=num_cube_faces)
683
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
684
+ norm_hidden_states = rearrange(norm_hidden_states, 'b (t l) c -> (b t) l c', t=num_cube_faces)
685
+ hidden_states = rearrange(hidden_states, 'b (t l) c -> (b t) l c', t=num_cube_faces)
686
+ elif not self.norm_type == "ada_norm_single":
687
+ # sync_norm:
688
+ hidden_states = rearrange(hidden_states, '(b t) l c -> b (t l) c', t=num_cube_faces)
689
+ norm_hidden_states = self.norm3(hidden_states)
690
+ norm_hidden_states = rearrange(norm_hidden_states, 'b (t l) c -> (b t) l c', t=num_cube_faces)
691
+ hidden_states = rearrange(hidden_states, 'b (t l) c -> (b t) l c', t=num_cube_faces)
692
+ if self.norm_type == "ada_norm_zero":
693
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
694
+
695
+ if self.norm_type == "ada_norm_single":
696
+ # sync_norm:
697
+ hidden_states = rearrange(hidden_states, '(b t) l c -> b (t l) c', t=6)
698
+ norm_hidden_states = self.norm2(hidden_states)
699
+ hidden_states = rearrange(hidden_states, '(b t) l c -> (b t) l c', t=6)
700
+ norm_hidden_states = rearrange(norm_hidden_states, 'b (t l) c -> (b t) l c', t=6)
701
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
702
+ if self._chunk_size is not None:
703
+ # "feed_forward_chunk_size" can be used to save memory
704
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
705
+ else:
706
+ ff_output = self.ff(norm_hidden_states)
707
+
708
+ if self.norm_type == "ada_norm_zero":
709
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
710
+ elif self.norm_type == "ada_norm_single":
711
+ ff_output = gate_mlp * ff_output
712
+
713
+ hidden_states = ff_output + hidden_states
714
+ if hidden_states.ndim == 4:
715
+ hidden_states = hidden_states.squeeze(1)
716
+
717
+ return hidden_states
718
+
719
+ class LuminaFeedForward(nn.Module):
720
+ r"""
721
+ A feed-forward layer.
722
+
723
+ Parameters:
724
+ hidden_size (`int`):
725
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
726
+ hidden representations.
727
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
728
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
729
+ of this value.
730
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
731
+ dimension. Defaults to None.
732
+ """
733
+
734
+ def __init__(
735
+ self,
736
+ dim: int,
737
+ inner_dim: int,
738
+ multiple_of: Optional[int] = 256,
739
+ ffn_dim_multiplier: Optional[float] = None,
740
+ ):
741
+ super().__init__()
742
+ inner_dim = int(2 * inner_dim / 3)
743
+ # custom hidden_size factor multiplier
744
+ if ffn_dim_multiplier is not None:
745
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
746
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
747
+
748
+ self.linear_1 = nn.Linear(
749
+ dim,
750
+ inner_dim,
751
+ bias=False,
752
+ )
753
+ self.linear_2 = nn.Linear(
754
+ inner_dim,
755
+ dim,
756
+ bias=False,
757
+ )
758
+ self.linear_3 = nn.Linear(
759
+ dim,
760
+ inner_dim,
761
+ bias=False,
762
+ )
763
+ self.silu = FP32SiLU()
764
+
765
+ def forward(self, x):
766
+ return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
767
+
768
+
769
+ @maybe_allow_in_graph
770
+ class TemporalBasicTransformerBlock(nn.Module):
771
+ r"""
772
+ A basic Transformer block for video like data.
773
+
774
+ Parameters:
775
+ dim (`int`): The number of channels in the input and output.
776
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
777
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
778
+ attention_head_dim (`int`): The number of channels in each head.
779
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
780
+ """
781
+
782
+ def __init__(
783
+ self,
784
+ dim: int,
785
+ time_mix_inner_dim: int,
786
+ num_attention_heads: int,
787
+ attention_head_dim: int,
788
+ cross_attention_dim: Optional[int] = None,
789
+ ):
790
+ super().__init__()
791
+ self.is_res = dim == time_mix_inner_dim
792
+
793
+ self.norm_in = nn.LayerNorm(dim)
794
+
795
+ # Define 3 blocks. Each block has its own normalization layer.
796
+ # 1. Self-Attn
797
+ self.ff_in = FeedForward(
798
+ dim,
799
+ dim_out=time_mix_inner_dim,
800
+ activation_fn="geglu",
801
+ )
802
+
803
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
804
+ self.attn1 = Attention(
805
+ query_dim=time_mix_inner_dim,
806
+ heads=num_attention_heads,
807
+ dim_head=attention_head_dim,
808
+ cross_attention_dim=None,
809
+ )
810
+
811
+ # 2. Cross-Attn
812
+ if cross_attention_dim is not None:
813
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
814
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
815
+ # the second cross attention block.
816
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
817
+ self.attn2 = Attention(
818
+ query_dim=time_mix_inner_dim,
819
+ cross_attention_dim=cross_attention_dim,
820
+ heads=num_attention_heads,
821
+ dim_head=attention_head_dim,
822
+ ) # is self-attn if encoder_hidden_states is none
823
+ else:
824
+ self.norm2 = None
825
+ self.attn2 = None
826
+
827
+ # 3. Feed-forward
828
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
829
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
830
+
831
+ # let chunk size default to None
832
+ self._chunk_size = None
833
+ self._chunk_dim = None
834
+
835
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
836
+ # Sets chunk feed-forward
837
+ self._chunk_size = chunk_size
838
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
839
+ self._chunk_dim = 1
840
+
841
+ def forward(
842
+ self,
843
+ hidden_states: torch.Tensor,
844
+ num_frames: int,
845
+ encoder_hidden_states: Optional[torch.Tensor] = None,
846
+ ) -> torch.Tensor:
847
+ # Notice that normalization is always applied before the real computation in the following blocks.
848
+ # 0. Self-Attention
849
+ batch_size = hidden_states.shape[0]
850
+
851
+ batch_frames, seq_length, channels = hidden_states.shape
852
+ batch_size = batch_frames // num_frames
853
+
854
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
855
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
856
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
857
+
858
+ residual = hidden_states
859
+ hidden_states = self.norm_in(hidden_states)
860
+
861
+ if self._chunk_size is not None:
862
+ hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
863
+ else:
864
+ hidden_states = self.ff_in(hidden_states)
865
+
866
+ if self.is_res:
867
+ hidden_states = hidden_states + residual
868
+
869
+ norm_hidden_states = self.norm1(hidden_states)
870
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
871
+ hidden_states = attn_output + hidden_states
872
+
873
+ # 3. Cross-Attention
874
+ if self.attn2 is not None:
875
+ norm_hidden_states = self.norm2(hidden_states)
876
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
877
+ hidden_states = attn_output + hidden_states
878
+
879
+ # 4. Feed-forward
880
+ norm_hidden_states = self.norm3(hidden_states)
881
+
882
+ if self._chunk_size is not None:
883
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
884
+ else:
885
+ ff_output = self.ff(norm_hidden_states)
886
+
887
+ if self.is_res:
888
+ hidden_states = ff_output + hidden_states
889
+ else:
890
+ hidden_states = ff_output
891
+
892
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
893
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
894
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
895
+
896
+ return hidden_states
897
+
898
+
899
+ class SkipFFTransformerBlock(nn.Module):
900
+ def __init__(
901
+ self,
902
+ dim: int,
903
+ num_attention_heads: int,
904
+ attention_head_dim: int,
905
+ kv_input_dim: int,
906
+ kv_input_dim_proj_use_bias: bool,
907
+ dropout=0.0,
908
+ cross_attention_dim: Optional[int] = None,
909
+ attention_bias: bool = False,
910
+ attention_out_bias: bool = True,
911
+ ):
912
+ super().__init__()
913
+ if kv_input_dim != dim:
914
+ self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
915
+ else:
916
+ self.kv_mapper = None
917
+
918
+ self.norm1 = RMSNorm(dim, 1e-06)
919
+
920
+ self.attn1 = Attention(
921
+ query_dim=dim,
922
+ heads=num_attention_heads,
923
+ dim_head=attention_head_dim,
924
+ dropout=dropout,
925
+ bias=attention_bias,
926
+ cross_attention_dim=cross_attention_dim,
927
+ out_bias=attention_out_bias,
928
+ )
929
+
930
+ self.norm2 = RMSNorm(dim, 1e-06)
931
+
932
+ self.attn2 = Attention(
933
+ query_dim=dim,
934
+ cross_attention_dim=cross_attention_dim,
935
+ heads=num_attention_heads,
936
+ dim_head=attention_head_dim,
937
+ dropout=dropout,
938
+ bias=attention_bias,
939
+ out_bias=attention_out_bias,
940
+ )
941
+
942
+ def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
943
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
944
+
945
+ if self.kv_mapper is not None:
946
+ encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
947
+
948
+ norm_hidden_states = self.norm1(hidden_states)
949
+
950
+ attn_output = self.attn1(
951
+ norm_hidden_states,
952
+ encoder_hidden_states=encoder_hidden_states,
953
+ **cross_attention_kwargs,
954
+ )
955
+
956
+ hidden_states = attn_output + hidden_states
957
+
958
+ norm_hidden_states = self.norm2(hidden_states)
959
+
960
+ attn_output = self.attn2(
961
+ norm_hidden_states,
962
+ encoder_hidden_states=encoder_hidden_states,
963
+ **cross_attention_kwargs,
964
+ )
965
+
966
+ hidden_states = attn_output + hidden_states
967
+
968
+ return hidden_states
969
+
970
+
971
+ @maybe_allow_in_graph
972
+ class FreeNoiseTransformerBlock(nn.Module):
973
+ r"""
974
+ A FreeNoise Transformer block.
975
+
976
+ Parameters:
977
+ dim (`int`):
978
+ The number of channels in the input and output.
979
+ num_attention_heads (`int`):
980
+ The number of heads to use for multi-head attention.
981
+ attention_head_dim (`int`):
982
+ The number of channels in each head.
983
+ dropout (`float`, *optional*, defaults to 0.0):
984
+ The dropout probability to use.
985
+ cross_attention_dim (`int`, *optional*):
986
+ The size of the encoder_hidden_states vector for cross attention.
987
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
988
+ Activation function to be used in feed-forward.
989
+ num_embeds_ada_norm (`int`, *optional*):
990
+ The number of diffusion steps used during training. See `Transformer2DModel`.
991
+ attention_bias (`bool`, defaults to `False`):
992
+ Configure if the attentions should contain a bias parameter.
993
+ only_cross_attention (`bool`, defaults to `False`):
994
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
995
+ double_self_attention (`bool`, defaults to `False`):
996
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
997
+ upcast_attention (`bool`, defaults to `False`):
998
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
999
+ norm_elementwise_affine (`bool`, defaults to `True`):
1000
+ Whether to use learnable elementwise affine parameters for normalization.
1001
+ norm_type (`str`, defaults to `"layer_norm"`):
1002
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
1003
+ final_dropout (`bool` defaults to `False`):
1004
+ Whether to apply a final dropout after the last feed-forward layer.
1005
+ attention_type (`str`, defaults to `"default"`):
1006
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
1007
+ positional_embeddings (`str`, *optional*):
1008
+ The type of positional embeddings to apply to.
1009
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
1010
+ The maximum number of positional embeddings to apply.
1011
+ ff_inner_dim (`int`, *optional*):
1012
+ Hidden dimension of feed-forward MLP.
1013
+ ff_bias (`bool`, defaults to `True`):
1014
+ Whether or not to use bias in feed-forward MLP.
1015
+ attention_out_bias (`bool`, defaults to `True`):
1016
+ Whether or not to use bias in attention output project layer.
1017
+ context_length (`int`, defaults to `16`):
1018
+ The maximum number of frames that the FreeNoise block processes at once.
1019
+ context_stride (`int`, defaults to `4`):
1020
+ The number of frames to be skipped before starting to process a new batch of `context_length` frames.
1021
+ weighting_scheme (`str`, defaults to `"pyramid"`):
1022
+ The weighting scheme to use for weighting averaging of processed latent frames. As described in the
1023
+ Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
1024
+ used.
1025
+ """
1026
+
1027
+ def __init__(
1028
+ self,
1029
+ dim: int,
1030
+ num_attention_heads: int,
1031
+ attention_head_dim: int,
1032
+ dropout: float = 0.0,
1033
+ cross_attention_dim: Optional[int] = None,
1034
+ activation_fn: str = "geglu",
1035
+ num_embeds_ada_norm: Optional[int] = None,
1036
+ attention_bias: bool = False,
1037
+ only_cross_attention: bool = False,
1038
+ double_self_attention: bool = False,
1039
+ upcast_attention: bool = False,
1040
+ norm_elementwise_affine: bool = True,
1041
+ norm_type: str = "layer_norm",
1042
+ norm_eps: float = 1e-5,
1043
+ final_dropout: bool = False,
1044
+ positional_embeddings: Optional[str] = None,
1045
+ num_positional_embeddings: Optional[int] = None,
1046
+ ff_inner_dim: Optional[int] = None,
1047
+ ff_bias: bool = True,
1048
+ attention_out_bias: bool = True,
1049
+ context_length: int = 16,
1050
+ context_stride: int = 4,
1051
+ weighting_scheme: str = "pyramid",
1052
+ ):
1053
+ super().__init__()
1054
+ self.dim = dim
1055
+ self.num_attention_heads = num_attention_heads
1056
+ self.attention_head_dim = attention_head_dim
1057
+ self.dropout = dropout
1058
+ self.cross_attention_dim = cross_attention_dim
1059
+ self.activation_fn = activation_fn
1060
+ self.attention_bias = attention_bias
1061
+ self.double_self_attention = double_self_attention
1062
+ self.norm_elementwise_affine = norm_elementwise_affine
1063
+ self.positional_embeddings = positional_embeddings
1064
+ self.num_positional_embeddings = num_positional_embeddings
1065
+ self.only_cross_attention = only_cross_attention
1066
+
1067
+ self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
1068
+
1069
+ # We keep these boolean flags for backward-compatibility.
1070
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
1071
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
1072
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
1073
+ self.use_layer_norm = norm_type == "layer_norm"
1074
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
1075
+
1076
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
1077
+ raise ValueError(
1078
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
1079
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
1080
+ )
1081
+
1082
+ self.norm_type = norm_type
1083
+ self.num_embeds_ada_norm = num_embeds_ada_norm
1084
+
1085
+ if positional_embeddings and (num_positional_embeddings is None):
1086
+ raise ValueError(
1087
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
1088
+ )
1089
+
1090
+ if positional_embeddings == "sinusoidal":
1091
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
1092
+ else:
1093
+ self.pos_embed = None
1094
+
1095
+ # Define 3 blocks. Each block has its own normalization layer.
1096
+ # 1. Self-Attn
1097
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1098
+
1099
+ self.attn1 = Attention(
1100
+ query_dim=dim,
1101
+ heads=num_attention_heads,
1102
+ dim_head=attention_head_dim,
1103
+ dropout=dropout,
1104
+ bias=attention_bias,
1105
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
1106
+ upcast_attention=upcast_attention,
1107
+ out_bias=attention_out_bias,
1108
+ )
1109
+
1110
+ # 2. Cross-Attn
1111
+ if cross_attention_dim is not None or double_self_attention:
1112
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
1113
+
1114
+ self.attn2 = Attention(
1115
+ query_dim=dim,
1116
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
1117
+ heads=num_attention_heads,
1118
+ dim_head=attention_head_dim,
1119
+ dropout=dropout,
1120
+ bias=attention_bias,
1121
+ upcast_attention=upcast_attention,
1122
+ out_bias=attention_out_bias,
1123
+ ) # is self-attn if encoder_hidden_states is none
1124
+
1125
+ # 3. Feed-forward
1126
+ self.ff = FeedForward(
1127
+ dim,
1128
+ dropout=dropout,
1129
+ activation_fn=activation_fn,
1130
+ final_dropout=final_dropout,
1131
+ inner_dim=ff_inner_dim,
1132
+ bias=ff_bias,
1133
+ )
1134
+
1135
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
1136
+
1137
+ # let chunk size default to None
1138
+ self._chunk_size = None
1139
+ self._chunk_dim = 0
1140
+
1141
+ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
1142
+ frame_indices = []
1143
+ for i in range(0, num_frames - self.context_length + 1, self.context_stride):
1144
+ window_start = i
1145
+ window_end = min(num_frames, i + self.context_length)
1146
+ frame_indices.append((window_start, window_end))
1147
+ return frame_indices
1148
+
1149
+ def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
1150
+ if weighting_scheme == "pyramid":
1151
+ if num_frames % 2 == 0:
1152
+ # num_frames = 4 => [1, 2, 2, 1]
1153
+ weights = list(range(1, num_frames // 2 + 1))
1154
+ weights = weights + weights[::-1]
1155
+ else:
1156
+ # num_frames = 5 => [1, 2, 3, 2, 1]
1157
+ weights = list(range(1, num_frames // 2 + 1))
1158
+ weights = weights + [num_frames // 2 + 1] + weights[::-1]
1159
+ else:
1160
+ raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
1161
+
1162
+ return weights
1163
+
1164
+ def set_free_noise_properties(
1165
+ self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
1166
+ ) -> None:
1167
+ self.context_length = context_length
1168
+ self.context_stride = context_stride
1169
+ self.weighting_scheme = weighting_scheme
1170
+
1171
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
1172
+ # Sets chunk feed-forward
1173
+ self._chunk_size = chunk_size
1174
+ self._chunk_dim = dim
1175
+
1176
+ def forward(
1177
+ self,
1178
+ hidden_states: torch.Tensor,
1179
+ attention_mask: Optional[torch.Tensor] = None,
1180
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1181
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1182
+ cross_attention_kwargs: Dict[str, Any] = None,
1183
+ *args,
1184
+ **kwargs,
1185
+ ) -> torch.Tensor:
1186
+ if cross_attention_kwargs is not None:
1187
+ if cross_attention_kwargs.get("scale", None) is not None:
1188
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1189
+
1190
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
1191
+
1192
+ # hidden_states: [B x H x W, F, C]
1193
+ device = hidden_states.device
1194
+ dtype = hidden_states.dtype
1195
+
1196
+ num_frames = hidden_states.size(1)
1197
+ frame_indices = self._get_frame_indices(num_frames)
1198
+ frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
1199
+ frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
1200
+ is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
1201
+
1202
+ # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
1203
+ # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
1204
+ # [(0, 16), (4, 20), (8, 24), (10, 26)]
1205
+ if not is_last_frame_batch_complete:
1206
+ if num_frames < self.context_length:
1207
+ raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
1208
+ last_frame_batch_length = num_frames - frame_indices[-1][1]
1209
+ frame_indices.append((num_frames - self.context_length, num_frames))
1210
+
1211
+ num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
1212
+ accumulated_values = torch.zeros_like(hidden_states)
1213
+
1214
+ for i, (frame_start, frame_end) in enumerate(frame_indices):
1215
+ # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
1216
+ # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
1217
+ # essentially a non-multiple of `context_length`.
1218
+ weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
1219
+ weights *= frame_weights
1220
+
1221
+ hidden_states_chunk = hidden_states[:, frame_start:frame_end]
1222
+
1223
+ # Notice that normalization is always applied before the real computation in the following blocks.
1224
+ # 1. Self-Attention
1225
+ norm_hidden_states = self.norm1(hidden_states_chunk)
1226
+
1227
+ if self.pos_embed is not None:
1228
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1229
+
1230
+ attn_output = self.attn1(
1231
+ norm_hidden_states,
1232
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1233
+ attention_mask=attention_mask,
1234
+ **cross_attention_kwargs,
1235
+ )
1236
+
1237
+ hidden_states_chunk = attn_output + hidden_states_chunk
1238
+ if hidden_states_chunk.ndim == 4:
1239
+ hidden_states_chunk = hidden_states_chunk.squeeze(1)
1240
+
1241
+ # 2. Cross-Attention
1242
+ if self.attn2 is not None:
1243
+ norm_hidden_states = self.norm2(hidden_states_chunk)
1244
+
1245
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
1246
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1247
+
1248
+ attn_output = self.attn2(
1249
+ norm_hidden_states,
1250
+ encoder_hidden_states=encoder_hidden_states,
1251
+ attention_mask=encoder_attention_mask,
1252
+ **cross_attention_kwargs,
1253
+ )
1254
+ hidden_states_chunk = attn_output + hidden_states_chunk
1255
+
1256
+ if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
1257
+ accumulated_values[:, -last_frame_batch_length:] += (
1258
+ hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
1259
+ )
1260
+ num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
1261
+ else:
1262
+ accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
1263
+ num_times_accumulated[:, frame_start:frame_end] += weights
1264
+
1265
+ hidden_states = torch.where(
1266
+ num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1267
+ ).to(dtype)
1268
+
1269
+ # 3. Feed-forward
1270
+ norm_hidden_states = self.norm3(hidden_states)
1271
+
1272
+ if self._chunk_size is not None:
1273
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
1274
+ else:
1275
+ ff_output = self.ff(norm_hidden_states)
1276
+
1277
+ hidden_states = ff_output + hidden_states
1278
+ if hidden_states.ndim == 4:
1279
+ hidden_states = hidden_states.squeeze(1)
1280
+
1281
+ return hidden_states
1282
+
1283
+
1284
+ class FeedForward(nn.Module):
1285
+ r"""
1286
+ A feed-forward layer.
1287
+
1288
+ Parameters:
1289
+ dim (`int`): The number of channels in the input.
1290
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1291
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1292
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1293
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1294
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1295
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
1296
+ """
1297
+
1298
+ def __init__(
1299
+ self,
1300
+ dim: int,
1301
+ dim_out: Optional[int] = None,
1302
+ mult: int = 4,
1303
+ dropout: float = 0.0,
1304
+ activation_fn: str = "geglu",
1305
+ final_dropout: bool = False,
1306
+ inner_dim=None,
1307
+ bias: bool = True,
1308
+ ):
1309
+ super().__init__()
1310
+ if inner_dim is None:
1311
+ inner_dim = int(dim * mult)
1312
+ dim_out = dim_out if dim_out is not None else dim
1313
+
1314
+ if activation_fn == "gelu":
1315
+ act_fn = GELU(dim, inner_dim, bias=bias)
1316
+ if activation_fn == "gelu-approximate":
1317
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
1318
+ elif activation_fn == "geglu":
1319
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
1320
+ elif activation_fn == "geglu-approximate":
1321
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1322
+ elif activation_fn == "swiglu":
1323
+ act_fn = SwiGLU(dim, inner_dim, bias=bias)
1324
+
1325
+ self.net = nn.ModuleList([])
1326
+ # project in
1327
+ self.net.append(act_fn)
1328
+ # project dropout
1329
+ self.net.append(nn.Dropout(dropout))
1330
+ # project out
1331
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
1332
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1333
+ if final_dropout:
1334
+ self.net.append(nn.Dropout(dropout))
1335
+
1336
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1337
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1338
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1339
+ deprecate("scale", "1.0.0", deprecation_message)
1340
+ for module in self.net:
1341
+ hidden_states = module(hidden_states)
1342
+ return hidden_states
Marigold/unet/attention_processor.py ADDED
The diff for this file is too large to render. See raw diff
 
Marigold/unet/transformer_2d.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.configuration_utils import LegacyConfigMixin, register_to_config
21
+ from diffusers.utils import deprecate, is_torch_version, logging
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
23
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
24
+ from diffusers.models.modeling_utils import LegacyModelMixin
25
+ from diffusers.models.normalization import AdaLayerNormSingle
26
+ from einops import rearrange
27
+ from Marigold.unet.attention import BasicTransformerBlock
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ class Transformer2DModelOutput(Transformer2DModelOutput):
33
+ def __init__(self, *args, **kwargs):
34
+ deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead."
35
+ deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
36
+ super().__init__(*args, **kwargs)
37
+
38
+
39
+ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
40
+ """
41
+ A 2D Transformer model for image-like data.
42
+
43
+ Parameters:
44
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
45
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
46
+ in_channels (`int`, *optional*):
47
+ The number of channels in the input and output (specify if the input is **continuous**).
48
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
49
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
50
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
51
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
52
+ This is fixed during training since it is used to learn a number of position embeddings.
53
+ num_vector_embeds (`int`, *optional*):
54
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
55
+ Includes the class for the masked latent pixel.
56
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
57
+ num_embeds_ada_norm ( `int`, *optional*):
58
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
59
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
60
+ added to the hidden states.
61
+
62
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
63
+ attention_bias (`bool`, *optional*):
64
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
65
+ """
66
+
67
+ _supports_gradient_checkpointing = True
68
+ _no_split_modules = ["BasicTransformerBlock"]
69
+
70
+ @register_to_config
71
+ def __init__(
72
+ self,
73
+ num_attention_heads: int = 16,
74
+ attention_head_dim: int = 88,
75
+ in_channels: Optional[int] = None,
76
+ out_channels: Optional[int] = None,
77
+ num_layers: int = 1,
78
+ dropout: float = 0.0,
79
+ norm_num_groups: int = 32,
80
+ cross_attention_dim: Optional[int] = None,
81
+ attention_bias: bool = False,
82
+ sample_size: Optional[int] = None,
83
+ num_vector_embeds: Optional[int] = None,
84
+ patch_size: Optional[int] = None,
85
+ activation_fn: str = "geglu",
86
+ num_embeds_ada_norm: Optional[int] = None,
87
+ use_linear_projection: bool = False,
88
+ only_cross_attention: bool = False,
89
+ double_self_attention: bool = False,
90
+ upcast_attention: bool = False,
91
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
92
+ norm_elementwise_affine: bool = True,
93
+ norm_eps: float = 1e-5,
94
+ attention_type: str = "default",
95
+ caption_channels: int = None,
96
+ interpolation_scale: float = None,
97
+ use_additional_conditions: Optional[bool] = None,
98
+ use_RoPE: bool = False,
99
+ ):
100
+ super().__init__()
101
+
102
+ # Validate inputs.
103
+ if patch_size is not None:
104
+ if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]:
105
+ raise NotImplementedError(
106
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
107
+ )
108
+ elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None:
109
+ raise ValueError(
110
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
111
+ )
112
+
113
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
114
+ # Define whether input is continuous or discrete depending on configuration
115
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
116
+ self.is_input_vectorized = num_vector_embeds is not None
117
+ self.is_input_patches = in_channels is not None and patch_size is not None
118
+
119
+ if self.is_input_continuous and self.is_input_vectorized:
120
+ raise ValueError(
121
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
122
+ " sure that either `in_channels` or `num_vector_embeds` is None."
123
+ )
124
+ elif self.is_input_vectorized and self.is_input_patches:
125
+ raise ValueError(
126
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
127
+ " sure that either `num_vector_embeds` or `num_patches` is None."
128
+ )
129
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
130
+ raise ValueError(
131
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
132
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
133
+ )
134
+
135
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
136
+ deprecation_message = (
137
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
138
+ " incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config."
139
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
140
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
141
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
142
+ )
143
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
144
+ norm_type = "ada_norm"
145
+
146
+ # Set some common variables used across the board.
147
+ self.use_linear_projection = use_linear_projection
148
+ self.interpolation_scale = interpolation_scale
149
+ self.caption_channels = caption_channels
150
+ self.num_attention_heads = num_attention_heads
151
+ self.attention_head_dim = attention_head_dim
152
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
153
+ self.in_channels = in_channels
154
+ self.out_channels = in_channels if out_channels is None else out_channels
155
+ self.gradient_checkpointing = False
156
+ self.use_RoPE = use_RoPE
157
+
158
+
159
+ if use_additional_conditions is None:
160
+ if norm_type == "ada_norm_single" and sample_size == 128:
161
+ use_additional_conditions = True
162
+ else:
163
+ use_additional_conditions = False
164
+ self.use_additional_conditions = use_additional_conditions
165
+
166
+ # 2. Initialize the right blocks.
167
+ # These functions follow a common structure:
168
+ # a. Initialize the input blocks. b. Initialize the transformer blocks.
169
+ # c. Initialize the output blocks and other projection blocks when necessary.
170
+ if self.is_input_continuous:
171
+ self._init_continuous_input(norm_type=norm_type)
172
+ elif self.is_input_vectorized:
173
+ self._init_vectorized_inputs(norm_type=norm_type)
174
+ elif self.is_input_patches:
175
+ self._init_patched_inputs(norm_type=norm_type)
176
+
177
+ def _init_continuous_input(self, norm_type):
178
+ self.norm = torch.nn.GroupNorm(
179
+ num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True
180
+ )
181
+ if self.use_linear_projection:
182
+ self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim)
183
+ else:
184
+ self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0)
185
+
186
+ self.transformer_blocks = nn.ModuleList(
187
+ [
188
+ BasicTransformerBlock(
189
+ self.inner_dim,
190
+ self.config.num_attention_heads,
191
+ self.config.attention_head_dim,
192
+ dropout=self.config.dropout,
193
+ cross_attention_dim=self.config.cross_attention_dim,
194
+ activation_fn=self.config.activation_fn,
195
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
196
+ attention_bias=self.config.attention_bias,
197
+ only_cross_attention=self.config.only_cross_attention,
198
+ double_self_attention=self.config.double_self_attention,
199
+ upcast_attention=self.config.upcast_attention,
200
+ norm_type=norm_type,
201
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
202
+ norm_eps=self.config.norm_eps,
203
+ attention_type=self.config.attention_type,
204
+ use_RoPE=self.use_RoPE,
205
+ )
206
+ for _ in range(self.config.num_layers)
207
+ ]
208
+ )
209
+
210
+ if self.use_linear_projection:
211
+ self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels)
212
+ else:
213
+ self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0)
214
+
215
+ def _init_vectorized_inputs(self, norm_type):
216
+ assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
217
+ assert (
218
+ self.config.num_vector_embeds is not None
219
+ ), "Transformer2DModel over discrete input must provide num_embed"
220
+
221
+ self.height = self.config.sample_size
222
+ self.width = self.config.sample_size
223
+ self.num_latent_pixels = self.height * self.width
224
+
225
+ self.latent_image_embedding = ImagePositionalEmbeddings(
226
+ num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width
227
+ )
228
+
229
+ self.transformer_blocks = nn.ModuleList(
230
+ [
231
+ BasicTransformerBlock(
232
+ self.inner_dim,
233
+ self.config.num_attention_heads,
234
+ self.config.attention_head_dim,
235
+ dropout=self.config.dropout,
236
+ cross_attention_dim=self.config.cross_attention_dim,
237
+ activation_fn=self.config.activation_fn,
238
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
239
+ attention_bias=self.config.attention_bias,
240
+ only_cross_attention=self.config.only_cross_attention,
241
+ double_self_attention=self.config.double_self_attention,
242
+ upcast_attention=self.config.upcast_attention,
243
+ norm_type=norm_type,
244
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
245
+ norm_eps=self.config.norm_eps,
246
+ attention_type=self.config.attention_type,
247
+ )
248
+ for _ in range(self.config.num_layers)
249
+ ]
250
+ )
251
+
252
+ self.norm_out = nn.LayerNorm(self.inner_dim)
253
+ self.out = nn.Linear(self.inner_dim, self.config.num_vector_embeds - 1)
254
+
255
+ def _init_patched_inputs(self, norm_type):
256
+ assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
257
+
258
+ self.height = self.config.sample_size
259
+ self.width = self.config.sample_size
260
+
261
+ self.patch_size = self.config.patch_size
262
+ interpolation_scale = (
263
+ self.config.interpolation_scale
264
+ if self.config.interpolation_scale is not None
265
+ else max(self.config.sample_size // 64, 1)
266
+ )
267
+ self.pos_embed = PatchEmbed(
268
+ height=self.config.sample_size,
269
+ width=self.config.sample_size,
270
+ patch_size=self.config.patch_size,
271
+ in_channels=self.in_channels,
272
+ embed_dim=self.inner_dim,
273
+ interpolation_scale=interpolation_scale,
274
+ )
275
+
276
+ self.transformer_blocks = nn.ModuleList(
277
+ [
278
+ BasicTransformerBlock(
279
+ self.inner_dim,
280
+ self.config.num_attention_heads,
281
+ self.config.attention_head_dim,
282
+ dropout=self.config.dropout,
283
+ cross_attention_dim=self.config.cross_attention_dim,
284
+ activation_fn=self.config.activation_fn,
285
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
286
+ attention_bias=self.config.attention_bias,
287
+ only_cross_attention=self.config.only_cross_attention,
288
+ double_self_attention=self.config.double_self_attention,
289
+ upcast_attention=self.config.upcast_attention,
290
+ norm_type=norm_type,
291
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
292
+ norm_eps=self.config.norm_eps,
293
+ attention_type=self.config.attention_type,
294
+ )
295
+ for _ in range(self.config.num_layers)
296
+ ]
297
+ )
298
+
299
+ if self.config.norm_type != "ada_norm_single":
300
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
301
+ self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
302
+ self.proj_out_2 = nn.Linear(
303
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
304
+ )
305
+ elif self.config.norm_type == "ada_norm_single":
306
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
307
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
308
+ self.proj_out = nn.Linear(
309
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
310
+ )
311
+
312
+ # PixArt-Alpha blocks.
313
+ self.adaln_single = None
314
+ if self.config.norm_type == "ada_norm_single":
315
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
316
+ # additional conditions until we find better name
317
+ self.adaln_single = AdaLayerNormSingle(
318
+ self.inner_dim, use_additional_conditions=self.use_additional_conditions
319
+ )
320
+
321
+ self.caption_projection = None
322
+ if self.caption_channels is not None:
323
+ self.caption_projection = PixArtAlphaTextProjection(
324
+ in_features=self.caption_channels, hidden_size=self.inner_dim
325
+ )
326
+
327
+ def _set_gradient_checkpointing(self, module, value=False):
328
+ if hasattr(module, "gradient_checkpointing"):
329
+ module.gradient_checkpointing = value
330
+
331
+ def forward(
332
+ self,
333
+ hidden_states: torch.Tensor,
334
+ encoder_hidden_states: Optional[torch.Tensor] = None,
335
+ timestep: Optional[torch.LongTensor] = None,
336
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
337
+ class_labels: Optional[torch.LongTensor] = None,
338
+ cross_attention_kwargs: Dict[str, Any] = None,
339
+ attention_mask: Optional[torch.Tensor] = None,
340
+ encoder_attention_mask: Optional[torch.Tensor] = None,
341
+ return_dict: bool = True,
342
+ ):
343
+ """
344
+ The [`Transformer2DModel`] forward method.
345
+
346
+ Args:
347
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):
348
+ Input `hidden_states`.
349
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
350
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
351
+ self-attention.
352
+ timestep ( `torch.LongTensor`, *optional*):
353
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
354
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
355
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
356
+ `AdaLayerZeroNorm`.
357
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
358
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
359
+ `self.processor` in
360
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
361
+ attention_mask ( `torch.Tensor`, *optional*):
362
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
363
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
364
+ negative values to the attention scores corresponding to "discard" tokens.
365
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
366
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
367
+
368
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
369
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
370
+
371
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
372
+ above. This bias will be added to the cross-attention scores.
373
+ return_dict (`bool`, *optional*, defaults to `True`):
374
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
375
+ tuple.
376
+
377
+ Returns:
378
+ If `return_dict` is True, an [`~models.transformers.transformer_2d.Transformer2DModelOutput`] is returned,
379
+ otherwise a `tuple` where the first element is the sample tensor.
380
+ """
381
+ if cross_attention_kwargs is not None:
382
+ if cross_attention_kwargs.get("scale", None) is not None:
383
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
384
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
385
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
386
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
387
+ # expects mask of shape:
388
+ # [batch, key_tokens]
389
+ # adds singleton query_tokens dimension:
390
+ # [batch, 1, key_tokens]
391
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
392
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
393
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
394
+ if attention_mask is not None and attention_mask.ndim == 2:
395
+ # assume that mask is expressed as:
396
+ # (1 = keep, 0 = discard)
397
+ # convert mask into a bias that can be added to attention scores:
398
+ # (keep = +0, discard = -10000.0)
399
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
400
+ attention_mask = attention_mask.unsqueeze(1)
401
+
402
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
403
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
404
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
405
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
406
+
407
+ # 1. Input
408
+ if self.is_input_continuous:
409
+ batch_size, _, height, width = hidden_states.shape
410
+ residual = hidden_states
411
+ hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states)
412
+ elif self.is_input_vectorized:
413
+ hidden_states = self.latent_image_embedding(hidden_states)
414
+ elif self.is_input_patches:
415
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
416
+ hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
417
+ hidden_states, encoder_hidden_states, timestep, added_cond_kwargs
418
+ )
419
+ # print("Hidden States Shape: ", hidden_states.shape)
420
+ # hidden_states = rearrange(hidden_states, '(b t) l c -> b (t l) c', t=6)
421
+ # print("Hidden States Shape: ", hidden_states.shape)
422
+ # 2. Blocks
423
+ for block in self.transformer_blocks:
424
+ if self.training and self.gradient_checkpointing:
425
+
426
+ def create_custom_forward(module, return_dict=None):
427
+ def custom_forward(*inputs):
428
+ if return_dict is not None:
429
+ return module(*inputs, return_dict=return_dict)
430
+ else:
431
+ return module(*inputs)
432
+
433
+ return custom_forward
434
+
435
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
436
+ hidden_states = torch.utils.checkpoint.checkpoint(
437
+ create_custom_forward(block),
438
+ hidden_states,
439
+ attention_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ timestep,
443
+ cross_attention_kwargs,
444
+ class_labels,
445
+ **ckpt_kwargs,
446
+ )
447
+ else:
448
+ hidden_states = block(
449
+ hidden_states,
450
+ attention_mask=attention_mask,
451
+ encoder_hidden_states=encoder_hidden_states,
452
+ encoder_attention_mask=encoder_attention_mask,
453
+ timestep=timestep,
454
+ cross_attention_kwargs=cross_attention_kwargs,
455
+ class_labels=class_labels,
456
+ )
457
+
458
+ # 3. Output
459
+ if self.is_input_continuous:
460
+ output = self._get_output_for_continuous_inputs(
461
+ hidden_states=hidden_states,
462
+ residual=residual,
463
+ batch_size=batch_size,
464
+ height=height,
465
+ width=width,
466
+ inner_dim=inner_dim,
467
+ )
468
+ elif self.is_input_vectorized:
469
+ output = self._get_output_for_vectorized_inputs(hidden_states)
470
+ elif self.is_input_patches:
471
+ output = self._get_output_for_patched_inputs(
472
+ hidden_states=hidden_states,
473
+ timestep=timestep,
474
+ class_labels=class_labels,
475
+ embedded_timestep=embedded_timestep,
476
+ height=height,
477
+ width=width,
478
+ )
479
+
480
+ if not return_dict:
481
+ return (output,)
482
+
483
+ return Transformer2DModelOutput(sample=output)
484
+
485
+ def _operate_on_continuous_inputs(self, hidden_states):
486
+ batch, _, height, width = hidden_states.shape
487
+ # sync_norm:
488
+ hidden_states = rearrange(hidden_states, "(b t) c h w -> b c (h w t)", b=1, h=height, w=width)
489
+ hidden_states = self.norm(hidden_states)
490
+ hidden_states = rearrange(hidden_states, "b c (h w t) -> (b t) c h w", b=1, h=height, w=width)
491
+ if not self.use_linear_projection:
492
+ hidden_states = self.proj_in(hidden_states)
493
+ inner_dim = hidden_states.shape[1]
494
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
495
+ else:
496
+ inner_dim = hidden_states.shape[1]
497
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
498
+ hidden_states = self.proj_in(hidden_states)
499
+ return hidden_states, inner_dim
500
+
501
+ def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs):
502
+ batch_size = hidden_states.shape[0]
503
+ hidden_states = self.pos_embed(hidden_states)
504
+ embedded_timestep = None
505
+
506
+ if self.adaln_single is not None:
507
+ if self.use_additional_conditions and added_cond_kwargs is None:
508
+ raise ValueError(
509
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
510
+ )
511
+ timestep, embedded_timestep = self.adaln_single(
512
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
513
+ )
514
+
515
+ if self.caption_projection is not None:
516
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
517
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
518
+
519
+ return hidden_states, encoder_hidden_states, timestep, embedded_timestep
520
+
521
+ def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim):
522
+ if not self.use_linear_projection:
523
+ hidden_states = (
524
+ hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
525
+ )
526
+ hidden_states = self.proj_out(hidden_states)
527
+ else:
528
+ hidden_states = self.proj_out(hidden_states)
529
+ hidden_states = (
530
+ hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
531
+ )
532
+
533
+ output = hidden_states + residual
534
+ return output
535
+
536
+ def _get_output_for_vectorized_inputs(self, hidden_states):
537
+ b, c, h, w = hidden_states.shape
538
+ # sync_norm:
539
+ hidden_states = rearrange(hidden_states, "(b t) c h w -> b c (h w t)", b=1, h=h, w=w)
540
+ hidden_states = self.norm_out(hidden_states)
541
+ hidden_states = rearrange(hidden_states, "b c (h w t) -> (b t) c h w", b=1, h=h, w=w)
542
+ logits = self.out(hidden_states)
543
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
544
+ logits = logits.permute(0, 2, 1)
545
+ # log(p(x_0))
546
+ output = F.log_softmax(logits.double(), dim=1).float()
547
+ return output
548
+
549
+ def _get_output_for_patched_inputs(
550
+ self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None
551
+ ):
552
+ if self.config.norm_type != "ada_norm_single":
553
+ conditioning = self.transformer_blocks[0].norm1.emb(
554
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
555
+ )
556
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
557
+ # sync_norm:
558
+ b, c, h, w = hidden_states.shape
559
+ hidden_states = rearrange(hidden_states, "(b t) c h w -> b c (h w t)", b=1, h=h, w=w)
560
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
561
+ hidden_states = rearrange(hidden_states, "b c (h w t) -> (b t) c h w", b=1, h=h, w=w)
562
+ hidden_states = self.proj_out_2(hidden_states)
563
+ elif self.config.norm_type == "ada_norm_single":
564
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
565
+ # sync_norm:
566
+ b, c, h, w = hidden_states.shape
567
+ hidden_states = rearrange(hidden_states, "(b t) c h w -> b c (h w t)", b=1, h=h, w=w)
568
+ hidden_states = self.norm_out(hidden_states)
569
+ hidden_states = rearrange(hidden_states, "b c (h w t) -> (b t) c h w", b=1, h=h, w=w)
570
+ # Modulation
571
+ hidden_states = hidden_states * (1 + scale) + shift
572
+ hidden_states = self.proj_out(hidden_states)
573
+ hidden_states = hidden_states.squeeze(1)
574
+
575
+ # unpatchify
576
+ if self.adaln_single is None:
577
+ height = width = int(hidden_states.shape[1] ** 0.5)
578
+ hidden_states = hidden_states.reshape(
579
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
580
+ )
581
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
582
+ output = hidden_states.reshape(
583
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
584
+ )
585
+ return output
Marigold/unet/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
Marigold/unet/unet_2d_condition.py ADDED
@@ -0,0 +1,1414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
22
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.modeling_utils import ModelMixin
25
+ from diffusers.models.embeddings import (
26
+ GaussianFourierProjection,
27
+ GLIGENTextBoundingboxProjection,
28
+ ImageHintTimeEmbedding,
29
+ ImageProjection,
30
+ ImageTimeEmbedding,
31
+ TextImageProjection,
32
+ TextImageTimeEmbedding,
33
+ TextTimeEmbedding,
34
+ TimestepEmbedding,
35
+ Timesteps,
36
+ )
37
+ from diffusers.models.activations import get_activation
38
+ from Marigold.unet.attention_processor import (
39
+ ADDED_KV_ATTENTION_PROCESSORS,
40
+ CROSS_ATTENTION_PROCESSORS,
41
+ Attention,
42
+ AttentionProcessor,
43
+ AttnAddedKVProcessor,
44
+ AttnProcessor,
45
+ FusedAttnProcessor2_0,
46
+ )
47
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
48
+ from Marigold.unet.unet_2d_blocks import get_mid_block, get_down_block, get_up_block
49
+ from einops import rearrange
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+
54
+ @dataclass
55
+ class UNet2DConditionOutput(BaseOutput):
56
+ """
57
+ The output of [`UNet2DConditionModel`].
58
+
59
+ Args:
60
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
61
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
62
+ """
63
+
64
+ sample: torch.Tensor = None
65
+
66
+
67
+ class UNet2DConditionModel(
68
+ ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
69
+ ):
70
+ r"""
71
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
72
+ shaped output.
73
+
74
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
75
+ for all models (such as downloading or saving).
76
+
77
+ Parameters:
78
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
79
+ Height and width of input/output sample.
80
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
81
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
82
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
83
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
84
+ Whether to flip the sin to cos in the time embedding.
85
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
86
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
87
+ The tuple of downsample blocks to use.
88
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
89
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
90
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
91
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
92
+ The tuple of upsample blocks to use.
93
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
94
+ Whether to include self-attention in the basic transformer blocks, see
95
+ [`~models.attention.BasicTransformerBlock`].
96
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
97
+ The tuple of output channels for each block.
98
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
99
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
100
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
101
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
102
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
103
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
104
+ If `None`, normalization and activation layers is skipped in post-processing.
105
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
106
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
107
+ The dimension of the cross attention features.
108
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
109
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
110
+ [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],
111
+ [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
112
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
113
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
114
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
115
+ [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],
116
+ [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
117
+ encoder_hid_dim (`int`, *optional*, defaults to None):
118
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
119
+ dimension to `cross_attention_dim`.
120
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
121
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
122
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
123
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
124
+ num_attention_heads (`int`, *optional*):
125
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
126
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
127
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
128
+ class_embed_type (`str`, *optional*, defaults to `None`):
129
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
130
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
131
+ addition_embed_type (`str`, *optional*, defaults to `None`):
132
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
133
+ "text". "text" will use the `TextTimeEmbedding` layer.
134
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
135
+ Dimension for the timestep embeddings.
136
+ num_class_embeds (`int`, *optional*, defaults to `None`):
137
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
138
+ class conditioning with `class_embed_type` equal to `None`.
139
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
140
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
141
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
142
+ An optional override for the dimension of the projected time embedding.
143
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
144
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
145
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
146
+ timestep_post_act (`str`, *optional*, defaults to `None`):
147
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
148
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
149
+ The dimension of `cond_proj` layer in the timestep embedding.
150
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
151
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
152
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
153
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
154
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
155
+ embeddings with the class embeddings.
156
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
157
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
158
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
159
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
160
+ otherwise.
161
+ """
162
+
163
+ _supports_gradient_checkpointing = True
164
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
165
+
166
+ @register_to_config
167
+ def __init__(
168
+ self,
169
+ sample_size: Optional[int] = None,
170
+ in_channels: int = 4,
171
+ out_channels: int = 4,
172
+ center_input_sample: bool = False,
173
+ flip_sin_to_cos: bool = True,
174
+ freq_shift: int = 0,
175
+ down_block_types: Tuple[str] = (
176
+ "CrossAttnDownBlock2D",
177
+ "CrossAttnDownBlock2D",
178
+ "CrossAttnDownBlock2D",
179
+ "DownBlock2D",
180
+ ),
181
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
182
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
183
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
184
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
185
+ layers_per_block: Union[int, Tuple[int]] = 2,
186
+ downsample_padding: int = 1,
187
+ mid_block_scale_factor: float = 1,
188
+ dropout: float = 0.0,
189
+ act_fn: str = "silu",
190
+ norm_num_groups: Optional[int] = 32,
191
+ norm_eps: float = 1e-5,
192
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
193
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
194
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
195
+ encoder_hid_dim: Optional[int] = None,
196
+ encoder_hid_dim_type: Optional[str] = None,
197
+ attention_head_dim: Union[int, Tuple[int]] = 8,
198
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
199
+ dual_cross_attention: bool = False,
200
+ use_linear_projection: bool = False,
201
+ class_embed_type: Optional[str] = None,
202
+ addition_embed_type: Optional[str] = None,
203
+ addition_time_embed_dim: Optional[int] = None,
204
+ num_class_embeds: Optional[int] = None,
205
+ upcast_attention: bool = False,
206
+ resnet_time_scale_shift: str = "default",
207
+ resnet_skip_time_act: bool = False,
208
+ resnet_out_scale_factor: float = 1.0,
209
+ time_embedding_type: str = "positional",
210
+ time_embedding_dim: Optional[int] = None,
211
+ time_embedding_act_fn: Optional[str] = None,
212
+ timestep_post_act: Optional[str] = None,
213
+ time_cond_proj_dim: Optional[int] = None,
214
+ conv_in_kernel: int = 3,
215
+ conv_out_kernel: int = 3,
216
+ projection_class_embeddings_input_dim: Optional[int] = None,
217
+ attention_type: str = "default",
218
+ class_embeddings_concat: bool = False,
219
+ mid_block_only_cross_attention: Optional[bool] = None,
220
+ cross_attention_norm: Optional[str] = None,
221
+ addition_embed_type_num_heads: int = 64,
222
+ use_RoPE: bool = False
223
+ ):
224
+ super().__init__()
225
+
226
+ self.sample_size = sample_size
227
+ self.use_RoPE = use_RoPE
228
+ self.gradient_checkpointing = False
229
+
230
+ if num_attention_heads is not None:
231
+ raise ValueError(
232
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
233
+ )
234
+
235
+ # If `num_attention_heads` is not defined (which is the case for most models)
236
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
237
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
238
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
239
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
240
+ # which is why we correct for the naming here.
241
+ num_attention_heads = num_attention_heads or attention_head_dim
242
+
243
+ # Check inputs
244
+ self._check_config(
245
+ down_block_types=down_block_types,
246
+ up_block_types=up_block_types,
247
+ only_cross_attention=only_cross_attention,
248
+ block_out_channels=block_out_channels,
249
+ layers_per_block=layers_per_block,
250
+ cross_attention_dim=cross_attention_dim,
251
+ transformer_layers_per_block=transformer_layers_per_block,
252
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
253
+ attention_head_dim=attention_head_dim,
254
+ num_attention_heads=num_attention_heads,
255
+ )
256
+
257
+ # input
258
+ conv_in_padding = (conv_in_kernel - 1) // 2
259
+ self.conv_in = nn.Conv2d(
260
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
261
+ )
262
+
263
+ # time
264
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
265
+ time_embedding_type,
266
+ block_out_channels=block_out_channels,
267
+ flip_sin_to_cos=flip_sin_to_cos,
268
+ freq_shift=freq_shift,
269
+ time_embedding_dim=time_embedding_dim,
270
+ )
271
+
272
+ self.time_embedding = TimestepEmbedding(
273
+ timestep_input_dim,
274
+ time_embed_dim,
275
+ act_fn=act_fn,
276
+ post_act_fn=timestep_post_act,
277
+ cond_proj_dim=time_cond_proj_dim,
278
+ )
279
+
280
+ self._set_encoder_hid_proj(
281
+ encoder_hid_dim_type,
282
+ cross_attention_dim=cross_attention_dim,
283
+ encoder_hid_dim=encoder_hid_dim,
284
+ )
285
+
286
+ # class embedding
287
+ self._set_class_embedding(
288
+ class_embed_type,
289
+ act_fn=act_fn,
290
+ num_class_embeds=num_class_embeds,
291
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
292
+ time_embed_dim=time_embed_dim,
293
+ timestep_input_dim=timestep_input_dim,
294
+ )
295
+
296
+ self._set_add_embedding(
297
+ addition_embed_type,
298
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
299
+ addition_time_embed_dim=addition_time_embed_dim,
300
+ cross_attention_dim=cross_attention_dim,
301
+ encoder_hid_dim=encoder_hid_dim,
302
+ flip_sin_to_cos=flip_sin_to_cos,
303
+ freq_shift=freq_shift,
304
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
305
+ time_embed_dim=time_embed_dim,
306
+ )
307
+
308
+ if time_embedding_act_fn is None:
309
+ self.time_embed_act = None
310
+ else:
311
+ self.time_embed_act = get_activation(time_embedding_act_fn)
312
+
313
+ self.down_blocks = nn.ModuleList([])
314
+ self.up_blocks = nn.ModuleList([])
315
+
316
+ if isinstance(only_cross_attention, bool):
317
+ if mid_block_only_cross_attention is None:
318
+ mid_block_only_cross_attention = only_cross_attention
319
+
320
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
321
+
322
+ if mid_block_only_cross_attention is None:
323
+ mid_block_only_cross_attention = False
324
+
325
+ if isinstance(num_attention_heads, int):
326
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
327
+
328
+ if isinstance(attention_head_dim, int):
329
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
330
+
331
+ if isinstance(cross_attention_dim, int):
332
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
333
+
334
+ if isinstance(layers_per_block, int):
335
+ layers_per_block = [layers_per_block] * len(down_block_types)
336
+
337
+ if isinstance(transformer_layers_per_block, int):
338
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
339
+
340
+ if class_embeddings_concat:
341
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
342
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
343
+ # regular time embeddings
344
+ blocks_time_embed_dim = time_embed_dim * 2
345
+ else:
346
+ blocks_time_embed_dim = time_embed_dim
347
+
348
+ # down
349
+ output_channel = block_out_channels[0]
350
+ for i, down_block_type in enumerate(down_block_types):
351
+ input_channel = output_channel
352
+ output_channel = block_out_channels[i]
353
+ is_final_block = i == len(block_out_channels) - 1
354
+
355
+ down_block = get_down_block(
356
+ down_block_type,
357
+ num_layers=layers_per_block[i],
358
+ transformer_layers_per_block=transformer_layers_per_block[i],
359
+ in_channels=input_channel,
360
+ out_channels=output_channel,
361
+ temb_channels=blocks_time_embed_dim,
362
+ add_downsample=not is_final_block,
363
+ resnet_eps=norm_eps,
364
+ resnet_act_fn=act_fn,
365
+ resnet_groups=norm_num_groups,
366
+ cross_attention_dim=cross_attention_dim[i],
367
+ num_attention_heads=num_attention_heads[i],
368
+ downsample_padding=downsample_padding,
369
+ dual_cross_attention=dual_cross_attention,
370
+ use_linear_projection=use_linear_projection,
371
+ only_cross_attention=only_cross_attention[i],
372
+ upcast_attention=upcast_attention,
373
+ resnet_time_scale_shift=resnet_time_scale_shift,
374
+ attention_type=attention_type,
375
+ resnet_skip_time_act=resnet_skip_time_act,
376
+ resnet_out_scale_factor=resnet_out_scale_factor,
377
+ cross_attention_norm=cross_attention_norm,
378
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
379
+ dropout=dropout,
380
+ use_RoPE=use_RoPE,
381
+ )
382
+ self.down_blocks.append(down_block)
383
+
384
+ # mid
385
+ self.mid_block = get_mid_block(
386
+ mid_block_type,
387
+ temb_channels=blocks_time_embed_dim,
388
+ in_channels=block_out_channels[-1],
389
+ resnet_eps=norm_eps,
390
+ resnet_act_fn=act_fn,
391
+ resnet_groups=norm_num_groups,
392
+ output_scale_factor=mid_block_scale_factor,
393
+ transformer_layers_per_block=transformer_layers_per_block[-1],
394
+ num_attention_heads=num_attention_heads[-1],
395
+ cross_attention_dim=cross_attention_dim[-1],
396
+ dual_cross_attention=dual_cross_attention,
397
+ use_linear_projection=use_linear_projection,
398
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
399
+ upcast_attention=upcast_attention,
400
+ resnet_time_scale_shift=resnet_time_scale_shift,
401
+ attention_type=attention_type,
402
+ resnet_skip_time_act=resnet_skip_time_act,
403
+ cross_attention_norm=cross_attention_norm,
404
+ attention_head_dim=attention_head_dim[-1],
405
+ dropout=dropout,
406
+ use_RoPE=use_RoPE,
407
+ )
408
+
409
+ # count how many layers upsample the images
410
+ self.num_upsamplers = 0
411
+
412
+ # up
413
+ reversed_block_out_channels = list(reversed(block_out_channels))
414
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
415
+ reversed_layers_per_block = list(reversed(layers_per_block))
416
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
417
+ reversed_transformer_layers_per_block = (
418
+ list(reversed(transformer_layers_per_block))
419
+ if reverse_transformer_layers_per_block is None
420
+ else reverse_transformer_layers_per_block
421
+ )
422
+ only_cross_attention = list(reversed(only_cross_attention))
423
+
424
+ output_channel = reversed_block_out_channels[0]
425
+ for i, up_block_type in enumerate(up_block_types):
426
+ is_final_block = i == len(block_out_channels) - 1
427
+
428
+ prev_output_channel = output_channel
429
+ output_channel = reversed_block_out_channels[i]
430
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
431
+
432
+ # add upsample block for all BUT final layer
433
+ if not is_final_block:
434
+ add_upsample = True
435
+ self.num_upsamplers += 1
436
+ else:
437
+ add_upsample = False
438
+
439
+ up_block = get_up_block(
440
+ up_block_type,
441
+ num_layers=reversed_layers_per_block[i] + 1,
442
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
443
+ in_channels=input_channel,
444
+ out_channels=output_channel,
445
+ prev_output_channel=prev_output_channel,
446
+ temb_channels=blocks_time_embed_dim,
447
+ add_upsample=add_upsample,
448
+ resnet_eps=norm_eps,
449
+ resnet_act_fn=act_fn,
450
+ resolution_idx=i,
451
+ resnet_groups=norm_num_groups,
452
+ cross_attention_dim=reversed_cross_attention_dim[i],
453
+ num_attention_heads=reversed_num_attention_heads[i],
454
+ dual_cross_attention=dual_cross_attention,
455
+ use_linear_projection=use_linear_projection,
456
+ only_cross_attention=only_cross_attention[i],
457
+ upcast_attention=upcast_attention,
458
+ resnet_time_scale_shift=resnet_time_scale_shift,
459
+ attention_type=attention_type,
460
+ resnet_skip_time_act=resnet_skip_time_act,
461
+ resnet_out_scale_factor=resnet_out_scale_factor,
462
+ cross_attention_norm=cross_attention_norm,
463
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
464
+ dropout=dropout,
465
+ use_RoPE=use_RoPE,
466
+ )
467
+ self.up_blocks.append(up_block)
468
+ prev_output_channel = output_channel
469
+
470
+ # out
471
+ if norm_num_groups is not None:
472
+ self.conv_norm_out = nn.GroupNorm(
473
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
474
+ )
475
+
476
+ self.conv_act = get_activation(act_fn)
477
+
478
+ else:
479
+ self.conv_norm_out = None
480
+ self.conv_act = None
481
+
482
+ conv_out_padding = (conv_out_kernel - 1) // 2
483
+ self.conv_out = nn.Conv2d(
484
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
485
+ )
486
+
487
+ self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
488
+
489
+ def _check_config(
490
+ self,
491
+ down_block_types: Tuple[str],
492
+ up_block_types: Tuple[str],
493
+ only_cross_attention: Union[bool, Tuple[bool]],
494
+ block_out_channels: Tuple[int],
495
+ layers_per_block: Union[int, Tuple[int]],
496
+ cross_attention_dim: Union[int, Tuple[int]],
497
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
498
+ reverse_transformer_layers_per_block: bool,
499
+ attention_head_dim: int,
500
+ num_attention_heads: Optional[Union[int, Tuple[int]]],
501
+ ):
502
+ if len(down_block_types) != len(up_block_types):
503
+ raise ValueError(
504
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
505
+ )
506
+
507
+ if len(block_out_channels) != len(down_block_types):
508
+ raise ValueError(
509
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
510
+ )
511
+
512
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
513
+ raise ValueError(
514
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
515
+ )
516
+
517
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
518
+ raise ValueError(
519
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
520
+ )
521
+
522
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
523
+ raise ValueError(
524
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
525
+ )
526
+
527
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
528
+ raise ValueError(
529
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
530
+ )
531
+
532
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
533
+ raise ValueError(
534
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
535
+ )
536
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
537
+ for layer_number_per_block in transformer_layers_per_block:
538
+ if isinstance(layer_number_per_block, list):
539
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
540
+
541
+ def _set_time_proj(
542
+ self,
543
+ time_embedding_type: str,
544
+ block_out_channels: int,
545
+ flip_sin_to_cos: bool,
546
+ freq_shift: float,
547
+ time_embedding_dim: int,
548
+ ) -> Tuple[int, int]:
549
+ if time_embedding_type == "fourier":
550
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
551
+ if time_embed_dim % 2 != 0:
552
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
553
+ self.time_proj = GaussianFourierProjection(
554
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
555
+ )
556
+ timestep_input_dim = time_embed_dim
557
+ elif time_embedding_type == "positional":
558
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
559
+
560
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
561
+ timestep_input_dim = block_out_channels[0]
562
+ else:
563
+ raise ValueError(
564
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
565
+ )
566
+
567
+ return time_embed_dim, timestep_input_dim
568
+
569
+ def _set_encoder_hid_proj(
570
+ self,
571
+ encoder_hid_dim_type: Optional[str],
572
+ cross_attention_dim: Union[int, Tuple[int]],
573
+ encoder_hid_dim: Optional[int],
574
+ ):
575
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
576
+ encoder_hid_dim_type = "text_proj"
577
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
578
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
579
+
580
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
581
+ raise ValueError(
582
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
583
+ )
584
+
585
+ if encoder_hid_dim_type == "text_proj":
586
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
587
+ elif encoder_hid_dim_type == "text_image_proj":
588
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
589
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
590
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
591
+ self.encoder_hid_proj = TextImageProjection(
592
+ text_embed_dim=encoder_hid_dim,
593
+ image_embed_dim=cross_attention_dim,
594
+ cross_attention_dim=cross_attention_dim,
595
+ )
596
+ elif encoder_hid_dim_type == "image_proj":
597
+ # Kandinsky 2.2
598
+ self.encoder_hid_proj = ImageProjection(
599
+ image_embed_dim=encoder_hid_dim,
600
+ cross_attention_dim=cross_attention_dim,
601
+ )
602
+ elif encoder_hid_dim_type is not None:
603
+ raise ValueError(
604
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
605
+ )
606
+ else:
607
+ self.encoder_hid_proj = None
608
+
609
+ def _set_class_embedding(
610
+ self,
611
+ class_embed_type: Optional[str],
612
+ act_fn: str,
613
+ num_class_embeds: Optional[int],
614
+ projection_class_embeddings_input_dim: Optional[int],
615
+ time_embed_dim: int,
616
+ timestep_input_dim: int,
617
+ ):
618
+ if class_embed_type is None and num_class_embeds is not None:
619
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
620
+ elif class_embed_type == "timestep":
621
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
622
+ elif class_embed_type == "identity":
623
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
624
+ elif class_embed_type == "projection":
625
+ if projection_class_embeddings_input_dim is None:
626
+ raise ValueError(
627
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
628
+ )
629
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
630
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
631
+ # 2. it projects from an arbitrary input dimension.
632
+ #
633
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
634
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
635
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
636
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
637
+ elif class_embed_type == "simple_projection":
638
+ if projection_class_embeddings_input_dim is None:
639
+ raise ValueError(
640
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
641
+ )
642
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
643
+ else:
644
+ self.class_embedding = None
645
+
646
+ def _set_add_embedding(
647
+ self,
648
+ addition_embed_type: str,
649
+ addition_embed_type_num_heads: int,
650
+ addition_time_embed_dim: Optional[int],
651
+ flip_sin_to_cos: bool,
652
+ freq_shift: float,
653
+ cross_attention_dim: Optional[int],
654
+ encoder_hid_dim: Optional[int],
655
+ projection_class_embeddings_input_dim: Optional[int],
656
+ time_embed_dim: int,
657
+ ):
658
+ if addition_embed_type == "text":
659
+ if encoder_hid_dim is not None:
660
+ text_time_embedding_from_dim = encoder_hid_dim
661
+ else:
662
+ text_time_embedding_from_dim = cross_attention_dim
663
+
664
+ self.add_embedding = TextTimeEmbedding(
665
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
666
+ )
667
+ elif addition_embed_type == "text_image":
668
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
669
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
670
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
671
+ self.add_embedding = TextImageTimeEmbedding(
672
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
673
+ )
674
+ elif addition_embed_type == "text_time":
675
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
676
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
677
+ elif addition_embed_type == "image":
678
+ # Kandinsky 2.2
679
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
680
+ elif addition_embed_type == "image_hint":
681
+ # Kandinsky 2.2 ControlNet
682
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
683
+ elif addition_embed_type is not None:
684
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
685
+
686
+ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
687
+ if attention_type in ["gated", "gated-text-image"]:
688
+ positive_len = 768
689
+ if isinstance(cross_attention_dim, int):
690
+ positive_len = cross_attention_dim
691
+ elif isinstance(cross_attention_dim, (list, tuple)):
692
+ positive_len = cross_attention_dim[0]
693
+
694
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
695
+ self.position_net = GLIGENTextBoundingboxProjection(
696
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
697
+ )
698
+
699
+ @property
700
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
701
+ r"""
702
+ Returns:
703
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
704
+ indexed by its weight name.
705
+ """
706
+ # set recursively
707
+ processors = {}
708
+
709
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
710
+ if hasattr(module, "get_processor"):
711
+ processors[f"{name}.processor"] = module.get_processor()
712
+
713
+ for sub_name, child in module.named_children():
714
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
715
+
716
+ return processors
717
+
718
+ for name, module in self.named_children():
719
+ fn_recursive_add_processors(name, module, processors)
720
+
721
+ return processors
722
+
723
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
724
+ r"""
725
+ Sets the attention processor to use to compute attention.
726
+
727
+ Parameters:
728
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
729
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
730
+ for **all** `Attention` layers.
731
+
732
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
733
+ processor. This is strongly recommended when setting trainable attention processors.
734
+
735
+ """
736
+ count = len(self.attn_processors.keys())
737
+
738
+ if isinstance(processor, dict) and len(processor) != count:
739
+ raise ValueError(
740
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
741
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
742
+ )
743
+
744
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
745
+ if hasattr(module, "set_processor"):
746
+ if not isinstance(processor, dict):
747
+ module.set_processor(processor)
748
+ else:
749
+ module.set_processor(processor.pop(f"{name}.processor"))
750
+
751
+ for sub_name, child in module.named_children():
752
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
753
+
754
+ for name, module in self.named_children():
755
+ fn_recursive_attn_processor(name, module, processor)
756
+
757
+ def set_default_attn_processor(self):
758
+ """
759
+ Disables custom attention processors and sets the default attention implementation.
760
+ """
761
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
762
+ processor = AttnAddedKVProcessor()
763
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
764
+ processor = AttnProcessor()
765
+ else:
766
+ raise ValueError(
767
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
768
+ )
769
+
770
+ self.set_attn_processor(processor)
771
+
772
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
773
+ r"""
774
+ Enable sliced attention computation.
775
+
776
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
777
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
778
+
779
+ Args:
780
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
781
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
782
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
783
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
784
+ must be a multiple of `slice_size`.
785
+ """
786
+ sliceable_head_dims = []
787
+
788
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
789
+ if hasattr(module, "set_attention_slice"):
790
+ sliceable_head_dims.append(module.sliceable_head_dim)
791
+
792
+ for child in module.children():
793
+ fn_recursive_retrieve_sliceable_dims(child)
794
+
795
+ # retrieve number of attention layers
796
+ for module in self.children():
797
+ fn_recursive_retrieve_sliceable_dims(module)
798
+
799
+ num_sliceable_layers = len(sliceable_head_dims)
800
+
801
+ if slice_size == "auto":
802
+ # half the attention head size is usually a good trade-off between
803
+ # speed and memory
804
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
805
+ elif slice_size == "max":
806
+ # make smallest slice possible
807
+ slice_size = num_sliceable_layers * [1]
808
+
809
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
810
+
811
+ if len(slice_size) != len(sliceable_head_dims):
812
+ raise ValueError(
813
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
814
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
815
+ )
816
+
817
+ for i in range(len(slice_size)):
818
+ size = slice_size[i]
819
+ dim = sliceable_head_dims[i]
820
+ if size is not None and size > dim:
821
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
822
+
823
+ # Recursively walk through all the children.
824
+ # Any children which exposes the set_attention_slice method
825
+ # gets the message
826
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
827
+ if hasattr(module, "set_attention_slice"):
828
+ module.set_attention_slice(slice_size.pop())
829
+
830
+ for child in module.children():
831
+ fn_recursive_set_attention_slice(child, slice_size)
832
+
833
+ reversed_slice_size = list(reversed(slice_size))
834
+ for module in self.children():
835
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
836
+
837
+ def _set_gradient_checkpointing(self):
838
+ self.gradient_checkpointing = True
839
+
840
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
841
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
842
+
843
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
844
+
845
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
846
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
847
+
848
+ Args:
849
+ s1 (`float`):
850
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
851
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
852
+ s2 (`float`):
853
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
854
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
855
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
856
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
857
+ """
858
+ for i, upsample_block in enumerate(self.up_blocks):
859
+ setattr(upsample_block, "s1", s1)
860
+ setattr(upsample_block, "s2", s2)
861
+ setattr(upsample_block, "b1", b1)
862
+ setattr(upsample_block, "b2", b2)
863
+
864
+ def disable_freeu(self):
865
+ """Disables the FreeU mechanism."""
866
+ freeu_keys = {"s1", "s2", "b1", "b2"}
867
+ for i, upsample_block in enumerate(self.up_blocks):
868
+ for k in freeu_keys:
869
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
870
+ setattr(upsample_block, k, None)
871
+
872
+ def fuse_qkv_projections(self):
873
+ """
874
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
875
+ are fused. For cross-attention modules, key and value projection matrices are fused.
876
+
877
+ <Tip warning={true}>
878
+
879
+ This API is 🧪 experimental.
880
+
881
+ </Tip>
882
+ """
883
+ self.original_attn_processors = None
884
+
885
+ for _, attn_processor in self.attn_processors.items():
886
+ if "Added" in str(attn_processor.__class__.__name__):
887
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
888
+
889
+ self.original_attn_processors = self.attn_processors
890
+
891
+ for module in self.modules():
892
+ if isinstance(module, Attention):
893
+ module.fuse_projections(fuse=True)
894
+
895
+ self.set_attn_processor(FusedAttnProcessor2_0())
896
+
897
+ def unfuse_qkv_projections(self):
898
+ """Disables the fused QKV projection if enabled.
899
+
900
+ <Tip warning={true}>
901
+
902
+ This API is 🧪 experimental.
903
+
904
+ </Tip>
905
+
906
+ """
907
+ if self.original_attn_processors is not None:
908
+ self.set_attn_processor(self.original_attn_processors)
909
+
910
+ def get_time_embed(
911
+ self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
912
+ ) -> Optional[torch.Tensor]:
913
+ timesteps = timestep
914
+ if not torch.is_tensor(timesteps):
915
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
916
+ # This would be a good case for the `match` statement (Python 3.10+)
917
+ is_mps = sample.device.type == "mps"
918
+ if isinstance(timestep, float):
919
+ dtype = torch.float32 if is_mps else torch.float64
920
+ else:
921
+ dtype = torch.int32 if is_mps else torch.int64
922
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
923
+ elif len(timesteps.shape) == 0:
924
+ timesteps = timesteps[None].to(sample.device)
925
+
926
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
927
+ timesteps = timesteps.expand(sample.shape[0])
928
+
929
+ t_emb = self.time_proj(timesteps)
930
+ # `Timesteps` does not contain any weights and will always return f32 tensors
931
+ # but time_embedding might actually be running in fp16. so we need to cast here.
932
+ # there might be better ways to encapsulate this.
933
+ t_emb = t_emb.to(dtype=sample.dtype)
934
+ return t_emb
935
+
936
+ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
937
+ class_emb = None
938
+ if self.class_embedding is not None:
939
+ if class_labels is None:
940
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
941
+
942
+ if self.config.class_embed_type == "timestep":
943
+ class_labels = self.time_proj(class_labels)
944
+
945
+ # `Timesteps` does not contain any weights and will always return f32 tensors
946
+ # there might be better ways to encapsulate this.
947
+ class_labels = class_labels.to(dtype=sample.dtype)
948
+
949
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
950
+ return class_emb
951
+
952
+ def get_aug_embed(
953
+ self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
954
+ ) -> Optional[torch.Tensor]:
955
+ aug_emb = None
956
+ if self.config.addition_embed_type == "text":
957
+ aug_emb = self.add_embedding(encoder_hidden_states)
958
+ elif self.config.addition_embed_type == "text_image":
959
+ # Kandinsky 2.1 - style
960
+ if "image_embeds" not in added_cond_kwargs:
961
+ raise ValueError(
962
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
963
+ )
964
+
965
+ image_embs = added_cond_kwargs.get("image_embeds")
966
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
967
+ aug_emb = self.add_embedding(text_embs, image_embs)
968
+ elif self.config.addition_embed_type == "text_time":
969
+ # SDXL - style
970
+ if "text_embeds" not in added_cond_kwargs:
971
+ raise ValueError(
972
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
973
+ )
974
+ text_embeds = added_cond_kwargs.get("text_embeds")
975
+ if "time_ids" not in added_cond_kwargs:
976
+ raise ValueError(
977
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
978
+ )
979
+ time_ids = added_cond_kwargs.get("time_ids")
980
+ time_embeds = self.add_time_proj(time_ids.flatten())
981
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
982
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
983
+ add_embeds = add_embeds.to(emb.dtype)
984
+ aug_emb = self.add_embedding(add_embeds)
985
+ elif self.config.addition_embed_type == "image":
986
+ # Kandinsky 2.2 - style
987
+ if "image_embeds" not in added_cond_kwargs:
988
+ raise ValueError(
989
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
990
+ )
991
+ image_embs = added_cond_kwargs.get("image_embeds")
992
+ aug_emb = self.add_embedding(image_embs)
993
+ elif self.config.addition_embed_type == "image_hint":
994
+ # Kandinsky 2.2 - style
995
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
996
+ raise ValueError(
997
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
998
+ )
999
+ image_embs = added_cond_kwargs.get("image_embeds")
1000
+ hint = added_cond_kwargs.get("hint")
1001
+ aug_emb = self.add_embedding(image_embs, hint)
1002
+ return aug_emb
1003
+
1004
+ def process_encoder_hidden_states(
1005
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
1006
+ ) -> torch.Tensor:
1007
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1008
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1009
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1010
+ # Kandinsky 2.1 - style
1011
+ if "image_embeds" not in added_cond_kwargs:
1012
+ raise ValueError(
1013
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1014
+ )
1015
+
1016
+ image_embeds = added_cond_kwargs.get("image_embeds")
1017
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1018
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1019
+ # Kandinsky 2.2 - style
1020
+ if "image_embeds" not in added_cond_kwargs:
1021
+ raise ValueError(
1022
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1023
+ )
1024
+ image_embeds = added_cond_kwargs.get("image_embeds")
1025
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1026
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1027
+ if "image_embeds" not in added_cond_kwargs:
1028
+ raise ValueError(
1029
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1030
+ )
1031
+
1032
+ if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
1033
+ encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)
1034
+
1035
+ image_embeds = added_cond_kwargs.get("image_embeds")
1036
+ image_embeds = self.encoder_hid_proj(image_embeds)
1037
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
1038
+ return encoder_hidden_states
1039
+
1040
+ def forward(
1041
+ self,
1042
+ sample: torch.Tensor,
1043
+ timestep: Union[torch.Tensor, float, int],
1044
+ encoder_hidden_states: torch.Tensor,
1045
+ class_labels: Optional[torch.Tensor] = None,
1046
+ timestep_cond: Optional[torch.Tensor] = None,
1047
+ attention_mask: Optional[torch.Tensor] = None,
1048
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1049
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1050
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1051
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1052
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1053
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1054
+ return_dict: bool = True,
1055
+ ) -> Union[UNet2DConditionOutput, Tuple]:
1056
+ r"""
1057
+ The [`UNet2DConditionModel`] forward method.
1058
+
1059
+ Args:
1060
+ sample (`torch.Tensor`):
1061
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
1062
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
1063
+ encoder_hidden_states (`torch.Tensor`):
1064
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1065
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1066
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
1067
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
1068
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
1069
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
1070
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
1071
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1072
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1073
+ negative values to the attention scores corresponding to "discard" tokens.
1074
+ cross_attention_kwargs (`dict`, *optional*):
1075
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1076
+ `self.processor` in
1077
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1078
+ added_cond_kwargs: (`dict`, *optional*):
1079
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
1080
+ are passed along to the UNet blocks.
1081
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
1082
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
1083
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
1084
+ A tensor that if specified is added to the residual of the middle unet block.
1085
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1086
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1087
+ encoder_attention_mask (`torch.Tensor`):
1088
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1089
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1090
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
1091
+ return_dict (`bool`, *optional*, defaults to `True`):
1092
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1093
+ tuple.
1094
+
1095
+ Returns:
1096
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1097
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
1098
+ otherwise a `tuple` is returned where the first element is the sample tensor.
1099
+ """
1100
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1101
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1102
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1103
+ # on the fly if necessary.
1104
+ default_overall_up_factor = 2**self.num_upsamplers
1105
+
1106
+ # print("Sample shape: ", sample.shape)
1107
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1108
+ forward_upsample_size = False
1109
+ upsample_size = None
1110
+
1111
+ for dim in sample.shape[-2:]:
1112
+ if dim % default_overall_up_factor != 0:
1113
+ # Forward upsample size to force interpolation output size.
1114
+ forward_upsample_size = True
1115
+ break
1116
+
1117
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1118
+ # expects mask of shape:
1119
+ # [batch, key_tokens]
1120
+ # adds singleton query_tokens dimension:
1121
+ # [batch, 1, key_tokens]
1122
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1123
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1124
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1125
+ if attention_mask is not None:
1126
+ # assume that mask is expressed as:
1127
+ # (1 = keep, 0 = discard)
1128
+ # convert mask into a bias that can be added to attention scores:
1129
+ # (keep = +0, discard = -10000.0)
1130
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1131
+ attention_mask = attention_mask.unsqueeze(1)
1132
+
1133
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
1134
+ if encoder_attention_mask is not None:
1135
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1136
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1137
+
1138
+ # 0. center input if necessary
1139
+ if self.config.center_input_sample:
1140
+ sample = 2 * sample - 1.0
1141
+ # 1. time
1142
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1143
+ emb = self.time_embedding(t_emb, timestep_cond)
1144
+ aug_emb = None
1145
+
1146
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
1147
+ if class_emb is not None:
1148
+ if self.config.class_embeddings_concat:
1149
+ emb = torch.cat([emb, class_emb], dim=-1)
1150
+ else:
1151
+ emb = emb + class_emb
1152
+
1153
+ aug_emb = self.get_aug_embed(
1154
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1155
+ )
1156
+ if self.config.addition_embed_type == "image_hint":
1157
+ aug_emb, hint = aug_emb
1158
+ sample = torch.cat([sample, hint], dim=1)
1159
+
1160
+ emb = emb + aug_emb if aug_emb is not None else emb
1161
+
1162
+ if self.time_embed_act is not None:
1163
+ emb = self.time_embed_act(emb)
1164
+
1165
+ encoder_hidden_states = self.process_encoder_hidden_states(
1166
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1167
+ )
1168
+
1169
+ # 2. pre-process
1170
+ if self.gradient_checkpointing:
1171
+ sample = torch.utils.checkpoint.checkpoint(
1172
+ self.conv_in,
1173
+ sample,
1174
+ use_reentrant=False,
1175
+ )
1176
+ else:
1177
+ sample = self.conv_in(sample)
1178
+
1179
+ # 2.5 GLIGEN position net
1180
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1181
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1182
+ gligen_args = cross_attention_kwargs.pop("gligen")
1183
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1184
+
1185
+ # 3. down
1186
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
1187
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
1188
+ if cross_attention_kwargs is not None:
1189
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1190
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
1191
+ else:
1192
+ lora_scale = 1.0
1193
+
1194
+ if USE_PEFT_BACKEND:
1195
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1196
+ scale_lora_layers(self, lora_scale)
1197
+
1198
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1199
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1200
+ is_adapter = down_intrablock_additional_residuals is not None
1201
+ # maintain backward compatibility for legacy usage, where
1202
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1203
+ # but can only use one or the other
1204
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1205
+ deprecate(
1206
+ "T2I should not use down_block_additional_residuals",
1207
+ "1.3.0",
1208
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1209
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1210
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1211
+ standard_warn=False,
1212
+ )
1213
+ down_intrablock_additional_residuals = down_block_additional_residuals
1214
+ is_adapter = True
1215
+
1216
+ down_block_res_samples = (sample,)
1217
+ for i, downsample_block in enumerate(self.down_blocks):
1218
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1219
+ # For t2i-adapter CrossAttnDownBlock2D
1220
+ additional_residuals = {}
1221
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1222
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1223
+
1224
+ if self.gradient_checkpointing:
1225
+ sample, res_samples = torch.utils.checkpoint.checkpoint(
1226
+ downsample_block,
1227
+ hidden_states=sample,
1228
+ temb=emb,
1229
+ encoder_hidden_states=encoder_hidden_states,
1230
+ attention_mask=attention_mask,
1231
+ cross_attention_kwargs=cross_attention_kwargs,
1232
+ encoder_attention_mask=encoder_attention_mask,
1233
+ **additional_residuals,
1234
+ use_reentrant=False,
1235
+ )
1236
+ else:
1237
+ sample, res_samples = downsample_block(
1238
+ hidden_states=sample,
1239
+ temb=emb,
1240
+ encoder_hidden_states=encoder_hidden_states,
1241
+ attention_mask=attention_mask,
1242
+ cross_attention_kwargs=cross_attention_kwargs,
1243
+ encoder_attention_mask=encoder_attention_mask,
1244
+ **additional_residuals,
1245
+ )
1246
+ else:
1247
+ if self.gradient_checkpointing:
1248
+ sample, res_samples = torch.utils.checkpoint.checkpoint(
1249
+ downsample_block,
1250
+ hidden_states=sample,
1251
+ temb=emb,
1252
+ use_reentrant=False,
1253
+ )
1254
+ else:
1255
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1256
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1257
+ sample += down_intrablock_additional_residuals.pop(0)
1258
+
1259
+ down_block_res_samples += res_samples
1260
+
1261
+ if is_controlnet:
1262
+ new_down_block_res_samples = ()
1263
+
1264
+ for down_block_res_sample, down_block_additional_residual in zip(
1265
+ down_block_res_samples, down_block_additional_residuals
1266
+ ):
1267
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1268
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1269
+
1270
+ down_block_res_samples = new_down_block_res_samples
1271
+
1272
+ # 4. mid
1273
+ if self.mid_block is not None:
1274
+ #
1275
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1276
+ if self.gradient_checkpointing:
1277
+ sample = torch.utils.checkpoint.checkpoint(
1278
+ self.mid_block,
1279
+ sample,
1280
+ emb,
1281
+ encoder_hidden_states=encoder_hidden_states,
1282
+ attention_mask=attention_mask,
1283
+ cross_attention_kwargs=cross_attention_kwargs,
1284
+ encoder_attention_mask=encoder_attention_mask,
1285
+ use_reentrant=False,
1286
+ )
1287
+ else:
1288
+ sample = self.mid_block(
1289
+ sample,
1290
+ emb,
1291
+ encoder_hidden_states=encoder_hidden_states,
1292
+ attention_mask=attention_mask,
1293
+ cross_attention_kwargs=cross_attention_kwargs,
1294
+ encoder_attention_mask=encoder_attention_mask,
1295
+ )
1296
+ else:
1297
+ if self.gradient_checkpointing:
1298
+ sample = torch.utils.checkpoint.checkpoint(
1299
+ self.mid_block,
1300
+ sample,
1301
+ emb,
1302
+ use_reentrant=False,
1303
+ )
1304
+ else:
1305
+ sample = self.mid_block(sample, emb)
1306
+
1307
+ # To support T2I-Adapter-XL
1308
+ if (
1309
+ is_adapter
1310
+ and len(down_intrablock_additional_residuals) > 0
1311
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1312
+ ):
1313
+ sample += down_intrablock_additional_residuals.pop(0)
1314
+
1315
+ if is_controlnet:
1316
+ sample = sample + mid_block_additional_residual
1317
+
1318
+ # 5. up
1319
+ for i, upsample_block in enumerate(self.up_blocks):
1320
+ is_final_block = i == len(self.up_blocks) - 1
1321
+
1322
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1323
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1324
+
1325
+ # if we have not reached the final block and need to forward the
1326
+ # upsample size, we do it here
1327
+ if not is_final_block and forward_upsample_size:
1328
+ upsample_size = down_block_res_samples[-1].shape[2:]
1329
+
1330
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1331
+ if self.gradient_checkpointing:
1332
+ sample = torch.utils.checkpoint.checkpoint(
1333
+ upsample_block,
1334
+ hidden_states=sample,
1335
+ temb=emb,
1336
+ res_hidden_states_tuple=res_samples,
1337
+ encoder_hidden_states=encoder_hidden_states,
1338
+ cross_attention_kwargs=cross_attention_kwargs,
1339
+ upsample_size=upsample_size,
1340
+ attention_mask=attention_mask,
1341
+ encoder_attention_mask=encoder_attention_mask,
1342
+ use_reentrant=False,
1343
+ )
1344
+ else:
1345
+ sample = upsample_block(
1346
+ hidden_states=sample,
1347
+ temb=emb,
1348
+ res_hidden_states_tuple=res_samples,
1349
+ encoder_hidden_states=encoder_hidden_states,
1350
+ cross_attention_kwargs=cross_attention_kwargs,
1351
+ upsample_size=upsample_size,
1352
+ attention_mask=attention_mask,
1353
+ encoder_attention_mask=encoder_attention_mask,
1354
+ )
1355
+ else:
1356
+ if self.gradient_checkpointing:
1357
+ sample = torch.utils.checkpoint.checkpoint(
1358
+ upsample_block,
1359
+ hidden_states=sample,
1360
+ temb=emb,
1361
+ res_hidden_states_tuple=res_samples,
1362
+ upsample_size=upsample_size,
1363
+ use_reentrant=False,
1364
+ )
1365
+ else:
1366
+ sample = upsample_block(
1367
+ hidden_states=sample,
1368
+ temb=emb,
1369
+ res_hidden_states_tuple=res_samples,
1370
+ upsample_size=upsample_size,
1371
+ )
1372
+
1373
+ # 6. post-process
1374
+ if self.conv_norm_out:
1375
+ b, c, h, w = sample.shape
1376
+ # sync_norm:
1377
+ sample = rearrange(sample, '(b t) c h w -> b c (h w t)', t=6, h=h, w=w)
1378
+ if self.gradient_checkpointing:
1379
+ sample = torch.utils.checkpoint.checkpoint(
1380
+ self.conv_norm_out,
1381
+ sample,
1382
+ use_reentrant=False,
1383
+ )
1384
+ else:
1385
+ sample = self.conv_norm_out(sample)
1386
+ # sync_norm:
1387
+ sample = rearrange(sample, 'b c (h w t) -> (b t) c h w', t=6, h=h, w=w)
1388
+
1389
+ if self.gradient_checkpointing:
1390
+ sample = torch.utils.checkpoint.checkpoint(
1391
+ self.conv_act,
1392
+ sample,
1393
+ use_reentrant=False,
1394
+ )
1395
+ else:
1396
+ sample = self.conv_act(sample)
1397
+
1398
+ if self.gradient_checkpointing:
1399
+ sample = torch.utils.checkpoint.checkpoint(
1400
+ self.conv_out,
1401
+ sample,
1402
+ use_reentrant=False,
1403
+ )
1404
+ else:
1405
+ sample = self.conv_out(sample)
1406
+
1407
+ if USE_PEFT_BACKEND:
1408
+ # remove `lora_scale` from each PEFT layer
1409
+ unscale_lora_layers(self, lora_scale)
1410
+
1411
+ if not return_dict:
1412
+ return (sample,)
1413
+
1414
+ return UNet2DConditionOutput(sample=sample)
Marigold/vae/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .vae import Encoder, Decoder, DecoderOutput, DiagonalGaussianDistribution
2
+
3
+ __all__ = [
4
+ "AutoencoderKL",
5
+ "Encoder",
6
+ "Decoder"
7
+ ]
Marigold/vae/autoencoder_kl.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
21
+ from diffusers.utils.accelerate_utils import apply_forward_hook
22
+ import os
23
+ import sys
24
+ sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../unet")))
25
+ from Marigold.unet.attention_processor import (
26
+ ADDED_KV_ATTENTION_PROCESSORS,
27
+ CROSS_ATTENTION_PROCESSORS,
28
+ Attention,
29
+ AttentionProcessor,
30
+ AttnAddedKVProcessor,
31
+ AttnProcessor,
32
+ FusedAttnProcessor2_0,
33
+ )
34
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
35
+ from diffusers.models.modeling_utils import ModelMixin
36
+ from Marigold.vae.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
37
+
38
+
39
+ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
40
+ r"""
41
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
42
+
43
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
44
+ for all models (such as downloading or saving).
45
+
46
+ Parameters:
47
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
48
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
49
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
50
+ Tuple of downsample block types.
51
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
52
+ Tuple of upsample block types.
53
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
54
+ Tuple of block output channels.
55
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
56
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
57
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
58
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
59
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
60
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
61
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
62
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
63
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
64
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
65
+ force_upcast (`bool`, *optional*, default to `True`):
66
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
67
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
68
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
69
+ mid_block_add_attention (`bool`, *optional*, default to `True`):
70
+ If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
71
+ mid_block will only have resnet blocks
72
+ """
73
+
74
+ _supports_gradient_checkpointing = True
75
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
76
+
77
+ @register_to_config
78
+ def __init__(
79
+ self,
80
+ in_channels: int = 3,
81
+ out_channels: int = 3,
82
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
83
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
84
+ block_out_channels: Tuple[int] = (64,),
85
+ layers_per_block: int = 1,
86
+ act_fn: str = "silu",
87
+ latent_channels: int = 4,
88
+ norm_num_groups: int = 32,
89
+ sample_size: int = 32,
90
+ scaling_factor: float = 0.18215,
91
+ shift_factor: Optional[float] = None,
92
+ latents_mean: Optional[Tuple[float]] = None,
93
+ latents_std: Optional[Tuple[float]] = None,
94
+ force_upcast: float = True,
95
+ use_quant_conv: bool = True,
96
+ use_post_quant_conv: bool = True,
97
+ mid_block_add_attention: bool = True,
98
+ # whether to use synchronized batch norm
99
+ use_RoPE: bool = False, # whether to use RoPE positional encoding
100
+ ):
101
+ super().__init__()
102
+
103
+ # pass init params to Encoder
104
+ self.encoder = Encoder(
105
+ in_channels=in_channels,
106
+ out_channels=latent_channels,
107
+ down_block_types=down_block_types,
108
+ block_out_channels=block_out_channels,
109
+ layers_per_block=layers_per_block,
110
+ act_fn=act_fn,
111
+ norm_num_groups=norm_num_groups,
112
+ double_z=True,
113
+ mid_block_add_attention=mid_block_add_attention,
114
+
115
+ use_RoPE=use_RoPE,
116
+ )
117
+
118
+ # pass init params to Decoder
119
+ self.decoder = Decoder(
120
+ in_channels=latent_channels,
121
+ out_channels=out_channels,
122
+ up_block_types=up_block_types,
123
+ block_out_channels=block_out_channels,
124
+ layers_per_block=layers_per_block,
125
+ norm_num_groups=norm_num_groups,
126
+ act_fn=act_fn,
127
+ mid_block_add_attention=mid_block_add_attention,
128
+
129
+ use_RoPE=use_RoPE,
130
+ )
131
+
132
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
133
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
134
+
135
+ self.use_slicing = False
136
+ self.use_tiling = False
137
+
138
+ # only relevant if vae tiling is enabled
139
+ self.tile_sample_min_size = self.config.sample_size
140
+ sample_size = (
141
+ self.config.sample_size[0]
142
+ if isinstance(self.config.sample_size, (list, tuple))
143
+ else self.config.sample_size
144
+ )
145
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
146
+ self.tile_overlap_factor = 0.25
147
+
148
+ def _set_gradient_checkpointing(self):
149
+ self.encoder.gradient_checkpointing = True
150
+ self.decoder.gradient_checkpointing = True
151
+
152
+ def enable_tiling(self, use_tiling: bool = True):
153
+ r"""
154
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
155
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
156
+ processing larger images.
157
+ """
158
+ self.use_tiling = use_tiling
159
+
160
+ def disable_tiling(self):
161
+ r"""
162
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
163
+ decoding in one step.
164
+ """
165
+ self.enable_tiling(False)
166
+
167
+ def enable_slicing(self):
168
+ r"""
169
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
170
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
171
+ """
172
+ self.use_slicing = True
173
+
174
+ def disable_slicing(self):
175
+ r"""
176
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
177
+ decoding in one step.
178
+ """
179
+ self.use_slicing = False
180
+
181
+ @property
182
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
183
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
184
+ r"""
185
+ Returns:
186
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
187
+ indexed by its weight name.
188
+ """
189
+ # set recursively
190
+ processors = {}
191
+
192
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
193
+ if hasattr(module, "get_processor"):
194
+ processors[f"{name}.processor"] = module.get_processor()
195
+
196
+ for sub_name, child in module.named_children():
197
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
198
+
199
+ return processors
200
+
201
+ for name, module in self.named_children():
202
+ fn_recursive_add_processors(name, module, processors)
203
+
204
+ return processors
205
+
206
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
207
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
208
+ r"""
209
+ Sets the attention processor to use to compute attention.
210
+
211
+ Parameters:
212
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
213
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
214
+ for **all** `Attention` layers.
215
+
216
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
217
+ processor. This is strongly recommended when setting trainable attention processors.
218
+
219
+ """
220
+ count = len(self.attn_processors.keys())
221
+
222
+ if isinstance(processor, dict) and len(processor) != count:
223
+ raise ValueError(
224
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
225
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
226
+ )
227
+
228
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
229
+ if hasattr(module, "set_processor"):
230
+ if not isinstance(processor, dict):
231
+ module.set_processor(processor)
232
+ else:
233
+ module.set_processor(processor.pop(f"{name}.processor"))
234
+
235
+ for sub_name, child in module.named_children():
236
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
237
+
238
+ for name, module in self.named_children():
239
+ fn_recursive_attn_processor(name, module, processor)
240
+
241
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
242
+ def set_default_attn_processor(self):
243
+ """
244
+ Disables custom attention processors and sets the default attention implementation.
245
+ """
246
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
247
+ processor = AttnAddedKVProcessor()
248
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
249
+ processor = AttnProcessor()
250
+ else:
251
+ raise ValueError(
252
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
253
+ )
254
+
255
+ self.set_attn_processor(processor)
256
+
257
+ @apply_forward_hook
258
+ def encode(
259
+ self, x: torch.Tensor, deterministic: bool = False, return_dict: bool = True
260
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
261
+ """
262
+ Encode a batch of images into latents.
263
+
264
+ Args:
265
+ x (`torch.Tensor`): Input batch of images.
266
+ return_dict (`bool`, *optional*, defaults to `True`):
267
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
268
+
269
+ Returns:
270
+ The latent representations of the encoded images. If `return_dict` is True, a
271
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
272
+ """
273
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
274
+ return self.tiled_encode(x, return_dict=return_dict)
275
+
276
+ if self.use_slicing and x.shape[0] > 1:
277
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
278
+ h = torch.cat(encoded_slices)
279
+ else:
280
+ h = self.encoder(x)
281
+
282
+ if self.quant_conv is not None:
283
+ moments = self.quant_conv(h)
284
+ else:
285
+ moments = h
286
+
287
+ if deterministic:
288
+ latent, _ = torch.chunk(moments, 2, dim=1)
289
+ return latent
290
+
291
+ posterior = DiagonalGaussianDistribution(moments)
292
+
293
+ if not return_dict:
294
+ return (posterior,)
295
+
296
+ return AutoencoderKLOutput(latent_dist=posterior)
297
+
298
+ def _decode(self, z: torch.Tensor, deterministic: bool = False, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
299
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
300
+ return self.tiled_decode(z, return_dict=return_dict)
301
+
302
+ if self.post_quant_conv is not None:
303
+ z = self.post_quant_conv(z)
304
+
305
+ dec = self.decoder(z)
306
+
307
+ if deterministic:
308
+ return dec
309
+
310
+ if not return_dict:
311
+ return (dec,)
312
+
313
+ return DecoderOutput(sample=dec)
314
+
315
+ @apply_forward_hook
316
+ def decode(
317
+ self, z: torch.FloatTensor, deterministic: bool = False, return_dict: bool = True, generator=None
318
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
319
+ """
320
+ Decode a batch of images.
321
+
322
+ Args:
323
+ z (`torch.Tensor`): Input batch of latent vectors.
324
+ return_dict (`bool`, *optional*, defaults to `True`):
325
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
326
+
327
+ Returns:
328
+ [`~models.vae.DecoderOutput`] or `tuple`:
329
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
330
+ returned.
331
+
332
+ """
333
+ if self.use_slicing and z.shape[0] > 1:
334
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
335
+ decoded = torch.cat(decoded_slices)
336
+ else:
337
+ decoded = self._decode(z, deterministic)
338
+
339
+ if deterministic:
340
+ return decoded
341
+ else:
342
+ decoded = decoded.sample
343
+
344
+ if not return_dict:
345
+ return (decoded,)
346
+
347
+ return DecoderOutput(sample=decoded)
348
+
349
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
350
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
351
+ for y in range(blend_extent):
352
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
353
+ return b
354
+
355
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
356
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
357
+ for x in range(blend_extent):
358
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
359
+ return b
360
+
361
+ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
362
+ r"""Encode a batch of images using a tiled encoder.
363
+
364
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
365
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
366
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
367
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
368
+ output, but they should be much less noticeable.
369
+
370
+ Args:
371
+ x (`torch.Tensor`): Input batch of images.
372
+ return_dict (`bool`, *optional*, defaults to `True`):
373
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
374
+
375
+ Returns:
376
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
377
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
378
+ `tuple` is returned.
379
+ """
380
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
381
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
382
+ row_limit = self.tile_latent_min_size - blend_extent
383
+
384
+ # Split the image into 512x512 tiles and encode them separately.
385
+ rows = []
386
+ for i in range(0, x.shape[2], overlap_size):
387
+ row = []
388
+ for j in range(0, x.shape[3], overlap_size):
389
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
390
+ tile = self.encoder(tile)
391
+ if self.config.use_quant_conv:
392
+ tile = self.quant_conv(tile)
393
+ row.append(tile)
394
+ rows.append(row)
395
+ result_rows = []
396
+ for i, row in enumerate(rows):
397
+ result_row = []
398
+ for j, tile in enumerate(row):
399
+ # blend the above tile and the left tile
400
+ # to the current tile and add the current tile to the result row
401
+ if i > 0:
402
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
403
+ if j > 0:
404
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
405
+ result_row.append(tile[:, :, :row_limit, :row_limit])
406
+ result_rows.append(torch.cat(result_row, dim=3))
407
+
408
+ moments = torch.cat(result_rows, dim=2)
409
+ posterior = DiagonalGaussianDistribution(moments)
410
+
411
+ if not return_dict:
412
+ return (posterior,)
413
+
414
+ return AutoencoderKLOutput(latent_dist=posterior)
415
+
416
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
417
+ r"""
418
+ Decode a batch of images using a tiled decoder.
419
+
420
+ Args:
421
+ z (`torch.Tensor`): Input batch of latent vectors.
422
+ return_dict (`bool`, *optional*, defaults to `True`):
423
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
424
+
425
+ Returns:
426
+ [`~models.vae.DecoderOutput`] or `tuple`:
427
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
428
+ returned.
429
+ """
430
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
431
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
432
+ row_limit = self.tile_sample_min_size - blend_extent
433
+
434
+ # Split z into overlapping 64x64 tiles and decode them separately.
435
+ # The tiles have an overlap to avoid seams between tiles.
436
+ rows = []
437
+ for i in range(0, z.shape[2], overlap_size):
438
+ row = []
439
+ for j in range(0, z.shape[3], overlap_size):
440
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
441
+ if self.config.use_post_quant_conv:
442
+ tile = self.post_quant_conv(tile)
443
+ decoded = self.decoder(tile)
444
+ row.append(decoded)
445
+ rows.append(row)
446
+ result_rows = []
447
+ for i, row in enumerate(rows):
448
+ result_row = []
449
+ for j, tile in enumerate(row):
450
+ # blend the above tile and the left tile
451
+ # to the current tile and add the current tile to the result row
452
+ if i > 0:
453
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
454
+ if j > 0:
455
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
456
+ result_row.append(tile[:, :, :row_limit, :row_limit])
457
+ result_rows.append(torch.cat(result_row, dim=3))
458
+
459
+ dec = torch.cat(result_rows, dim=2)
460
+ if not return_dict:
461
+ return (dec,)
462
+
463
+ return DecoderOutput(sample=dec)
464
+
465
+ def forward(
466
+ self,
467
+ sample: torch.Tensor,
468
+ sample_posterior: bool = False,
469
+ return_dict: bool = True,
470
+ generator: Optional[torch.Generator] = None,
471
+ ) -> Union[DecoderOutput, torch.Tensor]:
472
+ r"""
473
+ Args:
474
+ sample (`torch.Tensor`): Input sample.
475
+ sample_posterior (`bool`, *optional*, defaults to `False`):
476
+ Whether to sample from the posterior.
477
+ return_dict (`bool`, *optional*, defaults to `True`):
478
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
479
+ """
480
+ x = sample
481
+ posterior = self.encode(x).latent_dist
482
+ if sample_posterior:
483
+ z = posterior.sample(generator=generator)
484
+ else:
485
+ z = posterior.mode()
486
+ dec = self.decode(z).sample
487
+
488
+ if not return_dict:
489
+ return (dec,)
490
+
491
+ return DecoderOutput(sample=dec)
492
+
493
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
494
+ def fuse_qkv_projections(self):
495
+ """
496
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
497
+ are fused. For cross-attention modules, key and value projection matrices are fused.
498
+
499
+ <Tip warning={true}>
500
+
501
+ This API is 🧪 experimental.
502
+
503
+ </Tip>
504
+ """
505
+ self.original_attn_processors = None
506
+
507
+ for _, attn_processor in self.attn_processors.items():
508
+ if "Added" in str(attn_processor.__class__.__name__):
509
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
510
+
511
+ self.original_attn_processors = self.attn_processors
512
+
513
+ for module in self.modules():
514
+ if isinstance(module, Attention):
515
+ module.fuse_projections(fuse=True)
516
+
517
+ self.set_attn_processor(FusedAttnProcessor2_0())
518
+
519
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
520
+ def unfuse_qkv_projections(self):
521
+ """Disables the fused QKV projection if enabled.
522
+
523
+ <Tip warning={true}>
524
+
525
+ This API is 🧪 experimental.
526
+
527
+ </Tip>
528
+
529
+ """
530
+ if self.original_attn_processors is not None:
531
+ self.set_attn_processor(self.original_attn_processors)
Marigold/vae/vae.py ADDED
@@ -0,0 +1,1015 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from diffusers.utils import BaseOutput, is_torch_version
22
+ from diffusers.utils.torch_utils import randn_tensor
23
+ from diffusers.models.activations import get_activation
24
+ import sys
25
+ import os
26
+ sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../unet")))
27
+
28
+ from Marigold.unet.attention_processor import SpatialNorm
29
+
30
+ from Marigold.unet.unet_2d_blocks import (
31
+ AutoencoderTinyBlock,
32
+ UNetMidBlock2D,
33
+ get_down_block,
34
+ get_up_block,
35
+ )
36
+ from einops import rearrange
37
+
38
+ @dataclass
39
+ class DecoderOutput(BaseOutput):
40
+ r"""
41
+ Output of decoding method.
42
+
43
+ Args:
44
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
45
+ The decoded output sample from the last layer of the model.
46
+ """
47
+
48
+ sample: torch.Tensor
49
+ commit_loss: Optional[torch.FloatTensor] = None
50
+
51
+
52
+ class Encoder(nn.Module):
53
+ r"""
54
+ The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
55
+
56
+ Args:
57
+ in_channels (`int`, *optional*, defaults to 3):
58
+ The number of input channels.
59
+ out_channels (`int`, *optional*, defaults to 3):
60
+ The number of output channels.
61
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
62
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
63
+ options.
64
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
65
+ The number of output channels for each block.
66
+ layers_per_block (`int`, *optional*, defaults to 2):
67
+ The number of layers per block.
68
+ norm_num_groups (`int`, *optional*, defaults to 32):
69
+ The number of groups for normalization.
70
+ act_fn (`str`, *optional*, defaults to `"silu"`):
71
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
72
+ double_z (`bool`, *optional*, defaults to `True`):
73
+ Whether to double the number of output channels for the last block.
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ in_channels: int = 3,
79
+ out_channels: int = 3,
80
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
81
+ block_out_channels: Tuple[int, ...] = (64,),
82
+ layers_per_block: int = 2,
83
+ norm_num_groups: int = 32,
84
+ act_fn: str = "silu",
85
+ double_z: bool = True,
86
+ mid_block_add_attention=True,
87
+
88
+ padding_mode: str = "zeros",
89
+ use_RoPE: bool = False, # whether to use RoPE positional encoding
90
+ ):
91
+ super().__init__()
92
+ self.layers_per_block = layers_per_block
93
+
94
+ self.conv_in = nn.Conv2d(
95
+ in_channels,
96
+ block_out_channels[0],
97
+ kernel_size=3,
98
+ stride=1,
99
+ padding=1,
100
+ padding_mode=padding_mode,
101
+ )
102
+
103
+ self.down_blocks = nn.ModuleList([])
104
+
105
+ # down
106
+ output_channel = block_out_channels[0]
107
+ for i, down_block_type in enumerate(down_block_types):
108
+ input_channel = output_channel
109
+ output_channel = block_out_channels[i]
110
+ is_final_block = i == len(block_out_channels) - 1
111
+
112
+ down_block = get_down_block(
113
+ down_block_type,
114
+ num_layers=self.layers_per_block,
115
+ in_channels=input_channel,
116
+ out_channels=output_channel,
117
+ add_downsample=not is_final_block,
118
+ resnet_eps=1e-6,
119
+ downsample_padding=0,
120
+ resnet_act_fn=act_fn,
121
+ resnet_groups=norm_num_groups,
122
+ attention_head_dim=output_channel,
123
+ temb_channels=None,
124
+
125
+ )
126
+ self.down_blocks.append(down_block)
127
+
128
+ # mid
129
+ self.mid_block = UNetMidBlock2D(
130
+ in_channels=block_out_channels[-1],
131
+ resnet_eps=1e-6,
132
+ resnet_act_fn=act_fn,
133
+ output_scale_factor=1,
134
+ resnet_time_scale_shift="default",
135
+ attention_head_dim=block_out_channels[-1],
136
+ resnet_groups=norm_num_groups,
137
+ temb_channels=None,
138
+ add_attention=mid_block_add_attention,
139
+
140
+ use_RoPE=use_RoPE,
141
+ )
142
+
143
+ # out
144
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
145
+ self.conv_act = nn.SiLU()
146
+
147
+ conv_out_channels = 2 * out_channels if double_z else out_channels
148
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1, padding_mode=padding_mode)
149
+
150
+ self.gradient_checkpointing = False
151
+
152
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
153
+ r"""The forward method of the `Encoder` class."""
154
+
155
+ sample = self.conv_in(sample)
156
+
157
+ if self.training and self.gradient_checkpointing:
158
+
159
+ def create_custom_forward(module):
160
+ def custom_forward(*inputs):
161
+ return module(*inputs)
162
+
163
+ return custom_forward
164
+
165
+ # down
166
+ if is_torch_version(">=", "1.11.0"):
167
+ for down_block in self.down_blocks:
168
+ sample = torch.utils.checkpoint.checkpoint(
169
+ create_custom_forward(down_block), sample, use_reentrant=False
170
+ )
171
+ # middle
172
+ sample = torch.utils.checkpoint.checkpoint(
173
+ create_custom_forward(self.mid_block), sample, use_reentrant=False
174
+ )
175
+ else:
176
+ for down_block in self.down_blocks:
177
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
178
+ # middle
179
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
180
+
181
+ else:
182
+ # down
183
+ for down_block in self.down_blocks:
184
+ sample = down_block(sample)
185
+
186
+ # middle
187
+ sample = self.mid_block(sample)
188
+
189
+ b, c, h, w = sample.shape
190
+ # post-process
191
+ # sync_norm:
192
+ sample = rearrange(sample, '(b t) c h w -> b c (h w t)', t=6, h=h, w=w)
193
+
194
+ sample = self.conv_norm_out(sample)
195
+
196
+ # sync_norm:
197
+ sample = rearrange(sample, 'b c (h w t) -> (b t) c h w', t=6, h=h, w=w)
198
+
199
+ sample = self.conv_act(sample)
200
+ sample = self.conv_out(sample)
201
+
202
+ return sample
203
+
204
+
205
+ class Decoder(nn.Module):
206
+ r"""
207
+ The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
208
+
209
+ Args:
210
+ in_channels (`int`, *optional*, defaults to 3):
211
+ The number of input channels.
212
+ out_channels (`int`, *optional*, defaults to 3):
213
+ The number of output channels.
214
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
215
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
216
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
217
+ The number of output channels for each block.
218
+ layers_per_block (`int`, *optional*, defaults to 2):
219
+ The number of layers per block.
220
+ norm_num_groups (`int`, *optional*, defaults to 32):
221
+ The number of groups for normalization.
222
+ act_fn (`str`, *optional*, defaults to `"silu"`):
223
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
224
+ norm_type (`str`, *optional*, defaults to `"group"`):
225
+ The normalization type to use. Can be either `"group"` or `"spatial"`.
226
+ """
227
+
228
+ def __init__(
229
+ self,
230
+ in_channels: int = 3,
231
+ out_channels: int = 3,
232
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
233
+ block_out_channels: Tuple[int, ...] = (64,),
234
+ layers_per_block: int = 2,
235
+ norm_num_groups: int = 32,
236
+ act_fn: str = "silu",
237
+ norm_type: str = "group", # group, spatial
238
+ mid_block_add_attention=True,
239
+
240
+ use_RoPE: bool = False, # whether to use RoPE positional encoding
241
+ ):
242
+ super().__init__()
243
+ self.layers_per_block = layers_per_block
244
+
245
+ self.conv_in = nn.Conv2d(
246
+ in_channels,
247
+ block_out_channels[-1],
248
+ kernel_size=3,
249
+ stride=1,
250
+ padding=1,
251
+ )
252
+
253
+ self.up_blocks = nn.ModuleList([])
254
+
255
+ temb_channels = in_channels if norm_type == "spatial" else None
256
+
257
+ # mid
258
+ self.mid_block = UNetMidBlock2D(
259
+ in_channels=block_out_channels[-1],
260
+ resnet_eps=1e-6,
261
+ resnet_act_fn=act_fn,
262
+ output_scale_factor=1,
263
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
264
+ attention_head_dim=block_out_channels[-1],
265
+ resnet_groups=norm_num_groups,
266
+ temb_channels=temb_channels,
267
+ add_attention=mid_block_add_attention,
268
+
269
+ use_RoPE=use_RoPE,
270
+ )
271
+
272
+ # up
273
+ reversed_block_out_channels = list(reversed(block_out_channels))
274
+ output_channel = reversed_block_out_channels[0]
275
+ for i, up_block_type in enumerate(up_block_types):
276
+ prev_output_channel = output_channel
277
+ output_channel = reversed_block_out_channels[i]
278
+
279
+ is_final_block = i == len(block_out_channels) - 1
280
+
281
+ up_block = get_up_block(
282
+ up_block_type,
283
+ num_layers=self.layers_per_block + 1,
284
+ in_channels=prev_output_channel,
285
+ out_channels=output_channel,
286
+ prev_output_channel=None,
287
+ add_upsample=not is_final_block,
288
+ resnet_eps=1e-6,
289
+ resnet_act_fn=act_fn,
290
+ resnet_groups=norm_num_groups,
291
+ attention_head_dim=output_channel,
292
+ temb_channels=temb_channels,
293
+ resnet_time_scale_shift=norm_type,
294
+
295
+ )
296
+ self.up_blocks.append(up_block)
297
+ prev_output_channel = output_channel
298
+
299
+ # out
300
+ if norm_type == "spatial":
301
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
302
+ else:
303
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
304
+ self.conv_act = nn.SiLU()
305
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
306
+
307
+ self.gradient_checkpointing = False
308
+
309
+ def forward(
310
+ self,
311
+ sample: torch.Tensor,
312
+ latent_embeds: Optional[torch.Tensor] = None,
313
+ ) -> torch.Tensor:
314
+ r"""The forward method of the `Decoder` class."""
315
+
316
+ sample = self.conv_in(sample)
317
+
318
+
319
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
320
+ if self.gradient_checkpointing:
321
+
322
+ def create_custom_forward(module):
323
+ def custom_forward(*inputs):
324
+ return module(*inputs)
325
+
326
+ return custom_forward
327
+
328
+ if is_torch_version(">=", "1.11.0"):
329
+ # middle
330
+ sample = torch.utils.checkpoint.checkpoint(
331
+ create_custom_forward(self.mid_block),
332
+ sample,
333
+ latent_embeds,
334
+ use_reentrant=False,
335
+ )
336
+ sample = sample.to(upscale_dtype)
337
+
338
+ # up
339
+ for up_block in self.up_blocks:
340
+ sample = torch.utils.checkpoint.checkpoint(
341
+ create_custom_forward(up_block),
342
+ sample,
343
+ latent_embeds,
344
+ use_reentrant=False,
345
+ )
346
+ else:
347
+ # middle
348
+ sample = torch.utils.checkpoint.checkpoint(
349
+ create_custom_forward(self.mid_block), sample, latent_embeds
350
+ )
351
+ sample = sample.to(upscale_dtype)
352
+
353
+ # up
354
+ for up_block in self.up_blocks:
355
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
356
+ else:
357
+ # middle
358
+ sample = self.mid_block(sample, latent_embeds)
359
+ sample = sample.to(upscale_dtype)
360
+
361
+ # up
362
+ for up_block in self.up_blocks:
363
+ sample = up_block(sample, latent_embeds)
364
+
365
+ b, c, h, w = sample.shape
366
+ # sync_norm:
367
+ sample = rearrange(sample, '(b t) c h w -> b c (h w t)', t=6, h=h, w=w)
368
+
369
+ # post-process
370
+ if latent_embeds is None:
371
+ sample = self.conv_norm_out(sample)
372
+ else:
373
+ sample = self.conv_norm_out(sample, latent_embeds)
374
+
375
+ # sync_norm:
376
+ sample = rearrange(sample, 'b c (h w t) -> (b t) c h w', t=6, h=h, w=w)
377
+
378
+ sample = self.conv_act(sample)
379
+ sample = self.conv_out(sample)
380
+
381
+ return sample
382
+
383
+
384
+ class UpSample(nn.Module):
385
+ r"""
386
+ The `UpSample` layer of a variational autoencoder that upsamples its input.
387
+
388
+ Args:
389
+ in_channels (`int`, *optional*, defaults to 3):
390
+ The number of input channels.
391
+ out_channels (`int`, *optional*, defaults to 3):
392
+ The number of output channels.
393
+ """
394
+
395
+ def __init__(
396
+ self,
397
+ in_channels: int,
398
+ out_channels: int,
399
+ ) -> None:
400
+ super().__init__()
401
+ self.in_channels = in_channels
402
+ self.out_channels = out_channels
403
+ self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
404
+
405
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
406
+ r"""The forward method of the `UpSample` class."""
407
+ x = torch.relu(x)
408
+ x = self.deconv(x)
409
+ return x
410
+
411
+
412
+ class MaskConditionEncoder(nn.Module):
413
+ """
414
+ used in AsymmetricAutoencoderKL
415
+ """
416
+
417
+ def __init__(
418
+ self,
419
+ in_ch: int,
420
+ out_ch: int = 192,
421
+ res_ch: int = 768,
422
+ stride: int = 16,
423
+ ) -> None:
424
+ super().__init__()
425
+
426
+ channels = []
427
+ while stride > 1:
428
+ stride = stride // 2
429
+ in_ch_ = out_ch * 2
430
+ if out_ch > res_ch:
431
+ out_ch = res_ch
432
+ if stride == 1:
433
+ in_ch_ = res_ch
434
+ channels.append((in_ch_, out_ch))
435
+ out_ch *= 2
436
+
437
+ out_channels = []
438
+ for _in_ch, _out_ch in channels:
439
+ out_channels.append(_out_ch)
440
+ out_channels.append(channels[-1][0])
441
+
442
+ layers = []
443
+ in_ch_ = in_ch
444
+ for l in range(len(out_channels)):
445
+ out_ch_ = out_channels[l]
446
+ if l == 0 or l == 1:
447
+ layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1))
448
+ else:
449
+ layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1))
450
+ in_ch_ = out_ch_
451
+
452
+ self.layers = nn.Sequential(*layers)
453
+
454
+ def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
455
+ r"""The forward method of the `MaskConditionEncoder` class."""
456
+ out = {}
457
+ for l in range(len(self.layers)):
458
+ layer = self.layers[l]
459
+ x = layer(x)
460
+ out[str(tuple(x.shape))] = x
461
+ x = torch.relu(x)
462
+ return out
463
+
464
+
465
+ class MaskConditionDecoder(nn.Module):
466
+ r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
467
+ decoder with a conditioner on the mask and masked image.
468
+
469
+ Args:
470
+ in_channels (`int`, *optional*, defaults to 3):
471
+ The number of input channels.
472
+ out_channels (`int`, *optional*, defaults to 3):
473
+ The number of output channels.
474
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
475
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
476
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
477
+ The number of output channels for each block.
478
+ layers_per_block (`int`, *optional*, defaults to 2):
479
+ The number of layers per block.
480
+ norm_num_groups (`int`, *optional*, defaults to 32):
481
+ The number of groups for normalization.
482
+ act_fn (`str`, *optional*, defaults to `"silu"`):
483
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
484
+ norm_type (`str`, *optional*, defaults to `"group"`):
485
+ The normalization type to use. Can be either `"group"` or `"spatial"`.
486
+ """
487
+
488
+ def __init__(
489
+ self,
490
+ in_channels: int = 3,
491
+ out_channels: int = 3,
492
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
493
+ block_out_channels: Tuple[int, ...] = (64,),
494
+ layers_per_block: int = 2,
495
+ norm_num_groups: int = 32,
496
+ act_fn: str = "silu",
497
+ norm_type: str = "group", # group, spatial
498
+ ):
499
+ super().__init__()
500
+ self.layers_per_block = layers_per_block
501
+
502
+ self.conv_in = nn.Conv2d(
503
+ in_channels,
504
+ block_out_channels[-1],
505
+ kernel_size=3,
506
+ stride=1,
507
+ padding=1,
508
+ )
509
+
510
+ self.up_blocks = nn.ModuleList([])
511
+
512
+ temb_channels = in_channels if norm_type == "spatial" else None
513
+
514
+ # mid
515
+ self.mid_block = UNetMidBlock2D(
516
+ in_channels=block_out_channels[-1],
517
+ resnet_eps=1e-6,
518
+ resnet_act_fn=act_fn,
519
+ output_scale_factor=1,
520
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
521
+ attention_head_dim=block_out_channels[-1],
522
+ resnet_groups=norm_num_groups,
523
+ temb_channels=temb_channels,
524
+ )
525
+
526
+ # up
527
+ reversed_block_out_channels = list(reversed(block_out_channels))
528
+ output_channel = reversed_block_out_channels[0]
529
+ for i, up_block_type in enumerate(up_block_types):
530
+ prev_output_channel = output_channel
531
+ output_channel = reversed_block_out_channels[i]
532
+
533
+ is_final_block = i == len(block_out_channels) - 1
534
+
535
+ up_block = get_up_block(
536
+ up_block_type,
537
+ num_layers=self.layers_per_block + 1,
538
+ in_channels=prev_output_channel,
539
+ out_channels=output_channel,
540
+ prev_output_channel=None,
541
+ add_upsample=not is_final_block,
542
+ resnet_eps=1e-6,
543
+ resnet_act_fn=act_fn,
544
+ resnet_groups=norm_num_groups,
545
+ attention_head_dim=output_channel,
546
+ temb_channels=temb_channels,
547
+ resnet_time_scale_shift=norm_type,
548
+ )
549
+ self.up_blocks.append(up_block)
550
+ prev_output_channel = output_channel
551
+
552
+ # condition encoder
553
+ self.condition_encoder = MaskConditionEncoder(
554
+ in_ch=out_channels,
555
+ out_ch=block_out_channels[0],
556
+ res_ch=block_out_channels[-1],
557
+ )
558
+
559
+ # out
560
+ if norm_type == "spatial":
561
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
562
+ else:
563
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
564
+ self.conv_act = nn.SiLU()
565
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
566
+
567
+ self.gradient_checkpointing = False
568
+
569
+ def forward(
570
+ self,
571
+ z: torch.Tensor,
572
+ image: Optional[torch.Tensor] = None,
573
+ mask: Optional[torch.Tensor] = None,
574
+ latent_embeds: Optional[torch.Tensor] = None,
575
+ ) -> torch.Tensor:
576
+ r"""The forward method of the `MaskConditionDecoder` class."""
577
+ sample = z
578
+ sample = self.conv_in(sample)
579
+
580
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
581
+ if self.training and self.gradient_checkpointing:
582
+
583
+ def create_custom_forward(module):
584
+ def custom_forward(*inputs):
585
+ return module(*inputs)
586
+
587
+ return custom_forward
588
+
589
+ if is_torch_version(">=", "1.11.0"):
590
+ # middle
591
+ sample = torch.utils.checkpoint.checkpoint(
592
+ create_custom_forward(self.mid_block),
593
+ sample,
594
+ latent_embeds,
595
+ use_reentrant=False,
596
+ )
597
+ sample = sample.to(upscale_dtype)
598
+
599
+ # condition encoder
600
+ if image is not None and mask is not None:
601
+ masked_image = (1 - mask) * image
602
+ im_x = torch.utils.checkpoint.checkpoint(
603
+ create_custom_forward(self.condition_encoder),
604
+ masked_image,
605
+ mask,
606
+ use_reentrant=False,
607
+ )
608
+
609
+ # up
610
+ for up_block in self.up_blocks:
611
+ if image is not None and mask is not None:
612
+ sample_ = im_x[str(tuple(sample.shape))]
613
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
614
+ sample = sample * mask_ + sample_ * (1 - mask_)
615
+ sample = torch.utils.checkpoint.checkpoint(
616
+ create_custom_forward(up_block),
617
+ sample,
618
+ latent_embeds,
619
+ use_reentrant=False,
620
+ )
621
+ if image is not None and mask is not None:
622
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
623
+ else:
624
+ # middle
625
+ sample = torch.utils.checkpoint.checkpoint(
626
+ create_custom_forward(self.mid_block), sample, latent_embeds
627
+ )
628
+ sample = sample.to(upscale_dtype)
629
+
630
+ # condition encoder
631
+ if image is not None and mask is not None:
632
+ masked_image = (1 - mask) * image
633
+ im_x = torch.utils.checkpoint.checkpoint(
634
+ create_custom_forward(self.condition_encoder),
635
+ masked_image,
636
+ mask,
637
+ )
638
+
639
+ # up
640
+ for up_block in self.up_blocks:
641
+ if image is not None and mask is not None:
642
+ sample_ = im_x[str(tuple(sample.shape))]
643
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
644
+ sample = sample * mask_ + sample_ * (1 - mask_)
645
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
646
+ if image is not None and mask is not None:
647
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
648
+ else:
649
+ # middle
650
+ sample = self.mid_block(sample, latent_embeds)
651
+ sample = sample.to(upscale_dtype)
652
+
653
+ # condition encoder
654
+ if image is not None and mask is not None:
655
+ masked_image = (1 - mask) * image
656
+ im_x = self.condition_encoder(masked_image, mask)
657
+
658
+ # up
659
+ for up_block in self.up_blocks:
660
+ if image is not None and mask is not None:
661
+ sample_ = im_x[str(tuple(sample.shape))]
662
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
663
+ sample = sample * mask_ + sample_ * (1 - mask_)
664
+ sample = up_block(sample, latent_embeds)
665
+ if image is not None and mask is not None:
666
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
667
+
668
+ # post-process
669
+ if latent_embeds is None:
670
+ sample = self.conv_norm_out(sample)
671
+ else:
672
+ sample = self.conv_norm_out(sample, latent_embeds)
673
+ sample = self.conv_act(sample)
674
+ sample = self.conv_out(sample)
675
+
676
+ return sample
677
+
678
+
679
+ class VectorQuantizer(nn.Module):
680
+ """
681
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
682
+ multiplications and allows for post-hoc remapping of indices.
683
+ """
684
+
685
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
686
+ # backwards compatibility we use the buggy version by default, but you can
687
+ # specify legacy=False to fix it.
688
+ def __init__(
689
+ self,
690
+ n_e: int,
691
+ vq_embed_dim: int,
692
+ beta: float,
693
+ remap=None,
694
+ unknown_index: str = "random",
695
+ sane_index_shape: bool = False,
696
+ legacy: bool = True,
697
+ ):
698
+ super().__init__()
699
+ self.n_e = n_e
700
+ self.vq_embed_dim = vq_embed_dim
701
+ self.beta = beta
702
+ self.legacy = legacy
703
+
704
+ self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
705
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
706
+
707
+ self.remap = remap
708
+ if self.remap is not None:
709
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
710
+ self.used: torch.Tensor
711
+ self.re_embed = self.used.shape[0]
712
+ self.unknown_index = unknown_index # "random" or "extra" or integer
713
+ if self.unknown_index == "extra":
714
+ self.unknown_index = self.re_embed
715
+ self.re_embed = self.re_embed + 1
716
+ print(
717
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
718
+ f"Using {self.unknown_index} for unknown indices."
719
+ )
720
+ else:
721
+ self.re_embed = n_e
722
+
723
+ self.sane_index_shape = sane_index_shape
724
+
725
+ def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
726
+ ishape = inds.shape
727
+ assert len(ishape) > 1
728
+ inds = inds.reshape(ishape[0], -1)
729
+ used = self.used.to(inds)
730
+ match = (inds[:, :, None] == used[None, None, ...]).long()
731
+ new = match.argmax(-1)
732
+ unknown = match.sum(2) < 1
733
+ if self.unknown_index == "random":
734
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
735
+ else:
736
+ new[unknown] = self.unknown_index
737
+ return new.reshape(ishape)
738
+
739
+ def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
740
+ ishape = inds.shape
741
+ assert len(ishape) > 1
742
+ inds = inds.reshape(ishape[0], -1)
743
+ used = self.used.to(inds)
744
+ if self.re_embed > self.used.shape[0]: # extra token
745
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
746
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
747
+ return back.reshape(ishape)
748
+
749
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
750
+ # reshape z -> (batch, height, width, channel) and flatten
751
+ z = z.permute(0, 2, 3, 1).contiguous()
752
+ z_flattened = z.view(-1, self.vq_embed_dim)
753
+
754
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
755
+ min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
756
+
757
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
758
+ perplexity = None
759
+ min_encodings = None
760
+
761
+ # compute loss for embedding
762
+ if not self.legacy:
763
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
764
+ else:
765
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
766
+
767
+ # preserve gradients
768
+ z_q: torch.Tensor = z + (z_q - z).detach()
769
+
770
+ # reshape back to match original input shape
771
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
772
+
773
+ if self.remap is not None:
774
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
775
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
776
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
777
+
778
+ if self.sane_index_shape:
779
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
780
+
781
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
782
+
783
+ def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.Tensor:
784
+ # shape specifying (batch, height, width, channel)
785
+ if self.remap is not None:
786
+ indices = indices.reshape(shape[0], -1) # add batch axis
787
+ indices = self.unmap_to_all(indices)
788
+ indices = indices.reshape(-1) # flatten again
789
+
790
+ # get quantized latent vectors
791
+ z_q: torch.Tensor = self.embedding(indices)
792
+
793
+ if shape is not None:
794
+ z_q = z_q.view(shape)
795
+ # reshape back to match original input shape
796
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
797
+
798
+ return z_q
799
+
800
+
801
+ class DiagonalGaussianDistribution(object):
802
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
803
+ self.parameters = parameters
804
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
805
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
806
+ self.deterministic = deterministic
807
+ self.std = torch.exp(0.5 * self.logvar)
808
+ self.var = torch.exp(self.logvar)
809
+ if self.deterministic:
810
+ self.var = self.std = torch.zeros_like(
811
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
812
+ )
813
+
814
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
815
+ # make sure sample is on the same device as the parameters and has same dtype
816
+ sample = randn_tensor(
817
+ self.mean.shape,
818
+ generator=generator,
819
+ device=self.parameters.device,
820
+ dtype=self.parameters.dtype,
821
+ )
822
+ x = self.mean + self.std * sample
823
+ return x
824
+
825
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
826
+ if self.deterministic:
827
+ return torch.Tensor([0.0])
828
+ else:
829
+ if other is None:
830
+ return 0.5 * torch.sum(
831
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
832
+ dim=[1, 2, 3],
833
+ )
834
+ else:
835
+ return 0.5 * torch.sum(
836
+ torch.pow(self.mean - other.mean, 2) / other.var
837
+ + self.var / other.var
838
+ - 1.0
839
+ - self.logvar
840
+ + other.logvar,
841
+ dim=[1, 2, 3],
842
+ )
843
+
844
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
845
+ if self.deterministic:
846
+ return torch.Tensor([0.0])
847
+ logtwopi = np.log(2.0 * np.pi)
848
+ return 0.5 * torch.sum(
849
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
850
+ dim=dims,
851
+ )
852
+
853
+ def mode(self) -> torch.Tensor:
854
+ return self.mean
855
+
856
+
857
+ class EncoderTiny(nn.Module):
858
+ r"""
859
+ The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
860
+
861
+ Args:
862
+ in_channels (`int`):
863
+ The number of input channels.
864
+ out_channels (`int`):
865
+ The number of output channels.
866
+ num_blocks (`Tuple[int, ...]`):
867
+ Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
868
+ use.
869
+ block_out_channels (`Tuple[int, ...]`):
870
+ The number of output channels for each block.
871
+ act_fn (`str`):
872
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
873
+ """
874
+
875
+ def __init__(
876
+ self,
877
+ in_channels: int,
878
+ out_channels: int,
879
+ num_blocks: Tuple[int, ...],
880
+ block_out_channels: Tuple[int, ...],
881
+ act_fn: str,
882
+ ):
883
+ super().__init__()
884
+ layers = []
885
+ for i, num_block in enumerate(num_blocks):
886
+ num_channels = block_out_channels[i]
887
+
888
+ if i == 0:
889
+ layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
890
+ else:
891
+ layers.append(
892
+ nn.Conv2d(
893
+ num_channels,
894
+ num_channels,
895
+ kernel_size=3,
896
+ padding=1,
897
+ stride=2,
898
+ bias=False,
899
+ )
900
+ )
901
+
902
+ for _ in range(num_block):
903
+ layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
904
+
905
+ layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
906
+
907
+ self.layers = nn.Sequential(*layers)
908
+ self.gradient_checkpointing = False
909
+
910
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
911
+ r"""The forward method of the `EncoderTiny` class."""
912
+ if self.training and self.gradient_checkpointing:
913
+
914
+ def create_custom_forward(module):
915
+ def custom_forward(*inputs):
916
+ return module(*inputs)
917
+
918
+ return custom_forward
919
+
920
+ if is_torch_version(">=", "1.11.0"):
921
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
922
+ else:
923
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
924
+
925
+ else:
926
+ # scale image from [-1, 1] to [0, 1] to match TAESD convention
927
+ x = self.layers(x.add(1).div(2))
928
+
929
+ return x
930
+
931
+
932
+ class DecoderTiny(nn.Module):
933
+ r"""
934
+ The `DecoderTiny` layer is a simpler version of the `Decoder` layer.
935
+
936
+ Args:
937
+ in_channels (`int`):
938
+ The number of input channels.
939
+ out_channels (`int`):
940
+ The number of output channels.
941
+ num_blocks (`Tuple[int, ...]`):
942
+ Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
943
+ use.
944
+ block_out_channels (`Tuple[int, ...]`):
945
+ The number of output channels for each block.
946
+ upsampling_scaling_factor (`int`):
947
+ The scaling factor to use for upsampling.
948
+ act_fn (`str`):
949
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
950
+ """
951
+
952
+ def __init__(
953
+ self,
954
+ in_channels: int,
955
+ out_channels: int,
956
+ num_blocks: Tuple[int, ...],
957
+ block_out_channels: Tuple[int, ...],
958
+ upsampling_scaling_factor: int,
959
+ act_fn: str,
960
+ upsample_fn: str,
961
+ ):
962
+ super().__init__()
963
+
964
+ layers = [
965
+ nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1),
966
+ get_activation(act_fn),
967
+ ]
968
+
969
+ for i, num_block in enumerate(num_blocks):
970
+ is_final_block = i == (len(num_blocks) - 1)
971
+ num_channels = block_out_channels[i]
972
+
973
+ for _ in range(num_block):
974
+ layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
975
+
976
+ if not is_final_block:
977
+ layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor, mode=upsample_fn))
978
+
979
+ conv_out_channel = num_channels if not is_final_block else out_channels
980
+ layers.append(
981
+ nn.Conv2d(
982
+ num_channels,
983
+ conv_out_channel,
984
+ kernel_size=3,
985
+ padding=1,
986
+ bias=is_final_block,
987
+ )
988
+ )
989
+
990
+ self.layers = nn.Sequential(*layers)
991
+ self.gradient_checkpointing = False
992
+
993
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
994
+ r"""The forward method of the `DecoderTiny` class."""
995
+ # Clamp.
996
+ x = torch.tanh(x / 3) * 3
997
+
998
+ if self.training and self.gradient_checkpointing:
999
+
1000
+ def create_custom_forward(module):
1001
+ def custom_forward(*inputs):
1002
+ return module(*inputs)
1003
+
1004
+ return custom_forward
1005
+
1006
+ if is_torch_version(">=", "1.11.0"):
1007
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
1008
+ else:
1009
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
1010
+
1011
+ else:
1012
+ x = self.layers(x)
1013
+
1014
+ # scale image from [0, 1] to [-1, 1] to match diffusers convention
1015
+ return x.mul(2).sub(1)