xiaoanyu123 commited on
Commit
ddf41bf
·
verified ·
1 Parent(s): 4b39171

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/lumina_nextdit2d.py +342 -0
  2. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/pixart_transformer_2d.py +430 -0
  3. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/prior_transformer.py +384 -0
  4. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/sana_transformer.py +597 -0
  5. pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/stable_audio_transformer.py +439 -0
  6. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  7. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_heun_discrete.py +610 -0
  8. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_ipndm.py +224 -0
  9. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +617 -0
  10. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +589 -0
  11. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_karras_ve_flax.py +238 -0
  12. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_lcm.py +653 -0
  13. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_lms_discrete.py +552 -0
  14. pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_lms_discrete_flax.py +283 -0
  15. pythonProject/.venv/Lib/site-packages/diffusers/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  16. pythonProject/.venv/Lib/site-packages/diffusers/utils/__pycache__/state_dict_utils.cpython-310.pyc +0 -0
  17. pythonProject/.venv/Lib/site-packages/diffusers/utils/__pycache__/testing_utils.cpython-310.pyc +0 -0
  18. pythonProject/.venv/Lib/site-packages/diffusers/utils/__pycache__/torch_utils.cpython-310.pyc +0 -0
  19. pythonProject/.venv/Lib/site-packages/diffusers/utils/__pycache__/typing_utils.cpython-310.pyc +0 -0
  20. pythonProject/.venv/Lib/site-packages/diffusers/utils/__pycache__/versions.cpython-310.pyc +0 -0
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/lumina_nextdit2d.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ...configuration_utils import ConfigMixin, register_to_config
21
+ from ...utils import logging
22
+ from ..attention import LuminaFeedForward
23
+ from ..attention_processor import Attention, LuminaAttnProcessor2_0
24
+ from ..embeddings import (
25
+ LuminaCombinedTimestepCaptionEmbedding,
26
+ LuminaPatchEmbed,
27
+ )
28
+ from ..modeling_outputs import Transformer2DModelOutput
29
+ from ..modeling_utils import ModelMixin
30
+ from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ class LuminaNextDiTBlock(nn.Module):
37
+ """
38
+ A LuminaNextDiTBlock for LuminaNextDiT2DModel.
39
+
40
+ Parameters:
41
+ dim (`int`): Embedding dimension of the input features.
42
+ num_attention_heads (`int`): Number of attention heads.
43
+ num_kv_heads (`int`):
44
+ Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
45
+ multiple_of (`int`): The number of multiple of ffn layer.
46
+ ffn_dim_multiplier (`float`): The multiplier factor of ffn layer dimension.
47
+ norm_eps (`float`): The eps for norm layer.
48
+ qk_norm (`bool`): normalization for query and key.
49
+ cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states.
50
+ norm_elementwise_affine (`bool`, *optional*, defaults to True),
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ dim: int,
56
+ num_attention_heads: int,
57
+ num_kv_heads: int,
58
+ multiple_of: int,
59
+ ffn_dim_multiplier: float,
60
+ norm_eps: float,
61
+ qk_norm: bool,
62
+ cross_attention_dim: int,
63
+ norm_elementwise_affine: bool = True,
64
+ ) -> None:
65
+ super().__init__()
66
+ self.head_dim = dim // num_attention_heads
67
+
68
+ self.gate = nn.Parameter(torch.zeros([num_attention_heads]))
69
+
70
+ # Self-attention
71
+ self.attn1 = Attention(
72
+ query_dim=dim,
73
+ cross_attention_dim=None,
74
+ dim_head=dim // num_attention_heads,
75
+ qk_norm="layer_norm_across_heads" if qk_norm else None,
76
+ heads=num_attention_heads,
77
+ kv_heads=num_kv_heads,
78
+ eps=1e-5,
79
+ bias=False,
80
+ out_bias=False,
81
+ processor=LuminaAttnProcessor2_0(),
82
+ )
83
+ self.attn1.to_out = nn.Identity()
84
+
85
+ # Cross-attention
86
+ self.attn2 = Attention(
87
+ query_dim=dim,
88
+ cross_attention_dim=cross_attention_dim,
89
+ dim_head=dim // num_attention_heads,
90
+ qk_norm="layer_norm_across_heads" if qk_norm else None,
91
+ heads=num_attention_heads,
92
+ kv_heads=num_kv_heads,
93
+ eps=1e-5,
94
+ bias=False,
95
+ out_bias=False,
96
+ processor=LuminaAttnProcessor2_0(),
97
+ )
98
+
99
+ self.feed_forward = LuminaFeedForward(
100
+ dim=dim,
101
+ inner_dim=int(4 * 2 * dim / 3),
102
+ multiple_of=multiple_of,
103
+ ffn_dim_multiplier=ffn_dim_multiplier,
104
+ )
105
+
106
+ self.norm1 = LuminaRMSNormZero(
107
+ embedding_dim=dim,
108
+ norm_eps=norm_eps,
109
+ norm_elementwise_affine=norm_elementwise_affine,
110
+ )
111
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
112
+
113
+ self.norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
114
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
115
+
116
+ self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
117
+
118
+ def forward(
119
+ self,
120
+ hidden_states: torch.Tensor,
121
+ attention_mask: torch.Tensor,
122
+ image_rotary_emb: torch.Tensor,
123
+ encoder_hidden_states: torch.Tensor,
124
+ encoder_mask: torch.Tensor,
125
+ temb: torch.Tensor,
126
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
127
+ ):
128
+ """
129
+ Perform a forward pass through the LuminaNextDiTBlock.
130
+
131
+ Parameters:
132
+ hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock.
133
+ attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask.
134
+ image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies.
135
+ encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder.
136
+ encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask.
137
+ temb (`torch.Tensor`): Timestep embedding with text prompt embedding.
138
+ cross_attention_kwargs (`Dict[str, Any]`): kwargs for cross attention.
139
+ """
140
+ residual = hidden_states
141
+
142
+ # Self-attention
143
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
144
+ self_attn_output = self.attn1(
145
+ hidden_states=norm_hidden_states,
146
+ encoder_hidden_states=norm_hidden_states,
147
+ attention_mask=attention_mask,
148
+ query_rotary_emb=image_rotary_emb,
149
+ key_rotary_emb=image_rotary_emb,
150
+ **cross_attention_kwargs,
151
+ )
152
+
153
+ # Cross-attention
154
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
155
+ cross_attn_output = self.attn2(
156
+ hidden_states=norm_hidden_states,
157
+ encoder_hidden_states=norm_encoder_hidden_states,
158
+ attention_mask=encoder_mask,
159
+ query_rotary_emb=image_rotary_emb,
160
+ key_rotary_emb=None,
161
+ **cross_attention_kwargs,
162
+ )
163
+ cross_attn_output = cross_attn_output * self.gate.tanh().view(1, 1, -1, 1)
164
+ mixed_attn_output = self_attn_output + cross_attn_output
165
+ mixed_attn_output = mixed_attn_output.flatten(-2)
166
+ # linear proj
167
+ hidden_states = self.attn2.to_out[0](mixed_attn_output)
168
+
169
+ hidden_states = residual + gate_msa.unsqueeze(1).tanh() * self.norm2(hidden_states)
170
+
171
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
172
+
173
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
174
+
175
+ return hidden_states
176
+
177
+
178
+ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
179
+ """
180
+ LuminaNextDiT: Diffusion model with a Transformer backbone.
181
+
182
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
183
+
184
+ Parameters:
185
+ sample_size (`int`): The width of the latent images. This is fixed during training since
186
+ it is used to learn a number of position embeddings.
187
+ patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
188
+ The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
189
+ in_channels (`int`, *optional*, defaults to 4):
190
+ The number of input channels for the model. Typically, this matches the number of channels in the input
191
+ images.
192
+ hidden_size (`int`, *optional*, defaults to 4096):
193
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
194
+ hidden representations.
195
+ num_layers (`int`, *optional*, default to 32):
196
+ The number of layers in the model. This defines the depth of the neural network.
197
+ num_attention_heads (`int`, *optional*, defaults to 32):
198
+ The number of attention heads in each attention layer. This parameter specifies how many separate attention
199
+ mechanisms are used.
200
+ num_kv_heads (`int`, *optional*, defaults to 8):
201
+ The number of key-value heads in the attention mechanism, if different from the number of attention heads.
202
+ If None, it defaults to num_attention_heads.
203
+ multiple_of (`int`, *optional*, defaults to 256):
204
+ A factor that the hidden size should be a multiple of. This can help optimize certain hardware
205
+ configurations.
206
+ ffn_dim_multiplier (`float`, *optional*):
207
+ A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
208
+ the model configuration.
209
+ norm_eps (`float`, *optional*, defaults to 1e-5):
210
+ A small value added to the denominator for numerical stability in normalization layers.
211
+ learn_sigma (`bool`, *optional*, defaults to True):
212
+ Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in
213
+ predictions.
214
+ qk_norm (`bool`, *optional*, defaults to True):
215
+ Indicates if the queries and keys in the attention mechanism should be normalized.
216
+ cross_attention_dim (`int`, *optional*, defaults to 2048):
217
+ The dimensionality of the text embeddings. This parameter defines the size of the text representations used
218
+ in the model.
219
+ scaling_factor (`float`, *optional*, defaults to 1.0):
220
+ A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
221
+ overall scale of the model's operations.
222
+ """
223
+
224
+ _skip_layerwise_casting_patterns = ["patch_embedder", "norm", "ffn_norm"]
225
+
226
+ @register_to_config
227
+ def __init__(
228
+ self,
229
+ sample_size: int = 128,
230
+ patch_size: Optional[int] = 2,
231
+ in_channels: Optional[int] = 4,
232
+ hidden_size: Optional[int] = 2304,
233
+ num_layers: Optional[int] = 32,
234
+ num_attention_heads: Optional[int] = 32,
235
+ num_kv_heads: Optional[int] = None,
236
+ multiple_of: Optional[int] = 256,
237
+ ffn_dim_multiplier: Optional[float] = None,
238
+ norm_eps: Optional[float] = 1e-5,
239
+ learn_sigma: Optional[bool] = True,
240
+ qk_norm: Optional[bool] = True,
241
+ cross_attention_dim: Optional[int] = 2048,
242
+ scaling_factor: Optional[float] = 1.0,
243
+ ) -> None:
244
+ super().__init__()
245
+ self.sample_size = sample_size
246
+ self.patch_size = patch_size
247
+ self.in_channels = in_channels
248
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
249
+ self.hidden_size = hidden_size
250
+ self.num_attention_heads = num_attention_heads
251
+ self.head_dim = hidden_size // num_attention_heads
252
+ self.scaling_factor = scaling_factor
253
+
254
+ self.patch_embedder = LuminaPatchEmbed(
255
+ patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True
256
+ )
257
+
258
+ self.pad_token = nn.Parameter(torch.empty(hidden_size))
259
+
260
+ self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding(
261
+ hidden_size=min(hidden_size, 1024), cross_attention_dim=cross_attention_dim
262
+ )
263
+
264
+ self.layers = nn.ModuleList(
265
+ [
266
+ LuminaNextDiTBlock(
267
+ hidden_size,
268
+ num_attention_heads,
269
+ num_kv_heads,
270
+ multiple_of,
271
+ ffn_dim_multiplier,
272
+ norm_eps,
273
+ qk_norm,
274
+ cross_attention_dim,
275
+ )
276
+ for _ in range(num_layers)
277
+ ]
278
+ )
279
+ self.norm_out = LuminaLayerNormContinuous(
280
+ embedding_dim=hidden_size,
281
+ conditioning_embedding_dim=min(hidden_size, 1024),
282
+ elementwise_affine=False,
283
+ eps=1e-6,
284
+ bias=True,
285
+ out_dim=patch_size * patch_size * self.out_channels,
286
+ )
287
+ # self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels)
288
+
289
+ assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
290
+
291
+ def forward(
292
+ self,
293
+ hidden_states: torch.Tensor,
294
+ timestep: torch.Tensor,
295
+ encoder_hidden_states: torch.Tensor,
296
+ encoder_mask: torch.Tensor,
297
+ image_rotary_emb: torch.Tensor,
298
+ cross_attention_kwargs: Dict[str, Any] = None,
299
+ return_dict=True,
300
+ ) -> torch.Tensor:
301
+ """
302
+ Forward pass of LuminaNextDiT.
303
+
304
+ Parameters:
305
+ hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W).
306
+ timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,).
307
+ encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D).
308
+ encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L).
309
+ """
310
+ hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)
311
+ image_rotary_emb = image_rotary_emb.to(hidden_states.device)
312
+
313
+ temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)
314
+
315
+ encoder_mask = encoder_mask.bool()
316
+ for layer in self.layers:
317
+ hidden_states = layer(
318
+ hidden_states,
319
+ mask,
320
+ image_rotary_emb,
321
+ encoder_hidden_states,
322
+ encoder_mask,
323
+ temb=temb,
324
+ cross_attention_kwargs=cross_attention_kwargs,
325
+ )
326
+
327
+ hidden_states = self.norm_out(hidden_states, temb)
328
+
329
+ # unpatchify
330
+ height_tokens = width_tokens = self.patch_size
331
+ height, width = img_size[0]
332
+ batch_size = hidden_states.size(0)
333
+ sequence_length = (height // height_tokens) * (width // width_tokens)
334
+ hidden_states = hidden_states[:, :sequence_length].view(
335
+ batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
336
+ )
337
+ output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
338
+
339
+ if not return_dict:
340
+ return (output,)
341
+
342
+ return Transformer2DModelOutput(sample=output)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/pixart_transformer_2d.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Union
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from ...configuration_utils import ConfigMixin, register_to_config
20
+ from ...utils import logging
21
+ from ..attention import BasicTransformerBlock
22
+ from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
23
+ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
24
+ from ..modeling_outputs import Transformer2DModelOutput
25
+ from ..modeling_utils import ModelMixin
26
+ from ..normalization import AdaLayerNormSingle
27
+
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
33
+ r"""
34
+ A 2D Transformer model as introduced in PixArt family of models (https://huggingface.co/papers/2310.00426,
35
+ https://huggingface.co/papers/2403.04692).
36
+
37
+ Parameters:
38
+ num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
39
+ attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
40
+ in_channels (int, defaults to 4): The number of channels in the input.
41
+ out_channels (int, optional):
42
+ The number of channels in the output. Specify this parameter if the output channel number differs from the
43
+ input.
44
+ num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
45
+ dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
46
+ norm_num_groups (int, optional, defaults to 32):
47
+ Number of groups for group normalization within Transformer blocks.
48
+ cross_attention_dim (int, optional):
49
+ The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension.
50
+ attention_bias (bool, optional, defaults to True):
51
+ Configure if the Transformer blocks' attention should contain a bias parameter.
52
+ sample_size (int, defaults to 128):
53
+ The width of the latent images. This parameter is fixed during training.
54
+ patch_size (int, defaults to 2):
55
+ Size of the patches the model processes, relevant for architectures working on non-sequential data.
56
+ activation_fn (str, optional, defaults to "gelu-approximate"):
57
+ Activation function to use in feed-forward networks within Transformer blocks.
58
+ num_embeds_ada_norm (int, optional, defaults to 1000):
59
+ Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
60
+ inference.
61
+ upcast_attention (bool, optional, defaults to False):
62
+ If true, upcasts the attention mechanism dimensions for potentially improved performance.
63
+ norm_type (str, optional, defaults to "ada_norm_zero"):
64
+ Specifies the type of normalization used, can be 'ada_norm_zero'.
65
+ norm_elementwise_affine (bool, optional, defaults to False):
66
+ If true, enables element-wise affine parameters in the normalization layers.
67
+ norm_eps (float, optional, defaults to 1e-6):
68
+ A small constant added to the denominator in normalization layers to prevent division by zero.
69
+ interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings.
70
+ use_additional_conditions (bool, optional): If we're using additional conditions as inputs.
71
+ attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used.
72
+ caption_channels (int, optional, defaults to None):
73
+ Number of channels to use for projecting the caption embeddings.
74
+ use_linear_projection (bool, optional, defaults to False):
75
+ Deprecated argument. Will be removed in a future version.
76
+ num_vector_embeds (bool, optional, defaults to False):
77
+ Deprecated argument. Will be removed in a future version.
78
+ """
79
+
80
+ _supports_gradient_checkpointing = True
81
+ _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
82
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]
83
+
84
+ @register_to_config
85
+ def __init__(
86
+ self,
87
+ num_attention_heads: int = 16,
88
+ attention_head_dim: int = 72,
89
+ in_channels: int = 4,
90
+ out_channels: Optional[int] = 8,
91
+ num_layers: int = 28,
92
+ dropout: float = 0.0,
93
+ norm_num_groups: int = 32,
94
+ cross_attention_dim: Optional[int] = 1152,
95
+ attention_bias: bool = True,
96
+ sample_size: int = 128,
97
+ patch_size: int = 2,
98
+ activation_fn: str = "gelu-approximate",
99
+ num_embeds_ada_norm: Optional[int] = 1000,
100
+ upcast_attention: bool = False,
101
+ norm_type: str = "ada_norm_single",
102
+ norm_elementwise_affine: bool = False,
103
+ norm_eps: float = 1e-6,
104
+ interpolation_scale: Optional[int] = None,
105
+ use_additional_conditions: Optional[bool] = None,
106
+ caption_channels: Optional[int] = None,
107
+ attention_type: Optional[str] = "default",
108
+ ):
109
+ super().__init__()
110
+
111
+ # Validate inputs.
112
+ if norm_type != "ada_norm_single":
113
+ raise NotImplementedError(
114
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
115
+ )
116
+ elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None:
117
+ raise ValueError(
118
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
119
+ )
120
+
121
+ # Set some common variables used across the board.
122
+ self.attention_head_dim = attention_head_dim
123
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
124
+ self.out_channels = in_channels if out_channels is None else out_channels
125
+ if use_additional_conditions is None:
126
+ if sample_size == 128:
127
+ use_additional_conditions = True
128
+ else:
129
+ use_additional_conditions = False
130
+ self.use_additional_conditions = use_additional_conditions
131
+
132
+ self.gradient_checkpointing = False
133
+
134
+ # 2. Initialize the position embedding and transformer blocks.
135
+ self.height = self.config.sample_size
136
+ self.width = self.config.sample_size
137
+
138
+ interpolation_scale = (
139
+ self.config.interpolation_scale
140
+ if self.config.interpolation_scale is not None
141
+ else max(self.config.sample_size // 64, 1)
142
+ )
143
+ self.pos_embed = PatchEmbed(
144
+ height=self.config.sample_size,
145
+ width=self.config.sample_size,
146
+ patch_size=self.config.patch_size,
147
+ in_channels=self.config.in_channels,
148
+ embed_dim=self.inner_dim,
149
+ interpolation_scale=interpolation_scale,
150
+ )
151
+
152
+ self.transformer_blocks = nn.ModuleList(
153
+ [
154
+ BasicTransformerBlock(
155
+ self.inner_dim,
156
+ self.config.num_attention_heads,
157
+ self.config.attention_head_dim,
158
+ dropout=self.config.dropout,
159
+ cross_attention_dim=self.config.cross_attention_dim,
160
+ activation_fn=self.config.activation_fn,
161
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
162
+ attention_bias=self.config.attention_bias,
163
+ upcast_attention=self.config.upcast_attention,
164
+ norm_type=norm_type,
165
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
166
+ norm_eps=self.config.norm_eps,
167
+ attention_type=self.config.attention_type,
168
+ )
169
+ for _ in range(self.config.num_layers)
170
+ ]
171
+ )
172
+
173
+ # 3. Output blocks.
174
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
175
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
176
+ self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
177
+
178
+ self.adaln_single = AdaLayerNormSingle(
179
+ self.inner_dim, use_additional_conditions=self.use_additional_conditions
180
+ )
181
+ self.caption_projection = None
182
+ if self.config.caption_channels is not None:
183
+ self.caption_projection = PixArtAlphaTextProjection(
184
+ in_features=self.config.caption_channels, hidden_size=self.inner_dim
185
+ )
186
+
187
+ @property
188
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
189
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
190
+ r"""
191
+ Returns:
192
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
193
+ indexed by its weight name.
194
+ """
195
+ # set recursively
196
+ processors = {}
197
+
198
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
199
+ if hasattr(module, "get_processor"):
200
+ processors[f"{name}.processor"] = module.get_processor()
201
+
202
+ for sub_name, child in module.named_children():
203
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
204
+
205
+ return processors
206
+
207
+ for name, module in self.named_children():
208
+ fn_recursive_add_processors(name, module, processors)
209
+
210
+ return processors
211
+
212
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
213
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
214
+ r"""
215
+ Sets the attention processor to use to compute attention.
216
+
217
+ Parameters:
218
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
219
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
220
+ for **all** `Attention` layers.
221
+
222
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
223
+ processor. This is strongly recommended when setting trainable attention processors.
224
+
225
+ """
226
+ count = len(self.attn_processors.keys())
227
+
228
+ if isinstance(processor, dict) and len(processor) != count:
229
+ raise ValueError(
230
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
231
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
232
+ )
233
+
234
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
235
+ if hasattr(module, "set_processor"):
236
+ if not isinstance(processor, dict):
237
+ module.set_processor(processor)
238
+ else:
239
+ module.set_processor(processor.pop(f"{name}.processor"))
240
+
241
+ for sub_name, child in module.named_children():
242
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
243
+
244
+ for name, module in self.named_children():
245
+ fn_recursive_attn_processor(name, module, processor)
246
+
247
+ def set_default_attn_processor(self):
248
+ """
249
+ Disables custom attention processors and sets the default attention implementation.
250
+
251
+ Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
252
+ """
253
+ self.set_attn_processor(AttnProcessor())
254
+
255
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
256
+ def fuse_qkv_projections(self):
257
+ """
258
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
259
+ are fused. For cross-attention modules, key and value projection matrices are fused.
260
+
261
+ <Tip warning={true}>
262
+
263
+ This API is 🧪 experimental.
264
+
265
+ </Tip>
266
+ """
267
+ self.original_attn_processors = None
268
+
269
+ for _, attn_processor in self.attn_processors.items():
270
+ if "Added" in str(attn_processor.__class__.__name__):
271
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
272
+
273
+ self.original_attn_processors = self.attn_processors
274
+
275
+ for module in self.modules():
276
+ if isinstance(module, Attention):
277
+ module.fuse_projections(fuse=True)
278
+
279
+ self.set_attn_processor(FusedAttnProcessor2_0())
280
+
281
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
282
+ def unfuse_qkv_projections(self):
283
+ """Disables the fused QKV projection if enabled.
284
+
285
+ <Tip warning={true}>
286
+
287
+ This API is 🧪 experimental.
288
+
289
+ </Tip>
290
+
291
+ """
292
+ if self.original_attn_processors is not None:
293
+ self.set_attn_processor(self.original_attn_processors)
294
+
295
+ def forward(
296
+ self,
297
+ hidden_states: torch.Tensor,
298
+ encoder_hidden_states: Optional[torch.Tensor] = None,
299
+ timestep: Optional[torch.LongTensor] = None,
300
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
301
+ cross_attention_kwargs: Dict[str, Any] = None,
302
+ attention_mask: Optional[torch.Tensor] = None,
303
+ encoder_attention_mask: Optional[torch.Tensor] = None,
304
+ return_dict: bool = True,
305
+ ):
306
+ """
307
+ The [`PixArtTransformer2DModel`] forward method.
308
+
309
+ Args:
310
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
311
+ Input `hidden_states`.
312
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
313
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
314
+ self-attention.
315
+ timestep (`torch.LongTensor`, *optional*):
316
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
317
+ added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs.
318
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
319
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
320
+ `self.processor` in
321
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
322
+ attention_mask ( `torch.Tensor`, *optional*):
323
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
324
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
325
+ negative values to the attention scores corresponding to "discard" tokens.
326
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
327
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
328
+
329
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
330
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
331
+
332
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
333
+ above. This bias will be added to the cross-attention scores.
334
+ return_dict (`bool`, *optional*, defaults to `True`):
335
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
336
+ tuple.
337
+
338
+ Returns:
339
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
340
+ `tuple` where the first element is the sample tensor.
341
+ """
342
+ if self.use_additional_conditions and added_cond_kwargs is None:
343
+ raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
344
+
345
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
346
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
347
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
348
+ # expects mask of shape:
349
+ # [batch, key_tokens]
350
+ # adds singleton query_tokens dimension:
351
+ # [batch, 1, key_tokens]
352
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
353
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
354
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
355
+ if attention_mask is not None and attention_mask.ndim == 2:
356
+ # assume that mask is expressed as:
357
+ # (1 = keep, 0 = discard)
358
+ # convert mask into a bias that can be added to attention scores:
359
+ # (keep = +0, discard = -10000.0)
360
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
361
+ attention_mask = attention_mask.unsqueeze(1)
362
+
363
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
364
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
365
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
366
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
367
+
368
+ # 1. Input
369
+ batch_size = hidden_states.shape[0]
370
+ height, width = (
371
+ hidden_states.shape[-2] // self.config.patch_size,
372
+ hidden_states.shape[-1] // self.config.patch_size,
373
+ )
374
+ hidden_states = self.pos_embed(hidden_states)
375
+
376
+ timestep, embedded_timestep = self.adaln_single(
377
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
378
+ )
379
+
380
+ if self.caption_projection is not None:
381
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
382
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
383
+
384
+ # 2. Blocks
385
+ for block in self.transformer_blocks:
386
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
387
+ hidden_states = self._gradient_checkpointing_func(
388
+ block,
389
+ hidden_states,
390
+ attention_mask,
391
+ encoder_hidden_states,
392
+ encoder_attention_mask,
393
+ timestep,
394
+ cross_attention_kwargs,
395
+ None,
396
+ )
397
+ else:
398
+ hidden_states = block(
399
+ hidden_states,
400
+ attention_mask=attention_mask,
401
+ encoder_hidden_states=encoder_hidden_states,
402
+ encoder_attention_mask=encoder_attention_mask,
403
+ timestep=timestep,
404
+ cross_attention_kwargs=cross_attention_kwargs,
405
+ class_labels=None,
406
+ )
407
+
408
+ # 3. Output
409
+ shift, scale = (
410
+ self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
411
+ ).chunk(2, dim=1)
412
+ hidden_states = self.norm_out(hidden_states)
413
+ # Modulation
414
+ hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
415
+ hidden_states = self.proj_out(hidden_states)
416
+ hidden_states = hidden_states.squeeze(1)
417
+
418
+ # unpatchify
419
+ hidden_states = hidden_states.reshape(
420
+ shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
421
+ )
422
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
423
+ output = hidden_states.reshape(
424
+ shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
425
+ )
426
+
427
+ if not return_dict:
428
+ return (output,)
429
+
430
+ return Transformer2DModelOutput(sample=output)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/prior_transformer.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from ...configuration_utils import ConfigMixin, register_to_config
9
+ from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
10
+ from ...utils import BaseOutput
11
+ from ..attention import BasicTransformerBlock
12
+ from ..attention_processor import (
13
+ ADDED_KV_ATTENTION_PROCESSORS,
14
+ CROSS_ATTENTION_PROCESSORS,
15
+ AttentionProcessor,
16
+ AttnAddedKVProcessor,
17
+ AttnProcessor,
18
+ )
19
+ from ..embeddings import TimestepEmbedding, Timesteps
20
+ from ..modeling_utils import ModelMixin
21
+
22
+
23
+ @dataclass
24
+ class PriorTransformerOutput(BaseOutput):
25
+ """
26
+ The output of [`PriorTransformer`].
27
+
28
+ Args:
29
+ predicted_image_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
30
+ The predicted CLIP image embedding conditioned on the CLIP text embedding input.
31
+ """
32
+
33
+ predicted_image_embedding: torch.Tensor
34
+
35
+
36
+ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
37
+ """
38
+ A Prior Transformer model.
39
+
40
+ Parameters:
41
+ num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
42
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
43
+ num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
44
+ embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
45
+ num_embeddings (`int`, *optional*, defaults to 77):
46
+ The number of embeddings of the model input `hidden_states`
47
+ additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
48
+ projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
49
+ additional_embeddings`.
50
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
51
+ time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
52
+ The activation function to use to create timestep embeddings.
53
+ norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
54
+ passing to Transformer blocks. Set it to `None` if normalization is not needed.
55
+ embedding_proj_norm_type (`str`, *optional*, defaults to None):
56
+ The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
57
+ needed.
58
+ encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
59
+ The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
60
+ `encoder_hidden_states` is `None`.
61
+ added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
62
+ Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
63
+ product between the text embedding and image embedding as proposed in the unclip paper
64
+ https://huggingface.co/papers/2204.06125 If it is `None`, no additional embeddings will be prepended.
65
+ time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
66
+ If None, will be set to `num_attention_heads * attention_head_dim`
67
+ embedding_proj_dim (`int`, *optional*, default to None):
68
+ The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
69
+ clip_embed_dim (`int`, *optional*, default to None):
70
+ The dimension of the output. If None, will be set to `embedding_dim`.
71
+ """
72
+
73
+ @register_to_config
74
+ def __init__(
75
+ self,
76
+ num_attention_heads: int = 32,
77
+ attention_head_dim: int = 64,
78
+ num_layers: int = 20,
79
+ embedding_dim: int = 768,
80
+ num_embeddings=77,
81
+ additional_embeddings=4,
82
+ dropout: float = 0.0,
83
+ time_embed_act_fn: str = "silu",
84
+ norm_in_type: Optional[str] = None, # layer
85
+ embedding_proj_norm_type: Optional[str] = None, # layer
86
+ encoder_hid_proj_type: Optional[str] = "linear", # linear
87
+ added_emb_type: Optional[str] = "prd", # prd
88
+ time_embed_dim: Optional[int] = None,
89
+ embedding_proj_dim: Optional[int] = None,
90
+ clip_embed_dim: Optional[int] = None,
91
+ ):
92
+ super().__init__()
93
+ self.num_attention_heads = num_attention_heads
94
+ self.attention_head_dim = attention_head_dim
95
+ inner_dim = num_attention_heads * attention_head_dim
96
+ self.additional_embeddings = additional_embeddings
97
+
98
+ time_embed_dim = time_embed_dim or inner_dim
99
+ embedding_proj_dim = embedding_proj_dim or embedding_dim
100
+ clip_embed_dim = clip_embed_dim or embedding_dim
101
+
102
+ self.time_proj = Timesteps(inner_dim, True, 0)
103
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
104
+
105
+ self.proj_in = nn.Linear(embedding_dim, inner_dim)
106
+
107
+ if embedding_proj_norm_type is None:
108
+ self.embedding_proj_norm = None
109
+ elif embedding_proj_norm_type == "layer":
110
+ self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
111
+ else:
112
+ raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
113
+
114
+ self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
115
+
116
+ if encoder_hid_proj_type is None:
117
+ self.encoder_hidden_states_proj = None
118
+ elif encoder_hid_proj_type == "linear":
119
+ self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
120
+ else:
121
+ raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
122
+
123
+ self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
124
+
125
+ if added_emb_type == "prd":
126
+ self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
127
+ elif added_emb_type is None:
128
+ self.prd_embedding = None
129
+ else:
130
+ raise ValueError(
131
+ f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
132
+ )
133
+
134
+ self.transformer_blocks = nn.ModuleList(
135
+ [
136
+ BasicTransformerBlock(
137
+ inner_dim,
138
+ num_attention_heads,
139
+ attention_head_dim,
140
+ dropout=dropout,
141
+ activation_fn="gelu",
142
+ attention_bias=True,
143
+ )
144
+ for d in range(num_layers)
145
+ ]
146
+ )
147
+
148
+ if norm_in_type == "layer":
149
+ self.norm_in = nn.LayerNorm(inner_dim)
150
+ elif norm_in_type is None:
151
+ self.norm_in = None
152
+ else:
153
+ raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
154
+
155
+ self.norm_out = nn.LayerNorm(inner_dim)
156
+
157
+ self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
158
+
159
+ causal_attention_mask = torch.full(
160
+ [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
161
+ )
162
+ causal_attention_mask.triu_(1)
163
+ causal_attention_mask = causal_attention_mask[None, ...]
164
+ self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
165
+
166
+ self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
167
+ self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
168
+
169
+ @property
170
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
171
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
172
+ r"""
173
+ Returns:
174
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
175
+ indexed by its weight name.
176
+ """
177
+ # set recursively
178
+ processors = {}
179
+
180
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
181
+ if hasattr(module, "get_processor"):
182
+ processors[f"{name}.processor"] = module.get_processor()
183
+
184
+ for sub_name, child in module.named_children():
185
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
186
+
187
+ return processors
188
+
189
+ for name, module in self.named_children():
190
+ fn_recursive_add_processors(name, module, processors)
191
+
192
+ return processors
193
+
194
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
195
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
196
+ r"""
197
+ Sets the attention processor to use to compute attention.
198
+
199
+ Parameters:
200
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
201
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
202
+ for **all** `Attention` layers.
203
+
204
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
205
+ processor. This is strongly recommended when setting trainable attention processors.
206
+
207
+ """
208
+ count = len(self.attn_processors.keys())
209
+
210
+ if isinstance(processor, dict) and len(processor) != count:
211
+ raise ValueError(
212
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
213
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
214
+ )
215
+
216
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
217
+ if hasattr(module, "set_processor"):
218
+ if not isinstance(processor, dict):
219
+ module.set_processor(processor)
220
+ else:
221
+ module.set_processor(processor.pop(f"{name}.processor"))
222
+
223
+ for sub_name, child in module.named_children():
224
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
225
+
226
+ for name, module in self.named_children():
227
+ fn_recursive_attn_processor(name, module, processor)
228
+
229
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
230
+ def set_default_attn_processor(self):
231
+ """
232
+ Disables custom attention processors and sets the default attention implementation.
233
+ """
234
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
235
+ processor = AttnAddedKVProcessor()
236
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
237
+ processor = AttnProcessor()
238
+ else:
239
+ raise ValueError(
240
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
241
+ )
242
+
243
+ self.set_attn_processor(processor)
244
+
245
+ def forward(
246
+ self,
247
+ hidden_states,
248
+ timestep: Union[torch.Tensor, float, int],
249
+ proj_embedding: torch.Tensor,
250
+ encoder_hidden_states: Optional[torch.Tensor] = None,
251
+ attention_mask: Optional[torch.BoolTensor] = None,
252
+ return_dict: bool = True,
253
+ ):
254
+ """
255
+ The [`PriorTransformer`] forward method.
256
+
257
+ Args:
258
+ hidden_states (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
259
+ The currently predicted image embeddings.
260
+ timestep (`torch.LongTensor`):
261
+ Current denoising step.
262
+ proj_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
263
+ Projected embedding vector the denoising process is conditioned on.
264
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
265
+ Hidden states of the text embeddings the denoising process is conditioned on.
266
+ attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
267
+ Text mask for the text embeddings.
268
+ return_dict (`bool`, *optional*, defaults to `True`):
269
+ Whether or not to return a [`~models.transformers.prior_transformer.PriorTransformerOutput`] instead of
270
+ a plain tuple.
271
+
272
+ Returns:
273
+ [`~models.transformers.prior_transformer.PriorTransformerOutput`] or `tuple`:
274
+ If return_dict is True, a [`~models.transformers.prior_transformer.PriorTransformerOutput`] is
275
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
276
+ """
277
+ batch_size = hidden_states.shape[0]
278
+
279
+ timesteps = timestep
280
+ if not torch.is_tensor(timesteps):
281
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
282
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
283
+ timesteps = timesteps[None].to(hidden_states.device)
284
+
285
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
286
+ timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
287
+
288
+ timesteps_projected = self.time_proj(timesteps)
289
+
290
+ # timesteps does not contain any weights and will always return f32 tensors
291
+ # but time_embedding might be fp16, so we need to cast here.
292
+ timesteps_projected = timesteps_projected.to(dtype=self.dtype)
293
+ time_embeddings = self.time_embedding(timesteps_projected)
294
+
295
+ if self.embedding_proj_norm is not None:
296
+ proj_embedding = self.embedding_proj_norm(proj_embedding)
297
+
298
+ proj_embeddings = self.embedding_proj(proj_embedding)
299
+ if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
300
+ encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
301
+ elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
302
+ raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
303
+
304
+ hidden_states = self.proj_in(hidden_states)
305
+
306
+ positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
307
+
308
+ additional_embeds = []
309
+ additional_embeddings_len = 0
310
+
311
+ if encoder_hidden_states is not None:
312
+ additional_embeds.append(encoder_hidden_states)
313
+ additional_embeddings_len += encoder_hidden_states.shape[1]
314
+
315
+ if len(proj_embeddings.shape) == 2:
316
+ proj_embeddings = proj_embeddings[:, None, :]
317
+
318
+ if len(hidden_states.shape) == 2:
319
+ hidden_states = hidden_states[:, None, :]
320
+
321
+ additional_embeds = additional_embeds + [
322
+ proj_embeddings,
323
+ time_embeddings[:, None, :],
324
+ hidden_states,
325
+ ]
326
+
327
+ if self.prd_embedding is not None:
328
+ prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
329
+ additional_embeds.append(prd_embedding)
330
+
331
+ hidden_states = torch.cat(
332
+ additional_embeds,
333
+ dim=1,
334
+ )
335
+
336
+ # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
337
+ additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
338
+ if positional_embeddings.shape[1] < hidden_states.shape[1]:
339
+ positional_embeddings = F.pad(
340
+ positional_embeddings,
341
+ (
342
+ 0,
343
+ 0,
344
+ additional_embeddings_len,
345
+ self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
346
+ ),
347
+ value=0.0,
348
+ )
349
+
350
+ hidden_states = hidden_states + positional_embeddings
351
+
352
+ if attention_mask is not None:
353
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
354
+ attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
355
+ attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
356
+ attention_mask = attention_mask.repeat_interleave(
357
+ self.config.num_attention_heads,
358
+ dim=0,
359
+ output_size=attention_mask.shape[0] * self.config.num_attention_heads,
360
+ )
361
+
362
+ if self.norm_in is not None:
363
+ hidden_states = self.norm_in(hidden_states)
364
+
365
+ for block in self.transformer_blocks:
366
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
367
+
368
+ hidden_states = self.norm_out(hidden_states)
369
+
370
+ if self.prd_embedding is not None:
371
+ hidden_states = hidden_states[:, -1]
372
+ else:
373
+ hidden_states = hidden_states[:, additional_embeddings_len:]
374
+
375
+ predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
376
+
377
+ if not return_dict:
378
+ return (predicted_image_embedding,)
379
+
380
+ return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
381
+
382
+ def post_process_latents(self, prior_latents):
383
+ prior_latents = (prior_latents * self.clip_std) + self.clip_mean
384
+ return prior_latents
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/sana_transformer.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
23
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
24
+ from ..attention_processor import (
25
+ Attention,
26
+ AttentionProcessor,
27
+ SanaLinearAttnProcessor2_0,
28
+ )
29
+ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
30
+ from ..modeling_outputs import Transformer2DModelOutput
31
+ from ..modeling_utils import ModelMixin
32
+ from ..normalization import AdaLayerNormSingle, RMSNorm
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ class GLUMBConv(nn.Module):
39
+ def __init__(
40
+ self,
41
+ in_channels: int,
42
+ out_channels: int,
43
+ expand_ratio: float = 4,
44
+ norm_type: Optional[str] = None,
45
+ residual_connection: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ hidden_channels = int(expand_ratio * in_channels)
50
+ self.norm_type = norm_type
51
+ self.residual_connection = residual_connection
52
+
53
+ self.nonlinearity = nn.SiLU()
54
+ self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
55
+ self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
56
+ self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
57
+
58
+ self.norm = None
59
+ if norm_type == "rms_norm":
60
+ self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
61
+
62
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
63
+ if self.residual_connection:
64
+ residual = hidden_states
65
+
66
+ hidden_states = self.conv_inverted(hidden_states)
67
+ hidden_states = self.nonlinearity(hidden_states)
68
+
69
+ hidden_states = self.conv_depth(hidden_states)
70
+ hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
71
+ hidden_states = hidden_states * self.nonlinearity(gate)
72
+
73
+ hidden_states = self.conv_point(hidden_states)
74
+
75
+ if self.norm_type == "rms_norm":
76
+ # move channel to the last dimension so we apply RMSnorm across channel dimension
77
+ hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
78
+
79
+ if self.residual_connection:
80
+ hidden_states = hidden_states + residual
81
+
82
+ return hidden_states
83
+
84
+
85
+ class SanaModulatedNorm(nn.Module):
86
+ def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
87
+ super().__init__()
88
+ self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)
89
+
90
+ def forward(
91
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
92
+ ) -> torch.Tensor:
93
+ hidden_states = self.norm(hidden_states)
94
+ shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
95
+ hidden_states = hidden_states * (1 + scale) + shift
96
+ return hidden_states
97
+
98
+
99
+ class SanaCombinedTimestepGuidanceEmbeddings(nn.Module):
100
+ def __init__(self, embedding_dim):
101
+ super().__init__()
102
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
103
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
104
+
105
+ self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
106
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
107
+
108
+ self.silu = nn.SiLU()
109
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
110
+
111
+ def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None):
112
+ timesteps_proj = self.time_proj(timestep)
113
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
114
+
115
+ guidance_proj = self.guidance_condition_proj(guidance)
116
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype))
117
+ conditioning = timesteps_emb + guidance_emb
118
+
119
+ return self.linear(self.silu(conditioning)), conditioning
120
+
121
+
122
+ class SanaAttnProcessor2_0:
123
+ r"""
124
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
125
+ """
126
+
127
+ def __init__(self):
128
+ if not hasattr(F, "scaled_dot_product_attention"):
129
+ raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
130
+
131
+ def __call__(
132
+ self,
133
+ attn: Attention,
134
+ hidden_states: torch.Tensor,
135
+ encoder_hidden_states: Optional[torch.Tensor] = None,
136
+ attention_mask: Optional[torch.Tensor] = None,
137
+ ) -> torch.Tensor:
138
+ batch_size, sequence_length, _ = (
139
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
140
+ )
141
+
142
+ if attention_mask is not None:
143
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
144
+ # scaled_dot_product_attention expects attention_mask shape to be
145
+ # (batch, heads, source_length, target_length)
146
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
147
+
148
+ query = attn.to_q(hidden_states)
149
+
150
+ if encoder_hidden_states is None:
151
+ encoder_hidden_states = hidden_states
152
+
153
+ key = attn.to_k(encoder_hidden_states)
154
+ value = attn.to_v(encoder_hidden_states)
155
+
156
+ if attn.norm_q is not None:
157
+ query = attn.norm_q(query)
158
+ if attn.norm_k is not None:
159
+ key = attn.norm_k(key)
160
+
161
+ inner_dim = key.shape[-1]
162
+ head_dim = inner_dim // attn.heads
163
+
164
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
165
+
166
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
167
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
168
+
169
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
170
+ # TODO: add support for attn.scale when we move to Torch 2.1
171
+ hidden_states = F.scaled_dot_product_attention(
172
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
173
+ )
174
+
175
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
176
+ hidden_states = hidden_states.to(query.dtype)
177
+
178
+ # linear proj
179
+ hidden_states = attn.to_out[0](hidden_states)
180
+ # dropout
181
+ hidden_states = attn.to_out[1](hidden_states)
182
+
183
+ hidden_states = hidden_states / attn.rescale_output_factor
184
+
185
+ return hidden_states
186
+
187
+
188
+ class SanaTransformerBlock(nn.Module):
189
+ r"""
190
+ Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ dim: int = 2240,
196
+ num_attention_heads: int = 70,
197
+ attention_head_dim: int = 32,
198
+ dropout: float = 0.0,
199
+ num_cross_attention_heads: Optional[int] = 20,
200
+ cross_attention_head_dim: Optional[int] = 112,
201
+ cross_attention_dim: Optional[int] = 2240,
202
+ attention_bias: bool = True,
203
+ norm_elementwise_affine: bool = False,
204
+ norm_eps: float = 1e-6,
205
+ attention_out_bias: bool = True,
206
+ mlp_ratio: float = 2.5,
207
+ qk_norm: Optional[str] = None,
208
+ ) -> None:
209
+ super().__init__()
210
+
211
+ # 1. Self Attention
212
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
213
+ self.attn1 = Attention(
214
+ query_dim=dim,
215
+ heads=num_attention_heads,
216
+ dim_head=attention_head_dim,
217
+ kv_heads=num_attention_heads if qk_norm is not None else None,
218
+ qk_norm=qk_norm,
219
+ dropout=dropout,
220
+ bias=attention_bias,
221
+ cross_attention_dim=None,
222
+ processor=SanaLinearAttnProcessor2_0(),
223
+ )
224
+
225
+ # 2. Cross Attention
226
+ if cross_attention_dim is not None:
227
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
228
+ self.attn2 = Attention(
229
+ query_dim=dim,
230
+ qk_norm=qk_norm,
231
+ kv_heads=num_cross_attention_heads if qk_norm is not None else None,
232
+ cross_attention_dim=cross_attention_dim,
233
+ heads=num_cross_attention_heads,
234
+ dim_head=cross_attention_head_dim,
235
+ dropout=dropout,
236
+ bias=True,
237
+ out_bias=attention_out_bias,
238
+ processor=SanaAttnProcessor2_0(),
239
+ )
240
+
241
+ # 3. Feed-forward
242
+ self.ff = GLUMBConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
243
+
244
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
245
+
246
+ def forward(
247
+ self,
248
+ hidden_states: torch.Tensor,
249
+ attention_mask: Optional[torch.Tensor] = None,
250
+ encoder_hidden_states: Optional[torch.Tensor] = None,
251
+ encoder_attention_mask: Optional[torch.Tensor] = None,
252
+ timestep: Optional[torch.LongTensor] = None,
253
+ height: int = None,
254
+ width: int = None,
255
+ ) -> torch.Tensor:
256
+ batch_size = hidden_states.shape[0]
257
+
258
+ # 1. Modulation
259
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
260
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
261
+ ).chunk(6, dim=1)
262
+
263
+ # 2. Self Attention
264
+ norm_hidden_states = self.norm1(hidden_states)
265
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
266
+ norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
267
+
268
+ attn_output = self.attn1(norm_hidden_states)
269
+ hidden_states = hidden_states + gate_msa * attn_output
270
+
271
+ # 3. Cross Attention
272
+ if self.attn2 is not None:
273
+ attn_output = self.attn2(
274
+ hidden_states,
275
+ encoder_hidden_states=encoder_hidden_states,
276
+ attention_mask=encoder_attention_mask,
277
+ )
278
+ hidden_states = attn_output + hidden_states
279
+
280
+ # 4. Feed-forward
281
+ norm_hidden_states = self.norm2(hidden_states)
282
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
283
+
284
+ norm_hidden_states = norm_hidden_states.unflatten(1, (height, width)).permute(0, 3, 1, 2)
285
+ ff_output = self.ff(norm_hidden_states)
286
+ ff_output = ff_output.flatten(2, 3).permute(0, 2, 1)
287
+ hidden_states = hidden_states + gate_mlp * ff_output
288
+
289
+ return hidden_states
290
+
291
+
292
+ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
293
+ r"""
294
+ A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
295
+
296
+ Args:
297
+ in_channels (`int`, defaults to `32`):
298
+ The number of channels in the input.
299
+ out_channels (`int`, *optional*, defaults to `32`):
300
+ The number of channels in the output.
301
+ num_attention_heads (`int`, defaults to `70`):
302
+ The number of heads to use for multi-head attention.
303
+ attention_head_dim (`int`, defaults to `32`):
304
+ The number of channels in each head.
305
+ num_layers (`int`, defaults to `20`):
306
+ The number of layers of Transformer blocks to use.
307
+ num_cross_attention_heads (`int`, *optional*, defaults to `20`):
308
+ The number of heads to use for cross-attention.
309
+ cross_attention_head_dim (`int`, *optional*, defaults to `112`):
310
+ The number of channels in each head for cross-attention.
311
+ cross_attention_dim (`int`, *optional*, defaults to `2240`):
312
+ The number of channels in the cross-attention output.
313
+ caption_channels (`int`, defaults to `2304`):
314
+ The number of channels in the caption embeddings.
315
+ mlp_ratio (`float`, defaults to `2.5`):
316
+ The expansion ratio to use in the GLUMBConv layer.
317
+ dropout (`float`, defaults to `0.0`):
318
+ The dropout probability.
319
+ attention_bias (`bool`, defaults to `False`):
320
+ Whether to use bias in the attention layer.
321
+ sample_size (`int`, defaults to `32`):
322
+ The base size of the input latent.
323
+ patch_size (`int`, defaults to `1`):
324
+ The size of the patches to use in the patch embedding layer.
325
+ norm_elementwise_affine (`bool`, defaults to `False`):
326
+ Whether to use elementwise affinity in the normalization layer.
327
+ norm_eps (`float`, defaults to `1e-6`):
328
+ The epsilon value for the normalization layer.
329
+ qk_norm (`str`, *optional*, defaults to `None`):
330
+ The normalization to use for the query and key.
331
+ timestep_scale (`float`, defaults to `1.0`):
332
+ The scale to use for the timesteps.
333
+ """
334
+
335
+ _supports_gradient_checkpointing = True
336
+ _no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"]
337
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm"]
338
+
339
+ @register_to_config
340
+ def __init__(
341
+ self,
342
+ in_channels: int = 32,
343
+ out_channels: Optional[int] = 32,
344
+ num_attention_heads: int = 70,
345
+ attention_head_dim: int = 32,
346
+ num_layers: int = 20,
347
+ num_cross_attention_heads: Optional[int] = 20,
348
+ cross_attention_head_dim: Optional[int] = 112,
349
+ cross_attention_dim: Optional[int] = 2240,
350
+ caption_channels: int = 2304,
351
+ mlp_ratio: float = 2.5,
352
+ dropout: float = 0.0,
353
+ attention_bias: bool = False,
354
+ sample_size: int = 32,
355
+ patch_size: int = 1,
356
+ norm_elementwise_affine: bool = False,
357
+ norm_eps: float = 1e-6,
358
+ interpolation_scale: Optional[int] = None,
359
+ guidance_embeds: bool = False,
360
+ guidance_embeds_scale: float = 0.1,
361
+ qk_norm: Optional[str] = None,
362
+ timestep_scale: float = 1.0,
363
+ ) -> None:
364
+ super().__init__()
365
+
366
+ out_channels = out_channels or in_channels
367
+ inner_dim = num_attention_heads * attention_head_dim
368
+
369
+ # 1. Patch Embedding
370
+ self.patch_embed = PatchEmbed(
371
+ height=sample_size,
372
+ width=sample_size,
373
+ patch_size=patch_size,
374
+ in_channels=in_channels,
375
+ embed_dim=inner_dim,
376
+ interpolation_scale=interpolation_scale,
377
+ pos_embed_type="sincos" if interpolation_scale is not None else None,
378
+ )
379
+
380
+ # 2. Additional condition embeddings
381
+ if guidance_embeds:
382
+ self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
383
+ else:
384
+ self.time_embed = AdaLayerNormSingle(inner_dim)
385
+
386
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
387
+ self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
388
+
389
+ # 3. Transformer blocks
390
+ self.transformer_blocks = nn.ModuleList(
391
+ [
392
+ SanaTransformerBlock(
393
+ inner_dim,
394
+ num_attention_heads,
395
+ attention_head_dim,
396
+ dropout=dropout,
397
+ num_cross_attention_heads=num_cross_attention_heads,
398
+ cross_attention_head_dim=cross_attention_head_dim,
399
+ cross_attention_dim=cross_attention_dim,
400
+ attention_bias=attention_bias,
401
+ norm_elementwise_affine=norm_elementwise_affine,
402
+ norm_eps=norm_eps,
403
+ mlp_ratio=mlp_ratio,
404
+ qk_norm=qk_norm,
405
+ )
406
+ for _ in range(num_layers)
407
+ ]
408
+ )
409
+
410
+ # 4. Output blocks
411
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
412
+ self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
413
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
414
+
415
+ self.gradient_checkpointing = False
416
+
417
+ @property
418
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
419
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
420
+ r"""
421
+ Returns:
422
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
423
+ indexed by its weight name.
424
+ """
425
+ # set recursively
426
+ processors = {}
427
+
428
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
429
+ if hasattr(module, "get_processor"):
430
+ processors[f"{name}.processor"] = module.get_processor()
431
+
432
+ for sub_name, child in module.named_children():
433
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
434
+
435
+ return processors
436
+
437
+ for name, module in self.named_children():
438
+ fn_recursive_add_processors(name, module, processors)
439
+
440
+ return processors
441
+
442
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
443
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
444
+ r"""
445
+ Sets the attention processor to use to compute attention.
446
+
447
+ Parameters:
448
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
449
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
450
+ for **all** `Attention` layers.
451
+
452
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
453
+ processor. This is strongly recommended when setting trainable attention processors.
454
+
455
+ """
456
+ count = len(self.attn_processors.keys())
457
+
458
+ if isinstance(processor, dict) and len(processor) != count:
459
+ raise ValueError(
460
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
461
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
462
+ )
463
+
464
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
465
+ if hasattr(module, "set_processor"):
466
+ if not isinstance(processor, dict):
467
+ module.set_processor(processor)
468
+ else:
469
+ module.set_processor(processor.pop(f"{name}.processor"))
470
+
471
+ for sub_name, child in module.named_children():
472
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
473
+
474
+ for name, module in self.named_children():
475
+ fn_recursive_attn_processor(name, module, processor)
476
+
477
+ def forward(
478
+ self,
479
+ hidden_states: torch.Tensor,
480
+ encoder_hidden_states: torch.Tensor,
481
+ timestep: torch.Tensor,
482
+ guidance: Optional[torch.Tensor] = None,
483
+ encoder_attention_mask: Optional[torch.Tensor] = None,
484
+ attention_mask: Optional[torch.Tensor] = None,
485
+ attention_kwargs: Optional[Dict[str, Any]] = None,
486
+ controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
487
+ return_dict: bool = True,
488
+ ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
489
+ if attention_kwargs is not None:
490
+ attention_kwargs = attention_kwargs.copy()
491
+ lora_scale = attention_kwargs.pop("scale", 1.0)
492
+ else:
493
+ lora_scale = 1.0
494
+
495
+ if USE_PEFT_BACKEND:
496
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
497
+ scale_lora_layers(self, lora_scale)
498
+ else:
499
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
500
+ logger.warning(
501
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
502
+ )
503
+
504
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
505
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
506
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
507
+ # expects mask of shape:
508
+ # [batch, key_tokens]
509
+ # adds singleton query_tokens dimension:
510
+ # [batch, 1, key_tokens]
511
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
512
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
513
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
514
+ if attention_mask is not None and attention_mask.ndim == 2:
515
+ # assume that mask is expressed as:
516
+ # (1 = keep, 0 = discard)
517
+ # convert mask into a bias that can be added to attention scores:
518
+ # (keep = +0, discard = -10000.0)
519
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
520
+ attention_mask = attention_mask.unsqueeze(1)
521
+
522
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
523
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
524
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
525
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
526
+
527
+ # 1. Input
528
+ batch_size, num_channels, height, width = hidden_states.shape
529
+ p = self.config.patch_size
530
+ post_patch_height, post_patch_width = height // p, width // p
531
+
532
+ hidden_states = self.patch_embed(hidden_states)
533
+
534
+ if guidance is not None:
535
+ timestep, embedded_timestep = self.time_embed(
536
+ timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
537
+ )
538
+ else:
539
+ timestep, embedded_timestep = self.time_embed(
540
+ timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
541
+ )
542
+
543
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
544
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
545
+
546
+ encoder_hidden_states = self.caption_norm(encoder_hidden_states)
547
+
548
+ # 2. Transformer blocks
549
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
550
+ for index_block, block in enumerate(self.transformer_blocks):
551
+ hidden_states = self._gradient_checkpointing_func(
552
+ block,
553
+ hidden_states,
554
+ attention_mask,
555
+ encoder_hidden_states,
556
+ encoder_attention_mask,
557
+ timestep,
558
+ post_patch_height,
559
+ post_patch_width,
560
+ )
561
+ if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
562
+ hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
563
+
564
+ else:
565
+ for index_block, block in enumerate(self.transformer_blocks):
566
+ hidden_states = block(
567
+ hidden_states,
568
+ attention_mask,
569
+ encoder_hidden_states,
570
+ encoder_attention_mask,
571
+ timestep,
572
+ post_patch_height,
573
+ post_patch_width,
574
+ )
575
+ if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
576
+ hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
577
+
578
+ # 3. Normalization
579
+ hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
580
+
581
+ hidden_states = self.proj_out(hidden_states)
582
+
583
+ # 5. Unpatchify
584
+ hidden_states = hidden_states.reshape(
585
+ batch_size, post_patch_height, post_patch_width, self.config.patch_size, self.config.patch_size, -1
586
+ )
587
+ hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
588
+ output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p)
589
+
590
+ if USE_PEFT_BACKEND:
591
+ # remove `lora_scale` from each PEFT layer
592
+ unscale_lora_layers(self, lora_scale)
593
+
594
+ if not return_dict:
595
+ return (output,)
596
+
597
+ return Transformer2DModelOutput(sample=output)
pythonProject/.venv/Lib/site-packages/diffusers/models/transformers/stable_audio_transformer.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Dict, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.utils.checkpoint
22
+
23
+ from ...configuration_utils import ConfigMixin, register_to_config
24
+ from ...utils import logging
25
+ from ...utils.torch_utils import maybe_allow_in_graph
26
+ from ..attention import FeedForward
27
+ from ..attention_processor import Attention, AttentionProcessor, StableAudioAttnProcessor2_0
28
+ from ..modeling_utils import ModelMixin
29
+ from ..transformers.transformer_2d import Transformer2DModelOutput
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ class StableAudioGaussianFourierProjection(nn.Module):
36
+ """Gaussian Fourier embeddings for noise levels."""
37
+
38
+ # Copied from diffusers.models.embeddings.GaussianFourierProjection.__init__
39
+ def __init__(
40
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
41
+ ):
42
+ super().__init__()
43
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
44
+ self.log = log
45
+ self.flip_sin_to_cos = flip_sin_to_cos
46
+
47
+ if set_W_to_weight:
48
+ # to delete later
49
+ del self.weight
50
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
51
+ self.weight = self.W
52
+ del self.W
53
+
54
+ def forward(self, x):
55
+ if self.log:
56
+ x = torch.log(x)
57
+
58
+ x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :]
59
+
60
+ if self.flip_sin_to_cos:
61
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
62
+ else:
63
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
64
+ return out
65
+
66
+
67
+ @maybe_allow_in_graph
68
+ class StableAudioDiTBlock(nn.Module):
69
+ r"""
70
+ Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip
71
+ connection and QKNorm
72
+
73
+ Parameters:
74
+ dim (`int`): The number of channels in the input and output.
75
+ num_attention_heads (`int`): The number of heads to use for the query states.
76
+ num_key_value_attention_heads (`int`): The number of heads to use for the key and value states.
77
+ attention_head_dim (`int`): The number of channels in each head.
78
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
79
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
80
+ upcast_attention (`bool`, *optional*):
81
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ dim: int,
87
+ num_attention_heads: int,
88
+ num_key_value_attention_heads: int,
89
+ attention_head_dim: int,
90
+ dropout=0.0,
91
+ cross_attention_dim: Optional[int] = None,
92
+ upcast_attention: bool = False,
93
+ norm_eps: float = 1e-5,
94
+ ff_inner_dim: Optional[int] = None,
95
+ ):
96
+ super().__init__()
97
+ # Define 3 blocks. Each block has its own normalization layer.
98
+ # 1. Self-Attn
99
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps)
100
+ self.attn1 = Attention(
101
+ query_dim=dim,
102
+ heads=num_attention_heads,
103
+ dim_head=attention_head_dim,
104
+ dropout=dropout,
105
+ bias=False,
106
+ upcast_attention=upcast_attention,
107
+ out_bias=False,
108
+ processor=StableAudioAttnProcessor2_0(),
109
+ )
110
+
111
+ # 2. Cross-Attn
112
+ self.norm2 = nn.LayerNorm(dim, norm_eps, True)
113
+
114
+ self.attn2 = Attention(
115
+ query_dim=dim,
116
+ cross_attention_dim=cross_attention_dim,
117
+ heads=num_attention_heads,
118
+ dim_head=attention_head_dim,
119
+ kv_heads=num_key_value_attention_heads,
120
+ dropout=dropout,
121
+ bias=False,
122
+ upcast_attention=upcast_attention,
123
+ out_bias=False,
124
+ processor=StableAudioAttnProcessor2_0(),
125
+ ) # is self-attn if encoder_hidden_states is none
126
+
127
+ # 3. Feed-forward
128
+ self.norm3 = nn.LayerNorm(dim, norm_eps, True)
129
+ self.ff = FeedForward(
130
+ dim,
131
+ dropout=dropout,
132
+ activation_fn="swiglu",
133
+ final_dropout=False,
134
+ inner_dim=ff_inner_dim,
135
+ bias=True,
136
+ )
137
+
138
+ # let chunk size default to None
139
+ self._chunk_size = None
140
+ self._chunk_dim = 0
141
+
142
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
143
+ # Sets chunk feed-forward
144
+ self._chunk_size = chunk_size
145
+ self._chunk_dim = dim
146
+
147
+ def forward(
148
+ self,
149
+ hidden_states: torch.Tensor,
150
+ attention_mask: Optional[torch.Tensor] = None,
151
+ encoder_hidden_states: Optional[torch.Tensor] = None,
152
+ encoder_attention_mask: Optional[torch.Tensor] = None,
153
+ rotary_embedding: Optional[torch.FloatTensor] = None,
154
+ ) -> torch.Tensor:
155
+ # Notice that normalization is always applied before the real computation in the following blocks.
156
+ # 0. Self-Attention
157
+ norm_hidden_states = self.norm1(hidden_states)
158
+
159
+ attn_output = self.attn1(
160
+ norm_hidden_states,
161
+ attention_mask=attention_mask,
162
+ rotary_emb=rotary_embedding,
163
+ )
164
+
165
+ hidden_states = attn_output + hidden_states
166
+
167
+ # 2. Cross-Attention
168
+ norm_hidden_states = self.norm2(hidden_states)
169
+
170
+ attn_output = self.attn2(
171
+ norm_hidden_states,
172
+ encoder_hidden_states=encoder_hidden_states,
173
+ attention_mask=encoder_attention_mask,
174
+ )
175
+ hidden_states = attn_output + hidden_states
176
+
177
+ # 3. Feed-forward
178
+ norm_hidden_states = self.norm3(hidden_states)
179
+ ff_output = self.ff(norm_hidden_states)
180
+
181
+ hidden_states = ff_output + hidden_states
182
+
183
+ return hidden_states
184
+
185
+
186
+ class StableAudioDiTModel(ModelMixin, ConfigMixin):
187
+ """
188
+ The Diffusion Transformer model introduced in Stable Audio.
189
+
190
+ Reference: https://github.com/Stability-AI/stable-audio-tools
191
+
192
+ Parameters:
193
+ sample_size ( `int`, *optional*, defaults to 1024): The size of the input sample.
194
+ in_channels (`int`, *optional*, defaults to 64): The number of channels in the input.
195
+ num_layers (`int`, *optional*, defaults to 24): The number of layers of Transformer blocks to use.
196
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
197
+ num_attention_heads (`int`, *optional*, defaults to 24): The number of heads to use for the query states.
198
+ num_key_value_attention_heads (`int`, *optional*, defaults to 12):
199
+ The number of heads to use for the key and value states.
200
+ out_channels (`int`, defaults to 64): Number of output channels.
201
+ cross_attention_dim ( `int`, *optional*, defaults to 768): Dimension of the cross-attention projection.
202
+ time_proj_dim ( `int`, *optional*, defaults to 256): Dimension of the timestep inner projection.
203
+ global_states_input_dim ( `int`, *optional*, defaults to 1536):
204
+ Input dimension of the global hidden states projection.
205
+ cross_attention_input_dim ( `int`, *optional*, defaults to 768):
206
+ Input dimension of the cross-attention projection
207
+ """
208
+
209
+ _supports_gradient_checkpointing = True
210
+ _skip_layerwise_casting_patterns = ["preprocess_conv", "postprocess_conv", "^proj_in$", "^proj_out$", "norm"]
211
+
212
+ @register_to_config
213
+ def __init__(
214
+ self,
215
+ sample_size: int = 1024,
216
+ in_channels: int = 64,
217
+ num_layers: int = 24,
218
+ attention_head_dim: int = 64,
219
+ num_attention_heads: int = 24,
220
+ num_key_value_attention_heads: int = 12,
221
+ out_channels: int = 64,
222
+ cross_attention_dim: int = 768,
223
+ time_proj_dim: int = 256,
224
+ global_states_input_dim: int = 1536,
225
+ cross_attention_input_dim: int = 768,
226
+ ):
227
+ super().__init__()
228
+ self.sample_size = sample_size
229
+ self.out_channels = out_channels
230
+ self.inner_dim = num_attention_heads * attention_head_dim
231
+
232
+ self.time_proj = StableAudioGaussianFourierProjection(
233
+ embedding_size=time_proj_dim // 2,
234
+ flip_sin_to_cos=True,
235
+ log=False,
236
+ set_W_to_weight=False,
237
+ )
238
+
239
+ self.timestep_proj = nn.Sequential(
240
+ nn.Linear(time_proj_dim, self.inner_dim, bias=True),
241
+ nn.SiLU(),
242
+ nn.Linear(self.inner_dim, self.inner_dim, bias=True),
243
+ )
244
+
245
+ self.global_proj = nn.Sequential(
246
+ nn.Linear(global_states_input_dim, self.inner_dim, bias=False),
247
+ nn.SiLU(),
248
+ nn.Linear(self.inner_dim, self.inner_dim, bias=False),
249
+ )
250
+
251
+ self.cross_attention_proj = nn.Sequential(
252
+ nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False),
253
+ nn.SiLU(),
254
+ nn.Linear(cross_attention_dim, cross_attention_dim, bias=False),
255
+ )
256
+
257
+ self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False)
258
+ self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False)
259
+
260
+ self.transformer_blocks = nn.ModuleList(
261
+ [
262
+ StableAudioDiTBlock(
263
+ dim=self.inner_dim,
264
+ num_attention_heads=num_attention_heads,
265
+ num_key_value_attention_heads=num_key_value_attention_heads,
266
+ attention_head_dim=attention_head_dim,
267
+ cross_attention_dim=cross_attention_dim,
268
+ )
269
+ for i in range(num_layers)
270
+ ]
271
+ )
272
+
273
+ self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False)
274
+ self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False)
275
+
276
+ self.gradient_checkpointing = False
277
+
278
+ @property
279
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
280
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
281
+ r"""
282
+ Returns:
283
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
284
+ indexed by its weight name.
285
+ """
286
+ # set recursively
287
+ processors = {}
288
+
289
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
290
+ if hasattr(module, "get_processor"):
291
+ processors[f"{name}.processor"] = module.get_processor()
292
+
293
+ for sub_name, child in module.named_children():
294
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
295
+
296
+ return processors
297
+
298
+ for name, module in self.named_children():
299
+ fn_recursive_add_processors(name, module, processors)
300
+
301
+ return processors
302
+
303
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
304
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
305
+ r"""
306
+ Sets the attention processor to use to compute attention.
307
+
308
+ Parameters:
309
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
310
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
311
+ for **all** `Attention` layers.
312
+
313
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
314
+ processor. This is strongly recommended when setting trainable attention processors.
315
+
316
+ """
317
+ count = len(self.attn_processors.keys())
318
+
319
+ if isinstance(processor, dict) and len(processor) != count:
320
+ raise ValueError(
321
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
322
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
323
+ )
324
+
325
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
326
+ if hasattr(module, "set_processor"):
327
+ if not isinstance(processor, dict):
328
+ module.set_processor(processor)
329
+ else:
330
+ module.set_processor(processor.pop(f"{name}.processor"))
331
+
332
+ for sub_name, child in module.named_children():
333
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
334
+
335
+ for name, module in self.named_children():
336
+ fn_recursive_attn_processor(name, module, processor)
337
+
338
+ # Copied from diffusers.models.transformers.hunyuan_transformer_2d.HunyuanDiT2DModel.set_default_attn_processor with Hunyuan->StableAudio
339
+ def set_default_attn_processor(self):
340
+ """
341
+ Disables custom attention processors and sets the default attention implementation.
342
+ """
343
+ self.set_attn_processor(StableAudioAttnProcessor2_0())
344
+
345
+ def forward(
346
+ self,
347
+ hidden_states: torch.FloatTensor,
348
+ timestep: torch.LongTensor = None,
349
+ encoder_hidden_states: torch.FloatTensor = None,
350
+ global_hidden_states: torch.FloatTensor = None,
351
+ rotary_embedding: torch.FloatTensor = None,
352
+ return_dict: bool = True,
353
+ attention_mask: Optional[torch.LongTensor] = None,
354
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
355
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
356
+ """
357
+ The [`StableAudioDiTModel`] forward method.
358
+
359
+ Args:
360
+ hidden_states (`torch.FloatTensor` of shape `(batch size, in_channels, sequence_len)`):
361
+ Input `hidden_states`.
362
+ timestep ( `torch.LongTensor`):
363
+ Used to indicate denoising step.
364
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, encoder_sequence_len, cross_attention_input_dim)`):
365
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
366
+ global_hidden_states (`torch.FloatTensor` of shape `(batch size, global_sequence_len, global_states_input_dim)`):
367
+ Global embeddings that will be prepended to the hidden states.
368
+ rotary_embedding (`torch.Tensor`):
369
+ The rotary embeddings to apply on query and key tensors during attention calculation.
370
+ return_dict (`bool`, *optional*, defaults to `True`):
371
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
372
+ tuple.
373
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*):
374
+ Mask to avoid performing attention on padding token indices, formed by concatenating the attention
375
+ masks
376
+ for the two text encoders together. Mask values selected in `[0, 1]`:
377
+
378
+ - 1 for tokens that are **not masked**,
379
+ - 0 for tokens that are **masked**.
380
+ encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*):
381
+ Mask to avoid performing attention on padding token cross-attention indices, formed by concatenating
382
+ the attention masks
383
+ for the two text encoders together. Mask values selected in `[0, 1]`:
384
+
385
+ - 1 for tokens that are **not masked**,
386
+ - 0 for tokens that are **masked**.
387
+ Returns:
388
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
389
+ `tuple` where the first element is the sample tensor.
390
+ """
391
+ cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states)
392
+ global_hidden_states = self.global_proj(global_hidden_states)
393
+ time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.dtype)))
394
+
395
+ global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1)
396
+
397
+ hidden_states = self.preprocess_conv(hidden_states) + hidden_states
398
+ # (batch_size, dim, sequence_length) -> (batch_size, sequence_length, dim)
399
+ hidden_states = hidden_states.transpose(1, 2)
400
+
401
+ hidden_states = self.proj_in(hidden_states)
402
+
403
+ # prepend global states to hidden states
404
+ hidden_states = torch.cat([global_hidden_states, hidden_states], dim=-2)
405
+ if attention_mask is not None:
406
+ prepend_mask = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.bool)
407
+ attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1)
408
+
409
+ for block in self.transformer_blocks:
410
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
411
+ hidden_states = self._gradient_checkpointing_func(
412
+ block,
413
+ hidden_states,
414
+ attention_mask,
415
+ cross_attention_hidden_states,
416
+ encoder_attention_mask,
417
+ rotary_embedding,
418
+ )
419
+
420
+ else:
421
+ hidden_states = block(
422
+ hidden_states=hidden_states,
423
+ attention_mask=attention_mask,
424
+ encoder_hidden_states=cross_attention_hidden_states,
425
+ encoder_attention_mask=encoder_attention_mask,
426
+ rotary_embedding=rotary_embedding,
427
+ )
428
+
429
+ hidden_states = self.proj_out(hidden_states)
430
+
431
+ # (batch_size, sequence_length, dim) -> (batch_size, dim, sequence_length)
432
+ # remove prepend length that has been added by global hidden states
433
+ hidden_states = hidden_states.transpose(1, 2)[:, :, 1:]
434
+ hidden_states = self.postprocess_conv(hidden_states) + hidden_states
435
+
436
+ if not return_dict:
437
+ return (hidden_states,)
438
+
439
+ return Transformer2DModelOutput(sample=hidden_states)
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_flow_match_lcm.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from ..configuration_utils import ConfigMixin, register_to_config
23
+ from ..utils import BaseOutput, is_scipy_available, logging
24
+ from ..utils.torch_utils import randn_tensor
25
+ from .scheduling_utils import SchedulerMixin
26
+
27
+
28
+ if is_scipy_available():
29
+ import scipy.stats
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ class FlowMatchLCMSchedulerOutput(BaseOutput):
36
+ """
37
+ Output class for the scheduler's `step` function output.
38
+
39
+ Args:
40
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
+ denoising loop.
43
+ """
44
+
45
+ prev_sample: torch.FloatTensor
46
+
47
+
48
+ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
49
+ """
50
+ LCM scheduler for Flow Matching.
51
+
52
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
53
+ methods the library implements for all schedulers such as loading and saving.
54
+
55
+ Args:
56
+ num_train_timesteps (`int`, defaults to 1000):
57
+ The number of diffusion steps to train the model.
58
+ shift (`float`, defaults to 1.0):
59
+ The shift value for the timestep schedule.
60
+ use_dynamic_shifting (`bool`, defaults to False):
61
+ Whether to apply timestep shifting on-the-fly based on the image resolution.
62
+ base_shift (`float`, defaults to 0.5):
63
+ Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent
64
+ with desired output.
65
+ max_shift (`float`, defaults to 1.15):
66
+ Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be
67
+ more exaggerated or stylized.
68
+ base_image_seq_len (`int`, defaults to 256):
69
+ The base image sequence length.
70
+ max_image_seq_len (`int`, defaults to 4096):
71
+ The maximum image sequence length.
72
+ invert_sigmas (`bool`, defaults to False):
73
+ Whether to invert the sigmas.
74
+ shift_terminal (`float`, defaults to None):
75
+ The end value of the shifted timestep schedule.
76
+ use_karras_sigmas (`bool`, defaults to False):
77
+ Whether to use Karras sigmas for step sizes in the noise schedule during sampling.
78
+ use_exponential_sigmas (`bool`, defaults to False):
79
+ Whether to use exponential sigmas for step sizes in the noise schedule during sampling.
80
+ use_beta_sigmas (`bool`, defaults to False):
81
+ Whether to use beta sigmas for step sizes in the noise schedule during sampling.
82
+ time_shift_type (`str`, defaults to "exponential"):
83
+ The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
84
+ scale_factors ('list', defaults to None)
85
+ It defines how to scale the latents at which predictions are made.
86
+ upscale_mode ('str', defaults to 'bicubic')
87
+ Upscaling method, applied if scale-wise generation is considered
88
+ """
89
+
90
+ _compatibles = []
91
+ order = 1
92
+
93
+ @register_to_config
94
+ def __init__(
95
+ self,
96
+ num_train_timesteps: int = 1000,
97
+ shift: float = 1.0,
98
+ use_dynamic_shifting: bool = False,
99
+ base_shift: Optional[float] = 0.5,
100
+ max_shift: Optional[float] = 1.15,
101
+ base_image_seq_len: Optional[int] = 256,
102
+ max_image_seq_len: Optional[int] = 4096,
103
+ invert_sigmas: bool = False,
104
+ shift_terminal: Optional[float] = None,
105
+ use_karras_sigmas: Optional[bool] = False,
106
+ use_exponential_sigmas: Optional[bool] = False,
107
+ use_beta_sigmas: Optional[bool] = False,
108
+ time_shift_type: str = "exponential",
109
+ scale_factors: Optional[List[float]] = None,
110
+ upscale_mode: Optional[str] = "bicubic",
111
+ ):
112
+ if self.config.use_beta_sigmas and not is_scipy_available():
113
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
114
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
115
+ raise ValueError(
116
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
117
+ )
118
+ if time_shift_type not in {"exponential", "linear"}:
119
+ raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.")
120
+
121
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
122
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
123
+
124
+ sigmas = timesteps / num_train_timesteps
125
+ if not use_dynamic_shifting:
126
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
127
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
128
+
129
+ self.timesteps = sigmas * num_train_timesteps
130
+
131
+ self._step_index = None
132
+ self._begin_index = None
133
+
134
+ self._shift = shift
135
+
136
+ self._init_size = None
137
+ self._scale_factors = scale_factors
138
+ self._upscale_mode = upscale_mode
139
+
140
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
141
+ self.sigma_min = self.sigmas[-1].item()
142
+ self.sigma_max = self.sigmas[0].item()
143
+
144
+ @property
145
+ def shift(self):
146
+ """
147
+ The value used for shifting.
148
+ """
149
+ return self._shift
150
+
151
+ @property
152
+ def step_index(self):
153
+ """
154
+ The index counter for current timestep. It will increase 1 after each scheduler step.
155
+ """
156
+ return self._step_index
157
+
158
+ @property
159
+ def begin_index(self):
160
+ """
161
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
162
+ """
163
+ return self._begin_index
164
+
165
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
166
+ def set_begin_index(self, begin_index: int = 0):
167
+ """
168
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
169
+
170
+ Args:
171
+ begin_index (`int`):
172
+ The begin index for the scheduler.
173
+ """
174
+ self._begin_index = begin_index
175
+
176
+ def set_shift(self, shift: float):
177
+ self._shift = shift
178
+
179
+ def set_scale_factors(self, scale_factors: list, upscale_mode):
180
+ """
181
+ Sets scale factors for a scale-wise generation regime.
182
+
183
+ Args:
184
+ scale_factors (`list`):
185
+ The scale factors for each step
186
+ upscale_mode (`str`):
187
+ Upscaling method
188
+ """
189
+ self._scale_factors = scale_factors
190
+ self._upscale_mode = upscale_mode
191
+
192
+ def scale_noise(
193
+ self,
194
+ sample: torch.FloatTensor,
195
+ timestep: Union[float, torch.FloatTensor],
196
+ noise: Optional[torch.FloatTensor] = None,
197
+ ) -> torch.FloatTensor:
198
+ """
199
+ Forward process in flow-matching
200
+
201
+ Args:
202
+ sample (`torch.FloatTensor`):
203
+ The input sample.
204
+ timestep (`int`, *optional*):
205
+ The current timestep in the diffusion chain.
206
+
207
+ Returns:
208
+ `torch.FloatTensor`:
209
+ A scaled input sample.
210
+ """
211
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
212
+ sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
213
+
214
+ if sample.device.type == "mps" and torch.is_floating_point(timestep):
215
+ # mps does not support float64
216
+ schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
217
+ timestep = timestep.to(sample.device, dtype=torch.float32)
218
+ else:
219
+ schedule_timesteps = self.timesteps.to(sample.device)
220
+ timestep = timestep.to(sample.device)
221
+
222
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
223
+ if self.begin_index is None:
224
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
225
+ elif self.step_index is not None:
226
+ # add_noise is called after first denoising step (for inpainting)
227
+ step_indices = [self.step_index] * timestep.shape[0]
228
+ else:
229
+ # add noise is called before first denoising step to create initial latent(img2img)
230
+ step_indices = [self.begin_index] * timestep.shape[0]
231
+
232
+ sigma = sigmas[step_indices].flatten()
233
+ while len(sigma.shape) < len(sample.shape):
234
+ sigma = sigma.unsqueeze(-1)
235
+
236
+ sample = sigma * noise + (1.0 - sigma) * sample
237
+
238
+ return sample
239
+
240
+ def _sigma_to_t(self, sigma):
241
+ return sigma * self.config.num_train_timesteps
242
+
243
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
244
+ if self.config.time_shift_type == "exponential":
245
+ return self._time_shift_exponential(mu, sigma, t)
246
+ elif self.config.time_shift_type == "linear":
247
+ return self._time_shift_linear(mu, sigma, t)
248
+
249
+ def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
250
+ r"""
251
+ Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
252
+ value.
253
+
254
+ Reference:
255
+ https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
256
+
257
+ Args:
258
+ t (`torch.Tensor`):
259
+ A tensor of timesteps to be stretched and shifted.
260
+
261
+ Returns:
262
+ `torch.Tensor`:
263
+ A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
264
+ """
265
+ one_minus_z = 1 - t
266
+ scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
267
+ stretched_t = 1 - (one_minus_z / scale_factor)
268
+ return stretched_t
269
+
270
+ def set_timesteps(
271
+ self,
272
+ num_inference_steps: Optional[int] = None,
273
+ device: Union[str, torch.device] = None,
274
+ sigmas: Optional[List[float]] = None,
275
+ mu: Optional[float] = None,
276
+ timesteps: Optional[List[float]] = None,
277
+ ):
278
+ """
279
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
280
+
281
+ Args:
282
+ num_inference_steps (`int`, *optional*):
283
+ The number of diffusion steps used when generating samples with a pre-trained model.
284
+ device (`str` or `torch.device`, *optional*):
285
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
286
+ sigmas (`List[float]`, *optional*):
287
+ Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
288
+ automatically.
289
+ mu (`float`, *optional*):
290
+ Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep
291
+ shifting.
292
+ timesteps (`List[float]`, *optional*):
293
+ Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed
294
+ automatically.
295
+ """
296
+ if self.config.use_dynamic_shifting and mu is None:
297
+ raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
298
+
299
+ if sigmas is not None and timesteps is not None:
300
+ if len(sigmas) != len(timesteps):
301
+ raise ValueError("`sigmas` and `timesteps` should have the same length")
302
+
303
+ if num_inference_steps is not None:
304
+ if (sigmas is not None and len(sigmas) != num_inference_steps) or (
305
+ timesteps is not None and len(timesteps) != num_inference_steps
306
+ ):
307
+ raise ValueError(
308
+ "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided"
309
+ )
310
+ else:
311
+ num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps)
312
+
313
+ self.num_inference_steps = num_inference_steps
314
+
315
+ # 1. Prepare default sigmas
316
+ is_timesteps_provided = timesteps is not None
317
+
318
+ if is_timesteps_provided:
319
+ timesteps = np.array(timesteps).astype(np.float32)
320
+
321
+ if sigmas is None:
322
+ if timesteps is None:
323
+ timesteps = np.linspace(
324
+ self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
325
+ )
326
+ sigmas = timesteps / self.config.num_train_timesteps
327
+ else:
328
+ sigmas = np.array(sigmas).astype(np.float32)
329
+ num_inference_steps = len(sigmas)
330
+
331
+ # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of
332
+ # "exponential" or "linear" type is applied
333
+ if self.config.use_dynamic_shifting:
334
+ sigmas = self.time_shift(mu, 1.0, sigmas)
335
+ else:
336
+ sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
337
+
338
+ # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value
339
+ if self.config.shift_terminal:
340
+ sigmas = self.stretch_shift_to_terminal(sigmas)
341
+
342
+ # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules
343
+ if self.config.use_karras_sigmas:
344
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
345
+ elif self.config.use_exponential_sigmas:
346
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
347
+ elif self.config.use_beta_sigmas:
348
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
349
+
350
+ # 5. Convert sigmas and timesteps to tensors and move to specified device
351
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
352
+ if not is_timesteps_provided:
353
+ timesteps = sigmas * self.config.num_train_timesteps
354
+ else:
355
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
356
+
357
+ # 6. Append the terminal sigma value.
358
+ # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the
359
+ # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi
360
+ if self.config.invert_sigmas:
361
+ sigmas = 1.0 - sigmas
362
+ timesteps = sigmas * self.config.num_train_timesteps
363
+ sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
364
+ else:
365
+ sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
366
+
367
+ self.timesteps = timesteps
368
+ self.sigmas = sigmas
369
+ self._step_index = None
370
+ self._begin_index = None
371
+
372
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
373
+ if schedule_timesteps is None:
374
+ schedule_timesteps = self.timesteps
375
+
376
+ indices = (schedule_timesteps == timestep).nonzero()
377
+
378
+ # The sigma index that is taken for the **very** first `step`
379
+ # is always the second index (or the last index if there is only 1)
380
+ # This way we can ensure we don't accidentally skip a sigma in
381
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
382
+ pos = 1 if len(indices) > 1 else 0
383
+
384
+ return indices[pos].item()
385
+
386
+ def _init_step_index(self, timestep):
387
+ if self.begin_index is None:
388
+ if isinstance(timestep, torch.Tensor):
389
+ timestep = timestep.to(self.timesteps.device)
390
+ self._step_index = self.index_for_timestep(timestep)
391
+ else:
392
+ self._step_index = self._begin_index
393
+
394
+ def step(
395
+ self,
396
+ model_output: torch.FloatTensor,
397
+ timestep: Union[float, torch.FloatTensor],
398
+ sample: torch.FloatTensor,
399
+ generator: Optional[torch.Generator] = None,
400
+ return_dict: bool = True,
401
+ ) -> Union[FlowMatchLCMSchedulerOutput, Tuple]:
402
+ """
403
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
404
+ process from the learned model outputs (most often the predicted noise).
405
+
406
+ Args:
407
+ model_output (`torch.FloatTensor`):
408
+ The direct output from learned diffusion model.
409
+ timestep (`float`):
410
+ The current discrete timestep in the diffusion chain.
411
+ sample (`torch.FloatTensor`):
412
+ A current instance of a sample created by the diffusion process.
413
+ generator (`torch.Generator`, *optional*):
414
+ A random number generator.
415
+ return_dict (`bool`):
416
+ Whether or not to return a [`~schedulers.scheduling_flow_match_lcm.FlowMatchLCMSchedulerOutput`] or
417
+ tuple.
418
+
419
+ Returns:
420
+ [`~schedulers.scheduling_flow_match_lcm.FlowMatchLCMSchedulerOutput`] or `tuple`:
421
+ If return_dict is `True`, [`~schedulers.scheduling_flow_match_lcm.FlowMatchLCMSchedulerOutput`] is
422
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
423
+ """
424
+
425
+ if (
426
+ isinstance(timestep, int)
427
+ or isinstance(timestep, torch.IntTensor)
428
+ or isinstance(timestep, torch.LongTensor)
429
+ ):
430
+ raise ValueError(
431
+ (
432
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
433
+ " `FlowMatchLCMScheduler.step()` is not supported. Make sure to pass"
434
+ " one of the `scheduler.timesteps` as a timestep."
435
+ ),
436
+ )
437
+
438
+ if self._scale_factors and self._upscale_mode and len(self.timesteps) != len(self._scale_factors) + 1:
439
+ raise ValueError(
440
+ "`_scale_factors` should have the same length as `timesteps` - 1, if `_scale_factors` are set."
441
+ )
442
+
443
+ if self._init_size is None or self.step_index is None:
444
+ self._init_size = model_output.size()[2:]
445
+
446
+ if self.step_index is None:
447
+ self._init_step_index(timestep)
448
+
449
+ # Upcast to avoid precision issues when computing prev_sample
450
+ sample = sample.to(torch.float32)
451
+
452
+ sigma = self.sigmas[self.step_index]
453
+ sigma_next = self.sigmas[self.step_index + 1]
454
+ x0_pred = sample - sigma * model_output
455
+
456
+ if self._scale_factors and self._upscale_mode:
457
+ if self._step_index < len(self._scale_factors):
458
+ size = [round(self._scale_factors[self._step_index] * size) for size in self._init_size]
459
+ x0_pred = torch.nn.functional.interpolate(x0_pred, size=size, mode=self._upscale_mode)
460
+
461
+ noise = randn_tensor(x0_pred.shape, generator=generator, device=x0_pred.device, dtype=x0_pred.dtype)
462
+ prev_sample = (1 - sigma_next) * x0_pred + sigma_next * noise
463
+
464
+ # upon completion increase step index by one
465
+ self._step_index += 1
466
+ # Cast sample back to model compatible dtype
467
+ prev_sample = prev_sample.to(model_output.dtype)
468
+
469
+ if not return_dict:
470
+ return (prev_sample,)
471
+
472
+ return FlowMatchLCMSchedulerOutput(prev_sample=prev_sample)
473
+
474
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
475
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
476
+ """Constructs the noise schedule of Karras et al. (2022)."""
477
+
478
+ # Hack to make sure that other schedulers which copy this function don't break
479
+ # TODO: Add this logic to the other schedulers
480
+ if hasattr(self.config, "sigma_min"):
481
+ sigma_min = self.config.sigma_min
482
+ else:
483
+ sigma_min = None
484
+
485
+ if hasattr(self.config, "sigma_max"):
486
+ sigma_max = self.config.sigma_max
487
+ else:
488
+ sigma_max = None
489
+
490
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
491
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
492
+
493
+ rho = 7.0 # 7.0 is the value used in the paper
494
+ ramp = np.linspace(0, 1, num_inference_steps)
495
+ min_inv_rho = sigma_min ** (1 / rho)
496
+ max_inv_rho = sigma_max ** (1 / rho)
497
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
498
+ return sigmas
499
+
500
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
501
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
502
+ """Constructs an exponential noise schedule."""
503
+
504
+ # Hack to make sure that other schedulers which copy this function don't break
505
+ # TODO: Add this logic to the other schedulers
506
+ if hasattr(self.config, "sigma_min"):
507
+ sigma_min = self.config.sigma_min
508
+ else:
509
+ sigma_min = None
510
+
511
+ if hasattr(self.config, "sigma_max"):
512
+ sigma_max = self.config.sigma_max
513
+ else:
514
+ sigma_max = None
515
+
516
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
517
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
518
+
519
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
520
+ return sigmas
521
+
522
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
523
+ def _convert_to_beta(
524
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
525
+ ) -> torch.Tensor:
526
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
527
+
528
+ # Hack to make sure that other schedulers which copy this function don't break
529
+ # TODO: Add this logic to the other schedulers
530
+ if hasattr(self.config, "sigma_min"):
531
+ sigma_min = self.config.sigma_min
532
+ else:
533
+ sigma_min = None
534
+
535
+ if hasattr(self.config, "sigma_max"):
536
+ sigma_max = self.config.sigma_max
537
+ else:
538
+ sigma_max = None
539
+
540
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
541
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
542
+
543
+ sigmas = np.array(
544
+ [
545
+ sigma_min + (ppf * (sigma_max - sigma_min))
546
+ for ppf in [
547
+ scipy.stats.beta.ppf(timestep, alpha, beta)
548
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
549
+ ]
550
+ ]
551
+ )
552
+ return sigmas
553
+
554
+ def _time_shift_exponential(self, mu, sigma, t):
555
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
556
+
557
+ def _time_shift_linear(self, mu, sigma, t):
558
+ return mu / (mu + (1 / t - 1) ** sigma)
559
+
560
+ def __len__(self):
561
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_heun_discrete.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from ..configuration_utils import ConfigMixin, register_to_config
23
+ from ..utils import BaseOutput, is_scipy_available
24
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
25
+
26
+
27
+ if is_scipy_available():
28
+ import scipy.stats
29
+
30
+
31
+ @dataclass
32
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->HeunDiscrete
33
+ class HeunDiscreteSchedulerOutput(BaseOutput):
34
+ """
35
+ Output class for the scheduler's `step` function output.
36
+
37
+ Args:
38
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40
+ denoising loop.
41
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43
+ `pred_original_sample` can be used to preview progress or for guidance.
44
+ """
45
+
46
+ prev_sample: torch.Tensor
47
+ pred_original_sample: Optional[torch.Tensor] = None
48
+
49
+
50
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
51
+ def betas_for_alpha_bar(
52
+ num_diffusion_timesteps,
53
+ max_beta=0.999,
54
+ alpha_transform_type="cosine",
55
+ ):
56
+ """
57
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
58
+ (1-beta) over time from t = [0,1].
59
+
60
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
61
+ to that part of the diffusion process.
62
+
63
+
64
+ Args:
65
+ num_diffusion_timesteps (`int`): the number of betas to produce.
66
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
67
+ prevent singularities.
68
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
69
+ Choose from `cosine` or `exp`
70
+
71
+ Returns:
72
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
73
+ """
74
+ if alpha_transform_type == "cosine":
75
+
76
+ def alpha_bar_fn(t):
77
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
78
+
79
+ elif alpha_transform_type == "exp":
80
+
81
+ def alpha_bar_fn(t):
82
+ return math.exp(t * -12.0)
83
+
84
+ else:
85
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
86
+
87
+ betas = []
88
+ for i in range(num_diffusion_timesteps):
89
+ t1 = i / num_diffusion_timesteps
90
+ t2 = (i + 1) / num_diffusion_timesteps
91
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
92
+ return torch.tensor(betas, dtype=torch.float32)
93
+
94
+
95
+ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
96
+ """
97
+ Scheduler with Heun steps for discrete beta schedules.
98
+
99
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
100
+ methods the library implements for all schedulers such as loading and saving.
101
+
102
+ Args:
103
+ num_train_timesteps (`int`, defaults to 1000):
104
+ The number of diffusion steps to train the model.
105
+ beta_start (`float`, defaults to 0.0001):
106
+ The starting `beta` value of inference.
107
+ beta_end (`float`, defaults to 0.02):
108
+ The final `beta` value.
109
+ beta_schedule (`str`, defaults to `"linear"`):
110
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
111
+ `linear` or `scaled_linear`.
112
+ trained_betas (`np.ndarray`, *optional*):
113
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
114
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
115
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
116
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
117
+ Video](https://imagen.research.google/video/paper.pdf) paper).
118
+ clip_sample (`bool`, defaults to `True`):
119
+ Clip the predicted sample for numerical stability.
120
+ clip_sample_range (`float`, defaults to 1.0):
121
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
122
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
123
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
124
+ the sigmas are determined according to a sequence of noise levels {σi}.
125
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
126
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
127
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
128
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
129
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
130
+ timestep_spacing (`str`, defaults to `"linspace"`):
131
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
132
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
133
+ steps_offset (`int`, defaults to 0):
134
+ An offset added to the inference steps, as required by some model families.
135
+ """
136
+
137
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
138
+ order = 2
139
+
140
+ @register_to_config
141
+ def __init__(
142
+ self,
143
+ num_train_timesteps: int = 1000,
144
+ beta_start: float = 0.00085, # sensible defaults
145
+ beta_end: float = 0.012,
146
+ beta_schedule: str = "linear",
147
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
148
+ prediction_type: str = "epsilon",
149
+ use_karras_sigmas: Optional[bool] = False,
150
+ use_exponential_sigmas: Optional[bool] = False,
151
+ use_beta_sigmas: Optional[bool] = False,
152
+ clip_sample: Optional[bool] = False,
153
+ clip_sample_range: float = 1.0,
154
+ timestep_spacing: str = "linspace",
155
+ steps_offset: int = 0,
156
+ ):
157
+ if self.config.use_beta_sigmas and not is_scipy_available():
158
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
159
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
160
+ raise ValueError(
161
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
162
+ )
163
+ if trained_betas is not None:
164
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
165
+ elif beta_schedule == "linear":
166
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
167
+ elif beta_schedule == "scaled_linear":
168
+ # this schedule is very specific to the latent diffusion model.
169
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
170
+ elif beta_schedule == "squaredcos_cap_v2":
171
+ # Glide cosine schedule
172
+ self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
173
+ elif beta_schedule == "exp":
174
+ self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp")
175
+ else:
176
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
177
+
178
+ self.alphas = 1.0 - self.betas
179
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
180
+
181
+ # set all values
182
+ self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
183
+ self.use_karras_sigmas = use_karras_sigmas
184
+
185
+ self._step_index = None
186
+ self._begin_index = None
187
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
188
+
189
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
190
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
191
+ if schedule_timesteps is None:
192
+ schedule_timesteps = self.timesteps
193
+
194
+ indices = (schedule_timesteps == timestep).nonzero()
195
+
196
+ # The sigma index that is taken for the **very** first `step`
197
+ # is always the second index (or the last index if there is only 1)
198
+ # This way we can ensure we don't accidentally skip a sigma in
199
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
200
+ pos = 1 if len(indices) > 1 else 0
201
+
202
+ return indices[pos].item()
203
+
204
+ @property
205
+ def init_noise_sigma(self):
206
+ # standard deviation of the initial noise distribution
207
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
208
+ return self.sigmas.max()
209
+
210
+ return (self.sigmas.max() ** 2 + 1) ** 0.5
211
+
212
+ @property
213
+ def step_index(self):
214
+ """
215
+ The index counter for current timestep. It will increase 1 after each scheduler step.
216
+ """
217
+ return self._step_index
218
+
219
+ @property
220
+ def begin_index(self):
221
+ """
222
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
223
+ """
224
+ return self._begin_index
225
+
226
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
227
+ def set_begin_index(self, begin_index: int = 0):
228
+ """
229
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
230
+
231
+ Args:
232
+ begin_index (`int`):
233
+ The begin index for the scheduler.
234
+ """
235
+ self._begin_index = begin_index
236
+
237
+ def scale_model_input(
238
+ self,
239
+ sample: torch.Tensor,
240
+ timestep: Union[float, torch.Tensor],
241
+ ) -> torch.Tensor:
242
+ """
243
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
244
+ current timestep.
245
+
246
+ Args:
247
+ sample (`torch.Tensor`):
248
+ The input sample.
249
+ timestep (`int`, *optional*):
250
+ The current timestep in the diffusion chain.
251
+
252
+ Returns:
253
+ `torch.Tensor`:
254
+ A scaled input sample.
255
+ """
256
+ if self.step_index is None:
257
+ self._init_step_index(timestep)
258
+
259
+ sigma = self.sigmas[self.step_index]
260
+ sample = sample / ((sigma**2 + 1) ** 0.5)
261
+ return sample
262
+
263
+ def set_timesteps(
264
+ self,
265
+ num_inference_steps: Optional[int] = None,
266
+ device: Union[str, torch.device] = None,
267
+ num_train_timesteps: Optional[int] = None,
268
+ timesteps: Optional[List[int]] = None,
269
+ ):
270
+ """
271
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
272
+
273
+ Args:
274
+ num_inference_steps (`int`):
275
+ The number of diffusion steps used when generating samples with a pre-trained model.
276
+ device (`str` or `torch.device`, *optional*):
277
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
278
+ num_train_timesteps (`int`, *optional*):
279
+ The number of diffusion steps used when training the model. If `None`, the default
280
+ `num_train_timesteps` attribute is used.
281
+ timesteps (`List[int]`, *optional*):
282
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be
283
+ generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps`
284
+ must be `None`, and `timestep_spacing` attribute will be ignored.
285
+ """
286
+ if num_inference_steps is None and timesteps is None:
287
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
288
+ if num_inference_steps is not None and timesteps is not None:
289
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
290
+ if timesteps is not None and self.config.use_karras_sigmas:
291
+ raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
292
+ if timesteps is not None and self.config.use_exponential_sigmas:
293
+ raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
294
+ if timesteps is not None and self.config.use_beta_sigmas:
295
+ raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
296
+
297
+ num_inference_steps = num_inference_steps or len(timesteps)
298
+ self.num_inference_steps = num_inference_steps
299
+ num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
300
+
301
+ if timesteps is not None:
302
+ timesteps = np.array(timesteps, dtype=np.float32)
303
+ else:
304
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
305
+ if self.config.timestep_spacing == "linspace":
306
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
307
+ elif self.config.timestep_spacing == "leading":
308
+ step_ratio = num_train_timesteps // self.num_inference_steps
309
+ # creates integer timesteps by multiplying by ratio
310
+ # casting to int to avoid issues when num_inference_step is power of 3
311
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
312
+ timesteps += self.config.steps_offset
313
+ elif self.config.timestep_spacing == "trailing":
314
+ step_ratio = num_train_timesteps / self.num_inference_steps
315
+ # creates integer timesteps by multiplying by ratio
316
+ # casting to int to avoid issues when num_inference_step is power of 3
317
+ timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
318
+ timesteps -= 1
319
+ else:
320
+ raise ValueError(
321
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
322
+ )
323
+
324
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
325
+ log_sigmas = np.log(sigmas)
326
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
327
+
328
+ if self.config.use_karras_sigmas:
329
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
330
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
331
+ elif self.config.use_exponential_sigmas:
332
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
333
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
334
+ elif self.config.use_beta_sigmas:
335
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
336
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
337
+
338
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
339
+ sigmas = torch.from_numpy(sigmas).to(device=device)
340
+ self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
341
+
342
+ timesteps = torch.from_numpy(timesteps)
343
+ timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
344
+
345
+ self.timesteps = timesteps.to(device=device, dtype=torch.float32)
346
+
347
+ # empty dt and derivative
348
+ self.prev_derivative = None
349
+ self.dt = None
350
+
351
+ self._step_index = None
352
+ self._begin_index = None
353
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
354
+
355
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
356
+ def _sigma_to_t(self, sigma, log_sigmas):
357
+ # get log sigma
358
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
359
+
360
+ # get distribution
361
+ dists = log_sigma - log_sigmas[:, np.newaxis]
362
+
363
+ # get sigmas range
364
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
365
+ high_idx = low_idx + 1
366
+
367
+ low = log_sigmas[low_idx]
368
+ high = log_sigmas[high_idx]
369
+
370
+ # interpolate sigmas
371
+ w = (low - log_sigma) / (low - high)
372
+ w = np.clip(w, 0, 1)
373
+
374
+ # transform interpolation to time range
375
+ t = (1 - w) * low_idx + w * high_idx
376
+ t = t.reshape(sigma.shape)
377
+ return t
378
+
379
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
380
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
381
+ """Constructs the noise schedule of Karras et al. (2022)."""
382
+
383
+ # Hack to make sure that other schedulers which copy this function don't break
384
+ # TODO: Add this logic to the other schedulers
385
+ if hasattr(self.config, "sigma_min"):
386
+ sigma_min = self.config.sigma_min
387
+ else:
388
+ sigma_min = None
389
+
390
+ if hasattr(self.config, "sigma_max"):
391
+ sigma_max = self.config.sigma_max
392
+ else:
393
+ sigma_max = None
394
+
395
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
396
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
397
+
398
+ rho = 7.0 # 7.0 is the value used in the paper
399
+ ramp = np.linspace(0, 1, num_inference_steps)
400
+ min_inv_rho = sigma_min ** (1 / rho)
401
+ max_inv_rho = sigma_max ** (1 / rho)
402
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
403
+ return sigmas
404
+
405
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
406
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
407
+ """Constructs an exponential noise schedule."""
408
+
409
+ # Hack to make sure that other schedulers which copy this function don't break
410
+ # TODO: Add this logic to the other schedulers
411
+ if hasattr(self.config, "sigma_min"):
412
+ sigma_min = self.config.sigma_min
413
+ else:
414
+ sigma_min = None
415
+
416
+ if hasattr(self.config, "sigma_max"):
417
+ sigma_max = self.config.sigma_max
418
+ else:
419
+ sigma_max = None
420
+
421
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
422
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
423
+
424
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
425
+ return sigmas
426
+
427
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
428
+ def _convert_to_beta(
429
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
430
+ ) -> torch.Tensor:
431
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
432
+
433
+ # Hack to make sure that other schedulers which copy this function don't break
434
+ # TODO: Add this logic to the other schedulers
435
+ if hasattr(self.config, "sigma_min"):
436
+ sigma_min = self.config.sigma_min
437
+ else:
438
+ sigma_min = None
439
+
440
+ if hasattr(self.config, "sigma_max"):
441
+ sigma_max = self.config.sigma_max
442
+ else:
443
+ sigma_max = None
444
+
445
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
446
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
447
+
448
+ sigmas = np.array(
449
+ [
450
+ sigma_min + (ppf * (sigma_max - sigma_min))
451
+ for ppf in [
452
+ scipy.stats.beta.ppf(timestep, alpha, beta)
453
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
454
+ ]
455
+ ]
456
+ )
457
+ return sigmas
458
+
459
+ @property
460
+ def state_in_first_order(self):
461
+ return self.dt is None
462
+
463
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
464
+ def _init_step_index(self, timestep):
465
+ if self.begin_index is None:
466
+ if isinstance(timestep, torch.Tensor):
467
+ timestep = timestep.to(self.timesteps.device)
468
+ self._step_index = self.index_for_timestep(timestep)
469
+ else:
470
+ self._step_index = self._begin_index
471
+
472
+ def step(
473
+ self,
474
+ model_output: Union[torch.Tensor, np.ndarray],
475
+ timestep: Union[float, torch.Tensor],
476
+ sample: Union[torch.Tensor, np.ndarray],
477
+ return_dict: bool = True,
478
+ ) -> Union[HeunDiscreteSchedulerOutput, Tuple]:
479
+ """
480
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
481
+ process from the learned model outputs (most often the predicted noise).
482
+
483
+ Args:
484
+ model_output (`torch.Tensor`):
485
+ The direct output from learned diffusion model.
486
+ timestep (`float`):
487
+ The current discrete timestep in the diffusion chain.
488
+ sample (`torch.Tensor`):
489
+ A current instance of a sample created by the diffusion process.
490
+ return_dict (`bool`):
491
+ Whether or not to return a [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or
492
+ tuple.
493
+
494
+ Returns:
495
+ [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
496
+ If return_dict is `True`, [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] is
497
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
498
+ """
499
+ if self.step_index is None:
500
+ self._init_step_index(timestep)
501
+
502
+ if self.state_in_first_order:
503
+ sigma = self.sigmas[self.step_index]
504
+ sigma_next = self.sigmas[self.step_index + 1]
505
+ else:
506
+ # 2nd order / Heun's method
507
+ sigma = self.sigmas[self.step_index - 1]
508
+ sigma_next = self.sigmas[self.step_index]
509
+
510
+ # currently only gamma=0 is supported. This usually works best anyways.
511
+ # We can support gamma in the future but then need to scale the timestep before
512
+ # passing it to the model which requires a change in API
513
+ gamma = 0
514
+ sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
515
+
516
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
517
+ if self.config.prediction_type == "epsilon":
518
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_next
519
+ pred_original_sample = sample - sigma_input * model_output
520
+ elif self.config.prediction_type == "v_prediction":
521
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_next
522
+ pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
523
+ sample / (sigma_input**2 + 1)
524
+ )
525
+ elif self.config.prediction_type == "sample":
526
+ pred_original_sample = model_output
527
+ else:
528
+ raise ValueError(
529
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
530
+ )
531
+
532
+ if self.config.clip_sample:
533
+ pred_original_sample = pred_original_sample.clamp(
534
+ -self.config.clip_sample_range, self.config.clip_sample_range
535
+ )
536
+
537
+ if self.state_in_first_order:
538
+ # 2. Convert to an ODE derivative for 1st order
539
+ derivative = (sample - pred_original_sample) / sigma_hat
540
+ # 3. delta timestep
541
+ dt = sigma_next - sigma_hat
542
+
543
+ # store for 2nd order step
544
+ self.prev_derivative = derivative
545
+ self.dt = dt
546
+ self.sample = sample
547
+ else:
548
+ # 2. 2nd order / Heun's method
549
+ derivative = (sample - pred_original_sample) / sigma_next
550
+ derivative = (self.prev_derivative + derivative) / 2
551
+
552
+ # 3. take prev timestep & sample
553
+ dt = self.dt
554
+ sample = self.sample
555
+
556
+ # free dt and derivative
557
+ # Note, this puts the scheduler in "first order mode"
558
+ self.prev_derivative = None
559
+ self.dt = None
560
+ self.sample = None
561
+
562
+ prev_sample = sample + derivative * dt
563
+
564
+ # upon completion increase step index by one
565
+ self._step_index += 1
566
+
567
+ if not return_dict:
568
+ return (
569
+ prev_sample,
570
+ pred_original_sample,
571
+ )
572
+
573
+ return HeunDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
574
+
575
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
576
+ def add_noise(
577
+ self,
578
+ original_samples: torch.Tensor,
579
+ noise: torch.Tensor,
580
+ timesteps: torch.Tensor,
581
+ ) -> torch.Tensor:
582
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
583
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
584
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
585
+ # mps does not support float64
586
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
587
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
588
+ else:
589
+ schedule_timesteps = self.timesteps.to(original_samples.device)
590
+ timesteps = timesteps.to(original_samples.device)
591
+
592
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
593
+ if self.begin_index is None:
594
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
595
+ elif self.step_index is not None:
596
+ # add_noise is called after first denoising step (for inpainting)
597
+ step_indices = [self.step_index] * timesteps.shape[0]
598
+ else:
599
+ # add noise is called before first denoising step to create initial latent(img2img)
600
+ step_indices = [self.begin_index] * timesteps.shape[0]
601
+
602
+ sigma = sigmas[step_indices].flatten()
603
+ while len(sigma.shape) < len(original_samples.shape):
604
+ sigma = sigma.unsqueeze(-1)
605
+
606
+ noisy_samples = original_samples + noise * sigma
607
+ return noisy_samples
608
+
609
+ def __len__(self):
610
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_ipndm.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Zhejiang University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from .scheduling_utils import SchedulerMixin, SchedulerOutput
23
+
24
+
25
+ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
26
+ """
27
+ A fourth-order Improved Pseudo Linear Multistep scheduler.
28
+
29
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
30
+ methods the library implements for all schedulers such as loading and saving.
31
+
32
+ Args:
33
+ num_train_timesteps (`int`, defaults to 1000):
34
+ The number of diffusion steps to train the model.
35
+ trained_betas (`np.ndarray`, *optional*):
36
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
37
+ """
38
+
39
+ order = 1
40
+
41
+ @register_to_config
42
+ def __init__(
43
+ self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None
44
+ ):
45
+ # set `betas`, `alphas`, `timesteps`
46
+ self.set_timesteps(num_train_timesteps)
47
+
48
+ # standard deviation of the initial noise distribution
49
+ self.init_noise_sigma = 1.0
50
+
51
+ # For now we only support F-PNDM, i.e. the runge-kutta method
52
+ # For more information on the algorithm please take a look at the paper: https://huggingface.co/papers/2202.09778
53
+ # mainly at formula (9), (12), (13) and the Algorithm 2.
54
+ self.pndm_order = 4
55
+
56
+ # running values
57
+ self.ets = []
58
+ self._step_index = None
59
+ self._begin_index = None
60
+
61
+ @property
62
+ def step_index(self):
63
+ """
64
+ The index counter for current timestep. It will increase 1 after each scheduler step.
65
+ """
66
+ return self._step_index
67
+
68
+ @property
69
+ def begin_index(self):
70
+ """
71
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
72
+ """
73
+ return self._begin_index
74
+
75
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
76
+ def set_begin_index(self, begin_index: int = 0):
77
+ """
78
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
79
+
80
+ Args:
81
+ begin_index (`int`):
82
+ The begin index for the scheduler.
83
+ """
84
+ self._begin_index = begin_index
85
+
86
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
87
+ """
88
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
89
+
90
+ Args:
91
+ num_inference_steps (`int`):
92
+ The number of diffusion steps used when generating samples with a pre-trained model.
93
+ device (`str` or `torch.device`, *optional*):
94
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
95
+ """
96
+ self.num_inference_steps = num_inference_steps
97
+ steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1]
98
+ steps = torch.cat([steps, torch.tensor([0.0])])
99
+
100
+ if self.config.trained_betas is not None:
101
+ self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32)
102
+ else:
103
+ self.betas = torch.sin(steps * math.pi / 2) ** 2
104
+
105
+ self.alphas = (1.0 - self.betas**2) ** 0.5
106
+
107
+ timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1]
108
+ self.timesteps = timesteps.to(device)
109
+
110
+ self.ets = []
111
+ self._step_index = None
112
+ self._begin_index = None
113
+
114
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
115
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
116
+ if schedule_timesteps is None:
117
+ schedule_timesteps = self.timesteps
118
+
119
+ indices = (schedule_timesteps == timestep).nonzero()
120
+
121
+ # The sigma index that is taken for the **very** first `step`
122
+ # is always the second index (or the last index if there is only 1)
123
+ # This way we can ensure we don't accidentally skip a sigma in
124
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
125
+ pos = 1 if len(indices) > 1 else 0
126
+
127
+ return indices[pos].item()
128
+
129
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
130
+ def _init_step_index(self, timestep):
131
+ if self.begin_index is None:
132
+ if isinstance(timestep, torch.Tensor):
133
+ timestep = timestep.to(self.timesteps.device)
134
+ self._step_index = self.index_for_timestep(timestep)
135
+ else:
136
+ self._step_index = self._begin_index
137
+
138
+ def step(
139
+ self,
140
+ model_output: torch.Tensor,
141
+ timestep: Union[int, torch.Tensor],
142
+ sample: torch.Tensor,
143
+ return_dict: bool = True,
144
+ ) -> Union[SchedulerOutput, Tuple]:
145
+ """
146
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
147
+ the linear multistep method. It performs one forward pass multiple times to approximate the solution.
148
+
149
+ Args:
150
+ model_output (`torch.Tensor`):
151
+ The direct output from learned diffusion model.
152
+ timestep (`int`):
153
+ The current discrete timestep in the diffusion chain.
154
+ sample (`torch.Tensor`):
155
+ A current instance of a sample created by the diffusion process.
156
+ return_dict (`bool`):
157
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
158
+
159
+ Returns:
160
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
161
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
162
+ tuple is returned where the first element is the sample tensor.
163
+ """
164
+ if self.num_inference_steps is None:
165
+ raise ValueError(
166
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
167
+ )
168
+ if self.step_index is None:
169
+ self._init_step_index(timestep)
170
+
171
+ timestep_index = self.step_index
172
+ prev_timestep_index = self.step_index + 1
173
+
174
+ ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index]
175
+ self.ets.append(ets)
176
+
177
+ if len(self.ets) == 1:
178
+ ets = self.ets[-1]
179
+ elif len(self.ets) == 2:
180
+ ets = (3 * self.ets[-1] - self.ets[-2]) / 2
181
+ elif len(self.ets) == 3:
182
+ ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
183
+ else:
184
+ ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
185
+
186
+ prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets)
187
+
188
+ # upon completion increase step index by one
189
+ self._step_index += 1
190
+
191
+ if not return_dict:
192
+ return (prev_sample,)
193
+
194
+ return SchedulerOutput(prev_sample=prev_sample)
195
+
196
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
197
+ """
198
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
199
+ current timestep.
200
+
201
+ Args:
202
+ sample (`torch.Tensor`):
203
+ The input sample.
204
+
205
+ Returns:
206
+ `torch.Tensor`:
207
+ A scaled input sample.
208
+ """
209
+ return sample
210
+
211
+ def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):
212
+ alpha = self.alphas[timestep_index]
213
+ sigma = self.betas[timestep_index]
214
+
215
+ next_alpha = self.alphas[prev_timestep_index]
216
+ next_sigma = self.betas[prev_timestep_index]
217
+
218
+ pred = (sample - sigma * ets) / max(alpha, 1e-8)
219
+ prev_sample = next_alpha * pred + ets * next_sigma
220
+
221
+ return prev_sample
222
+
223
+ def __len__(self):
224
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from ..configuration_utils import ConfigMixin, register_to_config
23
+ from ..utils import BaseOutput, is_scipy_available
24
+ from ..utils.torch_utils import randn_tensor
25
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
26
+
27
+
28
+ if is_scipy_available():
29
+ import scipy.stats
30
+
31
+
32
+ @dataclass
33
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->KDPM2AncestralDiscrete
34
+ class KDPM2AncestralDiscreteSchedulerOutput(BaseOutput):
35
+ """
36
+ Output class for the scheduler's `step` function output.
37
+
38
+ Args:
39
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
41
+ denoising loop.
42
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
43
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
44
+ `pred_original_sample` can be used to preview progress or for guidance.
45
+ """
46
+
47
+ prev_sample: torch.Tensor
48
+ pred_original_sample: Optional[torch.Tensor] = None
49
+
50
+
51
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
52
+ def betas_for_alpha_bar(
53
+ num_diffusion_timesteps,
54
+ max_beta=0.999,
55
+ alpha_transform_type="cosine",
56
+ ):
57
+ """
58
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
59
+ (1-beta) over time from t = [0,1].
60
+
61
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
62
+ to that part of the diffusion process.
63
+
64
+
65
+ Args:
66
+ num_diffusion_timesteps (`int`): the number of betas to produce.
67
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
68
+ prevent singularities.
69
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
70
+ Choose from `cosine` or `exp`
71
+
72
+ Returns:
73
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
74
+ """
75
+ if alpha_transform_type == "cosine":
76
+
77
+ def alpha_bar_fn(t):
78
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
79
+
80
+ elif alpha_transform_type == "exp":
81
+
82
+ def alpha_bar_fn(t):
83
+ return math.exp(t * -12.0)
84
+
85
+ else:
86
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
87
+
88
+ betas = []
89
+ for i in range(num_diffusion_timesteps):
90
+ t1 = i / num_diffusion_timesteps
91
+ t2 = (i + 1) / num_diffusion_timesteps
92
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
93
+ return torch.tensor(betas, dtype=torch.float32)
94
+
95
+
96
+ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
97
+ """
98
+ KDPM2DiscreteScheduler with ancestral sampling is inspired by the DPMSolver2 and Algorithm 2 from the [Elucidating
99
+ the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper.
100
+
101
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
102
+ methods the library implements for all schedulers such as loading and saving.
103
+
104
+ Args:
105
+ num_train_timesteps (`int`, defaults to 1000):
106
+ The number of diffusion steps to train the model.
107
+ beta_start (`float`, defaults to 0.00085):
108
+ The starting `beta` value of inference.
109
+ beta_end (`float`, defaults to 0.012):
110
+ The final `beta` value.
111
+ beta_schedule (`str`, defaults to `"linear"`):
112
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
113
+ `linear` or `scaled_linear`.
114
+ trained_betas (`np.ndarray`, *optional*):
115
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
116
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
117
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
118
+ the sigmas are determined according to a sequence of noise levels {σi}.
119
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
120
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
121
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
122
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
123
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
124
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
125
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
126
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
127
+ Video](https://imagen.research.google/video/paper.pdf) paper).
128
+ timestep_spacing (`str`, defaults to `"linspace"`):
129
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
130
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
131
+ steps_offset (`int`, defaults to 0):
132
+ An offset added to the inference steps, as required by some model families.
133
+ """
134
+
135
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
136
+ order = 2
137
+
138
+ @register_to_config
139
+ def __init__(
140
+ self,
141
+ num_train_timesteps: int = 1000,
142
+ beta_start: float = 0.00085, # sensible defaults
143
+ beta_end: float = 0.012,
144
+ beta_schedule: str = "linear",
145
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
146
+ use_karras_sigmas: Optional[bool] = False,
147
+ use_exponential_sigmas: Optional[bool] = False,
148
+ use_beta_sigmas: Optional[bool] = False,
149
+ prediction_type: str = "epsilon",
150
+ timestep_spacing: str = "linspace",
151
+ steps_offset: int = 0,
152
+ ):
153
+ if self.config.use_beta_sigmas and not is_scipy_available():
154
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
155
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
156
+ raise ValueError(
157
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
158
+ )
159
+ if trained_betas is not None:
160
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
161
+ elif beta_schedule == "linear":
162
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
163
+ elif beta_schedule == "scaled_linear":
164
+ # this schedule is very specific to the latent diffusion model.
165
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
166
+ elif beta_schedule == "squaredcos_cap_v2":
167
+ # Glide cosine schedule
168
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
169
+ else:
170
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
171
+
172
+ self.alphas = 1.0 - self.betas
173
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
174
+
175
+ # set all values
176
+ self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
177
+ self._step_index = None
178
+ self._begin_index = None
179
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
180
+
181
+ @property
182
+ def init_noise_sigma(self):
183
+ # standard deviation of the initial noise distribution
184
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
185
+ return self.sigmas.max()
186
+
187
+ return (self.sigmas.max() ** 2 + 1) ** 0.5
188
+
189
+ @property
190
+ def step_index(self):
191
+ """
192
+ The index counter for current timestep. It will increase 1 after each scheduler step.
193
+ """
194
+ return self._step_index
195
+
196
+ @property
197
+ def begin_index(self):
198
+ """
199
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
200
+ """
201
+ return self._begin_index
202
+
203
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
204
+ def set_begin_index(self, begin_index: int = 0):
205
+ """
206
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
207
+
208
+ Args:
209
+ begin_index (`int`):
210
+ The begin index for the scheduler.
211
+ """
212
+ self._begin_index = begin_index
213
+
214
+ def scale_model_input(
215
+ self,
216
+ sample: torch.Tensor,
217
+ timestep: Union[float, torch.Tensor],
218
+ ) -> torch.Tensor:
219
+ """
220
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
221
+ current timestep.
222
+
223
+ Args:
224
+ sample (`torch.Tensor`):
225
+ The input sample.
226
+ timestep (`int`, *optional*):
227
+ The current timestep in the diffusion chain.
228
+
229
+ Returns:
230
+ `torch.Tensor`:
231
+ A scaled input sample.
232
+ """
233
+ if self.step_index is None:
234
+ self._init_step_index(timestep)
235
+
236
+ if self.state_in_first_order:
237
+ sigma = self.sigmas[self.step_index]
238
+ else:
239
+ sigma = self.sigmas_interpol[self.step_index - 1]
240
+
241
+ sample = sample / ((sigma**2 + 1) ** 0.5)
242
+ return sample
243
+
244
+ def set_timesteps(
245
+ self,
246
+ num_inference_steps: int,
247
+ device: Union[str, torch.device] = None,
248
+ num_train_timesteps: Optional[int] = None,
249
+ ):
250
+ """
251
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
252
+
253
+ Args:
254
+ num_inference_steps (`int`):
255
+ The number of diffusion steps used when generating samples with a pre-trained model.
256
+ device (`str` or `torch.device`, *optional*):
257
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
258
+ """
259
+ self.num_inference_steps = num_inference_steps
260
+
261
+ num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
262
+
263
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
264
+ if self.config.timestep_spacing == "linspace":
265
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
266
+ elif self.config.timestep_spacing == "leading":
267
+ step_ratio = num_train_timesteps // self.num_inference_steps
268
+ # creates integer timesteps by multiplying by ratio
269
+ # casting to int to avoid issues when num_inference_step is power of 3
270
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
271
+ timesteps += self.config.steps_offset
272
+ elif self.config.timestep_spacing == "trailing":
273
+ step_ratio = num_train_timesteps / self.num_inference_steps
274
+ # creates integer timesteps by multiplying by ratio
275
+ # casting to int to avoid issues when num_inference_step is power of 3
276
+ timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
277
+ timesteps -= 1
278
+ else:
279
+ raise ValueError(
280
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
281
+ )
282
+
283
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
284
+ log_sigmas = np.log(sigmas)
285
+
286
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
287
+
288
+ if self.config.use_karras_sigmas:
289
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
290
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
291
+ elif self.config.use_exponential_sigmas:
292
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
293
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
294
+ elif self.config.use_beta_sigmas:
295
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
296
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
297
+
298
+ self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
299
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
300
+ sigmas = torch.from_numpy(sigmas).to(device=device)
301
+
302
+ # compute up and down sigmas
303
+ sigmas_next = sigmas.roll(-1)
304
+ sigmas_next[-1] = 0.0
305
+ sigmas_up = (sigmas_next**2 * (sigmas**2 - sigmas_next**2) / sigmas**2) ** 0.5
306
+ sigmas_down = (sigmas_next**2 - sigmas_up**2) ** 0.5
307
+ sigmas_down[-1] = 0.0
308
+
309
+ # compute interpolated sigmas
310
+ sigmas_interpol = sigmas.log().lerp(sigmas_down.log(), 0.5).exp()
311
+ sigmas_interpol[-2:] = 0.0
312
+
313
+ # set sigmas
314
+ self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]])
315
+ self.sigmas_interpol = torch.cat(
316
+ [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]]
317
+ )
318
+ self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
319
+ self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
320
+
321
+ if str(device).startswith("mps"):
322
+ timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
323
+ else:
324
+ timesteps = torch.from_numpy(timesteps).to(device)
325
+
326
+ sigmas_interpol = sigmas_interpol.cpu()
327
+ log_sigmas = self.log_sigmas.cpu()
328
+ timesteps_interpol = np.array(
329
+ [self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
330
+ )
331
+
332
+ timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
333
+ interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
334
+
335
+ self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
336
+
337
+ self.sample = None
338
+
339
+ self._step_index = None
340
+ self._begin_index = None
341
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
342
+
343
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
344
+ def _sigma_to_t(self, sigma, log_sigmas):
345
+ # get log sigma
346
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
347
+
348
+ # get distribution
349
+ dists = log_sigma - log_sigmas[:, np.newaxis]
350
+
351
+ # get sigmas range
352
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
353
+ high_idx = low_idx + 1
354
+
355
+ low = log_sigmas[low_idx]
356
+ high = log_sigmas[high_idx]
357
+
358
+ # interpolate sigmas
359
+ w = (low - log_sigma) / (low - high)
360
+ w = np.clip(w, 0, 1)
361
+
362
+ # transform interpolation to time range
363
+ t = (1 - w) * low_idx + w * high_idx
364
+ t = t.reshape(sigma.shape)
365
+ return t
366
+
367
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
368
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
369
+ """Constructs the noise schedule of Karras et al. (2022)."""
370
+
371
+ # Hack to make sure that other schedulers which copy this function don't break
372
+ # TODO: Add this logic to the other schedulers
373
+ if hasattr(self.config, "sigma_min"):
374
+ sigma_min = self.config.sigma_min
375
+ else:
376
+ sigma_min = None
377
+
378
+ if hasattr(self.config, "sigma_max"):
379
+ sigma_max = self.config.sigma_max
380
+ else:
381
+ sigma_max = None
382
+
383
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
384
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
385
+
386
+ rho = 7.0 # 7.0 is the value used in the paper
387
+ ramp = np.linspace(0, 1, num_inference_steps)
388
+ min_inv_rho = sigma_min ** (1 / rho)
389
+ max_inv_rho = sigma_max ** (1 / rho)
390
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
391
+ return sigmas
392
+
393
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
394
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
395
+ """Constructs an exponential noise schedule."""
396
+
397
+ # Hack to make sure that other schedulers which copy this function don't break
398
+ # TODO: Add this logic to the other schedulers
399
+ if hasattr(self.config, "sigma_min"):
400
+ sigma_min = self.config.sigma_min
401
+ else:
402
+ sigma_min = None
403
+
404
+ if hasattr(self.config, "sigma_max"):
405
+ sigma_max = self.config.sigma_max
406
+ else:
407
+ sigma_max = None
408
+
409
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
410
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
411
+
412
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
413
+ return sigmas
414
+
415
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
416
+ def _convert_to_beta(
417
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
418
+ ) -> torch.Tensor:
419
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
420
+
421
+ # Hack to make sure that other schedulers which copy this function don't break
422
+ # TODO: Add this logic to the other schedulers
423
+ if hasattr(self.config, "sigma_min"):
424
+ sigma_min = self.config.sigma_min
425
+ else:
426
+ sigma_min = None
427
+
428
+ if hasattr(self.config, "sigma_max"):
429
+ sigma_max = self.config.sigma_max
430
+ else:
431
+ sigma_max = None
432
+
433
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
434
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
435
+
436
+ sigmas = np.array(
437
+ [
438
+ sigma_min + (ppf * (sigma_max - sigma_min))
439
+ for ppf in [
440
+ scipy.stats.beta.ppf(timestep, alpha, beta)
441
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
442
+ ]
443
+ ]
444
+ )
445
+ return sigmas
446
+
447
+ @property
448
+ def state_in_first_order(self):
449
+ return self.sample is None
450
+
451
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
452
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
453
+ if schedule_timesteps is None:
454
+ schedule_timesteps = self.timesteps
455
+
456
+ indices = (schedule_timesteps == timestep).nonzero()
457
+
458
+ # The sigma index that is taken for the **very** first `step`
459
+ # is always the second index (or the last index if there is only 1)
460
+ # This way we can ensure we don't accidentally skip a sigma in
461
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
462
+ pos = 1 if len(indices) > 1 else 0
463
+
464
+ return indices[pos].item()
465
+
466
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
467
+ def _init_step_index(self, timestep):
468
+ if self.begin_index is None:
469
+ if isinstance(timestep, torch.Tensor):
470
+ timestep = timestep.to(self.timesteps.device)
471
+ self._step_index = self.index_for_timestep(timestep)
472
+ else:
473
+ self._step_index = self._begin_index
474
+
475
+ def step(
476
+ self,
477
+ model_output: Union[torch.Tensor, np.ndarray],
478
+ timestep: Union[float, torch.Tensor],
479
+ sample: Union[torch.Tensor, np.ndarray],
480
+ generator: Optional[torch.Generator] = None,
481
+ return_dict: bool = True,
482
+ ) -> Union[KDPM2AncestralDiscreteSchedulerOutput, Tuple]:
483
+ """
484
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
485
+ process from the learned model outputs (most often the predicted noise).
486
+
487
+ Args:
488
+ model_output (`torch.Tensor`):
489
+ The direct output from learned diffusion model.
490
+ timestep (`float`):
491
+ The current discrete timestep in the diffusion chain.
492
+ sample (`torch.Tensor`):
493
+ A current instance of a sample created by the diffusion process.
494
+ generator (`torch.Generator`, *optional*):
495
+ A random number generator.
496
+ return_dict (`bool`):
497
+ Whether or not to return a
498
+ [`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] or tuple.
499
+
500
+ Returns:
501
+ [`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] or `tuple`:
502
+ If return_dict is `True`,
503
+ [`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] is
504
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
505
+ """
506
+ if self.step_index is None:
507
+ self._init_step_index(timestep)
508
+
509
+ if self.state_in_first_order:
510
+ sigma = self.sigmas[self.step_index]
511
+ sigma_interpol = self.sigmas_interpol[self.step_index]
512
+ sigma_up = self.sigmas_up[self.step_index]
513
+ sigma_down = self.sigmas_down[self.step_index - 1]
514
+ else:
515
+ # 2nd order / KPDM2's method
516
+ sigma = self.sigmas[self.step_index - 1]
517
+ sigma_interpol = self.sigmas_interpol[self.step_index - 1]
518
+ sigma_up = self.sigmas_up[self.step_index - 1]
519
+ sigma_down = self.sigmas_down[self.step_index - 1]
520
+
521
+ # currently only gamma=0 is supported. This usually works best anyways.
522
+ # We can support gamma in the future but then need to scale the timestep before
523
+ # passing it to the model which requires a change in API
524
+ gamma = 0
525
+ sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
526
+
527
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
528
+ if self.config.prediction_type == "epsilon":
529
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
530
+ pred_original_sample = sample - sigma_input * model_output
531
+ elif self.config.prediction_type == "v_prediction":
532
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
533
+ pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
534
+ sample / (sigma_input**2 + 1)
535
+ )
536
+ elif self.config.prediction_type == "sample":
537
+ raise NotImplementedError("prediction_type not implemented yet: sample")
538
+ else:
539
+ raise ValueError(
540
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
541
+ )
542
+
543
+ if self.state_in_first_order:
544
+ # 2. Convert to an ODE derivative for 1st order
545
+ derivative = (sample - pred_original_sample) / sigma_hat
546
+ # 3. delta timestep
547
+ dt = sigma_interpol - sigma_hat
548
+
549
+ # store for 2nd order step
550
+ self.sample = sample
551
+ self.dt = dt
552
+ prev_sample = sample + derivative * dt
553
+ else:
554
+ # DPM-Solver-2
555
+ # 2. Convert to an ODE derivative for 2nd order
556
+ derivative = (sample - pred_original_sample) / sigma_interpol
557
+ # 3. delta timestep
558
+ dt = sigma_down - sigma_hat
559
+
560
+ sample = self.sample
561
+ self.sample = None
562
+
563
+ prev_sample = sample + derivative * dt
564
+ noise = randn_tensor(
565
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
566
+ )
567
+ prev_sample = prev_sample + noise * sigma_up
568
+
569
+ # upon completion increase step index by one
570
+ self._step_index += 1
571
+
572
+ if not return_dict:
573
+ return (
574
+ prev_sample,
575
+ pred_original_sample,
576
+ )
577
+
578
+ return KDPM2AncestralDiscreteSchedulerOutput(
579
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
580
+ )
581
+
582
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
583
+ def add_noise(
584
+ self,
585
+ original_samples: torch.Tensor,
586
+ noise: torch.Tensor,
587
+ timesteps: torch.Tensor,
588
+ ) -> torch.Tensor:
589
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
590
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
591
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
592
+ # mps does not support float64
593
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
594
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
595
+ else:
596
+ schedule_timesteps = self.timesteps.to(original_samples.device)
597
+ timesteps = timesteps.to(original_samples.device)
598
+
599
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
600
+ if self.begin_index is None:
601
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
602
+ elif self.step_index is not None:
603
+ # add_noise is called after first denoising step (for inpainting)
604
+ step_indices = [self.step_index] * timesteps.shape[0]
605
+ else:
606
+ # add noise is called before first denoising step to create initial latent(img2img)
607
+ step_indices = [self.begin_index] * timesteps.shape[0]
608
+
609
+ sigma = sigmas[step_indices].flatten()
610
+ while len(sigma.shape) < len(original_samples.shape):
611
+ sigma = sigma.unsqueeze(-1)
612
+
613
+ noisy_samples = original_samples + noise * sigma
614
+ return noisy_samples
615
+
616
+ def __len__(self):
617
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_k_dpm_2_discrete.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from ..configuration_utils import ConfigMixin, register_to_config
23
+ from ..utils import BaseOutput, is_scipy_available
24
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
25
+
26
+
27
+ if is_scipy_available():
28
+ import scipy.stats
29
+
30
+
31
+ @dataclass
32
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->KDPM2Discrete
33
+ class KDPM2DiscreteSchedulerOutput(BaseOutput):
34
+ """
35
+ Output class for the scheduler's `step` function output.
36
+
37
+ Args:
38
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40
+ denoising loop.
41
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43
+ `pred_original_sample` can be used to preview progress or for guidance.
44
+ """
45
+
46
+ prev_sample: torch.Tensor
47
+ pred_original_sample: Optional[torch.Tensor] = None
48
+
49
+
50
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
51
+ def betas_for_alpha_bar(
52
+ num_diffusion_timesteps,
53
+ max_beta=0.999,
54
+ alpha_transform_type="cosine",
55
+ ):
56
+ """
57
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
58
+ (1-beta) over time from t = [0,1].
59
+
60
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
61
+ to that part of the diffusion process.
62
+
63
+
64
+ Args:
65
+ num_diffusion_timesteps (`int`): the number of betas to produce.
66
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
67
+ prevent singularities.
68
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
69
+ Choose from `cosine` or `exp`
70
+
71
+ Returns:
72
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
73
+ """
74
+ if alpha_transform_type == "cosine":
75
+
76
+ def alpha_bar_fn(t):
77
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
78
+
79
+ elif alpha_transform_type == "exp":
80
+
81
+ def alpha_bar_fn(t):
82
+ return math.exp(t * -12.0)
83
+
84
+ else:
85
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
86
+
87
+ betas = []
88
+ for i in range(num_diffusion_timesteps):
89
+ t1 = i / num_diffusion_timesteps
90
+ t2 = (i + 1) / num_diffusion_timesteps
91
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
92
+ return torch.tensor(betas, dtype=torch.float32)
93
+
94
+
95
+ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
96
+ """
97
+ KDPM2DiscreteScheduler is inspired by the DPMSolver2 and Algorithm 2 from the [Elucidating the Design Space of
98
+ Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper.
99
+
100
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
101
+ methods the library implements for all schedulers such as loading and saving.
102
+
103
+ Args:
104
+ num_train_timesteps (`int`, defaults to 1000):
105
+ The number of diffusion steps to train the model.
106
+ beta_start (`float`, defaults to 0.00085):
107
+ The starting `beta` value of inference.
108
+ beta_end (`float`, defaults to 0.012):
109
+ The final `beta` value.
110
+ beta_schedule (`str`, defaults to `"linear"`):
111
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
112
+ `linear` or `scaled_linear`.
113
+ trained_betas (`np.ndarray`, *optional*):
114
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
115
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
116
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
117
+ the sigmas are determined according to a sequence of noise levels {σi}.
118
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
119
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
120
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
121
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
122
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
123
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
124
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
125
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
126
+ Video](https://imagen.research.google/video/paper.pdf) paper).
127
+ timestep_spacing (`str`, defaults to `"linspace"`):
128
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
129
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
130
+ steps_offset (`int`, defaults to 0):
131
+ An offset added to the inference steps, as required by some model families.
132
+ """
133
+
134
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
135
+ order = 2
136
+
137
+ @register_to_config
138
+ def __init__(
139
+ self,
140
+ num_train_timesteps: int = 1000,
141
+ beta_start: float = 0.00085, # sensible defaults
142
+ beta_end: float = 0.012,
143
+ beta_schedule: str = "linear",
144
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
145
+ use_karras_sigmas: Optional[bool] = False,
146
+ use_exponential_sigmas: Optional[bool] = False,
147
+ use_beta_sigmas: Optional[bool] = False,
148
+ prediction_type: str = "epsilon",
149
+ timestep_spacing: str = "linspace",
150
+ steps_offset: int = 0,
151
+ ):
152
+ if self.config.use_beta_sigmas and not is_scipy_available():
153
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
154
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
155
+ raise ValueError(
156
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
157
+ )
158
+ if trained_betas is not None:
159
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
160
+ elif beta_schedule == "linear":
161
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
162
+ elif beta_schedule == "scaled_linear":
163
+ # this schedule is very specific to the latent diffusion model.
164
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
165
+ elif beta_schedule == "squaredcos_cap_v2":
166
+ # Glide cosine schedule
167
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
168
+ else:
169
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
170
+
171
+ self.alphas = 1.0 - self.betas
172
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
173
+
174
+ # set all values
175
+ self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
176
+
177
+ self._step_index = None
178
+ self._begin_index = None
179
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
180
+
181
+ @property
182
+ def init_noise_sigma(self):
183
+ # standard deviation of the initial noise distribution
184
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
185
+ return self.sigmas.max()
186
+
187
+ return (self.sigmas.max() ** 2 + 1) ** 0.5
188
+
189
+ @property
190
+ def step_index(self):
191
+ """
192
+ The index counter for current timestep. It will increase 1 after each scheduler step.
193
+ """
194
+ return self._step_index
195
+
196
+ @property
197
+ def begin_index(self):
198
+ """
199
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
200
+ """
201
+ return self._begin_index
202
+
203
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
204
+ def set_begin_index(self, begin_index: int = 0):
205
+ """
206
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
207
+
208
+ Args:
209
+ begin_index (`int`):
210
+ The begin index for the scheduler.
211
+ """
212
+ self._begin_index = begin_index
213
+
214
+ def scale_model_input(
215
+ self,
216
+ sample: torch.Tensor,
217
+ timestep: Union[float, torch.Tensor],
218
+ ) -> torch.Tensor:
219
+ """
220
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
221
+ current timestep.
222
+
223
+ Args:
224
+ sample (`torch.Tensor`):
225
+ The input sample.
226
+ timestep (`int`, *optional*):
227
+ The current timestep in the diffusion chain.
228
+
229
+ Returns:
230
+ `torch.Tensor`:
231
+ A scaled input sample.
232
+ """
233
+ if self.step_index is None:
234
+ self._init_step_index(timestep)
235
+
236
+ if self.state_in_first_order:
237
+ sigma = self.sigmas[self.step_index]
238
+ else:
239
+ sigma = self.sigmas_interpol[self.step_index]
240
+
241
+ sample = sample / ((sigma**2 + 1) ** 0.5)
242
+ return sample
243
+
244
+ def set_timesteps(
245
+ self,
246
+ num_inference_steps: int,
247
+ device: Union[str, torch.device] = None,
248
+ num_train_timesteps: Optional[int] = None,
249
+ ):
250
+ """
251
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
252
+
253
+ Args:
254
+ num_inference_steps (`int`):
255
+ The number of diffusion steps used when generating samples with a pre-trained model.
256
+ device (`str` or `torch.device`, *optional*):
257
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
258
+ """
259
+ self.num_inference_steps = num_inference_steps
260
+
261
+ num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
262
+
263
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
264
+ if self.config.timestep_spacing == "linspace":
265
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
266
+ elif self.config.timestep_spacing == "leading":
267
+ step_ratio = num_train_timesteps // self.num_inference_steps
268
+ # creates integer timesteps by multiplying by ratio
269
+ # casting to int to avoid issues when num_inference_step is power of 3
270
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
271
+ timesteps += self.config.steps_offset
272
+ elif self.config.timestep_spacing == "trailing":
273
+ step_ratio = num_train_timesteps / self.num_inference_steps
274
+ # creates integer timesteps by multiplying by ratio
275
+ # casting to int to avoid issues when num_inference_step is power of 3
276
+ timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
277
+ timesteps -= 1
278
+ else:
279
+ raise ValueError(
280
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
281
+ )
282
+
283
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
284
+ log_sigmas = np.log(sigmas)
285
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
286
+
287
+ if self.config.use_karras_sigmas:
288
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
289
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
290
+ elif self.config.use_exponential_sigmas:
291
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
292
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
293
+ elif self.config.use_beta_sigmas:
294
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
295
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
296
+
297
+ self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device)
298
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
299
+ sigmas = torch.from_numpy(sigmas).to(device=device)
300
+
301
+ # interpolate sigmas
302
+ sigmas_interpol = sigmas.log().lerp(sigmas.roll(1).log(), 0.5).exp()
303
+
304
+ self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]])
305
+ self.sigmas_interpol = torch.cat(
306
+ [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]]
307
+ )
308
+
309
+ timesteps = torch.from_numpy(timesteps).to(device)
310
+
311
+ # interpolate timesteps
312
+ sigmas_interpol = sigmas_interpol.cpu()
313
+ log_sigmas = self.log_sigmas.cpu()
314
+ timesteps_interpol = np.array(
315
+ [self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
316
+ )
317
+ timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
318
+ interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
319
+
320
+ self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
321
+
322
+ self.sample = None
323
+
324
+ self._step_index = None
325
+ self._begin_index = None
326
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
327
+
328
+ @property
329
+ def state_in_first_order(self):
330
+ return self.sample is None
331
+
332
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
333
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
334
+ if schedule_timesteps is None:
335
+ schedule_timesteps = self.timesteps
336
+
337
+ indices = (schedule_timesteps == timestep).nonzero()
338
+
339
+ # The sigma index that is taken for the **very** first `step`
340
+ # is always the second index (or the last index if there is only 1)
341
+ # This way we can ensure we don't accidentally skip a sigma in
342
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
343
+ pos = 1 if len(indices) > 1 else 0
344
+
345
+ return indices[pos].item()
346
+
347
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
348
+ def _init_step_index(self, timestep):
349
+ if self.begin_index is None:
350
+ if isinstance(timestep, torch.Tensor):
351
+ timestep = timestep.to(self.timesteps.device)
352
+ self._step_index = self.index_for_timestep(timestep)
353
+ else:
354
+ self._step_index = self._begin_index
355
+
356
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
357
+ def _sigma_to_t(self, sigma, log_sigmas):
358
+ # get log sigma
359
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
360
+
361
+ # get distribution
362
+ dists = log_sigma - log_sigmas[:, np.newaxis]
363
+
364
+ # get sigmas range
365
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
366
+ high_idx = low_idx + 1
367
+
368
+ low = log_sigmas[low_idx]
369
+ high = log_sigmas[high_idx]
370
+
371
+ # interpolate sigmas
372
+ w = (low - log_sigma) / (low - high)
373
+ w = np.clip(w, 0, 1)
374
+
375
+ # transform interpolation to time range
376
+ t = (1 - w) * low_idx + w * high_idx
377
+ t = t.reshape(sigma.shape)
378
+ return t
379
+
380
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
381
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
382
+ """Constructs the noise schedule of Karras et al. (2022)."""
383
+
384
+ # Hack to make sure that other schedulers which copy this function don't break
385
+ # TODO: Add this logic to the other schedulers
386
+ if hasattr(self.config, "sigma_min"):
387
+ sigma_min = self.config.sigma_min
388
+ else:
389
+ sigma_min = None
390
+
391
+ if hasattr(self.config, "sigma_max"):
392
+ sigma_max = self.config.sigma_max
393
+ else:
394
+ sigma_max = None
395
+
396
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
397
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
398
+
399
+ rho = 7.0 # 7.0 is the value used in the paper
400
+ ramp = np.linspace(0, 1, num_inference_steps)
401
+ min_inv_rho = sigma_min ** (1 / rho)
402
+ max_inv_rho = sigma_max ** (1 / rho)
403
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
404
+ return sigmas
405
+
406
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
407
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
408
+ """Constructs an exponential noise schedule."""
409
+
410
+ # Hack to make sure that other schedulers which copy this function don't break
411
+ # TODO: Add this logic to the other schedulers
412
+ if hasattr(self.config, "sigma_min"):
413
+ sigma_min = self.config.sigma_min
414
+ else:
415
+ sigma_min = None
416
+
417
+ if hasattr(self.config, "sigma_max"):
418
+ sigma_max = self.config.sigma_max
419
+ else:
420
+ sigma_max = None
421
+
422
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
423
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
424
+
425
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
426
+ return sigmas
427
+
428
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
429
+ def _convert_to_beta(
430
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
431
+ ) -> torch.Tensor:
432
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
433
+
434
+ # Hack to make sure that other schedulers which copy this function don't break
435
+ # TODO: Add this logic to the other schedulers
436
+ if hasattr(self.config, "sigma_min"):
437
+ sigma_min = self.config.sigma_min
438
+ else:
439
+ sigma_min = None
440
+
441
+ if hasattr(self.config, "sigma_max"):
442
+ sigma_max = self.config.sigma_max
443
+ else:
444
+ sigma_max = None
445
+
446
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
447
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
448
+
449
+ sigmas = np.array(
450
+ [
451
+ sigma_min + (ppf * (sigma_max - sigma_min))
452
+ for ppf in [
453
+ scipy.stats.beta.ppf(timestep, alpha, beta)
454
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
455
+ ]
456
+ ]
457
+ )
458
+ return sigmas
459
+
460
+ def step(
461
+ self,
462
+ model_output: Union[torch.Tensor, np.ndarray],
463
+ timestep: Union[float, torch.Tensor],
464
+ sample: Union[torch.Tensor, np.ndarray],
465
+ return_dict: bool = True,
466
+ ) -> Union[KDPM2DiscreteSchedulerOutput, Tuple]:
467
+ """
468
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
469
+ process from the learned model outputs (most often the predicted noise).
470
+
471
+ Args:
472
+ model_output (`torch.Tensor`):
473
+ The direct output from learned diffusion model.
474
+ timestep (`float`):
475
+ The current discrete timestep in the diffusion chain.
476
+ sample (`torch.Tensor`):
477
+ A current instance of a sample created by the diffusion process.
478
+ return_dict (`bool`):
479
+ Whether or not to return a [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] or
480
+ tuple.
481
+
482
+ Returns:
483
+ [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] or `tuple`:
484
+ If return_dict is `True`, [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] is
485
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
486
+ """
487
+ if self.step_index is None:
488
+ self._init_step_index(timestep)
489
+
490
+ if self.state_in_first_order:
491
+ sigma = self.sigmas[self.step_index]
492
+ sigma_interpol = self.sigmas_interpol[self.step_index + 1]
493
+ sigma_next = self.sigmas[self.step_index + 1]
494
+ else:
495
+ # 2nd order / KDPM2's method
496
+ sigma = self.sigmas[self.step_index - 1]
497
+ sigma_interpol = self.sigmas_interpol[self.step_index]
498
+ sigma_next = self.sigmas[self.step_index]
499
+
500
+ # currently only gamma=0 is supported. This usually works best anyways.
501
+ # We can support gamma in the future but then need to scale the timestep before
502
+ # passing it to the model which requires a change in API
503
+ gamma = 0
504
+ sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
505
+
506
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
507
+ if self.config.prediction_type == "epsilon":
508
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
509
+ pred_original_sample = sample - sigma_input * model_output
510
+ elif self.config.prediction_type == "v_prediction":
511
+ sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
512
+ pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
513
+ sample / (sigma_input**2 + 1)
514
+ )
515
+ elif self.config.prediction_type == "sample":
516
+ raise NotImplementedError("prediction_type not implemented yet: sample")
517
+ else:
518
+ raise ValueError(
519
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
520
+ )
521
+
522
+ if self.state_in_first_order:
523
+ # 2. Convert to an ODE derivative for 1st order
524
+ derivative = (sample - pred_original_sample) / sigma_hat
525
+ # 3. delta timestep
526
+ dt = sigma_interpol - sigma_hat
527
+
528
+ # store for 2nd order step
529
+ self.sample = sample
530
+ else:
531
+ # DPM-Solver-2
532
+ # 2. Convert to an ODE derivative for 2nd order
533
+ derivative = (sample - pred_original_sample) / sigma_interpol
534
+
535
+ # 3. delta timestep
536
+ dt = sigma_next - sigma_hat
537
+
538
+ sample = self.sample
539
+ self.sample = None
540
+
541
+ # upon completion increase step index by one
542
+ self._step_index += 1
543
+
544
+ prev_sample = sample + derivative * dt
545
+
546
+ if not return_dict:
547
+ return (
548
+ prev_sample,
549
+ pred_original_sample,
550
+ )
551
+
552
+ return KDPM2DiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
553
+
554
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
555
+ def add_noise(
556
+ self,
557
+ original_samples: torch.Tensor,
558
+ noise: torch.Tensor,
559
+ timesteps: torch.Tensor,
560
+ ) -> torch.Tensor:
561
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
562
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
563
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
564
+ # mps does not support float64
565
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
566
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
567
+ else:
568
+ schedule_timesteps = self.timesteps.to(original_samples.device)
569
+ timesteps = timesteps.to(original_samples.device)
570
+
571
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
572
+ if self.begin_index is None:
573
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
574
+ elif self.step_index is not None:
575
+ # add_noise is called after first denoising step (for inpainting)
576
+ step_indices = [self.step_index] * timesteps.shape[0]
577
+ else:
578
+ # add noise is called before first denoising step to create initial latent(img2img)
579
+ step_indices = [self.begin_index] * timesteps.shape[0]
580
+
581
+ sigma = sigmas[step_indices].flatten()
582
+ while len(sigma.shape) < len(original_samples.shape):
583
+ sigma = sigma.unsqueeze(-1)
584
+
585
+ noisy_samples = original_samples + noise * sigma
586
+ return noisy_samples
587
+
588
+ def __len__(self):
589
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_karras_ve_flax.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 NVIDIA and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import flax
20
+ import jax
21
+ import jax.numpy as jnp
22
+ from jax import random
23
+
24
+ from ..configuration_utils import ConfigMixin, register_to_config
25
+ from ..utils import BaseOutput
26
+ from .scheduling_utils_flax import FlaxSchedulerMixin
27
+
28
+
29
+ @flax.struct.dataclass
30
+ class KarrasVeSchedulerState:
31
+ # setable values
32
+ num_inference_steps: Optional[int] = None
33
+ timesteps: Optional[jnp.ndarray] = None
34
+ schedule: Optional[jnp.ndarray] = None # sigma(t_i)
35
+
36
+ @classmethod
37
+ def create(cls):
38
+ return cls()
39
+
40
+
41
+ @dataclass
42
+ class FlaxKarrasVeOutput(BaseOutput):
43
+ """
44
+ Output class for the scheduler's step function output.
45
+
46
+ Args:
47
+ prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
48
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
49
+ denoising loop.
50
+ derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
51
+ Derivative of predicted original image sample (x_0).
52
+ state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
53
+ """
54
+
55
+ prev_sample: jnp.ndarray
56
+ derivative: jnp.ndarray
57
+ state: KarrasVeSchedulerState
58
+
59
+
60
+ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
61
+ """
62
+ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
63
+ the VE column of Table 1 from [1] for reference.
64
+
65
+ [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
66
+ https://huggingface.co/papers/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
67
+ differential equations." https://huggingface.co/papers/2011.13456
68
+
69
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
70
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
71
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
72
+ [`~SchedulerMixin.from_pretrained`] functions.
73
+
74
+ For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
75
+ Diffusion-Based Generative Models." https://huggingface.co/papers/2206.00364. The grid search values used to find
76
+ the optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
77
+
78
+ Args:
79
+ sigma_min (`float`): minimum noise magnitude
80
+ sigma_max (`float`): maximum noise magnitude
81
+ s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
82
+ A reasonable range is [1.000, 1.011].
83
+ s_churn (`float`): the parameter controlling the overall amount of stochasticity.
84
+ A reasonable range is [0, 100].
85
+ s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
86
+ A reasonable range is [0, 10].
87
+ s_max (`float`): the end value of the sigma range where we add noise.
88
+ A reasonable range is [0.2, 80].
89
+ """
90
+
91
+ @property
92
+ def has_state(self):
93
+ return True
94
+
95
+ @register_to_config
96
+ def __init__(
97
+ self,
98
+ sigma_min: float = 0.02,
99
+ sigma_max: float = 100,
100
+ s_noise: float = 1.007,
101
+ s_churn: float = 80,
102
+ s_min: float = 0.05,
103
+ s_max: float = 50,
104
+ ):
105
+ pass
106
+
107
+ def create_state(self):
108
+ return KarrasVeSchedulerState.create()
109
+
110
+ def set_timesteps(
111
+ self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = ()
112
+ ) -> KarrasVeSchedulerState:
113
+ """
114
+ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
115
+
116
+ Args:
117
+ state (`KarrasVeSchedulerState`):
118
+ the `FlaxKarrasVeScheduler` state data class.
119
+ num_inference_steps (`int`):
120
+ the number of diffusion steps used when generating samples with a pre-trained model.
121
+
122
+ """
123
+ timesteps = jnp.arange(0, num_inference_steps)[::-1].copy()
124
+ schedule = [
125
+ (
126
+ self.config.sigma_max**2
127
+ * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
128
+ )
129
+ for i in timesteps
130
+ ]
131
+
132
+ return state.replace(
133
+ num_inference_steps=num_inference_steps,
134
+ schedule=jnp.array(schedule, dtype=jnp.float32),
135
+ timesteps=timesteps,
136
+ )
137
+
138
+ def add_noise_to_input(
139
+ self,
140
+ state: KarrasVeSchedulerState,
141
+ sample: jnp.ndarray,
142
+ sigma: float,
143
+ key: jax.Array,
144
+ ) -> Tuple[jnp.ndarray, float]:
145
+ """
146
+ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
147
+ higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
148
+
149
+ TODO Args:
150
+ """
151
+ if self.config.s_min <= sigma <= self.config.s_max:
152
+ gamma = min(self.config.s_churn / state.num_inference_steps, 2**0.5 - 1)
153
+ else:
154
+ gamma = 0
155
+
156
+ # sample eps ~ N(0, S_noise^2 * I)
157
+ key = random.split(key, num=1)
158
+ eps = self.config.s_noise * random.normal(key=key, shape=sample.shape)
159
+ sigma_hat = sigma + gamma * sigma
160
+ sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
161
+
162
+ return sample_hat, sigma_hat
163
+
164
+ def step(
165
+ self,
166
+ state: KarrasVeSchedulerState,
167
+ model_output: jnp.ndarray,
168
+ sigma_hat: float,
169
+ sigma_prev: float,
170
+ sample_hat: jnp.ndarray,
171
+ return_dict: bool = True,
172
+ ) -> Union[FlaxKarrasVeOutput, Tuple]:
173
+ """
174
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
175
+ process from the learned model outputs (most often the predicted noise).
176
+
177
+ Args:
178
+ state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
179
+ model_output (`torch.Tensor` or `np.ndarray`): direct output from learned diffusion model.
180
+ sigma_hat (`float`): TODO
181
+ sigma_prev (`float`): TODO
182
+ sample_hat (`torch.Tensor` or `np.ndarray`): TODO
183
+ return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
184
+
185
+ Returns:
186
+ [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion
187
+ chain and derivative. [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] if `return_dict` is
188
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
189
+ """
190
+
191
+ pred_original_sample = sample_hat + sigma_hat * model_output
192
+ derivative = (sample_hat - pred_original_sample) / sigma_hat
193
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
194
+
195
+ if not return_dict:
196
+ return (sample_prev, derivative, state)
197
+
198
+ return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state)
199
+
200
+ def step_correct(
201
+ self,
202
+ state: KarrasVeSchedulerState,
203
+ model_output: jnp.ndarray,
204
+ sigma_hat: float,
205
+ sigma_prev: float,
206
+ sample_hat: jnp.ndarray,
207
+ sample_prev: jnp.ndarray,
208
+ derivative: jnp.ndarray,
209
+ return_dict: bool = True,
210
+ ) -> Union[FlaxKarrasVeOutput, Tuple]:
211
+ """
212
+ Correct the predicted sample based on the output model_output of the network. TODO complete description
213
+
214
+ Args:
215
+ state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
216
+ model_output (`torch.Tensor` or `np.ndarray`): direct output from learned diffusion model.
217
+ sigma_hat (`float`): TODO
218
+ sigma_prev (`float`): TODO
219
+ sample_hat (`torch.Tensor` or `np.ndarray`): TODO
220
+ sample_prev (`torch.Tensor` or `np.ndarray`): TODO
221
+ derivative (`torch.Tensor` or `np.ndarray`): TODO
222
+ return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
223
+
224
+ Returns:
225
+ prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
226
+
227
+ """
228
+ pred_original_sample = sample_prev + sigma_prev * model_output
229
+ derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
230
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
231
+
232
+ if not return_dict:
233
+ return (sample_prev, derivative, state)
234
+
235
+ return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state)
236
+
237
+ def add_noise(self, state: KarrasVeSchedulerState, original_samples, noise, timesteps):
238
+ raise NotImplementedError()
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_lcm.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from ..configuration_utils import ConfigMixin, register_to_config
26
+ from ..utils import BaseOutput, logging
27
+ from ..utils.torch_utils import randn_tensor
28
+ from .scheduling_utils import SchedulerMixin
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ class LCMSchedulerOutput(BaseOutput):
36
+ """
37
+ Output class for the scheduler's `step` function output.
38
+
39
+ Args:
40
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
+ denoising loop.
43
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
44
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
45
+ `pred_original_sample` can be used to preview progress or for guidance.
46
+ """
47
+
48
+ prev_sample: torch.Tensor
49
+ denoised: Optional[torch.Tensor] = None
50
+
51
+
52
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
53
+ def betas_for_alpha_bar(
54
+ num_diffusion_timesteps,
55
+ max_beta=0.999,
56
+ alpha_transform_type="cosine",
57
+ ):
58
+ """
59
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
60
+ (1-beta) over time from t = [0,1].
61
+
62
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
63
+ to that part of the diffusion process.
64
+
65
+
66
+ Args:
67
+ num_diffusion_timesteps (`int`): the number of betas to produce.
68
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
69
+ prevent singularities.
70
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
71
+ Choose from `cosine` or `exp`
72
+
73
+ Returns:
74
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
75
+ """
76
+ if alpha_transform_type == "cosine":
77
+
78
+ def alpha_bar_fn(t):
79
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
80
+
81
+ elif alpha_transform_type == "exp":
82
+
83
+ def alpha_bar_fn(t):
84
+ return math.exp(t * -12.0)
85
+
86
+ else:
87
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
88
+
89
+ betas = []
90
+ for i in range(num_diffusion_timesteps):
91
+ t1 = i / num_diffusion_timesteps
92
+ t2 = (i + 1) / num_diffusion_timesteps
93
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
94
+ return torch.tensor(betas, dtype=torch.float32)
95
+
96
+
97
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
98
+ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
99
+ """
100
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
101
+
102
+
103
+ Args:
104
+ betas (`torch.Tensor`):
105
+ the betas that the scheduler is being initialized with.
106
+
107
+ Returns:
108
+ `torch.Tensor`: rescaled betas with zero terminal SNR
109
+ """
110
+ # Convert betas to alphas_bar_sqrt
111
+ alphas = 1.0 - betas
112
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
113
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
114
+
115
+ # Store old values.
116
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
117
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
118
+
119
+ # Shift so the last timestep is zero.
120
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
121
+
122
+ # Scale so the first timestep is back to the old value.
123
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
124
+
125
+ # Convert alphas_bar_sqrt to betas
126
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
127
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
128
+ alphas = torch.cat([alphas_bar[0:1], alphas])
129
+ betas = 1 - alphas
130
+
131
+ return betas
132
+
133
+
134
+ class LCMScheduler(SchedulerMixin, ConfigMixin):
135
+ """
136
+ `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
137
+ non-Markovian guidance.
138
+
139
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
140
+ attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
141
+ accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
142
+ functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
143
+
144
+ Args:
145
+ num_train_timesteps (`int`, defaults to 1000):
146
+ The number of diffusion steps to train the model.
147
+ beta_start (`float`, defaults to 0.0001):
148
+ The starting `beta` value of inference.
149
+ beta_end (`float`, defaults to 0.02):
150
+ The final `beta` value.
151
+ beta_schedule (`str`, defaults to `"linear"`):
152
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
153
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
154
+ trained_betas (`np.ndarray`, *optional*):
155
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
156
+ original_inference_steps (`int`, *optional*, defaults to 50):
157
+ The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
158
+ will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
159
+ clip_sample (`bool`, defaults to `True`):
160
+ Clip the predicted sample for numerical stability.
161
+ clip_sample_range (`float`, defaults to 1.0):
162
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
163
+ set_alpha_to_one (`bool`, defaults to `True`):
164
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
165
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
166
+ otherwise it uses the alpha value at step 0.
167
+ steps_offset (`int`, defaults to 0):
168
+ An offset added to the inference steps, as required by some model families.
169
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
170
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
171
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
172
+ Video](https://imagen.research.google/video/paper.pdf) paper).
173
+ thresholding (`bool`, defaults to `False`):
174
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
175
+ as Stable Diffusion.
176
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
177
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
178
+ sample_max_value (`float`, defaults to 1.0):
179
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
180
+ timestep_spacing (`str`, defaults to `"leading"`):
181
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
182
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
183
+ timestep_scaling (`float`, defaults to 10.0):
184
+ The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
185
+ `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
186
+ error at the default of `10.0` is already pretty small).
187
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
188
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
189
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
190
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
191
+ """
192
+
193
+ order = 1
194
+
195
+ @register_to_config
196
+ def __init__(
197
+ self,
198
+ num_train_timesteps: int = 1000,
199
+ beta_start: float = 0.00085,
200
+ beta_end: float = 0.012,
201
+ beta_schedule: str = "scaled_linear",
202
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
203
+ original_inference_steps: int = 50,
204
+ clip_sample: bool = False,
205
+ clip_sample_range: float = 1.0,
206
+ set_alpha_to_one: bool = True,
207
+ steps_offset: int = 0,
208
+ prediction_type: str = "epsilon",
209
+ thresholding: bool = False,
210
+ dynamic_thresholding_ratio: float = 0.995,
211
+ sample_max_value: float = 1.0,
212
+ timestep_spacing: str = "leading",
213
+ timestep_scaling: float = 10.0,
214
+ rescale_betas_zero_snr: bool = False,
215
+ ):
216
+ if trained_betas is not None:
217
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
218
+ elif beta_schedule == "linear":
219
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
220
+ elif beta_schedule == "scaled_linear":
221
+ # this schedule is very specific to the latent diffusion model.
222
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
223
+ elif beta_schedule == "squaredcos_cap_v2":
224
+ # Glide cosine schedule
225
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
226
+ else:
227
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
228
+
229
+ # Rescale for zero SNR
230
+ if rescale_betas_zero_snr:
231
+ self.betas = rescale_zero_terminal_snr(self.betas)
232
+
233
+ self.alphas = 1.0 - self.betas
234
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
235
+
236
+ # At every step in ddim, we are looking into the previous alphas_cumprod
237
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
238
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
239
+ # whether we use the final alpha of the "non-previous" one.
240
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
241
+
242
+ # standard deviation of the initial noise distribution
243
+ self.init_noise_sigma = 1.0
244
+
245
+ # setable values
246
+ self.num_inference_steps = None
247
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
248
+ self.custom_timesteps = False
249
+
250
+ self._step_index = None
251
+ self._begin_index = None
252
+
253
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
254
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
255
+ if schedule_timesteps is None:
256
+ schedule_timesteps = self.timesteps
257
+
258
+ indices = (schedule_timesteps == timestep).nonzero()
259
+
260
+ # The sigma index that is taken for the **very** first `step`
261
+ # is always the second index (or the last index if there is only 1)
262
+ # This way we can ensure we don't accidentally skip a sigma in
263
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
264
+ pos = 1 if len(indices) > 1 else 0
265
+
266
+ return indices[pos].item()
267
+
268
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
269
+ def _init_step_index(self, timestep):
270
+ if self.begin_index is None:
271
+ if isinstance(timestep, torch.Tensor):
272
+ timestep = timestep.to(self.timesteps.device)
273
+ self._step_index = self.index_for_timestep(timestep)
274
+ else:
275
+ self._step_index = self._begin_index
276
+
277
+ @property
278
+ def step_index(self):
279
+ return self._step_index
280
+
281
+ @property
282
+ def begin_index(self):
283
+ """
284
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
285
+ """
286
+ return self._begin_index
287
+
288
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
289
+ def set_begin_index(self, begin_index: int = 0):
290
+ """
291
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
292
+
293
+ Args:
294
+ begin_index (`int`):
295
+ The begin index for the scheduler.
296
+ """
297
+ self._begin_index = begin_index
298
+
299
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
300
+ """
301
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
302
+ current timestep.
303
+
304
+ Args:
305
+ sample (`torch.Tensor`):
306
+ The input sample.
307
+ timestep (`int`, *optional*):
308
+ The current timestep in the diffusion chain.
309
+ Returns:
310
+ `torch.Tensor`:
311
+ A scaled input sample.
312
+ """
313
+ return sample
314
+
315
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
316
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
317
+ """
318
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
319
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
320
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
321
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
322
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
323
+
324
+ https://huggingface.co/papers/2205.11487
325
+ """
326
+ dtype = sample.dtype
327
+ batch_size, channels, *remaining_dims = sample.shape
328
+
329
+ if dtype not in (torch.float32, torch.float64):
330
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
331
+
332
+ # Flatten sample for doing quantile calculation along each image
333
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
334
+
335
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
336
+
337
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
338
+ s = torch.clamp(
339
+ s, min=1, max=self.config.sample_max_value
340
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
341
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
342
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
343
+
344
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
345
+ sample = sample.to(dtype)
346
+
347
+ return sample
348
+
349
+ def set_timesteps(
350
+ self,
351
+ num_inference_steps: Optional[int] = None,
352
+ device: Union[str, torch.device] = None,
353
+ original_inference_steps: Optional[int] = None,
354
+ timesteps: Optional[List[int]] = None,
355
+ strength: int = 1.0,
356
+ ):
357
+ """
358
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
359
+
360
+ Args:
361
+ num_inference_steps (`int`, *optional*):
362
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
363
+ `timesteps` must be `None`.
364
+ device (`str` or `torch.device`, *optional*):
365
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
366
+ original_inference_steps (`int`, *optional*):
367
+ The original number of inference steps, which will be used to generate a linearly-spaced timestep
368
+ schedule (which is different from the standard `diffusers` implementation). We will then take
369
+ `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
370
+ our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
371
+ timesteps (`List[int]`, *optional*):
372
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
373
+ timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
374
+ schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
375
+ """
376
+ # 0. Check inputs
377
+ if num_inference_steps is None and timesteps is None:
378
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
379
+
380
+ if num_inference_steps is not None and timesteps is not None:
381
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
382
+
383
+ # 1. Calculate the LCM original training/distillation timestep schedule.
384
+ original_steps = (
385
+ original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
386
+ )
387
+
388
+ if original_steps > self.config.num_train_timesteps:
389
+ raise ValueError(
390
+ f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
391
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
392
+ f" maximal {self.config.num_train_timesteps} timesteps."
393
+ )
394
+
395
+ # LCM Timesteps Setting
396
+ # The skipping step parameter k from the paper.
397
+ k = self.config.num_train_timesteps // original_steps
398
+ # LCM Training/Distillation Steps Schedule
399
+ # Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts).
400
+ lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
401
+
402
+ # 2. Calculate the LCM inference timestep schedule.
403
+ if timesteps is not None:
404
+ # 2.1 Handle custom timestep schedules.
405
+ train_timesteps = set(lcm_origin_timesteps)
406
+ non_train_timesteps = []
407
+ for i in range(1, len(timesteps)):
408
+ if timesteps[i] >= timesteps[i - 1]:
409
+ raise ValueError("`custom_timesteps` must be in descending order.")
410
+
411
+ if timesteps[i] not in train_timesteps:
412
+ non_train_timesteps.append(timesteps[i])
413
+
414
+ if timesteps[0] >= self.config.num_train_timesteps:
415
+ raise ValueError(
416
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
417
+ )
418
+
419
+ # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
420
+ if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1:
421
+ logger.warning(
422
+ f"The first timestep on the custom timestep schedule is {timesteps[0]}, not"
423
+ f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get"
424
+ f" unexpected results when using this timestep schedule."
425
+ )
426
+
427
+ # Raise warning if custom timestep schedule contains timesteps not on original timestep schedule
428
+ if non_train_timesteps:
429
+ logger.warning(
430
+ f"The custom timestep schedule contains the following timesteps which are not on the original"
431
+ f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results"
432
+ f" when using this timestep schedule."
433
+ )
434
+
435
+ # Raise warning if custom timestep schedule is longer than original_steps
436
+ if len(timesteps) > original_steps:
437
+ logger.warning(
438
+ f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
439
+ f" the length of the timestep schedule used for training: {original_steps}. You may get some"
440
+ f" unexpected results when using this timestep schedule."
441
+ )
442
+
443
+ timesteps = np.array(timesteps, dtype=np.int64)
444
+ self.num_inference_steps = len(timesteps)
445
+ self.custom_timesteps = True
446
+
447
+ # Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps)
448
+ init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps)
449
+ t_start = max(self.num_inference_steps - init_timestep, 0)
450
+ timesteps = timesteps[t_start * self.order :]
451
+ # TODO: also reset self.num_inference_steps?
452
+ else:
453
+ # 2.2 Create the "standard" LCM inference timestep schedule.
454
+ if num_inference_steps > self.config.num_train_timesteps:
455
+ raise ValueError(
456
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
457
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
458
+ f" maximal {self.config.num_train_timesteps} timesteps."
459
+ )
460
+
461
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
462
+
463
+ if skipping_step < 1:
464
+ raise ValueError(
465
+ f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
466
+ )
467
+
468
+ self.num_inference_steps = num_inference_steps
469
+
470
+ if num_inference_steps > original_steps:
471
+ raise ValueError(
472
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
473
+ f" {original_steps} because the final timestep schedule will be a subset of the"
474
+ f" `original_inference_steps`-sized initial timestep schedule."
475
+ )
476
+
477
+ # LCM Inference Steps Schedule
478
+ lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy()
479
+ # Select (approximately) evenly spaced indices from lcm_origin_timesteps.
480
+ inference_indices = np.linspace(0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False)
481
+ inference_indices = np.floor(inference_indices).astype(np.int64)
482
+ timesteps = lcm_origin_timesteps[inference_indices]
483
+
484
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
485
+
486
+ self._step_index = None
487
+ self._begin_index = None
488
+
489
+ def get_scalings_for_boundary_condition_discrete(self, timestep):
490
+ self.sigma_data = 0.5 # Default: 0.5
491
+ scaled_timestep = timestep * self.config.timestep_scaling
492
+
493
+ c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
494
+ c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
495
+ return c_skip, c_out
496
+
497
+ def step(
498
+ self,
499
+ model_output: torch.Tensor,
500
+ timestep: int,
501
+ sample: torch.Tensor,
502
+ generator: Optional[torch.Generator] = None,
503
+ return_dict: bool = True,
504
+ ) -> Union[LCMSchedulerOutput, Tuple]:
505
+ """
506
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
507
+ process from the learned model outputs (most often the predicted noise).
508
+
509
+ Args:
510
+ model_output (`torch.Tensor`):
511
+ The direct output from learned diffusion model.
512
+ timestep (`float`):
513
+ The current discrete timestep in the diffusion chain.
514
+ sample (`torch.Tensor`):
515
+ A current instance of a sample created by the diffusion process.
516
+ generator (`torch.Generator`, *optional*):
517
+ A random number generator.
518
+ return_dict (`bool`, *optional*, defaults to `True`):
519
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
520
+ Returns:
521
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
522
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
523
+ tuple is returned where the first element is the sample tensor.
524
+ """
525
+ if self.num_inference_steps is None:
526
+ raise ValueError(
527
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
528
+ )
529
+
530
+ if self.step_index is None:
531
+ self._init_step_index(timestep)
532
+
533
+ # 1. get previous step value
534
+ prev_step_index = self.step_index + 1
535
+ if prev_step_index < len(self.timesteps):
536
+ prev_timestep = self.timesteps[prev_step_index]
537
+ else:
538
+ prev_timestep = timestep
539
+
540
+ # 2. compute alphas, betas
541
+ alpha_prod_t = self.alphas_cumprod[timestep]
542
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
543
+
544
+ beta_prod_t = 1 - alpha_prod_t
545
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
546
+
547
+ # 3. Get scalings for boundary conditions
548
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
549
+
550
+ # 4. Compute the predicted original sample x_0 based on the model parameterization
551
+ if self.config.prediction_type == "epsilon": # noise-prediction
552
+ predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
553
+ elif self.config.prediction_type == "sample": # x-prediction
554
+ predicted_original_sample = model_output
555
+ elif self.config.prediction_type == "v_prediction": # v-prediction
556
+ predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
557
+ else:
558
+ raise ValueError(
559
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
560
+ " `v_prediction` for `LCMScheduler`."
561
+ )
562
+
563
+ # 5. Clip or threshold "predicted x_0"
564
+ if self.config.thresholding:
565
+ predicted_original_sample = self._threshold_sample(predicted_original_sample)
566
+ elif self.config.clip_sample:
567
+ predicted_original_sample = predicted_original_sample.clamp(
568
+ -self.config.clip_sample_range, self.config.clip_sample_range
569
+ )
570
+
571
+ # 6. Denoise model output using boundary conditions
572
+ denoised = c_out * predicted_original_sample + c_skip * sample
573
+
574
+ # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
575
+ # Noise is not used on the final timestep of the timestep schedule.
576
+ # This also means that noise is not used for one-step sampling.
577
+ if self.step_index != self.num_inference_steps - 1:
578
+ noise = randn_tensor(
579
+ model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype
580
+ )
581
+ prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
582
+ else:
583
+ prev_sample = denoised
584
+
585
+ # upon completion increase step index by one
586
+ self._step_index += 1
587
+
588
+ if not return_dict:
589
+ return (prev_sample, denoised)
590
+
591
+ return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
592
+
593
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
594
+ def add_noise(
595
+ self,
596
+ original_samples: torch.Tensor,
597
+ noise: torch.Tensor,
598
+ timesteps: torch.IntTensor,
599
+ ) -> torch.Tensor:
600
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
601
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
602
+ # for the subsequent add_noise calls
603
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
604
+ alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
605
+ timesteps = timesteps.to(original_samples.device)
606
+
607
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
608
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
609
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
610
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
611
+
612
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
613
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
614
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
615
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
616
+
617
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
618
+ return noisy_samples
619
+
620
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
621
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
622
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
623
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
624
+ alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
625
+ timesteps = timesteps.to(sample.device)
626
+
627
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
628
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
629
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
630
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
631
+
632
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
633
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
634
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
635
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
636
+
637
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
638
+ return velocity
639
+
640
+ def __len__(self):
641
+ return self.config.num_train_timesteps
642
+
643
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
644
+ def previous_timestep(self, timestep):
645
+ if self.custom_timesteps or self.num_inference_steps:
646
+ index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
647
+ if index == self.timesteps.shape[0] - 1:
648
+ prev_t = torch.tensor(-1)
649
+ else:
650
+ prev_t = self.timesteps[index + 1]
651
+ else:
652
+ prev_t = timestep - 1
653
+ return prev_t
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_lms_discrete.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ import warnings
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import scipy.stats
21
+ import torch
22
+ from scipy import integrate
23
+
24
+ from ..configuration_utils import ConfigMixin, register_to_config
25
+ from ..utils import BaseOutput
26
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
27
+
28
+
29
+ @dataclass
30
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete
31
+ class LMSDiscreteSchedulerOutput(BaseOutput):
32
+ """
33
+ Output class for the scheduler's `step` function output.
34
+
35
+ Args:
36
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
+ denoising loop.
39
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
41
+ `pred_original_sample` can be used to preview progress or for guidance.
42
+ """
43
+
44
+ prev_sample: torch.Tensor
45
+ pred_original_sample: Optional[torch.Tensor] = None
46
+
47
+
48
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
49
+ def betas_for_alpha_bar(
50
+ num_diffusion_timesteps,
51
+ max_beta=0.999,
52
+ alpha_transform_type="cosine",
53
+ ):
54
+ """
55
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
56
+ (1-beta) over time from t = [0,1].
57
+
58
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
59
+ to that part of the diffusion process.
60
+
61
+
62
+ Args:
63
+ num_diffusion_timesteps (`int`): the number of betas to produce.
64
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
65
+ prevent singularities.
66
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
67
+ Choose from `cosine` or `exp`
68
+
69
+ Returns:
70
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
71
+ """
72
+ if alpha_transform_type == "cosine":
73
+
74
+ def alpha_bar_fn(t):
75
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
76
+
77
+ elif alpha_transform_type == "exp":
78
+
79
+ def alpha_bar_fn(t):
80
+ return math.exp(t * -12.0)
81
+
82
+ else:
83
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
84
+
85
+ betas = []
86
+ for i in range(num_diffusion_timesteps):
87
+ t1 = i / num_diffusion_timesteps
88
+ t2 = (i + 1) / num_diffusion_timesteps
89
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
90
+ return torch.tensor(betas, dtype=torch.float32)
91
+
92
+
93
+ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
94
+ """
95
+ A linear multistep scheduler for discrete beta schedules.
96
+
97
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
98
+ methods the library implements for all schedulers such as loading and saving.
99
+
100
+ Args:
101
+ num_train_timesteps (`int`, defaults to 1000):
102
+ The number of diffusion steps to train the model.
103
+ beta_start (`float`, defaults to 0.0001):
104
+ The starting `beta` value of inference.
105
+ beta_end (`float`, defaults to 0.02):
106
+ The final `beta` value.
107
+ beta_schedule (`str`, defaults to `"linear"`):
108
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
109
+ `linear` or `scaled_linear`.
110
+ trained_betas (`np.ndarray`, *optional*):
111
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
112
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
113
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
114
+ the sigmas are determined according to a sequence of noise levels {σi}.
115
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
116
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
117
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
118
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
119
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
120
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
121
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
122
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
123
+ Video](https://imagen.research.google/video/paper.pdf) paper).
124
+ timestep_spacing (`str`, defaults to `"linspace"`):
125
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
126
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
127
+ steps_offset (`int`, defaults to 0):
128
+ An offset added to the inference steps, as required by some model families.
129
+ """
130
+
131
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
132
+ order = 1
133
+
134
+ @register_to_config
135
+ def __init__(
136
+ self,
137
+ num_train_timesteps: int = 1000,
138
+ beta_start: float = 0.0001,
139
+ beta_end: float = 0.02,
140
+ beta_schedule: str = "linear",
141
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
142
+ use_karras_sigmas: Optional[bool] = False,
143
+ use_exponential_sigmas: Optional[bool] = False,
144
+ use_beta_sigmas: Optional[bool] = False,
145
+ prediction_type: str = "epsilon",
146
+ timestep_spacing: str = "linspace",
147
+ steps_offset: int = 0,
148
+ ):
149
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
150
+ raise ValueError(
151
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
152
+ )
153
+ if trained_betas is not None:
154
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
155
+ elif beta_schedule == "linear":
156
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
157
+ elif beta_schedule == "scaled_linear":
158
+ # this schedule is very specific to the latent diffusion model.
159
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
160
+ elif beta_schedule == "squaredcos_cap_v2":
161
+ # Glide cosine schedule
162
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
163
+ else:
164
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
165
+
166
+ self.alphas = 1.0 - self.betas
167
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
168
+
169
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
170
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
171
+ self.sigmas = torch.from_numpy(sigmas)
172
+
173
+ # setable values
174
+ self.num_inference_steps = None
175
+ self.use_karras_sigmas = use_karras_sigmas
176
+ self.set_timesteps(num_train_timesteps, None)
177
+ self.derivatives = []
178
+ self.is_scale_input_called = False
179
+
180
+ self._step_index = None
181
+ self._begin_index = None
182
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
183
+
184
+ @property
185
+ def init_noise_sigma(self):
186
+ # standard deviation of the initial noise distribution
187
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
188
+ return self.sigmas.max()
189
+
190
+ return (self.sigmas.max() ** 2 + 1) ** 0.5
191
+
192
+ @property
193
+ def step_index(self):
194
+ """
195
+ The index counter for current timestep. It will increase 1 after each scheduler step.
196
+ """
197
+ return self._step_index
198
+
199
+ @property
200
+ def begin_index(self):
201
+ """
202
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
203
+ """
204
+ return self._begin_index
205
+
206
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
207
+ def set_begin_index(self, begin_index: int = 0):
208
+ """
209
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
210
+
211
+ Args:
212
+ begin_index (`int`):
213
+ The begin index for the scheduler.
214
+ """
215
+ self._begin_index = begin_index
216
+
217
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
218
+ """
219
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
220
+ current timestep.
221
+
222
+ Args:
223
+ sample (`torch.Tensor`):
224
+ The input sample.
225
+ timestep (`float` or `torch.Tensor`):
226
+ The current timestep in the diffusion chain.
227
+
228
+ Returns:
229
+ `torch.Tensor`:
230
+ A scaled input sample.
231
+ """
232
+
233
+ if self.step_index is None:
234
+ self._init_step_index(timestep)
235
+
236
+ sigma = self.sigmas[self.step_index]
237
+ sample = sample / ((sigma**2 + 1) ** 0.5)
238
+ self.is_scale_input_called = True
239
+ return sample
240
+
241
+ def get_lms_coefficient(self, order, t, current_order):
242
+ """
243
+ Compute the linear multistep coefficient.
244
+
245
+ Args:
246
+ order ():
247
+ t ():
248
+ current_order ():
249
+ """
250
+
251
+ def lms_derivative(tau):
252
+ prod = 1.0
253
+ for k in range(order):
254
+ if current_order == k:
255
+ continue
256
+ prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
257
+ return prod
258
+
259
+ integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]
260
+
261
+ return integrated_coeff
262
+
263
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
264
+ """
265
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
266
+
267
+ Args:
268
+ num_inference_steps (`int`):
269
+ The number of diffusion steps used when generating samples with a pre-trained model.
270
+ device (`str` or `torch.device`, *optional*):
271
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
272
+ """
273
+ self.num_inference_steps = num_inference_steps
274
+
275
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
276
+ if self.config.timestep_spacing == "linspace":
277
+ timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
278
+ ::-1
279
+ ].copy()
280
+ elif self.config.timestep_spacing == "leading":
281
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
282
+ # creates integer timesteps by multiplying by ratio
283
+ # casting to int to avoid issues when num_inference_step is power of 3
284
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
285
+ timesteps += self.config.steps_offset
286
+ elif self.config.timestep_spacing == "trailing":
287
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
288
+ # creates integer timesteps by multiplying by ratio
289
+ # casting to int to avoid issues when num_inference_step is power of 3
290
+ timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
291
+ timesteps -= 1
292
+ else:
293
+ raise ValueError(
294
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
295
+ )
296
+
297
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
298
+ log_sigmas = np.log(sigmas)
299
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
300
+
301
+ if self.config.use_karras_sigmas:
302
+ sigmas = self._convert_to_karras(in_sigmas=sigmas)
303
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
304
+ elif self.config.use_exponential_sigmas:
305
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
306
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
307
+ elif self.config.use_beta_sigmas:
308
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
309
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
310
+
311
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
312
+
313
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
314
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.float32)
315
+ self._step_index = None
316
+ self._begin_index = None
317
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
318
+
319
+ self.derivatives = []
320
+
321
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
322
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
323
+ if schedule_timesteps is None:
324
+ schedule_timesteps = self.timesteps
325
+
326
+ indices = (schedule_timesteps == timestep).nonzero()
327
+
328
+ # The sigma index that is taken for the **very** first `step`
329
+ # is always the second index (or the last index if there is only 1)
330
+ # This way we can ensure we don't accidentally skip a sigma in
331
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
332
+ pos = 1 if len(indices) > 1 else 0
333
+
334
+ return indices[pos].item()
335
+
336
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
337
+ def _init_step_index(self, timestep):
338
+ if self.begin_index is None:
339
+ if isinstance(timestep, torch.Tensor):
340
+ timestep = timestep.to(self.timesteps.device)
341
+ self._step_index = self.index_for_timestep(timestep)
342
+ else:
343
+ self._step_index = self._begin_index
344
+
345
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
346
+ def _sigma_to_t(self, sigma, log_sigmas):
347
+ # get log sigma
348
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
349
+
350
+ # get distribution
351
+ dists = log_sigma - log_sigmas[:, np.newaxis]
352
+
353
+ # get sigmas range
354
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
355
+ high_idx = low_idx + 1
356
+
357
+ low = log_sigmas[low_idx]
358
+ high = log_sigmas[high_idx]
359
+
360
+ # interpolate sigmas
361
+ w = (low - log_sigma) / (low - high)
362
+ w = np.clip(w, 0, 1)
363
+
364
+ # transform interpolation to time range
365
+ t = (1 - w) * low_idx + w * high_idx
366
+ t = t.reshape(sigma.shape)
367
+ return t
368
+
369
+ # copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
370
+ def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
371
+ """Constructs the noise schedule of Karras et al. (2022)."""
372
+
373
+ sigma_min: float = in_sigmas[-1].item()
374
+ sigma_max: float = in_sigmas[0].item()
375
+
376
+ rho = 7.0 # 7.0 is the value used in the paper
377
+ ramp = np.linspace(0, 1, self.num_inference_steps)
378
+ min_inv_rho = sigma_min ** (1 / rho)
379
+ max_inv_rho = sigma_max ** (1 / rho)
380
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
381
+ return sigmas
382
+
383
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
384
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
385
+ """Constructs an exponential noise schedule."""
386
+
387
+ # Hack to make sure that other schedulers which copy this function don't break
388
+ # TODO: Add this logic to the other schedulers
389
+ if hasattr(self.config, "sigma_min"):
390
+ sigma_min = self.config.sigma_min
391
+ else:
392
+ sigma_min = None
393
+
394
+ if hasattr(self.config, "sigma_max"):
395
+ sigma_max = self.config.sigma_max
396
+ else:
397
+ sigma_max = None
398
+
399
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
400
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
401
+
402
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
403
+ return sigmas
404
+
405
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
406
+ def _convert_to_beta(
407
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
408
+ ) -> torch.Tensor:
409
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
410
+
411
+ # Hack to make sure that other schedulers which copy this function don't break
412
+ # TODO: Add this logic to the other schedulers
413
+ if hasattr(self.config, "sigma_min"):
414
+ sigma_min = self.config.sigma_min
415
+ else:
416
+ sigma_min = None
417
+
418
+ if hasattr(self.config, "sigma_max"):
419
+ sigma_max = self.config.sigma_max
420
+ else:
421
+ sigma_max = None
422
+
423
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
424
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
425
+
426
+ sigmas = np.array(
427
+ [
428
+ sigma_min + (ppf * (sigma_max - sigma_min))
429
+ for ppf in [
430
+ scipy.stats.beta.ppf(timestep, alpha, beta)
431
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
432
+ ]
433
+ ]
434
+ )
435
+ return sigmas
436
+
437
+ def step(
438
+ self,
439
+ model_output: torch.Tensor,
440
+ timestep: Union[float, torch.Tensor],
441
+ sample: torch.Tensor,
442
+ order: int = 4,
443
+ return_dict: bool = True,
444
+ ) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
445
+ """
446
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
447
+ process from the learned model outputs (most often the predicted noise).
448
+
449
+ Args:
450
+ model_output (`torch.Tensor`):
451
+ The direct output from learned diffusion model.
452
+ timestep (`float` or `torch.Tensor`):
453
+ The current discrete timestep in the diffusion chain.
454
+ sample (`torch.Tensor`):
455
+ A current instance of a sample created by the diffusion process.
456
+ order (`int`, defaults to 4):
457
+ The order of the linear multistep method.
458
+ return_dict (`bool`, *optional*, defaults to `True`):
459
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
460
+
461
+ Returns:
462
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
463
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
464
+ tuple is returned where the first element is the sample tensor.
465
+
466
+ """
467
+ if not self.is_scale_input_called:
468
+ warnings.warn(
469
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
470
+ "See `StableDiffusionPipeline` for a usage example."
471
+ )
472
+
473
+ if self.step_index is None:
474
+ self._init_step_index(timestep)
475
+
476
+ sigma = self.sigmas[self.step_index]
477
+
478
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
479
+ if self.config.prediction_type == "epsilon":
480
+ pred_original_sample = sample - sigma * model_output
481
+ elif self.config.prediction_type == "v_prediction":
482
+ # * c_out + input * c_skip
483
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
484
+ elif self.config.prediction_type == "sample":
485
+ pred_original_sample = model_output
486
+ else:
487
+ raise ValueError(
488
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
489
+ )
490
+
491
+ # 2. Convert to an ODE derivative
492
+ derivative = (sample - pred_original_sample) / sigma
493
+ self.derivatives.append(derivative)
494
+ if len(self.derivatives) > order:
495
+ self.derivatives.pop(0)
496
+
497
+ # 3. Compute linear multistep coefficients
498
+ order = min(self.step_index + 1, order)
499
+ lms_coeffs = [self.get_lms_coefficient(order, self.step_index, curr_order) for curr_order in range(order)]
500
+
501
+ # 4. Compute previous sample based on the derivatives path
502
+ prev_sample = sample + sum(
503
+ coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
504
+ )
505
+
506
+ # upon completion increase step index by one
507
+ self._step_index += 1
508
+
509
+ if not return_dict:
510
+ return (
511
+ prev_sample,
512
+ pred_original_sample,
513
+ )
514
+
515
+ return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
516
+
517
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
518
+ def add_noise(
519
+ self,
520
+ original_samples: torch.Tensor,
521
+ noise: torch.Tensor,
522
+ timesteps: torch.Tensor,
523
+ ) -> torch.Tensor:
524
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
525
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
526
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
527
+ # mps does not support float64
528
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
529
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
530
+ else:
531
+ schedule_timesteps = self.timesteps.to(original_samples.device)
532
+ timesteps = timesteps.to(original_samples.device)
533
+
534
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
535
+ if self.begin_index is None:
536
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
537
+ elif self.step_index is not None:
538
+ # add_noise is called after first denoising step (for inpainting)
539
+ step_indices = [self.step_index] * timesteps.shape[0]
540
+ else:
541
+ # add noise is called before first denoising step to create initial latent(img2img)
542
+ step_indices = [self.begin_index] * timesteps.shape[0]
543
+
544
+ sigma = sigmas[step_indices].flatten()
545
+ while len(sigma.shape) < len(original_samples.shape):
546
+ sigma = sigma.unsqueeze(-1)
547
+
548
+ noisy_samples = original_samples + noise * sigma
549
+ return noisy_samples
550
+
551
+ def __len__(self):
552
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/diffusers/schedulers/scheduling_lms_discrete_flax.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import flax
19
+ import jax.numpy as jnp
20
+ from scipy import integrate
21
+
22
+ from ..configuration_utils import ConfigMixin, register_to_config
23
+ from .scheduling_utils_flax import (
24
+ CommonSchedulerState,
25
+ FlaxKarrasDiffusionSchedulers,
26
+ FlaxSchedulerMixin,
27
+ FlaxSchedulerOutput,
28
+ broadcast_to_shape_from_left,
29
+ )
30
+
31
+
32
+ @flax.struct.dataclass
33
+ class LMSDiscreteSchedulerState:
34
+ common: CommonSchedulerState
35
+
36
+ # setable values
37
+ init_noise_sigma: jnp.ndarray
38
+ timesteps: jnp.ndarray
39
+ sigmas: jnp.ndarray
40
+ num_inference_steps: Optional[int] = None
41
+
42
+ # running values
43
+ derivatives: Optional[jnp.ndarray] = None
44
+
45
+ @classmethod
46
+ def create(
47
+ cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
48
+ ):
49
+ return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
50
+
51
+
52
+ @dataclass
53
+ class FlaxLMSSchedulerOutput(FlaxSchedulerOutput):
54
+ state: LMSDiscreteSchedulerState
55
+
56
+
57
+ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
58
+ """
59
+ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
60
+ Katherine Crowson:
61
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
62
+
63
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
64
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
65
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
66
+ [`~SchedulerMixin.from_pretrained`] functions.
67
+
68
+ Args:
69
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
70
+ beta_start (`float`): the starting `beta` value of inference.
71
+ beta_end (`float`): the final `beta` value.
72
+ beta_schedule (`str`):
73
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
74
+ `linear` or `scaled_linear`.
75
+ trained_betas (`jnp.ndarray`, optional):
76
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
77
+ prediction_type (`str`, default `epsilon`, optional):
78
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
79
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
80
+ https://imagen.research.google/video/paper.pdf)
81
+ dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
82
+ the `dtype` used for params and computation.
83
+ """
84
+
85
+ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
86
+
87
+ dtype: jnp.dtype
88
+
89
+ @property
90
+ def has_state(self):
91
+ return True
92
+
93
+ @register_to_config
94
+ def __init__(
95
+ self,
96
+ num_train_timesteps: int = 1000,
97
+ beta_start: float = 0.0001,
98
+ beta_end: float = 0.02,
99
+ beta_schedule: str = "linear",
100
+ trained_betas: Optional[jnp.ndarray] = None,
101
+ prediction_type: str = "epsilon",
102
+ dtype: jnp.dtype = jnp.float32,
103
+ ):
104
+ self.dtype = dtype
105
+
106
+ def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState:
107
+ if common is None:
108
+ common = CommonSchedulerState.create(self)
109
+
110
+ timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
111
+ sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5
112
+
113
+ # standard deviation of the initial noise distribution
114
+ init_noise_sigma = sigmas.max()
115
+
116
+ return LMSDiscreteSchedulerState.create(
117
+ common=common,
118
+ init_noise_sigma=init_noise_sigma,
119
+ timesteps=timesteps,
120
+ sigmas=sigmas,
121
+ )
122
+
123
+ def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
124
+ """
125
+ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
126
+
127
+ Args:
128
+ state (`LMSDiscreteSchedulerState`):
129
+ the `FlaxLMSDiscreteScheduler` state data class instance.
130
+ sample (`jnp.ndarray`):
131
+ current instance of sample being created by diffusion process.
132
+ timestep (`int`):
133
+ current discrete timestep in the diffusion chain.
134
+
135
+ Returns:
136
+ `jnp.ndarray`: scaled input sample
137
+ """
138
+ (step_index,) = jnp.where(state.timesteps == timestep, size=1)
139
+ step_index = step_index[0]
140
+
141
+ sigma = state.sigmas[step_index]
142
+ sample = sample / ((sigma**2 + 1) ** 0.5)
143
+ return sample
144
+
145
+ def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order):
146
+ """
147
+ Compute a linear multistep coefficient.
148
+
149
+ Args:
150
+ order (TODO):
151
+ t (TODO):
152
+ current_order (TODO):
153
+ """
154
+
155
+ def lms_derivative(tau):
156
+ prod = 1.0
157
+ for k in range(order):
158
+ if current_order == k:
159
+ continue
160
+ prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k])
161
+ return prod
162
+
163
+ integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0]
164
+
165
+ return integrated_coeff
166
+
167
+ def set_timesteps(
168
+ self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
169
+ ) -> LMSDiscreteSchedulerState:
170
+ """
171
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
172
+
173
+ Args:
174
+ state (`LMSDiscreteSchedulerState`):
175
+ the `FlaxLMSDiscreteScheduler` state data class instance.
176
+ num_inference_steps (`int`):
177
+ the number of diffusion steps used when generating samples with a pre-trained model.
178
+ """
179
+
180
+ timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
181
+
182
+ low_idx = jnp.floor(timesteps).astype(jnp.int32)
183
+ high_idx = jnp.ceil(timesteps).astype(jnp.int32)
184
+
185
+ frac = jnp.mod(timesteps, 1.0)
186
+
187
+ sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5
188
+ sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
189
+ sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])
190
+
191
+ timesteps = timesteps.astype(jnp.int32)
192
+
193
+ # initial running values
194
+ derivatives = jnp.zeros((0,) + shape, dtype=self.dtype)
195
+
196
+ return state.replace(
197
+ timesteps=timesteps,
198
+ sigmas=sigmas,
199
+ num_inference_steps=num_inference_steps,
200
+ derivatives=derivatives,
201
+ )
202
+
203
+ def step(
204
+ self,
205
+ state: LMSDiscreteSchedulerState,
206
+ model_output: jnp.ndarray,
207
+ timestep: int,
208
+ sample: jnp.ndarray,
209
+ order: int = 4,
210
+ return_dict: bool = True,
211
+ ) -> Union[FlaxLMSSchedulerOutput, Tuple]:
212
+ """
213
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
214
+ process from the learned model outputs (most often the predicted noise).
215
+
216
+ Args:
217
+ state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance.
218
+ model_output (`jnp.ndarray`): direct output from learned diffusion model.
219
+ timestep (`int`): current discrete timestep in the diffusion chain.
220
+ sample (`jnp.ndarray`):
221
+ current instance of sample being created by diffusion process.
222
+ order: coefficient for multi-step inference.
223
+ return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class
224
+
225
+ Returns:
226
+ [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a
227
+ `tuple`. When returning a tuple, the first element is the sample tensor.
228
+
229
+ """
230
+ if state.num_inference_steps is None:
231
+ raise ValueError(
232
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
233
+ )
234
+
235
+ sigma = state.sigmas[timestep]
236
+
237
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
238
+ if self.config.prediction_type == "epsilon":
239
+ pred_original_sample = sample - sigma * model_output
240
+ elif self.config.prediction_type == "v_prediction":
241
+ # * c_out + input * c_skip
242
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
243
+ else:
244
+ raise ValueError(
245
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
246
+ )
247
+
248
+ # 2. Convert to an ODE derivative
249
+ derivative = (sample - pred_original_sample) / sigma
250
+ state = state.replace(derivatives=jnp.append(state.derivatives, derivative))
251
+ if len(state.derivatives) > order:
252
+ state = state.replace(derivatives=jnp.delete(state.derivatives, 0))
253
+
254
+ # 3. Compute linear multistep coefficients
255
+ order = min(timestep + 1, order)
256
+ lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)]
257
+
258
+ # 4. Compute previous sample based on the derivatives path
259
+ prev_sample = sample + sum(
260
+ coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives))
261
+ )
262
+
263
+ if not return_dict:
264
+ return (prev_sample, state)
265
+
266
+ return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state)
267
+
268
+ def add_noise(
269
+ self,
270
+ state: LMSDiscreteSchedulerState,
271
+ original_samples: jnp.ndarray,
272
+ noise: jnp.ndarray,
273
+ timesteps: jnp.ndarray,
274
+ ) -> jnp.ndarray:
275
+ sigma = state.sigmas[timesteps].flatten()
276
+ sigma = broadcast_to_shape_from_left(sigma, noise.shape)
277
+
278
+ noisy_samples = original_samples + noise * sigma
279
+
280
+ return noisy_samples
281
+
282
+ def __len__(self):
283
+ return self.config.num_train_timesteps
pythonProject/.venv/Lib/site-packages/diffusers/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (5.16 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/utils/__pycache__/state_dict_utils.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/utils/__pycache__/testing_utils.cpython-310.pyc ADDED
Binary file (49.4 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/utils/__pycache__/torch_utils.cpython-310.pyc ADDED
Binary file (9.37 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/utils/__pycache__/typing_utils.cpython-310.pyc ADDED
Binary file (3.43 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/utils/__pycache__/versions.cpython-310.pyc ADDED
Binary file (3.13 kB). View file