xiaoanyu123 commited on
Commit
20eef34
·
verified ·
1 Parent(s): 58b6f82

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latent_diffusion/__init__.py +50 -0
  2. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
  3. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-310.pyc +0 -0
  4. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-310.pyc +0 -0
  5. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +200 -0
  6. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latte/__init__.py +48 -0
  7. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latte/__pycache__/__init__.cpython-310.pyc +0 -0
  8. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latte/__pycache__/pipeline_latte.cpython-310.pyc +0 -0
  9. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latte/pipeline_latte.py +910 -0
  10. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/qwenimage/__pycache__/__init__.cpython-310.pyc +0 -0
  11. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/__init__.py +53 -0
  12. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/__pycache__/pipeline_output.cpython-310.pyc +0 -0
  13. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/__pycache__/pipeline_sana.cpython-310.pyc +0 -0
  14. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/__pycache__/pipeline_sana_controlnet.cpython-310.pyc +0 -0
  15. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/__pycache__/pipeline_sana_sprint.cpython-310.pyc +0 -0
  16. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/pipeline_output.py +21 -0
  17. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/pipeline_sana.py +1011 -0
  18. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  19. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/pipeline_sana_sprint.py +893 -0
  20. pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latent_diffusion/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _import_structure = {}
15
+
16
+ try:
17
+ if not (is_transformers_available() and is_torch_available()):
18
+ raise OptionalDependencyNotAvailable()
19
+ except OptionalDependencyNotAvailable:
20
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
21
+
22
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
23
+ else:
24
+ _import_structure["pipeline_latent_diffusion"] = ["LDMBertModel", "LDMTextToImagePipeline"]
25
+ _import_structure["pipeline_latent_diffusion_superresolution"] = ["LDMSuperResolutionPipeline"]
26
+
27
+
28
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
29
+ try:
30
+ if not (is_transformers_available() and is_torch_available()):
31
+ raise OptionalDependencyNotAvailable()
32
+
33
+ except OptionalDependencyNotAvailable:
34
+ from ...utils.dummy_torch_and_transformers_objects import *
35
+ else:
36
+ from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline
37
+ from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline
38
+
39
+ else:
40
+ import sys
41
+
42
+ sys.modules[__name__] = _LazyModule(
43
+ __name__,
44
+ globals()["__file__"],
45
+ _import_structure,
46
+ module_spec=__spec__,
47
+ )
48
+
49
+ for name, value in _dummy_objects.items():
50
+ setattr(sys.modules[__name__], name, value)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latent_diffusion/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.2 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion.cpython-310.pyc ADDED
Binary file (22.5 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latent_diffusion/__pycache__/pipeline_latent_diffusion_superresolution.cpython-310.pyc ADDED
Binary file (7.5 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+ import torch.utils.checkpoint
8
+
9
+ from ...models import UNet2DModel, VQModel
10
+ from ...schedulers import (
11
+ DDIMScheduler,
12
+ DPMSolverMultistepScheduler,
13
+ EulerAncestralDiscreteScheduler,
14
+ EulerDiscreteScheduler,
15
+ LMSDiscreteScheduler,
16
+ PNDMScheduler,
17
+ )
18
+ from ...utils import PIL_INTERPOLATION, is_torch_xla_available
19
+ from ...utils.torch_utils import randn_tensor
20
+ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
21
+
22
+
23
+ if is_torch_xla_available():
24
+ import torch_xla.core.xla_model as xm
25
+
26
+ XLA_AVAILABLE = True
27
+ else:
28
+ XLA_AVAILABLE = False
29
+
30
+
31
+ def preprocess(image):
32
+ w, h = image.size
33
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
34
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
35
+ image = np.array(image).astype(np.float32) / 255.0
36
+ image = image[None].transpose(0, 3, 1, 2)
37
+ image = torch.from_numpy(image)
38
+ return 2.0 * image - 1.0
39
+
40
+
41
+ class LDMSuperResolutionPipeline(DiffusionPipeline):
42
+ r"""
43
+ A pipeline for image super-resolution using latent diffusion.
44
+
45
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
46
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
47
+
48
+ Parameters:
49
+ vqvae ([`VQModel`]):
50
+ Vector-quantized (VQ) model to encode and decode images to and from latent representations.
51
+ unet ([`UNet2DModel`]):
52
+ A `UNet2DModel` to denoise the encoded image.
53
+ scheduler ([`SchedulerMixin`]):
54
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
55
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`],
56
+ [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`].
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ vqvae: VQModel,
62
+ unet: UNet2DModel,
63
+ scheduler: Union[
64
+ DDIMScheduler,
65
+ PNDMScheduler,
66
+ LMSDiscreteScheduler,
67
+ EulerDiscreteScheduler,
68
+ EulerAncestralDiscreteScheduler,
69
+ DPMSolverMultistepScheduler,
70
+ ],
71
+ ):
72
+ super().__init__()
73
+ self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
74
+
75
+ @torch.no_grad()
76
+ def __call__(
77
+ self,
78
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
79
+ batch_size: Optional[int] = 1,
80
+ num_inference_steps: Optional[int] = 100,
81
+ eta: Optional[float] = 0.0,
82
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
83
+ output_type: Optional[str] = "pil",
84
+ return_dict: bool = True,
85
+ ) -> Union[Tuple, ImagePipelineOutput]:
86
+ r"""
87
+ The call function to the pipeline for generation.
88
+
89
+ Args:
90
+ image (`torch.Tensor` or `PIL.Image.Image`):
91
+ `Image` or tensor representing an image batch to be used as the starting point for the process.
92
+ batch_size (`int`, *optional*, defaults to 1):
93
+ Number of images to generate.
94
+ num_inference_steps (`int`, *optional*, defaults to 100):
95
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
96
+ expense of slower inference.
97
+ eta (`float`, *optional*, defaults to 0.0):
98
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
99
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
100
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
101
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
102
+ generation deterministic.
103
+ output_type (`str`, *optional*, defaults to `"pil"`):
104
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
105
+ return_dict (`bool`, *optional*, defaults to `True`):
106
+ Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.
107
+
108
+ Example:
109
+
110
+ ```py
111
+ >>> import requests
112
+ >>> from PIL import Image
113
+ >>> from io import BytesIO
114
+ >>> from diffusers import LDMSuperResolutionPipeline
115
+ >>> import torch
116
+
117
+ >>> # load model and scheduler
118
+ >>> pipeline = LDMSuperResolutionPipeline.from_pretrained("CompVis/ldm-super-resolution-4x-openimages")
119
+ >>> pipeline = pipeline.to("cuda")
120
+
121
+ >>> # let's download an image
122
+ >>> url = (
123
+ ... "https://user-images.githubusercontent.com/38061659/199705896-b48e17b8-b231-47cd-a270-4ffa5a93fa3e.png"
124
+ ... )
125
+ >>> response = requests.get(url)
126
+ >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
127
+ >>> low_res_img = low_res_img.resize((128, 128))
128
+
129
+ >>> # run pipeline in inference (sample random noise and denoise)
130
+ >>> upscaled_image = pipeline(low_res_img, num_inference_steps=100, eta=1).images[0]
131
+ >>> # save image
132
+ >>> upscaled_image.save("ldm_generated_image.png")
133
+ ```
134
+
135
+ Returns:
136
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
137
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
138
+ returned where the first element is a list with the generated images
139
+ """
140
+ if isinstance(image, PIL.Image.Image):
141
+ batch_size = 1
142
+ elif isinstance(image, torch.Tensor):
143
+ batch_size = image.shape[0]
144
+ else:
145
+ raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}")
146
+
147
+ if isinstance(image, PIL.Image.Image):
148
+ image = preprocess(image)
149
+
150
+ height, width = image.shape[-2:]
151
+
152
+ # in_channels should be 6: 3 for latents, 3 for low resolution image
153
+ latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width)
154
+ latents_dtype = next(self.unet.parameters()).dtype
155
+
156
+ latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
157
+
158
+ image = image.to(device=self.device, dtype=latents_dtype)
159
+
160
+ # set timesteps and move to the correct device
161
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
162
+ timesteps_tensor = self.scheduler.timesteps
163
+
164
+ # scale the initial noise by the standard deviation required by the scheduler
165
+ latents = latents * self.scheduler.init_noise_sigma
166
+
167
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature.
168
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
169
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
170
+ # and should be between [0, 1]
171
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
172
+ extra_kwargs = {}
173
+ if accepts_eta:
174
+ extra_kwargs["eta"] = eta
175
+
176
+ for t in self.progress_bar(timesteps_tensor):
177
+ # concat latents and low resolution image in the channel dimension.
178
+ latents_input = torch.cat([latents, image], dim=1)
179
+ latents_input = self.scheduler.scale_model_input(latents_input, t)
180
+ # predict the noise residual
181
+ noise_pred = self.unet(latents_input, t).sample
182
+ # compute the previous noisy sample x_t -> x_t-1
183
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
184
+
185
+ if XLA_AVAILABLE:
186
+ xm.mark_step()
187
+
188
+ # decode the image latents with the VQVAE
189
+ image = self.vqvae.decode(latents).sample
190
+ image = torch.clamp(image, -1.0, 1.0)
191
+ image = image / 2 + 0.5
192
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
193
+
194
+ if output_type == "pil":
195
+ image = self.numpy_to_pil(image)
196
+
197
+ if not return_dict:
198
+ return (image,)
199
+
200
+ return ImagePipelineOutput(images=image)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latte/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _import_structure = {}
15
+
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["pipeline_latte"] = ["LattePipeline"]
26
+
27
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
28
+ try:
29
+ if not (is_transformers_available() and is_torch_available()):
30
+ raise OptionalDependencyNotAvailable()
31
+
32
+ except OptionalDependencyNotAvailable:
33
+ from ...utils.dummy_torch_and_transformers_objects import *
34
+ else:
35
+ from .pipeline_latte import LattePipeline
36
+
37
+ else:
38
+ import sys
39
+
40
+ sys.modules[__name__] = _LazyModule(
41
+ __name__,
42
+ globals()["__file__"],
43
+ _import_structure,
44
+ module_spec=__spec__,
45
+ )
46
+
47
+ for name, value in _dummy_objects.items():
48
+ setattr(sys.modules[__name__], name, value)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latte/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.03 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latte/__pycache__/pipeline_latte.cpython-310.pyc ADDED
Binary file (27.4 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/latte/pipeline_latte.py ADDED
@@ -0,0 +1,910 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 the Latte Team and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import html
17
+ import inspect
18
+ import re
19
+ import urllib.parse as ul
20
+ from dataclasses import dataclass
21
+ from typing import Callable, Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ from transformers import T5EncoderModel, T5Tokenizer
25
+
26
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
27
+ from ...models import AutoencoderKL, LatteTransformer3DModel
28
+ from ...pipelines.pipeline_utils import DiffusionPipeline
29
+ from ...schedulers import KarrasDiffusionSchedulers
30
+ from ...utils import (
31
+ BACKENDS_MAPPING,
32
+ BaseOutput,
33
+ deprecate,
34
+ is_bs4_available,
35
+ is_ftfy_available,
36
+ is_torch_xla_available,
37
+ logging,
38
+ replace_example_docstring,
39
+ )
40
+ from ...utils.torch_utils import is_compiled_module, randn_tensor
41
+ from ...video_processor import VideoProcessor
42
+
43
+
44
+ if is_torch_xla_available():
45
+ import torch_xla.core.xla_model as xm
46
+
47
+ XLA_AVAILABLE = True
48
+ else:
49
+ XLA_AVAILABLE = False
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+
54
+ if is_bs4_available():
55
+ from bs4 import BeautifulSoup
56
+
57
+ if is_ftfy_available():
58
+ import ftfy
59
+
60
+
61
+ EXAMPLE_DOC_STRING = """
62
+ Examples:
63
+ ```py
64
+ >>> import torch
65
+ >>> from diffusers import LattePipeline
66
+ >>> from diffusers.utils import export_to_gif
67
+
68
+ >>> # You can replace the checkpoint id with "maxin-cn/Latte-1" too.
69
+ >>> pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16)
70
+ >>> # Enable memory optimizations.
71
+ >>> pipe.enable_model_cpu_offload()
72
+
73
+ >>> prompt = "A small cactus with a happy face in the Sahara desert."
74
+ >>> videos = pipe(prompt).frames[0]
75
+ >>> export_to_gif(videos, "latte.gif")
76
+ ```
77
+ """
78
+
79
+
80
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
81
+ def retrieve_timesteps(
82
+ scheduler,
83
+ num_inference_steps: Optional[int] = None,
84
+ device: Optional[Union[str, torch.device]] = None,
85
+ timesteps: Optional[List[int]] = None,
86
+ sigmas: Optional[List[float]] = None,
87
+ **kwargs,
88
+ ):
89
+ r"""
90
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
91
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
92
+
93
+ Args:
94
+ scheduler (`SchedulerMixin`):
95
+ The scheduler to get timesteps from.
96
+ num_inference_steps (`int`):
97
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
98
+ must be `None`.
99
+ device (`str` or `torch.device`, *optional*):
100
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
101
+ timesteps (`List[int]`, *optional*):
102
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
103
+ `num_inference_steps` and `sigmas` must be `None`.
104
+ sigmas (`List[float]`, *optional*):
105
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
106
+ `num_inference_steps` and `timesteps` must be `None`.
107
+
108
+ Returns:
109
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
110
+ second element is the number of inference steps.
111
+ """
112
+ if timesteps is not None and sigmas is not None:
113
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
114
+ if timesteps is not None:
115
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
116
+ if not accepts_timesteps:
117
+ raise ValueError(
118
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
119
+ f" timestep schedules. Please check whether you are using the correct scheduler."
120
+ )
121
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
122
+ timesteps = scheduler.timesteps
123
+ num_inference_steps = len(timesteps)
124
+ elif sigmas is not None:
125
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126
+ if not accept_sigmas:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ else:
135
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
136
+ timesteps = scheduler.timesteps
137
+ return timesteps, num_inference_steps
138
+
139
+
140
+ @dataclass
141
+ class LattePipelineOutput(BaseOutput):
142
+ frames: torch.Tensor
143
+
144
+
145
+ class LattePipeline(DiffusionPipeline):
146
+ r"""
147
+ Pipeline for text-to-video generation using Latte.
148
+
149
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
150
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
151
+
152
+ Args:
153
+ vae ([`AutoencoderKL`]):
154
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
155
+ text_encoder ([`T5EncoderModel`]):
156
+ Frozen text-encoder. Latte uses
157
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
158
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
159
+ tokenizer (`T5Tokenizer`):
160
+ Tokenizer of class
161
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
162
+ transformer ([`LatteTransformer3DModel`]):
163
+ A text conditioned `LatteTransformer3DModel` to denoise the encoded video latents.
164
+ scheduler ([`SchedulerMixin`]):
165
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
166
+ """
167
+
168
+ bad_punct_regex = re.compile(r"[#®•©™&@·º½¾¿¡§~\)\(\]\[\}\{\|\\/\\*]{1,}")
169
+
170
+ _optional_components = ["tokenizer", "text_encoder"]
171
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
172
+
173
+ _callback_tensor_inputs = [
174
+ "latents",
175
+ "prompt_embeds",
176
+ "negative_prompt_embeds",
177
+ ]
178
+
179
+ def __init__(
180
+ self,
181
+ tokenizer: T5Tokenizer,
182
+ text_encoder: T5EncoderModel,
183
+ vae: AutoencoderKL,
184
+ transformer: LatteTransformer3DModel,
185
+ scheduler: KarrasDiffusionSchedulers,
186
+ ):
187
+ super().__init__()
188
+
189
+ self.register_modules(
190
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
191
+ )
192
+
193
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
194
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
195
+
196
+ # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
197
+ def mask_text_embeddings(self, emb, mask):
198
+ if emb.shape[0] == 1:
199
+ keep_index = mask.sum().item()
200
+ return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096
201
+ else:
202
+ masked_feature = emb * mask[:, None, :, None] # 1 120 4096
203
+ return masked_feature, emb.shape[2]
204
+
205
+ # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
206
+ def encode_prompt(
207
+ self,
208
+ prompt: Union[str, List[str]],
209
+ do_classifier_free_guidance: bool = True,
210
+ negative_prompt: str = "",
211
+ num_images_per_prompt: int = 1,
212
+ device: Optional[torch.device] = None,
213
+ prompt_embeds: Optional[torch.FloatTensor] = None,
214
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
215
+ clean_caption: bool = False,
216
+ mask_feature: bool = True,
217
+ dtype=None,
218
+ ):
219
+ r"""
220
+ Encodes the prompt into text encoder hidden states.
221
+
222
+ Args:
223
+ prompt (`str` or `List[str]`, *optional*):
224
+ prompt to be encoded
225
+ negative_prompt (`str` or `List[str]`, *optional*):
226
+ The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds`
227
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
228
+ Latte, this should be "".
229
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
230
+ whether to use classifier free guidance or not
231
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
232
+ number of video that should be generated per prompt
233
+ device: (`torch.device`, *optional*):
234
+ torch device to place the resulting embeddings on
235
+ prompt_embeds (`torch.FloatTensor`, *optional*):
236
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
237
+ provided, text embeddings will be generated from `prompt` input argument.
238
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
239
+ Pre-generated negative text embeddings. For Latte, it's should be the embeddings of the "" string.
240
+ clean_caption (bool, defaults to `False`):
241
+ If `True`, the function will preprocess and clean the provided caption before encoding.
242
+ mask_feature: (bool, defaults to `True`):
243
+ If `True`, the function will mask the text embeddings.
244
+ """
245
+ embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
246
+
247
+ if device is None:
248
+ device = self._execution_device
249
+
250
+ if prompt is not None and isinstance(prompt, str):
251
+ batch_size = 1
252
+ elif prompt is not None and isinstance(prompt, list):
253
+ batch_size = len(prompt)
254
+ else:
255
+ batch_size = prompt_embeds.shape[0]
256
+
257
+ max_length = 120
258
+ if prompt_embeds is None:
259
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
260
+ text_inputs = self.tokenizer(
261
+ prompt,
262
+ padding="max_length",
263
+ max_length=max_length,
264
+ truncation=True,
265
+ return_attention_mask=True,
266
+ add_special_tokens=True,
267
+ return_tensors="pt",
268
+ )
269
+ text_input_ids = text_inputs.input_ids
270
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
271
+
272
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
273
+ text_input_ids, untruncated_ids
274
+ ):
275
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
276
+ logger.warning(
277
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
278
+ f" {max_length} tokens: {removed_text}"
279
+ )
280
+
281
+ attention_mask = text_inputs.attention_mask.to(device)
282
+ prompt_embeds_attention_mask = attention_mask
283
+
284
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
285
+ prompt_embeds = prompt_embeds[0]
286
+ else:
287
+ prompt_embeds_attention_mask = torch.ones_like(prompt_embeds)
288
+
289
+ if self.text_encoder is not None:
290
+ dtype = self.text_encoder.dtype
291
+ elif self.transformer is not None:
292
+ dtype = self.transformer.dtype
293
+ else:
294
+ dtype = None
295
+
296
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
297
+
298
+ bs_embed, seq_len, _ = prompt_embeds.shape
299
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
300
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
301
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
302
+ prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1)
303
+ prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1)
304
+
305
+ # get unconditional embeddings for classifier free guidance
306
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
307
+ uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
308
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
309
+ max_length = prompt_embeds.shape[1]
310
+ uncond_input = self.tokenizer(
311
+ uncond_tokens,
312
+ padding="max_length",
313
+ max_length=max_length,
314
+ truncation=True,
315
+ return_attention_mask=True,
316
+ add_special_tokens=True,
317
+ return_tensors="pt",
318
+ )
319
+ attention_mask = uncond_input.attention_mask.to(device)
320
+
321
+ negative_prompt_embeds = self.text_encoder(
322
+ uncond_input.input_ids.to(device),
323
+ attention_mask=attention_mask,
324
+ )
325
+ negative_prompt_embeds = negative_prompt_embeds[0]
326
+
327
+ if do_classifier_free_guidance:
328
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
329
+ seq_len = negative_prompt_embeds.shape[1]
330
+
331
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
332
+
333
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
334
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
335
+
336
+ # For classifier free guidance, we need to do two forward passes.
337
+ # Here we concatenate the unconditional and text embeddings into a single batch
338
+ # to avoid doing two forward passes
339
+ else:
340
+ negative_prompt_embeds = None
341
+
342
+ # Perform additional masking.
343
+ if mask_feature and not embeds_initially_provided:
344
+ prompt_embeds = prompt_embeds.unsqueeze(1)
345
+ masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
346
+ masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
347
+ masked_negative_prompt_embeds = (
348
+ negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None
349
+ )
350
+
351
+ return masked_prompt_embeds, masked_negative_prompt_embeds
352
+
353
+ return prompt_embeds, negative_prompt_embeds
354
+
355
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
356
+ def prepare_extra_step_kwargs(self, generator, eta):
357
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
358
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
359
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
360
+ # and should be between [0, 1]
361
+
362
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
363
+ extra_step_kwargs = {}
364
+ if accepts_eta:
365
+ extra_step_kwargs["eta"] = eta
366
+
367
+ # check if the scheduler accepts generator
368
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
369
+ if accepts_generator:
370
+ extra_step_kwargs["generator"] = generator
371
+ return extra_step_kwargs
372
+
373
+ def check_inputs(
374
+ self,
375
+ prompt,
376
+ height,
377
+ width,
378
+ negative_prompt,
379
+ callback_on_step_end_tensor_inputs,
380
+ prompt_embeds=None,
381
+ negative_prompt_embeds=None,
382
+ ):
383
+ if height % 8 != 0 or width % 8 != 0:
384
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
385
+
386
+ if callback_on_step_end_tensor_inputs is not None and not all(
387
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
388
+ ):
389
+ raise ValueError(
390
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
391
+ )
392
+ if prompt is not None and prompt_embeds is not None:
393
+ raise ValueError(
394
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
395
+ " only forward one of the two."
396
+ )
397
+ elif prompt is None and prompt_embeds is None:
398
+ raise ValueError(
399
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
400
+ )
401
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
402
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
403
+
404
+ if prompt is not None and negative_prompt_embeds is not None:
405
+ raise ValueError(
406
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
407
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
408
+ )
409
+
410
+ if negative_prompt is not None and negative_prompt_embeds is not None:
411
+ raise ValueError(
412
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
413
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
414
+ )
415
+
416
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
417
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
418
+ raise ValueError(
419
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
420
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
421
+ f" {negative_prompt_embeds.shape}."
422
+ )
423
+
424
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
425
+ def _text_preprocessing(self, text, clean_caption=False):
426
+ if clean_caption and not is_bs4_available():
427
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
428
+ logger.warning("Setting `clean_caption` to False...")
429
+ clean_caption = False
430
+
431
+ if clean_caption and not is_ftfy_available():
432
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
433
+ logger.warning("Setting `clean_caption` to False...")
434
+ clean_caption = False
435
+
436
+ if not isinstance(text, (tuple, list)):
437
+ text = [text]
438
+
439
+ def process(text: str):
440
+ if clean_caption:
441
+ text = self._clean_caption(text)
442
+ text = self._clean_caption(text)
443
+ else:
444
+ text = text.lower().strip()
445
+ return text
446
+
447
+ return [process(t) for t in text]
448
+
449
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
450
+ def _clean_caption(self, caption):
451
+ caption = str(caption)
452
+ caption = ul.unquote_plus(caption)
453
+ caption = caption.strip().lower()
454
+ caption = re.sub("<person>", "person", caption)
455
+ # urls:
456
+ caption = re.sub(
457
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
458
+ "",
459
+ caption,
460
+ ) # regex for urls
461
+ caption = re.sub(
462
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
463
+ "",
464
+ caption,
465
+ ) # regex for urls
466
+ # html:
467
+ caption = BeautifulSoup(caption, features="html.parser").text
468
+
469
+ # @<nickname>
470
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
471
+
472
+ # 31C0—31EF CJK Strokes
473
+ # 31F0—31FF Katakana Phonetic Extensions
474
+ # 3200—32FF Enclosed CJK Letters and Months
475
+ # 3300—33FF CJK Compatibility
476
+ # 3400—4DBF CJK Unified Ideographs Extension A
477
+ # 4DC0—4DFF Yijing Hexagram Symbols
478
+ # 4E00—9FFF CJK Unified Ideographs
479
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
480
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
481
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
482
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
483
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
484
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
485
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
486
+ #######################################################
487
+
488
+ # все виды тире / all types of dash --> "-"
489
+ caption = re.sub(
490
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
491
+ "-",
492
+ caption,
493
+ )
494
+
495
+ # кавычки к одному стандарту
496
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
497
+ caption = re.sub(r"[‘’]", "'", caption)
498
+
499
+ # &quot;
500
+ caption = re.sub(r"&quot;?", "", caption)
501
+ # &amp
502
+ caption = re.sub(r"&amp", "", caption)
503
+
504
+ # ip addresses:
505
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
506
+
507
+ # article ids:
508
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
509
+
510
+ # \n
511
+ caption = re.sub(r"\\n", " ", caption)
512
+
513
+ # "#123"
514
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
515
+ # "#12345.."
516
+ caption = re.sub(r"#\d{5,}\b", "", caption)
517
+ # "123456.."
518
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
519
+ # filenames:
520
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
521
+
522
+ #
523
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
524
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
525
+
526
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
527
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
528
+
529
+ # this-is-my-cute-cat / this_is_my_cute_cat
530
+ regex2 = re.compile(r"(?:\-|\_)")
531
+ if len(re.findall(regex2, caption)) > 3:
532
+ caption = re.sub(regex2, " ", caption)
533
+
534
+ caption = ftfy.fix_text(caption)
535
+ caption = html.unescape(html.unescape(caption))
536
+
537
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
538
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
539
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
540
+
541
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
542
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
543
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
544
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
545
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
546
+
547
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
548
+
549
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
550
+
551
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
552
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
553
+ caption = re.sub(r"\s+", " ", caption)
554
+
555
+ caption.strip()
556
+
557
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
558
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
559
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
560
+ caption = re.sub(r"^\.\S+$", "", caption)
561
+
562
+ return caption.strip()
563
+
564
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
565
+ def prepare_latents(
566
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
567
+ ):
568
+ shape = (
569
+ batch_size,
570
+ num_channels_latents,
571
+ num_frames,
572
+ height // self.vae_scale_factor,
573
+ width // self.vae_scale_factor,
574
+ )
575
+ if isinstance(generator, list) and len(generator) != batch_size:
576
+ raise ValueError(
577
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
578
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
579
+ )
580
+
581
+ if latents is None:
582
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
583
+ else:
584
+ latents = latents.to(device)
585
+
586
+ # scale the initial noise by the standard deviation required by the scheduler
587
+ latents = latents * self.scheduler.init_noise_sigma
588
+ return latents
589
+
590
+ @property
591
+ def guidance_scale(self):
592
+ return self._guidance_scale
593
+
594
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
595
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
596
+ # corresponds to doing no classifier free guidance.
597
+ @property
598
+ def do_classifier_free_guidance(self):
599
+ return self._guidance_scale > 1
600
+
601
+ @property
602
+ def num_timesteps(self):
603
+ return self._num_timesteps
604
+
605
+ @property
606
+ def current_timestep(self):
607
+ return self._current_timestep
608
+
609
+ @property
610
+ def interrupt(self):
611
+ return self._interrupt
612
+
613
+ @torch.no_grad()
614
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
615
+ def __call__(
616
+ self,
617
+ prompt: Union[str, List[str]] = None,
618
+ negative_prompt: str = "",
619
+ num_inference_steps: int = 50,
620
+ timesteps: Optional[List[int]] = None,
621
+ guidance_scale: float = 7.5,
622
+ num_images_per_prompt: int = 1,
623
+ video_length: int = 16,
624
+ height: int = 512,
625
+ width: int = 512,
626
+ eta: float = 0.0,
627
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
628
+ latents: Optional[torch.FloatTensor] = None,
629
+ prompt_embeds: Optional[torch.FloatTensor] = None,
630
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
631
+ output_type: str = "pil",
632
+ return_dict: bool = True,
633
+ callback_on_step_end: Optional[
634
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
635
+ ] = None,
636
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
637
+ clean_caption: bool = True,
638
+ mask_feature: bool = True,
639
+ enable_temporal_attentions: bool = True,
640
+ decode_chunk_size: int = 14,
641
+ ) -> Union[LattePipelineOutput, Tuple]:
642
+ """
643
+ Function invoked when calling the pipeline for generation.
644
+
645
+ Args:
646
+ prompt (`str` or `List[str]`, *optional*):
647
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
648
+ instead.
649
+ negative_prompt (`str` or `List[str]`, *optional*):
650
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
651
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
652
+ less than `1`).
653
+ num_inference_steps (`int`, *optional*, defaults to 100):
654
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
655
+ expense of slower inference.
656
+ timesteps (`List[int]`, *optional*):
657
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
658
+ timesteps are used. Must be in descending order.
659
+ guidance_scale (`float`, *optional*, defaults to 7.0):
660
+ Guidance scale as defined in [Classifier-Free Diffusion
661
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
662
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
663
+ `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to
664
+ the text `prompt`, usually at the expense of lower video quality.
665
+ video_length (`int`, *optional*, defaults to 16):
666
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
667
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
668
+ The number of videos to generate per prompt.
669
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
670
+ The height in pixels of the generated video.
671
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
672
+ The width in pixels of the generated video.
673
+ eta (`float`, *optional*, defaults to 0.0):
674
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
675
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
676
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
677
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
678
+ to make generation deterministic.
679
+ latents (`torch.FloatTensor`, *optional*):
680
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
681
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
682
+ tensor will be generated by sampling using the supplied random `generator`.
683
+ prompt_embeds (`torch.FloatTensor`, *optional*):
684
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
685
+ provided, text embeddings will be generated from `prompt` input argument.
686
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
687
+ Pre-generated negative text embeddings. For Latte this negative prompt should be "". If not provided,
688
+ negative_prompt_embeds will be generated from `negative_prompt` input argument.
689
+ output_type (`str`, *optional*, defaults to `"pil"`):
690
+ The output format of the generate video. Choose between
691
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
692
+ return_dict (`bool`, *optional*, defaults to `True`):
693
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
694
+ callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
695
+ A callback function or a list of callback functions to be called at the end of each denoising step.
696
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
697
+ A list of tensor inputs that should be passed to the callback function. If not defined, all tensor
698
+ inputs will be passed.
699
+ clean_caption (`bool`, *optional*, defaults to `True`):
700
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
701
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
702
+ prompt.
703
+ mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
704
+ enable_temporal_attentions (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions
705
+ decode_chunk_size (`int`, *optional*):
706
+ The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
707
+ expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
708
+ For lower memory usage, reduce `decode_chunk_size`.
709
+
710
+ Examples:
711
+
712
+ Returns:
713
+ [`~pipelines.latte.pipeline_latte.LattePipelineOutput`] or `tuple`:
714
+ If `return_dict` is `True`, [`~pipelines.latte.pipeline_latte.LattePipelineOutput`] is returned,
715
+ otherwise a `tuple` is returned where the first element is a list with the generated images
716
+ """
717
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
718
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
719
+
720
+ # 0. Default
721
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else video_length
722
+
723
+ # 1. Check inputs. Raise error if not correct
724
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor
725
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor
726
+ self.check_inputs(
727
+ prompt,
728
+ height,
729
+ width,
730
+ negative_prompt,
731
+ callback_on_step_end_tensor_inputs,
732
+ prompt_embeds,
733
+ negative_prompt_embeds,
734
+ )
735
+ self._guidance_scale = guidance_scale
736
+ self._current_timestep = None
737
+ self._interrupt = False
738
+
739
+ # 2. Default height and width to transformer
740
+ if prompt is not None and isinstance(prompt, str):
741
+ batch_size = 1
742
+ elif prompt is not None and isinstance(prompt, list):
743
+ batch_size = len(prompt)
744
+ else:
745
+ batch_size = prompt_embeds.shape[0]
746
+
747
+ device = self._execution_device
748
+
749
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
750
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
751
+ # corresponds to doing no classifier free guidance.
752
+ do_classifier_free_guidance = guidance_scale > 1.0
753
+
754
+ # 3. Encode input prompt
755
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
756
+ prompt,
757
+ do_classifier_free_guidance,
758
+ negative_prompt=negative_prompt,
759
+ num_images_per_prompt=num_images_per_prompt,
760
+ device=device,
761
+ prompt_embeds=prompt_embeds,
762
+ negative_prompt_embeds=negative_prompt_embeds,
763
+ clean_caption=clean_caption,
764
+ mask_feature=mask_feature,
765
+ )
766
+ if do_classifier_free_guidance:
767
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
768
+
769
+ # 4. Prepare timesteps
770
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
771
+ self._num_timesteps = len(timesteps)
772
+
773
+ # 5. Prepare latents.
774
+ latent_channels = self.transformer.config.in_channels
775
+ latents = self.prepare_latents(
776
+ batch_size * num_images_per_prompt,
777
+ latent_channels,
778
+ video_length,
779
+ height,
780
+ width,
781
+ prompt_embeds.dtype,
782
+ device,
783
+ generator,
784
+ latents,
785
+ )
786
+
787
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
788
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
789
+
790
+ # 7. Denoising loop
791
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
792
+
793
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
794
+ for i, t in enumerate(timesteps):
795
+ if self.interrupt:
796
+ continue
797
+
798
+ self._current_timestep = t
799
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
800
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
801
+
802
+ current_timestep = t
803
+ if not torch.is_tensor(current_timestep):
804
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
805
+ # This would be a good case for the `match` statement (Python 3.10+)
806
+ is_mps = latent_model_input.device.type == "mps"
807
+ is_npu = latent_model_input.device.type == "npu"
808
+ if isinstance(current_timestep, float):
809
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
810
+ else:
811
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
812
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
813
+ elif len(current_timestep.shape) == 0:
814
+ current_timestep = current_timestep[None].to(latent_model_input.device)
815
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
816
+ current_timestep = current_timestep.expand(latent_model_input.shape[0])
817
+
818
+ # predict noise model_output
819
+ noise_pred = self.transformer(
820
+ hidden_states=latent_model_input,
821
+ encoder_hidden_states=prompt_embeds,
822
+ timestep=current_timestep,
823
+ enable_temporal_attentions=enable_temporal_attentions,
824
+ return_dict=False,
825
+ )[0]
826
+
827
+ # perform guidance
828
+ if do_classifier_free_guidance:
829
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
830
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
831
+
832
+ # use learned sigma?
833
+ if not (
834
+ hasattr(self.scheduler.config, "variance_type")
835
+ and self.scheduler.config.variance_type in ["learned", "learned_range"]
836
+ ):
837
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
838
+
839
+ # compute previous video: x_t -> x_t-1
840
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
841
+
842
+ # call the callback, if provided
843
+ if callback_on_step_end is not None:
844
+ callback_kwargs = {}
845
+ for k in callback_on_step_end_tensor_inputs:
846
+ callback_kwargs[k] = locals()[k]
847
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
848
+
849
+ latents = callback_outputs.pop("latents", latents)
850
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
851
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
852
+
853
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
854
+ progress_bar.update()
855
+
856
+ if XLA_AVAILABLE:
857
+ xm.mark_step()
858
+
859
+ self._current_timestep = None
860
+
861
+ if output_type == "latents":
862
+ deprecation_message = (
863
+ "Passing `output_type='latents'` is deprecated. Please pass `output_type='latent'` instead."
864
+ )
865
+ deprecate("output_type_latents", "1.0.0", deprecation_message, standard_warn=False)
866
+ output_type = "latent"
867
+
868
+ if not output_type == "latent":
869
+ video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size)
870
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
871
+ else:
872
+ video = latents
873
+
874
+ # Offload all models
875
+ self.maybe_free_model_hooks()
876
+
877
+ if not return_dict:
878
+ return (video,)
879
+
880
+ return LattePipelineOutput(frames=video)
881
+
882
+ # Similar to diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion.decode_latents
883
+ def decode_latents(self, latents: torch.Tensor, video_length: int, decode_chunk_size: int = 14):
884
+ # [batch, channels, frames, height, width] -> [batch*frames, channels, height, width]
885
+ latents = latents.permute(0, 2, 1, 3, 4).flatten(0, 1)
886
+
887
+ latents = 1 / self.vae.config.scaling_factor * latents
888
+
889
+ forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
890
+ accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
891
+
892
+ # decode decode_chunk_size frames at a time to avoid OOM
893
+ frames = []
894
+ for i in range(0, latents.shape[0], decode_chunk_size):
895
+ num_frames_in = latents[i : i + decode_chunk_size].shape[0]
896
+ decode_kwargs = {}
897
+ if accepts_num_frames:
898
+ # we only pass num_frames_in if it's expected
899
+ decode_kwargs["num_frames"] = num_frames_in
900
+
901
+ frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
902
+ frames.append(frame)
903
+ frames = torch.cat(frames, dim=0)
904
+
905
+ # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
906
+ frames = frames.reshape(-1, video_length, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
907
+
908
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
909
+ frames = frames.float()
910
+ return frames
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/qwenimage/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.85 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _import_structure = {}
15
+
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["pipeline_sana"] = ["SanaPipeline"]
26
+ _import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"]
27
+ _import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
28
+ _import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"]
29
+
30
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
31
+ try:
32
+ if not (is_transformers_available() and is_torch_available()):
33
+ raise OptionalDependencyNotAvailable()
34
+
35
+ except OptionalDependencyNotAvailable:
36
+ from ...utils.dummy_torch_and_transformers_objects import *
37
+ else:
38
+ from .pipeline_sana import SanaPipeline
39
+ from .pipeline_sana_controlnet import SanaControlNetPipeline
40
+ from .pipeline_sana_sprint import SanaSprintPipeline
41
+ from .pipeline_sana_sprint_img2img import SanaSprintImg2ImgPipeline
42
+ else:
43
+ import sys
44
+
45
+ sys.modules[__name__] = _LazyModule(
46
+ __name__,
47
+ globals()["__file__"],
48
+ _import_structure,
49
+ module_spec=__spec__,
50
+ )
51
+
52
+ for name, value in _dummy_objects.items():
53
+ setattr(sys.modules[__name__], name, value)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/__pycache__/pipeline_output.cpython-310.pyc ADDED
Binary file (995 Bytes). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/__pycache__/pipeline_sana.cpython-310.pyc ADDED
Binary file (32.7 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/__pycache__/pipeline_sana_controlnet.cpython-310.pyc ADDED
Binary file (35.8 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/__pycache__/pipeline_sana_sprint.cpython-310.pyc ADDED
Binary file (29.3 kB). View file
 
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/pipeline_output.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+
7
+ from ...utils import BaseOutput
8
+
9
+
10
+ @dataclass
11
+ class SanaPipelineOutput(BaseOutput):
12
+ """
13
+ Output class for Sana pipelines.
14
+
15
+ Args:
16
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
17
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
18
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
19
+ """
20
+
21
+ images: Union[List[PIL.Image.Image], np.ndarray]
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/pipeline_sana.py ADDED
@@ -0,0 +1,1011 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ import inspect
17
+ import re
18
+ import urllib.parse as ul
19
+ import warnings
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
24
+
25
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
26
+ from ...image_processor import PixArtImageProcessor
27
+ from ...loaders import SanaLoraLoaderMixin
28
+ from ...models import AutoencoderDC, SanaTransformer2DModel
29
+ from ...schedulers import DPMSolverMultistepScheduler
30
+ from ...utils import (
31
+ BACKENDS_MAPPING,
32
+ USE_PEFT_BACKEND,
33
+ is_bs4_available,
34
+ is_ftfy_available,
35
+ is_torch_xla_available,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
42
+ from ..pipeline_utils import DiffusionPipeline
43
+ from ..pixart_alpha.pipeline_pixart_alpha import (
44
+ ASPECT_RATIO_512_BIN,
45
+ ASPECT_RATIO_1024_BIN,
46
+ )
47
+ from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
48
+ from .pipeline_output import SanaPipelineOutput
49
+
50
+
51
+ if is_torch_xla_available():
52
+ import torch_xla.core.xla_model as xm
53
+
54
+ XLA_AVAILABLE = True
55
+ else:
56
+ XLA_AVAILABLE = False
57
+
58
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
59
+
60
+ if is_bs4_available():
61
+ from bs4 import BeautifulSoup
62
+
63
+ if is_ftfy_available():
64
+ import ftfy
65
+
66
+
67
+ ASPECT_RATIO_4096_BIN = {
68
+ "0.25": [2048.0, 8192.0],
69
+ "0.26": [2048.0, 7936.0],
70
+ "0.27": [2048.0, 7680.0],
71
+ "0.28": [2048.0, 7424.0],
72
+ "0.32": [2304.0, 7168.0],
73
+ "0.33": [2304.0, 6912.0],
74
+ "0.35": [2304.0, 6656.0],
75
+ "0.4": [2560.0, 6400.0],
76
+ "0.42": [2560.0, 6144.0],
77
+ "0.48": [2816.0, 5888.0],
78
+ "0.5": [2816.0, 5632.0],
79
+ "0.52": [2816.0, 5376.0],
80
+ "0.57": [3072.0, 5376.0],
81
+ "0.6": [3072.0, 5120.0],
82
+ "0.68": [3328.0, 4864.0],
83
+ "0.72": [3328.0, 4608.0],
84
+ "0.78": [3584.0, 4608.0],
85
+ "0.82": [3584.0, 4352.0],
86
+ "0.88": [3840.0, 4352.0],
87
+ "0.94": [3840.0, 4096.0],
88
+ "1.0": [4096.0, 4096.0],
89
+ "1.07": [4096.0, 3840.0],
90
+ "1.13": [4352.0, 3840.0],
91
+ "1.21": [4352.0, 3584.0],
92
+ "1.29": [4608.0, 3584.0],
93
+ "1.38": [4608.0, 3328.0],
94
+ "1.46": [4864.0, 3328.0],
95
+ "1.67": [5120.0, 3072.0],
96
+ "1.75": [5376.0, 3072.0],
97
+ "2.0": [5632.0, 2816.0],
98
+ "2.09": [5888.0, 2816.0],
99
+ "2.4": [6144.0, 2560.0],
100
+ "2.5": [6400.0, 2560.0],
101
+ "2.89": [6656.0, 2304.0],
102
+ "3.0": [6912.0, 2304.0],
103
+ "3.11": [7168.0, 2304.0],
104
+ "3.62": [7424.0, 2048.0],
105
+ "3.75": [7680.0, 2048.0],
106
+ "3.88": [7936.0, 2048.0],
107
+ "4.0": [8192.0, 2048.0],
108
+ }
109
+
110
+ EXAMPLE_DOC_STRING = """
111
+ Examples:
112
+ ```py
113
+ >>> import torch
114
+ >>> from diffusers import SanaPipeline
115
+
116
+ >>> pipe = SanaPipeline.from_pretrained(
117
+ ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32
118
+ ... )
119
+ >>> pipe.to("cuda")
120
+ >>> pipe.text_encoder.to(torch.bfloat16)
121
+ >>> pipe.transformer = pipe.transformer.to(torch.bfloat16)
122
+
123
+ >>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0]
124
+ >>> image[0].save("output.png")
125
+ ```
126
+ """
127
+
128
+
129
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
130
+ def retrieve_timesteps(
131
+ scheduler,
132
+ num_inference_steps: Optional[int] = None,
133
+ device: Optional[Union[str, torch.device]] = None,
134
+ timesteps: Optional[List[int]] = None,
135
+ sigmas: Optional[List[float]] = None,
136
+ **kwargs,
137
+ ):
138
+ r"""
139
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
140
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
141
+
142
+ Args:
143
+ scheduler (`SchedulerMixin`):
144
+ The scheduler to get timesteps from.
145
+ num_inference_steps (`int`):
146
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
147
+ must be `None`.
148
+ device (`str` or `torch.device`, *optional*):
149
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
150
+ timesteps (`List[int]`, *optional*):
151
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
152
+ `num_inference_steps` and `sigmas` must be `None`.
153
+ sigmas (`List[float]`, *optional*):
154
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
155
+ `num_inference_steps` and `timesteps` must be `None`.
156
+
157
+ Returns:
158
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
159
+ second element is the number of inference steps.
160
+ """
161
+ if timesteps is not None and sigmas is not None:
162
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
163
+ if timesteps is not None:
164
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
165
+ if not accepts_timesteps:
166
+ raise ValueError(
167
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
168
+ f" timestep schedules. Please check whether you are using the correct scheduler."
169
+ )
170
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
171
+ timesteps = scheduler.timesteps
172
+ num_inference_steps = len(timesteps)
173
+ elif sigmas is not None:
174
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
175
+ if not accept_sigmas:
176
+ raise ValueError(
177
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
178
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
179
+ )
180
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
181
+ timesteps = scheduler.timesteps
182
+ num_inference_steps = len(timesteps)
183
+ else:
184
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
185
+ timesteps = scheduler.timesteps
186
+ return timesteps, num_inference_steps
187
+
188
+
189
+ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
190
+ r"""
191
+ Pipeline for text-to-image generation using [Sana](https://huggingface.co/papers/2410.10629).
192
+ """
193
+
194
+ # fmt: off
195
+ bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
196
+ # fmt: on
197
+
198
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
199
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
200
+
201
+ def __init__(
202
+ self,
203
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
204
+ text_encoder: Gemma2PreTrainedModel,
205
+ vae: AutoencoderDC,
206
+ transformer: SanaTransformer2DModel,
207
+ scheduler: DPMSolverMultistepScheduler,
208
+ ):
209
+ super().__init__()
210
+
211
+ self.register_modules(
212
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
213
+ )
214
+
215
+ self.vae_scale_factor = (
216
+ 2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
217
+ if hasattr(self, "vae") and self.vae is not None
218
+ else 32
219
+ )
220
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
221
+
222
+ def enable_vae_slicing(self):
223
+ r"""
224
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
225
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
226
+ """
227
+ self.vae.enable_slicing()
228
+
229
+ def disable_vae_slicing(self):
230
+ r"""
231
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
232
+ computing decoding in one step.
233
+ """
234
+ self.vae.disable_slicing()
235
+
236
+ def enable_vae_tiling(self):
237
+ r"""
238
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
239
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
240
+ processing larger images.
241
+ """
242
+ self.vae.enable_tiling()
243
+
244
+ def disable_vae_tiling(self):
245
+ r"""
246
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
247
+ computing decoding in one step.
248
+ """
249
+ self.vae.disable_tiling()
250
+
251
+ def _get_gemma_prompt_embeds(
252
+ self,
253
+ prompt: Union[str, List[str]],
254
+ device: torch.device,
255
+ dtype: torch.dtype,
256
+ clean_caption: bool = False,
257
+ max_sequence_length: int = 300,
258
+ complex_human_instruction: Optional[List[str]] = None,
259
+ ):
260
+ r"""
261
+ Encodes the prompt into text encoder hidden states.
262
+
263
+ Args:
264
+ prompt (`str` or `List[str]`, *optional*):
265
+ prompt to be encoded
266
+ device: (`torch.device`, *optional*):
267
+ torch device to place the resulting embeddings on
268
+ clean_caption (`bool`, defaults to `False`):
269
+ If `True`, the function will preprocess and clean the provided caption before encoding.
270
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
271
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
272
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
273
+ the prompt.
274
+ """
275
+ prompt = [prompt] if isinstance(prompt, str) else prompt
276
+
277
+ if getattr(self, "tokenizer", None) is not None:
278
+ self.tokenizer.padding_side = "right"
279
+
280
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
281
+
282
+ # prepare complex human instruction
283
+ if not complex_human_instruction:
284
+ max_length_all = max_sequence_length
285
+ else:
286
+ chi_prompt = "\n".join(complex_human_instruction)
287
+ prompt = [chi_prompt + p for p in prompt]
288
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
289
+ max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
290
+
291
+ text_inputs = self.tokenizer(
292
+ prompt,
293
+ padding="max_length",
294
+ max_length=max_length_all,
295
+ truncation=True,
296
+ add_special_tokens=True,
297
+ return_tensors="pt",
298
+ )
299
+ text_input_ids = text_inputs.input_ids
300
+
301
+ prompt_attention_mask = text_inputs.attention_mask
302
+ prompt_attention_mask = prompt_attention_mask.to(device)
303
+
304
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
305
+ prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
306
+
307
+ return prompt_embeds, prompt_attention_mask
308
+
309
+ def encode_prompt(
310
+ self,
311
+ prompt: Union[str, List[str]],
312
+ do_classifier_free_guidance: bool = True,
313
+ negative_prompt: str = "",
314
+ num_images_per_prompt: int = 1,
315
+ device: Optional[torch.device] = None,
316
+ prompt_embeds: Optional[torch.Tensor] = None,
317
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
318
+ prompt_attention_mask: Optional[torch.Tensor] = None,
319
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
320
+ clean_caption: bool = False,
321
+ max_sequence_length: int = 300,
322
+ complex_human_instruction: Optional[List[str]] = None,
323
+ lora_scale: Optional[float] = None,
324
+ ):
325
+ r"""
326
+ Encodes the prompt into text encoder hidden states.
327
+
328
+ Args:
329
+ prompt (`str` or `List[str]`, *optional*):
330
+ prompt to be encoded
331
+ negative_prompt (`str` or `List[str]`, *optional*):
332
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
333
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
334
+ PixArt-Alpha, this should be "".
335
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
336
+ whether to use classifier free guidance or not
337
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
338
+ number of images that should be generated per prompt
339
+ device: (`torch.device`, *optional*):
340
+ torch device to place the resulting embeddings on
341
+ prompt_embeds (`torch.Tensor`, *optional*):
342
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
343
+ provided, text embeddings will be generated from `prompt` input argument.
344
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
345
+ Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string.
346
+ clean_caption (`bool`, defaults to `False`):
347
+ If `True`, the function will preprocess and clean the provided caption before encoding.
348
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
349
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
350
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
351
+ the prompt.
352
+ """
353
+
354
+ if device is None:
355
+ device = self._execution_device
356
+
357
+ if self.text_encoder is not None:
358
+ dtype = self.text_encoder.dtype
359
+ else:
360
+ dtype = None
361
+
362
+ # set lora scale so that monkey patched LoRA
363
+ # function of text encoder can correctly access it
364
+ if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
365
+ self._lora_scale = lora_scale
366
+
367
+ # dynamically adjust the LoRA scale
368
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
369
+ scale_lora_layers(self.text_encoder, lora_scale)
370
+
371
+ if prompt is not None and isinstance(prompt, str):
372
+ batch_size = 1
373
+ elif prompt is not None and isinstance(prompt, list):
374
+ batch_size = len(prompt)
375
+ else:
376
+ batch_size = prompt_embeds.shape[0]
377
+
378
+ if getattr(self, "tokenizer", None) is not None:
379
+ self.tokenizer.padding_side = "right"
380
+
381
+ # See Section 3.1. of the paper.
382
+ max_length = max_sequence_length
383
+ select_index = [0] + list(range(-max_length + 1, 0))
384
+
385
+ if prompt_embeds is None:
386
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
387
+ prompt=prompt,
388
+ device=device,
389
+ dtype=dtype,
390
+ clean_caption=clean_caption,
391
+ max_sequence_length=max_sequence_length,
392
+ complex_human_instruction=complex_human_instruction,
393
+ )
394
+
395
+ prompt_embeds = prompt_embeds[:, select_index]
396
+ prompt_attention_mask = prompt_attention_mask[:, select_index]
397
+
398
+ bs_embed, seq_len, _ = prompt_embeds.shape
399
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
400
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
401
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
402
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
403
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
404
+
405
+ # get unconditional embeddings for classifier free guidance
406
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
407
+ negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
408
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
409
+ prompt=negative_prompt,
410
+ device=device,
411
+ dtype=dtype,
412
+ clean_caption=clean_caption,
413
+ max_sequence_length=max_sequence_length,
414
+ complex_human_instruction=False,
415
+ )
416
+
417
+ if do_classifier_free_guidance:
418
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
419
+ seq_len = negative_prompt_embeds.shape[1]
420
+
421
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
422
+
423
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
424
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
425
+
426
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
427
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
428
+ else:
429
+ negative_prompt_embeds = None
430
+ negative_prompt_attention_mask = None
431
+
432
+ if self.text_encoder is not None:
433
+ if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
434
+ # Retrieve the original scale by scaling back the LoRA layers
435
+ unscale_lora_layers(self.text_encoder, lora_scale)
436
+
437
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
438
+
439
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
440
+ def prepare_extra_step_kwargs(self, generator, eta):
441
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
442
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
443
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
444
+ # and should be between [0, 1]
445
+
446
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
447
+ extra_step_kwargs = {}
448
+ if accepts_eta:
449
+ extra_step_kwargs["eta"] = eta
450
+
451
+ # check if the scheduler accepts generator
452
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
453
+ if accepts_generator:
454
+ extra_step_kwargs["generator"] = generator
455
+ return extra_step_kwargs
456
+
457
+ def check_inputs(
458
+ self,
459
+ prompt,
460
+ height,
461
+ width,
462
+ callback_on_step_end_tensor_inputs=None,
463
+ negative_prompt=None,
464
+ prompt_embeds=None,
465
+ negative_prompt_embeds=None,
466
+ prompt_attention_mask=None,
467
+ negative_prompt_attention_mask=None,
468
+ ):
469
+ if height % 32 != 0 or width % 32 != 0:
470
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
471
+
472
+ if callback_on_step_end_tensor_inputs is not None and not all(
473
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
474
+ ):
475
+ raise ValueError(
476
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
477
+ )
478
+
479
+ if prompt is not None and prompt_embeds is not None:
480
+ raise ValueError(
481
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
482
+ " only forward one of the two."
483
+ )
484
+ elif prompt is None and prompt_embeds is None:
485
+ raise ValueError(
486
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
487
+ )
488
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
489
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
490
+
491
+ if prompt is not None and negative_prompt_embeds is not None:
492
+ raise ValueError(
493
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
494
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
495
+ )
496
+
497
+ if negative_prompt is not None and negative_prompt_embeds is not None:
498
+ raise ValueError(
499
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
500
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
501
+ )
502
+
503
+ if prompt_embeds is not None and prompt_attention_mask is None:
504
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
505
+
506
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
507
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
508
+
509
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
510
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
511
+ raise ValueError(
512
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
513
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
514
+ f" {negative_prompt_embeds.shape}."
515
+ )
516
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
517
+ raise ValueError(
518
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
519
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
520
+ f" {negative_prompt_attention_mask.shape}."
521
+ )
522
+
523
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
524
+ def _text_preprocessing(self, text, clean_caption=False):
525
+ if clean_caption and not is_bs4_available():
526
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
527
+ logger.warning("Setting `clean_caption` to False...")
528
+ clean_caption = False
529
+
530
+ if clean_caption and not is_ftfy_available():
531
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
532
+ logger.warning("Setting `clean_caption` to False...")
533
+ clean_caption = False
534
+
535
+ if not isinstance(text, (tuple, list)):
536
+ text = [text]
537
+
538
+ def process(text: str):
539
+ if clean_caption:
540
+ text = self._clean_caption(text)
541
+ text = self._clean_caption(text)
542
+ else:
543
+ text = text.lower().strip()
544
+ return text
545
+
546
+ return [process(t) for t in text]
547
+
548
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
549
+ def _clean_caption(self, caption):
550
+ caption = str(caption)
551
+ caption = ul.unquote_plus(caption)
552
+ caption = caption.strip().lower()
553
+ caption = re.sub("<person>", "person", caption)
554
+ # urls:
555
+ caption = re.sub(
556
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
557
+ "",
558
+ caption,
559
+ ) # regex for urls
560
+ caption = re.sub(
561
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
562
+ "",
563
+ caption,
564
+ ) # regex for urls
565
+ # html:
566
+ caption = BeautifulSoup(caption, features="html.parser").text
567
+
568
+ # @<nickname>
569
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
570
+
571
+ # 31C0—31EF CJK Strokes
572
+ # 31F0—31FF Katakana Phonetic Extensions
573
+ # 3200—32FF Enclosed CJK Letters and Months
574
+ # 3300—33FF CJK Compatibility
575
+ # 3400—4DBF CJK Unified Ideographs Extension A
576
+ # 4DC0—4DFF Yijing Hexagram Symbols
577
+ # 4E00—9FFF CJK Unified Ideographs
578
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
579
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
580
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
581
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
582
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
583
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
584
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
585
+ #######################################################
586
+
587
+ # все виды тире / all types of dash --> "-"
588
+ caption = re.sub(
589
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
590
+ "-",
591
+ caption,
592
+ )
593
+
594
+ # кавычки к одному стандарту
595
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
596
+ caption = re.sub(r"[‘’]", "'", caption)
597
+
598
+ # &quot;
599
+ caption = re.sub(r"&quot;?", "", caption)
600
+ # &amp
601
+ caption = re.sub(r"&amp", "", caption)
602
+
603
+ # ip addresses:
604
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
605
+
606
+ # article ids:
607
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
608
+
609
+ # \n
610
+ caption = re.sub(r"\\n", " ", caption)
611
+
612
+ # "#123"
613
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
614
+ # "#12345.."
615
+ caption = re.sub(r"#\d{5,}\b", "", caption)
616
+ # "123456.."
617
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
618
+ # filenames:
619
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
620
+
621
+ #
622
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
623
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
624
+
625
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
626
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
627
+
628
+ # this-is-my-cute-cat / this_is_my_cute_cat
629
+ regex2 = re.compile(r"(?:\-|\_)")
630
+ if len(re.findall(regex2, caption)) > 3:
631
+ caption = re.sub(regex2, " ", caption)
632
+
633
+ caption = ftfy.fix_text(caption)
634
+ caption = html.unescape(html.unescape(caption))
635
+
636
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
637
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
638
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
639
+
640
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
641
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
642
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
643
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
644
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
645
+
646
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
647
+
648
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
649
+
650
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
651
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
652
+ caption = re.sub(r"\s+", " ", caption)
653
+
654
+ caption.strip()
655
+
656
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
657
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
658
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
659
+ caption = re.sub(r"^\.\S+$", "", caption)
660
+
661
+ return caption.strip()
662
+
663
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
664
+ if latents is not None:
665
+ return latents.to(device=device, dtype=dtype)
666
+
667
+ shape = (
668
+ batch_size,
669
+ num_channels_latents,
670
+ int(height) // self.vae_scale_factor,
671
+ int(width) // self.vae_scale_factor,
672
+ )
673
+ if isinstance(generator, list) and len(generator) != batch_size:
674
+ raise ValueError(
675
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
676
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
677
+ )
678
+
679
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
680
+ return latents
681
+
682
+ @property
683
+ def guidance_scale(self):
684
+ return self._guidance_scale
685
+
686
+ @property
687
+ def attention_kwargs(self):
688
+ return self._attention_kwargs
689
+
690
+ @property
691
+ def do_classifier_free_guidance(self):
692
+ return self._guidance_scale > 1.0
693
+
694
+ @property
695
+ def num_timesteps(self):
696
+ return self._num_timesteps
697
+
698
+ @property
699
+ def interrupt(self):
700
+ return self._interrupt
701
+
702
+ @torch.no_grad()
703
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
704
+ def __call__(
705
+ self,
706
+ prompt: Union[str, List[str]] = None,
707
+ negative_prompt: str = "",
708
+ num_inference_steps: int = 20,
709
+ timesteps: List[int] = None,
710
+ sigmas: List[float] = None,
711
+ guidance_scale: float = 4.5,
712
+ num_images_per_prompt: Optional[int] = 1,
713
+ height: int = 1024,
714
+ width: int = 1024,
715
+ eta: float = 0.0,
716
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
717
+ latents: Optional[torch.Tensor] = None,
718
+ prompt_embeds: Optional[torch.Tensor] = None,
719
+ prompt_attention_mask: Optional[torch.Tensor] = None,
720
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
721
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
722
+ output_type: Optional[str] = "pil",
723
+ return_dict: bool = True,
724
+ clean_caption: bool = False,
725
+ use_resolution_binning: bool = True,
726
+ attention_kwargs: Optional[Dict[str, Any]] = None,
727
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
728
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
729
+ max_sequence_length: int = 300,
730
+ complex_human_instruction: List[str] = [
731
+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
732
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
733
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
734
+ "Here are examples of how to transform or refine prompts:",
735
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
736
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
737
+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
738
+ "User Prompt: ",
739
+ ],
740
+ ) -> Union[SanaPipelineOutput, Tuple]:
741
+ """
742
+ Function invoked when calling the pipeline for generation.
743
+
744
+ Args:
745
+ prompt (`str` or `List[str]`, *optional*):
746
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
747
+ instead.
748
+ negative_prompt (`str` or `List[str]`, *optional*):
749
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
750
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
751
+ less than `1`).
752
+ num_inference_steps (`int`, *optional*, defaults to 20):
753
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
754
+ expense of slower inference.
755
+ timesteps (`List[int]`, *optional*):
756
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
757
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
758
+ passed will be used. Must be in descending order.
759
+ sigmas (`List[float]`, *optional*):
760
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
761
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
762
+ will be used.
763
+ guidance_scale (`float`, *optional*, defaults to 4.5):
764
+ Guidance scale as defined in [Classifier-Free Diffusion
765
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
766
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
767
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
768
+ the text `prompt`, usually at the expense of lower image quality.
769
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
770
+ The number of images to generate per prompt.
771
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
772
+ The height in pixels of the generated image.
773
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
774
+ The width in pixels of the generated image.
775
+ eta (`float`, *optional*, defaults to 0.0):
776
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
777
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
778
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
779
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
780
+ to make generation deterministic.
781
+ latents (`torch.Tensor`, *optional*):
782
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
783
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
784
+ tensor will be generated by sampling using the supplied random `generator`.
785
+ prompt_embeds (`torch.Tensor`, *optional*):
786
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
787
+ provided, text embeddings will be generated from `prompt` input argument.
788
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
789
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
790
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
791
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
792
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
793
+ Pre-generated attention mask for negative text embeddings.
794
+ output_type (`str`, *optional*, defaults to `"pil"`):
795
+ The output format of the generate image. Choose between
796
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
797
+ return_dict (`bool`, *optional*, defaults to `True`):
798
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
799
+ attention_kwargs:
800
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
801
+ `self.processor` in
802
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
803
+ clean_caption (`bool`, *optional*, defaults to `True`):
804
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
805
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
806
+ prompt.
807
+ use_resolution_binning (`bool` defaults to `True`):
808
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
809
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
810
+ the requested resolution. Useful for generating non-square images.
811
+ callback_on_step_end (`Callable`, *optional*):
812
+ A function that calls at the end of each denoising steps during the inference. The function is called
813
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
814
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
815
+ `callback_on_step_end_tensor_inputs`.
816
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
817
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
818
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
819
+ `._callback_tensor_inputs` attribute of your pipeline class.
820
+ max_sequence_length (`int` defaults to `300`):
821
+ Maximum sequence length to use with the `prompt`.
822
+ complex_human_instruction (`List[str]`, *optional*):
823
+ Instructions for complex human attention:
824
+ https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
825
+
826
+ Examples:
827
+
828
+ Returns:
829
+ [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`:
830
+ If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned,
831
+ otherwise a `tuple` is returned where the first element is a list with the generated images
832
+ """
833
+
834
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
835
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
836
+
837
+ # 1. Check inputs. Raise error if not correct
838
+ if use_resolution_binning:
839
+ if self.transformer.config.sample_size == 128:
840
+ aspect_ratio_bin = ASPECT_RATIO_4096_BIN
841
+ elif self.transformer.config.sample_size == 64:
842
+ aspect_ratio_bin = ASPECT_RATIO_2048_BIN
843
+ elif self.transformer.config.sample_size == 32:
844
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
845
+ elif self.transformer.config.sample_size == 16:
846
+ aspect_ratio_bin = ASPECT_RATIO_512_BIN
847
+ else:
848
+ raise ValueError("Invalid sample size")
849
+ orig_height, orig_width = height, width
850
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
851
+
852
+ self.check_inputs(
853
+ prompt,
854
+ height,
855
+ width,
856
+ callback_on_step_end_tensor_inputs,
857
+ negative_prompt,
858
+ prompt_embeds,
859
+ negative_prompt_embeds,
860
+ prompt_attention_mask,
861
+ negative_prompt_attention_mask,
862
+ )
863
+
864
+ self._guidance_scale = guidance_scale
865
+ self._attention_kwargs = attention_kwargs
866
+ self._interrupt = False
867
+
868
+ # 2. Default height and width to transformer
869
+ if prompt is not None and isinstance(prompt, str):
870
+ batch_size = 1
871
+ elif prompt is not None and isinstance(prompt, list):
872
+ batch_size = len(prompt)
873
+ else:
874
+ batch_size = prompt_embeds.shape[0]
875
+
876
+ device = self._execution_device
877
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
878
+
879
+ # 3. Encode input prompt
880
+ (
881
+ prompt_embeds,
882
+ prompt_attention_mask,
883
+ negative_prompt_embeds,
884
+ negative_prompt_attention_mask,
885
+ ) = self.encode_prompt(
886
+ prompt,
887
+ self.do_classifier_free_guidance,
888
+ negative_prompt=negative_prompt,
889
+ num_images_per_prompt=num_images_per_prompt,
890
+ device=device,
891
+ prompt_embeds=prompt_embeds,
892
+ negative_prompt_embeds=negative_prompt_embeds,
893
+ prompt_attention_mask=prompt_attention_mask,
894
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
895
+ clean_caption=clean_caption,
896
+ max_sequence_length=max_sequence_length,
897
+ complex_human_instruction=complex_human_instruction,
898
+ lora_scale=lora_scale,
899
+ )
900
+ if self.do_classifier_free_guidance:
901
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
902
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
903
+
904
+ # 4. Prepare timesteps
905
+ timesteps, num_inference_steps = retrieve_timesteps(
906
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
907
+ )
908
+
909
+ # 5. Prepare latents.
910
+ latent_channels = self.transformer.config.in_channels
911
+ latents = self.prepare_latents(
912
+ batch_size * num_images_per_prompt,
913
+ latent_channels,
914
+ height,
915
+ width,
916
+ torch.float32,
917
+ device,
918
+ generator,
919
+ latents,
920
+ )
921
+
922
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
923
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
924
+
925
+ # 7. Denoising loop
926
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
927
+ self._num_timesteps = len(timesteps)
928
+
929
+ transformer_dtype = self.transformer.dtype
930
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
931
+ for i, t in enumerate(timesteps):
932
+ if self.interrupt:
933
+ continue
934
+
935
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
936
+
937
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
938
+ timestep = t.expand(latent_model_input.shape[0])
939
+ timestep = timestep * self.transformer.config.timestep_scale
940
+
941
+ # predict noise model_output
942
+ noise_pred = self.transformer(
943
+ latent_model_input.to(dtype=transformer_dtype),
944
+ encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
945
+ encoder_attention_mask=prompt_attention_mask,
946
+ timestep=timestep,
947
+ return_dict=False,
948
+ attention_kwargs=self.attention_kwargs,
949
+ )[0]
950
+ noise_pred = noise_pred.float()
951
+
952
+ # perform guidance
953
+ if self.do_classifier_free_guidance:
954
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
955
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
956
+
957
+ # learned sigma
958
+ if self.transformer.config.out_channels // 2 == latent_channels:
959
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
960
+
961
+ # compute previous image: x_t -> x_t-1
962
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
963
+
964
+ if callback_on_step_end is not None:
965
+ callback_kwargs = {}
966
+ for k in callback_on_step_end_tensor_inputs:
967
+ callback_kwargs[k] = locals()[k]
968
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
969
+
970
+ latents = callback_outputs.pop("latents", latents)
971
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
972
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
973
+
974
+ # call the callback, if provided
975
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
976
+ progress_bar.update()
977
+
978
+ if XLA_AVAILABLE:
979
+ xm.mark_step()
980
+
981
+ if output_type == "latent":
982
+ image = latents
983
+ else:
984
+ latents = latents.to(self.vae.dtype)
985
+ torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
986
+ oom_error = (
987
+ torch.OutOfMemoryError
988
+ if is_torch_version(">=", "2.5.0")
989
+ else torch_accelerator_module.OutOfMemoryError
990
+ )
991
+ try:
992
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
993
+ except oom_error as e:
994
+ warnings.warn(
995
+ f"{e}. \n"
996
+ f"Try to use VAE tiling for large images. For example: \n"
997
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
998
+ )
999
+ if use_resolution_binning:
1000
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
1001
+
1002
+ if not output_type == "latent":
1003
+ image = self.image_processor.postprocess(image, output_type=output_type)
1004
+
1005
+ # Offload all models
1006
+ self.maybe_free_model_hooks()
1007
+
1008
+ if not return_dict:
1009
+ return (image,)
1010
+
1011
+ return SanaPipelineOutput(images=image)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/pipeline_sana_controlnet.py ADDED
@@ -0,0 +1,1106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ import inspect
17
+ import re
18
+ import urllib.parse as ul
19
+ import warnings
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
24
+
25
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
26
+ from ...image_processor import PipelineImageInput, PixArtImageProcessor
27
+ from ...loaders import SanaLoraLoaderMixin
28
+ from ...models import AutoencoderDC, SanaControlNetModel, SanaTransformer2DModel
29
+ from ...schedulers import DPMSolverMultistepScheduler
30
+ from ...utils import (
31
+ BACKENDS_MAPPING,
32
+ USE_PEFT_BACKEND,
33
+ is_bs4_available,
34
+ is_ftfy_available,
35
+ is_torch_xla_available,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
42
+ from ..pipeline_utils import DiffusionPipeline
43
+ from ..pixart_alpha.pipeline_pixart_alpha import (
44
+ ASPECT_RATIO_512_BIN,
45
+ ASPECT_RATIO_1024_BIN,
46
+ )
47
+ from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
48
+ from .pipeline_output import SanaPipelineOutput
49
+
50
+
51
+ if is_torch_xla_available():
52
+ import torch_xla.core.xla_model as xm
53
+
54
+ XLA_AVAILABLE = True
55
+ else:
56
+ XLA_AVAILABLE = False
57
+
58
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
59
+
60
+ if is_bs4_available():
61
+ from bs4 import BeautifulSoup
62
+
63
+ if is_ftfy_available():
64
+ import ftfy
65
+
66
+
67
+ ASPECT_RATIO_4096_BIN = {
68
+ "0.25": [2048.0, 8192.0],
69
+ "0.26": [2048.0, 7936.0],
70
+ "0.27": [2048.0, 7680.0],
71
+ "0.28": [2048.0, 7424.0],
72
+ "0.32": [2304.0, 7168.0],
73
+ "0.33": [2304.0, 6912.0],
74
+ "0.35": [2304.0, 6656.0],
75
+ "0.4": [2560.0, 6400.0],
76
+ "0.42": [2560.0, 6144.0],
77
+ "0.48": [2816.0, 5888.0],
78
+ "0.5": [2816.0, 5632.0],
79
+ "0.52": [2816.0, 5376.0],
80
+ "0.57": [3072.0, 5376.0],
81
+ "0.6": [3072.0, 5120.0],
82
+ "0.68": [3328.0, 4864.0],
83
+ "0.72": [3328.0, 4608.0],
84
+ "0.78": [3584.0, 4608.0],
85
+ "0.82": [3584.0, 4352.0],
86
+ "0.88": [3840.0, 4352.0],
87
+ "0.94": [3840.0, 4096.0],
88
+ "1.0": [4096.0, 4096.0],
89
+ "1.07": [4096.0, 3840.0],
90
+ "1.13": [4352.0, 3840.0],
91
+ "1.21": [4352.0, 3584.0],
92
+ "1.29": [4608.0, 3584.0],
93
+ "1.38": [4608.0, 3328.0],
94
+ "1.46": [4864.0, 3328.0],
95
+ "1.67": [5120.0, 3072.0],
96
+ "1.75": [5376.0, 3072.0],
97
+ "2.0": [5632.0, 2816.0],
98
+ "2.09": [5888.0, 2816.0],
99
+ "2.4": [6144.0, 2560.0],
100
+ "2.5": [6400.0, 2560.0],
101
+ "2.89": [6656.0, 2304.0],
102
+ "3.0": [6912.0, 2304.0],
103
+ "3.11": [7168.0, 2304.0],
104
+ "3.62": [7424.0, 2048.0],
105
+ "3.75": [7680.0, 2048.0],
106
+ "3.88": [7936.0, 2048.0],
107
+ "4.0": [8192.0, 2048.0],
108
+ }
109
+
110
+ EXAMPLE_DOC_STRING = """
111
+ Examples:
112
+ ```py
113
+ >>> import torch
114
+ >>> from diffusers import SanaControlNetPipeline
115
+ >>> from diffusers.utils import load_image
116
+
117
+ >>> pipe = SanaControlNetPipeline.from_pretrained(
118
+ ... "ishan24/Sana_600M_1024px_ControlNetPlus_diffusers",
119
+ ... variant="fp16",
120
+ ... torch_dtype={"default": torch.bfloat16, "controlnet": torch.float16, "transformer": torch.float16},
121
+ ... device_map="balanced",
122
+ ... )
123
+ >>> cond_image = load_image(
124
+ ... "https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png"
125
+ ... )
126
+ >>> prompt = 'a cat with a neon sign that says "Sana"'
127
+ >>> image = pipe(
128
+ ... prompt,
129
+ ... control_image=cond_image,
130
+ ... ).images[0]
131
+ >>> image.save("output.png")
132
+ ```
133
+ """
134
+
135
+
136
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
137
+ def retrieve_timesteps(
138
+ scheduler,
139
+ num_inference_steps: Optional[int] = None,
140
+ device: Optional[Union[str, torch.device]] = None,
141
+ timesteps: Optional[List[int]] = None,
142
+ sigmas: Optional[List[float]] = None,
143
+ **kwargs,
144
+ ):
145
+ r"""
146
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
147
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
148
+
149
+ Args:
150
+ scheduler (`SchedulerMixin`):
151
+ The scheduler to get timesteps from.
152
+ num_inference_steps (`int`):
153
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
154
+ must be `None`.
155
+ device (`str` or `torch.device`, *optional*):
156
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
157
+ timesteps (`List[int]`, *optional*):
158
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
159
+ `num_inference_steps` and `sigmas` must be `None`.
160
+ sigmas (`List[float]`, *optional*):
161
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
162
+ `num_inference_steps` and `timesteps` must be `None`.
163
+
164
+ Returns:
165
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
166
+ second element is the number of inference steps.
167
+ """
168
+ if timesteps is not None and sigmas is not None:
169
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
170
+ if timesteps is not None:
171
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
172
+ if not accepts_timesteps:
173
+ raise ValueError(
174
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
175
+ f" timestep schedules. Please check whether you are using the correct scheduler."
176
+ )
177
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
178
+ timesteps = scheduler.timesteps
179
+ num_inference_steps = len(timesteps)
180
+ elif sigmas is not None:
181
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
182
+ if not accept_sigmas:
183
+ raise ValueError(
184
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
185
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
186
+ )
187
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
188
+ timesteps = scheduler.timesteps
189
+ num_inference_steps = len(timesteps)
190
+ else:
191
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
192
+ timesteps = scheduler.timesteps
193
+ return timesteps, num_inference_steps
194
+
195
+
196
+ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
197
+ r"""
198
+ Pipeline for text-to-image generation using [Sana](https://huggingface.co/papers/2410.10629).
199
+ """
200
+
201
+ # fmt: off
202
+ bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
203
+ # fmt: on
204
+
205
+ model_cpu_offload_seq = "text_encoder->controlnet->transformer->vae"
206
+ _callback_tensor_inputs = ["latents", "control_image", "prompt_embeds", "negative_prompt_embeds"]
207
+
208
+ def __init__(
209
+ self,
210
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
211
+ text_encoder: Gemma2PreTrainedModel,
212
+ vae: AutoencoderDC,
213
+ transformer: SanaTransformer2DModel,
214
+ controlnet: SanaControlNetModel,
215
+ scheduler: DPMSolverMultistepScheduler,
216
+ ):
217
+ super().__init__()
218
+
219
+ self.register_modules(
220
+ tokenizer=tokenizer,
221
+ text_encoder=text_encoder,
222
+ vae=vae,
223
+ transformer=transformer,
224
+ controlnet=controlnet,
225
+ scheduler=scheduler,
226
+ )
227
+
228
+ self.vae_scale_factor = (
229
+ 2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
230
+ if hasattr(self, "vae") and self.vae is not None
231
+ else 32
232
+ )
233
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
234
+
235
+ def enable_vae_slicing(self):
236
+ r"""
237
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
238
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
239
+ """
240
+ self.vae.enable_slicing()
241
+
242
+ def disable_vae_slicing(self):
243
+ r"""
244
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
245
+ computing decoding in one step.
246
+ """
247
+ self.vae.disable_slicing()
248
+
249
+ def enable_vae_tiling(self):
250
+ r"""
251
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
252
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
253
+ processing larger images.
254
+ """
255
+ self.vae.enable_tiling()
256
+
257
+ def disable_vae_tiling(self):
258
+ r"""
259
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
260
+ computing decoding in one step.
261
+ """
262
+ self.vae.disable_tiling()
263
+
264
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
265
+ def _get_gemma_prompt_embeds(
266
+ self,
267
+ prompt: Union[str, List[str]],
268
+ device: torch.device,
269
+ dtype: torch.dtype,
270
+ clean_caption: bool = False,
271
+ max_sequence_length: int = 300,
272
+ complex_human_instruction: Optional[List[str]] = None,
273
+ ):
274
+ r"""
275
+ Encodes the prompt into text encoder hidden states.
276
+
277
+ Args:
278
+ prompt (`str` or `List[str]`, *optional*):
279
+ prompt to be encoded
280
+ device: (`torch.device`, *optional*):
281
+ torch device to place the resulting embeddings on
282
+ clean_caption (`bool`, defaults to `False`):
283
+ If `True`, the function will preprocess and clean the provided caption before encoding.
284
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
285
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
286
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
287
+ the prompt.
288
+ """
289
+ prompt = [prompt] if isinstance(prompt, str) else prompt
290
+
291
+ if getattr(self, "tokenizer", None) is not None:
292
+ self.tokenizer.padding_side = "right"
293
+
294
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
295
+
296
+ # prepare complex human instruction
297
+ if not complex_human_instruction:
298
+ max_length_all = max_sequence_length
299
+ else:
300
+ chi_prompt = "\n".join(complex_human_instruction)
301
+ prompt = [chi_prompt + p for p in prompt]
302
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
303
+ max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
304
+
305
+ text_inputs = self.tokenizer(
306
+ prompt,
307
+ padding="max_length",
308
+ max_length=max_length_all,
309
+ truncation=True,
310
+ add_special_tokens=True,
311
+ return_tensors="pt",
312
+ )
313
+ text_input_ids = text_inputs.input_ids
314
+
315
+ prompt_attention_mask = text_inputs.attention_mask
316
+ prompt_attention_mask = prompt_attention_mask.to(device)
317
+
318
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
319
+ prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
320
+
321
+ return prompt_embeds, prompt_attention_mask
322
+
323
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.encode_prompt
324
+ def encode_prompt(
325
+ self,
326
+ prompt: Union[str, List[str]],
327
+ do_classifier_free_guidance: bool = True,
328
+ negative_prompt: str = "",
329
+ num_images_per_prompt: int = 1,
330
+ device: Optional[torch.device] = None,
331
+ prompt_embeds: Optional[torch.Tensor] = None,
332
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
333
+ prompt_attention_mask: Optional[torch.Tensor] = None,
334
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
335
+ clean_caption: bool = False,
336
+ max_sequence_length: int = 300,
337
+ complex_human_instruction: Optional[List[str]] = None,
338
+ lora_scale: Optional[float] = None,
339
+ ):
340
+ r"""
341
+ Encodes the prompt into text encoder hidden states.
342
+
343
+ Args:
344
+ prompt (`str` or `List[str]`, *optional*):
345
+ prompt to be encoded
346
+ negative_prompt (`str` or `List[str]`, *optional*):
347
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
348
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
349
+ PixArt-Alpha, this should be "".
350
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
351
+ whether to use classifier free guidance or not
352
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
353
+ number of images that should be generated per prompt
354
+ device: (`torch.device`, *optional*):
355
+ torch device to place the resulting embeddings on
356
+ prompt_embeds (`torch.Tensor`, *optional*):
357
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
358
+ provided, text embeddings will be generated from `prompt` input argument.
359
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
360
+ Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string.
361
+ clean_caption (`bool`, defaults to `False`):
362
+ If `True`, the function will preprocess and clean the provided caption before encoding.
363
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
364
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
365
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
366
+ the prompt.
367
+ """
368
+
369
+ if device is None:
370
+ device = self._execution_device
371
+
372
+ if self.text_encoder is not None:
373
+ dtype = self.text_encoder.dtype
374
+ else:
375
+ dtype = None
376
+
377
+ # set lora scale so that monkey patched LoRA
378
+ # function of text encoder can correctly access it
379
+ if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
380
+ self._lora_scale = lora_scale
381
+
382
+ # dynamically adjust the LoRA scale
383
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
384
+ scale_lora_layers(self.text_encoder, lora_scale)
385
+
386
+ if prompt is not None and isinstance(prompt, str):
387
+ batch_size = 1
388
+ elif prompt is not None and isinstance(prompt, list):
389
+ batch_size = len(prompt)
390
+ else:
391
+ batch_size = prompt_embeds.shape[0]
392
+
393
+ if getattr(self, "tokenizer", None) is not None:
394
+ self.tokenizer.padding_side = "right"
395
+
396
+ # See Section 3.1. of the paper.
397
+ max_length = max_sequence_length
398
+ select_index = [0] + list(range(-max_length + 1, 0))
399
+
400
+ if prompt_embeds is None:
401
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
402
+ prompt=prompt,
403
+ device=device,
404
+ dtype=dtype,
405
+ clean_caption=clean_caption,
406
+ max_sequence_length=max_sequence_length,
407
+ complex_human_instruction=complex_human_instruction,
408
+ )
409
+
410
+ prompt_embeds = prompt_embeds[:, select_index]
411
+ prompt_attention_mask = prompt_attention_mask[:, select_index]
412
+
413
+ bs_embed, seq_len, _ = prompt_embeds.shape
414
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
415
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
416
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
417
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
418
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
419
+
420
+ # get unconditional embeddings for classifier free guidance
421
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
422
+ negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
423
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
424
+ prompt=negative_prompt,
425
+ device=device,
426
+ dtype=dtype,
427
+ clean_caption=clean_caption,
428
+ max_sequence_length=max_sequence_length,
429
+ complex_human_instruction=False,
430
+ )
431
+
432
+ if do_classifier_free_guidance:
433
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
434
+ seq_len = negative_prompt_embeds.shape[1]
435
+
436
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
437
+
438
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
439
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
440
+
441
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
442
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
443
+ else:
444
+ negative_prompt_embeds = None
445
+ negative_prompt_attention_mask = None
446
+
447
+ if self.text_encoder is not None:
448
+ if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
449
+ # Retrieve the original scale by scaling back the LoRA layers
450
+ unscale_lora_layers(self.text_encoder, lora_scale)
451
+
452
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
453
+
454
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
455
+ def prepare_extra_step_kwargs(self, generator, eta):
456
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
457
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
458
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
459
+ # and should be between [0, 1]
460
+
461
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
462
+ extra_step_kwargs = {}
463
+ if accepts_eta:
464
+ extra_step_kwargs["eta"] = eta
465
+
466
+ # check if the scheduler accepts generator
467
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
468
+ if accepts_generator:
469
+ extra_step_kwargs["generator"] = generator
470
+ return extra_step_kwargs
471
+
472
+ def check_inputs(
473
+ self,
474
+ prompt,
475
+ height,
476
+ width,
477
+ callback_on_step_end_tensor_inputs=None,
478
+ negative_prompt=None,
479
+ prompt_embeds=None,
480
+ negative_prompt_embeds=None,
481
+ prompt_attention_mask=None,
482
+ negative_prompt_attention_mask=None,
483
+ ):
484
+ if height % 32 != 0 or width % 32 != 0:
485
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
486
+
487
+ if callback_on_step_end_tensor_inputs is not None and not all(
488
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
489
+ ):
490
+ raise ValueError(
491
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
492
+ )
493
+
494
+ if prompt is not None and prompt_embeds is not None:
495
+ raise ValueError(
496
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
497
+ " only forward one of the two."
498
+ )
499
+ elif prompt is None and prompt_embeds is None:
500
+ raise ValueError(
501
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
502
+ )
503
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
504
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
505
+
506
+ if prompt is not None and negative_prompt_embeds is not None:
507
+ raise ValueError(
508
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
509
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
510
+ )
511
+
512
+ if negative_prompt is not None and negative_prompt_embeds is not None:
513
+ raise ValueError(
514
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
515
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
516
+ )
517
+
518
+ if prompt_embeds is not None and prompt_attention_mask is None:
519
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
520
+
521
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
522
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
523
+
524
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
525
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
526
+ raise ValueError(
527
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
528
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
529
+ f" {negative_prompt_embeds.shape}."
530
+ )
531
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
532
+ raise ValueError(
533
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
534
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
535
+ f" {negative_prompt_attention_mask.shape}."
536
+ )
537
+
538
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
539
+ def _text_preprocessing(self, text, clean_caption=False):
540
+ if clean_caption and not is_bs4_available():
541
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
542
+ logger.warning("Setting `clean_caption` to False...")
543
+ clean_caption = False
544
+
545
+ if clean_caption and not is_ftfy_available():
546
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
547
+ logger.warning("Setting `clean_caption` to False...")
548
+ clean_caption = False
549
+
550
+ if not isinstance(text, (tuple, list)):
551
+ text = [text]
552
+
553
+ def process(text: str):
554
+ if clean_caption:
555
+ text = self._clean_caption(text)
556
+ text = self._clean_caption(text)
557
+ else:
558
+ text = text.lower().strip()
559
+ return text
560
+
561
+ return [process(t) for t in text]
562
+
563
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
564
+ def _clean_caption(self, caption):
565
+ caption = str(caption)
566
+ caption = ul.unquote_plus(caption)
567
+ caption = caption.strip().lower()
568
+ caption = re.sub("<person>", "person", caption)
569
+ # urls:
570
+ caption = re.sub(
571
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
572
+ "",
573
+ caption,
574
+ ) # regex for urls
575
+ caption = re.sub(
576
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
577
+ "",
578
+ caption,
579
+ ) # regex for urls
580
+ # html:
581
+ caption = BeautifulSoup(caption, features="html.parser").text
582
+
583
+ # @<nickname>
584
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
585
+
586
+ # 31C0—31EF CJK Strokes
587
+ # 31F0—31FF Katakana Phonetic Extensions
588
+ # 3200—32FF Enclosed CJK Letters and Months
589
+ # 3300—33FF CJK Compatibility
590
+ # 3400—4DBF CJK Unified Ideographs Extension A
591
+ # 4DC0—4DFF Yijing Hexagram Symbols
592
+ # 4E00—9FFF CJK Unified Ideographs
593
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
594
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
595
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
596
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
597
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
598
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
599
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
600
+ #######################################################
601
+
602
+ # все виды тире / all types of dash --> "-"
603
+ caption = re.sub(
604
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
605
+ "-",
606
+ caption,
607
+ )
608
+
609
+ # кавычки к одному стандарту
610
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
611
+ caption = re.sub(r"[‘’]", "'", caption)
612
+
613
+ # &quot;
614
+ caption = re.sub(r"&quot;?", "", caption)
615
+ # &amp
616
+ caption = re.sub(r"&amp", "", caption)
617
+
618
+ # ip addresses:
619
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
620
+
621
+ # article ids:
622
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
623
+
624
+ # \n
625
+ caption = re.sub(r"\\n", " ", caption)
626
+
627
+ # "#123"
628
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
629
+ # "#12345.."
630
+ caption = re.sub(r"#\d{5,}\b", "", caption)
631
+ # "123456.."
632
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
633
+ # filenames:
634
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
635
+
636
+ #
637
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
638
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
639
+
640
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
641
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
642
+
643
+ # this-is-my-cute-cat / this_is_my_cute_cat
644
+ regex2 = re.compile(r"(?:\-|\_)")
645
+ if len(re.findall(regex2, caption)) > 3:
646
+ caption = re.sub(regex2, " ", caption)
647
+
648
+ caption = ftfy.fix_text(caption)
649
+ caption = html.unescape(html.unescape(caption))
650
+
651
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
652
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
653
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
654
+
655
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
656
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
657
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
658
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
659
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
660
+
661
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
662
+
663
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
664
+
665
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
666
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
667
+ caption = re.sub(r"\s+", " ", caption)
668
+
669
+ caption.strip()
670
+
671
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
672
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
673
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
674
+ caption = re.sub(r"^\.\S+$", "", caption)
675
+
676
+ return caption.strip()
677
+
678
+ def prepare_image(
679
+ self,
680
+ image,
681
+ width,
682
+ height,
683
+ batch_size,
684
+ num_images_per_prompt,
685
+ device,
686
+ dtype,
687
+ do_classifier_free_guidance=False,
688
+ guess_mode=False,
689
+ ):
690
+ if isinstance(image, torch.Tensor):
691
+ pass
692
+ else:
693
+ image = self.image_processor.preprocess(image, height=height, width=width)
694
+
695
+ image_batch_size = image.shape[0]
696
+
697
+ if image_batch_size == 1:
698
+ repeat_by = batch_size
699
+ else:
700
+ # image batch size is the same as prompt batch size
701
+ repeat_by = num_images_per_prompt
702
+
703
+ image = image.repeat_interleave(repeat_by, dim=0)
704
+
705
+ image = image.to(device=device, dtype=dtype)
706
+
707
+ if do_classifier_free_guidance and not guess_mode:
708
+ image = torch.cat([image] * 2)
709
+
710
+ return image
711
+
712
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
713
+ if latents is not None:
714
+ return latents.to(device=device, dtype=dtype)
715
+
716
+ shape = (
717
+ batch_size,
718
+ num_channels_latents,
719
+ int(height) // self.vae_scale_factor,
720
+ int(width) // self.vae_scale_factor,
721
+ )
722
+ if isinstance(generator, list) and len(generator) != batch_size:
723
+ raise ValueError(
724
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
725
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
726
+ )
727
+
728
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
729
+ return latents
730
+
731
+ @property
732
+ def guidance_scale(self):
733
+ return self._guidance_scale
734
+
735
+ @property
736
+ def attention_kwargs(self):
737
+ return self._attention_kwargs
738
+
739
+ @property
740
+ def do_classifier_free_guidance(self):
741
+ return self._guidance_scale > 1.0
742
+
743
+ @property
744
+ def num_timesteps(self):
745
+ return self._num_timesteps
746
+
747
+ @property
748
+ def interrupt(self):
749
+ return self._interrupt
750
+
751
+ @torch.no_grad()
752
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
753
+ def __call__(
754
+ self,
755
+ prompt: Union[str, List[str]] = None,
756
+ negative_prompt: str = "",
757
+ num_inference_steps: int = 20,
758
+ timesteps: List[int] = None,
759
+ sigmas: List[float] = None,
760
+ guidance_scale: float = 4.5,
761
+ control_image: PipelineImageInput = None,
762
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
763
+ num_images_per_prompt: Optional[int] = 1,
764
+ height: int = 1024,
765
+ width: int = 1024,
766
+ eta: float = 0.0,
767
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
768
+ latents: Optional[torch.Tensor] = None,
769
+ prompt_embeds: Optional[torch.Tensor] = None,
770
+ prompt_attention_mask: Optional[torch.Tensor] = None,
771
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
772
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
773
+ output_type: Optional[str] = "pil",
774
+ return_dict: bool = True,
775
+ clean_caption: bool = False,
776
+ use_resolution_binning: bool = True,
777
+ attention_kwargs: Optional[Dict[str, Any]] = None,
778
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
779
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
780
+ max_sequence_length: int = 300,
781
+ complex_human_instruction: List[str] = [
782
+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
783
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
784
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
785
+ "Here are examples of how to transform or refine prompts:",
786
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
787
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
788
+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
789
+ "User Prompt: ",
790
+ ],
791
+ ) -> Union[SanaPipelineOutput, Tuple]:
792
+ """
793
+ Function invoked when calling the pipeline for generation.
794
+
795
+ Args:
796
+ prompt (`str` or `List[str]`, *optional*):
797
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
798
+ instead.
799
+ negative_prompt (`str` or `List[str]`, *optional*):
800
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
801
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
802
+ less than `1`).
803
+ num_inference_steps (`int`, *optional*, defaults to 20):
804
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
805
+ expense of slower inference.
806
+ timesteps (`List[int]`, *optional*):
807
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
808
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
809
+ passed will be used. Must be in descending order.
810
+ sigmas (`List[float]`, *optional*):
811
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
812
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
813
+ will be used.
814
+ guidance_scale (`float`, *optional*, defaults to 4.5):
815
+ Guidance scale as defined in [Classifier-Free Diffusion
816
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
817
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
818
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
819
+ the text `prompt`, usually at the expense of lower image quality.
820
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
821
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
822
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
823
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
824
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
825
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
826
+ images must be passed as a list such that each element of the list can be correctly batched for input
827
+ to a single ControlNet.
828
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
829
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
830
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
831
+ the corresponding scale as a list.
832
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
833
+ The number of images to generate per prompt.
834
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
835
+ The height in pixels of the generated image.
836
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
837
+ The width in pixels of the generated image.
838
+ eta (`float`, *optional*, defaults to 0.0):
839
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
840
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
841
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
842
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
843
+ to make generation deterministic.
844
+ latents (`torch.Tensor`, *optional*):
845
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
846
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
847
+ tensor will be generated by sampling using the supplied random `generator`.
848
+ prompt_embeds (`torch.Tensor`, *optional*):
849
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
850
+ provided, text embeddings will be generated from `prompt` input argument.
851
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
852
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
853
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
854
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
855
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
856
+ Pre-generated attention mask for negative text embeddings.
857
+ output_type (`str`, *optional*, defaults to `"pil"`):
858
+ The output format of the generate image. Choose between
859
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
860
+ return_dict (`bool`, *optional*, defaults to `True`):
861
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
862
+ attention_kwargs:
863
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
864
+ `self.processor` in
865
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
866
+ clean_caption (`bool`, *optional*, defaults to `True`):
867
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
868
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
869
+ prompt.
870
+ use_resolution_binning (`bool` defaults to `True`):
871
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
872
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
873
+ the requested resolution. Useful for generating non-square images.
874
+ callback_on_step_end (`Callable`, *optional*):
875
+ A function that calls at the end of each denoising steps during the inference. The function is called
876
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
877
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
878
+ `callback_on_step_end_tensor_inputs`.
879
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
880
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
881
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
882
+ `._callback_tensor_inputs` attribute of your pipeline class.
883
+ max_sequence_length (`int` defaults to `300`):
884
+ Maximum sequence length to use with the `prompt`.
885
+ complex_human_instruction (`List[str]`, *optional*):
886
+ Instructions for complex human attention:
887
+ https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
888
+
889
+ Examples:
890
+
891
+ Returns:
892
+ [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`:
893
+ If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned,
894
+ otherwise a `tuple` is returned where the first element is a list with the generated images
895
+ """
896
+
897
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
898
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
899
+
900
+ # 1. Check inputs. Raise error if not correct
901
+ if use_resolution_binning:
902
+ if self.transformer.config.sample_size == 128:
903
+ aspect_ratio_bin = ASPECT_RATIO_4096_BIN
904
+ elif self.transformer.config.sample_size == 64:
905
+ aspect_ratio_bin = ASPECT_RATIO_2048_BIN
906
+ elif self.transformer.config.sample_size == 32:
907
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
908
+ elif self.transformer.config.sample_size == 16:
909
+ aspect_ratio_bin = ASPECT_RATIO_512_BIN
910
+ else:
911
+ raise ValueError("Invalid sample size")
912
+ orig_height, orig_width = height, width
913
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
914
+
915
+ self.check_inputs(
916
+ prompt,
917
+ height,
918
+ width,
919
+ callback_on_step_end_tensor_inputs,
920
+ negative_prompt,
921
+ prompt_embeds,
922
+ negative_prompt_embeds,
923
+ prompt_attention_mask,
924
+ negative_prompt_attention_mask,
925
+ )
926
+
927
+ self._guidance_scale = guidance_scale
928
+ self._attention_kwargs = attention_kwargs
929
+ self._interrupt = False
930
+
931
+ # 2. Default height and width to transformer
932
+ if prompt is not None and isinstance(prompt, str):
933
+ batch_size = 1
934
+ elif prompt is not None and isinstance(prompt, list):
935
+ batch_size = len(prompt)
936
+ else:
937
+ batch_size = prompt_embeds.shape[0]
938
+
939
+ device = self._execution_device
940
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
941
+
942
+ # 3. Encode input prompt
943
+ (
944
+ prompt_embeds,
945
+ prompt_attention_mask,
946
+ negative_prompt_embeds,
947
+ negative_prompt_attention_mask,
948
+ ) = self.encode_prompt(
949
+ prompt,
950
+ self.do_classifier_free_guidance,
951
+ negative_prompt=negative_prompt,
952
+ num_images_per_prompt=num_images_per_prompt,
953
+ device=device,
954
+ prompt_embeds=prompt_embeds,
955
+ negative_prompt_embeds=negative_prompt_embeds,
956
+ prompt_attention_mask=prompt_attention_mask,
957
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
958
+ clean_caption=clean_caption,
959
+ max_sequence_length=max_sequence_length,
960
+ complex_human_instruction=complex_human_instruction,
961
+ lora_scale=lora_scale,
962
+ )
963
+ if self.do_classifier_free_guidance:
964
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
965
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
966
+
967
+ # 4. Prepare control image
968
+ if isinstance(self.controlnet, SanaControlNetModel):
969
+ control_image = self.prepare_image(
970
+ image=control_image,
971
+ width=width,
972
+ height=height,
973
+ batch_size=batch_size * num_images_per_prompt,
974
+ num_images_per_prompt=num_images_per_prompt,
975
+ device=device,
976
+ dtype=self.vae.dtype,
977
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
978
+ guess_mode=False,
979
+ )
980
+ height, width = control_image.shape[-2:]
981
+
982
+ control_image = self.vae.encode(control_image).latent
983
+ control_image = control_image * self.vae.config.scaling_factor
984
+ else:
985
+ raise ValueError("`controlnet` must be of type `SanaControlNetModel`.")
986
+
987
+ # 5. Prepare timesteps
988
+ timesteps, num_inference_steps = retrieve_timesteps(
989
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
990
+ )
991
+
992
+ # 6. Prepare latents.
993
+ latent_channels = self.transformer.config.in_channels
994
+ latents = self.prepare_latents(
995
+ batch_size * num_images_per_prompt,
996
+ latent_channels,
997
+ height,
998
+ width,
999
+ torch.float32,
1000
+ device,
1001
+ generator,
1002
+ latents,
1003
+ )
1004
+
1005
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1006
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1007
+
1008
+ # 8. Denoising loop
1009
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1010
+ self._num_timesteps = len(timesteps)
1011
+
1012
+ controlnet_dtype = self.controlnet.dtype
1013
+ transformer_dtype = self.transformer.dtype
1014
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1015
+ for i, t in enumerate(timesteps):
1016
+ if self.interrupt:
1017
+ continue
1018
+
1019
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1020
+
1021
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1022
+ timestep = t.expand(latent_model_input.shape[0])
1023
+
1024
+ # controlnet(s) inference
1025
+ controlnet_block_samples = self.controlnet(
1026
+ latent_model_input.to(dtype=controlnet_dtype),
1027
+ encoder_hidden_states=prompt_embeds.to(dtype=controlnet_dtype),
1028
+ encoder_attention_mask=prompt_attention_mask,
1029
+ timestep=timestep,
1030
+ return_dict=False,
1031
+ attention_kwargs=self.attention_kwargs,
1032
+ controlnet_cond=control_image,
1033
+ conditioning_scale=controlnet_conditioning_scale,
1034
+ )[0]
1035
+
1036
+ # predict noise model_output
1037
+ noise_pred = self.transformer(
1038
+ latent_model_input.to(dtype=transformer_dtype),
1039
+ encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
1040
+ encoder_attention_mask=prompt_attention_mask,
1041
+ timestep=timestep,
1042
+ return_dict=False,
1043
+ attention_kwargs=self.attention_kwargs,
1044
+ controlnet_block_samples=tuple(t.to(dtype=transformer_dtype) for t in controlnet_block_samples),
1045
+ )[0]
1046
+ noise_pred = noise_pred.float()
1047
+
1048
+ # perform guidance
1049
+ if self.do_classifier_free_guidance:
1050
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1051
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1052
+
1053
+ # learned sigma
1054
+ if self.transformer.config.out_channels // 2 == latent_channels:
1055
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
1056
+
1057
+ # compute previous image: x_t -> x_t-1
1058
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1059
+
1060
+ if callback_on_step_end is not None:
1061
+ callback_kwargs = {}
1062
+ for k in callback_on_step_end_tensor_inputs:
1063
+ callback_kwargs[k] = locals()[k]
1064
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1065
+
1066
+ latents = callback_outputs.pop("latents", latents)
1067
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1068
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1069
+
1070
+ # call the callback, if provided
1071
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1072
+ progress_bar.update()
1073
+
1074
+ if XLA_AVAILABLE:
1075
+ xm.mark_step()
1076
+
1077
+ if output_type == "latent":
1078
+ image = latents
1079
+ else:
1080
+ latents = latents.to(self.vae.dtype)
1081
+ torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
1082
+ oom_error = (
1083
+ torch.OutOfMemoryError
1084
+ if is_torch_version(">=", "2.5.0")
1085
+ else torch_accelerator_module.OutOfMemoryError
1086
+ )
1087
+ try:
1088
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1089
+ except oom_error as e:
1090
+ warnings.warn(
1091
+ f"{e}. \n"
1092
+ f"Try to use VAE tiling for large images. For example: \n"
1093
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
1094
+ )
1095
+ if use_resolution_binning:
1096
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
1097
+
1098
+ image = self.image_processor.postprocess(image, output_type=output_type)
1099
+
1100
+ # Offload all models
1101
+ self.maybe_free_model_hooks()
1102
+
1103
+ if not return_dict:
1104
+ return (image,)
1105
+
1106
+ return SanaPipelineOutput(images=image)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/pipeline_sana_sprint.py ADDED
@@ -0,0 +1,893 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ import inspect
17
+ import re
18
+ import urllib.parse as ul
19
+ import warnings
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
24
+
25
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
26
+ from ...image_processor import PixArtImageProcessor
27
+ from ...loaders import SanaLoraLoaderMixin
28
+ from ...models import AutoencoderDC, SanaTransformer2DModel
29
+ from ...schedulers import DPMSolverMultistepScheduler
30
+ from ...utils import (
31
+ BACKENDS_MAPPING,
32
+ USE_PEFT_BACKEND,
33
+ is_bs4_available,
34
+ is_ftfy_available,
35
+ is_torch_xla_available,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
42
+ from ..pipeline_utils import DiffusionPipeline
43
+ from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN
44
+ from .pipeline_output import SanaPipelineOutput
45
+
46
+
47
+ if is_torch_xla_available():
48
+ import torch_xla.core.xla_model as xm
49
+
50
+ XLA_AVAILABLE = True
51
+ else:
52
+ XLA_AVAILABLE = False
53
+
54
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
+
56
+ if is_bs4_available():
57
+ from bs4 import BeautifulSoup
58
+
59
+ if is_ftfy_available():
60
+ import ftfy
61
+
62
+
63
+ EXAMPLE_DOC_STRING = """
64
+ Examples:
65
+ ```py
66
+ >>> import torch
67
+ >>> from diffusers import SanaSprintPipeline
68
+
69
+ >>> pipe = SanaSprintPipeline.from_pretrained(
70
+ ... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16
71
+ ... )
72
+ >>> pipe.to("cuda")
73
+
74
+ >>> image = pipe(prompt="a tiny astronaut hatching from an egg on the moon")[0]
75
+ >>> image[0].save("output.png")
76
+ ```
77
+ """
78
+
79
+
80
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
81
+ def retrieve_timesteps(
82
+ scheduler,
83
+ num_inference_steps: Optional[int] = None,
84
+ device: Optional[Union[str, torch.device]] = None,
85
+ timesteps: Optional[List[int]] = None,
86
+ sigmas: Optional[List[float]] = None,
87
+ **kwargs,
88
+ ):
89
+ r"""
90
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
91
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
92
+
93
+ Args:
94
+ scheduler (`SchedulerMixin`):
95
+ The scheduler to get timesteps from.
96
+ num_inference_steps (`int`):
97
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
98
+ must be `None`.
99
+ device (`str` or `torch.device`, *optional*):
100
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
101
+ timesteps (`List[int]`, *optional*):
102
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
103
+ `num_inference_steps` and `sigmas` must be `None`.
104
+ sigmas (`List[float]`, *optional*):
105
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
106
+ `num_inference_steps` and `timesteps` must be `None`.
107
+
108
+ Returns:
109
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
110
+ second element is the number of inference steps.
111
+ """
112
+ if timesteps is not None and sigmas is not None:
113
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
114
+ if timesteps is not None:
115
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
116
+ if not accepts_timesteps:
117
+ raise ValueError(
118
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
119
+ f" timestep schedules. Please check whether you are using the correct scheduler."
120
+ )
121
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
122
+ timesteps = scheduler.timesteps
123
+ num_inference_steps = len(timesteps)
124
+ elif sigmas is not None:
125
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
126
+ if not accept_sigmas:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ else:
135
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
136
+ timesteps = scheduler.timesteps
137
+ return timesteps, num_inference_steps
138
+
139
+
140
+ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
141
+ r"""
142
+ Pipeline for text-to-image generation using [SANA-Sprint](https://huggingface.co/papers/2503.09641).
143
+ """
144
+
145
+ # fmt: off
146
+ bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
147
+ # fmt: on
148
+
149
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
150
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
151
+
152
+ def __init__(
153
+ self,
154
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
155
+ text_encoder: Gemma2PreTrainedModel,
156
+ vae: AutoencoderDC,
157
+ transformer: SanaTransformer2DModel,
158
+ scheduler: DPMSolverMultistepScheduler,
159
+ ):
160
+ super().__init__()
161
+
162
+ self.register_modules(
163
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
164
+ )
165
+
166
+ self.vae_scale_factor = (
167
+ 2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
168
+ if hasattr(self, "vae") and self.vae is not None
169
+ else 32
170
+ )
171
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
172
+
173
+ def enable_vae_slicing(self):
174
+ r"""
175
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
176
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
177
+ """
178
+ self.vae.enable_slicing()
179
+
180
+ def disable_vae_slicing(self):
181
+ r"""
182
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
183
+ computing decoding in one step.
184
+ """
185
+ self.vae.disable_slicing()
186
+
187
+ def enable_vae_tiling(self):
188
+ r"""
189
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
190
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
191
+ processing larger images.
192
+ """
193
+ self.vae.enable_tiling()
194
+
195
+ def disable_vae_tiling(self):
196
+ r"""
197
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
198
+ computing decoding in one step.
199
+ """
200
+ self.vae.disable_tiling()
201
+
202
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
203
+ def _get_gemma_prompt_embeds(
204
+ self,
205
+ prompt: Union[str, List[str]],
206
+ device: torch.device,
207
+ dtype: torch.dtype,
208
+ clean_caption: bool = False,
209
+ max_sequence_length: int = 300,
210
+ complex_human_instruction: Optional[List[str]] = None,
211
+ ):
212
+ r"""
213
+ Encodes the prompt into text encoder hidden states.
214
+
215
+ Args:
216
+ prompt (`str` or `List[str]`, *optional*):
217
+ prompt to be encoded
218
+ device: (`torch.device`, *optional*):
219
+ torch device to place the resulting embeddings on
220
+ clean_caption (`bool`, defaults to `False`):
221
+ If `True`, the function will preprocess and clean the provided caption before encoding.
222
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
223
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
224
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
225
+ the prompt.
226
+ """
227
+ prompt = [prompt] if isinstance(prompt, str) else prompt
228
+
229
+ if getattr(self, "tokenizer", None) is not None:
230
+ self.tokenizer.padding_side = "right"
231
+
232
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
233
+
234
+ # prepare complex human instruction
235
+ if not complex_human_instruction:
236
+ max_length_all = max_sequence_length
237
+ else:
238
+ chi_prompt = "\n".join(complex_human_instruction)
239
+ prompt = [chi_prompt + p for p in prompt]
240
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
241
+ max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
242
+
243
+ text_inputs = self.tokenizer(
244
+ prompt,
245
+ padding="max_length",
246
+ max_length=max_length_all,
247
+ truncation=True,
248
+ add_special_tokens=True,
249
+ return_tensors="pt",
250
+ )
251
+ text_input_ids = text_inputs.input_ids
252
+
253
+ prompt_attention_mask = text_inputs.attention_mask
254
+ prompt_attention_mask = prompt_attention_mask.to(device)
255
+
256
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
257
+ prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
258
+
259
+ return prompt_embeds, prompt_attention_mask
260
+
261
+ def encode_prompt(
262
+ self,
263
+ prompt: Union[str, List[str]],
264
+ num_images_per_prompt: int = 1,
265
+ device: Optional[torch.device] = None,
266
+ prompt_embeds: Optional[torch.Tensor] = None,
267
+ prompt_attention_mask: Optional[torch.Tensor] = None,
268
+ clean_caption: bool = False,
269
+ max_sequence_length: int = 300,
270
+ complex_human_instruction: Optional[List[str]] = None,
271
+ lora_scale: Optional[float] = None,
272
+ ):
273
+ r"""
274
+ Encodes the prompt into text encoder hidden states.
275
+
276
+ Args:
277
+ prompt (`str` or `List[str]`, *optional*):
278
+ prompt to be encoded
279
+
280
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
281
+ number of images that should be generated per prompt
282
+ device: (`torch.device`, *optional*):
283
+ torch device to place the resulting embeddings on
284
+ prompt_embeds (`torch.Tensor`, *optional*):
285
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
286
+ provided, text embeddings will be generated from `prompt` input argument.
287
+ clean_caption (`bool`, defaults to `False`):
288
+ If `True`, the function will preprocess and clean the provided caption before encoding.
289
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
290
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
291
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
292
+ the prompt.
293
+ """
294
+
295
+ if device is None:
296
+ device = self._execution_device
297
+
298
+ if self.text_encoder is not None:
299
+ dtype = self.text_encoder.dtype
300
+ else:
301
+ dtype = None
302
+
303
+ # set lora scale so that monkey patched LoRA
304
+ # function of text encoder can correctly access it
305
+ if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
306
+ self._lora_scale = lora_scale
307
+
308
+ # dynamically adjust the LoRA scale
309
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
310
+ scale_lora_layers(self.text_encoder, lora_scale)
311
+
312
+ if getattr(self, "tokenizer", None) is not None:
313
+ self.tokenizer.padding_side = "right"
314
+
315
+ # See Section 3.1. of the paper.
316
+ max_length = max_sequence_length
317
+ select_index = [0] + list(range(-max_length + 1, 0))
318
+
319
+ if prompt_embeds is None:
320
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
321
+ prompt=prompt,
322
+ device=device,
323
+ dtype=dtype,
324
+ clean_caption=clean_caption,
325
+ max_sequence_length=max_sequence_length,
326
+ complex_human_instruction=complex_human_instruction,
327
+ )
328
+
329
+ prompt_embeds = prompt_embeds[:, select_index]
330
+ prompt_attention_mask = prompt_attention_mask[:, select_index]
331
+
332
+ bs_embed, seq_len, _ = prompt_embeds.shape
333
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
334
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
335
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
336
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
337
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
338
+
339
+ if self.text_encoder is not None:
340
+ if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
341
+ # Retrieve the original scale by scaling back the LoRA layers
342
+ unscale_lora_layers(self.text_encoder, lora_scale)
343
+
344
+ return prompt_embeds, prompt_attention_mask
345
+
346
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
347
+ def prepare_extra_step_kwargs(self, generator, eta):
348
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
349
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
350
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
351
+ # and should be between [0, 1]
352
+
353
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
354
+ extra_step_kwargs = {}
355
+ if accepts_eta:
356
+ extra_step_kwargs["eta"] = eta
357
+
358
+ # check if the scheduler accepts generator
359
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
360
+ if accepts_generator:
361
+ extra_step_kwargs["generator"] = generator
362
+ return extra_step_kwargs
363
+
364
+ def check_inputs(
365
+ self,
366
+ prompt,
367
+ height,
368
+ width,
369
+ num_inference_steps,
370
+ timesteps,
371
+ max_timesteps,
372
+ intermediate_timesteps,
373
+ callback_on_step_end_tensor_inputs=None,
374
+ prompt_embeds=None,
375
+ prompt_attention_mask=None,
376
+ ):
377
+ if height % 32 != 0 or width % 32 != 0:
378
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
379
+
380
+ if callback_on_step_end_tensor_inputs is not None and not all(
381
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
382
+ ):
383
+ raise ValueError(
384
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
385
+ )
386
+
387
+ if prompt is not None and prompt_embeds is not None:
388
+ raise ValueError(
389
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
390
+ " only forward one of the two."
391
+ )
392
+ elif prompt is None and prompt_embeds is None:
393
+ raise ValueError(
394
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
395
+ )
396
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
397
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
398
+
399
+ if prompt_embeds is not None and prompt_attention_mask is None:
400
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
401
+
402
+ if timesteps is not None and len(timesteps) != num_inference_steps + 1:
403
+ raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
404
+
405
+ if timesteps is not None and max_timesteps is not None:
406
+ raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
407
+
408
+ if timesteps is None and max_timesteps is None:
409
+ raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
410
+
411
+ if intermediate_timesteps is not None and num_inference_steps != 2:
412
+ raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
413
+
414
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
415
+ def _text_preprocessing(self, text, clean_caption=False):
416
+ if clean_caption and not is_bs4_available():
417
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
418
+ logger.warning("Setting `clean_caption` to False...")
419
+ clean_caption = False
420
+
421
+ if clean_caption and not is_ftfy_available():
422
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
423
+ logger.warning("Setting `clean_caption` to False...")
424
+ clean_caption = False
425
+
426
+ if not isinstance(text, (tuple, list)):
427
+ text = [text]
428
+
429
+ def process(text: str):
430
+ if clean_caption:
431
+ text = self._clean_caption(text)
432
+ text = self._clean_caption(text)
433
+ else:
434
+ text = text.lower().strip()
435
+ return text
436
+
437
+ return [process(t) for t in text]
438
+
439
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
440
+ def _clean_caption(self, caption):
441
+ caption = str(caption)
442
+ caption = ul.unquote_plus(caption)
443
+ caption = caption.strip().lower()
444
+ caption = re.sub("<person>", "person", caption)
445
+ # urls:
446
+ caption = re.sub(
447
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
448
+ "",
449
+ caption,
450
+ ) # regex for urls
451
+ caption = re.sub(
452
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
453
+ "",
454
+ caption,
455
+ ) # regex for urls
456
+ # html:
457
+ caption = BeautifulSoup(caption, features="html.parser").text
458
+
459
+ # @<nickname>
460
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
461
+
462
+ # 31C0—31EF CJK Strokes
463
+ # 31F0—31FF Katakana Phonetic Extensions
464
+ # 3200—32FF Enclosed CJK Letters and Months
465
+ # 3300—33FF CJK Compatibility
466
+ # 3400—4DBF CJK Unified Ideographs Extension A
467
+ # 4DC0—4DFF Yijing Hexagram Symbols
468
+ # 4E00—9FFF CJK Unified Ideographs
469
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
470
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
471
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
472
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
473
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
474
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
475
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
476
+ #######################################################
477
+
478
+ # все виды тире / all types of dash --> "-"
479
+ caption = re.sub(
480
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
481
+ "-",
482
+ caption,
483
+ )
484
+
485
+ # кавычки к одному стандарту
486
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
487
+ caption = re.sub(r"[‘’]", "'", caption)
488
+
489
+ # &quot;
490
+ caption = re.sub(r"&quot;?", "", caption)
491
+ # &amp
492
+ caption = re.sub(r"&amp", "", caption)
493
+
494
+ # ip addresses:
495
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
496
+
497
+ # article ids:
498
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
499
+
500
+ # \n
501
+ caption = re.sub(r"\\n", " ", caption)
502
+
503
+ # "#123"
504
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
505
+ # "#12345.."
506
+ caption = re.sub(r"#\d{5,}\b", "", caption)
507
+ # "123456.."
508
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
509
+ # filenames:
510
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
511
+
512
+ #
513
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
514
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
515
+
516
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
517
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
518
+
519
+ # this-is-my-cute-cat / this_is_my_cute_cat
520
+ regex2 = re.compile(r"(?:\-|\_)")
521
+ if len(re.findall(regex2, caption)) > 3:
522
+ caption = re.sub(regex2, " ", caption)
523
+
524
+ caption = ftfy.fix_text(caption)
525
+ caption = html.unescape(html.unescape(caption))
526
+
527
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
528
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
529
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
530
+
531
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
532
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
533
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
534
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
535
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
536
+
537
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
538
+
539
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
540
+
541
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
542
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
543
+ caption = re.sub(r"\s+", " ", caption)
544
+
545
+ caption.strip()
546
+
547
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
548
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
549
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
550
+ caption = re.sub(r"^\.\S+$", "", caption)
551
+
552
+ return caption.strip()
553
+
554
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents
555
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
556
+ if latents is not None:
557
+ return latents.to(device=device, dtype=dtype)
558
+
559
+ shape = (
560
+ batch_size,
561
+ num_channels_latents,
562
+ int(height) // self.vae_scale_factor,
563
+ int(width) // self.vae_scale_factor,
564
+ )
565
+ if isinstance(generator, list) and len(generator) != batch_size:
566
+ raise ValueError(
567
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
568
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
569
+ )
570
+
571
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
572
+ return latents
573
+
574
+ @property
575
+ def guidance_scale(self):
576
+ return self._guidance_scale
577
+
578
+ @property
579
+ def attention_kwargs(self):
580
+ return self._attention_kwargs
581
+
582
+ @property
583
+ def num_timesteps(self):
584
+ return self._num_timesteps
585
+
586
+ @property
587
+ def interrupt(self):
588
+ return self._interrupt
589
+
590
+ @torch.no_grad()
591
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
592
+ def __call__(
593
+ self,
594
+ prompt: Union[str, List[str]] = None,
595
+ num_inference_steps: int = 2,
596
+ timesteps: List[int] = None,
597
+ max_timesteps: float = 1.57080,
598
+ intermediate_timesteps: float = 1.3,
599
+ guidance_scale: float = 4.5,
600
+ num_images_per_prompt: Optional[int] = 1,
601
+ height: int = 1024,
602
+ width: int = 1024,
603
+ eta: float = 0.0,
604
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
605
+ latents: Optional[torch.Tensor] = None,
606
+ prompt_embeds: Optional[torch.Tensor] = None,
607
+ prompt_attention_mask: Optional[torch.Tensor] = None,
608
+ output_type: Optional[str] = "pil",
609
+ return_dict: bool = True,
610
+ clean_caption: bool = False,
611
+ use_resolution_binning: bool = True,
612
+ attention_kwargs: Optional[Dict[str, Any]] = None,
613
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
614
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
615
+ max_sequence_length: int = 300,
616
+ complex_human_instruction: List[str] = [
617
+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
618
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
619
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
620
+ "Here are examples of how to transform or refine prompts:",
621
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
622
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
623
+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
624
+ "User Prompt: ",
625
+ ],
626
+ ) -> Union[SanaPipelineOutput, Tuple]:
627
+ """
628
+ Function invoked when calling the pipeline for generation.
629
+
630
+ Args:
631
+ prompt (`str` or `List[str]`, *optional*):
632
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
633
+ instead.
634
+ num_inference_steps (`int`, *optional*, defaults to 20):
635
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
636
+ expense of slower inference.
637
+ max_timesteps (`float`, *optional*, defaults to 1.57080):
638
+ The maximum timestep value used in the SCM scheduler.
639
+ intermediate_timesteps (`float`, *optional*, defaults to 1.3):
640
+ The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
641
+ timesteps (`List[int]`, *optional*):
642
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
643
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
644
+ passed will be used. Must be in descending order.
645
+ guidance_scale (`float`, *optional*, defaults to 4.5):
646
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
647
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
648
+
649
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
650
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
651
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
652
+ The number of images to generate per prompt.
653
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
654
+ The height in pixels of the generated image.
655
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
656
+ The width in pixels of the generated image.
657
+ eta (`float`, *optional*, defaults to 0.0):
658
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
659
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
660
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
661
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
662
+ to make generation deterministic.
663
+ latents (`torch.Tensor`, *optional*):
664
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
665
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
666
+ tensor will be generated by sampling using the supplied random `generator`.
667
+ prompt_embeds (`torch.Tensor`, *optional*):
668
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
669
+ provided, text embeddings will be generated from `prompt` input argument.
670
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
671
+ output_type (`str`, *optional*, defaults to `"pil"`):
672
+ The output format of the generate image. Choose between
673
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
674
+ return_dict (`bool`, *optional*, defaults to `True`):
675
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
676
+ attention_kwargs:
677
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
678
+ `self.processor` in
679
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
680
+ clean_caption (`bool`, *optional*, defaults to `True`):
681
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
682
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
683
+ prompt.
684
+ use_resolution_binning (`bool` defaults to `True`):
685
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
686
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
687
+ the requested resolution. Useful for generating non-square images.
688
+ callback_on_step_end (`Callable`, *optional*):
689
+ A function that calls at the end of each denoising steps during the inference. The function is called
690
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
691
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
692
+ `callback_on_step_end_tensor_inputs`.
693
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
694
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
695
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
696
+ `._callback_tensor_inputs` attribute of your pipeline class.
697
+ max_sequence_length (`int` defaults to `300`):
698
+ Maximum sequence length to use with the `prompt`.
699
+ complex_human_instruction (`List[str]`, *optional*):
700
+ Instructions for complex human attention:
701
+ https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
702
+
703
+ Examples:
704
+
705
+ Returns:
706
+ [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`:
707
+ If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned,
708
+ otherwise a `tuple` is returned where the first element is a list with the generated images
709
+ """
710
+
711
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
712
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
713
+
714
+ # 1. Check inputs. Raise error if not correct
715
+ if use_resolution_binning:
716
+ if self.transformer.config.sample_size == 32:
717
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
718
+ else:
719
+ raise ValueError("Invalid sample size")
720
+ orig_height, orig_width = height, width
721
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
722
+
723
+ self.check_inputs(
724
+ prompt=prompt,
725
+ height=height,
726
+ width=width,
727
+ num_inference_steps=num_inference_steps,
728
+ timesteps=timesteps,
729
+ max_timesteps=max_timesteps,
730
+ intermediate_timesteps=intermediate_timesteps,
731
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
732
+ prompt_embeds=prompt_embeds,
733
+ prompt_attention_mask=prompt_attention_mask,
734
+ )
735
+
736
+ self._guidance_scale = guidance_scale
737
+ self._attention_kwargs = attention_kwargs
738
+ self._interrupt = False
739
+
740
+ # 2. Default height and width to transformer
741
+ if prompt is not None and isinstance(prompt, str):
742
+ batch_size = 1
743
+ elif prompt is not None and isinstance(prompt, list):
744
+ batch_size = len(prompt)
745
+ else:
746
+ batch_size = prompt_embeds.shape[0]
747
+
748
+ device = self._execution_device
749
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
750
+
751
+ # 3. Encode input prompt
752
+ (
753
+ prompt_embeds,
754
+ prompt_attention_mask,
755
+ ) = self.encode_prompt(
756
+ prompt,
757
+ num_images_per_prompt=num_images_per_prompt,
758
+ device=device,
759
+ prompt_embeds=prompt_embeds,
760
+ prompt_attention_mask=prompt_attention_mask,
761
+ clean_caption=clean_caption,
762
+ max_sequence_length=max_sequence_length,
763
+ complex_human_instruction=complex_human_instruction,
764
+ lora_scale=lora_scale,
765
+ )
766
+
767
+ # 4. Prepare timesteps
768
+ timesteps, num_inference_steps = retrieve_timesteps(
769
+ self.scheduler,
770
+ num_inference_steps,
771
+ device,
772
+ timesteps,
773
+ sigmas=None,
774
+ max_timesteps=max_timesteps,
775
+ intermediate_timesteps=intermediate_timesteps,
776
+ )
777
+ if hasattr(self.scheduler, "set_begin_index"):
778
+ self.scheduler.set_begin_index(0)
779
+
780
+ # 5. Prepare latents.
781
+ latent_channels = self.transformer.config.in_channels
782
+ latents = self.prepare_latents(
783
+ batch_size * num_images_per_prompt,
784
+ latent_channels,
785
+ height,
786
+ width,
787
+ torch.float32,
788
+ device,
789
+ generator,
790
+ latents,
791
+ )
792
+
793
+ latents = latents * self.scheduler.config.sigma_data
794
+
795
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
796
+ guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype)
797
+ guidance = guidance * self.transformer.config.guidance_embeds_scale
798
+
799
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
800
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
801
+
802
+ # 7. Denoising loop
803
+ timesteps = timesteps[:-1]
804
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
805
+ self._num_timesteps = len(timesteps)
806
+
807
+ transformer_dtype = self.transformer.dtype
808
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
809
+ for i, t in enumerate(timesteps):
810
+ if self.interrupt:
811
+ continue
812
+
813
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
814
+ timestep = t.expand(latents.shape[0])
815
+ latents_model_input = latents / self.scheduler.config.sigma_data
816
+
817
+ scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep))
818
+
819
+ scm_timestep_expanded = scm_timestep.view(-1, 1, 1, 1)
820
+ latent_model_input = latents_model_input * torch.sqrt(
821
+ scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2
822
+ )
823
+
824
+ # predict noise model_output
825
+ noise_pred = self.transformer(
826
+ latent_model_input.to(dtype=transformer_dtype),
827
+ encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
828
+ encoder_attention_mask=prompt_attention_mask,
829
+ guidance=guidance,
830
+ timestep=scm_timestep,
831
+ return_dict=False,
832
+ attention_kwargs=self.attention_kwargs,
833
+ )[0]
834
+
835
+ noise_pred = (
836
+ (1 - 2 * scm_timestep_expanded) * latent_model_input
837
+ + (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded**2) * noise_pred
838
+ ) / torch.sqrt(scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2)
839
+ noise_pred = noise_pred.float() * self.scheduler.config.sigma_data
840
+
841
+ # compute previous image: x_t -> x_t-1
842
+ latents, denoised = self.scheduler.step(
843
+ noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False
844
+ )
845
+
846
+ if callback_on_step_end is not None:
847
+ callback_kwargs = {}
848
+ for k in callback_on_step_end_tensor_inputs:
849
+ callback_kwargs[k] = locals()[k]
850
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
851
+
852
+ latents = callback_outputs.pop("latents", latents)
853
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
854
+
855
+ # call the callback, if provided
856
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
857
+ progress_bar.update()
858
+
859
+ if XLA_AVAILABLE:
860
+ xm.mark_step()
861
+
862
+ latents = denoised / self.scheduler.config.sigma_data
863
+ if output_type == "latent":
864
+ image = latents
865
+ else:
866
+ latents = latents.to(self.vae.dtype)
867
+ torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
868
+ oom_error = (
869
+ torch.OutOfMemoryError
870
+ if is_torch_version(">=", "2.5.0")
871
+ else torch_accelerator_module.OutOfMemoryError
872
+ )
873
+ try:
874
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
875
+ except oom_error as e:
876
+ warnings.warn(
877
+ f"{e}. \n"
878
+ f"Try to use VAE tiling for large images. For example: \n"
879
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
880
+ )
881
+ if use_resolution_binning:
882
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
883
+
884
+ if not output_type == "latent":
885
+ image = self.image_processor.postprocess(image, output_type=output_type)
886
+
887
+ # Offload all models
888
+ self.maybe_free_model_hooks()
889
+
890
+ if not return_dict:
891
+ return (image,)
892
+
893
+ return SanaPipelineOutput(images=image)
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py ADDED
@@ -0,0 +1,981 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ import inspect
17
+ import re
18
+ import urllib.parse as ul
19
+ import warnings
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
25
+
26
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
27
+ from ...image_processor import PipelineImageInput, PixArtImageProcessor
28
+ from ...loaders import SanaLoraLoaderMixin
29
+ from ...models import AutoencoderDC, SanaTransformer2DModel
30
+ from ...schedulers import DPMSolverMultistepScheduler
31
+ from ...utils import (
32
+ BACKENDS_MAPPING,
33
+ USE_PEFT_BACKEND,
34
+ is_bs4_available,
35
+ is_ftfy_available,
36
+ is_torch_xla_available,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
43
+ from ..pipeline_utils import DiffusionPipeline
44
+ from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN
45
+ from .pipeline_output import SanaPipelineOutput
46
+
47
+
48
+ if is_torch_xla_available():
49
+ import torch_xla.core.xla_model as xm
50
+
51
+ XLA_AVAILABLE = True
52
+ else:
53
+ XLA_AVAILABLE = False
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+ if is_bs4_available():
58
+ from bs4 import BeautifulSoup
59
+
60
+ if is_ftfy_available():
61
+ import ftfy
62
+
63
+ EXAMPLE_DOC_STRING = """
64
+ Examples:
65
+ ```py
66
+ >>> import torch
67
+ >>> from diffusers import SanaSprintImg2ImgPipeline
68
+ >>> from diffusers.utils.loading_utils import load_image
69
+
70
+ >>> pipe = SanaSprintImg2ImgPipeline.from_pretrained(
71
+ ... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16
72
+ ... )
73
+ >>> pipe.to("cuda")
74
+
75
+ >>> image = load_image(
76
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
77
+ ... )
78
+
79
+
80
+ >>> image = pipe(prompt="a cute pink bear", image=image, strength=0.5, height=832, width=480).images[0]
81
+ >>> image[0].save("output.png")
82
+ ```
83
+ """
84
+
85
+
86
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
87
+ def retrieve_timesteps(
88
+ scheduler,
89
+ num_inference_steps: Optional[int] = None,
90
+ device: Optional[Union[str, torch.device]] = None,
91
+ timesteps: Optional[List[int]] = None,
92
+ sigmas: Optional[List[float]] = None,
93
+ **kwargs,
94
+ ):
95
+ r"""
96
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
97
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
98
+
99
+ Args:
100
+ scheduler (`SchedulerMixin`):
101
+ The scheduler to get timesteps from.
102
+ num_inference_steps (`int`):
103
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
104
+ must be `None`.
105
+ device (`str` or `torch.device`, *optional*):
106
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
107
+ timesteps (`List[int]`, *optional*):
108
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
109
+ `num_inference_steps` and `sigmas` must be `None`.
110
+ sigmas (`List[float]`, *optional*):
111
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
112
+ `num_inference_steps` and `timesteps` must be `None`.
113
+
114
+ Returns:
115
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
116
+ second element is the number of inference steps.
117
+ """
118
+ if timesteps is not None and sigmas is not None:
119
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
120
+ if timesteps is not None:
121
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
122
+ if not accepts_timesteps:
123
+ raise ValueError(
124
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
125
+ f" timestep schedules. Please check whether you are using the correct scheduler."
126
+ )
127
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
128
+ timesteps = scheduler.timesteps
129
+ num_inference_steps = len(timesteps)
130
+ elif sigmas is not None:
131
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
132
+ if not accept_sigmas:
133
+ raise ValueError(
134
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
135
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
136
+ )
137
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
138
+ timesteps = scheduler.timesteps
139
+ num_inference_steps = len(timesteps)
140
+ else:
141
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
142
+ timesteps = scheduler.timesteps
143
+ return timesteps, num_inference_steps
144
+
145
+
146
+ class SanaSprintImg2ImgPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
147
+ r"""
148
+ Pipeline for text-to-image generation using [SANA-Sprint](https://huggingface.co/papers/2503.09641).
149
+ """
150
+
151
+ # fmt: off
152
+ bad_punct_regex = re.compile(
153
+ r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
154
+ # fmt: on
155
+
156
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
157
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
158
+
159
+ def __init__(
160
+ self,
161
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
162
+ text_encoder: Gemma2PreTrainedModel,
163
+ vae: AutoencoderDC,
164
+ transformer: SanaTransformer2DModel,
165
+ scheduler: DPMSolverMultistepScheduler,
166
+ ):
167
+ super().__init__()
168
+
169
+ self.register_modules(
170
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
171
+ )
172
+
173
+ self.vae_scale_factor = (
174
+ 2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
175
+ if hasattr(self, "vae") and self.vae is not None
176
+ else 32
177
+ )
178
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
179
+
180
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.enable_vae_slicing
181
+ def enable_vae_slicing(self):
182
+ r"""
183
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
184
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
185
+ """
186
+ self.vae.enable_slicing()
187
+
188
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.disable_vae_slicing
189
+ def disable_vae_slicing(self):
190
+ r"""
191
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
192
+ computing decoding in one step.
193
+ """
194
+ self.vae.disable_slicing()
195
+
196
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.enable_vae_tiling
197
+ def enable_vae_tiling(self):
198
+ r"""
199
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
200
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
201
+ processing larger images.
202
+ """
203
+ self.vae.enable_tiling()
204
+
205
+ def disable_vae_tiling(self):
206
+ r"""
207
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
208
+ computing decoding in one step.
209
+ """
210
+ self.vae.disable_tiling()
211
+
212
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
213
+ def _get_gemma_prompt_embeds(
214
+ self,
215
+ prompt: Union[str, List[str]],
216
+ device: torch.device,
217
+ dtype: torch.dtype,
218
+ clean_caption: bool = False,
219
+ max_sequence_length: int = 300,
220
+ complex_human_instruction: Optional[List[str]] = None,
221
+ ):
222
+ r"""
223
+ Encodes the prompt into text encoder hidden states.
224
+
225
+ Args:
226
+ prompt (`str` or `List[str]`, *optional*):
227
+ prompt to be encoded
228
+ device: (`torch.device`, *optional*):
229
+ torch device to place the resulting embeddings on
230
+ clean_caption (`bool`, defaults to `False`):
231
+ If `True`, the function will preprocess and clean the provided caption before encoding.
232
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
233
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
234
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
235
+ the prompt.
236
+ """
237
+ prompt = [prompt] if isinstance(prompt, str) else prompt
238
+
239
+ if getattr(self, "tokenizer", None) is not None:
240
+ self.tokenizer.padding_side = "right"
241
+
242
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
243
+
244
+ # prepare complex human instruction
245
+ if not complex_human_instruction:
246
+ max_length_all = max_sequence_length
247
+ else:
248
+ chi_prompt = "\n".join(complex_human_instruction)
249
+ prompt = [chi_prompt + p for p in prompt]
250
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
251
+ max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
252
+
253
+ text_inputs = self.tokenizer(
254
+ prompt,
255
+ padding="max_length",
256
+ max_length=max_length_all,
257
+ truncation=True,
258
+ add_special_tokens=True,
259
+ return_tensors="pt",
260
+ )
261
+ text_input_ids = text_inputs.input_ids
262
+
263
+ prompt_attention_mask = text_inputs.attention_mask
264
+ prompt_attention_mask = prompt_attention_mask.to(device)
265
+
266
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
267
+ prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
268
+
269
+ return prompt_embeds, prompt_attention_mask
270
+
271
+ # Copied from diffusers.pipelines.sana.pipeline_sana_sprint.SanaSprintPipeline.encode_prompt
272
+ def encode_prompt(
273
+ self,
274
+ prompt: Union[str, List[str]],
275
+ num_images_per_prompt: int = 1,
276
+ device: Optional[torch.device] = None,
277
+ prompt_embeds: Optional[torch.Tensor] = None,
278
+ prompt_attention_mask: Optional[torch.Tensor] = None,
279
+ clean_caption: bool = False,
280
+ max_sequence_length: int = 300,
281
+ complex_human_instruction: Optional[List[str]] = None,
282
+ lora_scale: Optional[float] = None,
283
+ ):
284
+ r"""
285
+ Encodes the prompt into text encoder hidden states.
286
+
287
+ Args:
288
+ prompt (`str` or `List[str]`, *optional*):
289
+ prompt to be encoded
290
+
291
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
292
+ number of images that should be generated per prompt
293
+ device: (`torch.device`, *optional*):
294
+ torch device to place the resulting embeddings on
295
+ prompt_embeds (`torch.Tensor`, *optional*):
296
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
297
+ provided, text embeddings will be generated from `prompt` input argument.
298
+ clean_caption (`bool`, defaults to `False`):
299
+ If `True`, the function will preprocess and clean the provided caption before encoding.
300
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
301
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
302
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
303
+ the prompt.
304
+ """
305
+
306
+ if device is None:
307
+ device = self._execution_device
308
+
309
+ if self.text_encoder is not None:
310
+ dtype = self.text_encoder.dtype
311
+ else:
312
+ dtype = None
313
+
314
+ # set lora scale so that monkey patched LoRA
315
+ # function of text encoder can correctly access it
316
+ if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
317
+ self._lora_scale = lora_scale
318
+
319
+ # dynamically adjust the LoRA scale
320
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
321
+ scale_lora_layers(self.text_encoder, lora_scale)
322
+
323
+ if getattr(self, "tokenizer", None) is not None:
324
+ self.tokenizer.padding_side = "right"
325
+
326
+ # See Section 3.1. of the paper.
327
+ max_length = max_sequence_length
328
+ select_index = [0] + list(range(-max_length + 1, 0))
329
+
330
+ if prompt_embeds is None:
331
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
332
+ prompt=prompt,
333
+ device=device,
334
+ dtype=dtype,
335
+ clean_caption=clean_caption,
336
+ max_sequence_length=max_sequence_length,
337
+ complex_human_instruction=complex_human_instruction,
338
+ )
339
+
340
+ prompt_embeds = prompt_embeds[:, select_index]
341
+ prompt_attention_mask = prompt_attention_mask[:, select_index]
342
+
343
+ bs_embed, seq_len, _ = prompt_embeds.shape
344
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
345
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
346
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
347
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
348
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
349
+
350
+ if self.text_encoder is not None:
351
+ if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
352
+ # Retrieve the original scale by scaling back the LoRA layers
353
+ unscale_lora_layers(self.text_encoder, lora_scale)
354
+
355
+ return prompt_embeds, prompt_attention_mask
356
+
357
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
358
+ def prepare_extra_step_kwargs(self, generator, eta):
359
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
360
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
361
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
362
+ # and should be between [0, 1]
363
+
364
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
365
+ extra_step_kwargs = {}
366
+ if accepts_eta:
367
+ extra_step_kwargs["eta"] = eta
368
+
369
+ # check if the scheduler accepts generator
370
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
371
+ if accepts_generator:
372
+ extra_step_kwargs["generator"] = generator
373
+ return extra_step_kwargs
374
+
375
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
376
+ def get_timesteps(self, num_inference_steps, strength, device):
377
+ # get the original timestep using init_timestep
378
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
379
+
380
+ t_start = int(max(num_inference_steps - init_timestep, 0))
381
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
382
+ if hasattr(self.scheduler, "set_begin_index"):
383
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
384
+
385
+ return timesteps, num_inference_steps - t_start
386
+
387
+ def check_inputs(
388
+ self,
389
+ prompt,
390
+ strength,
391
+ height,
392
+ width,
393
+ num_inference_steps,
394
+ timesteps,
395
+ max_timesteps,
396
+ intermediate_timesteps,
397
+ callback_on_step_end_tensor_inputs=None,
398
+ prompt_embeds=None,
399
+ prompt_attention_mask=None,
400
+ ):
401
+ if strength < 0 or strength > 1:
402
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
403
+
404
+ if height % 32 != 0 or width % 32 != 0:
405
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
406
+
407
+ if callback_on_step_end_tensor_inputs is not None and not all(
408
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
409
+ ):
410
+ raise ValueError(
411
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
412
+ )
413
+
414
+ if prompt is not None and prompt_embeds is not None:
415
+ raise ValueError(
416
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
417
+ " only forward one of the two."
418
+ )
419
+ elif prompt is None and prompt_embeds is None:
420
+ raise ValueError(
421
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
422
+ )
423
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
424
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
425
+
426
+ if prompt_embeds is not None and prompt_attention_mask is None:
427
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
428
+
429
+ if timesteps is not None and len(timesteps) != num_inference_steps + 1:
430
+ raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
431
+
432
+ if timesteps is not None and max_timesteps is not None:
433
+ raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
434
+
435
+ if timesteps is None and max_timesteps is None:
436
+ raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
437
+
438
+ if intermediate_timesteps is not None and num_inference_steps != 2:
439
+ raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
440
+
441
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
442
+ def _text_preprocessing(self, text, clean_caption=False):
443
+ if clean_caption and not is_bs4_available():
444
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
445
+ logger.warning("Setting `clean_caption` to False...")
446
+ clean_caption = False
447
+
448
+ if clean_caption and not is_ftfy_available():
449
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
450
+ logger.warning("Setting `clean_caption` to False...")
451
+ clean_caption = False
452
+
453
+ if not isinstance(text, (tuple, list)):
454
+ text = [text]
455
+
456
+ def process(text: str):
457
+ if clean_caption:
458
+ text = self._clean_caption(text)
459
+ text = self._clean_caption(text)
460
+ else:
461
+ text = text.lower().strip()
462
+ return text
463
+
464
+ return [process(t) for t in text]
465
+
466
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
467
+ def _clean_caption(self, caption):
468
+ caption = str(caption)
469
+ caption = ul.unquote_plus(caption)
470
+ caption = caption.strip().lower()
471
+ caption = re.sub("<person>", "person", caption)
472
+ # urls:
473
+ caption = re.sub(
474
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
475
+ "",
476
+ caption,
477
+ ) # regex for urls
478
+ caption = re.sub(
479
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
480
+ "",
481
+ caption,
482
+ ) # regex for urls
483
+ # html:
484
+ caption = BeautifulSoup(caption, features="html.parser").text
485
+
486
+ # @<nickname>
487
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
488
+
489
+ # 31C0—31EF CJK Strokes
490
+ # 31F0—31FF Katakana Phonetic Extensions
491
+ # 3200—32FF Enclosed CJK Letters and Months
492
+ # 3300—33FF CJK Compatibility
493
+ # 3400—4DBF CJK Unified Ideographs Extension A
494
+ # 4DC0—4DFF Yijing Hexagram Symbols
495
+ # 4E00—9FFF CJK Unified Ideographs
496
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
497
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
498
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
499
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
500
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
501
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
502
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
503
+ #######################################################
504
+
505
+ # все виды тире / all types of dash --> "-"
506
+ caption = re.sub(
507
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
508
+ "-",
509
+ caption,
510
+ )
511
+
512
+ # кавычки к одному стандарту
513
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
514
+ caption = re.sub(r"[‘’]", "'", caption)
515
+
516
+ # &quot;
517
+ caption = re.sub(r"&quot;?", "", caption)
518
+ # &amp
519
+ caption = re.sub(r"&amp", "", caption)
520
+
521
+ # ip addresses:
522
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
523
+
524
+ # article ids:
525
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
526
+
527
+ # \n
528
+ caption = re.sub(r"\\n", " ", caption)
529
+
530
+ # "#123"
531
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
532
+ # "#12345.."
533
+ caption = re.sub(r"#\d{5,}\b", "", caption)
534
+ # "123456.."
535
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
536
+ # filenames:
537
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
538
+
539
+ #
540
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
541
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
542
+
543
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
544
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
545
+
546
+ # this-is-my-cute-cat / this_is_my_cute_cat
547
+ regex2 = re.compile(r"(?:\-|\_)")
548
+ if len(re.findall(regex2, caption)) > 3:
549
+ caption = re.sub(regex2, " ", caption)
550
+
551
+ caption = ftfy.fix_text(caption)
552
+ caption = html.unescape(html.unescape(caption))
553
+
554
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
555
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
556
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
557
+
558
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
559
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
560
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
561
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
562
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
563
+
564
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
565
+
566
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
567
+
568
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
569
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
570
+ caption = re.sub(r"\s+", " ", caption)
571
+
572
+ caption.strip()
573
+
574
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
575
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
576
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
577
+ caption = re.sub(r"^\.\S+$", "", caption)
578
+
579
+ return caption.strip()
580
+
581
+ def prepare_image(
582
+ self,
583
+ image: PipelineImageInput,
584
+ width: int,
585
+ height: int,
586
+ device: torch.device,
587
+ dtype: torch.dtype,
588
+ ):
589
+ if isinstance(image, torch.Tensor):
590
+ if image.ndim == 3:
591
+ image = image.unsqueeze(0)
592
+ # Resize if current dimensions do not match target dimensions.
593
+ if image.shape[2] != height or image.shape[3] != width:
594
+ image = F.interpolate(image, size=(height, width), mode="bilinear", align_corners=False)
595
+
596
+ image = self.image_processor.preprocess(image, height=height, width=width)
597
+
598
+ else:
599
+ image = self.image_processor.preprocess(image, height=height, width=width)
600
+
601
+ image = image.to(device=device, dtype=dtype)
602
+
603
+ return image
604
+
605
+ def prepare_latents(
606
+ self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None
607
+ ):
608
+ if latents is not None:
609
+ return latents.to(device=device, dtype=dtype)
610
+
611
+ shape = (
612
+ batch_size,
613
+ num_channels_latents,
614
+ int(height) // self.vae_scale_factor,
615
+ int(width) // self.vae_scale_factor,
616
+ )
617
+
618
+ if image.shape[1] != num_channels_latents:
619
+ image = self.vae.encode(image).latent
620
+ image_latents = image * self.vae.config.scaling_factor * self.scheduler.config.sigma_data
621
+ else:
622
+ image_latents = image
623
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
624
+ # expand init_latents for batch_size
625
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
626
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
627
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
628
+ raise ValueError(
629
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
630
+ )
631
+ else:
632
+ image_latents = torch.cat([image_latents], dim=0)
633
+
634
+ if isinstance(generator, list) and len(generator) != batch_size:
635
+ raise ValueError(
636
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
637
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
638
+ )
639
+
640
+ # adapt from https://github.com/huggingface/diffusers/blob/c36f8487df35895421c15f351c7d360bd680[…]/examples/research_projects/sana/train_sana_sprint_diffusers.py
641
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) * self.scheduler.config.sigma_data
642
+ latents = torch.cos(timestep) * image_latents + torch.sin(timestep) * noise
643
+ return latents
644
+
645
+ @property
646
+ def guidance_scale(self):
647
+ return self._guidance_scale
648
+
649
+ @property
650
+ def attention_kwargs(self):
651
+ return self._attention_kwargs
652
+
653
+ @property
654
+ def num_timesteps(self):
655
+ return self._num_timesteps
656
+
657
+ @property
658
+ def interrupt(self):
659
+ return self._interrupt
660
+
661
+ @torch.no_grad()
662
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
663
+ def __call__(
664
+ self,
665
+ prompt: Union[str, List[str]] = None,
666
+ num_inference_steps: int = 2,
667
+ timesteps: List[int] = None,
668
+ max_timesteps: float = 1.57080,
669
+ intermediate_timesteps: float = 1.3,
670
+ guidance_scale: float = 4.5,
671
+ image: PipelineImageInput = None,
672
+ strength: float = 0.6,
673
+ num_images_per_prompt: Optional[int] = 1,
674
+ height: int = 1024,
675
+ width: int = 1024,
676
+ eta: float = 0.0,
677
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
678
+ latents: Optional[torch.Tensor] = None,
679
+ prompt_embeds: Optional[torch.Tensor] = None,
680
+ prompt_attention_mask: Optional[torch.Tensor] = None,
681
+ output_type: Optional[str] = "pil",
682
+ return_dict: bool = True,
683
+ clean_caption: bool = False,
684
+ use_resolution_binning: bool = True,
685
+ attention_kwargs: Optional[Dict[str, Any]] = None,
686
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
687
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
688
+ max_sequence_length: int = 300,
689
+ complex_human_instruction: List[str] = [
690
+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
691
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
692
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
693
+ "Here are examples of how to transform or refine prompts:",
694
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
695
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
696
+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
697
+ "User Prompt: ",
698
+ ],
699
+ ) -> Union[SanaPipelineOutput, Tuple]:
700
+ """
701
+ Function invoked when calling the pipeline for generation.
702
+
703
+ Args:
704
+ prompt (`str` or `List[str]`, *optional*):
705
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
706
+ instead.
707
+ num_inference_steps (`int`, *optional*, defaults to 20):
708
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
709
+ expense of slower inference.
710
+ max_timesteps (`float`, *optional*, defaults to 1.57080):
711
+ The maximum timestep value used in the SCM scheduler.
712
+ intermediate_timesteps (`float`, *optional*, defaults to 1.3):
713
+ The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
714
+ timesteps (`List[int]`, *optional*):
715
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
716
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
717
+ passed will be used. Must be in descending order.
718
+ guidance_scale (`float`, *optional*, defaults to 4.5):
719
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
720
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
721
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
722
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
723
+ usually at the expense of lower image quality.
724
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
725
+ The number of images to generate per prompt.
726
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
727
+ The height in pixels of the generated image.
728
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
729
+ The width in pixels of the generated image.
730
+ eta (`float`, *optional*, defaults to 0.0):
731
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
732
+ [`schedulers.DDIMScheduler`], will be ignored for others.
733
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
734
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
735
+ to make generation deterministic.
736
+ latents (`torch.Tensor`, *optional*):
737
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
738
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
739
+ tensor will be generated by sampling using the supplied random `generator`.
740
+ prompt_embeds (`torch.Tensor`, *optional*):
741
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
742
+ provided, text embeddings will be generated from `prompt` input argument.
743
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
744
+ output_type (`str`, *optional*, defaults to `"pil"`):
745
+ The output format of the generate image. Choose between
746
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
747
+ return_dict (`bool`, *optional*, defaults to `True`):
748
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
749
+ attention_kwargs:
750
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
751
+ `self.processor` in
752
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
753
+ clean_caption (`bool`, *optional*, defaults to `True`):
754
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
755
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
756
+ prompt.
757
+ use_resolution_binning (`bool` defaults to `True`):
758
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
759
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
760
+ the requested resolution. Useful for generating non-square images.
761
+ callback_on_step_end (`Callable`, *optional*):
762
+ A function that calls at the end of each denoising steps during the inference. The function is called
763
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
764
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
765
+ `callback_on_step_end_tensor_inputs`.
766
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
767
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
768
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
769
+ `._callback_tensor_inputs` attribute of your pipeline class.
770
+ max_sequence_length (`int` defaults to `300`):
771
+ Maximum sequence length to use with the `prompt`.
772
+ complex_human_instruction (`List[str]`, *optional*):
773
+ Instructions for complex human attention:
774
+ https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
775
+
776
+ Examples:
777
+
778
+ Returns:
779
+ [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`:
780
+ If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned,
781
+ otherwise a `tuple` is returned where the first element is a list with the generated images
782
+ """
783
+
784
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
785
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
786
+
787
+ # 1. Check inputs. Raise error if not correct
788
+ if use_resolution_binning:
789
+ if self.transformer.config.sample_size == 32:
790
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
791
+ else:
792
+ raise ValueError("Invalid sample size")
793
+ orig_height, orig_width = height, width
794
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
795
+
796
+ self.check_inputs(
797
+ prompt=prompt,
798
+ strength=strength,
799
+ height=height,
800
+ width=width,
801
+ num_inference_steps=num_inference_steps,
802
+ timesteps=timesteps,
803
+ max_timesteps=max_timesteps,
804
+ intermediate_timesteps=intermediate_timesteps,
805
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
806
+ prompt_embeds=prompt_embeds,
807
+ prompt_attention_mask=prompt_attention_mask,
808
+ )
809
+
810
+ self._guidance_scale = guidance_scale
811
+ self._attention_kwargs = attention_kwargs
812
+ self._interrupt = False
813
+
814
+ # 2. Default height and width to transformer
815
+ if prompt is not None and isinstance(prompt, str):
816
+ batch_size = 1
817
+ elif prompt is not None and isinstance(prompt, list):
818
+ batch_size = len(prompt)
819
+ else:
820
+ batch_size = prompt_embeds.shape[0]
821
+
822
+ device = self._execution_device
823
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
824
+
825
+ # 2. Preprocess image
826
+ init_image = self.prepare_image(image, width, height, device, self.vae.dtype)
827
+
828
+ # 3. Encode input prompt
829
+ (
830
+ prompt_embeds,
831
+ prompt_attention_mask,
832
+ ) = self.encode_prompt(
833
+ prompt,
834
+ num_images_per_prompt=num_images_per_prompt,
835
+ device=device,
836
+ prompt_embeds=prompt_embeds,
837
+ prompt_attention_mask=prompt_attention_mask,
838
+ clean_caption=clean_caption,
839
+ max_sequence_length=max_sequence_length,
840
+ complex_human_instruction=complex_human_instruction,
841
+ lora_scale=lora_scale,
842
+ )
843
+
844
+ # 5. Prepare timesteps
845
+ timesteps, num_inference_steps = retrieve_timesteps(
846
+ self.scheduler,
847
+ num_inference_steps,
848
+ device,
849
+ timesteps,
850
+ sigmas=None,
851
+ max_timesteps=max_timesteps,
852
+ intermediate_timesteps=intermediate_timesteps,
853
+ )
854
+ if hasattr(self.scheduler, "set_begin_index"):
855
+ self.scheduler.set_begin_index(0)
856
+
857
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
858
+ if num_inference_steps < 1:
859
+ raise ValueError(
860
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
861
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
862
+ )
863
+ latent_timestep = timesteps[:1]
864
+
865
+ # 5. Prepare latents.
866
+ latent_channels = self.transformer.config.in_channels
867
+ latents = self.prepare_latents(
868
+ init_image,
869
+ latent_timestep,
870
+ batch_size * num_images_per_prompt,
871
+ latent_channels,
872
+ height,
873
+ width,
874
+ torch.float32,
875
+ device,
876
+ generator,
877
+ latents,
878
+ )
879
+
880
+ # I think this is redundant given the scaling in prepare_latents
881
+ # latents = latents * self.scheduler.config.sigma_data
882
+
883
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
884
+ guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype)
885
+ guidance = guidance * self.transformer.config.guidance_embeds_scale
886
+
887
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
888
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
889
+
890
+ # 7. Denoising loop
891
+ timesteps = timesteps[:-1]
892
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
893
+ self._num_timesteps = len(timesteps)
894
+
895
+ transformer_dtype = self.transformer.dtype
896
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
897
+ for i, t in enumerate(timesteps):
898
+ if self.interrupt:
899
+ continue
900
+
901
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
902
+ timestep = t.expand(latents.shape[0])
903
+ latents_model_input = latents / self.scheduler.config.sigma_data
904
+
905
+ scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep))
906
+
907
+ scm_timestep_expanded = scm_timestep.view(-1, 1, 1, 1)
908
+ latent_model_input = latents_model_input * torch.sqrt(
909
+ scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2
910
+ )
911
+
912
+ # predict noise model_output
913
+ noise_pred = self.transformer(
914
+ latent_model_input.to(dtype=transformer_dtype),
915
+ encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
916
+ encoder_attention_mask=prompt_attention_mask,
917
+ guidance=guidance,
918
+ timestep=scm_timestep,
919
+ return_dict=False,
920
+ attention_kwargs=self.attention_kwargs,
921
+ )[0]
922
+
923
+ noise_pred = (
924
+ (1 - 2 * scm_timestep_expanded) * latent_model_input
925
+ + (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded**2) * noise_pred
926
+ ) / torch.sqrt(scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2)
927
+ noise_pred = noise_pred.float() * self.scheduler.config.sigma_data
928
+
929
+ # compute previous image: x_t -> x_t-1
930
+ latents, denoised = self.scheduler.step(
931
+ noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False
932
+ )
933
+
934
+ if callback_on_step_end is not None:
935
+ callback_kwargs = {}
936
+ for k in callback_on_step_end_tensor_inputs:
937
+ callback_kwargs[k] = locals()[k]
938
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
939
+
940
+ latents = callback_outputs.pop("latents", latents)
941
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
942
+
943
+ # call the callback, if provided
944
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
945
+ progress_bar.update()
946
+
947
+ if XLA_AVAILABLE:
948
+ xm.mark_step()
949
+
950
+ latents = denoised / self.scheduler.config.sigma_data
951
+ if output_type == "latent":
952
+ image = latents
953
+ else:
954
+ latents = latents.to(self.vae.dtype)
955
+ torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
956
+ oom_error = (
957
+ torch.OutOfMemoryError
958
+ if is_torch_version(">=", "2.5.0")
959
+ else torch_accelerator_module.OutOfMemoryError
960
+ )
961
+ try:
962
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
963
+ except oom_error as e:
964
+ warnings.warn(
965
+ f"{e}. \n"
966
+ f"Try to use VAE tiling for large images. For example: \n"
967
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
968
+ )
969
+ if use_resolution_binning:
970
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
971
+
972
+ if not output_type == "latent":
973
+ image = self.image_processor.postprocess(image, output_type=output_type)
974
+
975
+ # Offload all models
976
+ self.maybe_free_model_hooks()
977
+
978
+ if not return_dict:
979
+ return (image,)
980
+
981
+ return SanaPipelineOutput(images=image)