Add files using upload-large-folder tool
Browse files- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_output.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/__init__.py +153 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__init__.py +53 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/modeling_roberta_series.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +124 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +990 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1045 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/alt_diffusion/pipeline_output.py +28 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/hunyuan_video/__init__.py +54 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/hunyuan_video/__pycache__/__init__.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/hunyuan_video/__pycache__/pipeline_hunyuan_skyreels_image2video.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/hunyuan_video/__pycache__/pipeline_hunyuan_video.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/hunyuan_video/__pycache__/pipeline_hunyuan_video_framepack.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/hunyuan_video/__pycache__/pipeline_hunyuan_video_image2video.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/hunyuan_video/__pycache__/pipeline_output.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/hunyuan_video/pipeline_output.py +39 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/i2vgen_xl/__init__.py +46 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/i2vgen_xl/__pycache__/__init__.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/i2vgen_xl/__pycache__/pipeline_i2vgen_xl.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +797 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/__init__.py +66 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_combined.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky.py +419 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +817 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +505 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +647 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +559 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/text_encoder.py +27 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/__pycache__/__init__.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/__pycache__/pipeline_sana_sprint_img2img.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/semantic_stable_diffusion/__init__.py +49 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_output.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/semantic_stable_diffusion/pipeline_output.py +25 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +733 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/shap_e/camera.py +147 -0
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (2.06 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-310.pyc
ADDED
|
Binary file (23.1 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-310.pyc
ADDED
|
Binary file (26.1 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-310.pyc
ADDED
|
Binary file (28.9 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-310.pyc
ADDED
|
Binary file (28.6 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-310.pyc
ADDED
|
Binary file (31.4 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-310.pyc
ADDED
|
Binary file (25.5 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_output.cpython-310.pyc
ADDED
|
Binary file (1.57 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-310.pyc
ADDED
|
Binary file (1.97 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-310.pyc
ADDED
|
Binary file (3.22 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-310.pyc
ADDED
|
Binary file (1.71 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/__init__.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
get_objects_from_module,
|
| 8 |
+
is_librosa_available,
|
| 9 |
+
is_note_seq_available,
|
| 10 |
+
is_torch_available,
|
| 11 |
+
is_transformers_available,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
_dummy_objects = {}
|
| 16 |
+
_import_structure = {}
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
if not is_torch_available():
|
| 20 |
+
raise OptionalDependencyNotAvailable()
|
| 21 |
+
except OptionalDependencyNotAvailable:
|
| 22 |
+
from ...utils import dummy_pt_objects
|
| 23 |
+
|
| 24 |
+
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
|
| 25 |
+
else:
|
| 26 |
+
_import_structure["latent_diffusion_uncond"] = ["LDMPipeline"]
|
| 27 |
+
_import_structure["pndm"] = ["PNDMPipeline"]
|
| 28 |
+
_import_structure["repaint"] = ["RePaintPipeline"]
|
| 29 |
+
_import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"]
|
| 30 |
+
_import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"]
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 34 |
+
raise OptionalDependencyNotAvailable()
|
| 35 |
+
except OptionalDependencyNotAvailable:
|
| 36 |
+
from ...utils import dummy_torch_and_transformers_objects
|
| 37 |
+
|
| 38 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
| 39 |
+
else:
|
| 40 |
+
_import_structure["alt_diffusion"] = [
|
| 41 |
+
"AltDiffusionImg2ImgPipeline",
|
| 42 |
+
"AltDiffusionPipeline",
|
| 43 |
+
"AltDiffusionPipelineOutput",
|
| 44 |
+
]
|
| 45 |
+
_import_structure["versatile_diffusion"] = [
|
| 46 |
+
"VersatileDiffusionDualGuidedPipeline",
|
| 47 |
+
"VersatileDiffusionImageVariationPipeline",
|
| 48 |
+
"VersatileDiffusionPipeline",
|
| 49 |
+
"VersatileDiffusionTextToImagePipeline",
|
| 50 |
+
]
|
| 51 |
+
_import_structure["vq_diffusion"] = ["VQDiffusionPipeline"]
|
| 52 |
+
_import_structure["stable_diffusion_variants"] = [
|
| 53 |
+
"CycleDiffusionPipeline",
|
| 54 |
+
"StableDiffusionInpaintPipelineLegacy",
|
| 55 |
+
"StableDiffusionPix2PixZeroPipeline",
|
| 56 |
+
"StableDiffusionParadigmsPipeline",
|
| 57 |
+
"StableDiffusionModelEditingPipeline",
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
if not (is_torch_available() and is_librosa_available()):
|
| 62 |
+
raise OptionalDependencyNotAvailable()
|
| 63 |
+
except OptionalDependencyNotAvailable:
|
| 64 |
+
from ...utils import dummy_torch_and_librosa_objects # noqa F403
|
| 65 |
+
|
| 66 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects))
|
| 67 |
+
|
| 68 |
+
else:
|
| 69 |
+
_import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"]
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
| 73 |
+
raise OptionalDependencyNotAvailable()
|
| 74 |
+
except OptionalDependencyNotAvailable:
|
| 75 |
+
from ...utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
|
| 76 |
+
|
| 77 |
+
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
|
| 78 |
+
|
| 79 |
+
else:
|
| 80 |
+
_import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 84 |
+
try:
|
| 85 |
+
if not is_torch_available():
|
| 86 |
+
raise OptionalDependencyNotAvailable()
|
| 87 |
+
except OptionalDependencyNotAvailable:
|
| 88 |
+
from ...utils.dummy_pt_objects import *
|
| 89 |
+
|
| 90 |
+
else:
|
| 91 |
+
from .latent_diffusion_uncond import LDMPipeline
|
| 92 |
+
from .pndm import PNDMPipeline
|
| 93 |
+
from .repaint import RePaintPipeline
|
| 94 |
+
from .score_sde_ve import ScoreSdeVePipeline
|
| 95 |
+
from .stochastic_karras_ve import KarrasVePipeline
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 99 |
+
raise OptionalDependencyNotAvailable()
|
| 100 |
+
except OptionalDependencyNotAvailable:
|
| 101 |
+
from ...utils.dummy_torch_and_transformers_objects import *
|
| 102 |
+
|
| 103 |
+
else:
|
| 104 |
+
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, AltDiffusionPipelineOutput
|
| 105 |
+
from .audio_diffusion import AudioDiffusionPipeline, Mel
|
| 106 |
+
from .spectrogram_diffusion import SpectrogramDiffusionPipeline
|
| 107 |
+
from .stable_diffusion_variants import (
|
| 108 |
+
CycleDiffusionPipeline,
|
| 109 |
+
StableDiffusionInpaintPipelineLegacy,
|
| 110 |
+
StableDiffusionModelEditingPipeline,
|
| 111 |
+
StableDiffusionParadigmsPipeline,
|
| 112 |
+
StableDiffusionPix2PixZeroPipeline,
|
| 113 |
+
)
|
| 114 |
+
from .stochastic_karras_ve import KarrasVePipeline
|
| 115 |
+
from .versatile_diffusion import (
|
| 116 |
+
VersatileDiffusionDualGuidedPipeline,
|
| 117 |
+
VersatileDiffusionImageVariationPipeline,
|
| 118 |
+
VersatileDiffusionPipeline,
|
| 119 |
+
VersatileDiffusionTextToImagePipeline,
|
| 120 |
+
)
|
| 121 |
+
from .vq_diffusion import VQDiffusionPipeline
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
if not (is_torch_available() and is_librosa_available()):
|
| 125 |
+
raise OptionalDependencyNotAvailable()
|
| 126 |
+
except OptionalDependencyNotAvailable:
|
| 127 |
+
from ...utils.dummy_torch_and_librosa_objects import *
|
| 128 |
+
else:
|
| 129 |
+
from .audio_diffusion import AudioDiffusionPipeline, Mel
|
| 130 |
+
|
| 131 |
+
try:
|
| 132 |
+
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
| 133 |
+
raise OptionalDependencyNotAvailable()
|
| 134 |
+
except OptionalDependencyNotAvailable:
|
| 135 |
+
from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
|
| 136 |
+
else:
|
| 137 |
+
from .spectrogram_diffusion import (
|
| 138 |
+
MidiProcessor,
|
| 139 |
+
SpectrogramDiffusionPipeline,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
else:
|
| 144 |
+
import sys
|
| 145 |
+
|
| 146 |
+
sys.modules[__name__] = _LazyModule(
|
| 147 |
+
__name__,
|
| 148 |
+
globals()["__file__"],
|
| 149 |
+
_import_structure,
|
| 150 |
+
module_spec=__spec__,
|
| 151 |
+
)
|
| 152 |
+
for name, value in _dummy_objects.items():
|
| 153 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__init__.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ....utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
get_objects_from_module,
|
| 8 |
+
is_torch_available,
|
| 9 |
+
is_transformers_available,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_dummy_objects = {}
|
| 14 |
+
_import_structure = {}
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 18 |
+
raise OptionalDependencyNotAvailable()
|
| 19 |
+
except OptionalDependencyNotAvailable:
|
| 20 |
+
from ....utils import dummy_torch_and_transformers_objects
|
| 21 |
+
|
| 22 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
| 23 |
+
else:
|
| 24 |
+
_import_structure["modeling_roberta_series"] = ["RobertaSeriesModelWithTransformation"]
|
| 25 |
+
_import_structure["pipeline_alt_diffusion"] = ["AltDiffusionPipeline"]
|
| 26 |
+
_import_structure["pipeline_alt_diffusion_img2img"] = ["AltDiffusionImg2ImgPipeline"]
|
| 27 |
+
|
| 28 |
+
_import_structure["pipeline_output"] = ["AltDiffusionPipelineOutput"]
|
| 29 |
+
|
| 30 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 31 |
+
try:
|
| 32 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 33 |
+
raise OptionalDependencyNotAvailable()
|
| 34 |
+
except OptionalDependencyNotAvailable:
|
| 35 |
+
from ....utils.dummy_torch_and_transformers_objects import *
|
| 36 |
+
|
| 37 |
+
else:
|
| 38 |
+
from .modeling_roberta_series import RobertaSeriesModelWithTransformation
|
| 39 |
+
from .pipeline_alt_diffusion import AltDiffusionPipeline
|
| 40 |
+
from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline
|
| 41 |
+
from .pipeline_output import AltDiffusionPipelineOutput
|
| 42 |
+
|
| 43 |
+
else:
|
| 44 |
+
import sys
|
| 45 |
+
|
| 46 |
+
sys.modules[__name__] = _LazyModule(
|
| 47 |
+
__name__,
|
| 48 |
+
globals()["__file__"],
|
| 49 |
+
_import_structure,
|
| 50 |
+
module_spec=__spec__,
|
| 51 |
+
)
|
| 52 |
+
for name, value in _dummy_objects.items():
|
| 53 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/modeling_roberta_series.cpython-310.pyc
ADDED
|
Binary file (4.85 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-310.pyc
ADDED
|
Binary file (33.6 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-310.pyc
ADDED
|
Binary file (36.1 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel
|
| 7 |
+
from transformers.utils import ModelOutput
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class TransformationModelOutput(ModelOutput):
|
| 12 |
+
"""
|
| 13 |
+
Base class for text model's outputs that also contains a pooling of the last hidden states.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
text_embeds (`torch.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
| 17 |
+
The text embeddings obtained by applying the projection layer to the pooler_output.
|
| 18 |
+
last_hidden_state (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 19 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 20 |
+
hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 21 |
+
Tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one
|
| 22 |
+
for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 23 |
+
|
| 24 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 25 |
+
attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 26 |
+
Tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 27 |
+
sequence_length)`.
|
| 28 |
+
|
| 29 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 30 |
+
heads.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
projection_state: Optional[torch.Tensor] = None
|
| 34 |
+
last_hidden_state: torch.Tensor = None
|
| 35 |
+
hidden_states: Optional[Tuple[torch.Tensor]] = None
|
| 36 |
+
attentions: Optional[Tuple[torch.Tensor]] = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class RobertaSeriesConfig(XLMRobertaConfig):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
pad_token_id=1,
|
| 43 |
+
bos_token_id=0,
|
| 44 |
+
eos_token_id=2,
|
| 45 |
+
project_dim=512,
|
| 46 |
+
pooler_fn="cls",
|
| 47 |
+
learn_encoder=False,
|
| 48 |
+
use_attention_mask=True,
|
| 49 |
+
**kwargs,
|
| 50 |
+
):
|
| 51 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 52 |
+
self.project_dim = project_dim
|
| 53 |
+
self.pooler_fn = pooler_fn
|
| 54 |
+
self.learn_encoder = learn_encoder
|
| 55 |
+
self.use_attention_mask = use_attention_mask
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
|
| 59 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"]
|
| 60 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
| 61 |
+
base_model_prefix = "roberta"
|
| 62 |
+
config_class = RobertaSeriesConfig
|
| 63 |
+
|
| 64 |
+
def __init__(self, config):
|
| 65 |
+
super().__init__(config)
|
| 66 |
+
self.roberta = XLMRobertaModel(config)
|
| 67 |
+
self.transformation = nn.Linear(config.hidden_size, config.project_dim)
|
| 68 |
+
self.has_pre_transformation = getattr(config, "has_pre_transformation", False)
|
| 69 |
+
if self.has_pre_transformation:
|
| 70 |
+
self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
|
| 71 |
+
self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 72 |
+
self.post_init()
|
| 73 |
+
|
| 74 |
+
def forward(
|
| 75 |
+
self,
|
| 76 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 77 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 78 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 79 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 80 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 81 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 82 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 83 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 84 |
+
output_attentions: Optional[bool] = None,
|
| 85 |
+
return_dict: Optional[bool] = None,
|
| 86 |
+
output_hidden_states: Optional[bool] = None,
|
| 87 |
+
):
|
| 88 |
+
r""" """
|
| 89 |
+
|
| 90 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 91 |
+
|
| 92 |
+
outputs = self.base_model(
|
| 93 |
+
input_ids=input_ids,
|
| 94 |
+
attention_mask=attention_mask,
|
| 95 |
+
token_type_ids=token_type_ids,
|
| 96 |
+
position_ids=position_ids,
|
| 97 |
+
head_mask=head_mask,
|
| 98 |
+
inputs_embeds=inputs_embeds,
|
| 99 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 100 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 101 |
+
output_attentions=output_attentions,
|
| 102 |
+
output_hidden_states=True if self.has_pre_transformation else output_hidden_states,
|
| 103 |
+
return_dict=return_dict,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if self.has_pre_transformation:
|
| 107 |
+
sequence_output2 = outputs["hidden_states"][-2]
|
| 108 |
+
sequence_output2 = self.pre_LN(sequence_output2)
|
| 109 |
+
projection_state2 = self.transformation_pre(sequence_output2)
|
| 110 |
+
|
| 111 |
+
return TransformationModelOutput(
|
| 112 |
+
projection_state=projection_state2,
|
| 113 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 114 |
+
hidden_states=outputs.hidden_states,
|
| 115 |
+
attentions=outputs.attentions,
|
| 116 |
+
)
|
| 117 |
+
else:
|
| 118 |
+
projection_state = self.transformation(outputs.last_hidden_state)
|
| 119 |
+
return TransformationModelOutput(
|
| 120 |
+
projection_state=projection_state,
|
| 121 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 122 |
+
hidden_states=outputs.hidden_states,
|
| 123 |
+
attentions=outputs.attentions,
|
| 124 |
+
)
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
ADDED
|
@@ -0,0 +1,990 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from packaging import version
|
| 20 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer
|
| 21 |
+
|
| 22 |
+
from ....configuration_utils import FrozenDict
|
| 23 |
+
from ....image_processor import PipelineImageInput, VaeImageProcessor
|
| 24 |
+
from ....loaders import (
|
| 25 |
+
FromSingleFileMixin,
|
| 26 |
+
IPAdapterMixin,
|
| 27 |
+
StableDiffusionLoraLoaderMixin,
|
| 28 |
+
TextualInversionLoaderMixin,
|
| 29 |
+
)
|
| 30 |
+
from ....models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
| 31 |
+
from ....models.lora import adjust_lora_scale_text_encoder
|
| 32 |
+
from ....schedulers import KarrasDiffusionSchedulers
|
| 33 |
+
from ....utils import (
|
| 34 |
+
USE_PEFT_BACKEND,
|
| 35 |
+
deprecate,
|
| 36 |
+
logging,
|
| 37 |
+
replace_example_docstring,
|
| 38 |
+
scale_lora_layers,
|
| 39 |
+
unscale_lora_layers,
|
| 40 |
+
)
|
| 41 |
+
from ....utils.torch_utils import randn_tensor
|
| 42 |
+
from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
| 43 |
+
from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 44 |
+
from .modeling_roberta_series import RobertaSeriesModelWithTransformation
|
| 45 |
+
from .pipeline_output import AltDiffusionPipelineOutput
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 49 |
+
|
| 50 |
+
EXAMPLE_DOC_STRING = """
|
| 51 |
+
Examples:
|
| 52 |
+
```py
|
| 53 |
+
>>> import torch
|
| 54 |
+
>>> from diffusers import AltDiffusionPipeline
|
| 55 |
+
|
| 56 |
+
>>> pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9", torch_dtype=torch.float16)
|
| 57 |
+
>>> pipe = pipe.to("cuda")
|
| 58 |
+
|
| 59 |
+
>>> # "dark elf princess, highly detailed, d & d, fantasy, highly detailed, digital painting, trending on artstation, concept art, sharp focus, illustration, art by artgerm and greg rutkowski and fuji choko and viktoria gavrilenko and hoang lap"
|
| 60 |
+
>>> prompt = "黑暗精灵公主,非常详细,幻想,非常详细,数字绘画,概念艺术,敏锐的焦点,插图"
|
| 61 |
+
>>> image = pipe(prompt).images[0]
|
| 62 |
+
```
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
| 67 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
| 68 |
+
r"""
|
| 69 |
+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
| 70 |
+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
| 71 |
+
Flawed](https://huggingface.co/papers/2305.08891).
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
noise_cfg (`torch.Tensor`):
|
| 75 |
+
The predicted noise tensor for the guided diffusion process.
|
| 76 |
+
noise_pred_text (`torch.Tensor`):
|
| 77 |
+
The predicted noise tensor for the text-guided diffusion process.
|
| 78 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 79 |
+
A rescale factor applied to the noise predictions.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
| 83 |
+
"""
|
| 84 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
| 85 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
| 86 |
+
# rescale the results from guidance (fixes overexposure)
|
| 87 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
| 88 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
| 89 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
| 90 |
+
return noise_cfg
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 94 |
+
def retrieve_timesteps(
|
| 95 |
+
scheduler,
|
| 96 |
+
num_inference_steps: Optional[int] = None,
|
| 97 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 98 |
+
timesteps: Optional[List[int]] = None,
|
| 99 |
+
sigmas: Optional[List[float]] = None,
|
| 100 |
+
**kwargs,
|
| 101 |
+
):
|
| 102 |
+
r"""
|
| 103 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 104 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
scheduler (`SchedulerMixin`):
|
| 108 |
+
The scheduler to get timesteps from.
|
| 109 |
+
num_inference_steps (`int`):
|
| 110 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 111 |
+
must be `None`.
|
| 112 |
+
device (`str` or `torch.device`, *optional*):
|
| 113 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 114 |
+
timesteps (`List[int]`, *optional*):
|
| 115 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 116 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 117 |
+
sigmas (`List[float]`, *optional*):
|
| 118 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 119 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 123 |
+
second element is the number of inference steps.
|
| 124 |
+
"""
|
| 125 |
+
if timesteps is not None and sigmas is not None:
|
| 126 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 127 |
+
if timesteps is not None:
|
| 128 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 129 |
+
if not accepts_timesteps:
|
| 130 |
+
raise ValueError(
|
| 131 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 132 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 133 |
+
)
|
| 134 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 135 |
+
timesteps = scheduler.timesteps
|
| 136 |
+
num_inference_steps = len(timesteps)
|
| 137 |
+
elif sigmas is not None:
|
| 138 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 139 |
+
if not accept_sigmas:
|
| 140 |
+
raise ValueError(
|
| 141 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 142 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 143 |
+
)
|
| 144 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 145 |
+
timesteps = scheduler.timesteps
|
| 146 |
+
num_inference_steps = len(timesteps)
|
| 147 |
+
else:
|
| 148 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 149 |
+
timesteps = scheduler.timesteps
|
| 150 |
+
return timesteps, num_inference_steps
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class AltDiffusionPipeline(
|
| 154 |
+
DiffusionPipeline,
|
| 155 |
+
StableDiffusionMixin,
|
| 156 |
+
TextualInversionLoaderMixin,
|
| 157 |
+
StableDiffusionLoraLoaderMixin,
|
| 158 |
+
IPAdapterMixin,
|
| 159 |
+
FromSingleFileMixin,
|
| 160 |
+
):
|
| 161 |
+
r"""
|
| 162 |
+
Pipeline for text-to-image generation using Alt Diffusion.
|
| 163 |
+
|
| 164 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 165 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 166 |
+
|
| 167 |
+
The pipeline also inherits the following loading methods:
|
| 168 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
| 169 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 170 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
| 171 |
+
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
| 172 |
+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
vae ([`AutoencoderKL`]):
|
| 176 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 177 |
+
text_encoder ([`~transformers.RobertaSeriesModelWithTransformation`]):
|
| 178 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
| 179 |
+
tokenizer ([`~transformers.XLMRobertaTokenizer`]):
|
| 180 |
+
A `XLMRobertaTokenizer` to tokenize text.
|
| 181 |
+
unet ([`UNet2DConditionModel`]):
|
| 182 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
| 183 |
+
scheduler ([`SchedulerMixin`]):
|
| 184 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 185 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 186 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
| 187 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 188 |
+
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
| 189 |
+
about a model's potential harms.
|
| 190 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
| 191 |
+
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
| 195 |
+
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
| 196 |
+
_exclude_from_cpu_offload = ["safety_checker"]
|
| 197 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 198 |
+
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
vae: AutoencoderKL,
|
| 202 |
+
text_encoder: RobertaSeriesModelWithTransformation,
|
| 203 |
+
tokenizer: XLMRobertaTokenizer,
|
| 204 |
+
unet: UNet2DConditionModel,
|
| 205 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 206 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 207 |
+
feature_extractor: CLIPImageProcessor,
|
| 208 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 209 |
+
requires_safety_checker: bool = True,
|
| 210 |
+
):
|
| 211 |
+
super().__init__()
|
| 212 |
+
|
| 213 |
+
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
| 214 |
+
deprecation_message = (
|
| 215 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
| 216 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
| 217 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
| 218 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
| 219 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
| 220 |
+
" file"
|
| 221 |
+
)
|
| 222 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
| 223 |
+
new_config = dict(scheduler.config)
|
| 224 |
+
new_config["steps_offset"] = 1
|
| 225 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
| 226 |
+
|
| 227 |
+
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
| 228 |
+
deprecation_message = (
|
| 229 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
| 230 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
| 231 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
| 232 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
| 233 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
| 234 |
+
)
|
| 235 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
| 236 |
+
new_config = dict(scheduler.config)
|
| 237 |
+
new_config["clip_sample"] = False
|
| 238 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
| 239 |
+
|
| 240 |
+
if safety_checker is None and requires_safety_checker:
|
| 241 |
+
logger.warning(
|
| 242 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
| 243 |
+
" that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
|
| 244 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
| 245 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
| 246 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
| 247 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
if safety_checker is not None and feature_extractor is None:
|
| 251 |
+
raise ValueError(
|
| 252 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
| 253 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
is_unet_version_less_0_9_0 = (
|
| 257 |
+
unet is not None
|
| 258 |
+
and hasattr(unet.config, "_diffusers_version")
|
| 259 |
+
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
| 260 |
+
)
|
| 261 |
+
is_unet_sample_size_less_64 = (
|
| 262 |
+
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
| 263 |
+
)
|
| 264 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
| 265 |
+
deprecation_message = (
|
| 266 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
| 267 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
| 268 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
| 269 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
| 270 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
| 271 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
| 272 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
| 273 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
| 274 |
+
" the `unet/config.json` file"
|
| 275 |
+
)
|
| 276 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
| 277 |
+
new_config = dict(unet.config)
|
| 278 |
+
new_config["sample_size"] = 64
|
| 279 |
+
unet._internal_dict = FrozenDict(new_config)
|
| 280 |
+
|
| 281 |
+
self.register_modules(
|
| 282 |
+
vae=vae,
|
| 283 |
+
text_encoder=text_encoder,
|
| 284 |
+
tokenizer=tokenizer,
|
| 285 |
+
unet=unet,
|
| 286 |
+
scheduler=scheduler,
|
| 287 |
+
safety_checker=safety_checker,
|
| 288 |
+
feature_extractor=feature_extractor,
|
| 289 |
+
image_encoder=image_encoder,
|
| 290 |
+
)
|
| 291 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 292 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 293 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
| 294 |
+
|
| 295 |
+
def _encode_prompt(
|
| 296 |
+
self,
|
| 297 |
+
prompt,
|
| 298 |
+
device,
|
| 299 |
+
num_images_per_prompt,
|
| 300 |
+
do_classifier_free_guidance,
|
| 301 |
+
negative_prompt=None,
|
| 302 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 303 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 304 |
+
lora_scale: Optional[float] = None,
|
| 305 |
+
**kwargs,
|
| 306 |
+
):
|
| 307 |
+
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
| 308 |
+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
| 309 |
+
|
| 310 |
+
prompt_embeds_tuple = self.encode_prompt(
|
| 311 |
+
prompt=prompt,
|
| 312 |
+
device=device,
|
| 313 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 314 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 315 |
+
negative_prompt=negative_prompt,
|
| 316 |
+
prompt_embeds=prompt_embeds,
|
| 317 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 318 |
+
lora_scale=lora_scale,
|
| 319 |
+
**kwargs,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# concatenate for backwards comp
|
| 323 |
+
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
| 324 |
+
|
| 325 |
+
return prompt_embeds
|
| 326 |
+
|
| 327 |
+
def encode_prompt(
|
| 328 |
+
self,
|
| 329 |
+
prompt,
|
| 330 |
+
device,
|
| 331 |
+
num_images_per_prompt,
|
| 332 |
+
do_classifier_free_guidance,
|
| 333 |
+
negative_prompt=None,
|
| 334 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 335 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 336 |
+
lora_scale: Optional[float] = None,
|
| 337 |
+
clip_skip: Optional[int] = None,
|
| 338 |
+
):
|
| 339 |
+
r"""
|
| 340 |
+
Encodes the prompt into text encoder hidden states.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 344 |
+
prompt to be encoded
|
| 345 |
+
device: (`torch.device`):
|
| 346 |
+
torch device
|
| 347 |
+
num_images_per_prompt (`int`):
|
| 348 |
+
number of images that should be generated per prompt
|
| 349 |
+
do_classifier_free_guidance (`bool`):
|
| 350 |
+
whether to use classifier free guidance or not
|
| 351 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 352 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 353 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 354 |
+
less than `1`).
|
| 355 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 356 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 357 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 358 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 359 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 360 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 361 |
+
argument.
|
| 362 |
+
lora_scale (`float`, *optional*):
|
| 363 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 364 |
+
clip_skip (`int`, *optional*):
|
| 365 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 366 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 367 |
+
"""
|
| 368 |
+
# set lora scale so that monkey patched LoRA
|
| 369 |
+
# function of text encoder can correctly access it
|
| 370 |
+
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
|
| 371 |
+
self._lora_scale = lora_scale
|
| 372 |
+
|
| 373 |
+
# dynamically adjust the LoRA scale
|
| 374 |
+
if not USE_PEFT_BACKEND:
|
| 375 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
| 376 |
+
else:
|
| 377 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 378 |
+
|
| 379 |
+
if prompt is not None and isinstance(prompt, str):
|
| 380 |
+
batch_size = 1
|
| 381 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 382 |
+
batch_size = len(prompt)
|
| 383 |
+
else:
|
| 384 |
+
batch_size = prompt_embeds.shape[0]
|
| 385 |
+
|
| 386 |
+
if prompt_embeds is None:
|
| 387 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 388 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 389 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 390 |
+
|
| 391 |
+
text_inputs = self.tokenizer(
|
| 392 |
+
prompt,
|
| 393 |
+
padding="max_length",
|
| 394 |
+
max_length=self.tokenizer.model_max_length,
|
| 395 |
+
truncation=True,
|
| 396 |
+
return_tensors="pt",
|
| 397 |
+
)
|
| 398 |
+
text_input_ids = text_inputs.input_ids
|
| 399 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 400 |
+
|
| 401 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 402 |
+
text_input_ids, untruncated_ids
|
| 403 |
+
):
|
| 404 |
+
removed_text = self.tokenizer.batch_decode(
|
| 405 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 406 |
+
)
|
| 407 |
+
logger.warning(
|
| 408 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 409 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 413 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
| 414 |
+
else:
|
| 415 |
+
attention_mask = None
|
| 416 |
+
|
| 417 |
+
if clip_skip is None:
|
| 418 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
| 419 |
+
prompt_embeds = prompt_embeds[0]
|
| 420 |
+
else:
|
| 421 |
+
prompt_embeds = self.text_encoder(
|
| 422 |
+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
| 423 |
+
)
|
| 424 |
+
# Access the `hidden_states` first, that contains a tuple of
|
| 425 |
+
# all the hidden states from the encoder layers. Then index into
|
| 426 |
+
# the tuple to access the hidden states from the desired layer.
|
| 427 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
| 428 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
| 429 |
+
# representations. The `last_hidden_states` that we typically use for
|
| 430 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
| 431 |
+
# layer.
|
| 432 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
| 433 |
+
|
| 434 |
+
if self.text_encoder is not None:
|
| 435 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
| 436 |
+
elif self.unet is not None:
|
| 437 |
+
prompt_embeds_dtype = self.unet.dtype
|
| 438 |
+
else:
|
| 439 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
| 440 |
+
|
| 441 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 442 |
+
|
| 443 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 444 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 445 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 446 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 447 |
+
|
| 448 |
+
# get unconditional embeddings for classifier free guidance
|
| 449 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 450 |
+
uncond_tokens: List[str]
|
| 451 |
+
if negative_prompt is None:
|
| 452 |
+
uncond_tokens = [""] * batch_size
|
| 453 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
| 454 |
+
raise TypeError(
|
| 455 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 456 |
+
f" {type(prompt)}."
|
| 457 |
+
)
|
| 458 |
+
elif isinstance(negative_prompt, str):
|
| 459 |
+
uncond_tokens = [negative_prompt]
|
| 460 |
+
elif batch_size != len(negative_prompt):
|
| 461 |
+
raise ValueError(
|
| 462 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 463 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 464 |
+
" the batch size of `prompt`."
|
| 465 |
+
)
|
| 466 |
+
else:
|
| 467 |
+
uncond_tokens = negative_prompt
|
| 468 |
+
|
| 469 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 470 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 471 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
| 472 |
+
|
| 473 |
+
max_length = prompt_embeds.shape[1]
|
| 474 |
+
uncond_input = self.tokenizer(
|
| 475 |
+
uncond_tokens,
|
| 476 |
+
padding="max_length",
|
| 477 |
+
max_length=max_length,
|
| 478 |
+
truncation=True,
|
| 479 |
+
return_tensors="pt",
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 483 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
| 484 |
+
else:
|
| 485 |
+
attention_mask = None
|
| 486 |
+
|
| 487 |
+
negative_prompt_embeds = self.text_encoder(
|
| 488 |
+
uncond_input.input_ids.to(device),
|
| 489 |
+
attention_mask=attention_mask,
|
| 490 |
+
)
|
| 491 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 492 |
+
|
| 493 |
+
if do_classifier_free_guidance:
|
| 494 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 495 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 496 |
+
|
| 497 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 498 |
+
|
| 499 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 500 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 501 |
+
|
| 502 |
+
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 503 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 504 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 505 |
+
|
| 506 |
+
return prompt_embeds, negative_prompt_embeds
|
| 507 |
+
|
| 508 |
+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
| 509 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 510 |
+
|
| 511 |
+
if not isinstance(image, torch.Tensor):
|
| 512 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 513 |
+
|
| 514 |
+
image = image.to(device=device, dtype=dtype)
|
| 515 |
+
if output_hidden_states:
|
| 516 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
| 517 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| 518 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
| 519 |
+
torch.zeros_like(image), output_hidden_states=True
|
| 520 |
+
).hidden_states[-2]
|
| 521 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
| 522 |
+
num_images_per_prompt, dim=0
|
| 523 |
+
)
|
| 524 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
| 525 |
+
else:
|
| 526 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 527 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 528 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
| 529 |
+
|
| 530 |
+
return image_embeds, uncond_image_embeds
|
| 531 |
+
|
| 532 |
+
def run_safety_checker(self, image, device, dtype):
|
| 533 |
+
if self.safety_checker is None:
|
| 534 |
+
has_nsfw_concept = None
|
| 535 |
+
else:
|
| 536 |
+
if torch.is_tensor(image):
|
| 537 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
| 538 |
+
else:
|
| 539 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
| 540 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
| 541 |
+
image, has_nsfw_concept = self.safety_checker(
|
| 542 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
| 543 |
+
)
|
| 544 |
+
return image, has_nsfw_concept
|
| 545 |
+
|
| 546 |
+
def decode_latents(self, latents):
|
| 547 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
| 548 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
| 549 |
+
|
| 550 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 551 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 552 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 553 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 554 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 555 |
+
return image
|
| 556 |
+
|
| 557 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 558 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 559 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 560 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 561 |
+
# and should be between [0, 1]
|
| 562 |
+
|
| 563 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 564 |
+
extra_step_kwargs = {}
|
| 565 |
+
if accepts_eta:
|
| 566 |
+
extra_step_kwargs["eta"] = eta
|
| 567 |
+
|
| 568 |
+
# check if the scheduler accepts generator
|
| 569 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 570 |
+
if accepts_generator:
|
| 571 |
+
extra_step_kwargs["generator"] = generator
|
| 572 |
+
return extra_step_kwargs
|
| 573 |
+
|
| 574 |
+
def check_inputs(
|
| 575 |
+
self,
|
| 576 |
+
prompt,
|
| 577 |
+
height,
|
| 578 |
+
width,
|
| 579 |
+
callback_steps,
|
| 580 |
+
negative_prompt=None,
|
| 581 |
+
prompt_embeds=None,
|
| 582 |
+
negative_prompt_embeds=None,
|
| 583 |
+
callback_on_step_end_tensor_inputs=None,
|
| 584 |
+
):
|
| 585 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 586 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 587 |
+
|
| 588 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
| 589 |
+
raise ValueError(
|
| 590 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 591 |
+
f" {type(callback_steps)}."
|
| 592 |
+
)
|
| 593 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 594 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 595 |
+
):
|
| 596 |
+
raise ValueError(
|
| 597 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
if prompt is not None and prompt_embeds is not None:
|
| 601 |
+
raise ValueError(
|
| 602 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 603 |
+
" only forward one of the two."
|
| 604 |
+
)
|
| 605 |
+
elif prompt is None and prompt_embeds is None:
|
| 606 |
+
raise ValueError(
|
| 607 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 608 |
+
)
|
| 609 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 610 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 611 |
+
|
| 612 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 613 |
+
raise ValueError(
|
| 614 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 615 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 619 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 620 |
+
raise ValueError(
|
| 621 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 622 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 623 |
+
f" {negative_prompt_embeds.shape}."
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
| 627 |
+
shape = (
|
| 628 |
+
batch_size,
|
| 629 |
+
num_channels_latents,
|
| 630 |
+
int(height) // self.vae_scale_factor,
|
| 631 |
+
int(width) // self.vae_scale_factor,
|
| 632 |
+
)
|
| 633 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 634 |
+
raise ValueError(
|
| 635 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 636 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
if latents is None:
|
| 640 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 641 |
+
else:
|
| 642 |
+
latents = latents.to(device)
|
| 643 |
+
|
| 644 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 645 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 646 |
+
return latents
|
| 647 |
+
|
| 648 |
+
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
| 649 |
+
"""
|
| 650 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
| 651 |
+
|
| 652 |
+
Args:
|
| 653 |
+
timesteps (`torch.Tensor`):
|
| 654 |
+
generate embedding vectors at these timesteps
|
| 655 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
| 656 |
+
dimension of the embeddings to generate
|
| 657 |
+
dtype:
|
| 658 |
+
data type of the generated embeddings
|
| 659 |
+
|
| 660 |
+
Returns:
|
| 661 |
+
`torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
| 662 |
+
"""
|
| 663 |
+
assert len(w.shape) == 1
|
| 664 |
+
w = w * 1000.0
|
| 665 |
+
|
| 666 |
+
half_dim = embedding_dim // 2
|
| 667 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
| 668 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
| 669 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
| 670 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 671 |
+
if embedding_dim % 2 == 1: # zero pad
|
| 672 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
| 673 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
| 674 |
+
return emb
|
| 675 |
+
|
| 676 |
+
@property
|
| 677 |
+
def guidance_scale(self):
|
| 678 |
+
return self._guidance_scale
|
| 679 |
+
|
| 680 |
+
@property
|
| 681 |
+
def guidance_rescale(self):
|
| 682 |
+
return self._guidance_rescale
|
| 683 |
+
|
| 684 |
+
@property
|
| 685 |
+
def clip_skip(self):
|
| 686 |
+
return self._clip_skip
|
| 687 |
+
|
| 688 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 689 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 690 |
+
# corresponds to doing no classifier free guidance.
|
| 691 |
+
@property
|
| 692 |
+
def do_classifier_free_guidance(self):
|
| 693 |
+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
| 694 |
+
|
| 695 |
+
@property
|
| 696 |
+
def cross_attention_kwargs(self):
|
| 697 |
+
return self._cross_attention_kwargs
|
| 698 |
+
|
| 699 |
+
@property
|
| 700 |
+
def num_timesteps(self):
|
| 701 |
+
return self._num_timesteps
|
| 702 |
+
|
| 703 |
+
@torch.no_grad()
|
| 704 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 705 |
+
def __call__(
|
| 706 |
+
self,
|
| 707 |
+
prompt: Union[str, List[str]] = None,
|
| 708 |
+
height: Optional[int] = None,
|
| 709 |
+
width: Optional[int] = None,
|
| 710 |
+
num_inference_steps: int = 50,
|
| 711 |
+
timesteps: List[int] = None,
|
| 712 |
+
sigmas: List[float] = None,
|
| 713 |
+
guidance_scale: float = 7.5,
|
| 714 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 715 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 716 |
+
eta: float = 0.0,
|
| 717 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 718 |
+
latents: Optional[torch.Tensor] = None,
|
| 719 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 720 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 721 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 722 |
+
output_type: Optional[str] = "pil",
|
| 723 |
+
return_dict: bool = True,
|
| 724 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 725 |
+
guidance_rescale: float = 0.0,
|
| 726 |
+
clip_skip: Optional[int] = None,
|
| 727 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 728 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 729 |
+
**kwargs,
|
| 730 |
+
):
|
| 731 |
+
r"""
|
| 732 |
+
The call function to the pipeline for generation.
|
| 733 |
+
|
| 734 |
+
Args:
|
| 735 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 736 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 737 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 738 |
+
The height in pixels of the generated image.
|
| 739 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 740 |
+
The width in pixels of the generated image.
|
| 741 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 742 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 743 |
+
expense of slower inference.
|
| 744 |
+
timesteps (`List[int]`, *optional*):
|
| 745 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 746 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 747 |
+
passed will be used. Must be in descending order.
|
| 748 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 749 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 750 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 751 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 752 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 753 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 754 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 755 |
+
The number of images to generate per prompt.
|
| 756 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 757 |
+
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
|
| 758 |
+
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 759 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 760 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 761 |
+
generation deterministic.
|
| 762 |
+
latents (`torch.Tensor`, *optional*):
|
| 763 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 764 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 765 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 766 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 767 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 768 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 769 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 770 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 771 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 772 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 773 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 774 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 775 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 776 |
+
Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
|
| 777 |
+
plain tuple.
|
| 778 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 779 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 780 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 781 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 782 |
+
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
| 783 |
+
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
| 784 |
+
using zero terminal SNR.
|
| 785 |
+
clip_skip (`int`, *optional*):
|
| 786 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 787 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 788 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 789 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 790 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 791 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 792 |
+
`callback_on_step_end_tensor_inputs`.
|
| 793 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 794 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 795 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 796 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 797 |
+
|
| 798 |
+
Examples:
|
| 799 |
+
|
| 800 |
+
Returns:
|
| 801 |
+
[`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`:
|
| 802 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] is returned,
|
| 803 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
| 804 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
| 805 |
+
"not-safe-for-work" (nsfw) content.
|
| 806 |
+
"""
|
| 807 |
+
|
| 808 |
+
callback = kwargs.pop("callback", None)
|
| 809 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 810 |
+
|
| 811 |
+
if callback is not None:
|
| 812 |
+
deprecate(
|
| 813 |
+
"callback",
|
| 814 |
+
"1.0.0",
|
| 815 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 816 |
+
)
|
| 817 |
+
if callback_steps is not None:
|
| 818 |
+
deprecate(
|
| 819 |
+
"callback_steps",
|
| 820 |
+
"1.0.0",
|
| 821 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
# 0. Default height and width to unet
|
| 825 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 826 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 827 |
+
# to deal with lora scaling and other possible forward hooks
|
| 828 |
+
|
| 829 |
+
# 1. Check inputs. Raise error if not correct
|
| 830 |
+
self.check_inputs(
|
| 831 |
+
prompt,
|
| 832 |
+
height,
|
| 833 |
+
width,
|
| 834 |
+
callback_steps,
|
| 835 |
+
negative_prompt,
|
| 836 |
+
prompt_embeds,
|
| 837 |
+
negative_prompt_embeds,
|
| 838 |
+
callback_on_step_end_tensor_inputs,
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
self._guidance_scale = guidance_scale
|
| 842 |
+
self._guidance_rescale = guidance_rescale
|
| 843 |
+
self._clip_skip = clip_skip
|
| 844 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 845 |
+
|
| 846 |
+
# 2. Define call parameters
|
| 847 |
+
if prompt is not None and isinstance(prompt, str):
|
| 848 |
+
batch_size = 1
|
| 849 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 850 |
+
batch_size = len(prompt)
|
| 851 |
+
else:
|
| 852 |
+
batch_size = prompt_embeds.shape[0]
|
| 853 |
+
|
| 854 |
+
device = self._execution_device
|
| 855 |
+
|
| 856 |
+
# 3. Encode input prompt
|
| 857 |
+
lora_scale = (
|
| 858 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 862 |
+
prompt,
|
| 863 |
+
device,
|
| 864 |
+
num_images_per_prompt,
|
| 865 |
+
self.do_classifier_free_guidance,
|
| 866 |
+
negative_prompt,
|
| 867 |
+
prompt_embeds=prompt_embeds,
|
| 868 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 869 |
+
lora_scale=lora_scale,
|
| 870 |
+
clip_skip=self.clip_skip,
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 874 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 875 |
+
# to avoid doing two forward passes
|
| 876 |
+
if self.do_classifier_free_guidance:
|
| 877 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 878 |
+
|
| 879 |
+
if ip_adapter_image is not None:
|
| 880 |
+
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
| 881 |
+
image_embeds, negative_image_embeds = self.encode_image(
|
| 882 |
+
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
| 883 |
+
)
|
| 884 |
+
if self.do_classifier_free_guidance:
|
| 885 |
+
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
| 886 |
+
|
| 887 |
+
# 4. Prepare timesteps
|
| 888 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 889 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
+
# 5. Prepare latent variables
|
| 893 |
+
num_channels_latents = self.unet.config.in_channels
|
| 894 |
+
latents = self.prepare_latents(
|
| 895 |
+
batch_size * num_images_per_prompt,
|
| 896 |
+
num_channels_latents,
|
| 897 |
+
height,
|
| 898 |
+
width,
|
| 899 |
+
prompt_embeds.dtype,
|
| 900 |
+
device,
|
| 901 |
+
generator,
|
| 902 |
+
latents,
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 906 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 907 |
+
|
| 908 |
+
# 6.1 Add image embeds for IP-Adapter
|
| 909 |
+
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
| 910 |
+
|
| 911 |
+
# 6.2 Optionally get Guidance Scale Embedding
|
| 912 |
+
timestep_cond = None
|
| 913 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
| 914 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
| 915 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
| 916 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 917 |
+
).to(device=device, dtype=latents.dtype)
|
| 918 |
+
|
| 919 |
+
# 7. Denoising loop
|
| 920 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 921 |
+
self._num_timesteps = len(timesteps)
|
| 922 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 923 |
+
for i, t in enumerate(timesteps):
|
| 924 |
+
# expand the latents if we are doing classifier free guidance
|
| 925 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 926 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 927 |
+
|
| 928 |
+
# predict the noise residual
|
| 929 |
+
noise_pred = self.unet(
|
| 930 |
+
latent_model_input,
|
| 931 |
+
t,
|
| 932 |
+
encoder_hidden_states=prompt_embeds,
|
| 933 |
+
timestep_cond=timestep_cond,
|
| 934 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 935 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 936 |
+
return_dict=False,
|
| 937 |
+
)[0]
|
| 938 |
+
|
| 939 |
+
# perform guidance
|
| 940 |
+
if self.do_classifier_free_guidance:
|
| 941 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 942 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 943 |
+
|
| 944 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
| 945 |
+
# Based on 3.4. in https://huggingface.co/papers/2305.08891
|
| 946 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
| 947 |
+
|
| 948 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 949 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 950 |
+
|
| 951 |
+
if callback_on_step_end is not None:
|
| 952 |
+
callback_kwargs = {}
|
| 953 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 954 |
+
callback_kwargs[k] = locals()[k]
|
| 955 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 956 |
+
|
| 957 |
+
latents = callback_outputs.pop("latents", latents)
|
| 958 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 959 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 960 |
+
|
| 961 |
+
# call the callback, if provided
|
| 962 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 963 |
+
progress_bar.update()
|
| 964 |
+
if callback is not None and i % callback_steps == 0:
|
| 965 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 966 |
+
callback(step_idx, t, latents)
|
| 967 |
+
|
| 968 |
+
if not output_type == "latent":
|
| 969 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
| 970 |
+
0
|
| 971 |
+
]
|
| 972 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
| 973 |
+
else:
|
| 974 |
+
image = latents
|
| 975 |
+
has_nsfw_concept = None
|
| 976 |
+
|
| 977 |
+
if has_nsfw_concept is None:
|
| 978 |
+
do_denormalize = [True] * image.shape[0]
|
| 979 |
+
else:
|
| 980 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 981 |
+
|
| 982 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 983 |
+
|
| 984 |
+
# Offload all models
|
| 985 |
+
self.maybe_free_model_hooks()
|
| 986 |
+
|
| 987 |
+
if not return_dict:
|
| 988 |
+
return (image, has_nsfw_concept)
|
| 989 |
+
|
| 990 |
+
return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
ADDED
|
@@ -0,0 +1,1045 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import PIL.Image
|
| 20 |
+
import torch
|
| 21 |
+
from packaging import version
|
| 22 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer
|
| 23 |
+
|
| 24 |
+
from ....configuration_utils import FrozenDict
|
| 25 |
+
from ....image_processor import PipelineImageInput, VaeImageProcessor
|
| 26 |
+
from ....loaders import (
|
| 27 |
+
FromSingleFileMixin,
|
| 28 |
+
IPAdapterMixin,
|
| 29 |
+
StableDiffusionLoraLoaderMixin,
|
| 30 |
+
TextualInversionLoaderMixin,
|
| 31 |
+
)
|
| 32 |
+
from ....models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
| 33 |
+
from ....models.lora import adjust_lora_scale_text_encoder
|
| 34 |
+
from ....schedulers import KarrasDiffusionSchedulers
|
| 35 |
+
from ....utils import (
|
| 36 |
+
PIL_INTERPOLATION,
|
| 37 |
+
USE_PEFT_BACKEND,
|
| 38 |
+
deprecate,
|
| 39 |
+
logging,
|
| 40 |
+
replace_example_docstring,
|
| 41 |
+
scale_lora_layers,
|
| 42 |
+
unscale_lora_layers,
|
| 43 |
+
)
|
| 44 |
+
from ....utils.torch_utils import randn_tensor
|
| 45 |
+
from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
| 46 |
+
from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 47 |
+
from .modeling_roberta_series import RobertaSeriesModelWithTransformation
|
| 48 |
+
from .pipeline_output import AltDiffusionPipelineOutput
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 52 |
+
|
| 53 |
+
EXAMPLE_DOC_STRING = """
|
| 54 |
+
Examples:
|
| 55 |
+
```py
|
| 56 |
+
>>> import requests
|
| 57 |
+
>>> import torch
|
| 58 |
+
>>> from PIL import Image
|
| 59 |
+
>>> from io import BytesIO
|
| 60 |
+
|
| 61 |
+
>>> from diffusers import AltDiffusionImg2ImgPipeline
|
| 62 |
+
|
| 63 |
+
>>> device = "cuda"
|
| 64 |
+
>>> model_id_or_path = "BAAI/AltDiffusion-m9"
|
| 65 |
+
>>> pipe = AltDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
|
| 66 |
+
>>> pipe = pipe.to(device)
|
| 67 |
+
|
| 68 |
+
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
| 69 |
+
|
| 70 |
+
>>> response = requests.get(url)
|
| 71 |
+
>>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 72 |
+
>>> init_image = init_image.resize((768, 512))
|
| 73 |
+
|
| 74 |
+
>>> # "A fantasy landscape, trending on artstation"
|
| 75 |
+
>>> prompt = "幻想风景, artstation"
|
| 76 |
+
|
| 77 |
+
>>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
|
| 78 |
+
>>> images[0].save("幻想风景.png")
|
| 79 |
+
```
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 84 |
+
def retrieve_latents(
|
| 85 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 86 |
+
):
|
| 87 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 88 |
+
return encoder_output.latent_dist.sample(generator)
|
| 89 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 90 |
+
return encoder_output.latent_dist.mode()
|
| 91 |
+
elif hasattr(encoder_output, "latents"):
|
| 92 |
+
return encoder_output.latents
|
| 93 |
+
else:
|
| 94 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
| 98 |
+
def preprocess(image):
|
| 99 |
+
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
|
| 100 |
+
deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
|
| 101 |
+
if isinstance(image, torch.Tensor):
|
| 102 |
+
return image
|
| 103 |
+
elif isinstance(image, PIL.Image.Image):
|
| 104 |
+
image = [image]
|
| 105 |
+
|
| 106 |
+
if isinstance(image[0], PIL.Image.Image):
|
| 107 |
+
w, h = image[0].size
|
| 108 |
+
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
| 109 |
+
|
| 110 |
+
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
| 111 |
+
image = np.concatenate(image, axis=0)
|
| 112 |
+
image = np.array(image).astype(np.float32) / 255.0
|
| 113 |
+
image = image.transpose(0, 3, 1, 2)
|
| 114 |
+
image = 2.0 * image - 1.0
|
| 115 |
+
image = torch.from_numpy(image)
|
| 116 |
+
elif isinstance(image[0], torch.Tensor):
|
| 117 |
+
image = torch.cat(image, dim=0)
|
| 118 |
+
return image
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 122 |
+
def retrieve_timesteps(
|
| 123 |
+
scheduler,
|
| 124 |
+
num_inference_steps: Optional[int] = None,
|
| 125 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 126 |
+
timesteps: Optional[List[int]] = None,
|
| 127 |
+
sigmas: Optional[List[float]] = None,
|
| 128 |
+
**kwargs,
|
| 129 |
+
):
|
| 130 |
+
r"""
|
| 131 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 132 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
scheduler (`SchedulerMixin`):
|
| 136 |
+
The scheduler to get timesteps from.
|
| 137 |
+
num_inference_steps (`int`):
|
| 138 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 139 |
+
must be `None`.
|
| 140 |
+
device (`str` or `torch.device`, *optional*):
|
| 141 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 142 |
+
timesteps (`List[int]`, *optional*):
|
| 143 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 144 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 145 |
+
sigmas (`List[float]`, *optional*):
|
| 146 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 147 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 151 |
+
second element is the number of inference steps.
|
| 152 |
+
"""
|
| 153 |
+
if timesteps is not None and sigmas is not None:
|
| 154 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 155 |
+
if timesteps is not None:
|
| 156 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 157 |
+
if not accepts_timesteps:
|
| 158 |
+
raise ValueError(
|
| 159 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 160 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 161 |
+
)
|
| 162 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 163 |
+
timesteps = scheduler.timesteps
|
| 164 |
+
num_inference_steps = len(timesteps)
|
| 165 |
+
elif sigmas is not None:
|
| 166 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 167 |
+
if not accept_sigmas:
|
| 168 |
+
raise ValueError(
|
| 169 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 170 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 171 |
+
)
|
| 172 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 173 |
+
timesteps = scheduler.timesteps
|
| 174 |
+
num_inference_steps = len(timesteps)
|
| 175 |
+
else:
|
| 176 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 177 |
+
timesteps = scheduler.timesteps
|
| 178 |
+
return timesteps, num_inference_steps
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class AltDiffusionImg2ImgPipeline(
|
| 182 |
+
DiffusionPipeline,
|
| 183 |
+
StableDiffusionMixin,
|
| 184 |
+
TextualInversionLoaderMixin,
|
| 185 |
+
IPAdapterMixin,
|
| 186 |
+
StableDiffusionLoraLoaderMixin,
|
| 187 |
+
FromSingleFileMixin,
|
| 188 |
+
):
|
| 189 |
+
r"""
|
| 190 |
+
Pipeline for text-guided image-to-image generation using Alt Diffusion.
|
| 191 |
+
|
| 192 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 193 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 194 |
+
|
| 195 |
+
The pipeline also inherits the following loading methods:
|
| 196 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
| 197 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 198 |
+
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
| 199 |
+
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
| 200 |
+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
vae ([`AutoencoderKL`]):
|
| 204 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 205 |
+
text_encoder ([`~transformers.RobertaSeriesModelWithTransformation`]):
|
| 206 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
| 207 |
+
tokenizer ([`~transformers.XLMRobertaTokenizer`]):
|
| 208 |
+
A `XLMRobertaTokenizer` to tokenize text.
|
| 209 |
+
unet ([`UNet2DConditionModel`]):
|
| 210 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
| 211 |
+
scheduler ([`SchedulerMixin`]):
|
| 212 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 213 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 214 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
| 215 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 216 |
+
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
| 217 |
+
about a model's potential harms.
|
| 218 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
| 219 |
+
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
| 223 |
+
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
| 224 |
+
_exclude_from_cpu_offload = ["safety_checker"]
|
| 225 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 226 |
+
|
| 227 |
+
def __init__(
|
| 228 |
+
self,
|
| 229 |
+
vae: AutoencoderKL,
|
| 230 |
+
text_encoder: RobertaSeriesModelWithTransformation,
|
| 231 |
+
tokenizer: XLMRobertaTokenizer,
|
| 232 |
+
unet: UNet2DConditionModel,
|
| 233 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 234 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 235 |
+
feature_extractor: CLIPImageProcessor,
|
| 236 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 237 |
+
requires_safety_checker: bool = True,
|
| 238 |
+
):
|
| 239 |
+
super().__init__()
|
| 240 |
+
|
| 241 |
+
if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
|
| 242 |
+
deprecation_message = (
|
| 243 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
| 244 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
| 245 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
| 246 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
| 247 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
| 248 |
+
" file"
|
| 249 |
+
)
|
| 250 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
| 251 |
+
new_config = dict(scheduler.config)
|
| 252 |
+
new_config["steps_offset"] = 1
|
| 253 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
| 254 |
+
|
| 255 |
+
if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
|
| 256 |
+
deprecation_message = (
|
| 257 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
| 258 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
| 259 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
| 260 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
| 261 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
| 262 |
+
)
|
| 263 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
| 264 |
+
new_config = dict(scheduler.config)
|
| 265 |
+
new_config["clip_sample"] = False
|
| 266 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
| 267 |
+
|
| 268 |
+
if safety_checker is None and requires_safety_checker:
|
| 269 |
+
logger.warning(
|
| 270 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
| 271 |
+
" that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
|
| 272 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
| 273 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
| 274 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
| 275 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
if safety_checker is not None and feature_extractor is None:
|
| 279 |
+
raise ValueError(
|
| 280 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
| 281 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
is_unet_version_less_0_9_0 = (
|
| 285 |
+
unet is not None
|
| 286 |
+
and hasattr(unet.config, "_diffusers_version")
|
| 287 |
+
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
|
| 288 |
+
)
|
| 289 |
+
is_unet_sample_size_less_64 = (
|
| 290 |
+
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
| 291 |
+
)
|
| 292 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
| 293 |
+
deprecation_message = (
|
| 294 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
| 295 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
| 296 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
| 297 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
| 298 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
| 299 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
| 300 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
| 301 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
| 302 |
+
" the `unet/config.json` file"
|
| 303 |
+
)
|
| 304 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
| 305 |
+
new_config = dict(unet.config)
|
| 306 |
+
new_config["sample_size"] = 64
|
| 307 |
+
unet._internal_dict = FrozenDict(new_config)
|
| 308 |
+
|
| 309 |
+
self.register_modules(
|
| 310 |
+
vae=vae,
|
| 311 |
+
text_encoder=text_encoder,
|
| 312 |
+
tokenizer=tokenizer,
|
| 313 |
+
unet=unet,
|
| 314 |
+
scheduler=scheduler,
|
| 315 |
+
safety_checker=safety_checker,
|
| 316 |
+
feature_extractor=feature_extractor,
|
| 317 |
+
image_encoder=image_encoder,
|
| 318 |
+
)
|
| 319 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 320 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 321 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
| 322 |
+
|
| 323 |
+
def _encode_prompt(
|
| 324 |
+
self,
|
| 325 |
+
prompt,
|
| 326 |
+
device,
|
| 327 |
+
num_images_per_prompt,
|
| 328 |
+
do_classifier_free_guidance,
|
| 329 |
+
negative_prompt=None,
|
| 330 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 331 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 332 |
+
lora_scale: Optional[float] = None,
|
| 333 |
+
**kwargs,
|
| 334 |
+
):
|
| 335 |
+
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
| 336 |
+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
| 337 |
+
|
| 338 |
+
prompt_embeds_tuple = self.encode_prompt(
|
| 339 |
+
prompt=prompt,
|
| 340 |
+
device=device,
|
| 341 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 342 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 343 |
+
negative_prompt=negative_prompt,
|
| 344 |
+
prompt_embeds=prompt_embeds,
|
| 345 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 346 |
+
lora_scale=lora_scale,
|
| 347 |
+
**kwargs,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# concatenate for backwards comp
|
| 351 |
+
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
| 352 |
+
|
| 353 |
+
return prompt_embeds
|
| 354 |
+
|
| 355 |
+
def encode_prompt(
|
| 356 |
+
self,
|
| 357 |
+
prompt,
|
| 358 |
+
device,
|
| 359 |
+
num_images_per_prompt,
|
| 360 |
+
do_classifier_free_guidance,
|
| 361 |
+
negative_prompt=None,
|
| 362 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 363 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 364 |
+
lora_scale: Optional[float] = None,
|
| 365 |
+
clip_skip: Optional[int] = None,
|
| 366 |
+
):
|
| 367 |
+
r"""
|
| 368 |
+
Encodes the prompt into text encoder hidden states.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 372 |
+
prompt to be encoded
|
| 373 |
+
device: (`torch.device`):
|
| 374 |
+
torch device
|
| 375 |
+
num_images_per_prompt (`int`):
|
| 376 |
+
number of images that should be generated per prompt
|
| 377 |
+
do_classifier_free_guidance (`bool`):
|
| 378 |
+
whether to use classifier free guidance or not
|
| 379 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 380 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 381 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 382 |
+
less than `1`).
|
| 383 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 384 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 385 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 386 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 387 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 388 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 389 |
+
argument.
|
| 390 |
+
lora_scale (`float`, *optional*):
|
| 391 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 392 |
+
clip_skip (`int`, *optional*):
|
| 393 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 394 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 395 |
+
"""
|
| 396 |
+
# set lora scale so that monkey patched LoRA
|
| 397 |
+
# function of text encoder can correctly access it
|
| 398 |
+
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
|
| 399 |
+
self._lora_scale = lora_scale
|
| 400 |
+
|
| 401 |
+
# dynamically adjust the LoRA scale
|
| 402 |
+
if not USE_PEFT_BACKEND:
|
| 403 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
| 404 |
+
else:
|
| 405 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 406 |
+
|
| 407 |
+
if prompt is not None and isinstance(prompt, str):
|
| 408 |
+
batch_size = 1
|
| 409 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 410 |
+
batch_size = len(prompt)
|
| 411 |
+
else:
|
| 412 |
+
batch_size = prompt_embeds.shape[0]
|
| 413 |
+
|
| 414 |
+
if prompt_embeds is None:
|
| 415 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 416 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 417 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 418 |
+
|
| 419 |
+
text_inputs = self.tokenizer(
|
| 420 |
+
prompt,
|
| 421 |
+
padding="max_length",
|
| 422 |
+
max_length=self.tokenizer.model_max_length,
|
| 423 |
+
truncation=True,
|
| 424 |
+
return_tensors="pt",
|
| 425 |
+
)
|
| 426 |
+
text_input_ids = text_inputs.input_ids
|
| 427 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 428 |
+
|
| 429 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 430 |
+
text_input_ids, untruncated_ids
|
| 431 |
+
):
|
| 432 |
+
removed_text = self.tokenizer.batch_decode(
|
| 433 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 434 |
+
)
|
| 435 |
+
logger.warning(
|
| 436 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 437 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 441 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
| 442 |
+
else:
|
| 443 |
+
attention_mask = None
|
| 444 |
+
|
| 445 |
+
if clip_skip is None:
|
| 446 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
| 447 |
+
prompt_embeds = prompt_embeds[0]
|
| 448 |
+
else:
|
| 449 |
+
prompt_embeds = self.text_encoder(
|
| 450 |
+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
| 451 |
+
)
|
| 452 |
+
# Access the `hidden_states` first, that contains a tuple of
|
| 453 |
+
# all the hidden states from the encoder layers. Then index into
|
| 454 |
+
# the tuple to access the hidden states from the desired layer.
|
| 455 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
| 456 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
| 457 |
+
# representations. The `last_hidden_states` that we typically use for
|
| 458 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
| 459 |
+
# layer.
|
| 460 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
| 461 |
+
|
| 462 |
+
if self.text_encoder is not None:
|
| 463 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
| 464 |
+
elif self.unet is not None:
|
| 465 |
+
prompt_embeds_dtype = self.unet.dtype
|
| 466 |
+
else:
|
| 467 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
| 468 |
+
|
| 469 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 470 |
+
|
| 471 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 472 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 473 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 474 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 475 |
+
|
| 476 |
+
# get unconditional embeddings for classifier free guidance
|
| 477 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 478 |
+
uncond_tokens: List[str]
|
| 479 |
+
if negative_prompt is None:
|
| 480 |
+
uncond_tokens = [""] * batch_size
|
| 481 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
| 482 |
+
raise TypeError(
|
| 483 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 484 |
+
f" {type(prompt)}."
|
| 485 |
+
)
|
| 486 |
+
elif isinstance(negative_prompt, str):
|
| 487 |
+
uncond_tokens = [negative_prompt]
|
| 488 |
+
elif batch_size != len(negative_prompt):
|
| 489 |
+
raise ValueError(
|
| 490 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 491 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 492 |
+
" the batch size of `prompt`."
|
| 493 |
+
)
|
| 494 |
+
else:
|
| 495 |
+
uncond_tokens = negative_prompt
|
| 496 |
+
|
| 497 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 498 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 499 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
| 500 |
+
|
| 501 |
+
max_length = prompt_embeds.shape[1]
|
| 502 |
+
uncond_input = self.tokenizer(
|
| 503 |
+
uncond_tokens,
|
| 504 |
+
padding="max_length",
|
| 505 |
+
max_length=max_length,
|
| 506 |
+
truncation=True,
|
| 507 |
+
return_tensors="pt",
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 511 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
| 512 |
+
else:
|
| 513 |
+
attention_mask = None
|
| 514 |
+
|
| 515 |
+
negative_prompt_embeds = self.text_encoder(
|
| 516 |
+
uncond_input.input_ids.to(device),
|
| 517 |
+
attention_mask=attention_mask,
|
| 518 |
+
)
|
| 519 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 520 |
+
|
| 521 |
+
if do_classifier_free_guidance:
|
| 522 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 523 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 524 |
+
|
| 525 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 526 |
+
|
| 527 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 528 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 529 |
+
|
| 530 |
+
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 531 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 532 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 533 |
+
|
| 534 |
+
return prompt_embeds, negative_prompt_embeds
|
| 535 |
+
|
| 536 |
+
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
| 537 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 538 |
+
|
| 539 |
+
if not isinstance(image, torch.Tensor):
|
| 540 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 541 |
+
|
| 542 |
+
image = image.to(device=device, dtype=dtype)
|
| 543 |
+
if output_hidden_states:
|
| 544 |
+
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
| 545 |
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| 546 |
+
uncond_image_enc_hidden_states = self.image_encoder(
|
| 547 |
+
torch.zeros_like(image), output_hidden_states=True
|
| 548 |
+
).hidden_states[-2]
|
| 549 |
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
| 550 |
+
num_images_per_prompt, dim=0
|
| 551 |
+
)
|
| 552 |
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
| 553 |
+
else:
|
| 554 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 555 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 556 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
| 557 |
+
|
| 558 |
+
return image_embeds, uncond_image_embeds
|
| 559 |
+
|
| 560 |
+
def run_safety_checker(self, image, device, dtype):
|
| 561 |
+
if self.safety_checker is None:
|
| 562 |
+
has_nsfw_concept = None
|
| 563 |
+
else:
|
| 564 |
+
if torch.is_tensor(image):
|
| 565 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
| 566 |
+
else:
|
| 567 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
| 568 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
| 569 |
+
image, has_nsfw_concept = self.safety_checker(
|
| 570 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
| 571 |
+
)
|
| 572 |
+
return image, has_nsfw_concept
|
| 573 |
+
|
| 574 |
+
def decode_latents(self, latents):
|
| 575 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
| 576 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
| 577 |
+
|
| 578 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 579 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 580 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 581 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 582 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 583 |
+
return image
|
| 584 |
+
|
| 585 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 586 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 587 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 588 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 589 |
+
# and should be between [0, 1]
|
| 590 |
+
|
| 591 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 592 |
+
extra_step_kwargs = {}
|
| 593 |
+
if accepts_eta:
|
| 594 |
+
extra_step_kwargs["eta"] = eta
|
| 595 |
+
|
| 596 |
+
# check if the scheduler accepts generator
|
| 597 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 598 |
+
if accepts_generator:
|
| 599 |
+
extra_step_kwargs["generator"] = generator
|
| 600 |
+
return extra_step_kwargs
|
| 601 |
+
|
| 602 |
+
def check_inputs(
|
| 603 |
+
self,
|
| 604 |
+
prompt,
|
| 605 |
+
strength,
|
| 606 |
+
callback_steps,
|
| 607 |
+
negative_prompt=None,
|
| 608 |
+
prompt_embeds=None,
|
| 609 |
+
negative_prompt_embeds=None,
|
| 610 |
+
callback_on_step_end_tensor_inputs=None,
|
| 611 |
+
):
|
| 612 |
+
if strength < 0 or strength > 1:
|
| 613 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
| 614 |
+
|
| 615 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
| 616 |
+
raise ValueError(
|
| 617 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 618 |
+
f" {type(callback_steps)}."
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 622 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 623 |
+
):
|
| 624 |
+
raise ValueError(
|
| 625 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 626 |
+
)
|
| 627 |
+
if prompt is not None and prompt_embeds is not None:
|
| 628 |
+
raise ValueError(
|
| 629 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 630 |
+
" only forward one of the two."
|
| 631 |
+
)
|
| 632 |
+
elif prompt is None and prompt_embeds is None:
|
| 633 |
+
raise ValueError(
|
| 634 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 635 |
+
)
|
| 636 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 637 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 638 |
+
|
| 639 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 640 |
+
raise ValueError(
|
| 641 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 642 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 646 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 647 |
+
raise ValueError(
|
| 648 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 649 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 650 |
+
f" {negative_prompt_embeds.shape}."
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
| 654 |
+
# get the original timestep using init_timestep
|
| 655 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 656 |
+
|
| 657 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 658 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
| 659 |
+
|
| 660 |
+
return timesteps, num_inference_steps - t_start
|
| 661 |
+
|
| 662 |
+
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
| 663 |
+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
| 664 |
+
raise ValueError(
|
| 665 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
image = image.to(device=device, dtype=dtype)
|
| 669 |
+
|
| 670 |
+
batch_size = batch_size * num_images_per_prompt
|
| 671 |
+
|
| 672 |
+
if image.shape[1] == 4:
|
| 673 |
+
init_latents = image
|
| 674 |
+
|
| 675 |
+
else:
|
| 676 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 677 |
+
raise ValueError(
|
| 678 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 679 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
elif isinstance(generator, list):
|
| 683 |
+
init_latents = [
|
| 684 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
| 685 |
+
for i in range(batch_size)
|
| 686 |
+
]
|
| 687 |
+
init_latents = torch.cat(init_latents, dim=0)
|
| 688 |
+
else:
|
| 689 |
+
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
| 690 |
+
|
| 691 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
| 692 |
+
|
| 693 |
+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
| 694 |
+
# expand init_latents for batch_size
|
| 695 |
+
deprecation_message = (
|
| 696 |
+
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
| 697 |
+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
| 698 |
+
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
| 699 |
+
" your script to pass as many initial images as text prompts to suppress this warning."
|
| 700 |
+
)
|
| 701 |
+
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
| 702 |
+
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
| 703 |
+
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
| 704 |
+
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
| 705 |
+
raise ValueError(
|
| 706 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
| 707 |
+
)
|
| 708 |
+
else:
|
| 709 |
+
init_latents = torch.cat([init_latents], dim=0)
|
| 710 |
+
|
| 711 |
+
shape = init_latents.shape
|
| 712 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 713 |
+
|
| 714 |
+
# get latents
|
| 715 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
| 716 |
+
latents = init_latents
|
| 717 |
+
|
| 718 |
+
return latents
|
| 719 |
+
|
| 720 |
+
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
| 721 |
+
"""
|
| 722 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
| 723 |
+
|
| 724 |
+
Args:
|
| 725 |
+
timesteps (`torch.Tensor`):
|
| 726 |
+
generate embedding vectors at these timesteps
|
| 727 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
| 728 |
+
dimension of the embeddings to generate
|
| 729 |
+
dtype:
|
| 730 |
+
data type of the generated embeddings
|
| 731 |
+
|
| 732 |
+
Returns:
|
| 733 |
+
`torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
| 734 |
+
"""
|
| 735 |
+
assert len(w.shape) == 1
|
| 736 |
+
w = w * 1000.0
|
| 737 |
+
|
| 738 |
+
half_dim = embedding_dim // 2
|
| 739 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
| 740 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
| 741 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
| 742 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 743 |
+
if embedding_dim % 2 == 1: # zero pad
|
| 744 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
| 745 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
| 746 |
+
return emb
|
| 747 |
+
|
| 748 |
+
@property
|
| 749 |
+
def guidance_scale(self):
|
| 750 |
+
return self._guidance_scale
|
| 751 |
+
|
| 752 |
+
@property
|
| 753 |
+
def clip_skip(self):
|
| 754 |
+
return self._clip_skip
|
| 755 |
+
|
| 756 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 757 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 758 |
+
# corresponds to doing no classifier free guidance.
|
| 759 |
+
@property
|
| 760 |
+
def do_classifier_free_guidance(self):
|
| 761 |
+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
| 762 |
+
|
| 763 |
+
@property
|
| 764 |
+
def cross_attention_kwargs(self):
|
| 765 |
+
return self._cross_attention_kwargs
|
| 766 |
+
|
| 767 |
+
@property
|
| 768 |
+
def num_timesteps(self):
|
| 769 |
+
return self._num_timesteps
|
| 770 |
+
|
| 771 |
+
@torch.no_grad()
|
| 772 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 773 |
+
def __call__(
|
| 774 |
+
self,
|
| 775 |
+
prompt: Union[str, List[str]] = None,
|
| 776 |
+
image: PipelineImageInput = None,
|
| 777 |
+
strength: float = 0.8,
|
| 778 |
+
num_inference_steps: Optional[int] = 50,
|
| 779 |
+
timesteps: List[int] = None,
|
| 780 |
+
sigmas: List[float] = None,
|
| 781 |
+
guidance_scale: Optional[float] = 7.5,
|
| 782 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 783 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 784 |
+
eta: Optional[float] = 0.0,
|
| 785 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 786 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 787 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 788 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 789 |
+
output_type: Optional[str] = "pil",
|
| 790 |
+
return_dict: bool = True,
|
| 791 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 792 |
+
clip_skip: int = None,
|
| 793 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 794 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 795 |
+
**kwargs,
|
| 796 |
+
):
|
| 797 |
+
r"""
|
| 798 |
+
The call function to the pipeline for generation.
|
| 799 |
+
|
| 800 |
+
Args:
|
| 801 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 802 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 803 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 804 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
| 805 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
| 806 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
| 807 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
| 808 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
| 809 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 810 |
+
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
| 811 |
+
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
| 812 |
+
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
| 813 |
+
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
| 814 |
+
essentially ignores `image`.
|
| 815 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 816 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 817 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
| 818 |
+
timesteps (`List[int]`, *optional*):
|
| 819 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 820 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 821 |
+
passed will be used. Must be in descending order.
|
| 822 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 823 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 824 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 825 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 826 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 827 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 828 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 829 |
+
The number of images to generate per prompt.
|
| 830 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 831 |
+
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
|
| 832 |
+
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 833 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 834 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 835 |
+
generation deterministic.
|
| 836 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 837 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 838 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 839 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 840 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 841 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 842 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 843 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 844 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 845 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 846 |
+
Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
|
| 847 |
+
plain tuple.
|
| 848 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 849 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 850 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 851 |
+
clip_skip (`int`, *optional*):
|
| 852 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 853 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 854 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 855 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 856 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 857 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 858 |
+
`callback_on_step_end_tensor_inputs`.
|
| 859 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 860 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 861 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 862 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 863 |
+
Examples:
|
| 864 |
+
|
| 865 |
+
Returns:
|
| 866 |
+
[`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`:
|
| 867 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] is returned,
|
| 868 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
| 869 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
| 870 |
+
"not-safe-for-work" (nsfw) content.
|
| 871 |
+
"""
|
| 872 |
+
|
| 873 |
+
callback = kwargs.pop("callback", None)
|
| 874 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 875 |
+
|
| 876 |
+
if callback is not None:
|
| 877 |
+
deprecate(
|
| 878 |
+
"callback",
|
| 879 |
+
"1.0.0",
|
| 880 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 881 |
+
)
|
| 882 |
+
if callback_steps is not None:
|
| 883 |
+
deprecate(
|
| 884 |
+
"callback_steps",
|
| 885 |
+
"1.0.0",
|
| 886 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
# 1. Check inputs. Raise error if not correct
|
| 890 |
+
self.check_inputs(
|
| 891 |
+
prompt,
|
| 892 |
+
strength,
|
| 893 |
+
callback_steps,
|
| 894 |
+
negative_prompt,
|
| 895 |
+
prompt_embeds,
|
| 896 |
+
negative_prompt_embeds,
|
| 897 |
+
callback_on_step_end_tensor_inputs,
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
self._guidance_scale = guidance_scale
|
| 901 |
+
self._clip_skip = clip_skip
|
| 902 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 903 |
+
|
| 904 |
+
# 2. Define call parameters
|
| 905 |
+
if prompt is not None and isinstance(prompt, str):
|
| 906 |
+
batch_size = 1
|
| 907 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 908 |
+
batch_size = len(prompt)
|
| 909 |
+
else:
|
| 910 |
+
batch_size = prompt_embeds.shape[0]
|
| 911 |
+
|
| 912 |
+
device = self._execution_device
|
| 913 |
+
|
| 914 |
+
# 3. Encode input prompt
|
| 915 |
+
text_encoder_lora_scale = (
|
| 916 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 917 |
+
)
|
| 918 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 919 |
+
prompt,
|
| 920 |
+
device,
|
| 921 |
+
num_images_per_prompt,
|
| 922 |
+
self.do_classifier_free_guidance,
|
| 923 |
+
negative_prompt,
|
| 924 |
+
prompt_embeds=prompt_embeds,
|
| 925 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 926 |
+
lora_scale=text_encoder_lora_scale,
|
| 927 |
+
clip_skip=self.clip_skip,
|
| 928 |
+
)
|
| 929 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 930 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 931 |
+
# to avoid doing two forward passes
|
| 932 |
+
if self.do_classifier_free_guidance:
|
| 933 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 934 |
+
|
| 935 |
+
if ip_adapter_image is not None:
|
| 936 |
+
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
| 937 |
+
image_embeds, negative_image_embeds = self.encode_image(
|
| 938 |
+
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
| 939 |
+
)
|
| 940 |
+
if self.do_classifier_free_guidance:
|
| 941 |
+
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
| 942 |
+
|
| 943 |
+
# 4. Preprocess image
|
| 944 |
+
image = self.image_processor.preprocess(image)
|
| 945 |
+
|
| 946 |
+
# 5. set timesteps
|
| 947 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 948 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
| 949 |
+
)
|
| 950 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
| 951 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
| 952 |
+
|
| 953 |
+
# 6. Prepare latent variables
|
| 954 |
+
latents = self.prepare_latents(
|
| 955 |
+
image,
|
| 956 |
+
latent_timestep,
|
| 957 |
+
batch_size,
|
| 958 |
+
num_images_per_prompt,
|
| 959 |
+
prompt_embeds.dtype,
|
| 960 |
+
device,
|
| 961 |
+
generator,
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 965 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 966 |
+
|
| 967 |
+
# 7.1 Add image embeds for IP-Adapter
|
| 968 |
+
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
| 969 |
+
|
| 970 |
+
# 7.2 Optionally get Guidance Scale Embedding
|
| 971 |
+
timestep_cond = None
|
| 972 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
| 973 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
| 974 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
| 975 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 976 |
+
).to(device=device, dtype=latents.dtype)
|
| 977 |
+
|
| 978 |
+
# 8. Denoising loop
|
| 979 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 980 |
+
self._num_timesteps = len(timesteps)
|
| 981 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 982 |
+
for i, t in enumerate(timesteps):
|
| 983 |
+
# expand the latents if we are doing classifier free guidance
|
| 984 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 985 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 986 |
+
|
| 987 |
+
# predict the noise residual
|
| 988 |
+
noise_pred = self.unet(
|
| 989 |
+
latent_model_input,
|
| 990 |
+
t,
|
| 991 |
+
encoder_hidden_states=prompt_embeds,
|
| 992 |
+
timestep_cond=timestep_cond,
|
| 993 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 994 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 995 |
+
return_dict=False,
|
| 996 |
+
)[0]
|
| 997 |
+
|
| 998 |
+
# perform guidance
|
| 999 |
+
if self.do_classifier_free_guidance:
|
| 1000 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1001 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1002 |
+
|
| 1003 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1004 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 1005 |
+
|
| 1006 |
+
if callback_on_step_end is not None:
|
| 1007 |
+
callback_kwargs = {}
|
| 1008 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1009 |
+
callback_kwargs[k] = locals()[k]
|
| 1010 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1011 |
+
|
| 1012 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1013 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1014 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 1015 |
+
|
| 1016 |
+
# call the callback, if provided
|
| 1017 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1018 |
+
progress_bar.update()
|
| 1019 |
+
if callback is not None and i % callback_steps == 0:
|
| 1020 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 1021 |
+
callback(step_idx, t, latents)
|
| 1022 |
+
|
| 1023 |
+
if not output_type == "latent":
|
| 1024 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
| 1025 |
+
0
|
| 1026 |
+
]
|
| 1027 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
| 1028 |
+
else:
|
| 1029 |
+
image = latents
|
| 1030 |
+
has_nsfw_concept = None
|
| 1031 |
+
|
| 1032 |
+
if has_nsfw_concept is None:
|
| 1033 |
+
do_denormalize = [True] * image.shape[0]
|
| 1034 |
+
else:
|
| 1035 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 1036 |
+
|
| 1037 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 1038 |
+
|
| 1039 |
+
# Offload all models
|
| 1040 |
+
self.maybe_free_model_hooks()
|
| 1041 |
+
|
| 1042 |
+
if not return_dict:
|
| 1043 |
+
return (image, has_nsfw_concept)
|
| 1044 |
+
|
| 1045 |
+
return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/deprecated/alt_diffusion/pipeline_output.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Optional, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import PIL.Image
|
| 6 |
+
|
| 7 |
+
from ....utils import (
|
| 8 |
+
BaseOutput,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput with Stable->Alt
|
| 14 |
+
class AltDiffusionPipelineOutput(BaseOutput):
|
| 15 |
+
"""
|
| 16 |
+
Output class for Alt Diffusion pipelines.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
| 20 |
+
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
| 21 |
+
num_channels)`.
|
| 22 |
+
nsfw_content_detected (`List[bool]`)
|
| 23 |
+
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
|
| 24 |
+
`None` if safety checking could not be performed.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
| 28 |
+
nsfw_content_detected: Optional[List[bool]]
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/hunyuan_video/__init__.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
get_objects_from_module,
|
| 8 |
+
is_torch_available,
|
| 9 |
+
is_transformers_available,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_dummy_objects = {}
|
| 14 |
+
_import_structure = {}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 19 |
+
raise OptionalDependencyNotAvailable()
|
| 20 |
+
except OptionalDependencyNotAvailable:
|
| 21 |
+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
| 22 |
+
|
| 23 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
| 24 |
+
else:
|
| 25 |
+
_import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"]
|
| 26 |
+
_import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"]
|
| 27 |
+
_import_structure["pipeline_hunyuan_video_framepack"] = ["HunyuanVideoFramepackPipeline"]
|
| 28 |
+
_import_structure["pipeline_hunyuan_video_image2video"] = ["HunyuanVideoImageToVideoPipeline"]
|
| 29 |
+
|
| 30 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 31 |
+
try:
|
| 32 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 33 |
+
raise OptionalDependencyNotAvailable()
|
| 34 |
+
|
| 35 |
+
except OptionalDependencyNotAvailable:
|
| 36 |
+
from ...utils.dummy_torch_and_transformers_objects import *
|
| 37 |
+
else:
|
| 38 |
+
from .pipeline_hunyuan_skyreels_image2video import HunyuanSkyreelsImageToVideoPipeline
|
| 39 |
+
from .pipeline_hunyuan_video import HunyuanVideoPipeline
|
| 40 |
+
from .pipeline_hunyuan_video_framepack import HunyuanVideoFramepackPipeline
|
| 41 |
+
from .pipeline_hunyuan_video_image2video import HunyuanVideoImageToVideoPipeline
|
| 42 |
+
|
| 43 |
+
else:
|
| 44 |
+
import sys
|
| 45 |
+
|
| 46 |
+
sys.modules[__name__] = _LazyModule(
|
| 47 |
+
__name__,
|
| 48 |
+
globals()["__file__"],
|
| 49 |
+
_import_structure,
|
| 50 |
+
module_spec=__spec__,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
for name, value in _dummy_objects.items():
|
| 54 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/hunyuan_video/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.39 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/hunyuan_video/__pycache__/pipeline_hunyuan_skyreels_image2video.cpython-310.pyc
ADDED
|
Binary file (28.5 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/hunyuan_video/__pycache__/pipeline_hunyuan_video.cpython-310.pyc
ADDED
|
Binary file (26 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/hunyuan_video/__pycache__/pipeline_hunyuan_video_framepack.cpython-310.pyc
ADDED
|
Binary file (36.4 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/hunyuan_video/__pycache__/pipeline_hunyuan_video_image2video.cpython-310.pyc
ADDED
|
Binary file (31.9 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/hunyuan_video/__pycache__/pipeline_output.cpython-310.pyc
ADDED
|
Binary file (1.94 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/hunyuan_video/pipeline_output.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import PIL.Image
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from diffusers.utils import BaseOutput
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class HunyuanVideoPipelineOutput(BaseOutput):
|
| 13 |
+
r"""
|
| 14 |
+
Output class for HunyuanVideo pipelines.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 18 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 19 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 20 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
frames: torch.Tensor
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class HunyuanVideoFramepackPipelineOutput(BaseOutput):
|
| 28 |
+
r"""
|
| 29 |
+
Output class for HunyuanVideo pipelines.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 33 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 34 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 35 |
+
`(batch_size, num_frames, channels, height, width)`. Or, a list of torch tensors where each tensor
|
| 36 |
+
corresponds to a latent that decodes to multiple frames.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]], List[torch.Tensor]]
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/i2vgen_xl/__init__.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
get_objects_from_module,
|
| 8 |
+
is_torch_available,
|
| 9 |
+
is_transformers_available,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_dummy_objects = {}
|
| 14 |
+
_import_structure = {}
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 18 |
+
raise OptionalDependencyNotAvailable()
|
| 19 |
+
except OptionalDependencyNotAvailable:
|
| 20 |
+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
| 21 |
+
|
| 22 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
| 23 |
+
else:
|
| 24 |
+
_import_structure["pipeline_i2vgen_xl"] = ["I2VGenXLPipeline"]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 28 |
+
try:
|
| 29 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 30 |
+
raise OptionalDependencyNotAvailable()
|
| 31 |
+
except OptionalDependencyNotAvailable:
|
| 32 |
+
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
| 33 |
+
else:
|
| 34 |
+
from .pipeline_i2vgen_xl import I2VGenXLPipeline
|
| 35 |
+
|
| 36 |
+
else:
|
| 37 |
+
import sys
|
| 38 |
+
|
| 39 |
+
sys.modules[__name__] = _LazyModule(
|
| 40 |
+
__name__,
|
| 41 |
+
globals()["__file__"],
|
| 42 |
+
_import_structure,
|
| 43 |
+
module_spec=__spec__,
|
| 44 |
+
)
|
| 45 |
+
for name, value in _dummy_objects.items():
|
| 46 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/i2vgen_xl/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.04 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/i2vgen_xl/__pycache__/pipeline_i2vgen_xl.cpython-310.pyc
ADDED
|
Binary file (22.6 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
ADDED
|
@@ -0,0 +1,797 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import PIL
|
| 21 |
+
import torch
|
| 22 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 23 |
+
|
| 24 |
+
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
| 25 |
+
from ...models import AutoencoderKL
|
| 26 |
+
from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet
|
| 27 |
+
from ...schedulers import DDIMScheduler
|
| 28 |
+
from ...utils import (
|
| 29 |
+
BaseOutput,
|
| 30 |
+
is_torch_xla_available,
|
| 31 |
+
logging,
|
| 32 |
+
replace_example_docstring,
|
| 33 |
+
)
|
| 34 |
+
from ...utils.torch_utils import randn_tensor
|
| 35 |
+
from ...video_processor import VideoProcessor
|
| 36 |
+
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if is_torch_xla_available():
|
| 40 |
+
import torch_xla.core.xla_model as xm
|
| 41 |
+
|
| 42 |
+
XLA_AVAILABLE = True
|
| 43 |
+
else:
|
| 44 |
+
XLA_AVAILABLE = False
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
EXAMPLE_DOC_STRING = """
|
| 50 |
+
Examples:
|
| 51 |
+
```py
|
| 52 |
+
>>> import torch
|
| 53 |
+
>>> from diffusers import I2VGenXLPipeline
|
| 54 |
+
>>> from diffusers.utils import export_to_gif, load_image
|
| 55 |
+
|
| 56 |
+
>>> pipeline = I2VGenXLPipeline.from_pretrained(
|
| 57 |
+
... "ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16"
|
| 58 |
+
... )
|
| 59 |
+
>>> pipeline.enable_model_cpu_offload()
|
| 60 |
+
|
| 61 |
+
>>> image_url = (
|
| 62 |
+
... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png"
|
| 63 |
+
... )
|
| 64 |
+
>>> image = load_image(image_url).convert("RGB")
|
| 65 |
+
|
| 66 |
+
>>> prompt = "Papers were floating in the air on a table in the library"
|
| 67 |
+
>>> negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
|
| 68 |
+
>>> generator = torch.manual_seed(8888)
|
| 69 |
+
|
| 70 |
+
>>> frames = pipeline(
|
| 71 |
+
... prompt=prompt,
|
| 72 |
+
... image=image,
|
| 73 |
+
... num_inference_steps=50,
|
| 74 |
+
... negative_prompt=negative_prompt,
|
| 75 |
+
... guidance_scale=9.0,
|
| 76 |
+
... generator=generator,
|
| 77 |
+
... ).frames[0]
|
| 78 |
+
>>> video_path = export_to_gif(frames, "i2v.gif")
|
| 79 |
+
```
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@dataclass
|
| 84 |
+
class I2VGenXLPipelineOutput(BaseOutput):
|
| 85 |
+
r"""
|
| 86 |
+
Output class for image-to-video pipeline.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 90 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 91 |
+
denoised
|
| 92 |
+
PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 93 |
+
`(batch_size, num_frames, channels, height, width)`
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class I2VGenXLPipeline(
|
| 100 |
+
DeprecatedPipelineMixin,
|
| 101 |
+
DiffusionPipeline,
|
| 102 |
+
StableDiffusionMixin,
|
| 103 |
+
):
|
| 104 |
+
_last_supported_version = "0.33.1"
|
| 105 |
+
r"""
|
| 106 |
+
Pipeline for image-to-video generation as proposed in [I2VGenXL](https://i2vgen-xl.github.io/).
|
| 107 |
+
|
| 108 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 109 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
vae ([`AutoencoderKL`]):
|
| 113 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 114 |
+
text_encoder ([`CLIPTextModel`]):
|
| 115 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
| 116 |
+
tokenizer (`CLIPTokenizer`):
|
| 117 |
+
A [`~transformers.CLIPTokenizer`] to tokenize text.
|
| 118 |
+
unet ([`I2VGenXLUNet`]):
|
| 119 |
+
A [`I2VGenXLUNet`] to denoise the encoded video latents.
|
| 120 |
+
scheduler ([`DDIMScheduler`]):
|
| 121 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
| 125 |
+
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
vae: AutoencoderKL,
|
| 129 |
+
text_encoder: CLIPTextModel,
|
| 130 |
+
tokenizer: CLIPTokenizer,
|
| 131 |
+
image_encoder: CLIPVisionModelWithProjection,
|
| 132 |
+
feature_extractor: CLIPImageProcessor,
|
| 133 |
+
unet: I2VGenXLUNet,
|
| 134 |
+
scheduler: DDIMScheduler,
|
| 135 |
+
):
|
| 136 |
+
super().__init__()
|
| 137 |
+
|
| 138 |
+
self.register_modules(
|
| 139 |
+
vae=vae,
|
| 140 |
+
text_encoder=text_encoder,
|
| 141 |
+
tokenizer=tokenizer,
|
| 142 |
+
image_encoder=image_encoder,
|
| 143 |
+
feature_extractor=feature_extractor,
|
| 144 |
+
unet=unet,
|
| 145 |
+
scheduler=scheduler,
|
| 146 |
+
)
|
| 147 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 148 |
+
# `do_resize=False` as we do custom resizing.
|
| 149 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False)
|
| 150 |
+
|
| 151 |
+
@property
|
| 152 |
+
def guidance_scale(self):
|
| 153 |
+
return self._guidance_scale
|
| 154 |
+
|
| 155 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 156 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 157 |
+
# corresponds to doing no classifier free guidance.
|
| 158 |
+
@property
|
| 159 |
+
def do_classifier_free_guidance(self):
|
| 160 |
+
return self._guidance_scale > 1
|
| 161 |
+
|
| 162 |
+
def encode_prompt(
|
| 163 |
+
self,
|
| 164 |
+
prompt,
|
| 165 |
+
device,
|
| 166 |
+
num_videos_per_prompt,
|
| 167 |
+
negative_prompt=None,
|
| 168 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 169 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 170 |
+
clip_skip: Optional[int] = None,
|
| 171 |
+
):
|
| 172 |
+
r"""
|
| 173 |
+
Encodes the prompt into text encoder hidden states.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 177 |
+
prompt to be encoded
|
| 178 |
+
device: (`torch.device`):
|
| 179 |
+
torch device
|
| 180 |
+
num_videos_per_prompt (`int`):
|
| 181 |
+
number of images that should be generated per prompt
|
| 182 |
+
do_classifier_free_guidance (`bool`):
|
| 183 |
+
whether to use classifier free guidance or not
|
| 184 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 185 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 186 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 187 |
+
less than `1`).
|
| 188 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 189 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 190 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 191 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 192 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 193 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 194 |
+
argument.
|
| 195 |
+
clip_skip (`int`, *optional*):
|
| 196 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 197 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 198 |
+
"""
|
| 199 |
+
if prompt is not None and isinstance(prompt, str):
|
| 200 |
+
batch_size = 1
|
| 201 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 202 |
+
batch_size = len(prompt)
|
| 203 |
+
else:
|
| 204 |
+
batch_size = prompt_embeds.shape[0]
|
| 205 |
+
|
| 206 |
+
if prompt_embeds is None:
|
| 207 |
+
text_inputs = self.tokenizer(
|
| 208 |
+
prompt,
|
| 209 |
+
padding="max_length",
|
| 210 |
+
max_length=self.tokenizer.model_max_length,
|
| 211 |
+
truncation=True,
|
| 212 |
+
return_tensors="pt",
|
| 213 |
+
)
|
| 214 |
+
text_input_ids = text_inputs.input_ids
|
| 215 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 216 |
+
|
| 217 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
| 218 |
+
text_input_ids, untruncated_ids
|
| 219 |
+
):
|
| 220 |
+
removed_text = self.tokenizer.batch_decode(
|
| 221 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 222 |
+
)
|
| 223 |
+
logger.warning(
|
| 224 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 225 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 229 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
| 230 |
+
else:
|
| 231 |
+
attention_mask = None
|
| 232 |
+
|
| 233 |
+
if clip_skip is None:
|
| 234 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
| 235 |
+
prompt_embeds = prompt_embeds[0]
|
| 236 |
+
else:
|
| 237 |
+
prompt_embeds = self.text_encoder(
|
| 238 |
+
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
| 239 |
+
)
|
| 240 |
+
# Access the `hidden_states` first, that contains a tuple of
|
| 241 |
+
# all the hidden states from the encoder layers. Then index into
|
| 242 |
+
# the tuple to access the hidden states from the desired layer.
|
| 243 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
| 244 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
| 245 |
+
# representations. The `last_hidden_states` that we typically use for
|
| 246 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
| 247 |
+
# layer.
|
| 248 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
| 249 |
+
|
| 250 |
+
if self.text_encoder is not None:
|
| 251 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
| 252 |
+
elif self.unet is not None:
|
| 253 |
+
prompt_embeds_dtype = self.unet.dtype
|
| 254 |
+
else:
|
| 255 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
| 256 |
+
|
| 257 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 258 |
+
|
| 259 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 260 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 261 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 262 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
| 263 |
+
|
| 264 |
+
# get unconditional embeddings for classifier free guidance
|
| 265 |
+
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 266 |
+
uncond_tokens: List[str]
|
| 267 |
+
if negative_prompt is None:
|
| 268 |
+
uncond_tokens = [""] * batch_size
|
| 269 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
| 270 |
+
raise TypeError(
|
| 271 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 272 |
+
f" {type(prompt)}."
|
| 273 |
+
)
|
| 274 |
+
elif isinstance(negative_prompt, str):
|
| 275 |
+
uncond_tokens = [negative_prompt]
|
| 276 |
+
elif batch_size != len(negative_prompt):
|
| 277 |
+
raise ValueError(
|
| 278 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 279 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 280 |
+
" the batch size of `prompt`."
|
| 281 |
+
)
|
| 282 |
+
else:
|
| 283 |
+
uncond_tokens = negative_prompt
|
| 284 |
+
|
| 285 |
+
max_length = prompt_embeds.shape[1]
|
| 286 |
+
uncond_input = self.tokenizer(
|
| 287 |
+
uncond_tokens,
|
| 288 |
+
padding="max_length",
|
| 289 |
+
max_length=max_length,
|
| 290 |
+
truncation=True,
|
| 291 |
+
return_tensors="pt",
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 295 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
| 296 |
+
else:
|
| 297 |
+
attention_mask = None
|
| 298 |
+
|
| 299 |
+
# Apply clip_skip to negative prompt embeds
|
| 300 |
+
if clip_skip is None:
|
| 301 |
+
negative_prompt_embeds = self.text_encoder(
|
| 302 |
+
uncond_input.input_ids.to(device),
|
| 303 |
+
attention_mask=attention_mask,
|
| 304 |
+
)
|
| 305 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 306 |
+
else:
|
| 307 |
+
negative_prompt_embeds = self.text_encoder(
|
| 308 |
+
uncond_input.input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
| 309 |
+
)
|
| 310 |
+
# Access the `hidden_states` first, that contains a tuple of
|
| 311 |
+
# all the hidden states from the encoder layers. Then index into
|
| 312 |
+
# the tuple to access the hidden states from the desired layer.
|
| 313 |
+
negative_prompt_embeds = negative_prompt_embeds[-1][-(clip_skip + 1)]
|
| 314 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
| 315 |
+
# representations. The `last_hidden_states` that we typically use for
|
| 316 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
| 317 |
+
# layer.
|
| 318 |
+
negative_prompt_embeds = self.text_encoder.text_model.final_layer_norm(negative_prompt_embeds)
|
| 319 |
+
|
| 320 |
+
if self.do_classifier_free_guidance:
|
| 321 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 322 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 323 |
+
|
| 324 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 325 |
+
|
| 326 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 327 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 328 |
+
|
| 329 |
+
return prompt_embeds, negative_prompt_embeds
|
| 330 |
+
|
| 331 |
+
def _encode_image(self, image, device, num_videos_per_prompt):
|
| 332 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 333 |
+
|
| 334 |
+
if not isinstance(image, torch.Tensor):
|
| 335 |
+
image = self.video_processor.pil_to_numpy(image)
|
| 336 |
+
image = self.video_processor.numpy_to_pt(image)
|
| 337 |
+
|
| 338 |
+
# Normalize the image with CLIP training stats.
|
| 339 |
+
image = self.feature_extractor(
|
| 340 |
+
images=image,
|
| 341 |
+
do_normalize=True,
|
| 342 |
+
do_center_crop=False,
|
| 343 |
+
do_resize=False,
|
| 344 |
+
do_rescale=False,
|
| 345 |
+
return_tensors="pt",
|
| 346 |
+
).pixel_values
|
| 347 |
+
|
| 348 |
+
image = image.to(device=device, dtype=dtype)
|
| 349 |
+
image_embeddings = self.image_encoder(image).image_embeds
|
| 350 |
+
image_embeddings = image_embeddings.unsqueeze(1)
|
| 351 |
+
|
| 352 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
| 353 |
+
bs_embed, seq_len, _ = image_embeddings.shape
|
| 354 |
+
image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
|
| 355 |
+
image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
| 356 |
+
|
| 357 |
+
if self.do_classifier_free_guidance:
|
| 358 |
+
negative_image_embeddings = torch.zeros_like(image_embeddings)
|
| 359 |
+
image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
|
| 360 |
+
|
| 361 |
+
return image_embeddings
|
| 362 |
+
|
| 363 |
+
def decode_latents(self, latents, decode_chunk_size=None):
|
| 364 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 365 |
+
|
| 366 |
+
batch_size, channels, num_frames, height, width = latents.shape
|
| 367 |
+
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
| 368 |
+
|
| 369 |
+
if decode_chunk_size is not None:
|
| 370 |
+
frames = []
|
| 371 |
+
for i in range(0, latents.shape[0], decode_chunk_size):
|
| 372 |
+
frame = self.vae.decode(latents[i : i + decode_chunk_size]).sample
|
| 373 |
+
frames.append(frame)
|
| 374 |
+
image = torch.cat(frames, dim=0)
|
| 375 |
+
else:
|
| 376 |
+
image = self.vae.decode(latents).sample
|
| 377 |
+
|
| 378 |
+
decode_shape = (batch_size, num_frames, -1) + image.shape[2:]
|
| 379 |
+
video = image[None, :].reshape(decode_shape).permute(0, 2, 1, 3, 4)
|
| 380 |
+
|
| 381 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 382 |
+
video = video.float()
|
| 383 |
+
return video
|
| 384 |
+
|
| 385 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 386 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 387 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 388 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 389 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 390 |
+
# and should be between [0, 1]
|
| 391 |
+
|
| 392 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 393 |
+
extra_step_kwargs = {}
|
| 394 |
+
if accepts_eta:
|
| 395 |
+
extra_step_kwargs["eta"] = eta
|
| 396 |
+
|
| 397 |
+
# check if the scheduler accepts generator
|
| 398 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 399 |
+
if accepts_generator:
|
| 400 |
+
extra_step_kwargs["generator"] = generator
|
| 401 |
+
return extra_step_kwargs
|
| 402 |
+
|
| 403 |
+
def check_inputs(
|
| 404 |
+
self,
|
| 405 |
+
prompt,
|
| 406 |
+
image,
|
| 407 |
+
height,
|
| 408 |
+
width,
|
| 409 |
+
negative_prompt=None,
|
| 410 |
+
prompt_embeds=None,
|
| 411 |
+
negative_prompt_embeds=None,
|
| 412 |
+
):
|
| 413 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 414 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 415 |
+
|
| 416 |
+
if prompt is not None and prompt_embeds is not None:
|
| 417 |
+
raise ValueError(
|
| 418 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 419 |
+
" only forward one of the two."
|
| 420 |
+
)
|
| 421 |
+
elif prompt is None and prompt_embeds is None:
|
| 422 |
+
raise ValueError(
|
| 423 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 424 |
+
)
|
| 425 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 426 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 427 |
+
|
| 428 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 429 |
+
raise ValueError(
|
| 430 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 431 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 435 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 436 |
+
raise ValueError(
|
| 437 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 438 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 439 |
+
f" {negative_prompt_embeds.shape}."
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
if (
|
| 443 |
+
not isinstance(image, torch.Tensor)
|
| 444 |
+
and not isinstance(image, PIL.Image.Image)
|
| 445 |
+
and not isinstance(image, list)
|
| 446 |
+
):
|
| 447 |
+
raise ValueError(
|
| 448 |
+
"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
| 449 |
+
f" {type(image)}"
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
def prepare_image_latents(
|
| 453 |
+
self,
|
| 454 |
+
image,
|
| 455 |
+
device,
|
| 456 |
+
num_frames,
|
| 457 |
+
num_videos_per_prompt,
|
| 458 |
+
):
|
| 459 |
+
image = image.to(device=device)
|
| 460 |
+
image_latents = self.vae.encode(image).latent_dist.sample()
|
| 461 |
+
image_latents = image_latents * self.vae.config.scaling_factor
|
| 462 |
+
|
| 463 |
+
# Add frames dimension to image latents
|
| 464 |
+
image_latents = image_latents.unsqueeze(2)
|
| 465 |
+
|
| 466 |
+
# Append a position mask for each subsequent frame
|
| 467 |
+
# after the initial image latent frame
|
| 468 |
+
frame_position_mask = []
|
| 469 |
+
for frame_idx in range(num_frames - 1):
|
| 470 |
+
scale = (frame_idx + 1) / (num_frames - 1)
|
| 471 |
+
frame_position_mask.append(torch.ones_like(image_latents[:, :, :1]) * scale)
|
| 472 |
+
if frame_position_mask:
|
| 473 |
+
frame_position_mask = torch.cat(frame_position_mask, dim=2)
|
| 474 |
+
image_latents = torch.cat([image_latents, frame_position_mask], dim=2)
|
| 475 |
+
|
| 476 |
+
# duplicate image_latents for each generation per prompt, using mps friendly method
|
| 477 |
+
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1, 1)
|
| 478 |
+
|
| 479 |
+
if self.do_classifier_free_guidance:
|
| 480 |
+
image_latents = torch.cat([image_latents] * 2)
|
| 481 |
+
|
| 482 |
+
return image_latents
|
| 483 |
+
|
| 484 |
+
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
| 485 |
+
def prepare_latents(
|
| 486 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 487 |
+
):
|
| 488 |
+
shape = (
|
| 489 |
+
batch_size,
|
| 490 |
+
num_channels_latents,
|
| 491 |
+
num_frames,
|
| 492 |
+
height // self.vae_scale_factor,
|
| 493 |
+
width // self.vae_scale_factor,
|
| 494 |
+
)
|
| 495 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 496 |
+
raise ValueError(
|
| 497 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 498 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
if latents is None:
|
| 502 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 503 |
+
else:
|
| 504 |
+
latents = latents.to(device)
|
| 505 |
+
|
| 506 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 507 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 508 |
+
return latents
|
| 509 |
+
|
| 510 |
+
@torch.no_grad()
|
| 511 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 512 |
+
def __call__(
|
| 513 |
+
self,
|
| 514 |
+
prompt: Union[str, List[str]] = None,
|
| 515 |
+
image: PipelineImageInput = None,
|
| 516 |
+
height: Optional[int] = 704,
|
| 517 |
+
width: Optional[int] = 1280,
|
| 518 |
+
target_fps: Optional[int] = 16,
|
| 519 |
+
num_frames: int = 16,
|
| 520 |
+
num_inference_steps: int = 50,
|
| 521 |
+
guidance_scale: float = 9.0,
|
| 522 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 523 |
+
eta: float = 0.0,
|
| 524 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 525 |
+
decode_chunk_size: Optional[int] = 1,
|
| 526 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 527 |
+
latents: Optional[torch.Tensor] = None,
|
| 528 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 529 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 530 |
+
output_type: Optional[str] = "pil",
|
| 531 |
+
return_dict: bool = True,
|
| 532 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 533 |
+
clip_skip: Optional[int] = 1,
|
| 534 |
+
):
|
| 535 |
+
r"""
|
| 536 |
+
The call function to the pipeline for image-to-video generation with [`I2VGenXLPipeline`].
|
| 537 |
+
|
| 538 |
+
Args:
|
| 539 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 540 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 541 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`):
|
| 542 |
+
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
|
| 543 |
+
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
|
| 544 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 545 |
+
The height in pixels of the generated image.
|
| 546 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 547 |
+
The width in pixels of the generated image.
|
| 548 |
+
target_fps (`int`, *optional*):
|
| 549 |
+
Frames per second. The rate at which the generated images shall be exported to a video after
|
| 550 |
+
generation. This is also used as a "micro-condition" while generation.
|
| 551 |
+
num_frames (`int`, *optional*):
|
| 552 |
+
The number of video frames to generate.
|
| 553 |
+
num_inference_steps (`int`, *optional*):
|
| 554 |
+
The number of denoising steps.
|
| 555 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 556 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 557 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 558 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 559 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 560 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 561 |
+
eta (`float`, *optional*):
|
| 562 |
+
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
|
| 563 |
+
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 564 |
+
num_videos_per_prompt (`int`, *optional*):
|
| 565 |
+
The number of images to generate per prompt.
|
| 566 |
+
decode_chunk_size (`int`, *optional*):
|
| 567 |
+
The number of frames to decode at a time. The higher the chunk size, the higher the temporal
|
| 568 |
+
consistency between frames, but also the higher the memory consumption. By default, the decoder will
|
| 569 |
+
decode all frames at once for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
|
| 570 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 571 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 572 |
+
generation deterministic.
|
| 573 |
+
latents (`torch.Tensor`, *optional*):
|
| 574 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 575 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 576 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 577 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 578 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 579 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 580 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 581 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 582 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 583 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 584 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 585 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 586 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 587 |
+
plain tuple.
|
| 588 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 589 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 590 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 591 |
+
clip_skip (`int`, *optional*):
|
| 592 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 593 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 594 |
+
|
| 595 |
+
Examples:
|
| 596 |
+
|
| 597 |
+
Returns:
|
| 598 |
+
[`pipelines.i2vgen_xl.pipeline_i2vgen_xl.I2VGenXLPipelineOutput`] or `tuple`:
|
| 599 |
+
If `return_dict` is `True`, [`pipelines.i2vgen_xl.pipeline_i2vgen_xl.I2VGenXLPipelineOutput`] is
|
| 600 |
+
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
| 601 |
+
"""
|
| 602 |
+
# 0. Default height and width to unet
|
| 603 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 604 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 605 |
+
|
| 606 |
+
# 1. Check inputs. Raise error if not correct
|
| 607 |
+
self.check_inputs(prompt, image, height, width, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
| 608 |
+
|
| 609 |
+
# 2. Define call parameters
|
| 610 |
+
if prompt is not None and isinstance(prompt, str):
|
| 611 |
+
batch_size = 1
|
| 612 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 613 |
+
batch_size = len(prompt)
|
| 614 |
+
else:
|
| 615 |
+
batch_size = prompt_embeds.shape[0]
|
| 616 |
+
|
| 617 |
+
device = self._execution_device
|
| 618 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 619 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 620 |
+
# corresponds to doing no classifier free guidance.
|
| 621 |
+
self._guidance_scale = guidance_scale
|
| 622 |
+
|
| 623 |
+
# 3.1 Encode input text prompt
|
| 624 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 625 |
+
prompt,
|
| 626 |
+
device,
|
| 627 |
+
num_videos_per_prompt,
|
| 628 |
+
negative_prompt,
|
| 629 |
+
prompt_embeds=prompt_embeds,
|
| 630 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 631 |
+
clip_skip=clip_skip,
|
| 632 |
+
)
|
| 633 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 634 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 635 |
+
# to avoid doing two forward passes
|
| 636 |
+
if self.do_classifier_free_guidance:
|
| 637 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 638 |
+
|
| 639 |
+
# 3.2 Encode image prompt
|
| 640 |
+
# 3.2.1 Image encodings.
|
| 641 |
+
# https://github.com/ali-vilab/i2vgen-xl/blob/2539c9262ff8a2a22fa9daecbfd13f0a2dbc32d0/tools/inferences/inference_i2vgen_entrance.py#L114
|
| 642 |
+
cropped_image = _center_crop_wide(image, (width, width))
|
| 643 |
+
cropped_image = _resize_bilinear(
|
| 644 |
+
cropped_image, (self.feature_extractor.crop_size["width"], self.feature_extractor.crop_size["height"])
|
| 645 |
+
)
|
| 646 |
+
image_embeddings = self._encode_image(cropped_image, device, num_videos_per_prompt)
|
| 647 |
+
|
| 648 |
+
# 3.2.2 Image latents.
|
| 649 |
+
resized_image = _center_crop_wide(image, (width, height))
|
| 650 |
+
image = self.video_processor.preprocess(resized_image).to(device=device, dtype=image_embeddings.dtype)
|
| 651 |
+
image_latents = self.prepare_image_latents(
|
| 652 |
+
image,
|
| 653 |
+
device=device,
|
| 654 |
+
num_frames=num_frames,
|
| 655 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
# 3.3 Prepare additional conditions for the UNet.
|
| 659 |
+
if self.do_classifier_free_guidance:
|
| 660 |
+
fps_tensor = torch.tensor([target_fps, target_fps]).to(device)
|
| 661 |
+
else:
|
| 662 |
+
fps_tensor = torch.tensor([target_fps]).to(device)
|
| 663 |
+
fps_tensor = fps_tensor.repeat(batch_size * num_videos_per_prompt, 1).ravel()
|
| 664 |
+
|
| 665 |
+
# 4. Prepare timesteps
|
| 666 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 667 |
+
timesteps = self.scheduler.timesteps
|
| 668 |
+
|
| 669 |
+
# 5. Prepare latent variables
|
| 670 |
+
num_channels_latents = self.unet.config.in_channels
|
| 671 |
+
latents = self.prepare_latents(
|
| 672 |
+
batch_size * num_videos_per_prompt,
|
| 673 |
+
num_channels_latents,
|
| 674 |
+
num_frames,
|
| 675 |
+
height,
|
| 676 |
+
width,
|
| 677 |
+
prompt_embeds.dtype,
|
| 678 |
+
device,
|
| 679 |
+
generator,
|
| 680 |
+
latents,
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 684 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 685 |
+
|
| 686 |
+
# 7. Denoising loop
|
| 687 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 688 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 689 |
+
for i, t in enumerate(timesteps):
|
| 690 |
+
# expand the latents if we are doing classifier free guidance
|
| 691 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 692 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 693 |
+
|
| 694 |
+
# predict the noise residual
|
| 695 |
+
noise_pred = self.unet(
|
| 696 |
+
latent_model_input,
|
| 697 |
+
t,
|
| 698 |
+
encoder_hidden_states=prompt_embeds,
|
| 699 |
+
fps=fps_tensor,
|
| 700 |
+
image_latents=image_latents,
|
| 701 |
+
image_embeddings=image_embeddings,
|
| 702 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 703 |
+
return_dict=False,
|
| 704 |
+
)[0]
|
| 705 |
+
|
| 706 |
+
# perform guidance
|
| 707 |
+
if self.do_classifier_free_guidance:
|
| 708 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 709 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 710 |
+
|
| 711 |
+
# reshape latents
|
| 712 |
+
batch_size, channel, frames, width, height = latents.shape
|
| 713 |
+
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channel, width, height)
|
| 714 |
+
noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channel, width, height)
|
| 715 |
+
|
| 716 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 717 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 718 |
+
|
| 719 |
+
# reshape latents back
|
| 720 |
+
latents = latents[None, :].reshape(batch_size, frames, channel, width, height).permute(0, 2, 1, 3, 4)
|
| 721 |
+
# call the callback, if provided
|
| 722 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 723 |
+
progress_bar.update()
|
| 724 |
+
|
| 725 |
+
if XLA_AVAILABLE:
|
| 726 |
+
xm.mark_step()
|
| 727 |
+
|
| 728 |
+
# 8. Post processing
|
| 729 |
+
if output_type == "latent":
|
| 730 |
+
video = latents
|
| 731 |
+
else:
|
| 732 |
+
video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size)
|
| 733 |
+
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
| 734 |
+
|
| 735 |
+
# 9. Offload all models
|
| 736 |
+
self.maybe_free_model_hooks()
|
| 737 |
+
|
| 738 |
+
if not return_dict:
|
| 739 |
+
return (video,)
|
| 740 |
+
|
| 741 |
+
return I2VGenXLPipelineOutput(frames=video)
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
# The following utilities are taken and adapted from
|
| 745 |
+
# https://github.com/ali-vilab/i2vgen-xl/blob/main/utils/transforms.py.
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
def _convert_pt_to_pil(image: Union[torch.Tensor, List[torch.Tensor]]):
|
| 749 |
+
if isinstance(image, list) and isinstance(image[0], torch.Tensor):
|
| 750 |
+
image = torch.cat(image, 0)
|
| 751 |
+
|
| 752 |
+
if isinstance(image, torch.Tensor):
|
| 753 |
+
if image.ndim == 3:
|
| 754 |
+
image = image.unsqueeze(0)
|
| 755 |
+
|
| 756 |
+
image_numpy = VaeImageProcessor.pt_to_numpy(image)
|
| 757 |
+
image_pil = VaeImageProcessor.numpy_to_pil(image_numpy)
|
| 758 |
+
image = image_pil
|
| 759 |
+
|
| 760 |
+
return image
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
def _resize_bilinear(
|
| 764 |
+
image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], resolution: Tuple[int, int]
|
| 765 |
+
):
|
| 766 |
+
# First convert the images to PIL in case they are float tensors (only relevant for tests now).
|
| 767 |
+
image = _convert_pt_to_pil(image)
|
| 768 |
+
|
| 769 |
+
if isinstance(image, list):
|
| 770 |
+
image = [u.resize(resolution, PIL.Image.BILINEAR) for u in image]
|
| 771 |
+
else:
|
| 772 |
+
image = image.resize(resolution, PIL.Image.BILINEAR)
|
| 773 |
+
return image
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
def _center_crop_wide(
|
| 777 |
+
image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], resolution: Tuple[int, int]
|
| 778 |
+
):
|
| 779 |
+
# First convert the images to PIL in case they are float tensors (only relevant for tests now).
|
| 780 |
+
image = _convert_pt_to_pil(image)
|
| 781 |
+
|
| 782 |
+
if isinstance(image, list):
|
| 783 |
+
scale = min(image[0].size[0] / resolution[0], image[0].size[1] / resolution[1])
|
| 784 |
+
image = [u.resize((round(u.width // scale), round(u.height // scale)), resample=PIL.Image.BOX) for u in image]
|
| 785 |
+
|
| 786 |
+
# center crop
|
| 787 |
+
x1 = (image[0].width - resolution[0]) // 2
|
| 788 |
+
y1 = (image[0].height - resolution[1]) // 2
|
| 789 |
+
image = [u.crop((x1, y1, x1 + resolution[0], y1 + resolution[1])) for u in image]
|
| 790 |
+
return image
|
| 791 |
+
else:
|
| 792 |
+
scale = min(image.size[0] / resolution[0], image.size[1] / resolution[1])
|
| 793 |
+
image = image.resize((round(image.width // scale), round(image.height // scale)), resample=PIL.Image.BOX)
|
| 794 |
+
x1 = (image.width - resolution[0]) // 2
|
| 795 |
+
y1 = (image.height - resolution[1]) // 2
|
| 796 |
+
image = image.crop((x1, y1, x1 + resolution[0], y1 + resolution[1]))
|
| 797 |
+
return image
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/__init__.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
get_objects_from_module,
|
| 8 |
+
is_torch_available,
|
| 9 |
+
is_transformers_available,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_dummy_objects = {}
|
| 14 |
+
_import_structure = {}
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 18 |
+
raise OptionalDependencyNotAvailable()
|
| 19 |
+
except OptionalDependencyNotAvailable:
|
| 20 |
+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
| 21 |
+
|
| 22 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
| 23 |
+
else:
|
| 24 |
+
_import_structure["pipeline_kandinsky"] = ["KandinskyPipeline"]
|
| 25 |
+
_import_structure["pipeline_kandinsky_combined"] = [
|
| 26 |
+
"KandinskyCombinedPipeline",
|
| 27 |
+
"KandinskyImg2ImgCombinedPipeline",
|
| 28 |
+
"KandinskyInpaintCombinedPipeline",
|
| 29 |
+
]
|
| 30 |
+
_import_structure["pipeline_kandinsky_img2img"] = ["KandinskyImg2ImgPipeline"]
|
| 31 |
+
_import_structure["pipeline_kandinsky_inpaint"] = ["KandinskyInpaintPipeline"]
|
| 32 |
+
_import_structure["pipeline_kandinsky_prior"] = ["KandinskyPriorPipeline", "KandinskyPriorPipelineOutput"]
|
| 33 |
+
_import_structure["text_encoder"] = ["MultilingualCLIP"]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 37 |
+
try:
|
| 38 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 39 |
+
raise OptionalDependencyNotAvailable()
|
| 40 |
+
except OptionalDependencyNotAvailable:
|
| 41 |
+
from ...utils.dummy_torch_and_transformers_objects import *
|
| 42 |
+
|
| 43 |
+
else:
|
| 44 |
+
from .pipeline_kandinsky import KandinskyPipeline
|
| 45 |
+
from .pipeline_kandinsky_combined import (
|
| 46 |
+
KandinskyCombinedPipeline,
|
| 47 |
+
KandinskyImg2ImgCombinedPipeline,
|
| 48 |
+
KandinskyInpaintCombinedPipeline,
|
| 49 |
+
)
|
| 50 |
+
from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline
|
| 51 |
+
from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline
|
| 52 |
+
from .pipeline_kandinsky_prior import KandinskyPriorPipeline, KandinskyPriorPipelineOutput
|
| 53 |
+
from .text_encoder import MultilingualCLIP
|
| 54 |
+
|
| 55 |
+
else:
|
| 56 |
+
import sys
|
| 57 |
+
|
| 58 |
+
sys.modules[__name__] = _LazyModule(
|
| 59 |
+
__name__,
|
| 60 |
+
globals()["__file__"],
|
| 61 |
+
_import_structure,
|
| 62 |
+
module_spec=__spec__,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
for name, value in _dummy_objects.items():
|
| 66 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky.cpython-310.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/__pycache__/pipeline_kandinsky_combined.cpython-310.pyc
ADDED
|
Binary file (32 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Callable, List, Optional, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from transformers import (
|
| 19 |
+
XLMRobertaTokenizer,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
from ...models import UNet2DConditionModel, VQModel
|
| 23 |
+
from ...schedulers import DDIMScheduler, DDPMScheduler
|
| 24 |
+
from ...utils import (
|
| 25 |
+
is_torch_xla_available,
|
| 26 |
+
logging,
|
| 27 |
+
replace_example_docstring,
|
| 28 |
+
)
|
| 29 |
+
from ...utils.torch_utils import randn_tensor
|
| 30 |
+
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 31 |
+
from .text_encoder import MultilingualCLIP
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if is_torch_xla_available():
|
| 35 |
+
import torch_xla.core.xla_model as xm
|
| 36 |
+
|
| 37 |
+
XLA_AVAILABLE = True
|
| 38 |
+
else:
|
| 39 |
+
XLA_AVAILABLE = False
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
EXAMPLE_DOC_STRING = """
|
| 45 |
+
Examples:
|
| 46 |
+
```py
|
| 47 |
+
>>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline
|
| 48 |
+
>>> import torch
|
| 49 |
+
|
| 50 |
+
>>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/Kandinsky-2-1-prior")
|
| 51 |
+
>>> pipe_prior.to("cuda")
|
| 52 |
+
|
| 53 |
+
>>> prompt = "red cat, 4k photo"
|
| 54 |
+
>>> out = pipe_prior(prompt)
|
| 55 |
+
>>> image_emb = out.image_embeds
|
| 56 |
+
>>> negative_image_emb = out.negative_image_embeds
|
| 57 |
+
|
| 58 |
+
>>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1")
|
| 59 |
+
>>> pipe.to("cuda")
|
| 60 |
+
|
| 61 |
+
>>> image = pipe(
|
| 62 |
+
... prompt,
|
| 63 |
+
... image_embeds=image_emb,
|
| 64 |
+
... negative_image_embeds=negative_image_emb,
|
| 65 |
+
... height=768,
|
| 66 |
+
... width=768,
|
| 67 |
+
... num_inference_steps=100,
|
| 68 |
+
... ).images
|
| 69 |
+
|
| 70 |
+
>>> image[0].save("cat.png")
|
| 71 |
+
```
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_new_h_w(h, w, scale_factor=8):
|
| 76 |
+
new_h = h // scale_factor**2
|
| 77 |
+
if h % scale_factor**2 != 0:
|
| 78 |
+
new_h += 1
|
| 79 |
+
new_w = w // scale_factor**2
|
| 80 |
+
if w % scale_factor**2 != 0:
|
| 81 |
+
new_w += 1
|
| 82 |
+
return new_h * scale_factor, new_w * scale_factor
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class KandinskyPipeline(DiffusionPipeline):
|
| 86 |
+
"""
|
| 87 |
+
Pipeline for text-to-image generation using Kandinsky
|
| 88 |
+
|
| 89 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 90 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
text_encoder ([`MultilingualCLIP`]):
|
| 94 |
+
Frozen text-encoder.
|
| 95 |
+
tokenizer ([`XLMRobertaTokenizer`]):
|
| 96 |
+
Tokenizer of class
|
| 97 |
+
scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]):
|
| 98 |
+
A scheduler to be used in combination with `unet` to generate image latents.
|
| 99 |
+
unet ([`UNet2DConditionModel`]):
|
| 100 |
+
Conditional U-Net architecture to denoise the image embedding.
|
| 101 |
+
movq ([`VQModel`]):
|
| 102 |
+
MoVQ Decoder to generate the image from the latents.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
model_cpu_offload_seq = "text_encoder->unet->movq"
|
| 106 |
+
|
| 107 |
+
def __init__(
|
| 108 |
+
self,
|
| 109 |
+
text_encoder: MultilingualCLIP,
|
| 110 |
+
tokenizer: XLMRobertaTokenizer,
|
| 111 |
+
unet: UNet2DConditionModel,
|
| 112 |
+
scheduler: Union[DDIMScheduler, DDPMScheduler],
|
| 113 |
+
movq: VQModel,
|
| 114 |
+
):
|
| 115 |
+
super().__init__()
|
| 116 |
+
|
| 117 |
+
self.register_modules(
|
| 118 |
+
text_encoder=text_encoder,
|
| 119 |
+
tokenizer=tokenizer,
|
| 120 |
+
unet=unet,
|
| 121 |
+
scheduler=scheduler,
|
| 122 |
+
movq=movq,
|
| 123 |
+
)
|
| 124 |
+
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
|
| 125 |
+
|
| 126 |
+
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
| 127 |
+
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
| 128 |
+
if latents is None:
|
| 129 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 130 |
+
else:
|
| 131 |
+
if latents.shape != shape:
|
| 132 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
| 133 |
+
latents = latents.to(device)
|
| 134 |
+
|
| 135 |
+
latents = latents * scheduler.init_noise_sigma
|
| 136 |
+
return latents
|
| 137 |
+
|
| 138 |
+
def _encode_prompt(
|
| 139 |
+
self,
|
| 140 |
+
prompt,
|
| 141 |
+
device,
|
| 142 |
+
num_images_per_prompt,
|
| 143 |
+
do_classifier_free_guidance,
|
| 144 |
+
negative_prompt=None,
|
| 145 |
+
):
|
| 146 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
| 147 |
+
# get prompt text embeddings
|
| 148 |
+
text_inputs = self.tokenizer(
|
| 149 |
+
prompt,
|
| 150 |
+
padding="max_length",
|
| 151 |
+
truncation=True,
|
| 152 |
+
max_length=77,
|
| 153 |
+
return_attention_mask=True,
|
| 154 |
+
add_special_tokens=True,
|
| 155 |
+
return_tensors="pt",
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
text_input_ids = text_inputs.input_ids
|
| 159 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 160 |
+
|
| 161 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 162 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
| 163 |
+
logger.warning(
|
| 164 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 165 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
text_input_ids = text_input_ids.to(device)
|
| 169 |
+
text_mask = text_inputs.attention_mask.to(device)
|
| 170 |
+
|
| 171 |
+
prompt_embeds, text_encoder_hidden_states = self.text_encoder(
|
| 172 |
+
input_ids=text_input_ids, attention_mask=text_mask
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 176 |
+
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| 177 |
+
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
| 178 |
+
|
| 179 |
+
if do_classifier_free_guidance:
|
| 180 |
+
uncond_tokens: List[str]
|
| 181 |
+
if negative_prompt is None:
|
| 182 |
+
uncond_tokens = [""] * batch_size
|
| 183 |
+
elif type(prompt) is not type(negative_prompt):
|
| 184 |
+
raise TypeError(
|
| 185 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 186 |
+
f" {type(prompt)}."
|
| 187 |
+
)
|
| 188 |
+
elif isinstance(negative_prompt, str):
|
| 189 |
+
uncond_tokens = [negative_prompt]
|
| 190 |
+
elif batch_size != len(negative_prompt):
|
| 191 |
+
raise ValueError(
|
| 192 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 193 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 194 |
+
" the batch size of `prompt`."
|
| 195 |
+
)
|
| 196 |
+
else:
|
| 197 |
+
uncond_tokens = negative_prompt
|
| 198 |
+
|
| 199 |
+
uncond_input = self.tokenizer(
|
| 200 |
+
uncond_tokens,
|
| 201 |
+
padding="max_length",
|
| 202 |
+
max_length=77,
|
| 203 |
+
truncation=True,
|
| 204 |
+
return_attention_mask=True,
|
| 205 |
+
add_special_tokens=True,
|
| 206 |
+
return_tensors="pt",
|
| 207 |
+
)
|
| 208 |
+
uncond_text_input_ids = uncond_input.input_ids.to(device)
|
| 209 |
+
uncond_text_mask = uncond_input.attention_mask.to(device)
|
| 210 |
+
|
| 211 |
+
negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder(
|
| 212 |
+
input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 216 |
+
|
| 217 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 218 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
|
| 219 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
|
| 220 |
+
|
| 221 |
+
seq_len = uncond_text_encoder_hidden_states.shape[1]
|
| 222 |
+
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
|
| 223 |
+
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
|
| 224 |
+
batch_size * num_images_per_prompt, seq_len, -1
|
| 225 |
+
)
|
| 226 |
+
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
| 227 |
+
|
| 228 |
+
# done duplicates
|
| 229 |
+
|
| 230 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 231 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 232 |
+
# to avoid doing two forward passes
|
| 233 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 234 |
+
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
|
| 235 |
+
|
| 236 |
+
text_mask = torch.cat([uncond_text_mask, text_mask])
|
| 237 |
+
|
| 238 |
+
return prompt_embeds, text_encoder_hidden_states, text_mask
|
| 239 |
+
|
| 240 |
+
@torch.no_grad()
|
| 241 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 242 |
+
def __call__(
|
| 243 |
+
self,
|
| 244 |
+
prompt: Union[str, List[str]],
|
| 245 |
+
image_embeds: Union[torch.Tensor, List[torch.Tensor]],
|
| 246 |
+
negative_image_embeds: Union[torch.Tensor, List[torch.Tensor]],
|
| 247 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 248 |
+
height: int = 512,
|
| 249 |
+
width: int = 512,
|
| 250 |
+
num_inference_steps: int = 100,
|
| 251 |
+
guidance_scale: float = 4.0,
|
| 252 |
+
num_images_per_prompt: int = 1,
|
| 253 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 254 |
+
latents: Optional[torch.Tensor] = None,
|
| 255 |
+
output_type: Optional[str] = "pil",
|
| 256 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
| 257 |
+
callback_steps: int = 1,
|
| 258 |
+
return_dict: bool = True,
|
| 259 |
+
):
|
| 260 |
+
"""
|
| 261 |
+
Function invoked when calling the pipeline for generation.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
prompt (`str` or `List[str]`):
|
| 265 |
+
The prompt or prompts to guide the image generation.
|
| 266 |
+
image_embeds (`torch.Tensor` or `List[torch.Tensor]`):
|
| 267 |
+
The clip image embeddings for text prompt, that will be used to condition the image generation.
|
| 268 |
+
negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`):
|
| 269 |
+
The clip image embeddings for negative text prompt, will be used to condition the image generation.
|
| 270 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 271 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 272 |
+
if `guidance_scale` is less than `1`).
|
| 273 |
+
height (`int`, *optional*, defaults to 512):
|
| 274 |
+
The height in pixels of the generated image.
|
| 275 |
+
width (`int`, *optional*, defaults to 512):
|
| 276 |
+
The width in pixels of the generated image.
|
| 277 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
| 278 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 279 |
+
expense of slower inference.
|
| 280 |
+
guidance_scale (`float`, *optional*, defaults to 4.0):
|
| 281 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 282 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 283 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 284 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 285 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 286 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 287 |
+
The number of images to generate per prompt.
|
| 288 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 289 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 290 |
+
to make generation deterministic.
|
| 291 |
+
latents (`torch.Tensor`, *optional*):
|
| 292 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 293 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 294 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 295 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 296 |
+
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
| 297 |
+
(`np.array`) or `"pt"` (`torch.Tensor`).
|
| 298 |
+
callback (`Callable`, *optional*):
|
| 299 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 300 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
| 301 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 302 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 303 |
+
every step.
|
| 304 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 305 |
+
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
| 306 |
+
|
| 307 |
+
Examples:
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
if isinstance(prompt, str):
|
| 314 |
+
batch_size = 1
|
| 315 |
+
elif isinstance(prompt, list):
|
| 316 |
+
batch_size = len(prompt)
|
| 317 |
+
else:
|
| 318 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 319 |
+
|
| 320 |
+
device = self._execution_device
|
| 321 |
+
|
| 322 |
+
batch_size = batch_size * num_images_per_prompt
|
| 323 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 324 |
+
|
| 325 |
+
prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt(
|
| 326 |
+
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
if isinstance(image_embeds, list):
|
| 330 |
+
image_embeds = torch.cat(image_embeds, dim=0)
|
| 331 |
+
if isinstance(negative_image_embeds, list):
|
| 332 |
+
negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
|
| 333 |
+
|
| 334 |
+
if do_classifier_free_guidance:
|
| 335 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 336 |
+
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 337 |
+
|
| 338 |
+
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
|
| 339 |
+
dtype=prompt_embeds.dtype, device=device
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 343 |
+
timesteps_tensor = self.scheduler.timesteps
|
| 344 |
+
|
| 345 |
+
num_channels_latents = self.unet.config.in_channels
|
| 346 |
+
|
| 347 |
+
height, width = get_new_h_w(height, width, self.movq_scale_factor)
|
| 348 |
+
|
| 349 |
+
# create initial latent
|
| 350 |
+
latents = self.prepare_latents(
|
| 351 |
+
(batch_size, num_channels_latents, height, width),
|
| 352 |
+
text_encoder_hidden_states.dtype,
|
| 353 |
+
device,
|
| 354 |
+
generator,
|
| 355 |
+
latents,
|
| 356 |
+
self.scheduler,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
| 360 |
+
# expand the latents if we are doing classifier free guidance
|
| 361 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 362 |
+
|
| 363 |
+
added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds}
|
| 364 |
+
noise_pred = self.unet(
|
| 365 |
+
sample=latent_model_input,
|
| 366 |
+
timestep=t,
|
| 367 |
+
encoder_hidden_states=text_encoder_hidden_states,
|
| 368 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 369 |
+
return_dict=False,
|
| 370 |
+
)[0]
|
| 371 |
+
|
| 372 |
+
if do_classifier_free_guidance:
|
| 373 |
+
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
|
| 374 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 375 |
+
_, variance_pred_text = variance_pred.chunk(2)
|
| 376 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 377 |
+
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
|
| 378 |
+
|
| 379 |
+
if not (
|
| 380 |
+
hasattr(self.scheduler.config, "variance_type")
|
| 381 |
+
and self.scheduler.config.variance_type in ["learned", "learned_range"]
|
| 382 |
+
):
|
| 383 |
+
noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
|
| 384 |
+
|
| 385 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 386 |
+
latents = self.scheduler.step(
|
| 387 |
+
noise_pred,
|
| 388 |
+
t,
|
| 389 |
+
latents,
|
| 390 |
+
generator=generator,
|
| 391 |
+
).prev_sample
|
| 392 |
+
|
| 393 |
+
if callback is not None and i % callback_steps == 0:
|
| 394 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 395 |
+
callback(step_idx, t, latents)
|
| 396 |
+
|
| 397 |
+
if XLA_AVAILABLE:
|
| 398 |
+
xm.mark_step()
|
| 399 |
+
|
| 400 |
+
# post-processing
|
| 401 |
+
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
| 402 |
+
|
| 403 |
+
self.maybe_free_model_hooks()
|
| 404 |
+
|
| 405 |
+
if output_type not in ["pt", "np", "pil"]:
|
| 406 |
+
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
|
| 407 |
+
|
| 408 |
+
if output_type in ["np", "pil"]:
|
| 409 |
+
image = image * 0.5 + 0.5
|
| 410 |
+
image = image.clamp(0, 1)
|
| 411 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 412 |
+
|
| 413 |
+
if output_type == "pil":
|
| 414 |
+
image = self.numpy_to_pil(image)
|
| 415 |
+
|
| 416 |
+
if not return_dict:
|
| 417 |
+
return (image,)
|
| 418 |
+
|
| 419 |
+
return ImagePipelineOutput(images=image)
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
ADDED
|
@@ -0,0 +1,817 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Callable, List, Optional, Union
|
| 15 |
+
|
| 16 |
+
import PIL.Image
|
| 17 |
+
import torch
|
| 18 |
+
from transformers import (
|
| 19 |
+
CLIPImageProcessor,
|
| 20 |
+
CLIPTextModelWithProjection,
|
| 21 |
+
CLIPTokenizer,
|
| 22 |
+
CLIPVisionModelWithProjection,
|
| 23 |
+
XLMRobertaTokenizer,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from ...models import PriorTransformer, UNet2DConditionModel, VQModel
|
| 27 |
+
from ...schedulers import DDIMScheduler, DDPMScheduler, UnCLIPScheduler
|
| 28 |
+
from ...utils import (
|
| 29 |
+
replace_example_docstring,
|
| 30 |
+
)
|
| 31 |
+
from ..pipeline_utils import DiffusionPipeline
|
| 32 |
+
from .pipeline_kandinsky import KandinskyPipeline
|
| 33 |
+
from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline
|
| 34 |
+
from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline
|
| 35 |
+
from .pipeline_kandinsky_prior import KandinskyPriorPipeline
|
| 36 |
+
from .text_encoder import MultilingualCLIP
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
TEXT2IMAGE_EXAMPLE_DOC_STRING = """
|
| 40 |
+
Examples:
|
| 41 |
+
```py
|
| 42 |
+
from diffusers import AutoPipelineForText2Image
|
| 43 |
+
import torch
|
| 44 |
+
|
| 45 |
+
pipe = AutoPipelineForText2Image.from_pretrained(
|
| 46 |
+
"kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16
|
| 47 |
+
)
|
| 48 |
+
pipe.enable_model_cpu_offload()
|
| 49 |
+
|
| 50 |
+
prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k"
|
| 51 |
+
|
| 52 |
+
image = pipe(prompt=prompt, num_inference_steps=25).images[0]
|
| 53 |
+
```
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
IMAGE2IMAGE_EXAMPLE_DOC_STRING = """
|
| 57 |
+
Examples:
|
| 58 |
+
```py
|
| 59 |
+
from diffusers import AutoPipelineForImage2Image
|
| 60 |
+
import torch
|
| 61 |
+
import requests
|
| 62 |
+
from io import BytesIO
|
| 63 |
+
from PIL import Image
|
| 64 |
+
import os
|
| 65 |
+
|
| 66 |
+
pipe = AutoPipelineForImage2Image.from_pretrained(
|
| 67 |
+
"kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16
|
| 68 |
+
)
|
| 69 |
+
pipe.enable_model_cpu_offload()
|
| 70 |
+
|
| 71 |
+
prompt = "A fantasy landscape, Cinematic lighting"
|
| 72 |
+
negative_prompt = "low quality, bad quality"
|
| 73 |
+
|
| 74 |
+
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
| 75 |
+
|
| 76 |
+
response = requests.get(url)
|
| 77 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 78 |
+
image.thumbnail((768, 768))
|
| 79 |
+
|
| 80 |
+
image = pipe(prompt=prompt, image=original_image, num_inference_steps=25).images[0]
|
| 81 |
+
```
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
INPAINT_EXAMPLE_DOC_STRING = """
|
| 85 |
+
Examples:
|
| 86 |
+
```py
|
| 87 |
+
from diffusers import AutoPipelineForInpainting
|
| 88 |
+
from diffusers.utils import load_image
|
| 89 |
+
import torch
|
| 90 |
+
import numpy as np
|
| 91 |
+
|
| 92 |
+
pipe = AutoPipelineForInpainting.from_pretrained(
|
| 93 |
+
"kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16
|
| 94 |
+
)
|
| 95 |
+
pipe.enable_model_cpu_offload()
|
| 96 |
+
|
| 97 |
+
prompt = "A fantasy landscape, Cinematic lighting"
|
| 98 |
+
negative_prompt = "low quality, bad quality"
|
| 99 |
+
|
| 100 |
+
original_image = load_image(
|
| 101 |
+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
mask = np.zeros((768, 768), dtype=np.float32)
|
| 105 |
+
# Let's mask out an area above the cat's head
|
| 106 |
+
mask[:250, 250:-250] = 1
|
| 107 |
+
|
| 108 |
+
image = pipe(prompt=prompt, image=original_image, mask_image=mask, num_inference_steps=25).images[0]
|
| 109 |
+
```
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class KandinskyCombinedPipeline(DiffusionPipeline):
|
| 114 |
+
"""
|
| 115 |
+
Combined Pipeline for text-to-image generation using Kandinsky
|
| 116 |
+
|
| 117 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 118 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
text_encoder ([`MultilingualCLIP`]):
|
| 122 |
+
Frozen text-encoder.
|
| 123 |
+
tokenizer ([`XLMRobertaTokenizer`]):
|
| 124 |
+
Tokenizer of class
|
| 125 |
+
scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]):
|
| 126 |
+
A scheduler to be used in combination with `unet` to generate image latents.
|
| 127 |
+
unet ([`UNet2DConditionModel`]):
|
| 128 |
+
Conditional U-Net architecture to denoise the image embedding.
|
| 129 |
+
movq ([`VQModel`]):
|
| 130 |
+
MoVQ Decoder to generate the image from the latents.
|
| 131 |
+
prior_prior ([`PriorTransformer`]):
|
| 132 |
+
The canonical unCLIP prior to approximate the image embedding from the text embedding.
|
| 133 |
+
prior_image_encoder ([`CLIPVisionModelWithProjection`]):
|
| 134 |
+
Frozen image-encoder.
|
| 135 |
+
prior_text_encoder ([`CLIPTextModelWithProjection`]):
|
| 136 |
+
Frozen text-encoder.
|
| 137 |
+
prior_tokenizer (`CLIPTokenizer`):
|
| 138 |
+
Tokenizer of class
|
| 139 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 140 |
+
prior_scheduler ([`UnCLIPScheduler`]):
|
| 141 |
+
A scheduler to be used in combination with `prior` to generate image embedding.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
_load_connected_pipes = True
|
| 145 |
+
model_cpu_offload_seq = "text_encoder->unet->movq->prior_prior->prior_image_encoder->prior_text_encoder"
|
| 146 |
+
_exclude_from_cpu_offload = ["prior_prior"]
|
| 147 |
+
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
text_encoder: MultilingualCLIP,
|
| 151 |
+
tokenizer: XLMRobertaTokenizer,
|
| 152 |
+
unet: UNet2DConditionModel,
|
| 153 |
+
scheduler: Union[DDIMScheduler, DDPMScheduler],
|
| 154 |
+
movq: VQModel,
|
| 155 |
+
prior_prior: PriorTransformer,
|
| 156 |
+
prior_image_encoder: CLIPVisionModelWithProjection,
|
| 157 |
+
prior_text_encoder: CLIPTextModelWithProjection,
|
| 158 |
+
prior_tokenizer: CLIPTokenizer,
|
| 159 |
+
prior_scheduler: UnCLIPScheduler,
|
| 160 |
+
prior_image_processor: CLIPImageProcessor,
|
| 161 |
+
):
|
| 162 |
+
super().__init__()
|
| 163 |
+
|
| 164 |
+
self.register_modules(
|
| 165 |
+
text_encoder=text_encoder,
|
| 166 |
+
tokenizer=tokenizer,
|
| 167 |
+
unet=unet,
|
| 168 |
+
scheduler=scheduler,
|
| 169 |
+
movq=movq,
|
| 170 |
+
prior_prior=prior_prior,
|
| 171 |
+
prior_image_encoder=prior_image_encoder,
|
| 172 |
+
prior_text_encoder=prior_text_encoder,
|
| 173 |
+
prior_tokenizer=prior_tokenizer,
|
| 174 |
+
prior_scheduler=prior_scheduler,
|
| 175 |
+
prior_image_processor=prior_image_processor,
|
| 176 |
+
)
|
| 177 |
+
self.prior_pipe = KandinskyPriorPipeline(
|
| 178 |
+
prior=prior_prior,
|
| 179 |
+
image_encoder=prior_image_encoder,
|
| 180 |
+
text_encoder=prior_text_encoder,
|
| 181 |
+
tokenizer=prior_tokenizer,
|
| 182 |
+
scheduler=prior_scheduler,
|
| 183 |
+
image_processor=prior_image_processor,
|
| 184 |
+
)
|
| 185 |
+
self.decoder_pipe = KandinskyPipeline(
|
| 186 |
+
text_encoder=text_encoder,
|
| 187 |
+
tokenizer=tokenizer,
|
| 188 |
+
unet=unet,
|
| 189 |
+
scheduler=scheduler,
|
| 190 |
+
movq=movq,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
| 194 |
+
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
|
| 195 |
+
|
| 196 |
+
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
|
| 197 |
+
r"""
|
| 198 |
+
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
|
| 199 |
+
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
|
| 200 |
+
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
|
| 201 |
+
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
|
| 202 |
+
"""
|
| 203 |
+
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
| 204 |
+
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
| 205 |
+
|
| 206 |
+
def progress_bar(self, iterable=None, total=None):
|
| 207 |
+
self.prior_pipe.progress_bar(iterable=iterable, total=total)
|
| 208 |
+
self.decoder_pipe.progress_bar(iterable=iterable, total=total)
|
| 209 |
+
self.decoder_pipe.enable_model_cpu_offload()
|
| 210 |
+
|
| 211 |
+
def set_progress_bar_config(self, **kwargs):
|
| 212 |
+
self.prior_pipe.set_progress_bar_config(**kwargs)
|
| 213 |
+
self.decoder_pipe.set_progress_bar_config(**kwargs)
|
| 214 |
+
|
| 215 |
+
@torch.no_grad()
|
| 216 |
+
@replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING)
|
| 217 |
+
def __call__(
|
| 218 |
+
self,
|
| 219 |
+
prompt: Union[str, List[str]],
|
| 220 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 221 |
+
num_inference_steps: int = 100,
|
| 222 |
+
guidance_scale: float = 4.0,
|
| 223 |
+
num_images_per_prompt: int = 1,
|
| 224 |
+
height: int = 512,
|
| 225 |
+
width: int = 512,
|
| 226 |
+
prior_guidance_scale: float = 4.0,
|
| 227 |
+
prior_num_inference_steps: int = 25,
|
| 228 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 229 |
+
latents: Optional[torch.Tensor] = None,
|
| 230 |
+
output_type: Optional[str] = "pil",
|
| 231 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
| 232 |
+
callback_steps: int = 1,
|
| 233 |
+
return_dict: bool = True,
|
| 234 |
+
):
|
| 235 |
+
"""
|
| 236 |
+
Function invoked when calling the pipeline for generation.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
prompt (`str` or `List[str]`):
|
| 240 |
+
The prompt or prompts to guide the image generation.
|
| 241 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 242 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 243 |
+
if `guidance_scale` is less than `1`).
|
| 244 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 245 |
+
The number of images to generate per prompt.
|
| 246 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
| 247 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 248 |
+
expense of slower inference.
|
| 249 |
+
height (`int`, *optional*, defaults to 512):
|
| 250 |
+
The height in pixels of the generated image.
|
| 251 |
+
width (`int`, *optional*, defaults to 512):
|
| 252 |
+
The width in pixels of the generated image.
|
| 253 |
+
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
|
| 254 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 255 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 256 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 257 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 258 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 259 |
+
prior_num_inference_steps (`int`, *optional*, defaults to 100):
|
| 260 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 261 |
+
expense of slower inference.
|
| 262 |
+
guidance_scale (`float`, *optional*, defaults to 4.0):
|
| 263 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 264 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 265 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 266 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 267 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 268 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 269 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 270 |
+
to make generation deterministic.
|
| 271 |
+
latents (`torch.Tensor`, *optional*):
|
| 272 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 273 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 274 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 275 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 276 |
+
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
| 277 |
+
(`np.array`) or `"pt"` (`torch.Tensor`).
|
| 278 |
+
callback (`Callable`, *optional*):
|
| 279 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 280 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
| 281 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 282 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 283 |
+
every step.
|
| 284 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 285 |
+
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
| 286 |
+
|
| 287 |
+
Examples:
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
| 291 |
+
"""
|
| 292 |
+
prior_outputs = self.prior_pipe(
|
| 293 |
+
prompt=prompt,
|
| 294 |
+
negative_prompt=negative_prompt,
|
| 295 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 296 |
+
num_inference_steps=prior_num_inference_steps,
|
| 297 |
+
generator=generator,
|
| 298 |
+
latents=latents,
|
| 299 |
+
guidance_scale=prior_guidance_scale,
|
| 300 |
+
output_type="pt",
|
| 301 |
+
return_dict=False,
|
| 302 |
+
)
|
| 303 |
+
image_embeds = prior_outputs[0]
|
| 304 |
+
negative_image_embeds = prior_outputs[1]
|
| 305 |
+
|
| 306 |
+
prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt
|
| 307 |
+
|
| 308 |
+
if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0:
|
| 309 |
+
prompt = (image_embeds.shape[0] // len(prompt)) * prompt
|
| 310 |
+
|
| 311 |
+
outputs = self.decoder_pipe(
|
| 312 |
+
prompt=prompt,
|
| 313 |
+
image_embeds=image_embeds,
|
| 314 |
+
negative_image_embeds=negative_image_embeds,
|
| 315 |
+
width=width,
|
| 316 |
+
height=height,
|
| 317 |
+
num_inference_steps=num_inference_steps,
|
| 318 |
+
generator=generator,
|
| 319 |
+
guidance_scale=guidance_scale,
|
| 320 |
+
output_type=output_type,
|
| 321 |
+
callback=callback,
|
| 322 |
+
callback_steps=callback_steps,
|
| 323 |
+
return_dict=return_dict,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
self.maybe_free_model_hooks()
|
| 327 |
+
|
| 328 |
+
return outputs
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
|
| 332 |
+
"""
|
| 333 |
+
Combined Pipeline for image-to-image generation using Kandinsky
|
| 334 |
+
|
| 335 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 336 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
text_encoder ([`MultilingualCLIP`]):
|
| 340 |
+
Frozen text-encoder.
|
| 341 |
+
tokenizer ([`XLMRobertaTokenizer`]):
|
| 342 |
+
Tokenizer of class
|
| 343 |
+
scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]):
|
| 344 |
+
A scheduler to be used in combination with `unet` to generate image latents.
|
| 345 |
+
unet ([`UNet2DConditionModel`]):
|
| 346 |
+
Conditional U-Net architecture to denoise the image embedding.
|
| 347 |
+
movq ([`VQModel`]):
|
| 348 |
+
MoVQ Decoder to generate the image from the latents.
|
| 349 |
+
prior_prior ([`PriorTransformer`]):
|
| 350 |
+
The canonical unCLIP prior to approximate the image embedding from the text embedding.
|
| 351 |
+
prior_image_encoder ([`CLIPVisionModelWithProjection`]):
|
| 352 |
+
Frozen image-encoder.
|
| 353 |
+
prior_text_encoder ([`CLIPTextModelWithProjection`]):
|
| 354 |
+
Frozen text-encoder.
|
| 355 |
+
prior_tokenizer (`CLIPTokenizer`):
|
| 356 |
+
Tokenizer of class
|
| 357 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 358 |
+
prior_scheduler ([`UnCLIPScheduler`]):
|
| 359 |
+
A scheduler to be used in combination with `prior` to generate image embedding.
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
_load_connected_pipes = True
|
| 363 |
+
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
|
| 364 |
+
_exclude_from_cpu_offload = ["prior_prior"]
|
| 365 |
+
|
| 366 |
+
def __init__(
|
| 367 |
+
self,
|
| 368 |
+
text_encoder: MultilingualCLIP,
|
| 369 |
+
tokenizer: XLMRobertaTokenizer,
|
| 370 |
+
unet: UNet2DConditionModel,
|
| 371 |
+
scheduler: Union[DDIMScheduler, DDPMScheduler],
|
| 372 |
+
movq: VQModel,
|
| 373 |
+
prior_prior: PriorTransformer,
|
| 374 |
+
prior_image_encoder: CLIPVisionModelWithProjection,
|
| 375 |
+
prior_text_encoder: CLIPTextModelWithProjection,
|
| 376 |
+
prior_tokenizer: CLIPTokenizer,
|
| 377 |
+
prior_scheduler: UnCLIPScheduler,
|
| 378 |
+
prior_image_processor: CLIPImageProcessor,
|
| 379 |
+
):
|
| 380 |
+
super().__init__()
|
| 381 |
+
|
| 382 |
+
self.register_modules(
|
| 383 |
+
text_encoder=text_encoder,
|
| 384 |
+
tokenizer=tokenizer,
|
| 385 |
+
unet=unet,
|
| 386 |
+
scheduler=scheduler,
|
| 387 |
+
movq=movq,
|
| 388 |
+
prior_prior=prior_prior,
|
| 389 |
+
prior_image_encoder=prior_image_encoder,
|
| 390 |
+
prior_text_encoder=prior_text_encoder,
|
| 391 |
+
prior_tokenizer=prior_tokenizer,
|
| 392 |
+
prior_scheduler=prior_scheduler,
|
| 393 |
+
prior_image_processor=prior_image_processor,
|
| 394 |
+
)
|
| 395 |
+
self.prior_pipe = KandinskyPriorPipeline(
|
| 396 |
+
prior=prior_prior,
|
| 397 |
+
image_encoder=prior_image_encoder,
|
| 398 |
+
text_encoder=prior_text_encoder,
|
| 399 |
+
tokenizer=prior_tokenizer,
|
| 400 |
+
scheduler=prior_scheduler,
|
| 401 |
+
image_processor=prior_image_processor,
|
| 402 |
+
)
|
| 403 |
+
self.decoder_pipe = KandinskyImg2ImgPipeline(
|
| 404 |
+
text_encoder=text_encoder,
|
| 405 |
+
tokenizer=tokenizer,
|
| 406 |
+
unet=unet,
|
| 407 |
+
scheduler=scheduler,
|
| 408 |
+
movq=movq,
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
| 412 |
+
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
|
| 413 |
+
|
| 414 |
+
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
|
| 415 |
+
r"""
|
| 416 |
+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
| 417 |
+
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
| 418 |
+
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
| 419 |
+
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
| 420 |
+
`enable_model_cpu_offload`, but performance is lower.
|
| 421 |
+
"""
|
| 422 |
+
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
| 423 |
+
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
| 424 |
+
|
| 425 |
+
def progress_bar(self, iterable=None, total=None):
|
| 426 |
+
self.prior_pipe.progress_bar(iterable=iterable, total=total)
|
| 427 |
+
self.decoder_pipe.progress_bar(iterable=iterable, total=total)
|
| 428 |
+
self.decoder_pipe.enable_model_cpu_offload()
|
| 429 |
+
|
| 430 |
+
def set_progress_bar_config(self, **kwargs):
|
| 431 |
+
self.prior_pipe.set_progress_bar_config(**kwargs)
|
| 432 |
+
self.decoder_pipe.set_progress_bar_config(**kwargs)
|
| 433 |
+
|
| 434 |
+
@torch.no_grad()
|
| 435 |
+
@replace_example_docstring(IMAGE2IMAGE_EXAMPLE_DOC_STRING)
|
| 436 |
+
def __call__(
|
| 437 |
+
self,
|
| 438 |
+
prompt: Union[str, List[str]],
|
| 439 |
+
image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]],
|
| 440 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 441 |
+
num_inference_steps: int = 100,
|
| 442 |
+
guidance_scale: float = 4.0,
|
| 443 |
+
num_images_per_prompt: int = 1,
|
| 444 |
+
strength: float = 0.3,
|
| 445 |
+
height: int = 512,
|
| 446 |
+
width: int = 512,
|
| 447 |
+
prior_guidance_scale: float = 4.0,
|
| 448 |
+
prior_num_inference_steps: int = 25,
|
| 449 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 450 |
+
latents: Optional[torch.Tensor] = None,
|
| 451 |
+
output_type: Optional[str] = "pil",
|
| 452 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
| 453 |
+
callback_steps: int = 1,
|
| 454 |
+
return_dict: bool = True,
|
| 455 |
+
):
|
| 456 |
+
"""
|
| 457 |
+
Function invoked when calling the pipeline for generation.
|
| 458 |
+
|
| 459 |
+
Args:
|
| 460 |
+
prompt (`str` or `List[str]`):
|
| 461 |
+
The prompt or prompts to guide the image generation.
|
| 462 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 463 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 464 |
+
process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded
|
| 465 |
+
again.
|
| 466 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 467 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 468 |
+
if `guidance_scale` is less than `1`).
|
| 469 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 470 |
+
The number of images to generate per prompt.
|
| 471 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
| 472 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 473 |
+
expense of slower inference.
|
| 474 |
+
height (`int`, *optional*, defaults to 512):
|
| 475 |
+
The height in pixels of the generated image.
|
| 476 |
+
width (`int`, *optional*, defaults to 512):
|
| 477 |
+
The width in pixels of the generated image.
|
| 478 |
+
strength (`float`, *optional*, defaults to 0.3):
|
| 479 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
| 480 |
+
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
| 481 |
+
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
| 482 |
+
be maximum and the denoising process will run for the full number of iterations specified in
|
| 483 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
| 484 |
+
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
|
| 485 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 486 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 487 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 488 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 489 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 490 |
+
prior_num_inference_steps (`int`, *optional*, defaults to 100):
|
| 491 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 492 |
+
expense of slower inference.
|
| 493 |
+
guidance_scale (`float`, *optional*, defaults to 4.0):
|
| 494 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 495 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 496 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 497 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 498 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 499 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 500 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 501 |
+
to make generation deterministic.
|
| 502 |
+
latents (`torch.Tensor`, *optional*):
|
| 503 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 504 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 505 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 506 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 507 |
+
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
| 508 |
+
(`np.array`) or `"pt"` (`torch.Tensor`).
|
| 509 |
+
callback (`Callable`, *optional*):
|
| 510 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 511 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
| 512 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 513 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 514 |
+
every step.
|
| 515 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 516 |
+
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
| 517 |
+
|
| 518 |
+
Examples:
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
| 522 |
+
"""
|
| 523 |
+
prior_outputs = self.prior_pipe(
|
| 524 |
+
prompt=prompt,
|
| 525 |
+
negative_prompt=negative_prompt,
|
| 526 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 527 |
+
num_inference_steps=prior_num_inference_steps,
|
| 528 |
+
generator=generator,
|
| 529 |
+
latents=latents,
|
| 530 |
+
guidance_scale=prior_guidance_scale,
|
| 531 |
+
output_type="pt",
|
| 532 |
+
return_dict=False,
|
| 533 |
+
)
|
| 534 |
+
image_embeds = prior_outputs[0]
|
| 535 |
+
negative_image_embeds = prior_outputs[1]
|
| 536 |
+
|
| 537 |
+
prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt
|
| 538 |
+
image = [image] if isinstance(prompt, PIL.Image.Image) else image
|
| 539 |
+
|
| 540 |
+
if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0:
|
| 541 |
+
prompt = (image_embeds.shape[0] // len(prompt)) * prompt
|
| 542 |
+
|
| 543 |
+
if (
|
| 544 |
+
isinstance(image, (list, tuple))
|
| 545 |
+
and len(image) < image_embeds.shape[0]
|
| 546 |
+
and image_embeds.shape[0] % len(image) == 0
|
| 547 |
+
):
|
| 548 |
+
image = (image_embeds.shape[0] // len(image)) * image
|
| 549 |
+
|
| 550 |
+
outputs = self.decoder_pipe(
|
| 551 |
+
prompt=prompt,
|
| 552 |
+
image=image,
|
| 553 |
+
image_embeds=image_embeds,
|
| 554 |
+
negative_image_embeds=negative_image_embeds,
|
| 555 |
+
strength=strength,
|
| 556 |
+
width=width,
|
| 557 |
+
height=height,
|
| 558 |
+
num_inference_steps=num_inference_steps,
|
| 559 |
+
generator=generator,
|
| 560 |
+
guidance_scale=guidance_scale,
|
| 561 |
+
output_type=output_type,
|
| 562 |
+
callback=callback,
|
| 563 |
+
callback_steps=callback_steps,
|
| 564 |
+
return_dict=return_dict,
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
self.maybe_free_model_hooks()
|
| 568 |
+
|
| 569 |
+
return outputs
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
|
| 573 |
+
"""
|
| 574 |
+
Combined Pipeline for generation using Kandinsky
|
| 575 |
+
|
| 576 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 577 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 578 |
+
|
| 579 |
+
Args:
|
| 580 |
+
text_encoder ([`MultilingualCLIP`]):
|
| 581 |
+
Frozen text-encoder.
|
| 582 |
+
tokenizer ([`XLMRobertaTokenizer`]):
|
| 583 |
+
Tokenizer of class
|
| 584 |
+
scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]):
|
| 585 |
+
A scheduler to be used in combination with `unet` to generate image latents.
|
| 586 |
+
unet ([`UNet2DConditionModel`]):
|
| 587 |
+
Conditional U-Net architecture to denoise the image embedding.
|
| 588 |
+
movq ([`VQModel`]):
|
| 589 |
+
MoVQ Decoder to generate the image from the latents.
|
| 590 |
+
prior_prior ([`PriorTransformer`]):
|
| 591 |
+
The canonical unCLIP prior to approximate the image embedding from the text embedding.
|
| 592 |
+
prior_image_encoder ([`CLIPVisionModelWithProjection`]):
|
| 593 |
+
Frozen image-encoder.
|
| 594 |
+
prior_text_encoder ([`CLIPTextModelWithProjection`]):
|
| 595 |
+
Frozen text-encoder.
|
| 596 |
+
prior_tokenizer (`CLIPTokenizer`):
|
| 597 |
+
Tokenizer of class
|
| 598 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 599 |
+
prior_scheduler ([`UnCLIPScheduler`]):
|
| 600 |
+
A scheduler to be used in combination with `prior` to generate image embedding.
|
| 601 |
+
"""
|
| 602 |
+
|
| 603 |
+
_load_connected_pipes = True
|
| 604 |
+
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
|
| 605 |
+
_exclude_from_cpu_offload = ["prior_prior"]
|
| 606 |
+
|
| 607 |
+
def __init__(
|
| 608 |
+
self,
|
| 609 |
+
text_encoder: MultilingualCLIP,
|
| 610 |
+
tokenizer: XLMRobertaTokenizer,
|
| 611 |
+
unet: UNet2DConditionModel,
|
| 612 |
+
scheduler: Union[DDIMScheduler, DDPMScheduler],
|
| 613 |
+
movq: VQModel,
|
| 614 |
+
prior_prior: PriorTransformer,
|
| 615 |
+
prior_image_encoder: CLIPVisionModelWithProjection,
|
| 616 |
+
prior_text_encoder: CLIPTextModelWithProjection,
|
| 617 |
+
prior_tokenizer: CLIPTokenizer,
|
| 618 |
+
prior_scheduler: UnCLIPScheduler,
|
| 619 |
+
prior_image_processor: CLIPImageProcessor,
|
| 620 |
+
):
|
| 621 |
+
super().__init__()
|
| 622 |
+
|
| 623 |
+
self.register_modules(
|
| 624 |
+
text_encoder=text_encoder,
|
| 625 |
+
tokenizer=tokenizer,
|
| 626 |
+
unet=unet,
|
| 627 |
+
scheduler=scheduler,
|
| 628 |
+
movq=movq,
|
| 629 |
+
prior_prior=prior_prior,
|
| 630 |
+
prior_image_encoder=prior_image_encoder,
|
| 631 |
+
prior_text_encoder=prior_text_encoder,
|
| 632 |
+
prior_tokenizer=prior_tokenizer,
|
| 633 |
+
prior_scheduler=prior_scheduler,
|
| 634 |
+
prior_image_processor=prior_image_processor,
|
| 635 |
+
)
|
| 636 |
+
self.prior_pipe = KandinskyPriorPipeline(
|
| 637 |
+
prior=prior_prior,
|
| 638 |
+
image_encoder=prior_image_encoder,
|
| 639 |
+
text_encoder=prior_text_encoder,
|
| 640 |
+
tokenizer=prior_tokenizer,
|
| 641 |
+
scheduler=prior_scheduler,
|
| 642 |
+
image_processor=prior_image_processor,
|
| 643 |
+
)
|
| 644 |
+
self.decoder_pipe = KandinskyInpaintPipeline(
|
| 645 |
+
text_encoder=text_encoder,
|
| 646 |
+
tokenizer=tokenizer,
|
| 647 |
+
unet=unet,
|
| 648 |
+
scheduler=scheduler,
|
| 649 |
+
movq=movq,
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
| 653 |
+
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
|
| 654 |
+
|
| 655 |
+
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
|
| 656 |
+
r"""
|
| 657 |
+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
| 658 |
+
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
| 659 |
+
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
| 660 |
+
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
| 661 |
+
`enable_model_cpu_offload`, but performance is lower.
|
| 662 |
+
"""
|
| 663 |
+
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
| 664 |
+
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
| 665 |
+
|
| 666 |
+
def progress_bar(self, iterable=None, total=None):
|
| 667 |
+
self.prior_pipe.progress_bar(iterable=iterable, total=total)
|
| 668 |
+
self.decoder_pipe.progress_bar(iterable=iterable, total=total)
|
| 669 |
+
self.decoder_pipe.enable_model_cpu_offload()
|
| 670 |
+
|
| 671 |
+
def set_progress_bar_config(self, **kwargs):
|
| 672 |
+
self.prior_pipe.set_progress_bar_config(**kwargs)
|
| 673 |
+
self.decoder_pipe.set_progress_bar_config(**kwargs)
|
| 674 |
+
|
| 675 |
+
@torch.no_grad()
|
| 676 |
+
@replace_example_docstring(INPAINT_EXAMPLE_DOC_STRING)
|
| 677 |
+
def __call__(
|
| 678 |
+
self,
|
| 679 |
+
prompt: Union[str, List[str]],
|
| 680 |
+
image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]],
|
| 681 |
+
mask_image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]],
|
| 682 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 683 |
+
num_inference_steps: int = 100,
|
| 684 |
+
guidance_scale: float = 4.0,
|
| 685 |
+
num_images_per_prompt: int = 1,
|
| 686 |
+
height: int = 512,
|
| 687 |
+
width: int = 512,
|
| 688 |
+
prior_guidance_scale: float = 4.0,
|
| 689 |
+
prior_num_inference_steps: int = 25,
|
| 690 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 691 |
+
latents: Optional[torch.Tensor] = None,
|
| 692 |
+
output_type: Optional[str] = "pil",
|
| 693 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
| 694 |
+
callback_steps: int = 1,
|
| 695 |
+
return_dict: bool = True,
|
| 696 |
+
):
|
| 697 |
+
"""
|
| 698 |
+
Function invoked when calling the pipeline for generation.
|
| 699 |
+
|
| 700 |
+
Args:
|
| 701 |
+
prompt (`str` or `List[str]`):
|
| 702 |
+
The prompt or prompts to guide the image generation.
|
| 703 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 704 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 705 |
+
process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded
|
| 706 |
+
again.
|
| 707 |
+
mask_image (`np.array`):
|
| 708 |
+
Tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while
|
| 709 |
+
black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single
|
| 710 |
+
channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3,
|
| 711 |
+
so the expected shape would be `(B, H, W, 1)`.
|
| 712 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 713 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 714 |
+
if `guidance_scale` is less than `1`).
|
| 715 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 716 |
+
The number of images to generate per prompt.
|
| 717 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
| 718 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 719 |
+
expense of slower inference.
|
| 720 |
+
height (`int`, *optional*, defaults to 512):
|
| 721 |
+
The height in pixels of the generated image.
|
| 722 |
+
width (`int`, *optional*, defaults to 512):
|
| 723 |
+
The width in pixels of the generated image.
|
| 724 |
+
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
|
| 725 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 726 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 727 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 728 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 729 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 730 |
+
prior_num_inference_steps (`int`, *optional*, defaults to 100):
|
| 731 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 732 |
+
expense of slower inference.
|
| 733 |
+
guidance_scale (`float`, *optional*, defaults to 4.0):
|
| 734 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 735 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 736 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 737 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 738 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 739 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 740 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 741 |
+
to make generation deterministic.
|
| 742 |
+
latents (`torch.Tensor`, *optional*):
|
| 743 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 744 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 745 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 746 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 747 |
+
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
| 748 |
+
(`np.array`) or `"pt"` (`torch.Tensor`).
|
| 749 |
+
callback (`Callable`, *optional*):
|
| 750 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 751 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
| 752 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 753 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 754 |
+
every step.
|
| 755 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 756 |
+
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
| 757 |
+
|
| 758 |
+
Examples:
|
| 759 |
+
|
| 760 |
+
Returns:
|
| 761 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
| 762 |
+
"""
|
| 763 |
+
prior_outputs = self.prior_pipe(
|
| 764 |
+
prompt=prompt,
|
| 765 |
+
negative_prompt=negative_prompt,
|
| 766 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 767 |
+
num_inference_steps=prior_num_inference_steps,
|
| 768 |
+
generator=generator,
|
| 769 |
+
latents=latents,
|
| 770 |
+
guidance_scale=prior_guidance_scale,
|
| 771 |
+
output_type="pt",
|
| 772 |
+
return_dict=False,
|
| 773 |
+
)
|
| 774 |
+
image_embeds = prior_outputs[0]
|
| 775 |
+
negative_image_embeds = prior_outputs[1]
|
| 776 |
+
|
| 777 |
+
prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt
|
| 778 |
+
image = [image] if isinstance(prompt, PIL.Image.Image) else image
|
| 779 |
+
mask_image = [mask_image] if isinstance(mask_image, PIL.Image.Image) else mask_image
|
| 780 |
+
|
| 781 |
+
if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0:
|
| 782 |
+
prompt = (image_embeds.shape[0] // len(prompt)) * prompt
|
| 783 |
+
|
| 784 |
+
if (
|
| 785 |
+
isinstance(image, (list, tuple))
|
| 786 |
+
and len(image) < image_embeds.shape[0]
|
| 787 |
+
and image_embeds.shape[0] % len(image) == 0
|
| 788 |
+
):
|
| 789 |
+
image = (image_embeds.shape[0] // len(image)) * image
|
| 790 |
+
|
| 791 |
+
if (
|
| 792 |
+
isinstance(mask_image, (list, tuple))
|
| 793 |
+
and len(mask_image) < image_embeds.shape[0]
|
| 794 |
+
and image_embeds.shape[0] % len(mask_image) == 0
|
| 795 |
+
):
|
| 796 |
+
mask_image = (image_embeds.shape[0] // len(mask_image)) * mask_image
|
| 797 |
+
|
| 798 |
+
outputs = self.decoder_pipe(
|
| 799 |
+
prompt=prompt,
|
| 800 |
+
image=image,
|
| 801 |
+
mask_image=mask_image,
|
| 802 |
+
image_embeds=image_embeds,
|
| 803 |
+
negative_image_embeds=negative_image_embeds,
|
| 804 |
+
width=width,
|
| 805 |
+
height=height,
|
| 806 |
+
num_inference_steps=num_inference_steps,
|
| 807 |
+
generator=generator,
|
| 808 |
+
guidance_scale=guidance_scale,
|
| 809 |
+
output_type=output_type,
|
| 810 |
+
callback=callback,
|
| 811 |
+
callback_steps=callback_steps,
|
| 812 |
+
return_dict=return_dict,
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
self.maybe_free_model_hooks()
|
| 816 |
+
|
| 817 |
+
return outputs
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Callable, List, Optional, Union
|
| 16 |
+
|
| 17 |
+
import PIL.Image
|
| 18 |
+
import torch
|
| 19 |
+
from transformers import (
|
| 20 |
+
XLMRobertaTokenizer,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
from ...image_processor import VaeImageProcessor
|
| 24 |
+
from ...models import UNet2DConditionModel, VQModel
|
| 25 |
+
from ...schedulers import DDIMScheduler
|
| 26 |
+
from ...utils import (
|
| 27 |
+
is_torch_xla_available,
|
| 28 |
+
logging,
|
| 29 |
+
replace_example_docstring,
|
| 30 |
+
)
|
| 31 |
+
from ...utils.torch_utils import randn_tensor
|
| 32 |
+
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 33 |
+
from .text_encoder import MultilingualCLIP
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if is_torch_xla_available():
|
| 37 |
+
import torch_xla.core.xla_model as xm
|
| 38 |
+
|
| 39 |
+
XLA_AVAILABLE = True
|
| 40 |
+
else:
|
| 41 |
+
XLA_AVAILABLE = False
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
EXAMPLE_DOC_STRING = """
|
| 47 |
+
Examples:
|
| 48 |
+
```py
|
| 49 |
+
>>> from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline
|
| 50 |
+
>>> from diffusers.utils import load_image
|
| 51 |
+
>>> import torch
|
| 52 |
+
|
| 53 |
+
>>> pipe_prior = KandinskyPriorPipeline.from_pretrained(
|
| 54 |
+
... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
|
| 55 |
+
... )
|
| 56 |
+
>>> pipe_prior.to("cuda")
|
| 57 |
+
|
| 58 |
+
>>> prompt = "A red cartoon frog, 4k"
|
| 59 |
+
>>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)
|
| 60 |
+
|
| 61 |
+
>>> pipe = KandinskyImg2ImgPipeline.from_pretrained(
|
| 62 |
+
... "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16
|
| 63 |
+
... )
|
| 64 |
+
>>> pipe.to("cuda")
|
| 65 |
+
|
| 66 |
+
>>> init_image = load_image(
|
| 67 |
+
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
| 68 |
+
... "/kandinsky/frog.png"
|
| 69 |
+
... )
|
| 70 |
+
|
| 71 |
+
>>> image = pipe(
|
| 72 |
+
... prompt,
|
| 73 |
+
... image=init_image,
|
| 74 |
+
... image_embeds=image_emb,
|
| 75 |
+
... negative_image_embeds=zero_image_emb,
|
| 76 |
+
... height=768,
|
| 77 |
+
... width=768,
|
| 78 |
+
... num_inference_steps=100,
|
| 79 |
+
... strength=0.2,
|
| 80 |
+
... ).images
|
| 81 |
+
|
| 82 |
+
>>> image[0].save("red_frog.png")
|
| 83 |
+
```
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_new_h_w(h, w, scale_factor=8):
|
| 88 |
+
new_h = h // scale_factor**2
|
| 89 |
+
if h % scale_factor**2 != 0:
|
| 90 |
+
new_h += 1
|
| 91 |
+
new_w = w // scale_factor**2
|
| 92 |
+
if w % scale_factor**2 != 0:
|
| 93 |
+
new_w += 1
|
| 94 |
+
return new_h * scale_factor, new_w * scale_factor
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class KandinskyImg2ImgPipeline(DiffusionPipeline):
|
| 98 |
+
"""
|
| 99 |
+
Pipeline for image-to-image generation using Kandinsky
|
| 100 |
+
|
| 101 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 102 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
text_encoder ([`MultilingualCLIP`]):
|
| 106 |
+
Frozen text-encoder.
|
| 107 |
+
tokenizer ([`XLMRobertaTokenizer`]):
|
| 108 |
+
Tokenizer of class
|
| 109 |
+
scheduler ([`DDIMScheduler`]):
|
| 110 |
+
A scheduler to be used in combination with `unet` to generate image latents.
|
| 111 |
+
unet ([`UNet2DConditionModel`]):
|
| 112 |
+
Conditional U-Net architecture to denoise the image embedding.
|
| 113 |
+
movq ([`VQModel`]):
|
| 114 |
+
MoVQ image encoder and decoder
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
model_cpu_offload_seq = "text_encoder->unet->movq"
|
| 118 |
+
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
text_encoder: MultilingualCLIP,
|
| 122 |
+
movq: VQModel,
|
| 123 |
+
tokenizer: XLMRobertaTokenizer,
|
| 124 |
+
unet: UNet2DConditionModel,
|
| 125 |
+
scheduler: DDIMScheduler,
|
| 126 |
+
):
|
| 127 |
+
super().__init__()
|
| 128 |
+
|
| 129 |
+
self.register_modules(
|
| 130 |
+
text_encoder=text_encoder,
|
| 131 |
+
tokenizer=tokenizer,
|
| 132 |
+
unet=unet,
|
| 133 |
+
scheduler=scheduler,
|
| 134 |
+
movq=movq,
|
| 135 |
+
)
|
| 136 |
+
self.movq_scale_factor = (
|
| 137 |
+
2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
|
| 138 |
+
)
|
| 139 |
+
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
|
| 140 |
+
self.image_processor = VaeImageProcessor(
|
| 141 |
+
vae_scale_factor=self.movq_scale_factor,
|
| 142 |
+
vae_latent_channels=movq_latent_channels,
|
| 143 |
+
resample="bicubic",
|
| 144 |
+
reducing_gap=1,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
| 148 |
+
# get the original timestep using init_timestep
|
| 149 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 150 |
+
|
| 151 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 152 |
+
timesteps = self.scheduler.timesteps[t_start:]
|
| 153 |
+
|
| 154 |
+
return timesteps, num_inference_steps - t_start
|
| 155 |
+
|
| 156 |
+
def prepare_latents(self, latents, latent_timestep, shape, dtype, device, generator, scheduler):
|
| 157 |
+
if latents is None:
|
| 158 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 159 |
+
else:
|
| 160 |
+
if latents.shape != shape:
|
| 161 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
| 162 |
+
latents = latents.to(device)
|
| 163 |
+
|
| 164 |
+
latents = latents * scheduler.init_noise_sigma
|
| 165 |
+
|
| 166 |
+
shape = latents.shape
|
| 167 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 168 |
+
|
| 169 |
+
latents = self.add_noise(latents, noise, latent_timestep)
|
| 170 |
+
return latents
|
| 171 |
+
|
| 172 |
+
def _encode_prompt(
|
| 173 |
+
self,
|
| 174 |
+
prompt,
|
| 175 |
+
device,
|
| 176 |
+
num_images_per_prompt,
|
| 177 |
+
do_classifier_free_guidance,
|
| 178 |
+
negative_prompt=None,
|
| 179 |
+
):
|
| 180 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
| 181 |
+
# get prompt text embeddings
|
| 182 |
+
text_inputs = self.tokenizer(
|
| 183 |
+
prompt,
|
| 184 |
+
padding="max_length",
|
| 185 |
+
max_length=77,
|
| 186 |
+
truncation=True,
|
| 187 |
+
return_attention_mask=True,
|
| 188 |
+
add_special_tokens=True,
|
| 189 |
+
return_tensors="pt",
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
text_input_ids = text_inputs.input_ids
|
| 193 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 194 |
+
|
| 195 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 196 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
| 197 |
+
logger.warning(
|
| 198 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 199 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
text_input_ids = text_input_ids.to(device)
|
| 203 |
+
text_mask = text_inputs.attention_mask.to(device)
|
| 204 |
+
|
| 205 |
+
prompt_embeds, text_encoder_hidden_states = self.text_encoder(
|
| 206 |
+
input_ids=text_input_ids, attention_mask=text_mask
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 210 |
+
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| 211 |
+
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
| 212 |
+
|
| 213 |
+
if do_classifier_free_guidance:
|
| 214 |
+
uncond_tokens: List[str]
|
| 215 |
+
if negative_prompt is None:
|
| 216 |
+
uncond_tokens = [""] * batch_size
|
| 217 |
+
elif type(prompt) is not type(negative_prompt):
|
| 218 |
+
raise TypeError(
|
| 219 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 220 |
+
f" {type(prompt)}."
|
| 221 |
+
)
|
| 222 |
+
elif isinstance(negative_prompt, str):
|
| 223 |
+
uncond_tokens = [negative_prompt]
|
| 224 |
+
elif batch_size != len(negative_prompt):
|
| 225 |
+
raise ValueError(
|
| 226 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 227 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 228 |
+
" the batch size of `prompt`."
|
| 229 |
+
)
|
| 230 |
+
else:
|
| 231 |
+
uncond_tokens = negative_prompt
|
| 232 |
+
|
| 233 |
+
uncond_input = self.tokenizer(
|
| 234 |
+
uncond_tokens,
|
| 235 |
+
padding="max_length",
|
| 236 |
+
max_length=77,
|
| 237 |
+
truncation=True,
|
| 238 |
+
return_attention_mask=True,
|
| 239 |
+
add_special_tokens=True,
|
| 240 |
+
return_tensors="pt",
|
| 241 |
+
)
|
| 242 |
+
uncond_text_input_ids = uncond_input.input_ids.to(device)
|
| 243 |
+
uncond_text_mask = uncond_input.attention_mask.to(device)
|
| 244 |
+
|
| 245 |
+
negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder(
|
| 246 |
+
input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 250 |
+
|
| 251 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 252 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
|
| 253 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
|
| 254 |
+
|
| 255 |
+
seq_len = uncond_text_encoder_hidden_states.shape[1]
|
| 256 |
+
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
|
| 257 |
+
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
|
| 258 |
+
batch_size * num_images_per_prompt, seq_len, -1
|
| 259 |
+
)
|
| 260 |
+
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
| 261 |
+
|
| 262 |
+
# done duplicates
|
| 263 |
+
|
| 264 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 265 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 266 |
+
# to avoid doing two forward passes
|
| 267 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 268 |
+
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
|
| 269 |
+
|
| 270 |
+
text_mask = torch.cat([uncond_text_mask, text_mask])
|
| 271 |
+
|
| 272 |
+
return prompt_embeds, text_encoder_hidden_states, text_mask
|
| 273 |
+
|
| 274 |
+
# add_noise method to overwrite the one in schedule because it use a different beta schedule for adding noise vs sampling
|
| 275 |
+
def add_noise(
|
| 276 |
+
self,
|
| 277 |
+
original_samples: torch.Tensor,
|
| 278 |
+
noise: torch.Tensor,
|
| 279 |
+
timesteps: torch.IntTensor,
|
| 280 |
+
) -> torch.Tensor:
|
| 281 |
+
betas = torch.linspace(0.0001, 0.02, 1000, dtype=torch.float32)
|
| 282 |
+
alphas = 1.0 - betas
|
| 283 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 284 |
+
alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 285 |
+
timesteps = timesteps.to(original_samples.device)
|
| 286 |
+
|
| 287 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
| 288 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
| 289 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
| 290 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
| 291 |
+
|
| 292 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
| 293 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
| 294 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
| 295 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
| 296 |
+
|
| 297 |
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
| 298 |
+
|
| 299 |
+
return noisy_samples
|
| 300 |
+
|
| 301 |
+
@torch.no_grad()
|
| 302 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 303 |
+
def __call__(
|
| 304 |
+
self,
|
| 305 |
+
prompt: Union[str, List[str]],
|
| 306 |
+
image: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]],
|
| 307 |
+
image_embeds: torch.Tensor,
|
| 308 |
+
negative_image_embeds: torch.Tensor,
|
| 309 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 310 |
+
height: int = 512,
|
| 311 |
+
width: int = 512,
|
| 312 |
+
num_inference_steps: int = 100,
|
| 313 |
+
strength: float = 0.3,
|
| 314 |
+
guidance_scale: float = 7.0,
|
| 315 |
+
num_images_per_prompt: int = 1,
|
| 316 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 317 |
+
output_type: Optional[str] = "pil",
|
| 318 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
| 319 |
+
callback_steps: int = 1,
|
| 320 |
+
return_dict: bool = True,
|
| 321 |
+
):
|
| 322 |
+
"""
|
| 323 |
+
Function invoked when calling the pipeline for generation.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
prompt (`str` or `List[str]`):
|
| 327 |
+
The prompt or prompts to guide the image generation.
|
| 328 |
+
image (`torch.Tensor`, `PIL.Image.Image`):
|
| 329 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 330 |
+
process.
|
| 331 |
+
image_embeds (`torch.Tensor` or `List[torch.Tensor]`):
|
| 332 |
+
The clip image embeddings for text prompt, that will be used to condition the image generation.
|
| 333 |
+
negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`):
|
| 334 |
+
The clip image embeddings for negative text prompt, will be used to condition the image generation.
|
| 335 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 336 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 337 |
+
if `guidance_scale` is less than `1`).
|
| 338 |
+
height (`int`, *optional*, defaults to 512):
|
| 339 |
+
The height in pixels of the generated image.
|
| 340 |
+
width (`int`, *optional*, defaults to 512):
|
| 341 |
+
The width in pixels of the generated image.
|
| 342 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
| 343 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 344 |
+
expense of slower inference.
|
| 345 |
+
strength (`float`, *optional*, defaults to 0.3):
|
| 346 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
| 347 |
+
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
| 348 |
+
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
| 349 |
+
be maximum and the denoising process will run for the full number of iterations specified in
|
| 350 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
| 351 |
+
guidance_scale (`float`, *optional*, defaults to 4.0):
|
| 352 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 353 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 354 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 355 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 356 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 357 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 358 |
+
The number of images to generate per prompt.
|
| 359 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 360 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 361 |
+
to make generation deterministic.
|
| 362 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 363 |
+
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
| 364 |
+
(`np.array`) or `"pt"` (`torch.Tensor`).
|
| 365 |
+
callback (`Callable`, *optional*):
|
| 366 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 367 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
| 368 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 369 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 370 |
+
every step.
|
| 371 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 372 |
+
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
| 373 |
+
|
| 374 |
+
Examples:
|
| 375 |
+
|
| 376 |
+
Returns:
|
| 377 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
| 378 |
+
"""
|
| 379 |
+
# 1. Define call parameters
|
| 380 |
+
if isinstance(prompt, str):
|
| 381 |
+
batch_size = 1
|
| 382 |
+
elif isinstance(prompt, list):
|
| 383 |
+
batch_size = len(prompt)
|
| 384 |
+
else:
|
| 385 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 386 |
+
|
| 387 |
+
device = self._execution_device
|
| 388 |
+
|
| 389 |
+
batch_size = batch_size * num_images_per_prompt
|
| 390 |
+
|
| 391 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 392 |
+
|
| 393 |
+
# 2. get text and image embeddings
|
| 394 |
+
prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt(
|
| 395 |
+
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
if isinstance(image_embeds, list):
|
| 399 |
+
image_embeds = torch.cat(image_embeds, dim=0)
|
| 400 |
+
if isinstance(negative_image_embeds, list):
|
| 401 |
+
negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
|
| 402 |
+
|
| 403 |
+
if do_classifier_free_guidance:
|
| 404 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 405 |
+
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 406 |
+
|
| 407 |
+
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
|
| 408 |
+
dtype=prompt_embeds.dtype, device=device
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
# 3. pre-processing initial image
|
| 412 |
+
if not isinstance(image, list):
|
| 413 |
+
image = [image]
|
| 414 |
+
if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image):
|
| 415 |
+
raise ValueError(
|
| 416 |
+
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
|
| 420 |
+
image = image.to(dtype=prompt_embeds.dtype, device=device)
|
| 421 |
+
|
| 422 |
+
latents = self.movq.encode(image)["latents"]
|
| 423 |
+
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
|
| 424 |
+
|
| 425 |
+
# 4. set timesteps
|
| 426 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 427 |
+
|
| 428 |
+
timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
| 429 |
+
|
| 430 |
+
# the formular to calculate timestep for add_noise is taken from the original kandinsky repo
|
| 431 |
+
latent_timestep = int(self.scheduler.config.num_train_timesteps * strength) - 2
|
| 432 |
+
|
| 433 |
+
latent_timestep = torch.tensor([latent_timestep] * batch_size, dtype=timesteps_tensor.dtype, device=device)
|
| 434 |
+
|
| 435 |
+
num_channels_latents = self.unet.config.in_channels
|
| 436 |
+
|
| 437 |
+
height, width = get_new_h_w(height, width, self.movq_scale_factor)
|
| 438 |
+
|
| 439 |
+
# 5. Create initial latent
|
| 440 |
+
latents = self.prepare_latents(
|
| 441 |
+
latents,
|
| 442 |
+
latent_timestep,
|
| 443 |
+
(batch_size, num_channels_latents, height, width),
|
| 444 |
+
text_encoder_hidden_states.dtype,
|
| 445 |
+
device,
|
| 446 |
+
generator,
|
| 447 |
+
self.scheduler,
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# 6. Denoising loop
|
| 451 |
+
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
| 452 |
+
# expand the latents if we are doing classifier free guidance
|
| 453 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 454 |
+
|
| 455 |
+
added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds}
|
| 456 |
+
noise_pred = self.unet(
|
| 457 |
+
sample=latent_model_input,
|
| 458 |
+
timestep=t,
|
| 459 |
+
encoder_hidden_states=text_encoder_hidden_states,
|
| 460 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 461 |
+
return_dict=False,
|
| 462 |
+
)[0]
|
| 463 |
+
|
| 464 |
+
if do_classifier_free_guidance:
|
| 465 |
+
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
|
| 466 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 467 |
+
_, variance_pred_text = variance_pred.chunk(2)
|
| 468 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 469 |
+
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
|
| 470 |
+
|
| 471 |
+
if not (
|
| 472 |
+
hasattr(self.scheduler.config, "variance_type")
|
| 473 |
+
and self.scheduler.config.variance_type in ["learned", "learned_range"]
|
| 474 |
+
):
|
| 475 |
+
noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
|
| 476 |
+
|
| 477 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 478 |
+
latents = self.scheduler.step(
|
| 479 |
+
noise_pred,
|
| 480 |
+
t,
|
| 481 |
+
latents,
|
| 482 |
+
generator=generator,
|
| 483 |
+
).prev_sample
|
| 484 |
+
|
| 485 |
+
if callback is not None and i % callback_steps == 0:
|
| 486 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 487 |
+
callback(step_idx, t, latents)
|
| 488 |
+
|
| 489 |
+
if XLA_AVAILABLE:
|
| 490 |
+
xm.mark_step()
|
| 491 |
+
|
| 492 |
+
# 7. post-processing
|
| 493 |
+
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
| 494 |
+
|
| 495 |
+
self.maybe_free_model_hooks()
|
| 496 |
+
|
| 497 |
+
if output_type not in ["pt", "np", "pil"]:
|
| 498 |
+
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
|
| 499 |
+
|
| 500 |
+
image = self.image_processor.postprocess(image, output_type)
|
| 501 |
+
|
| 502 |
+
if not return_dict:
|
| 503 |
+
return (image,)
|
| 504 |
+
|
| 505 |
+
return ImagePipelineOutput(images=image)
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
ADDED
|
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from copy import deepcopy
|
| 16 |
+
from typing import Callable, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import PIL.Image
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
from packaging import version
|
| 23 |
+
from PIL import Image
|
| 24 |
+
from transformers import (
|
| 25 |
+
XLMRobertaTokenizer,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
from ... import __version__
|
| 29 |
+
from ...models import UNet2DConditionModel, VQModel
|
| 30 |
+
from ...schedulers import DDIMScheduler
|
| 31 |
+
from ...utils import (
|
| 32 |
+
is_torch_xla_available,
|
| 33 |
+
logging,
|
| 34 |
+
replace_example_docstring,
|
| 35 |
+
)
|
| 36 |
+
from ...utils.torch_utils import randn_tensor
|
| 37 |
+
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 38 |
+
from .text_encoder import MultilingualCLIP
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if is_torch_xla_available():
|
| 42 |
+
import torch_xla.core.xla_model as xm
|
| 43 |
+
|
| 44 |
+
XLA_AVAILABLE = True
|
| 45 |
+
else:
|
| 46 |
+
XLA_AVAILABLE = False
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
EXAMPLE_DOC_STRING = """
|
| 52 |
+
Examples:
|
| 53 |
+
```py
|
| 54 |
+
>>> from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline
|
| 55 |
+
>>> from diffusers.utils import load_image
|
| 56 |
+
>>> import torch
|
| 57 |
+
>>> import numpy as np
|
| 58 |
+
|
| 59 |
+
>>> pipe_prior = KandinskyPriorPipeline.from_pretrained(
|
| 60 |
+
... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
|
| 61 |
+
... )
|
| 62 |
+
>>> pipe_prior.to("cuda")
|
| 63 |
+
|
| 64 |
+
>>> prompt = "a hat"
|
| 65 |
+
>>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)
|
| 66 |
+
|
| 67 |
+
>>> pipe = KandinskyInpaintPipeline.from_pretrained(
|
| 68 |
+
... "kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16
|
| 69 |
+
... )
|
| 70 |
+
>>> pipe.to("cuda")
|
| 71 |
+
|
| 72 |
+
>>> init_image = load_image(
|
| 73 |
+
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
| 74 |
+
... "/kandinsky/cat.png"
|
| 75 |
+
... )
|
| 76 |
+
|
| 77 |
+
>>> mask = np.zeros((768, 768), dtype=np.float32)
|
| 78 |
+
>>> mask[:250, 250:-250] = 1
|
| 79 |
+
|
| 80 |
+
>>> out = pipe(
|
| 81 |
+
... prompt,
|
| 82 |
+
... image=init_image,
|
| 83 |
+
... mask_image=mask,
|
| 84 |
+
... image_embeds=image_emb,
|
| 85 |
+
... negative_image_embeds=zero_image_emb,
|
| 86 |
+
... height=768,
|
| 87 |
+
... width=768,
|
| 88 |
+
... num_inference_steps=50,
|
| 89 |
+
... )
|
| 90 |
+
|
| 91 |
+
>>> image = out.images[0]
|
| 92 |
+
>>> image.save("cat_with_hat.png")
|
| 93 |
+
```
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_new_h_w(h, w, scale_factor=8):
|
| 98 |
+
new_h = h // scale_factor**2
|
| 99 |
+
if h % scale_factor**2 != 0:
|
| 100 |
+
new_h += 1
|
| 101 |
+
new_w = w // scale_factor**2
|
| 102 |
+
if w % scale_factor**2 != 0:
|
| 103 |
+
new_w += 1
|
| 104 |
+
return new_h * scale_factor, new_w * scale_factor
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def prepare_mask(masks):
|
| 108 |
+
prepared_masks = []
|
| 109 |
+
for mask in masks:
|
| 110 |
+
old_mask = deepcopy(mask)
|
| 111 |
+
for i in range(mask.shape[1]):
|
| 112 |
+
for j in range(mask.shape[2]):
|
| 113 |
+
if old_mask[0][i][j] == 1:
|
| 114 |
+
continue
|
| 115 |
+
if i != 0:
|
| 116 |
+
mask[:, i - 1, j] = 0
|
| 117 |
+
if j != 0:
|
| 118 |
+
mask[:, i, j - 1] = 0
|
| 119 |
+
if i != 0 and j != 0:
|
| 120 |
+
mask[:, i - 1, j - 1] = 0
|
| 121 |
+
if i != mask.shape[1] - 1:
|
| 122 |
+
mask[:, i + 1, j] = 0
|
| 123 |
+
if j != mask.shape[2] - 1:
|
| 124 |
+
mask[:, i, j + 1] = 0
|
| 125 |
+
if i != mask.shape[1] - 1 and j != mask.shape[2] - 1:
|
| 126 |
+
mask[:, i + 1, j + 1] = 0
|
| 127 |
+
prepared_masks.append(mask)
|
| 128 |
+
return torch.stack(prepared_masks, dim=0)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def prepare_mask_and_masked_image(image, mask, height, width):
|
| 132 |
+
r"""
|
| 133 |
+
Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will
|
| 134 |
+
be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for
|
| 135 |
+
the ``image`` and ``1`` for the ``mask``.
|
| 136 |
+
|
| 137 |
+
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
|
| 138 |
+
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
|
| 142 |
+
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
|
| 143 |
+
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
|
| 144 |
+
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
|
| 145 |
+
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
|
| 146 |
+
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
|
| 147 |
+
height (`int`, *optional*, defaults to 512):
|
| 148 |
+
The height in pixels of the generated image.
|
| 149 |
+
width (`int`, *optional*, defaults to 512):
|
| 150 |
+
The width in pixels of the generated image.
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
Raises:
|
| 154 |
+
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
|
| 155 |
+
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
|
| 156 |
+
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
|
| 157 |
+
(ot the other way around).
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4
|
| 161 |
+
dimensions: ``batch x channels x height x width``.
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
if image is None:
|
| 165 |
+
raise ValueError("`image` input cannot be undefined.")
|
| 166 |
+
|
| 167 |
+
if mask is None:
|
| 168 |
+
raise ValueError("`mask_image` input cannot be undefined.")
|
| 169 |
+
|
| 170 |
+
if isinstance(image, torch.Tensor):
|
| 171 |
+
if not isinstance(mask, torch.Tensor):
|
| 172 |
+
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
|
| 173 |
+
|
| 174 |
+
# Batch single image
|
| 175 |
+
if image.ndim == 3:
|
| 176 |
+
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
|
| 177 |
+
image = image.unsqueeze(0)
|
| 178 |
+
|
| 179 |
+
# Batch and add channel dim for single mask
|
| 180 |
+
if mask.ndim == 2:
|
| 181 |
+
mask = mask.unsqueeze(0).unsqueeze(0)
|
| 182 |
+
|
| 183 |
+
# Batch single mask or add channel dim
|
| 184 |
+
if mask.ndim == 3:
|
| 185 |
+
# Single batched mask, no channel dim or single mask not batched but channel dim
|
| 186 |
+
if mask.shape[0] == 1:
|
| 187 |
+
mask = mask.unsqueeze(0)
|
| 188 |
+
|
| 189 |
+
# Batched masks no channel dim
|
| 190 |
+
else:
|
| 191 |
+
mask = mask.unsqueeze(1)
|
| 192 |
+
|
| 193 |
+
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
|
| 194 |
+
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
|
| 195 |
+
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
|
| 196 |
+
|
| 197 |
+
# Check image is in [-1, 1]
|
| 198 |
+
if image.min() < -1 or image.max() > 1:
|
| 199 |
+
raise ValueError("Image should be in [-1, 1] range")
|
| 200 |
+
|
| 201 |
+
# Check mask is in [0, 1]
|
| 202 |
+
if mask.min() < 0 or mask.max() > 1:
|
| 203 |
+
raise ValueError("Mask should be in [0, 1] range")
|
| 204 |
+
|
| 205 |
+
# Binarize mask
|
| 206 |
+
mask[mask < 0.5] = 0
|
| 207 |
+
mask[mask >= 0.5] = 1
|
| 208 |
+
|
| 209 |
+
# Image as float32
|
| 210 |
+
image = image.to(dtype=torch.float32)
|
| 211 |
+
elif isinstance(mask, torch.Tensor):
|
| 212 |
+
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
|
| 213 |
+
else:
|
| 214 |
+
# preprocess image
|
| 215 |
+
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
| 216 |
+
image = [image]
|
| 217 |
+
|
| 218 |
+
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
| 219 |
+
# resize all images w.r.t passed height an width
|
| 220 |
+
image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image]
|
| 221 |
+
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
| 222 |
+
image = np.concatenate(image, axis=0)
|
| 223 |
+
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
| 224 |
+
image = np.concatenate([i[None, :] for i in image], axis=0)
|
| 225 |
+
|
| 226 |
+
image = image.transpose(0, 3, 1, 2)
|
| 227 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
| 228 |
+
|
| 229 |
+
# preprocess mask
|
| 230 |
+
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
|
| 231 |
+
mask = [mask]
|
| 232 |
+
|
| 233 |
+
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
|
| 234 |
+
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
|
| 235 |
+
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
|
| 236 |
+
mask = mask.astype(np.float32) / 255.0
|
| 237 |
+
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
| 238 |
+
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
| 239 |
+
|
| 240 |
+
mask[mask < 0.5] = 0
|
| 241 |
+
mask[mask >= 0.5] = 1
|
| 242 |
+
mask = torch.from_numpy(mask)
|
| 243 |
+
|
| 244 |
+
mask = 1 - mask
|
| 245 |
+
|
| 246 |
+
return mask, image
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class KandinskyInpaintPipeline(DiffusionPipeline):
|
| 250 |
+
"""
|
| 251 |
+
Pipeline for text-guided image inpainting using Kandinsky2.1
|
| 252 |
+
|
| 253 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 254 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
text_encoder ([`MultilingualCLIP`]):
|
| 258 |
+
Frozen text-encoder.
|
| 259 |
+
tokenizer ([`XLMRobertaTokenizer`]):
|
| 260 |
+
Tokenizer of class
|
| 261 |
+
scheduler ([`DDIMScheduler`]):
|
| 262 |
+
A scheduler to be used in combination with `unet` to generate image latents.
|
| 263 |
+
unet ([`UNet2DConditionModel`]):
|
| 264 |
+
Conditional U-Net architecture to denoise the image embedding.
|
| 265 |
+
movq ([`VQModel`]):
|
| 266 |
+
MoVQ image encoder and decoder
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
model_cpu_offload_seq = "text_encoder->unet->movq"
|
| 270 |
+
|
| 271 |
+
def __init__(
|
| 272 |
+
self,
|
| 273 |
+
text_encoder: MultilingualCLIP,
|
| 274 |
+
movq: VQModel,
|
| 275 |
+
tokenizer: XLMRobertaTokenizer,
|
| 276 |
+
unet: UNet2DConditionModel,
|
| 277 |
+
scheduler: DDIMScheduler,
|
| 278 |
+
):
|
| 279 |
+
super().__init__()
|
| 280 |
+
|
| 281 |
+
self.register_modules(
|
| 282 |
+
text_encoder=text_encoder,
|
| 283 |
+
movq=movq,
|
| 284 |
+
tokenizer=tokenizer,
|
| 285 |
+
unet=unet,
|
| 286 |
+
scheduler=scheduler,
|
| 287 |
+
)
|
| 288 |
+
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
|
| 289 |
+
self._warn_has_been_called = False
|
| 290 |
+
|
| 291 |
+
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
| 292 |
+
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
| 293 |
+
if latents is None:
|
| 294 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 295 |
+
else:
|
| 296 |
+
if latents.shape != shape:
|
| 297 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
| 298 |
+
latents = latents.to(device)
|
| 299 |
+
|
| 300 |
+
latents = latents * scheduler.init_noise_sigma
|
| 301 |
+
return latents
|
| 302 |
+
|
| 303 |
+
def _encode_prompt(
|
| 304 |
+
self,
|
| 305 |
+
prompt,
|
| 306 |
+
device,
|
| 307 |
+
num_images_per_prompt,
|
| 308 |
+
do_classifier_free_guidance,
|
| 309 |
+
negative_prompt=None,
|
| 310 |
+
):
|
| 311 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
| 312 |
+
# get prompt text embeddings
|
| 313 |
+
text_inputs = self.tokenizer(
|
| 314 |
+
prompt,
|
| 315 |
+
padding="max_length",
|
| 316 |
+
max_length=77,
|
| 317 |
+
truncation=True,
|
| 318 |
+
return_attention_mask=True,
|
| 319 |
+
add_special_tokens=True,
|
| 320 |
+
return_tensors="pt",
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
text_input_ids = text_inputs.input_ids
|
| 324 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 325 |
+
|
| 326 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 327 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
| 328 |
+
logger.warning(
|
| 329 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 330 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
text_input_ids = text_input_ids.to(device)
|
| 334 |
+
text_mask = text_inputs.attention_mask.to(device)
|
| 335 |
+
|
| 336 |
+
prompt_embeds, text_encoder_hidden_states = self.text_encoder(
|
| 337 |
+
input_ids=text_input_ids, attention_mask=text_mask
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 341 |
+
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| 342 |
+
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
| 343 |
+
|
| 344 |
+
if do_classifier_free_guidance:
|
| 345 |
+
uncond_tokens: List[str]
|
| 346 |
+
if negative_prompt is None:
|
| 347 |
+
uncond_tokens = [""] * batch_size
|
| 348 |
+
elif type(prompt) is not type(negative_prompt):
|
| 349 |
+
raise TypeError(
|
| 350 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 351 |
+
f" {type(prompt)}."
|
| 352 |
+
)
|
| 353 |
+
elif isinstance(negative_prompt, str):
|
| 354 |
+
uncond_tokens = [negative_prompt]
|
| 355 |
+
elif batch_size != len(negative_prompt):
|
| 356 |
+
raise ValueError(
|
| 357 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 358 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 359 |
+
" the batch size of `prompt`."
|
| 360 |
+
)
|
| 361 |
+
else:
|
| 362 |
+
uncond_tokens = negative_prompt
|
| 363 |
+
|
| 364 |
+
uncond_input = self.tokenizer(
|
| 365 |
+
uncond_tokens,
|
| 366 |
+
padding="max_length",
|
| 367 |
+
max_length=77,
|
| 368 |
+
truncation=True,
|
| 369 |
+
return_attention_mask=True,
|
| 370 |
+
add_special_tokens=True,
|
| 371 |
+
return_tensors="pt",
|
| 372 |
+
)
|
| 373 |
+
uncond_text_input_ids = uncond_input.input_ids.to(device)
|
| 374 |
+
uncond_text_mask = uncond_input.attention_mask.to(device)
|
| 375 |
+
|
| 376 |
+
negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder(
|
| 377 |
+
input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 381 |
+
|
| 382 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 383 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
|
| 384 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
|
| 385 |
+
|
| 386 |
+
seq_len = uncond_text_encoder_hidden_states.shape[1]
|
| 387 |
+
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
|
| 388 |
+
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
|
| 389 |
+
batch_size * num_images_per_prompt, seq_len, -1
|
| 390 |
+
)
|
| 391 |
+
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
| 392 |
+
|
| 393 |
+
# done duplicates
|
| 394 |
+
|
| 395 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 396 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 397 |
+
# to avoid doing two forward passes
|
| 398 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 399 |
+
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
|
| 400 |
+
|
| 401 |
+
text_mask = torch.cat([uncond_text_mask, text_mask])
|
| 402 |
+
|
| 403 |
+
return prompt_embeds, text_encoder_hidden_states, text_mask
|
| 404 |
+
|
| 405 |
+
@torch.no_grad()
|
| 406 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 407 |
+
def __call__(
|
| 408 |
+
self,
|
| 409 |
+
prompt: Union[str, List[str]],
|
| 410 |
+
image: Union[torch.Tensor, PIL.Image.Image],
|
| 411 |
+
mask_image: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
|
| 412 |
+
image_embeds: torch.Tensor,
|
| 413 |
+
negative_image_embeds: torch.Tensor,
|
| 414 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 415 |
+
height: int = 512,
|
| 416 |
+
width: int = 512,
|
| 417 |
+
num_inference_steps: int = 100,
|
| 418 |
+
guidance_scale: float = 4.0,
|
| 419 |
+
num_images_per_prompt: int = 1,
|
| 420 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 421 |
+
latents: Optional[torch.Tensor] = None,
|
| 422 |
+
output_type: Optional[str] = "pil",
|
| 423 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
| 424 |
+
callback_steps: int = 1,
|
| 425 |
+
return_dict: bool = True,
|
| 426 |
+
):
|
| 427 |
+
"""
|
| 428 |
+
Function invoked when calling the pipeline for generation.
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
prompt (`str` or `List[str]`):
|
| 432 |
+
The prompt or prompts to guide the image generation.
|
| 433 |
+
image (`torch.Tensor`, `PIL.Image.Image` or `np.ndarray`):
|
| 434 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 435 |
+
process.
|
| 436 |
+
mask_image (`PIL.Image.Image`,`torch.Tensor` or `np.ndarray`):
|
| 437 |
+
`Image`, or a tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
| 438 |
+
repainted, while black pixels will be preserved. You can pass a pytorch tensor as mask only if the
|
| 439 |
+
image you passed is a pytorch tensor, and it should contain one color channel (L) instead of 3, so the
|
| 440 |
+
expected shape would be either `(B, 1, H, W,)`, `(B, H, W)`, `(1, H, W)` or `(H, W)` If image is an PIL
|
| 441 |
+
image or numpy array, mask should also be a either PIL image or numpy array. If it is a PIL image, it
|
| 442 |
+
will be converted to a single channel (luminance) before use. If it is a nummpy array, the expected
|
| 443 |
+
shape is `(H, W)`.
|
| 444 |
+
image_embeds (`torch.Tensor` or `List[torch.Tensor]`):
|
| 445 |
+
The clip image embeddings for text prompt, that will be used to condition the image generation.
|
| 446 |
+
negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`):
|
| 447 |
+
The clip image embeddings for negative text prompt, will be used to condition the image generation.
|
| 448 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 449 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 450 |
+
if `guidance_scale` is less than `1`).
|
| 451 |
+
height (`int`, *optional*, defaults to 512):
|
| 452 |
+
The height in pixels of the generated image.
|
| 453 |
+
width (`int`, *optional*, defaults to 512):
|
| 454 |
+
The width in pixels of the generated image.
|
| 455 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
| 456 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 457 |
+
expense of slower inference.
|
| 458 |
+
guidance_scale (`float`, *optional*, defaults to 4.0):
|
| 459 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 460 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 461 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 462 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 463 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 464 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 465 |
+
The number of images to generate per prompt.
|
| 466 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 467 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 468 |
+
to make generation deterministic.
|
| 469 |
+
latents (`torch.Tensor`, *optional*):
|
| 470 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 471 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 472 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 473 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 474 |
+
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
| 475 |
+
(`np.array`) or `"pt"` (`torch.Tensor`).
|
| 476 |
+
callback (`Callable`, *optional*):
|
| 477 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 478 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
| 479 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 480 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 481 |
+
every step.
|
| 482 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 483 |
+
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
| 484 |
+
|
| 485 |
+
Examples:
|
| 486 |
+
|
| 487 |
+
Returns:
|
| 488 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
| 489 |
+
"""
|
| 490 |
+
if not self._warn_has_been_called and version.parse(version.parse(__version__).base_version) < version.parse(
|
| 491 |
+
"0.23.0.dev0"
|
| 492 |
+
):
|
| 493 |
+
logger.warning(
|
| 494 |
+
"Please note that the expected format of `mask_image` has recently been changed. "
|
| 495 |
+
"Before diffusers == 0.19.0, Kandinsky Inpainting pipelines repainted black pixels and preserved black pixels. "
|
| 496 |
+
"As of diffusers==0.19.0 this behavior has been inverted. Now white pixels are repainted and black pixels are preserved. "
|
| 497 |
+
"This way, Kandinsky's masking behavior is aligned with Stable Diffusion. "
|
| 498 |
+
"THIS means that you HAVE to invert the input mask to have the same behavior as before as explained in https://github.com/huggingface/diffusers/pull/4207. "
|
| 499 |
+
"This warning will be suppressed after the first inference call and will be removed in diffusers>0.23.0"
|
| 500 |
+
)
|
| 501 |
+
self._warn_has_been_called = True
|
| 502 |
+
|
| 503 |
+
# Define call parameters
|
| 504 |
+
if isinstance(prompt, str):
|
| 505 |
+
batch_size = 1
|
| 506 |
+
elif isinstance(prompt, list):
|
| 507 |
+
batch_size = len(prompt)
|
| 508 |
+
else:
|
| 509 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 510 |
+
|
| 511 |
+
device = self._execution_device
|
| 512 |
+
|
| 513 |
+
batch_size = batch_size * num_images_per_prompt
|
| 514 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 515 |
+
|
| 516 |
+
prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt(
|
| 517 |
+
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
if isinstance(image_embeds, list):
|
| 521 |
+
image_embeds = torch.cat(image_embeds, dim=0)
|
| 522 |
+
if isinstance(negative_image_embeds, list):
|
| 523 |
+
negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
|
| 524 |
+
|
| 525 |
+
if do_classifier_free_guidance:
|
| 526 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 527 |
+
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 528 |
+
|
| 529 |
+
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
|
| 530 |
+
dtype=prompt_embeds.dtype, device=device
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
# preprocess image and mask
|
| 534 |
+
mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width)
|
| 535 |
+
|
| 536 |
+
image = image.to(dtype=prompt_embeds.dtype, device=device)
|
| 537 |
+
image = self.movq.encode(image)["latents"]
|
| 538 |
+
|
| 539 |
+
mask_image = mask_image.to(dtype=prompt_embeds.dtype, device=device)
|
| 540 |
+
|
| 541 |
+
image_shape = tuple(image.shape[-2:])
|
| 542 |
+
mask_image = F.interpolate(
|
| 543 |
+
mask_image,
|
| 544 |
+
image_shape,
|
| 545 |
+
mode="nearest",
|
| 546 |
+
)
|
| 547 |
+
mask_image = prepare_mask(mask_image)
|
| 548 |
+
masked_image = image * mask_image
|
| 549 |
+
|
| 550 |
+
mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0)
|
| 551 |
+
masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0)
|
| 552 |
+
if do_classifier_free_guidance:
|
| 553 |
+
mask_image = mask_image.repeat(2, 1, 1, 1)
|
| 554 |
+
masked_image = masked_image.repeat(2, 1, 1, 1)
|
| 555 |
+
|
| 556 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 557 |
+
timesteps_tensor = self.scheduler.timesteps
|
| 558 |
+
|
| 559 |
+
num_channels_latents = self.movq.config.latent_channels
|
| 560 |
+
|
| 561 |
+
# get h, w for latents
|
| 562 |
+
sample_height, sample_width = get_new_h_w(height, width, self.movq_scale_factor)
|
| 563 |
+
|
| 564 |
+
# create initial latent
|
| 565 |
+
latents = self.prepare_latents(
|
| 566 |
+
(batch_size, num_channels_latents, sample_height, sample_width),
|
| 567 |
+
text_encoder_hidden_states.dtype,
|
| 568 |
+
device,
|
| 569 |
+
generator,
|
| 570 |
+
latents,
|
| 571 |
+
self.scheduler,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
# Check that sizes of mask, masked image and latents match with expected
|
| 575 |
+
num_channels_mask = mask_image.shape[1]
|
| 576 |
+
num_channels_masked_image = masked_image.shape[1]
|
| 577 |
+
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
| 578 |
+
raise ValueError(
|
| 579 |
+
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
| 580 |
+
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
| 581 |
+
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
| 582 |
+
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
| 583 |
+
" `pipeline.unet` or your `mask_image` or `image` input."
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
| 587 |
+
# expand the latents if we are doing classifier free guidance
|
| 588 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 589 |
+
latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1)
|
| 590 |
+
|
| 591 |
+
added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds}
|
| 592 |
+
noise_pred = self.unet(
|
| 593 |
+
sample=latent_model_input,
|
| 594 |
+
timestep=t,
|
| 595 |
+
encoder_hidden_states=text_encoder_hidden_states,
|
| 596 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 597 |
+
return_dict=False,
|
| 598 |
+
)[0]
|
| 599 |
+
|
| 600 |
+
if do_classifier_free_guidance:
|
| 601 |
+
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
|
| 602 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 603 |
+
_, variance_pred_text = variance_pred.chunk(2)
|
| 604 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 605 |
+
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
|
| 606 |
+
|
| 607 |
+
if not (
|
| 608 |
+
hasattr(self.scheduler.config, "variance_type")
|
| 609 |
+
and self.scheduler.config.variance_type in ["learned", "learned_range"]
|
| 610 |
+
):
|
| 611 |
+
noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
|
| 612 |
+
|
| 613 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 614 |
+
latents = self.scheduler.step(
|
| 615 |
+
noise_pred,
|
| 616 |
+
t,
|
| 617 |
+
latents,
|
| 618 |
+
generator=generator,
|
| 619 |
+
).prev_sample
|
| 620 |
+
|
| 621 |
+
if callback is not None and i % callback_steps == 0:
|
| 622 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 623 |
+
callback(step_idx, t, latents)
|
| 624 |
+
|
| 625 |
+
if XLA_AVAILABLE:
|
| 626 |
+
xm.mark_step()
|
| 627 |
+
|
| 628 |
+
# post-processing
|
| 629 |
+
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
| 630 |
+
|
| 631 |
+
self.maybe_free_model_hooks()
|
| 632 |
+
|
| 633 |
+
if output_type not in ["pt", "np", "pil"]:
|
| 634 |
+
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
|
| 635 |
+
|
| 636 |
+
if output_type in ["np", "pil"]:
|
| 637 |
+
image = image * 0.5 + 0.5
|
| 638 |
+
image = image.clamp(0, 1)
|
| 639 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 640 |
+
|
| 641 |
+
if output_type == "pil":
|
| 642 |
+
image = self.numpy_to_pil(image)
|
| 643 |
+
|
| 644 |
+
if not return_dict:
|
| 645 |
+
return (image,)
|
| 646 |
+
|
| 647 |
+
return ImagePipelineOutput(images=image)
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import PIL.Image
|
| 20 |
+
import torch
|
| 21 |
+
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 22 |
+
|
| 23 |
+
from ...models import PriorTransformer
|
| 24 |
+
from ...schedulers import UnCLIPScheduler
|
| 25 |
+
from ...utils import (
|
| 26 |
+
BaseOutput,
|
| 27 |
+
is_torch_xla_available,
|
| 28 |
+
logging,
|
| 29 |
+
replace_example_docstring,
|
| 30 |
+
)
|
| 31 |
+
from ...utils.torch_utils import randn_tensor
|
| 32 |
+
from ..pipeline_utils import DiffusionPipeline
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if is_torch_xla_available():
|
| 36 |
+
import torch_xla.core.xla_model as xm
|
| 37 |
+
|
| 38 |
+
XLA_AVAILABLE = True
|
| 39 |
+
else:
|
| 40 |
+
XLA_AVAILABLE = False
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
EXAMPLE_DOC_STRING = """
|
| 46 |
+
Examples:
|
| 47 |
+
```py
|
| 48 |
+
>>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline
|
| 49 |
+
>>> import torch
|
| 50 |
+
|
| 51 |
+
>>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior")
|
| 52 |
+
>>> pipe_prior.to("cuda")
|
| 53 |
+
|
| 54 |
+
>>> prompt = "red cat, 4k photo"
|
| 55 |
+
>>> out = pipe_prior(prompt)
|
| 56 |
+
>>> image_emb = out.image_embeds
|
| 57 |
+
>>> negative_image_emb = out.negative_image_embeds
|
| 58 |
+
|
| 59 |
+
>>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1")
|
| 60 |
+
>>> pipe.to("cuda")
|
| 61 |
+
|
| 62 |
+
>>> image = pipe(
|
| 63 |
+
... prompt,
|
| 64 |
+
... image_embeds=image_emb,
|
| 65 |
+
... negative_image_embeds=negative_image_emb,
|
| 66 |
+
... height=768,
|
| 67 |
+
... width=768,
|
| 68 |
+
... num_inference_steps=100,
|
| 69 |
+
... ).images
|
| 70 |
+
|
| 71 |
+
>>> image[0].save("cat.png")
|
| 72 |
+
```
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
EXAMPLE_INTERPOLATE_DOC_STRING = """
|
| 76 |
+
Examples:
|
| 77 |
+
```py
|
| 78 |
+
>>> from diffusers import KandinskyPriorPipeline, KandinskyPipeline
|
| 79 |
+
>>> from diffusers.utils import load_image
|
| 80 |
+
>>> import PIL
|
| 81 |
+
|
| 82 |
+
>>> import torch
|
| 83 |
+
>>> from torchvision import transforms
|
| 84 |
+
|
| 85 |
+
>>> pipe_prior = KandinskyPriorPipeline.from_pretrained(
|
| 86 |
+
... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
|
| 87 |
+
... )
|
| 88 |
+
>>> pipe_prior.to("cuda")
|
| 89 |
+
|
| 90 |
+
>>> img1 = load_image(
|
| 91 |
+
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
| 92 |
+
... "/kandinsky/cat.png"
|
| 93 |
+
... )
|
| 94 |
+
|
| 95 |
+
>>> img2 = load_image(
|
| 96 |
+
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
| 97 |
+
... "/kandinsky/starry_night.jpeg"
|
| 98 |
+
... )
|
| 99 |
+
|
| 100 |
+
>>> images_texts = ["a cat", img1, img2]
|
| 101 |
+
>>> weights = [0.3, 0.3, 0.4]
|
| 102 |
+
>>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights)
|
| 103 |
+
|
| 104 |
+
>>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
|
| 105 |
+
>>> pipe.to("cuda")
|
| 106 |
+
|
| 107 |
+
>>> image = pipe(
|
| 108 |
+
... "",
|
| 109 |
+
... image_embeds=image_emb,
|
| 110 |
+
... negative_image_embeds=zero_image_emb,
|
| 111 |
+
... height=768,
|
| 112 |
+
... width=768,
|
| 113 |
+
... num_inference_steps=150,
|
| 114 |
+
... ).images[0]
|
| 115 |
+
|
| 116 |
+
>>> image.save("starry_cat.png")
|
| 117 |
+
```
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@dataclass
|
| 122 |
+
class KandinskyPriorPipelineOutput(BaseOutput):
|
| 123 |
+
"""
|
| 124 |
+
Output class for KandinskyPriorPipeline.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
image_embeds (`torch.Tensor`)
|
| 128 |
+
clip image embeddings for text prompt
|
| 129 |
+
negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`)
|
| 130 |
+
clip image embeddings for unconditional tokens
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
image_embeds: Union[torch.Tensor, np.ndarray]
|
| 134 |
+
negative_image_embeds: Union[torch.Tensor, np.ndarray]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class KandinskyPriorPipeline(DiffusionPipeline):
|
| 138 |
+
"""
|
| 139 |
+
Pipeline for generating image prior for Kandinsky
|
| 140 |
+
|
| 141 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 142 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
prior ([`PriorTransformer`]):
|
| 146 |
+
The canonical unCLIP prior to approximate the image embedding from the text embedding.
|
| 147 |
+
image_encoder ([`CLIPVisionModelWithProjection`]):
|
| 148 |
+
Frozen image-encoder.
|
| 149 |
+
text_encoder ([`CLIPTextModelWithProjection`]):
|
| 150 |
+
Frozen text-encoder.
|
| 151 |
+
tokenizer (`CLIPTokenizer`):
|
| 152 |
+
Tokenizer of class
|
| 153 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 154 |
+
scheduler ([`UnCLIPScheduler`]):
|
| 155 |
+
A scheduler to be used in combination with `prior` to generate image embedding.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
_exclude_from_cpu_offload = ["prior"]
|
| 159 |
+
model_cpu_offload_seq = "text_encoder->prior"
|
| 160 |
+
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
prior: PriorTransformer,
|
| 164 |
+
image_encoder: CLIPVisionModelWithProjection,
|
| 165 |
+
text_encoder: CLIPTextModelWithProjection,
|
| 166 |
+
tokenizer: CLIPTokenizer,
|
| 167 |
+
scheduler: UnCLIPScheduler,
|
| 168 |
+
image_processor: CLIPImageProcessor,
|
| 169 |
+
):
|
| 170 |
+
super().__init__()
|
| 171 |
+
|
| 172 |
+
self.register_modules(
|
| 173 |
+
prior=prior,
|
| 174 |
+
text_encoder=text_encoder,
|
| 175 |
+
tokenizer=tokenizer,
|
| 176 |
+
scheduler=scheduler,
|
| 177 |
+
image_encoder=image_encoder,
|
| 178 |
+
image_processor=image_processor,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
@torch.no_grad()
|
| 182 |
+
@replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING)
|
| 183 |
+
def interpolate(
|
| 184 |
+
self,
|
| 185 |
+
images_and_prompts: List[Union[str, PIL.Image.Image, torch.Tensor]],
|
| 186 |
+
weights: List[float],
|
| 187 |
+
num_images_per_prompt: int = 1,
|
| 188 |
+
num_inference_steps: int = 25,
|
| 189 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 190 |
+
latents: Optional[torch.Tensor] = None,
|
| 191 |
+
negative_prior_prompt: Optional[str] = None,
|
| 192 |
+
negative_prompt: str = "",
|
| 193 |
+
guidance_scale: float = 4.0,
|
| 194 |
+
device=None,
|
| 195 |
+
):
|
| 196 |
+
"""
|
| 197 |
+
Function invoked when using the prior pipeline for interpolation.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
images_and_prompts (`List[Union[str, PIL.Image.Image, torch.Tensor]]`):
|
| 201 |
+
list of prompts and images to guide the image generation.
|
| 202 |
+
weights: (`List[float]`):
|
| 203 |
+
list of weights for each condition in `images_and_prompts`
|
| 204 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 205 |
+
The number of images to generate per prompt.
|
| 206 |
+
num_inference_steps (`int`, *optional*, defaults to 25):
|
| 207 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 208 |
+
expense of slower inference.
|
| 209 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 210 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 211 |
+
to make generation deterministic.
|
| 212 |
+
latents (`torch.Tensor`, *optional*):
|
| 213 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 214 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 215 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 216 |
+
negative_prior_prompt (`str`, *optional*):
|
| 217 |
+
The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
|
| 218 |
+
`guidance_scale` is less than `1`).
|
| 219 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 220 |
+
The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
|
| 221 |
+
`guidance_scale` is less than `1`).
|
| 222 |
+
guidance_scale (`float`, *optional*, defaults to 4.0):
|
| 223 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 224 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 225 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 226 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 227 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 228 |
+
|
| 229 |
+
Examples:
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
[`KandinskyPriorPipelineOutput`] or `tuple`
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
device = device or self.device
|
| 236 |
+
|
| 237 |
+
if len(images_and_prompts) != len(weights):
|
| 238 |
+
raise ValueError(
|
| 239 |
+
f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
image_embeddings = []
|
| 243 |
+
for cond, weight in zip(images_and_prompts, weights):
|
| 244 |
+
if isinstance(cond, str):
|
| 245 |
+
image_emb = self(
|
| 246 |
+
cond,
|
| 247 |
+
num_inference_steps=num_inference_steps,
|
| 248 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 249 |
+
generator=generator,
|
| 250 |
+
latents=latents,
|
| 251 |
+
negative_prompt=negative_prior_prompt,
|
| 252 |
+
guidance_scale=guidance_scale,
|
| 253 |
+
).image_embeds
|
| 254 |
+
|
| 255 |
+
elif isinstance(cond, (PIL.Image.Image, torch.Tensor)):
|
| 256 |
+
if isinstance(cond, PIL.Image.Image):
|
| 257 |
+
cond = (
|
| 258 |
+
self.image_processor(cond, return_tensors="pt")
|
| 259 |
+
.pixel_values[0]
|
| 260 |
+
.unsqueeze(0)
|
| 261 |
+
.to(dtype=self.image_encoder.dtype, device=device)
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
image_emb = self.image_encoder(cond)["image_embeds"]
|
| 265 |
+
|
| 266 |
+
else:
|
| 267 |
+
raise ValueError(
|
| 268 |
+
f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}"
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
image_embeddings.append(image_emb * weight)
|
| 272 |
+
|
| 273 |
+
image_emb = torch.cat(image_embeddings).sum(dim=0, keepdim=True)
|
| 274 |
+
|
| 275 |
+
out_zero = self(
|
| 276 |
+
negative_prompt,
|
| 277 |
+
num_inference_steps=num_inference_steps,
|
| 278 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 279 |
+
generator=generator,
|
| 280 |
+
latents=latents,
|
| 281 |
+
negative_prompt=negative_prior_prompt,
|
| 282 |
+
guidance_scale=guidance_scale,
|
| 283 |
+
)
|
| 284 |
+
zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds
|
| 285 |
+
|
| 286 |
+
return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb)
|
| 287 |
+
|
| 288 |
+
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
| 289 |
+
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
| 290 |
+
if latents is None:
|
| 291 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 292 |
+
else:
|
| 293 |
+
if latents.shape != shape:
|
| 294 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
| 295 |
+
latents = latents.to(device)
|
| 296 |
+
|
| 297 |
+
latents = latents * scheduler.init_noise_sigma
|
| 298 |
+
return latents
|
| 299 |
+
|
| 300 |
+
def get_zero_embed(self, batch_size=1, device=None):
|
| 301 |
+
device = device or self.device
|
| 302 |
+
zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to(
|
| 303 |
+
device=device, dtype=self.image_encoder.dtype
|
| 304 |
+
)
|
| 305 |
+
zero_image_emb = self.image_encoder(zero_img)["image_embeds"]
|
| 306 |
+
zero_image_emb = zero_image_emb.repeat(batch_size, 1)
|
| 307 |
+
return zero_image_emb
|
| 308 |
+
|
| 309 |
+
def _encode_prompt(
|
| 310 |
+
self,
|
| 311 |
+
prompt,
|
| 312 |
+
device,
|
| 313 |
+
num_images_per_prompt,
|
| 314 |
+
do_classifier_free_guidance,
|
| 315 |
+
negative_prompt=None,
|
| 316 |
+
):
|
| 317 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
| 318 |
+
# get prompt text embeddings
|
| 319 |
+
text_inputs = self.tokenizer(
|
| 320 |
+
prompt,
|
| 321 |
+
padding="max_length",
|
| 322 |
+
max_length=self.tokenizer.model_max_length,
|
| 323 |
+
truncation=True,
|
| 324 |
+
return_tensors="pt",
|
| 325 |
+
)
|
| 326 |
+
text_input_ids = text_inputs.input_ids
|
| 327 |
+
text_mask = text_inputs.attention_mask.bool().to(device)
|
| 328 |
+
|
| 329 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 330 |
+
|
| 331 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 332 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
| 333 |
+
logger.warning(
|
| 334 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 335 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 336 |
+
)
|
| 337 |
+
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
| 338 |
+
|
| 339 |
+
text_encoder_output = self.text_encoder(text_input_ids.to(device))
|
| 340 |
+
|
| 341 |
+
prompt_embeds = text_encoder_output.text_embeds
|
| 342 |
+
text_encoder_hidden_states = text_encoder_output.last_hidden_state
|
| 343 |
+
|
| 344 |
+
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 345 |
+
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
| 346 |
+
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
| 347 |
+
|
| 348 |
+
if do_classifier_free_guidance:
|
| 349 |
+
uncond_tokens: List[str]
|
| 350 |
+
if negative_prompt is None:
|
| 351 |
+
uncond_tokens = [""] * batch_size
|
| 352 |
+
elif type(prompt) is not type(negative_prompt):
|
| 353 |
+
raise TypeError(
|
| 354 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 355 |
+
f" {type(prompt)}."
|
| 356 |
+
)
|
| 357 |
+
elif isinstance(negative_prompt, str):
|
| 358 |
+
uncond_tokens = [negative_prompt]
|
| 359 |
+
elif batch_size != len(negative_prompt):
|
| 360 |
+
raise ValueError(
|
| 361 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 362 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 363 |
+
" the batch size of `prompt`."
|
| 364 |
+
)
|
| 365 |
+
else:
|
| 366 |
+
uncond_tokens = negative_prompt
|
| 367 |
+
|
| 368 |
+
uncond_input = self.tokenizer(
|
| 369 |
+
uncond_tokens,
|
| 370 |
+
padding="max_length",
|
| 371 |
+
max_length=self.tokenizer.model_max_length,
|
| 372 |
+
truncation=True,
|
| 373 |
+
return_tensors="pt",
|
| 374 |
+
)
|
| 375 |
+
uncond_text_mask = uncond_input.attention_mask.bool().to(device)
|
| 376 |
+
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
|
| 377 |
+
|
| 378 |
+
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
|
| 379 |
+
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
|
| 380 |
+
|
| 381 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 382 |
+
|
| 383 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 384 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
|
| 385 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
|
| 386 |
+
|
| 387 |
+
seq_len = uncond_text_encoder_hidden_states.shape[1]
|
| 388 |
+
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
|
| 389 |
+
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
|
| 390 |
+
batch_size * num_images_per_prompt, seq_len, -1
|
| 391 |
+
)
|
| 392 |
+
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
| 393 |
+
|
| 394 |
+
# done duplicates
|
| 395 |
+
|
| 396 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 397 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 398 |
+
# to avoid doing two forward passes
|
| 399 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 400 |
+
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
|
| 401 |
+
|
| 402 |
+
text_mask = torch.cat([uncond_text_mask, text_mask])
|
| 403 |
+
|
| 404 |
+
return prompt_embeds, text_encoder_hidden_states, text_mask
|
| 405 |
+
|
| 406 |
+
@torch.no_grad()
|
| 407 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 408 |
+
def __call__(
|
| 409 |
+
self,
|
| 410 |
+
prompt: Union[str, List[str]],
|
| 411 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 412 |
+
num_images_per_prompt: int = 1,
|
| 413 |
+
num_inference_steps: int = 25,
|
| 414 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 415 |
+
latents: Optional[torch.Tensor] = None,
|
| 416 |
+
guidance_scale: float = 4.0,
|
| 417 |
+
output_type: Optional[str] = "pt",
|
| 418 |
+
return_dict: bool = True,
|
| 419 |
+
):
|
| 420 |
+
"""
|
| 421 |
+
Function invoked when calling the pipeline for generation.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
prompt (`str` or `List[str]`):
|
| 425 |
+
The prompt or prompts to guide the image generation.
|
| 426 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 427 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 428 |
+
if `guidance_scale` is less than `1`).
|
| 429 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 430 |
+
The number of images to generate per prompt.
|
| 431 |
+
num_inference_steps (`int`, *optional*, defaults to 25):
|
| 432 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 433 |
+
expense of slower inference.
|
| 434 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 435 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 436 |
+
to make generation deterministic.
|
| 437 |
+
latents (`torch.Tensor`, *optional*):
|
| 438 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 439 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 440 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 441 |
+
guidance_scale (`float`, *optional*, defaults to 4.0):
|
| 442 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 443 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 444 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 445 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 446 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 447 |
+
output_type (`str`, *optional*, defaults to `"pt"`):
|
| 448 |
+
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
|
| 449 |
+
(`torch.Tensor`).
|
| 450 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 451 |
+
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
| 452 |
+
|
| 453 |
+
Examples:
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
[`KandinskyPriorPipelineOutput`] or `tuple`
|
| 457 |
+
"""
|
| 458 |
+
|
| 459 |
+
if isinstance(prompt, str):
|
| 460 |
+
prompt = [prompt]
|
| 461 |
+
elif not isinstance(prompt, list):
|
| 462 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 463 |
+
|
| 464 |
+
if isinstance(negative_prompt, str):
|
| 465 |
+
negative_prompt = [negative_prompt]
|
| 466 |
+
elif not isinstance(negative_prompt, list) and negative_prompt is not None:
|
| 467 |
+
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
| 468 |
+
|
| 469 |
+
# if the negative prompt is defined we double the batch size to
|
| 470 |
+
# directly retrieve the negative prompt embedding
|
| 471 |
+
if negative_prompt is not None:
|
| 472 |
+
prompt = prompt + negative_prompt
|
| 473 |
+
negative_prompt = 2 * negative_prompt
|
| 474 |
+
|
| 475 |
+
device = self._execution_device
|
| 476 |
+
|
| 477 |
+
batch_size = len(prompt)
|
| 478 |
+
batch_size = batch_size * num_images_per_prompt
|
| 479 |
+
|
| 480 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 481 |
+
prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
|
| 482 |
+
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# prior
|
| 486 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 487 |
+
prior_timesteps_tensor = self.scheduler.timesteps
|
| 488 |
+
|
| 489 |
+
embedding_dim = self.prior.config.embedding_dim
|
| 490 |
+
|
| 491 |
+
latents = self.prepare_latents(
|
| 492 |
+
(batch_size, embedding_dim),
|
| 493 |
+
prompt_embeds.dtype,
|
| 494 |
+
device,
|
| 495 |
+
generator,
|
| 496 |
+
latents,
|
| 497 |
+
self.scheduler,
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
|
| 501 |
+
# expand the latents if we are doing classifier free guidance
|
| 502 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 503 |
+
|
| 504 |
+
predicted_image_embedding = self.prior(
|
| 505 |
+
latent_model_input,
|
| 506 |
+
timestep=t,
|
| 507 |
+
proj_embedding=prompt_embeds,
|
| 508 |
+
encoder_hidden_states=text_encoder_hidden_states,
|
| 509 |
+
attention_mask=text_mask,
|
| 510 |
+
).predicted_image_embedding
|
| 511 |
+
|
| 512 |
+
if do_classifier_free_guidance:
|
| 513 |
+
predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
|
| 514 |
+
predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * (
|
| 515 |
+
predicted_image_embedding_text - predicted_image_embedding_uncond
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
if i + 1 == prior_timesteps_tensor.shape[0]:
|
| 519 |
+
prev_timestep = None
|
| 520 |
+
else:
|
| 521 |
+
prev_timestep = prior_timesteps_tensor[i + 1]
|
| 522 |
+
|
| 523 |
+
latents = self.scheduler.step(
|
| 524 |
+
predicted_image_embedding,
|
| 525 |
+
timestep=t,
|
| 526 |
+
sample=latents,
|
| 527 |
+
generator=generator,
|
| 528 |
+
prev_timestep=prev_timestep,
|
| 529 |
+
).prev_sample
|
| 530 |
+
|
| 531 |
+
if XLA_AVAILABLE:
|
| 532 |
+
xm.mark_step()
|
| 533 |
+
|
| 534 |
+
latents = self.prior.post_process_latents(latents)
|
| 535 |
+
|
| 536 |
+
image_embeddings = latents
|
| 537 |
+
|
| 538 |
+
# if negative prompt has been defined, we retrieve split the image embedding into two
|
| 539 |
+
if negative_prompt is None:
|
| 540 |
+
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
|
| 541 |
+
|
| 542 |
+
self.maybe_free_model_hooks()
|
| 543 |
+
else:
|
| 544 |
+
image_embeddings, zero_embeds = image_embeddings.chunk(2)
|
| 545 |
+
|
| 546 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 547 |
+
self.prior_hook.offload()
|
| 548 |
+
|
| 549 |
+
if output_type not in ["pt", "np"]:
|
| 550 |
+
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
|
| 551 |
+
|
| 552 |
+
if output_type == "np":
|
| 553 |
+
image_embeddings = image_embeddings.cpu().numpy()
|
| 554 |
+
zero_embeds = zero_embeds.cpu().numpy()
|
| 555 |
+
|
| 556 |
+
if not return_dict:
|
| 557 |
+
return (image_embeddings, zero_embeds)
|
| 558 |
+
|
| 559 |
+
return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds)
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/kandinsky/text_encoder.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import PreTrainedModel, XLMRobertaConfig, XLMRobertaModel
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class MCLIPConfig(XLMRobertaConfig):
|
| 6 |
+
model_type = "M-CLIP"
|
| 7 |
+
|
| 8 |
+
def __init__(self, transformerDimSize=1024, imageDimSize=768, **kwargs):
|
| 9 |
+
self.transformerDimensions = transformerDimSize
|
| 10 |
+
self.numDims = imageDimSize
|
| 11 |
+
super().__init__(**kwargs)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MultilingualCLIP(PreTrainedModel):
|
| 15 |
+
config_class = MCLIPConfig
|
| 16 |
+
|
| 17 |
+
def __init__(self, config, *args, **kwargs):
|
| 18 |
+
super().__init__(config, *args, **kwargs)
|
| 19 |
+
self.transformer = XLMRobertaModel(config)
|
| 20 |
+
self.LinearTransformation = torch.nn.Linear(
|
| 21 |
+
in_features=config.transformerDimensions, out_features=config.numDims
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, input_ids, attention_mask):
|
| 25 |
+
embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0]
|
| 26 |
+
embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None]
|
| 27 |
+
return self.LinearTransformation(embs2), embs
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.3 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/sana/__pycache__/pipeline_sana_sprint_img2img.cpython-310.pyc
ADDED
|
Binary file (31.6 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/semantic_stable_diffusion/__init__.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING
|
| 2 |
+
|
| 3 |
+
from ...utils import (
|
| 4 |
+
DIFFUSERS_SLOW_IMPORT,
|
| 5 |
+
OptionalDependencyNotAvailable,
|
| 6 |
+
_LazyModule,
|
| 7 |
+
get_objects_from_module,
|
| 8 |
+
is_torch_available,
|
| 9 |
+
is_transformers_available,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_dummy_objects = {}
|
| 14 |
+
_import_structure = {}
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 18 |
+
raise OptionalDependencyNotAvailable()
|
| 19 |
+
except OptionalDependencyNotAvailable:
|
| 20 |
+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
| 21 |
+
|
| 22 |
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
| 23 |
+
else:
|
| 24 |
+
_import_structure["pipeline_output"] = ["SemanticStableDiffusionPipelineOutput"]
|
| 25 |
+
_import_structure["pipeline_semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
| 29 |
+
try:
|
| 30 |
+
if not (is_transformers_available() and is_torch_available()):
|
| 31 |
+
raise OptionalDependencyNotAvailable()
|
| 32 |
+
|
| 33 |
+
except OptionalDependencyNotAvailable:
|
| 34 |
+
from ...utils.dummy_torch_and_transformers_objects import *
|
| 35 |
+
else:
|
| 36 |
+
from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
| 37 |
+
|
| 38 |
+
else:
|
| 39 |
+
import sys
|
| 40 |
+
|
| 41 |
+
sys.modules[__name__] = _LazyModule(
|
| 42 |
+
__name__,
|
| 43 |
+
globals()["__file__"],
|
| 44 |
+
_import_structure,
|
| 45 |
+
module_spec=__spec__,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
for name, value in _dummy_objects.items():
|
| 49 |
+
setattr(sys.modules[__name__], name, value)
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/semantic_stable_diffusion/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.15 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_output.cpython-310.pyc
ADDED
|
Binary file (1.27 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/semantic_stable_diffusion/__pycache__/pipeline_semantic_stable_diffusion.cpython-310.pyc
ADDED
|
Binary file (22.3 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/semantic_stable_diffusion/pipeline_output.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Optional, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import PIL.Image
|
| 6 |
+
|
| 7 |
+
from ...utils import BaseOutput
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class SemanticStableDiffusionPipelineOutput(BaseOutput):
|
| 12 |
+
"""
|
| 13 |
+
Output class for Stable Diffusion pipelines.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
| 17 |
+
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
| 18 |
+
num_channels)`.
|
| 19 |
+
nsfw_content_detected (`List[bool]`)
|
| 20 |
+
List indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content or
|
| 21 |
+
`None` if safety checking could not be performed.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
| 25 |
+
nsfw_content_detected: Optional[List[bool]]
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
ADDED
|
@@ -0,0 +1,733 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from itertools import repeat
|
| 3 |
+
from typing import Callable, List, Optional, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
| 7 |
+
|
| 8 |
+
from ...image_processor import VaeImageProcessor
|
| 9 |
+
from ...models import AutoencoderKL, UNet2DConditionModel
|
| 10 |
+
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 11 |
+
from ...schedulers import KarrasDiffusionSchedulers
|
| 12 |
+
from ...utils import deprecate, is_torch_xla_available, logging
|
| 13 |
+
from ...utils.torch_utils import randn_tensor
|
| 14 |
+
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
| 15 |
+
from .pipeline_output import SemanticStableDiffusionPipelineOutput
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if is_torch_xla_available():
|
| 19 |
+
import torch_xla.core.xla_model as xm
|
| 20 |
+
|
| 21 |
+
XLA_AVAILABLE = True
|
| 22 |
+
else:
|
| 23 |
+
XLA_AVAILABLE = False
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SemanticStableDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
|
| 29 |
+
_last_supported_version = "0.33.1"
|
| 30 |
+
r"""
|
| 31 |
+
Pipeline for text-to-image generation using Stable Diffusion with latent editing.
|
| 32 |
+
|
| 33 |
+
This model inherits from [`DiffusionPipeline`] and builds on the [`StableDiffusionPipeline`]. Check the superclass
|
| 34 |
+
documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular
|
| 35 |
+
device, etc.).
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
vae ([`AutoencoderKL`]):
|
| 39 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 40 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
| 41 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
| 42 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
| 43 |
+
A `CLIPTokenizer` to tokenize text.
|
| 44 |
+
unet ([`UNet2DConditionModel`]):
|
| 45 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
| 46 |
+
scheduler ([`SchedulerMixin`]):
|
| 47 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 48 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 49 |
+
safety_checker ([`Q16SafetyChecker`]):
|
| 50 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 51 |
+
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
| 52 |
+
about a model's potential harms.
|
| 53 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
| 54 |
+
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
model_cpu_offload_seq = "text_encoder->unet->vae"
|
| 58 |
+
_optional_components = ["safety_checker", "feature_extractor"]
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
vae: AutoencoderKL,
|
| 63 |
+
text_encoder: CLIPTextModel,
|
| 64 |
+
tokenizer: CLIPTokenizer,
|
| 65 |
+
unet: UNet2DConditionModel,
|
| 66 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 67 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 68 |
+
feature_extractor: CLIPImageProcessor,
|
| 69 |
+
requires_safety_checker: bool = True,
|
| 70 |
+
):
|
| 71 |
+
super().__init__()
|
| 72 |
+
|
| 73 |
+
if safety_checker is None and requires_safety_checker:
|
| 74 |
+
logger.warning(
|
| 75 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
| 76 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
| 77 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
| 78 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
| 79 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
| 80 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if safety_checker is not None and feature_extractor is None:
|
| 84 |
+
raise ValueError(
|
| 85 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
| 86 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
self.register_modules(
|
| 90 |
+
vae=vae,
|
| 91 |
+
text_encoder=text_encoder,
|
| 92 |
+
tokenizer=tokenizer,
|
| 93 |
+
unet=unet,
|
| 94 |
+
scheduler=scheduler,
|
| 95 |
+
safety_checker=safety_checker,
|
| 96 |
+
feature_extractor=feature_extractor,
|
| 97 |
+
)
|
| 98 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 99 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 100 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
| 101 |
+
|
| 102 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
| 103 |
+
def run_safety_checker(self, image, device, dtype):
|
| 104 |
+
if self.safety_checker is None:
|
| 105 |
+
has_nsfw_concept = None
|
| 106 |
+
else:
|
| 107 |
+
if torch.is_tensor(image):
|
| 108 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
| 109 |
+
else:
|
| 110 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
| 111 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
| 112 |
+
image, has_nsfw_concept = self.safety_checker(
|
| 113 |
+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
| 114 |
+
)
|
| 115 |
+
return image, has_nsfw_concept
|
| 116 |
+
|
| 117 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
| 118 |
+
def decode_latents(self, latents):
|
| 119 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
| 120 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
| 121 |
+
|
| 122 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 123 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 124 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 125 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 126 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 127 |
+
return image
|
| 128 |
+
|
| 129 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 130 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 131 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 132 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 133 |
+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
| 134 |
+
# and should be between [0, 1]
|
| 135 |
+
|
| 136 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 137 |
+
extra_step_kwargs = {}
|
| 138 |
+
if accepts_eta:
|
| 139 |
+
extra_step_kwargs["eta"] = eta
|
| 140 |
+
|
| 141 |
+
# check if the scheduler accepts generator
|
| 142 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 143 |
+
if accepts_generator:
|
| 144 |
+
extra_step_kwargs["generator"] = generator
|
| 145 |
+
return extra_step_kwargs
|
| 146 |
+
|
| 147 |
+
# Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs
|
| 148 |
+
def check_inputs(
|
| 149 |
+
self,
|
| 150 |
+
prompt,
|
| 151 |
+
height,
|
| 152 |
+
width,
|
| 153 |
+
callback_steps,
|
| 154 |
+
negative_prompt=None,
|
| 155 |
+
prompt_embeds=None,
|
| 156 |
+
negative_prompt_embeds=None,
|
| 157 |
+
callback_on_step_end_tensor_inputs=None,
|
| 158 |
+
):
|
| 159 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 160 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 161 |
+
|
| 162 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
| 163 |
+
raise ValueError(
|
| 164 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 165 |
+
f" {type(callback_steps)}."
|
| 166 |
+
)
|
| 167 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 168 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 169 |
+
):
|
| 170 |
+
raise ValueError(
|
| 171 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if prompt is not None and prompt_embeds is not None:
|
| 175 |
+
raise ValueError(
|
| 176 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 177 |
+
" only forward one of the two."
|
| 178 |
+
)
|
| 179 |
+
elif prompt is None and prompt_embeds is None:
|
| 180 |
+
raise ValueError(
|
| 181 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 182 |
+
)
|
| 183 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 184 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 185 |
+
|
| 186 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 187 |
+
raise ValueError(
|
| 188 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 189 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 193 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 194 |
+
raise ValueError(
|
| 195 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 196 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 197 |
+
f" {negative_prompt_embeds.shape}."
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
| 201 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
| 202 |
+
shape = (
|
| 203 |
+
batch_size,
|
| 204 |
+
num_channels_latents,
|
| 205 |
+
int(height) // self.vae_scale_factor,
|
| 206 |
+
int(width) // self.vae_scale_factor,
|
| 207 |
+
)
|
| 208 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 209 |
+
raise ValueError(
|
| 210 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 211 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
if latents is None:
|
| 215 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 216 |
+
else:
|
| 217 |
+
latents = latents.to(device)
|
| 218 |
+
|
| 219 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 220 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 221 |
+
return latents
|
| 222 |
+
|
| 223 |
+
@torch.no_grad()
|
| 224 |
+
def __call__(
|
| 225 |
+
self,
|
| 226 |
+
prompt: Union[str, List[str]],
|
| 227 |
+
height: Optional[int] = None,
|
| 228 |
+
width: Optional[int] = None,
|
| 229 |
+
num_inference_steps: int = 50,
|
| 230 |
+
guidance_scale: float = 7.5,
|
| 231 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 232 |
+
num_images_per_prompt: int = 1,
|
| 233 |
+
eta: float = 0.0,
|
| 234 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 235 |
+
latents: Optional[torch.Tensor] = None,
|
| 236 |
+
output_type: Optional[str] = "pil",
|
| 237 |
+
return_dict: bool = True,
|
| 238 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
| 239 |
+
callback_steps: int = 1,
|
| 240 |
+
editing_prompt: Optional[Union[str, List[str]]] = None,
|
| 241 |
+
editing_prompt_embeddings: Optional[torch.Tensor] = None,
|
| 242 |
+
reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
|
| 243 |
+
edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
|
| 244 |
+
edit_warmup_steps: Optional[Union[int, List[int]]] = 10,
|
| 245 |
+
edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
|
| 246 |
+
edit_threshold: Optional[Union[float, List[float]]] = 0.9,
|
| 247 |
+
edit_momentum_scale: Optional[float] = 0.1,
|
| 248 |
+
edit_mom_beta: Optional[float] = 0.4,
|
| 249 |
+
edit_weights: Optional[List[float]] = None,
|
| 250 |
+
sem_guidance: Optional[List[torch.Tensor]] = None,
|
| 251 |
+
):
|
| 252 |
+
r"""
|
| 253 |
+
The call function to the pipeline for generation.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
prompt (`str` or `List[str]`):
|
| 257 |
+
The prompt or prompts to guide image generation.
|
| 258 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 259 |
+
The height in pixels of the generated image.
|
| 260 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 261 |
+
The width in pixels of the generated image.
|
| 262 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 263 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 264 |
+
expense of slower inference.
|
| 265 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 266 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 267 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 268 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 269 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 270 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 271 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 272 |
+
The number of images to generate per prompt.
|
| 273 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 274 |
+
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
|
| 275 |
+
applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 276 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 277 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 278 |
+
generation deterministic.
|
| 279 |
+
latents (`torch.Tensor`, *optional*):
|
| 280 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 281 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 282 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 283 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 284 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 285 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 286 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 287 |
+
plain tuple.
|
| 288 |
+
callback (`Callable`, *optional*):
|
| 289 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 290 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
| 291 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 292 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 293 |
+
every step.
|
| 294 |
+
editing_prompt (`str` or `List[str]`, *optional*):
|
| 295 |
+
The prompt or prompts to use for semantic guidance. Semantic guidance is disabled by setting
|
| 296 |
+
`editing_prompt = None`. Guidance direction of prompt should be specified via
|
| 297 |
+
`reverse_editing_direction`.
|
| 298 |
+
editing_prompt_embeddings (`torch.Tensor`, *optional*):
|
| 299 |
+
Pre-computed embeddings to use for semantic guidance. Guidance direction of embedding should be
|
| 300 |
+
specified via `reverse_editing_direction`.
|
| 301 |
+
reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`):
|
| 302 |
+
Whether the corresponding prompt in `editing_prompt` should be increased or decreased.
|
| 303 |
+
edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5):
|
| 304 |
+
Guidance scale for semantic guidance. If provided as a list, values should correspond to
|
| 305 |
+
`editing_prompt`.
|
| 306 |
+
edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10):
|
| 307 |
+
Number of diffusion steps (for each prompt) for which semantic guidance is not applied. Momentum is
|
| 308 |
+
calculated for those steps and applied once all warmup periods are over.
|
| 309 |
+
edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`):
|
| 310 |
+
Number of diffusion steps (for each prompt) after which semantic guidance is longer applied.
|
| 311 |
+
edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9):
|
| 312 |
+
Threshold of semantic guidance.
|
| 313 |
+
edit_momentum_scale (`float`, *optional*, defaults to 0.1):
|
| 314 |
+
Scale of the momentum to be added to the semantic guidance at each diffusion step. If set to 0.0,
|
| 315 |
+
momentum is disabled. Momentum is already built up during warmup (for diffusion steps smaller than
|
| 316 |
+
`sld_warmup_steps`). Momentum is only added to latent guidance once all warmup periods are finished.
|
| 317 |
+
edit_mom_beta (`float`, *optional*, defaults to 0.4):
|
| 318 |
+
Defines how semantic guidance momentum builds up. `edit_mom_beta` indicates how much of the previous
|
| 319 |
+
momentum is kept. Momentum is already built up during warmup (for diffusion steps smaller than
|
| 320 |
+
`edit_warmup_steps`).
|
| 321 |
+
edit_weights (`List[float]`, *optional*, defaults to `None`):
|
| 322 |
+
Indicates how much each individual concept should influence the overall guidance. If no weights are
|
| 323 |
+
provided all concepts are applied equally.
|
| 324 |
+
sem_guidance (`List[torch.Tensor]`, *optional*):
|
| 325 |
+
List of pre-generated guidance vectors to be applied at generation. Length of the list has to
|
| 326 |
+
correspond to `num_inference_steps`.
|
| 327 |
+
|
| 328 |
+
Examples:
|
| 329 |
+
|
| 330 |
+
```py
|
| 331 |
+
>>> import torch
|
| 332 |
+
>>> from diffusers import SemanticStableDiffusionPipeline
|
| 333 |
+
|
| 334 |
+
>>> pipe = SemanticStableDiffusionPipeline.from_pretrained(
|
| 335 |
+
... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
|
| 336 |
+
... )
|
| 337 |
+
>>> pipe = pipe.to("cuda")
|
| 338 |
+
|
| 339 |
+
>>> out = pipe(
|
| 340 |
+
... prompt="a photo of the face of a woman",
|
| 341 |
+
... num_images_per_prompt=1,
|
| 342 |
+
... guidance_scale=7,
|
| 343 |
+
... editing_prompt=[
|
| 344 |
+
... "smiling, smile", # Concepts to apply
|
| 345 |
+
... "glasses, wearing glasses",
|
| 346 |
+
... "curls, wavy hair, curly hair",
|
| 347 |
+
... "beard, full beard, mustache",
|
| 348 |
+
... ],
|
| 349 |
+
... reverse_editing_direction=[
|
| 350 |
+
... False,
|
| 351 |
+
... False,
|
| 352 |
+
... False,
|
| 353 |
+
... False,
|
| 354 |
+
... ], # Direction of guidance i.e. increase all concepts
|
| 355 |
+
... edit_warmup_steps=[10, 10, 10, 10], # Warmup period for each concept
|
| 356 |
+
... edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept
|
| 357 |
+
... edit_threshold=[
|
| 358 |
+
... 0.99,
|
| 359 |
+
... 0.975,
|
| 360 |
+
... 0.925,
|
| 361 |
+
... 0.96,
|
| 362 |
+
... ], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions
|
| 363 |
+
... edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance
|
| 364 |
+
... edit_mom_beta=0.6, # Momentum beta
|
| 365 |
+
... edit_weights=[1, 1, 1, 1, 1], # Weights of the individual concepts against each other
|
| 366 |
+
... )
|
| 367 |
+
>>> image = out.images[0]
|
| 368 |
+
```
|
| 369 |
+
|
| 370 |
+
Returns:
|
| 371 |
+
[`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] or `tuple`:
|
| 372 |
+
If `return_dict` is `True`,
|
| 373 |
+
[`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] is returned, otherwise a
|
| 374 |
+
`tuple` is returned where the first element is a list with the generated images and the second element
|
| 375 |
+
is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work"
|
| 376 |
+
(nsfw) content.
|
| 377 |
+
"""
|
| 378 |
+
# 0. Default height and width to unet
|
| 379 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 380 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 381 |
+
|
| 382 |
+
# 1. Check inputs. Raise error if not correct
|
| 383 |
+
self.check_inputs(prompt, height, width, callback_steps)
|
| 384 |
+
|
| 385 |
+
# 2. Define call parameters
|
| 386 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
| 387 |
+
device = self._execution_device
|
| 388 |
+
|
| 389 |
+
if editing_prompt:
|
| 390 |
+
enable_edit_guidance = True
|
| 391 |
+
if isinstance(editing_prompt, str):
|
| 392 |
+
editing_prompt = [editing_prompt]
|
| 393 |
+
enabled_editing_prompts = len(editing_prompt)
|
| 394 |
+
elif editing_prompt_embeddings is not None:
|
| 395 |
+
enable_edit_guidance = True
|
| 396 |
+
enabled_editing_prompts = editing_prompt_embeddings.shape[0]
|
| 397 |
+
else:
|
| 398 |
+
enabled_editing_prompts = 0
|
| 399 |
+
enable_edit_guidance = False
|
| 400 |
+
|
| 401 |
+
# get prompt text embeddings
|
| 402 |
+
text_inputs = self.tokenizer(
|
| 403 |
+
prompt,
|
| 404 |
+
padding="max_length",
|
| 405 |
+
max_length=self.tokenizer.model_max_length,
|
| 406 |
+
return_tensors="pt",
|
| 407 |
+
)
|
| 408 |
+
text_input_ids = text_inputs.input_ids
|
| 409 |
+
|
| 410 |
+
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
| 411 |
+
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
| 412 |
+
logger.warning(
|
| 413 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 414 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 415 |
+
)
|
| 416 |
+
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
| 417 |
+
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
|
| 418 |
+
|
| 419 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 420 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
| 421 |
+
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
| 422 |
+
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 423 |
+
|
| 424 |
+
if enable_edit_guidance:
|
| 425 |
+
# get safety text embeddings
|
| 426 |
+
if editing_prompt_embeddings is None:
|
| 427 |
+
edit_concepts_input = self.tokenizer(
|
| 428 |
+
[x for item in editing_prompt for x in repeat(item, batch_size)],
|
| 429 |
+
padding="max_length",
|
| 430 |
+
max_length=self.tokenizer.model_max_length,
|
| 431 |
+
return_tensors="pt",
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
edit_concepts_input_ids = edit_concepts_input.input_ids
|
| 435 |
+
|
| 436 |
+
if edit_concepts_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
| 437 |
+
removed_text = self.tokenizer.batch_decode(
|
| 438 |
+
edit_concepts_input_ids[:, self.tokenizer.model_max_length :]
|
| 439 |
+
)
|
| 440 |
+
logger.warning(
|
| 441 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 442 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 443 |
+
)
|
| 444 |
+
edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length]
|
| 445 |
+
edit_concepts = self.text_encoder(edit_concepts_input_ids.to(device))[0]
|
| 446 |
+
else:
|
| 447 |
+
edit_concepts = editing_prompt_embeddings.to(device).repeat(batch_size, 1, 1)
|
| 448 |
+
|
| 449 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 450 |
+
bs_embed_edit, seq_len_edit, _ = edit_concepts.shape
|
| 451 |
+
edit_concepts = edit_concepts.repeat(1, num_images_per_prompt, 1)
|
| 452 |
+
edit_concepts = edit_concepts.view(bs_embed_edit * num_images_per_prompt, seq_len_edit, -1)
|
| 453 |
+
|
| 454 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 455 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 456 |
+
# corresponds to doing no classifier free guidance.
|
| 457 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 458 |
+
# get unconditional embeddings for classifier free guidance
|
| 459 |
+
|
| 460 |
+
if do_classifier_free_guidance:
|
| 461 |
+
uncond_tokens: List[str]
|
| 462 |
+
if negative_prompt is None:
|
| 463 |
+
uncond_tokens = [""] * batch_size
|
| 464 |
+
elif type(prompt) is not type(negative_prompt):
|
| 465 |
+
raise TypeError(
|
| 466 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 467 |
+
f" {type(prompt)}."
|
| 468 |
+
)
|
| 469 |
+
elif isinstance(negative_prompt, str):
|
| 470 |
+
uncond_tokens = [negative_prompt]
|
| 471 |
+
elif batch_size != len(negative_prompt):
|
| 472 |
+
raise ValueError(
|
| 473 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 474 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 475 |
+
" the batch size of `prompt`."
|
| 476 |
+
)
|
| 477 |
+
else:
|
| 478 |
+
uncond_tokens = negative_prompt
|
| 479 |
+
|
| 480 |
+
max_length = text_input_ids.shape[-1]
|
| 481 |
+
uncond_input = self.tokenizer(
|
| 482 |
+
uncond_tokens,
|
| 483 |
+
padding="max_length",
|
| 484 |
+
max_length=max_length,
|
| 485 |
+
truncation=True,
|
| 486 |
+
return_tensors="pt",
|
| 487 |
+
)
|
| 488 |
+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
|
| 489 |
+
|
| 490 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 491 |
+
seq_len = uncond_embeddings.shape[1]
|
| 492 |
+
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
| 493 |
+
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 494 |
+
|
| 495 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 496 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 497 |
+
# to avoid doing two forward passes
|
| 498 |
+
if enable_edit_guidance:
|
| 499 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
|
| 500 |
+
else:
|
| 501 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 502 |
+
# get the initial random noise unless the user supplied it
|
| 503 |
+
|
| 504 |
+
# 4. Prepare timesteps
|
| 505 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 506 |
+
timesteps = self.scheduler.timesteps
|
| 507 |
+
|
| 508 |
+
# 5. Prepare latent variables
|
| 509 |
+
num_channels_latents = self.unet.config.in_channels
|
| 510 |
+
latents = self.prepare_latents(
|
| 511 |
+
batch_size * num_images_per_prompt,
|
| 512 |
+
num_channels_latents,
|
| 513 |
+
height,
|
| 514 |
+
width,
|
| 515 |
+
text_embeddings.dtype,
|
| 516 |
+
device,
|
| 517 |
+
generator,
|
| 518 |
+
latents,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# 6. Prepare extra step kwargs.
|
| 522 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 523 |
+
|
| 524 |
+
# Initialize edit_momentum to None
|
| 525 |
+
edit_momentum = None
|
| 526 |
+
|
| 527 |
+
self.uncond_estimates = None
|
| 528 |
+
self.text_estimates = None
|
| 529 |
+
self.edit_estimates = None
|
| 530 |
+
self.sem_guidance = None
|
| 531 |
+
|
| 532 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
| 533 |
+
# expand the latents if we are doing classifier free guidance
|
| 534 |
+
latent_model_input = (
|
| 535 |
+
torch.cat([latents] * (2 + enabled_editing_prompts)) if do_classifier_free_guidance else latents
|
| 536 |
+
)
|
| 537 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 538 |
+
|
| 539 |
+
# predict the noise residual
|
| 540 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
| 541 |
+
|
| 542 |
+
# perform guidance
|
| 543 |
+
if do_classifier_free_guidance:
|
| 544 |
+
noise_pred_out = noise_pred.chunk(2 + enabled_editing_prompts) # [b,4, 64, 64]
|
| 545 |
+
noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
|
| 546 |
+
noise_pred_edit_concepts = noise_pred_out[2:]
|
| 547 |
+
|
| 548 |
+
# default text guidance
|
| 549 |
+
noise_guidance = guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 550 |
+
# noise_guidance = (noise_pred_text - noise_pred_edit_concepts[0])
|
| 551 |
+
|
| 552 |
+
if self.uncond_estimates is None:
|
| 553 |
+
self.uncond_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_uncond.shape))
|
| 554 |
+
self.uncond_estimates[i] = noise_pred_uncond.detach().cpu()
|
| 555 |
+
|
| 556 |
+
if self.text_estimates is None:
|
| 557 |
+
self.text_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape))
|
| 558 |
+
self.text_estimates[i] = noise_pred_text.detach().cpu()
|
| 559 |
+
|
| 560 |
+
if self.edit_estimates is None and enable_edit_guidance:
|
| 561 |
+
self.edit_estimates = torch.zeros(
|
| 562 |
+
(num_inference_steps + 1, len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape)
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
if self.sem_guidance is None:
|
| 566 |
+
self.sem_guidance = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape))
|
| 567 |
+
|
| 568 |
+
if edit_momentum is None:
|
| 569 |
+
edit_momentum = torch.zeros_like(noise_guidance)
|
| 570 |
+
|
| 571 |
+
if enable_edit_guidance:
|
| 572 |
+
concept_weights = torch.zeros(
|
| 573 |
+
(len(noise_pred_edit_concepts), noise_guidance.shape[0]),
|
| 574 |
+
device=device,
|
| 575 |
+
dtype=noise_guidance.dtype,
|
| 576 |
+
)
|
| 577 |
+
noise_guidance_edit = torch.zeros(
|
| 578 |
+
(len(noise_pred_edit_concepts), *noise_guidance.shape),
|
| 579 |
+
device=device,
|
| 580 |
+
dtype=noise_guidance.dtype,
|
| 581 |
+
)
|
| 582 |
+
# noise_guidance_edit = torch.zeros_like(noise_guidance)
|
| 583 |
+
warmup_inds = []
|
| 584 |
+
for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):
|
| 585 |
+
self.edit_estimates[i, c] = noise_pred_edit_concept
|
| 586 |
+
if isinstance(edit_guidance_scale, list):
|
| 587 |
+
edit_guidance_scale_c = edit_guidance_scale[c]
|
| 588 |
+
else:
|
| 589 |
+
edit_guidance_scale_c = edit_guidance_scale
|
| 590 |
+
|
| 591 |
+
if isinstance(edit_threshold, list):
|
| 592 |
+
edit_threshold_c = edit_threshold[c]
|
| 593 |
+
else:
|
| 594 |
+
edit_threshold_c = edit_threshold
|
| 595 |
+
if isinstance(reverse_editing_direction, list):
|
| 596 |
+
reverse_editing_direction_c = reverse_editing_direction[c]
|
| 597 |
+
else:
|
| 598 |
+
reverse_editing_direction_c = reverse_editing_direction
|
| 599 |
+
if edit_weights:
|
| 600 |
+
edit_weight_c = edit_weights[c]
|
| 601 |
+
else:
|
| 602 |
+
edit_weight_c = 1.0
|
| 603 |
+
if isinstance(edit_warmup_steps, list):
|
| 604 |
+
edit_warmup_steps_c = edit_warmup_steps[c]
|
| 605 |
+
else:
|
| 606 |
+
edit_warmup_steps_c = edit_warmup_steps
|
| 607 |
+
|
| 608 |
+
if isinstance(edit_cooldown_steps, list):
|
| 609 |
+
edit_cooldown_steps_c = edit_cooldown_steps[c]
|
| 610 |
+
elif edit_cooldown_steps is None:
|
| 611 |
+
edit_cooldown_steps_c = i + 1
|
| 612 |
+
else:
|
| 613 |
+
edit_cooldown_steps_c = edit_cooldown_steps
|
| 614 |
+
if i >= edit_warmup_steps_c:
|
| 615 |
+
warmup_inds.append(c)
|
| 616 |
+
if i >= edit_cooldown_steps_c:
|
| 617 |
+
noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept)
|
| 618 |
+
continue
|
| 619 |
+
|
| 620 |
+
noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond
|
| 621 |
+
# tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3))
|
| 622 |
+
tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3))
|
| 623 |
+
|
| 624 |
+
tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts)
|
| 625 |
+
if reverse_editing_direction_c:
|
| 626 |
+
noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1
|
| 627 |
+
concept_weights[c, :] = tmp_weights
|
| 628 |
+
|
| 629 |
+
noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c
|
| 630 |
+
|
| 631 |
+
# torch.quantile function expects float32
|
| 632 |
+
if noise_guidance_edit_tmp.dtype == torch.float32:
|
| 633 |
+
tmp = torch.quantile(
|
| 634 |
+
torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2),
|
| 635 |
+
edit_threshold_c,
|
| 636 |
+
dim=2,
|
| 637 |
+
keepdim=False,
|
| 638 |
+
)
|
| 639 |
+
else:
|
| 640 |
+
tmp = torch.quantile(
|
| 641 |
+
torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2).to(torch.float32),
|
| 642 |
+
edit_threshold_c,
|
| 643 |
+
dim=2,
|
| 644 |
+
keepdim=False,
|
| 645 |
+
).to(noise_guidance_edit_tmp.dtype)
|
| 646 |
+
|
| 647 |
+
noise_guidance_edit_tmp = torch.where(
|
| 648 |
+
torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None, None],
|
| 649 |
+
noise_guidance_edit_tmp,
|
| 650 |
+
torch.zeros_like(noise_guidance_edit_tmp),
|
| 651 |
+
)
|
| 652 |
+
noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp
|
| 653 |
+
|
| 654 |
+
# noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp
|
| 655 |
+
|
| 656 |
+
warmup_inds = torch.tensor(warmup_inds).to(device)
|
| 657 |
+
if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0:
|
| 658 |
+
concept_weights = concept_weights.to("cpu") # Offload to cpu
|
| 659 |
+
noise_guidance_edit = noise_guidance_edit.to("cpu")
|
| 660 |
+
|
| 661 |
+
concept_weights_tmp = torch.index_select(concept_weights.to(device), 0, warmup_inds)
|
| 662 |
+
concept_weights_tmp = torch.where(
|
| 663 |
+
concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp
|
| 664 |
+
)
|
| 665 |
+
concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
|
| 666 |
+
# concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)
|
| 667 |
+
|
| 668 |
+
noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds)
|
| 669 |
+
noise_guidance_edit_tmp = torch.einsum(
|
| 670 |
+
"cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp
|
| 671 |
+
)
|
| 672 |
+
noise_guidance = noise_guidance + noise_guidance_edit_tmp
|
| 673 |
+
|
| 674 |
+
self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu()
|
| 675 |
+
|
| 676 |
+
del noise_guidance_edit_tmp
|
| 677 |
+
del concept_weights_tmp
|
| 678 |
+
concept_weights = concept_weights.to(device)
|
| 679 |
+
noise_guidance_edit = noise_guidance_edit.to(device)
|
| 680 |
+
|
| 681 |
+
concept_weights = torch.where(
|
| 682 |
+
concept_weights < 0, torch.zeros_like(concept_weights), concept_weights
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
concept_weights = torch.nan_to_num(concept_weights)
|
| 686 |
+
|
| 687 |
+
noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit)
|
| 688 |
+
noise_guidance_edit = noise_guidance_edit.to(edit_momentum.device)
|
| 689 |
+
|
| 690 |
+
noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum
|
| 691 |
+
|
| 692 |
+
edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit
|
| 693 |
+
|
| 694 |
+
if warmup_inds.shape[0] == len(noise_pred_edit_concepts):
|
| 695 |
+
noise_guidance = noise_guidance + noise_guidance_edit
|
| 696 |
+
self.sem_guidance[i] = noise_guidance_edit.detach().cpu()
|
| 697 |
+
|
| 698 |
+
if sem_guidance is not None:
|
| 699 |
+
edit_guidance = sem_guidance[i].to(device)
|
| 700 |
+
noise_guidance = noise_guidance + edit_guidance
|
| 701 |
+
|
| 702 |
+
noise_pred = noise_pred_uncond + noise_guidance
|
| 703 |
+
|
| 704 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 705 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 706 |
+
|
| 707 |
+
# call the callback, if provided
|
| 708 |
+
if callback is not None and i % callback_steps == 0:
|
| 709 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 710 |
+
callback(step_idx, t, latents)
|
| 711 |
+
|
| 712 |
+
if XLA_AVAILABLE:
|
| 713 |
+
xm.mark_step()
|
| 714 |
+
|
| 715 |
+
# 8. Post-processing
|
| 716 |
+
if not output_type == "latent":
|
| 717 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 718 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
| 719 |
+
else:
|
| 720 |
+
image = latents
|
| 721 |
+
has_nsfw_concept = None
|
| 722 |
+
|
| 723 |
+
if has_nsfw_concept is None:
|
| 724 |
+
do_denormalize = [True] * image.shape[0]
|
| 725 |
+
else:
|
| 726 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 727 |
+
|
| 728 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 729 |
+
|
| 730 |
+
if not return_dict:
|
| 731 |
+
return (image, has_nsfw_concept)
|
| 732 |
+
|
| 733 |
+
return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/shap_e/camera.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Open AI and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import Tuple
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class DifferentiableProjectiveCamera:
|
| 24 |
+
"""
|
| 25 |
+
Implements a batch, differentiable, standard pinhole camera
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
origin: torch.Tensor # [batch_size x 3]
|
| 29 |
+
x: torch.Tensor # [batch_size x 3]
|
| 30 |
+
y: torch.Tensor # [batch_size x 3]
|
| 31 |
+
z: torch.Tensor # [batch_size x 3]
|
| 32 |
+
width: int
|
| 33 |
+
height: int
|
| 34 |
+
x_fov: float
|
| 35 |
+
y_fov: float
|
| 36 |
+
shape: Tuple[int]
|
| 37 |
+
|
| 38 |
+
def __post_init__(self):
|
| 39 |
+
assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0]
|
| 40 |
+
assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3
|
| 41 |
+
assert len(self.x.shape) == len(self.y.shape) == len(self.z.shape) == len(self.origin.shape) == 2
|
| 42 |
+
|
| 43 |
+
def resolution(self):
|
| 44 |
+
return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32))
|
| 45 |
+
|
| 46 |
+
def fov(self):
|
| 47 |
+
return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32))
|
| 48 |
+
|
| 49 |
+
def get_image_coords(self) -> torch.Tensor:
|
| 50 |
+
"""
|
| 51 |
+
:return: coords of shape (width * height, 2)
|
| 52 |
+
"""
|
| 53 |
+
pixel_indices = torch.arange(self.height * self.width)
|
| 54 |
+
coords = torch.stack(
|
| 55 |
+
[
|
| 56 |
+
pixel_indices % self.width,
|
| 57 |
+
torch.div(pixel_indices, self.width, rounding_mode="trunc"),
|
| 58 |
+
],
|
| 59 |
+
axis=1,
|
| 60 |
+
)
|
| 61 |
+
return coords
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def camera_rays(self):
|
| 65 |
+
batch_size, *inner_shape = self.shape
|
| 66 |
+
inner_batch_size = int(np.prod(inner_shape))
|
| 67 |
+
|
| 68 |
+
coords = self.get_image_coords()
|
| 69 |
+
coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape])
|
| 70 |
+
rays = self.get_camera_rays(coords)
|
| 71 |
+
|
| 72 |
+
rays = rays.view(batch_size, inner_batch_size * self.height * self.width, 2, 3)
|
| 73 |
+
|
| 74 |
+
return rays
|
| 75 |
+
|
| 76 |
+
def get_camera_rays(self, coords: torch.Tensor) -> torch.Tensor:
|
| 77 |
+
batch_size, *shape, n_coords = coords.shape
|
| 78 |
+
assert n_coords == 2
|
| 79 |
+
assert batch_size == self.origin.shape[0]
|
| 80 |
+
|
| 81 |
+
flat = coords.view(batch_size, -1, 2)
|
| 82 |
+
|
| 83 |
+
res = self.resolution()
|
| 84 |
+
fov = self.fov()
|
| 85 |
+
|
| 86 |
+
fracs = (flat.float() / (res - 1)) * 2 - 1
|
| 87 |
+
fracs = fracs * torch.tan(fov / 2)
|
| 88 |
+
|
| 89 |
+
fracs = fracs.view(batch_size, -1, 2)
|
| 90 |
+
directions = (
|
| 91 |
+
self.z.view(batch_size, 1, 3)
|
| 92 |
+
+ self.x.view(batch_size, 1, 3) * fracs[:, :, :1]
|
| 93 |
+
+ self.y.view(batch_size, 1, 3) * fracs[:, :, 1:]
|
| 94 |
+
)
|
| 95 |
+
directions = directions / directions.norm(dim=-1, keepdim=True)
|
| 96 |
+
rays = torch.stack(
|
| 97 |
+
[
|
| 98 |
+
torch.broadcast_to(self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]),
|
| 99 |
+
directions,
|
| 100 |
+
],
|
| 101 |
+
dim=2,
|
| 102 |
+
)
|
| 103 |
+
return rays.view(batch_size, *shape, 2, 3)
|
| 104 |
+
|
| 105 |
+
def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera":
|
| 106 |
+
"""
|
| 107 |
+
Creates a new camera for the resized view assuming the aspect ratio does not change.
|
| 108 |
+
"""
|
| 109 |
+
assert width * self.height == height * self.width, "The aspect ratio should not change."
|
| 110 |
+
return DifferentiableProjectiveCamera(
|
| 111 |
+
origin=self.origin,
|
| 112 |
+
x=self.x,
|
| 113 |
+
y=self.y,
|
| 114 |
+
z=self.z,
|
| 115 |
+
width=width,
|
| 116 |
+
height=height,
|
| 117 |
+
x_fov=self.x_fov,
|
| 118 |
+
y_fov=self.y_fov,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def create_pan_cameras(size: int) -> DifferentiableProjectiveCamera:
|
| 123 |
+
origins = []
|
| 124 |
+
xs = []
|
| 125 |
+
ys = []
|
| 126 |
+
zs = []
|
| 127 |
+
for theta in np.linspace(0, 2 * np.pi, num=20):
|
| 128 |
+
z = np.array([np.sin(theta), np.cos(theta), -0.5])
|
| 129 |
+
z /= np.sqrt(np.sum(z**2))
|
| 130 |
+
origin = -z * 4
|
| 131 |
+
x = np.array([np.cos(theta), -np.sin(theta), 0.0])
|
| 132 |
+
y = np.cross(z, x)
|
| 133 |
+
origins.append(origin)
|
| 134 |
+
xs.append(x)
|
| 135 |
+
ys.append(y)
|
| 136 |
+
zs.append(z)
|
| 137 |
+
return DifferentiableProjectiveCamera(
|
| 138 |
+
origin=torch.from_numpy(np.stack(origins, axis=0)).float(),
|
| 139 |
+
x=torch.from_numpy(np.stack(xs, axis=0)).float(),
|
| 140 |
+
y=torch.from_numpy(np.stack(ys, axis=0)).float(),
|
| 141 |
+
z=torch.from_numpy(np.stack(zs, axis=0)).float(),
|
| 142 |
+
width=size,
|
| 143 |
+
height=size,
|
| 144 |
+
x_fov=0.7,
|
| 145 |
+
y_fov=0.7,
|
| 146 |
+
shape=(1, len(xs)),
|
| 147 |
+
)
|