Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- DeepSeek-VL2/deepseek_vl2/serve/__pycache__/__init__.cpython-312.pyc +0 -0
- DeepSeek-VL2/deepseek_vl2/serve/__pycache__/inference.cpython-312.pyc +0 -0
- DeepSeek-VL2/deepseek_vl2/serve/app_modules/__init__.py +0 -0
- DeepSeek-VL2/deepseek_vl2/serve/app_modules/__pycache__/__init__.cpython-312.pyc +0 -0
- DeepSeek-VL2/deepseek_vl2/serve/app_modules/__pycache__/overwrites.cpython-312.pyc +0 -0
- DeepSeek-VL2/deepseek_vl2/serve/app_modules/__pycache__/presets.cpython-312.pyc +0 -0
- DeepSeek-VL2/deepseek_vl2/serve/assets/Kelpy-Codos.js +100 -0
- DeepSeek-VL2/deepseek_vl2/serve/logs/20241230-104509_gradio_log.log +0 -0
- DeepSeek-VL2/deepseek_vl2/serve/logs/20241231-035201_gradio_log.log +0 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/alt_diffusion/__init__.py +69 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/alt_diffusion/modeling_roberta_series.py +141 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +993 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/alt_diffusion/pipeline_output.py +40 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/animatediff/__init__.py +59 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/audioldm2/__init__.py +66 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/audioldm2/modeling_audioldm2.py +1520 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/audioldm2/pipeline_audioldm2.py +914 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/cogvideo/__init__.py +74 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_cogvideox.py +716 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_cogvideox_image2video_vctrl.py +799 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_cogvideox_vctrl.py +739 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_output.py +20 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +392 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/kandinsky3/__init__.py +62 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +439 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +449 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/latent_consistency_models/__init__.py +65 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +796 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +729 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/lvdm/pipeline_latent_video_diffusion_model_uncond.py +277 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/lvdm/video_save.py +419 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/paint_by_example/__init__.py +68 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/paint_by_example/image_encoder.py +182 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +615 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/pixart_alpha/__init__.py +61 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +867 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/pndm/pipeline_pndm.py +119 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/score_sde_ve/__init__.py +32 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +111 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/spectrogram_diffusion/__init__.py +92 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/spectrogram_diffusion/continous_encoder.py +89 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/spectrogram_diffusion/midi_utils.py +694 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/spectrogram_diffusion/notes_encoder.py +87 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +262 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/stable_diffusion/__init__.py +294 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/stable_diffusion/convert_from_ckpt.py +1915 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/stable_diffusion/pipeline_fastdeploy_stable_diffusion_inpaint.py +582 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/stable_diffusion/pipeline_output.py +39 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/stable_diffusion/pipeline_paddleinfer_stable_diffusion.py +360 -0
- VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/stable_diffusion/pipeline_paddleinfer_stable_diffusion_img2img.py +381 -0
DeepSeek-VL2/deepseek_vl2/serve/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (176 Bytes). View file
|
|
|
DeepSeek-VL2/deepseek_vl2/serve/__pycache__/inference.cpython-312.pyc
ADDED
|
Binary file (6.66 kB). View file
|
|
|
DeepSeek-VL2/deepseek_vl2/serve/app_modules/__init__.py
ADDED
|
File without changes
|
DeepSeek-VL2/deepseek_vl2/serve/app_modules/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (188 Bytes). View file
|
|
|
DeepSeek-VL2/deepseek_vl2/serve/app_modules/__pycache__/overwrites.cpython-312.pyc
ADDED
|
Binary file (3.67 kB). View file
|
|
|
DeepSeek-VL2/deepseek_vl2/serve/app_modules/__pycache__/presets.cpython-312.pyc
ADDED
|
Binary file (2.8 kB). View file
|
|
|
DeepSeek-VL2/deepseek_vl2/serve/assets/Kelpy-Codos.js
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright (c) 2023-2024 DeepSeek.
|
| 3 |
+
*
|
| 4 |
+
* Permission is hereby granted, free of charge, to any person obtaining a copy of
|
| 5 |
+
* this software and associated documentation files (the "Software"), to deal in
|
| 6 |
+
* the Software without restriction, including without limitation the rights to
|
| 7 |
+
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
| 8 |
+
* the Software, and to permit persons to whom the Software is furnished to do so,
|
| 9 |
+
* subject to the following conditions:
|
| 10 |
+
*
|
| 11 |
+
* The above copyright notice and this permission notice shall be included in all
|
| 12 |
+
* copies or substantial portions of the Software.
|
| 13 |
+
*
|
| 14 |
+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 15 |
+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
| 16 |
+
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
| 17 |
+
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
| 18 |
+
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
| 19 |
+
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 20 |
+
*/
|
| 21 |
+
|
| 22 |
+
// ==UserScript==
|
| 23 |
+
// @name Kelpy Codos
|
| 24 |
+
// @namespace https://github.com/Keldos-Li/Kelpy-Codos
|
| 25 |
+
// @version 1.0.5
|
| 26 |
+
// @author Keldos; https://keldos.me/
|
| 27 |
+
// @description Add copy button to PRE tags before CODE tag, for Chuanhu ChatGPT especially.
|
| 28 |
+
// Based on Chuanhu ChatGPT version: ac04408 (2023-3-22)
|
| 29 |
+
// @license GPL-3.0
|
| 30 |
+
// @grant none
|
| 31 |
+
// ==/UserScript==
|
| 32 |
+
|
| 33 |
+
(function () {
|
| 34 |
+
"use strict";
|
| 35 |
+
|
| 36 |
+
function addCopyButton(pre) {
|
| 37 |
+
var code = pre.querySelector("code");
|
| 38 |
+
if (!code) {
|
| 39 |
+
return; // 如果没有找到 <code> 元素,则不添加按钮
|
| 40 |
+
}
|
| 41 |
+
var firstChild = code.firstChild;
|
| 42 |
+
if (!firstChild) {
|
| 43 |
+
return; // 如果 <code> 元素没有子节点,则不添加按钮
|
| 44 |
+
}
|
| 45 |
+
var button = document.createElement("button");
|
| 46 |
+
button.textContent = "\uD83D\uDCCE"; // 使用 📎 符号作为“复制”按钮的文本
|
| 47 |
+
button.style.position = "relative";
|
| 48 |
+
button.style.float = "right";
|
| 49 |
+
button.style.fontSize = "1em"; // 可选:调整按钮大小
|
| 50 |
+
button.style.background = "none"; // 可选:去掉背景颜色
|
| 51 |
+
button.style.border = "none"; // 可选:去掉边框
|
| 52 |
+
button.style.cursor = "pointer"; // 可选:显示指针样式
|
| 53 |
+
button.addEventListener("click", function () {
|
| 54 |
+
var range = document.createRange();
|
| 55 |
+
range.selectNodeContents(code);
|
| 56 |
+
range.setStartBefore(firstChild); // 将范围设置为第一个子节点之前
|
| 57 |
+
var selection = window.getSelection();
|
| 58 |
+
selection.removeAllRanges();
|
| 59 |
+
selection.addRange(range);
|
| 60 |
+
|
| 61 |
+
try {
|
| 62 |
+
var success = document.execCommand("copy");
|
| 63 |
+
if (success) {
|
| 64 |
+
button.textContent = "\u2714";
|
| 65 |
+
setTimeout(function () {
|
| 66 |
+
button.textContent = "\uD83D\uDCCE"; // 恢复按钮为“复制”
|
| 67 |
+
}, 2000);
|
| 68 |
+
} else {
|
| 69 |
+
button.textContent = "\u2716";
|
| 70 |
+
}
|
| 71 |
+
} catch (e) {
|
| 72 |
+
console.error(e);
|
| 73 |
+
button.textContent = "\u2716";
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
selection.removeAllRanges();
|
| 77 |
+
});
|
| 78 |
+
code.insertBefore(button, firstChild); // 将按钮插入到第一个子元素之前
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
function handleNewElements(mutationsList, observer) {
|
| 82 |
+
for (var mutation of mutationsList) {
|
| 83 |
+
if (mutation.type === "childList") {
|
| 84 |
+
for (var node of mutation.addedNodes) {
|
| 85 |
+
if (node.nodeName === "PRE") {
|
| 86 |
+
addCopyButton(node);
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
var observer = new MutationObserver(handleNewElements);
|
| 94 |
+
observer.observe(document.documentElement, {
|
| 95 |
+
childList: true,
|
| 96 |
+
subtree: true,
|
| 97 |
+
});
|
| 98 |
+
|
| 99 |
+
document.querySelectorAll("pre").forEach(addCopyButton);
|
| 100 |
+
})();
|
DeepSeek-VL2/deepseek_vl2/serve/logs/20241230-104509_gradio_log.log
ADDED
|
File without changes
|
DeepSeek-VL2/deepseek_vl2/serve/logs/20241231-035201_gradio_log.log
ADDED
|
File without changes
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/alt_diffusion/__init__.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
from ...utils import (
|
| 18 |
+
PPDIFFUSERS_SLOW_IMPORT,
|
| 19 |
+
OptionalDependencyNotAvailable,
|
| 20 |
+
_LazyModule,
|
| 21 |
+
get_objects_from_module,
|
| 22 |
+
is_paddle_available,
|
| 23 |
+
is_paddlenlp_available,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
_dummy_objects = {}
|
| 27 |
+
_import_structure = {}
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 31 |
+
raise OptionalDependencyNotAvailable()
|
| 32 |
+
except OptionalDependencyNotAvailable:
|
| 33 |
+
from ...utils import dummy_paddle_and_paddlenlp_objects
|
| 34 |
+
|
| 35 |
+
_dummy_objects.update(get_objects_from_module(dummy_paddle_and_paddlenlp_objects))
|
| 36 |
+
else:
|
| 37 |
+
_import_structure["modeling_roberta_series"] = ["RobertaSeriesModelWithTransformation", "RobertaSeriesConfig"]
|
| 38 |
+
_import_structure["pipeline_alt_diffusion"] = ["AltDiffusionPipeline"]
|
| 39 |
+
_import_structure["pipeline_alt_diffusion_img2img"] = ["AltDiffusionImg2ImgPipeline"]
|
| 40 |
+
|
| 41 |
+
_import_structure["pipeline_output"] = ["AltDiffusionPipelineOutput"]
|
| 42 |
+
|
| 43 |
+
if TYPE_CHECKING or PPDIFFUSERS_SLOW_IMPORT:
|
| 44 |
+
try:
|
| 45 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 46 |
+
raise OptionalDependencyNotAvailable()
|
| 47 |
+
except OptionalDependencyNotAvailable:
|
| 48 |
+
from ...utils.dummy_paddle_and_paddlenlp_objects import *
|
| 49 |
+
|
| 50 |
+
else:
|
| 51 |
+
from .modeling_roberta_series import (
|
| 52 |
+
RobertaSeriesConfig,
|
| 53 |
+
RobertaSeriesModelWithTransformation,
|
| 54 |
+
)
|
| 55 |
+
from .pipeline_alt_diffusion import AltDiffusionPipeline
|
| 56 |
+
from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline
|
| 57 |
+
from .pipeline_output import AltDiffusionPipelineOutput
|
| 58 |
+
|
| 59 |
+
else:
|
| 60 |
+
import sys
|
| 61 |
+
|
| 62 |
+
sys.modules[__name__] = _LazyModule(
|
| 63 |
+
__name__,
|
| 64 |
+
globals()["__file__"],
|
| 65 |
+
_import_structure,
|
| 66 |
+
module_spec=__spec__,
|
| 67 |
+
)
|
| 68 |
+
for name, value in _dummy_objects.items():
|
| 69 |
+
setattr(sys.modules[__name__], name, value)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/alt_diffusion/modeling_roberta_series.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import Optional, Tuple
|
| 17 |
+
|
| 18 |
+
import paddle
|
| 19 |
+
from paddle import nn
|
| 20 |
+
from paddlenlp.transformers.model_outputs import ModelOutput
|
| 21 |
+
|
| 22 |
+
from ppdiffusers.transformers import (
|
| 23 |
+
XLMRobertaConfig,
|
| 24 |
+
XLMRobertaModel,
|
| 25 |
+
XLMRobertaPretrainedModel,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class TransformationModelOutput(ModelOutput):
|
| 31 |
+
"""
|
| 32 |
+
Base class for text model's outputs that also contains a pooling of the last hidden states.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
text_embeds (`paddle.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
|
| 36 |
+
The text embeddings obtained by applying the projection layer to the pooler_output.
|
| 37 |
+
last_hidden_state (`paddle.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 38 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 39 |
+
hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 40 |
+
Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 41 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 42 |
+
|
| 43 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 44 |
+
attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 45 |
+
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 46 |
+
sequence_length)`.
|
| 47 |
+
|
| 48 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 49 |
+
heads.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
projection_state: Optional[paddle.Tensor] = None
|
| 53 |
+
last_hidden_state: paddle.Tensor = None
|
| 54 |
+
hidden_states: Optional[Tuple[paddle.Tensor]] = None
|
| 55 |
+
attentions: Optional[Tuple[paddle.Tensor]] = None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class RobertaSeriesConfig(XLMRobertaConfig):
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
pad_token_id=1,
|
| 62 |
+
bos_token_id=0,
|
| 63 |
+
eos_token_id=2,
|
| 64 |
+
project_dim=512,
|
| 65 |
+
pooler_fn="cls",
|
| 66 |
+
learn_encoder=False,
|
| 67 |
+
use_attention_mask=True,
|
| 68 |
+
**kwargs,
|
| 69 |
+
):
|
| 70 |
+
kwargs["return_dict"] = kwargs.pop("return_dict", True)
|
| 71 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 72 |
+
self.project_dim = project_dim
|
| 73 |
+
self.pooler_fn = pooler_fn
|
| 74 |
+
self.learn_encoder = learn_encoder
|
| 75 |
+
self.use_attention_mask = use_attention_mask
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class RobertaSeriesModelWithTransformation(XLMRobertaPretrainedModel):
|
| 79 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"]
|
| 80 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
| 81 |
+
base_model_prefix = "roberta"
|
| 82 |
+
config_class = RobertaSeriesConfig
|
| 83 |
+
|
| 84 |
+
def __init__(self, config):
|
| 85 |
+
super().__init__(config)
|
| 86 |
+
self.roberta = XLMRobertaModel(config)
|
| 87 |
+
self.transformation = nn.Linear(config.hidden_size, config.project_dim)
|
| 88 |
+
self.has_pre_transformation = getattr(config, "has_pre_transformation", False)
|
| 89 |
+
if self.has_pre_transformation:
|
| 90 |
+
self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
|
| 91 |
+
self.pre_LN = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
|
| 92 |
+
self.post_init()
|
| 93 |
+
|
| 94 |
+
def forward(
|
| 95 |
+
self,
|
| 96 |
+
input_ids: Optional[paddle.Tensor] = None,
|
| 97 |
+
attention_mask: Optional[paddle.Tensor] = None,
|
| 98 |
+
token_type_ids: Optional[paddle.Tensor] = None,
|
| 99 |
+
position_ids: Optional[paddle.Tensor] = None,
|
| 100 |
+
inputs_embeds: Optional[paddle.Tensor] = None,
|
| 101 |
+
encoder_hidden_states: Optional[paddle.Tensor] = None,
|
| 102 |
+
encoder_attention_mask: Optional[paddle.Tensor] = None,
|
| 103 |
+
output_attentions: Optional[bool] = None,
|
| 104 |
+
return_dict: Optional[bool] = None,
|
| 105 |
+
output_hidden_states: Optional[bool] = None,
|
| 106 |
+
):
|
| 107 |
+
|
| 108 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 109 |
+
|
| 110 |
+
outputs = self.roberta(
|
| 111 |
+
input_ids=input_ids,
|
| 112 |
+
attention_mask=attention_mask,
|
| 113 |
+
token_type_ids=token_type_ids,
|
| 114 |
+
position_ids=position_ids,
|
| 115 |
+
inputs_embeds=inputs_embeds,
|
| 116 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 117 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 118 |
+
output_attentions=output_attentions,
|
| 119 |
+
output_hidden_states=True if self.has_pre_transformation else output_hidden_states,
|
| 120 |
+
return_dict=return_dict,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if self.has_pre_transformation:
|
| 124 |
+
sequence_output2 = outputs["hidden_states"][-2]
|
| 125 |
+
sequence_output2 = self.pre_LN(sequence_output2)
|
| 126 |
+
projection_state2 = self.transformation_pre(sequence_output2)
|
| 127 |
+
|
| 128 |
+
return TransformationModelOutput(
|
| 129 |
+
projection_state=projection_state2,
|
| 130 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 131 |
+
hidden_states=outputs.hidden_states,
|
| 132 |
+
attentions=outputs.attentions,
|
| 133 |
+
)
|
| 134 |
+
else:
|
| 135 |
+
projection_state = self.transformation(outputs.last_hidden_state)
|
| 136 |
+
return TransformationModelOutput(
|
| 137 |
+
projection_state=projection_state,
|
| 138 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 139 |
+
hidden_states=outputs.hidden_states,
|
| 140 |
+
attentions=outputs.attentions,
|
| 141 |
+
)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
ADDED
|
@@ -0,0 +1,993 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import paddle
|
| 20 |
+
import PIL.Image
|
| 21 |
+
from packaging import version
|
| 22 |
+
|
| 23 |
+
from ppdiffusers.transformers import (
|
| 24 |
+
CLIPImageProcessor,
|
| 25 |
+
CLIPVisionModelWithProjection,
|
| 26 |
+
XLMRobertaTokenizer,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
from ...configuration_utils import FrozenDict
|
| 30 |
+
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
| 31 |
+
from ...loaders import (
|
| 32 |
+
FromSingleFileMixin,
|
| 33 |
+
IPAdapterMixin,
|
| 34 |
+
LoraLoaderMixin,
|
| 35 |
+
TextualInversionLoaderMixin,
|
| 36 |
+
)
|
| 37 |
+
from ...models import AutoencoderKL, UNet2DConditionModel
|
| 38 |
+
from ...models.lora import adjust_lora_scale_text_encoder
|
| 39 |
+
from ...schedulers import KarrasDiffusionSchedulers
|
| 40 |
+
from ...utils import (
|
| 41 |
+
PIL_INTERPOLATION,
|
| 42 |
+
USE_PEFT_BACKEND,
|
| 43 |
+
deprecate,
|
| 44 |
+
logging,
|
| 45 |
+
replace_example_docstring,
|
| 46 |
+
scale_lora_layers,
|
| 47 |
+
unscale_lora_layers,
|
| 48 |
+
)
|
| 49 |
+
from ...utils.paddle_utils import randn_tensor
|
| 50 |
+
from ..pipeline_utils import DiffusionPipeline
|
| 51 |
+
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 52 |
+
from .modeling_roberta_series import RobertaSeriesModelWithTransformation
|
| 53 |
+
from .pipeline_output import AltDiffusionPipelineOutput
|
| 54 |
+
|
| 55 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 56 |
+
|
| 57 |
+
EXAMPLE_DOC_STRING = """
|
| 58 |
+
Examples:
|
| 59 |
+
```py
|
| 60 |
+
>>> import requests
|
| 61 |
+
>>> import paddle
|
| 62 |
+
>>> from PIL import Image
|
| 63 |
+
>>> from io import BytesIO
|
| 64 |
+
|
| 65 |
+
>>> from ppdiffusers import AltDiffusionImg2ImgPipeline
|
| 66 |
+
|
| 67 |
+
>>> model_id_or_path = "BAAI/AltDiffusion-m9"
|
| 68 |
+
>>> pipe = AltDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, paddle_dtype=paddle.float16)
|
| 69 |
+
|
| 70 |
+
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
| 71 |
+
|
| 72 |
+
>>> response = requests.get(url)
|
| 73 |
+
>>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 74 |
+
>>> init_image = init_image.resize((768, 512))
|
| 75 |
+
|
| 76 |
+
>>> # "A fantasy landscape, trending on artstation"
|
| 77 |
+
>>> prompt = "幻想风景, artstation"
|
| 78 |
+
|
| 79 |
+
>>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
|
| 80 |
+
>>> images[0].save("幻想风景.png")
|
| 81 |
+
```
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 86 |
+
def retrieve_latents(
|
| 87 |
+
encoder_output: paddle.Tensor, generator: Optional[paddle.Generator] = None, sample_mode: str = "sample"
|
| 88 |
+
):
|
| 89 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 90 |
+
return encoder_output.latent_dist.sample(generator)
|
| 91 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 92 |
+
return encoder_output.latent_dist.mode()
|
| 93 |
+
elif hasattr(encoder_output, "latents"):
|
| 94 |
+
return encoder_output.latents
|
| 95 |
+
else:
|
| 96 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
| 100 |
+
def preprocess(image):
|
| 101 |
+
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
|
| 102 |
+
deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
|
| 103 |
+
if isinstance(image, paddle.Tensor):
|
| 104 |
+
return image
|
| 105 |
+
elif isinstance(image, PIL.Image.Image):
|
| 106 |
+
image = [image]
|
| 107 |
+
|
| 108 |
+
if isinstance(image[0], PIL.Image.Image):
|
| 109 |
+
w, h = image[0].size
|
| 110 |
+
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
| 111 |
+
|
| 112 |
+
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
| 113 |
+
image = np.concatenate(image, axis=0)
|
| 114 |
+
image = np.array(image).astype(np.float32) / 255.0
|
| 115 |
+
image = image.transpose(0, 3, 1, 2)
|
| 116 |
+
image = 2.0 * image - 1.0
|
| 117 |
+
image = paddle.to_tensor(image)
|
| 118 |
+
elif isinstance(image[0], paddle.Tensor):
|
| 119 |
+
image = paddle.concat(image, axis=0)
|
| 120 |
+
return image
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 124 |
+
def retrieve_timesteps(
|
| 125 |
+
scheduler,
|
| 126 |
+
num_inference_steps: Optional[int] = None,
|
| 127 |
+
timesteps: Optional[List[int]] = None,
|
| 128 |
+
**kwargs,
|
| 129 |
+
):
|
| 130 |
+
"""
|
| 131 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 132 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
scheduler (`SchedulerMixin`):
|
| 136 |
+
The scheduler to get timesteps from.
|
| 137 |
+
num_inference_steps (`int`):
|
| 138 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
| 139 |
+
`timesteps` must be `None`.
|
| 140 |
+
timesteps (`List[int]`, *optional*):
|
| 141 |
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
| 142 |
+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
| 143 |
+
must be `None`.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
`Tuple[paddle.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 147 |
+
second element is the number of inference steps.
|
| 148 |
+
"""
|
| 149 |
+
if timesteps is not None:
|
| 150 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 151 |
+
if not accepts_timesteps:
|
| 152 |
+
raise ValueError(
|
| 153 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 154 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 155 |
+
)
|
| 156 |
+
scheduler.set_timesteps(timesteps=timesteps, **kwargs)
|
| 157 |
+
timesteps = scheduler.timesteps
|
| 158 |
+
num_inference_steps = len(timesteps)
|
| 159 |
+
else:
|
| 160 |
+
scheduler.set_timesteps(num_inference_steps, **kwargs)
|
| 161 |
+
timesteps = scheduler.timesteps
|
| 162 |
+
return timesteps, num_inference_steps
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
|
| 166 |
+
class AltDiffusionImg2ImgPipeline(
|
| 167 |
+
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
|
| 168 |
+
):
|
| 169 |
+
r"""
|
| 170 |
+
Pipeline for text-guided image-to-image generation using Alt Diffusion.
|
| 171 |
+
|
| 172 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 173 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 174 |
+
|
| 175 |
+
The pipeline also inherits the following loading methods:
|
| 176 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
| 177 |
+
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 178 |
+
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
| 179 |
+
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
| 180 |
+
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
vae ([`AutoencoderKL`]):
|
| 184 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 185 |
+
text_encoder ([`~transformers.RobertaSeriesModelWithTransformation`]):
|
| 186 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
| 187 |
+
tokenizer ([`~transformers.XLMRobertaTokenizer`]):
|
| 188 |
+
A `XLMRobertaTokenizer` to tokenize text.
|
| 189 |
+
unet ([`UNet2DConditionModel`]):
|
| 190 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
| 191 |
+
scheduler ([`SchedulerMixin`]):
|
| 192 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 193 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 194 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
| 195 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 196 |
+
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
| 197 |
+
about a model's potential harms.
|
| 198 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
| 199 |
+
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
model_cpu_offload_seq = "text_encoder->unet->vae"
|
| 203 |
+
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
| 204 |
+
_exclude_from_cpu_offload = ["safety_checker"]
|
| 205 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 206 |
+
|
| 207 |
+
def __init__(
|
| 208 |
+
self,
|
| 209 |
+
vae: AutoencoderKL,
|
| 210 |
+
text_encoder: RobertaSeriesModelWithTransformation,
|
| 211 |
+
tokenizer: XLMRobertaTokenizer,
|
| 212 |
+
unet: UNet2DConditionModel,
|
| 213 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 214 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 215 |
+
feature_extractor: CLIPImageProcessor,
|
| 216 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 217 |
+
requires_safety_checker: bool = True,
|
| 218 |
+
):
|
| 219 |
+
super().__init__()
|
| 220 |
+
|
| 221 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
| 222 |
+
deprecation_message = (
|
| 223 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
| 224 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
| 225 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
| 226 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
| 227 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
| 228 |
+
" file"
|
| 229 |
+
)
|
| 230 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
| 231 |
+
new_config = dict(scheduler.config)
|
| 232 |
+
new_config["steps_offset"] = 1
|
| 233 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
| 234 |
+
|
| 235 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
| 236 |
+
deprecation_message = (
|
| 237 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
| 238 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
| 239 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
| 240 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
| 241 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
| 242 |
+
)
|
| 243 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
| 244 |
+
new_config = dict(scheduler.config)
|
| 245 |
+
new_config["clip_sample"] = False
|
| 246 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
| 247 |
+
|
| 248 |
+
if safety_checker is None and requires_safety_checker:
|
| 249 |
+
logger.warning(
|
| 250 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
| 251 |
+
" that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
|
| 252 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
| 253 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
| 254 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
| 255 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
if safety_checker is not None and feature_extractor is None:
|
| 259 |
+
raise ValueError(
|
| 260 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
| 261 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_ppdiffusers_version") and version.parse(
|
| 265 |
+
version.parse(unet.config._ppdiffusers_version).base_version
|
| 266 |
+
) < version.parse("0.9.0.dev0")
|
| 267 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
| 268 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
| 269 |
+
deprecation_message = (
|
| 270 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
| 271 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
| 272 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
| 273 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
| 274 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
| 275 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
| 276 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
| 277 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
| 278 |
+
" the `unet/config.json` file"
|
| 279 |
+
)
|
| 280 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
| 281 |
+
new_config = dict(unet.config)
|
| 282 |
+
new_config["sample_size"] = 64
|
| 283 |
+
unet._internal_dict = FrozenDict(new_config)
|
| 284 |
+
|
| 285 |
+
self.register_modules(
|
| 286 |
+
vae=vae,
|
| 287 |
+
text_encoder=text_encoder,
|
| 288 |
+
tokenizer=tokenizer,
|
| 289 |
+
unet=unet,
|
| 290 |
+
scheduler=scheduler,
|
| 291 |
+
safety_checker=safety_checker,
|
| 292 |
+
feature_extractor=feature_extractor,
|
| 293 |
+
image_encoder=image_encoder,
|
| 294 |
+
)
|
| 295 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 296 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 297 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
| 298 |
+
|
| 299 |
+
def _encode_prompt(
|
| 300 |
+
self,
|
| 301 |
+
prompt,
|
| 302 |
+
num_images_per_prompt,
|
| 303 |
+
do_classifier_free_guidance,
|
| 304 |
+
negative_prompt=None,
|
| 305 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 306 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 307 |
+
lora_scale: Optional[float] = None,
|
| 308 |
+
**kwargs,
|
| 309 |
+
):
|
| 310 |
+
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
| 311 |
+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
| 312 |
+
|
| 313 |
+
prompt_embeds_tuple = self.encode_prompt(
|
| 314 |
+
prompt=prompt,
|
| 315 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 316 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 317 |
+
negative_prompt=negative_prompt,
|
| 318 |
+
prompt_embeds=prompt_embeds,
|
| 319 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 320 |
+
lora_scale=lora_scale,
|
| 321 |
+
**kwargs,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# concatenate for backwards comp
|
| 325 |
+
prompt_embeds = paddle.concat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
| 326 |
+
|
| 327 |
+
return prompt_embeds
|
| 328 |
+
|
| 329 |
+
def encode_prompt(
|
| 330 |
+
self,
|
| 331 |
+
prompt,
|
| 332 |
+
num_images_per_prompt,
|
| 333 |
+
do_classifier_free_guidance,
|
| 334 |
+
negative_prompt=None,
|
| 335 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 336 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 337 |
+
lora_scale: Optional[float] = None,
|
| 338 |
+
clip_skip: Optional[int] = None,
|
| 339 |
+
):
|
| 340 |
+
r"""
|
| 341 |
+
Encodes the prompt into text encoder hidden states.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 345 |
+
prompt to be encoded
|
| 346 |
+
num_images_per_prompt (`int`):
|
| 347 |
+
number of images that should be generated per prompt
|
| 348 |
+
do_classifier_free_guidance (`bool`):
|
| 349 |
+
whether to use classifier free guidance or not
|
| 350 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 351 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 352 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 353 |
+
less than `1`).
|
| 354 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 355 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 356 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 357 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 358 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 359 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 360 |
+
argument.
|
| 361 |
+
lora_scale (`float`, *optional*):
|
| 362 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 363 |
+
clip_skip (`int`, *optional*):
|
| 364 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 365 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 366 |
+
"""
|
| 367 |
+
# set lora scale so that monkey patched LoRA
|
| 368 |
+
# function of text encoder can correctly access it
|
| 369 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
| 370 |
+
self._lora_scale = lora_scale
|
| 371 |
+
|
| 372 |
+
# dynamically adjust the LoRA scale
|
| 373 |
+
if not USE_PEFT_BACKEND:
|
| 374 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
| 375 |
+
else:
|
| 376 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 377 |
+
|
| 378 |
+
if prompt is not None and isinstance(prompt, str):
|
| 379 |
+
batch_size = 1
|
| 380 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 381 |
+
batch_size = len(prompt)
|
| 382 |
+
else:
|
| 383 |
+
batch_size = prompt_embeds.shape[0]
|
| 384 |
+
|
| 385 |
+
if prompt_embeds is None:
|
| 386 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 387 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 388 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 389 |
+
|
| 390 |
+
text_inputs = self.tokenizer(
|
| 391 |
+
prompt,
|
| 392 |
+
padding="max_length",
|
| 393 |
+
max_length=self.tokenizer.model_max_length,
|
| 394 |
+
truncation=True,
|
| 395 |
+
return_tensors="pd",
|
| 396 |
+
)
|
| 397 |
+
text_input_ids = text_inputs.input_ids
|
| 398 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pd").input_ids
|
| 399 |
+
|
| 400 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not paddle.equal_all(
|
| 401 |
+
text_input_ids, untruncated_ids
|
| 402 |
+
):
|
| 403 |
+
removed_text = self.tokenizer.batch_decode(
|
| 404 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 405 |
+
)
|
| 406 |
+
logger.warning(
|
| 407 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 408 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 412 |
+
attention_mask = text_inputs.attention_mask
|
| 413 |
+
else:
|
| 414 |
+
attention_mask = None
|
| 415 |
+
|
| 416 |
+
if clip_skip is None:
|
| 417 |
+
prompt_embeds = self.text_encoder(text_input_ids, attention_mask=attention_mask)
|
| 418 |
+
prompt_embeds = prompt_embeds[0]
|
| 419 |
+
else:
|
| 420 |
+
prompt_embeds = self.text_encoder(
|
| 421 |
+
text_input_ids, attention_mask=attention_mask, output_hidden_states=True
|
| 422 |
+
)
|
| 423 |
+
# Access the `hidden_states` first, that contains a tuple of
|
| 424 |
+
# all the hidden states from the encoder layers. Then index into
|
| 425 |
+
# the tuple to access the hidden states from the desired layer.
|
| 426 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
| 427 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
| 428 |
+
# representations. The `last_hidden_states` that we typically use for
|
| 429 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
| 430 |
+
# layer.
|
| 431 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
| 432 |
+
|
| 433 |
+
if self.text_encoder is not None:
|
| 434 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
| 435 |
+
elif self.unet is not None:
|
| 436 |
+
prompt_embeds_dtype = self.unet.dtype
|
| 437 |
+
else:
|
| 438 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
| 439 |
+
|
| 440 |
+
prompt_embeds = prompt_embeds.cast(dtype=prompt_embeds_dtype)
|
| 441 |
+
|
| 442 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 443 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 444 |
+
prompt_embeds = prompt_embeds.tile([1, num_images_per_prompt, 1])
|
| 445 |
+
prompt_embeds = prompt_embeds.reshape([bs_embed * num_images_per_prompt, seq_len, -1])
|
| 446 |
+
|
| 447 |
+
# get unconditional embeddings for classifier free guidance
|
| 448 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 449 |
+
uncond_tokens: List[str]
|
| 450 |
+
if negative_prompt is None:
|
| 451 |
+
uncond_tokens = [""] * batch_size
|
| 452 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
| 453 |
+
raise TypeError(
|
| 454 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 455 |
+
f" {type(prompt)}."
|
| 456 |
+
)
|
| 457 |
+
elif isinstance(negative_prompt, str):
|
| 458 |
+
uncond_tokens = [negative_prompt]
|
| 459 |
+
elif batch_size != len(negative_prompt):
|
| 460 |
+
raise ValueError(
|
| 461 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 462 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 463 |
+
" the batch size of `prompt`."
|
| 464 |
+
)
|
| 465 |
+
else:
|
| 466 |
+
uncond_tokens = negative_prompt
|
| 467 |
+
|
| 468 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 469 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 470 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
| 471 |
+
|
| 472 |
+
max_length = prompt_embeds.shape[1]
|
| 473 |
+
uncond_input = self.tokenizer(
|
| 474 |
+
uncond_tokens,
|
| 475 |
+
padding="max_length",
|
| 476 |
+
max_length=max_length,
|
| 477 |
+
truncation=True,
|
| 478 |
+
return_tensors="pd",
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 482 |
+
attention_mask = uncond_input.attention_mask
|
| 483 |
+
else:
|
| 484 |
+
attention_mask = None
|
| 485 |
+
|
| 486 |
+
negative_prompt_embeds = self.text_encoder(
|
| 487 |
+
uncond_input.input_ids,
|
| 488 |
+
attention_mask=attention_mask,
|
| 489 |
+
)
|
| 490 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 491 |
+
|
| 492 |
+
if do_classifier_free_guidance:
|
| 493 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 494 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 495 |
+
|
| 496 |
+
negative_prompt_embeds = negative_prompt_embeds.cast(dtype=prompt_embeds_dtype)
|
| 497 |
+
|
| 498 |
+
negative_prompt_embeds = negative_prompt_embeds.tile([1, num_images_per_prompt, 1])
|
| 499 |
+
negative_prompt_embeds = negative_prompt_embeds.reshape([batch_size * num_images_per_prompt, seq_len, -1])
|
| 500 |
+
|
| 501 |
+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 502 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 503 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 504 |
+
|
| 505 |
+
return prompt_embeds, negative_prompt_embeds
|
| 506 |
+
|
| 507 |
+
def encode_image(self, image, num_images_per_prompt):
|
| 508 |
+
dtype = next(self.image_encoder.named_parameters())[1].dtype
|
| 509 |
+
|
| 510 |
+
if not isinstance(image, paddle.Tensor):
|
| 511 |
+
image = self.feature_extractor(image, return_tensors="pd").pixel_values
|
| 512 |
+
|
| 513 |
+
image = image.cast(dtype=dtype)
|
| 514 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 515 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, axis=0)
|
| 516 |
+
|
| 517 |
+
uncond_image_embeds = paddle.zeros_like(image_embeds)
|
| 518 |
+
return image_embeds, uncond_image_embeds
|
| 519 |
+
|
| 520 |
+
def run_safety_checker(self, image, dtype):
|
| 521 |
+
if self.safety_checker is None:
|
| 522 |
+
has_nsfw_concept = None
|
| 523 |
+
else:
|
| 524 |
+
if paddle.is_tensor(image):
|
| 525 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
| 526 |
+
else:
|
| 527 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
| 528 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pd")
|
| 529 |
+
image, has_nsfw_concept = self.safety_checker(
|
| 530 |
+
images=image, clip_input=safety_checker_input.pixel_values.cast(dtype)
|
| 531 |
+
)
|
| 532 |
+
return image, has_nsfw_concept
|
| 533 |
+
|
| 534 |
+
def decode_latents(self, latents):
|
| 535 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
| 536 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
| 537 |
+
|
| 538 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 539 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 540 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 541 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 542 |
+
image = image.cast(dtype=paddle.float32).transpose([0, 2, 3, 1]).cpu().numpy()
|
| 543 |
+
return image
|
| 544 |
+
|
| 545 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 546 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 547 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 548 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 549 |
+
# and should be between [0, 1]
|
| 550 |
+
|
| 551 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 552 |
+
extra_step_kwargs = {}
|
| 553 |
+
if accepts_eta:
|
| 554 |
+
extra_step_kwargs["eta"] = eta
|
| 555 |
+
|
| 556 |
+
# check if the scheduler accepts generator
|
| 557 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 558 |
+
if accepts_generator:
|
| 559 |
+
extra_step_kwargs["generator"] = generator
|
| 560 |
+
return extra_step_kwargs
|
| 561 |
+
|
| 562 |
+
def check_inputs(
|
| 563 |
+
self,
|
| 564 |
+
prompt,
|
| 565 |
+
strength,
|
| 566 |
+
callback_steps,
|
| 567 |
+
negative_prompt=None,
|
| 568 |
+
prompt_embeds=None,
|
| 569 |
+
negative_prompt_embeds=None,
|
| 570 |
+
callback_on_step_end_tensor_inputs=None,
|
| 571 |
+
):
|
| 572 |
+
if strength < 0 or strength > 1:
|
| 573 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
| 574 |
+
|
| 575 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
| 576 |
+
raise ValueError(
|
| 577 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 578 |
+
f" {type(callback_steps)}."
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 582 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 583 |
+
):
|
| 584 |
+
raise ValueError(
|
| 585 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 586 |
+
)
|
| 587 |
+
if prompt is not None and prompt_embeds is not None:
|
| 588 |
+
raise ValueError(
|
| 589 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 590 |
+
" only forward one of the two."
|
| 591 |
+
)
|
| 592 |
+
elif prompt is None and prompt_embeds is None:
|
| 593 |
+
raise ValueError(
|
| 594 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 595 |
+
)
|
| 596 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 597 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 598 |
+
|
| 599 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 600 |
+
raise ValueError(
|
| 601 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 602 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 606 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 607 |
+
raise ValueError(
|
| 608 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 609 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 610 |
+
f" {negative_prompt_embeds.shape}."
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
def get_timesteps(self, num_inference_steps, strength):
|
| 614 |
+
# get the original timestep using init_timestep
|
| 615 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 616 |
+
|
| 617 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 618 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
| 619 |
+
|
| 620 |
+
return timesteps, num_inference_steps - t_start
|
| 621 |
+
|
| 622 |
+
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None):
|
| 623 |
+
if not isinstance(image, (paddle.Tensor, PIL.Image.Image, list)):
|
| 624 |
+
raise ValueError(
|
| 625 |
+
f"`image` has to be of type `paddle.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
image = image.cast(dtype)
|
| 629 |
+
|
| 630 |
+
batch_size = batch_size * num_images_per_prompt
|
| 631 |
+
|
| 632 |
+
if image.shape[1] == 4:
|
| 633 |
+
init_latents = image
|
| 634 |
+
|
| 635 |
+
else:
|
| 636 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 637 |
+
raise ValueError(
|
| 638 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 639 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
elif isinstance(generator, list):
|
| 643 |
+
init_latents = [
|
| 644 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
| 645 |
+
for i in range(batch_size)
|
| 646 |
+
]
|
| 647 |
+
init_latents = paddle.concat(init_latents, axis=0)
|
| 648 |
+
else:
|
| 649 |
+
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
| 650 |
+
|
| 651 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
| 652 |
+
|
| 653 |
+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
| 654 |
+
# expand init_latents for batch_size
|
| 655 |
+
deprecation_message = (
|
| 656 |
+
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
| 657 |
+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
| 658 |
+
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
| 659 |
+
" your script to pass as many initial images as text prompts to suppress this warning."
|
| 660 |
+
)
|
| 661 |
+
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
| 662 |
+
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
| 663 |
+
init_latents = paddle.concat([init_latents] * additional_image_per_prompt, axis=0)
|
| 664 |
+
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
| 665 |
+
raise ValueError(
|
| 666 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
| 667 |
+
)
|
| 668 |
+
else:
|
| 669 |
+
init_latents = paddle.concat([init_latents], axis=0)
|
| 670 |
+
|
| 671 |
+
shape = init_latents.shape
|
| 672 |
+
noise = randn_tensor(shape, generator=generator, dtype=dtype)
|
| 673 |
+
|
| 674 |
+
# get latents
|
| 675 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
| 676 |
+
latents = init_latents
|
| 677 |
+
|
| 678 |
+
return latents
|
| 679 |
+
|
| 680 |
+
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=paddle.float32):
|
| 681 |
+
"""
|
| 682 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
| 683 |
+
|
| 684 |
+
Args:
|
| 685 |
+
timesteps (`paddle.Tensor`):
|
| 686 |
+
generate embedding vectors at these timesteps
|
| 687 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
| 688 |
+
dimension of the embeddings to generate
|
| 689 |
+
dtype:
|
| 690 |
+
data type of the generated embeddings
|
| 691 |
+
|
| 692 |
+
Returns:
|
| 693 |
+
`paddle.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
| 694 |
+
"""
|
| 695 |
+
assert len(w.shape) == 1
|
| 696 |
+
w = w * 1000.0
|
| 697 |
+
|
| 698 |
+
half_dim = embedding_dim // 2
|
| 699 |
+
emb = paddle.log(paddle.to_tensor(10000.0)) / (half_dim - 1)
|
| 700 |
+
emb = paddle.exp(paddle.arange(half_dim, dtype=dtype) * -emb)
|
| 701 |
+
emb = w.cast(dtype)[:, None] * emb[None, :]
|
| 702 |
+
emb = paddle.concat([paddle.sin(emb), paddle.sin(emb)], axis=1)
|
| 703 |
+
if embedding_dim % 2 == 1: # zero pad
|
| 704 |
+
emb = paddle.concat([emb, paddle.zeros([emb.shape[0], 1])], axis=-1)
|
| 705 |
+
assert tuple(emb.shape) == (w.shape[0], embedding_dim)
|
| 706 |
+
return emb
|
| 707 |
+
|
| 708 |
+
@property
|
| 709 |
+
def guidance_scale(self):
|
| 710 |
+
return self._guidance_scale
|
| 711 |
+
|
| 712 |
+
@property
|
| 713 |
+
def clip_skip(self):
|
| 714 |
+
return self._clip_skip
|
| 715 |
+
|
| 716 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 717 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 718 |
+
# corresponds to doing no classifier free guidance.
|
| 719 |
+
@property
|
| 720 |
+
def do_classifier_free_guidance(self):
|
| 721 |
+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
| 722 |
+
|
| 723 |
+
@property
|
| 724 |
+
def cross_attention_kwargs(self):
|
| 725 |
+
return self._cross_attention_kwargs
|
| 726 |
+
|
| 727 |
+
@property
|
| 728 |
+
def num_timesteps(self):
|
| 729 |
+
return self._num_timesteps
|
| 730 |
+
|
| 731 |
+
@paddle.no_grad()
|
| 732 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 733 |
+
def __call__(
|
| 734 |
+
self,
|
| 735 |
+
prompt: Union[str, List[str]] = None,
|
| 736 |
+
image: PipelineImageInput = None,
|
| 737 |
+
strength: float = 0.8,
|
| 738 |
+
num_inference_steps: Optional[int] = 50,
|
| 739 |
+
timesteps: List[int] = None,
|
| 740 |
+
guidance_scale: Optional[float] = 7.5,
|
| 741 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 742 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 743 |
+
eta: Optional[float] = 0.0,
|
| 744 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 745 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 746 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 747 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 748 |
+
output_type: Optional[str] = "pil",
|
| 749 |
+
return_dict: bool = True,
|
| 750 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 751 |
+
clip_skip: int = None,
|
| 752 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 753 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 754 |
+
**kwargs,
|
| 755 |
+
):
|
| 756 |
+
r"""
|
| 757 |
+
The call function to the pipeline for generation.
|
| 758 |
+
|
| 759 |
+
Args:
|
| 760 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 761 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 762 |
+
image (`paddle.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[paddle.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 763 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
| 764 |
+
numpy array and paddle Tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
| 765 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
| 766 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
| 767 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
| 768 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 769 |
+
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
| 770 |
+
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
| 771 |
+
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
| 772 |
+
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
| 773 |
+
essentially ignores `image`.
|
| 774 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 775 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 776 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
| 777 |
+
timesteps (`List[int]`, *optional*):
|
| 778 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 779 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 780 |
+
passed will be used. Must be in descending order.
|
| 781 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 782 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 783 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 784 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 785 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 786 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 787 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 788 |
+
The number of images to generate per prompt.
|
| 789 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 790 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
| 791 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 792 |
+
generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*):
|
| 793 |
+
A [`paddle.Generator`] to make generation deterministic.
|
| 794 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 795 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 796 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 797 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 798 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 799 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 800 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 801 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 802 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 803 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 804 |
+
Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
|
| 805 |
+
plain tuple.
|
| 806 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 807 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 808 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 809 |
+
clip_skip (`int`, *optional*):
|
| 810 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 811 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 812 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 813 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 814 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 815 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 816 |
+
`callback_on_step_end_tensor_inputs`.
|
| 817 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 818 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 819 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 820 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 821 |
+
Examples:
|
| 822 |
+
|
| 823 |
+
Returns:
|
| 824 |
+
[`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`:
|
| 825 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] is returned,
|
| 826 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
| 827 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
| 828 |
+
"not-safe-for-work" (nsfw) content.
|
| 829 |
+
"""
|
| 830 |
+
|
| 831 |
+
callback = kwargs.pop("callback", None)
|
| 832 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 833 |
+
|
| 834 |
+
if callback is not None:
|
| 835 |
+
deprecate(
|
| 836 |
+
"callback",
|
| 837 |
+
"1.0.0",
|
| 838 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 839 |
+
)
|
| 840 |
+
if callback_steps is not None:
|
| 841 |
+
deprecate(
|
| 842 |
+
"callback_steps",
|
| 843 |
+
"1.0.0",
|
| 844 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
# 1. Check inputs. Raise error if not correct
|
| 848 |
+
self.check_inputs(
|
| 849 |
+
prompt,
|
| 850 |
+
strength,
|
| 851 |
+
callback_steps,
|
| 852 |
+
negative_prompt,
|
| 853 |
+
prompt_embeds,
|
| 854 |
+
negative_prompt_embeds,
|
| 855 |
+
callback_on_step_end_tensor_inputs,
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
self._guidance_scale = guidance_scale
|
| 859 |
+
self._clip_skip = clip_skip
|
| 860 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 861 |
+
|
| 862 |
+
# 2. Define call parameters
|
| 863 |
+
if prompt is not None and isinstance(prompt, str):
|
| 864 |
+
batch_size = 1
|
| 865 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 866 |
+
batch_size = len(prompt)
|
| 867 |
+
else:
|
| 868 |
+
batch_size = prompt_embeds.shape[0]
|
| 869 |
+
num_images_per_prompt = 1 if ip_adapter_image is None else num_images_per_prompt
|
| 870 |
+
# 3. Encode input prompt
|
| 871 |
+
text_encoder_lora_scale = (
|
| 872 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 873 |
+
)
|
| 874 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 875 |
+
prompt,
|
| 876 |
+
num_images_per_prompt,
|
| 877 |
+
self.do_classifier_free_guidance,
|
| 878 |
+
negative_prompt,
|
| 879 |
+
prompt_embeds=prompt_embeds,
|
| 880 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 881 |
+
lora_scale=text_encoder_lora_scale,
|
| 882 |
+
clip_skip=self.clip_skip,
|
| 883 |
+
)
|
| 884 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 885 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 886 |
+
# to avoid doing two forward passes
|
| 887 |
+
if self.do_classifier_free_guidance:
|
| 888 |
+
prompt_embeds = paddle.concat([negative_prompt_embeds, prompt_embeds])
|
| 889 |
+
|
| 890 |
+
if ip_adapter_image is not None:
|
| 891 |
+
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, num_images_per_prompt)
|
| 892 |
+
if self.do_classifier_free_guidance:
|
| 893 |
+
image_embeds = paddle.concat([negative_image_embeds, image_embeds])
|
| 894 |
+
|
| 895 |
+
# 4. Preprocess image
|
| 896 |
+
image = self.image_processor.preprocess(image)
|
| 897 |
+
|
| 898 |
+
# 5. set timesteps
|
| 899 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps)
|
| 900 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
|
| 901 |
+
latent_timestep = timesteps[:1].tile([batch_size * num_images_per_prompt])
|
| 902 |
+
|
| 903 |
+
# 6. Prepare latent variables
|
| 904 |
+
latents = self.prepare_latents(
|
| 905 |
+
image,
|
| 906 |
+
latent_timestep,
|
| 907 |
+
batch_size,
|
| 908 |
+
num_images_per_prompt,
|
| 909 |
+
prompt_embeds.dtype,
|
| 910 |
+
generator,
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 914 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 915 |
+
|
| 916 |
+
# 7.1 Add image embeds for IP-Adapter
|
| 917 |
+
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
| 918 |
+
|
| 919 |
+
# 7.2 Optionally get Guidance Scale Embedding
|
| 920 |
+
timestep_cond = None
|
| 921 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
| 922 |
+
guidance_scale_tensor = paddle.to_tensor(self.guidance_scale - 1).tile(
|
| 923 |
+
[batch_size * num_images_per_prompt]
|
| 924 |
+
)
|
| 925 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
| 926 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 927 |
+
).cast(dtype=latents.dtype)
|
| 928 |
+
|
| 929 |
+
# 8. Denoising loop
|
| 930 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 931 |
+
self._num_timesteps = len(timesteps)
|
| 932 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 933 |
+
for i, t in enumerate(timesteps):
|
| 934 |
+
# expand the latents if we are doing classifier free guidance
|
| 935 |
+
latent_model_input = paddle.concat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 936 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 937 |
+
|
| 938 |
+
# predict the noise residual
|
| 939 |
+
noise_pred = self.unet(
|
| 940 |
+
latent_model_input,
|
| 941 |
+
t,
|
| 942 |
+
encoder_hidden_states=prompt_embeds,
|
| 943 |
+
timestep_cond=timestep_cond,
|
| 944 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 945 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 946 |
+
return_dict=False,
|
| 947 |
+
)[0]
|
| 948 |
+
|
| 949 |
+
# perform guidance
|
| 950 |
+
if self.do_classifier_free_guidance:
|
| 951 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 952 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 953 |
+
|
| 954 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 955 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 956 |
+
|
| 957 |
+
if callback_on_step_end is not None:
|
| 958 |
+
callback_kwargs = {}
|
| 959 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 960 |
+
callback_kwargs[k] = locals()[k]
|
| 961 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 962 |
+
|
| 963 |
+
latents = callback_outputs.pop("latents", latents)
|
| 964 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 965 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 966 |
+
|
| 967 |
+
# call the callback, if provided
|
| 968 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 969 |
+
progress_bar.update()
|
| 970 |
+
if callback is not None and i % callback_steps == 0:
|
| 971 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 972 |
+
callback(step_idx, t, latents)
|
| 973 |
+
|
| 974 |
+
if not output_type == "latent":
|
| 975 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
| 976 |
+
0
|
| 977 |
+
]
|
| 978 |
+
image, has_nsfw_concept = self.run_safety_checker(image, prompt_embeds.dtype)
|
| 979 |
+
else:
|
| 980 |
+
image = latents
|
| 981 |
+
has_nsfw_concept = None
|
| 982 |
+
|
| 983 |
+
if has_nsfw_concept is None:
|
| 984 |
+
do_denormalize = [True] * image.shape[0]
|
| 985 |
+
else:
|
| 986 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 987 |
+
|
| 988 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 989 |
+
|
| 990 |
+
if not return_dict:
|
| 991 |
+
return (image, has_nsfw_concept)
|
| 992 |
+
|
| 993 |
+
return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/alt_diffusion/pipeline_output.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import PIL.Image
|
| 20 |
+
|
| 21 |
+
from ...utils import BaseOutput
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput with Stable->Alt
|
| 26 |
+
class AltDiffusionPipelineOutput(BaseOutput):
|
| 27 |
+
"""
|
| 28 |
+
Output class for Alt Diffusion pipelines.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
| 32 |
+
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
| 33 |
+
num_channels)`.
|
| 34 |
+
nsfw_content_detected (`List[bool]`)
|
| 35 |
+
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
|
| 36 |
+
`None` if safety checking could not be performed.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
| 40 |
+
nsfw_content_detected: Optional[List[bool]]
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/animatediff/__init__.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
from ...utils import (
|
| 18 |
+
PPDIFFUSERS_SLOW_IMPORT,
|
| 19 |
+
OptionalDependencyNotAvailable,
|
| 20 |
+
_LazyModule,
|
| 21 |
+
get_objects_from_module,
|
| 22 |
+
is_paddle_available,
|
| 23 |
+
is_paddlenlp_available,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
_dummy_objects = {}
|
| 27 |
+
_import_structure = {}
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 31 |
+
raise OptionalDependencyNotAvailable()
|
| 32 |
+
except OptionalDependencyNotAvailable:
|
| 33 |
+
from ...utils import dummy_paddle_and_paddlenlp_objects
|
| 34 |
+
|
| 35 |
+
_dummy_objects.update(get_objects_from_module(dummy_paddle_and_paddlenlp_objects))
|
| 36 |
+
else:
|
| 37 |
+
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline", "AnimateDiffPipelineOutput"]
|
| 38 |
+
|
| 39 |
+
if TYPE_CHECKING or PPDIFFUSERS_SLOW_IMPORT:
|
| 40 |
+
try:
|
| 41 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 42 |
+
raise OptionalDependencyNotAvailable()
|
| 43 |
+
except OptionalDependencyNotAvailable:
|
| 44 |
+
from ...utils.dummy_paddle_and_paddlenlp_objects import *
|
| 45 |
+
|
| 46 |
+
else:
|
| 47 |
+
from .pipeline_animatediff import AnimateDiffPipeline, AnimateDiffPipelineOutput
|
| 48 |
+
|
| 49 |
+
else:
|
| 50 |
+
import sys
|
| 51 |
+
|
| 52 |
+
sys.modules[__name__] = _LazyModule(
|
| 53 |
+
__name__,
|
| 54 |
+
globals()["__file__"],
|
| 55 |
+
_import_structure,
|
| 56 |
+
module_spec=__spec__,
|
| 57 |
+
)
|
| 58 |
+
for name, value in _dummy_objects.items():
|
| 59 |
+
setattr(sys.modules[__name__], name, value)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/audioldm2/__init__.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
from ...utils import (
|
| 18 |
+
PPDIFFUSERS_SLOW_IMPORT,
|
| 19 |
+
OptionalDependencyNotAvailable,
|
| 20 |
+
_LazyModule,
|
| 21 |
+
get_objects_from_module,
|
| 22 |
+
is_paddle_available,
|
| 23 |
+
is_paddlenlp_available,
|
| 24 |
+
is_paddlenlp_version,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
_dummy_objects = {}
|
| 28 |
+
_import_structure = {}
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
if not (is_paddlenlp_available() and is_paddle_available() and is_paddlenlp_version(">=", "2.6.0")):
|
| 32 |
+
raise OptionalDependencyNotAvailable()
|
| 33 |
+
except OptionalDependencyNotAvailable:
|
| 34 |
+
from ...utils import dummy_paddle_and_paddlenlp_objects
|
| 35 |
+
|
| 36 |
+
_dummy_objects.update(get_objects_from_module(dummy_paddle_and_paddlenlp_objects))
|
| 37 |
+
else:
|
| 38 |
+
_import_structure["modeling_audioldm2"] = ["AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel"]
|
| 39 |
+
_import_structure["pipeline_audioldm2"] = ["AudioLDM2Pipeline"]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if TYPE_CHECKING or PPDIFFUSERS_SLOW_IMPORT:
|
| 43 |
+
try:
|
| 44 |
+
if not (is_paddlenlp_available() and is_paddle_available() and is_paddlenlp_version(">=", "2.6.0")):
|
| 45 |
+
raise OptionalDependencyNotAvailable()
|
| 46 |
+
except OptionalDependencyNotAvailable:
|
| 47 |
+
from ...utils.dummy_paddle_and_paddlenlp_objects import *
|
| 48 |
+
|
| 49 |
+
else:
|
| 50 |
+
from .modeling_audioldm2 import (
|
| 51 |
+
AudioLDM2ProjectionModel,
|
| 52 |
+
AudioLDM2UNet2DConditionModel,
|
| 53 |
+
)
|
| 54 |
+
from .pipeline_audioldm2 import AudioLDM2Pipeline
|
| 55 |
+
|
| 56 |
+
else:
|
| 57 |
+
import sys
|
| 58 |
+
|
| 59 |
+
sys.modules[__name__] = _LazyModule(
|
| 60 |
+
__name__,
|
| 61 |
+
globals()["__file__"],
|
| 62 |
+
_import_structure,
|
| 63 |
+
module_spec=__spec__,
|
| 64 |
+
)
|
| 65 |
+
for name, value in _dummy_objects.items():
|
| 66 |
+
setattr(sys.modules[__name__], name, value)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/audioldm2/modeling_audioldm2.py
ADDED
|
@@ -0,0 +1,1520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import paddle
|
| 19 |
+
import paddle.nn as nn
|
| 20 |
+
from paddle.distributed.fleet.utils import recompute
|
| 21 |
+
|
| 22 |
+
from ...configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from ...loaders import UNet2DConditionLoadersMixin
|
| 24 |
+
from ...models.activations import get_activation
|
| 25 |
+
from ...models.attention_processor import (
|
| 26 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
| 27 |
+
CROSS_ATTENTION_PROCESSORS,
|
| 28 |
+
AttentionProcessor,
|
| 29 |
+
AttnAddedKVProcessor,
|
| 30 |
+
AttnProcessor,
|
| 31 |
+
)
|
| 32 |
+
from ...models.embeddings import TimestepEmbedding, Timesteps
|
| 33 |
+
from ...models.modeling_utils import ModelMixin
|
| 34 |
+
from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
| 35 |
+
from ...models.transformer_2d import Transformer2DModel
|
| 36 |
+
from ...models.unet_2d_blocks import DownBlock2D, UpBlock2D
|
| 37 |
+
from ...models.unet_2d_condition import UNet2DConditionOutput
|
| 38 |
+
from ...utils import BaseOutput, logging, recompute_use_reentrant
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def add_special_tokens(hidden_states, attention_mask, sos_token, eos_token):
|
| 44 |
+
batch_size = hidden_states.shape[0]
|
| 45 |
+
|
| 46 |
+
if attention_mask is not None:
|
| 47 |
+
# Add two more steps to attn mask
|
| 48 |
+
new_attn_mask_step = paddle.ones([batch_size, 1], dtype=attention_mask.dtype)
|
| 49 |
+
attention_mask = paddle.concat([new_attn_mask_step, attention_mask, new_attn_mask_step], axis=-1)
|
| 50 |
+
|
| 51 |
+
# Add the SOS / EOS tokens at the start / end of the sequence respectively
|
| 52 |
+
sos_token = sos_token.expand([batch_size, 1, -1])
|
| 53 |
+
eos_token = eos_token.expand([batch_size, 1, -1])
|
| 54 |
+
hidden_states = paddle.concat([sos_token, hidden_states, eos_token], axis=1)
|
| 55 |
+
return hidden_states, attention_mask
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class AudioLDM2ProjectionModelOutput(BaseOutput):
|
| 60 |
+
"""
|
| 61 |
+
Args:
|
| 62 |
+
Class for AudioLDM2 projection layer's outputs.
|
| 63 |
+
hidden_states (`paddle.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 64 |
+
Sequence of hidden-states obtained by linearly projecting the hidden-states for each of the text
|
| 65 |
+
encoders and subsequently concatenating them together.
|
| 66 |
+
attention_mask (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 67 |
+
Mask to avoid performing attention on padding token indices, formed by concatenating the attention masks
|
| 68 |
+
for the two text encoders together. Mask values selected in `[0, 1]`:
|
| 69 |
+
|
| 70 |
+
- 1 for tokens that are **not masked**,
|
| 71 |
+
- 0 for tokens that are **masked**.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
hidden_states: paddle.Tensor
|
| 75 |
+
attention_mask: Optional[paddle.Tensor] = None
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin):
|
| 79 |
+
"""
|
| 80 |
+
A simple linear projection model to map two text embeddings to a shared latent space. It also inserts learned
|
| 81 |
+
embedding vectors at the start and end of each text embedding sequence respectively. Each variable appended with
|
| 82 |
+
`_1` refers to that corresponding to the second text encoder. Otherwise, it is from the first.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
text_encoder_dim (`int`):
|
| 86 |
+
Dimensionality of the text embeddings from the first text encoder (CLAP).
|
| 87 |
+
text_encoder_1_dim (`int`):
|
| 88 |
+
Dimensionality of the text embeddings from the second text encoder (T5 or VITS).
|
| 89 |
+
langauge_model_dim (`int`):
|
| 90 |
+
Dimensionality of the text embeddings from the language model (GPT2).
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
@register_to_config
|
| 94 |
+
def __init__(self, text_encoder_dim, text_encoder_1_dim, langauge_model_dim):
|
| 95 |
+
super().__init__()
|
| 96 |
+
# additional projection layers for each text encoder
|
| 97 |
+
self.projection = nn.Linear(text_encoder_dim, langauge_model_dim)
|
| 98 |
+
self.projection_1 = nn.Linear(text_encoder_1_dim, langauge_model_dim)
|
| 99 |
+
|
| 100 |
+
# learnable SOS / EOS token embeddings for each text encoder
|
| 101 |
+
self.sos_embed = nn.Parameter(paddle.ones((langauge_model_dim,)))
|
| 102 |
+
self.eos_embed = nn.Parameter(paddle.ones((langauge_model_dim,)))
|
| 103 |
+
|
| 104 |
+
self.sos_embed_1 = nn.Parameter(paddle.ones((langauge_model_dim,)))
|
| 105 |
+
self.eos_embed_1 = nn.Parameter(paddle.ones((langauge_model_dim,)))
|
| 106 |
+
|
| 107 |
+
def forward(
|
| 108 |
+
self,
|
| 109 |
+
hidden_states: Optional[paddle.Tensor] = None,
|
| 110 |
+
hidden_states_1: Optional[paddle.Tensor] = None,
|
| 111 |
+
attention_mask: Optional[paddle.Tensor] = None,
|
| 112 |
+
attention_mask_1: Optional[paddle.Tensor] = None,
|
| 113 |
+
):
|
| 114 |
+
hidden_states = self.projection(hidden_states)
|
| 115 |
+
hidden_states, attention_mask = add_special_tokens(
|
| 116 |
+
hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
hidden_states_1 = self.projection_1(hidden_states_1)
|
| 120 |
+
hidden_states_1, attention_mask_1 = add_special_tokens(
|
| 121 |
+
hidden_states_1, attention_mask_1, sos_token=self.sos_embed_1, eos_token=self.eos_embed_1
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# concatenate clap and t5 text encoding
|
| 125 |
+
hidden_states = paddle.concat([hidden_states, hidden_states_1], axis=1)
|
| 126 |
+
|
| 127 |
+
# concatenate attention masks
|
| 128 |
+
if attention_mask is None and attention_mask_1 is not None:
|
| 129 |
+
attention_mask = paddle.ones(hidden_states[:2], dtype=attention_mask_1.dtype)
|
| 130 |
+
|
| 131 |
+
elif attention_mask is not None and attention_mask_1 is None:
|
| 132 |
+
attention_mask = paddle.ones(hidden_states_1[:2], dtype=attention_mask.dtype)
|
| 133 |
+
|
| 134 |
+
if attention_mask is not None and attention_mask_1 is not None:
|
| 135 |
+
attention_mask = paddle.concat([attention_mask, attention_mask_1], axis=-1)
|
| 136 |
+
else:
|
| 137 |
+
attention_mask = None
|
| 138 |
+
|
| 139 |
+
return AudioLDM2ProjectionModelOutput(
|
| 140 |
+
hidden_states=hidden_states,
|
| 141 |
+
attention_mask=attention_mask,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
| 146 |
+
r"""
|
| 147 |
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
| 148 |
+
shaped output. Compared to the vanilla [`UNet2DConditionModel`], this variant optionally includes an additional
|
| 149 |
+
self-attention layer in each Transformer block, as well as multiple cross-attention layers. It also allows for up
|
| 150 |
+
to two cross-attention embeddings, `encoder_hidden_states` and `encoder_hidden_states_1`.
|
| 151 |
+
|
| 152 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 153 |
+
for all models (such as downloading or saving).
|
| 154 |
+
|
| 155 |
+
Parameters:
|
| 156 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
| 157 |
+
Height and width of input/output sample.
|
| 158 |
+
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
| 159 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
| 160 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
| 161 |
+
Whether to flip the sin to cos in the time embedding.
|
| 162 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
| 163 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
| 164 |
+
The tuple of downsample blocks to use.
|
| 165 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
| 166 |
+
Block type for middle of UNet, it can only be `UNetMidBlock2DCrossAttn` for AudioLDM2.
|
| 167 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
| 168 |
+
The tuple of upsample blocks to use.
|
| 169 |
+
only_cross_attention (`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
| 170 |
+
Whether to include self-attention in the basic transformer blocks, see
|
| 171 |
+
[`~models.attention.BasicTransformerBlock`].
|
| 172 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
| 173 |
+
The tuple of output channels for each block.
|
| 174 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
| 175 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
| 176 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
| 177 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
| 178 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
| 179 |
+
If `None`, normalization and activation layers is skipped in post-processing.
|
| 180 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
| 181 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
| 182 |
+
The dimension of the cross attention features.
|
| 183 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
| 184 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
| 185 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
| 186 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
| 187 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
| 188 |
+
num_attention_heads (`int`, *optional*):
|
| 189 |
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
| 190 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
| 191 |
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
| 192 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
| 193 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
| 194 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
| 195 |
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
| 196 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
| 197 |
+
class conditioning with `class_embed_type` equal to `None`.
|
| 198 |
+
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
| 199 |
+
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
| 200 |
+
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
| 201 |
+
An optional override for the dimension of the projected time embedding.
|
| 202 |
+
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
| 203 |
+
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
| 204 |
+
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
| 205 |
+
timestep_post_act (`str`, *optional*, defaults to `None`):
|
| 206 |
+
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
| 207 |
+
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
| 208 |
+
The dimension of `cond_proj` layer in the timestep embedding.
|
| 209 |
+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
| 210 |
+
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
| 211 |
+
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
| 212 |
+
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
| 213 |
+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
| 214 |
+
embeddings with the class embeddings.
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
_supports_gradient_checkpointing = True
|
| 218 |
+
|
| 219 |
+
@register_to_config
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
sample_size: Optional[int] = None,
|
| 223 |
+
in_channels: int = 4,
|
| 224 |
+
out_channels: int = 4,
|
| 225 |
+
flip_sin_to_cos: bool = True,
|
| 226 |
+
freq_shift: int = 0,
|
| 227 |
+
down_block_types: Tuple[str] = (
|
| 228 |
+
"CrossAttnDownBlock2D",
|
| 229 |
+
"CrossAttnDownBlock2D",
|
| 230 |
+
"CrossAttnDownBlock2D",
|
| 231 |
+
"DownBlock2D",
|
| 232 |
+
),
|
| 233 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
| 234 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
| 235 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
| 236 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
| 237 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
| 238 |
+
downsample_padding: int = 1,
|
| 239 |
+
mid_block_scale_factor: float = 1,
|
| 240 |
+
act_fn: str = "silu",
|
| 241 |
+
norm_num_groups: Optional[int] = 32,
|
| 242 |
+
norm_eps: float = 1e-5,
|
| 243 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
| 244 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
| 245 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
| 246 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
| 247 |
+
use_linear_projection: bool = False,
|
| 248 |
+
class_embed_type: Optional[str] = None,
|
| 249 |
+
num_class_embeds: Optional[int] = None,
|
| 250 |
+
upcast_attention: bool = False,
|
| 251 |
+
resnet_time_scale_shift: str = "default",
|
| 252 |
+
time_embedding_type: str = "positional",
|
| 253 |
+
time_embedding_dim: Optional[int] = None,
|
| 254 |
+
time_embedding_act_fn: Optional[str] = None,
|
| 255 |
+
timestep_post_act: Optional[str] = None,
|
| 256 |
+
time_cond_proj_dim: Optional[int] = None,
|
| 257 |
+
conv_in_kernel: int = 3,
|
| 258 |
+
conv_out_kernel: int = 3,
|
| 259 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
| 260 |
+
class_embeddings_concat: bool = False,
|
| 261 |
+
):
|
| 262 |
+
super().__init__()
|
| 263 |
+
|
| 264 |
+
self.sample_size = sample_size
|
| 265 |
+
|
| 266 |
+
if num_attention_heads is not None:
|
| 267 |
+
raise ValueError(
|
| 268 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
| 272 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
| 273 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
| 274 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
| 275 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
| 276 |
+
# which is why we correct for the naming here.
|
| 277 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
| 278 |
+
|
| 279 |
+
# Check inputs
|
| 280 |
+
if len(down_block_types) != len(up_block_types):
|
| 281 |
+
raise ValueError(
|
| 282 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
if len(block_out_channels) != len(down_block_types):
|
| 286 |
+
raise ValueError(
|
| 287 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
| 291 |
+
raise ValueError(
|
| 292 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
| 296 |
+
raise ValueError(
|
| 297 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
| 301 |
+
raise ValueError(
|
| 302 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
| 306 |
+
raise ValueError(
|
| 307 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
| 311 |
+
raise ValueError(
|
| 312 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
# input
|
| 316 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
| 317 |
+
self.conv_in = nn.Conv2D(
|
| 318 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# time
|
| 322 |
+
if time_embedding_type == "positional":
|
| 323 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
| 324 |
+
|
| 325 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
| 326 |
+
timestep_input_dim = block_out_channels[0]
|
| 327 |
+
else:
|
| 328 |
+
raise ValueError(f"{time_embedding_type} does not exist. Please make sure to use `positional`.")
|
| 329 |
+
|
| 330 |
+
self.time_embedding = TimestepEmbedding(
|
| 331 |
+
timestep_input_dim,
|
| 332 |
+
time_embed_dim,
|
| 333 |
+
act_fn=act_fn,
|
| 334 |
+
post_act_fn=timestep_post_act,
|
| 335 |
+
cond_proj_dim=time_cond_proj_dim,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# class embedding
|
| 339 |
+
if class_embed_type is None and num_class_embeds is not None:
|
| 340 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
| 341 |
+
elif class_embed_type == "timestep":
|
| 342 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
| 343 |
+
elif class_embed_type == "identity":
|
| 344 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
| 345 |
+
elif class_embed_type == "projection":
|
| 346 |
+
if projection_class_embeddings_input_dim is None:
|
| 347 |
+
raise ValueError(
|
| 348 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
| 349 |
+
)
|
| 350 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
| 351 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
| 352 |
+
# 2. it projects from an arbitrary input dimension.
|
| 353 |
+
#
|
| 354 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
| 355 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
| 356 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
| 357 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
| 358 |
+
elif class_embed_type == "simple_projection":
|
| 359 |
+
if projection_class_embeddings_input_dim is None:
|
| 360 |
+
raise ValueError(
|
| 361 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
| 362 |
+
)
|
| 363 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
| 364 |
+
else:
|
| 365 |
+
self.class_embedding = None
|
| 366 |
+
|
| 367 |
+
if time_embedding_act_fn is None:
|
| 368 |
+
self.time_embed_act = None
|
| 369 |
+
else:
|
| 370 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
| 371 |
+
|
| 372 |
+
self.down_blocks = nn.LayerList([])
|
| 373 |
+
self.up_blocks = nn.LayerList([])
|
| 374 |
+
|
| 375 |
+
if isinstance(only_cross_attention, bool):
|
| 376 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
| 377 |
+
|
| 378 |
+
if isinstance(num_attention_heads, int):
|
| 379 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
| 380 |
+
|
| 381 |
+
if isinstance(cross_attention_dim, int):
|
| 382 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
| 383 |
+
|
| 384 |
+
if isinstance(layers_per_block, int):
|
| 385 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
| 386 |
+
|
| 387 |
+
if isinstance(transformer_layers_per_block, int):
|
| 388 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
| 389 |
+
|
| 390 |
+
if class_embeddings_concat:
|
| 391 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
| 392 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
| 393 |
+
# regular time embeddings
|
| 394 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
| 395 |
+
else:
|
| 396 |
+
blocks_time_embed_dim = time_embed_dim
|
| 397 |
+
|
| 398 |
+
# down
|
| 399 |
+
output_channel = block_out_channels[0]
|
| 400 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 401 |
+
input_channel = output_channel
|
| 402 |
+
output_channel = block_out_channels[i]
|
| 403 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 404 |
+
|
| 405 |
+
down_block = get_down_block(
|
| 406 |
+
down_block_type,
|
| 407 |
+
num_layers=layers_per_block[i],
|
| 408 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
| 409 |
+
in_channels=input_channel,
|
| 410 |
+
out_channels=output_channel,
|
| 411 |
+
temb_channels=blocks_time_embed_dim,
|
| 412 |
+
add_downsample=not is_final_block,
|
| 413 |
+
resnet_eps=norm_eps,
|
| 414 |
+
resnet_act_fn=act_fn,
|
| 415 |
+
resnet_groups=norm_num_groups,
|
| 416 |
+
cross_attention_dim=cross_attention_dim[i],
|
| 417 |
+
num_attention_heads=num_attention_heads[i],
|
| 418 |
+
downsample_padding=downsample_padding,
|
| 419 |
+
use_linear_projection=use_linear_projection,
|
| 420 |
+
only_cross_attention=only_cross_attention[i],
|
| 421 |
+
upcast_attention=upcast_attention,
|
| 422 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 423 |
+
)
|
| 424 |
+
self.down_blocks.append(down_block)
|
| 425 |
+
|
| 426 |
+
# mid
|
| 427 |
+
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
| 428 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
| 429 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
| 430 |
+
in_channels=block_out_channels[-1],
|
| 431 |
+
temb_channels=blocks_time_embed_dim,
|
| 432 |
+
resnet_eps=norm_eps,
|
| 433 |
+
resnet_act_fn=act_fn,
|
| 434 |
+
output_scale_factor=mid_block_scale_factor,
|
| 435 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 436 |
+
cross_attention_dim=cross_attention_dim[-1],
|
| 437 |
+
num_attention_heads=num_attention_heads[-1],
|
| 438 |
+
resnet_groups=norm_num_groups,
|
| 439 |
+
use_linear_projection=use_linear_projection,
|
| 440 |
+
upcast_attention=upcast_attention,
|
| 441 |
+
)
|
| 442 |
+
else:
|
| 443 |
+
raise ValueError(
|
| 444 |
+
f"unknown mid_block_type : {mid_block_type}. Should be `UNetMidBlock2DCrossAttn` for AudioLDM2."
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# count how many layers upsample the images
|
| 448 |
+
self.num_upsamplers = 0
|
| 449 |
+
|
| 450 |
+
# up
|
| 451 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 452 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
| 453 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
| 454 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
| 455 |
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
| 456 |
+
reversed_only_cross_attention = list(reversed(only_cross_attention))
|
| 457 |
+
|
| 458 |
+
output_channel = reversed_block_out_channels[0]
|
| 459 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 460 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 461 |
+
|
| 462 |
+
prev_output_channel = output_channel
|
| 463 |
+
output_channel = reversed_block_out_channels[i]
|
| 464 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
| 465 |
+
|
| 466 |
+
# add upsample block for all BUT final layer
|
| 467 |
+
if not is_final_block:
|
| 468 |
+
add_upsample = True
|
| 469 |
+
self.num_upsamplers += 1
|
| 470 |
+
else:
|
| 471 |
+
add_upsample = False
|
| 472 |
+
|
| 473 |
+
up_block = get_up_block(
|
| 474 |
+
up_block_type,
|
| 475 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
| 476 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
| 477 |
+
in_channels=input_channel,
|
| 478 |
+
out_channels=output_channel,
|
| 479 |
+
prev_output_channel=prev_output_channel,
|
| 480 |
+
temb_channels=blocks_time_embed_dim,
|
| 481 |
+
add_upsample=add_upsample,
|
| 482 |
+
resnet_eps=norm_eps,
|
| 483 |
+
resnet_act_fn=act_fn,
|
| 484 |
+
resnet_groups=norm_num_groups,
|
| 485 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
| 486 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
| 487 |
+
use_linear_projection=use_linear_projection,
|
| 488 |
+
only_cross_attention=reversed_only_cross_attention[i],
|
| 489 |
+
upcast_attention=upcast_attention,
|
| 490 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 491 |
+
)
|
| 492 |
+
self.up_blocks.append(up_block)
|
| 493 |
+
prev_output_channel = output_channel
|
| 494 |
+
|
| 495 |
+
# out
|
| 496 |
+
if norm_num_groups is not None:
|
| 497 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 498 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, epsilon=norm_eps
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
self.conv_act = get_activation(act_fn)
|
| 502 |
+
|
| 503 |
+
else:
|
| 504 |
+
self.conv_norm_out = None
|
| 505 |
+
self.conv_act = None
|
| 506 |
+
|
| 507 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
| 508 |
+
self.conv_out = nn.Conv2D(
|
| 509 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
@property
|
| 513 |
+
# Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 514 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 515 |
+
r"""
|
| 516 |
+
Returns:
|
| 517 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 518 |
+
indexed by its weight name.
|
| 519 |
+
"""
|
| 520 |
+
# set recursively
|
| 521 |
+
processors = {}
|
| 522 |
+
|
| 523 |
+
def fn_recursive_add_processors(name: str, module: nn.Layer, processors: Dict[str, AttentionProcessor]):
|
| 524 |
+
if hasattr(module, "get_processor"):
|
| 525 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
| 526 |
+
|
| 527 |
+
for sub_name, child in module.named_children():
|
| 528 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 529 |
+
|
| 530 |
+
return processors
|
| 531 |
+
|
| 532 |
+
for name, module in self.named_children():
|
| 533 |
+
fn_recursive_add_processors(name, module, processors)
|
| 534 |
+
|
| 535 |
+
return processors
|
| 536 |
+
|
| 537 |
+
# Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 538 |
+
def set_attn_processor(
|
| 539 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
| 540 |
+
):
|
| 541 |
+
r"""
|
| 542 |
+
Sets the attention processor to use to compute attention.
|
| 543 |
+
|
| 544 |
+
Parameters:
|
| 545 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 546 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 547 |
+
for **all** `Attention` layers.
|
| 548 |
+
|
| 549 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 550 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 551 |
+
|
| 552 |
+
"""
|
| 553 |
+
count = len(self.attn_processors.keys())
|
| 554 |
+
|
| 555 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 556 |
+
raise ValueError(
|
| 557 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 558 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
def fn_recursive_attn_processor(name: str, module: nn.Layer, processor):
|
| 562 |
+
if hasattr(module, "set_processor"):
|
| 563 |
+
if not isinstance(processor, dict):
|
| 564 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
| 565 |
+
else:
|
| 566 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
| 567 |
+
|
| 568 |
+
for sub_name, child in module.named_children():
|
| 569 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 570 |
+
|
| 571 |
+
for name, module in self.named_children():
|
| 572 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 573 |
+
|
| 574 |
+
# Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
| 575 |
+
def set_default_attn_processor(self):
|
| 576 |
+
"""
|
| 577 |
+
Disables custom attention processors and sets the default attention implementation.
|
| 578 |
+
"""
|
| 579 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 580 |
+
processor = AttnAddedKVProcessor()
|
| 581 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 582 |
+
processor = AttnProcessor()
|
| 583 |
+
else:
|
| 584 |
+
raise ValueError(
|
| 585 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
| 589 |
+
|
| 590 |
+
# Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
| 591 |
+
def set_attention_slice(self, slice_size):
|
| 592 |
+
r"""
|
| 593 |
+
Enable sliced attention computation.
|
| 594 |
+
|
| 595 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
| 596 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
| 597 |
+
|
| 598 |
+
Args:
|
| 599 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
| 600 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
| 601 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
| 602 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
| 603 |
+
must be a multiple of `slice_size`.
|
| 604 |
+
"""
|
| 605 |
+
sliceable_head_dims = []
|
| 606 |
+
|
| 607 |
+
def fn_recursive_retrieve_sliceable_dims(module: nn.Layer):
|
| 608 |
+
if hasattr(module, "set_attention_slice"):
|
| 609 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
| 610 |
+
|
| 611 |
+
for child in module.children():
|
| 612 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
| 613 |
+
|
| 614 |
+
# retrieve number of attention layers
|
| 615 |
+
for module in self.children():
|
| 616 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
| 617 |
+
|
| 618 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
| 619 |
+
|
| 620 |
+
if slice_size == "auto":
|
| 621 |
+
# half the attention head size is usually a good trade-off between
|
| 622 |
+
# speed and memory
|
| 623 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
| 624 |
+
elif slice_size == "max":
|
| 625 |
+
# make smallest slice possible
|
| 626 |
+
slice_size = num_sliceable_layers * [1]
|
| 627 |
+
|
| 628 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
| 629 |
+
|
| 630 |
+
if len(slice_size) != len(sliceable_head_dims):
|
| 631 |
+
raise ValueError(
|
| 632 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
| 633 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
for i in range(len(slice_size)):
|
| 637 |
+
size = slice_size[i]
|
| 638 |
+
dim = sliceable_head_dims[i]
|
| 639 |
+
if size is not None and size > dim:
|
| 640 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
| 641 |
+
|
| 642 |
+
# Recursively walk through all the children.
|
| 643 |
+
# Any children which exposes the set_attention_slice method
|
| 644 |
+
# gets the message
|
| 645 |
+
def fn_recursive_set_attention_slice(module: nn.Layer, slice_size: List[int]):
|
| 646 |
+
if hasattr(module, "set_attention_slice"):
|
| 647 |
+
module.set_attention_slice(slice_size.pop())
|
| 648 |
+
|
| 649 |
+
for child in module.children():
|
| 650 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
| 651 |
+
|
| 652 |
+
reversed_slice_size = list(reversed(slice_size))
|
| 653 |
+
for module in self.children():
|
| 654 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
| 655 |
+
|
| 656 |
+
# Copied from ppdiffusers.models.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
|
| 657 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 658 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 659 |
+
module.gradient_checkpointing = value
|
| 660 |
+
|
| 661 |
+
def forward(
|
| 662 |
+
self,
|
| 663 |
+
sample: paddle.Tensor,
|
| 664 |
+
timestep: Union[paddle.Tensor, float, int],
|
| 665 |
+
encoder_hidden_states: paddle.Tensor,
|
| 666 |
+
class_labels: Optional[paddle.Tensor] = None,
|
| 667 |
+
timestep_cond: Optional[paddle.Tensor] = None,
|
| 668 |
+
attention_mask: Optional[paddle.Tensor] = None,
|
| 669 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 670 |
+
encoder_attention_mask: Optional[paddle.Tensor] = None,
|
| 671 |
+
return_dict: bool = True,
|
| 672 |
+
encoder_hidden_states_1: Optional[paddle.Tensor] = None,
|
| 673 |
+
encoder_attention_mask_1: Optional[paddle.Tensor] = None,
|
| 674 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
| 675 |
+
r"""
|
| 676 |
+
The [`AudioLDM2UNet2DConditionModel`] forward method.
|
| 677 |
+
|
| 678 |
+
Args:
|
| 679 |
+
sample (`paddle.Tensor`):
|
| 680 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
| 681 |
+
timestep (`paddle.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
|
| 682 |
+
encoder_hidden_states (`paddle.Tensor`):
|
| 683 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
| 684 |
+
encoder_attention_mask (`paddle.Tensor`):
|
| 685 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
| 686 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
| 687 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
| 688 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 689 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
| 690 |
+
tuple.
|
| 691 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 692 |
+
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
| 693 |
+
encoder_hidden_states_1 (`paddle.Tensor`, *optional*):
|
| 694 |
+
A second set of encoder hidden states with shape `(batch, sequence_length_2, feature_dim_2)`. Can be
|
| 695 |
+
used to condition the model on a different set of embeddings to `encoder_hidden_states`.
|
| 696 |
+
encoder_attention_mask_1 (`paddle.Tensor`, *optional*):
|
| 697 |
+
A cross-attention mask of shape `(batch, sequence_length_2)` is applied to `encoder_hidden_states_1`.
|
| 698 |
+
If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
| 699 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
| 700 |
+
|
| 701 |
+
Returns:
|
| 702 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
| 703 |
+
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
| 704 |
+
a `tuple` is returned where the first element is the sample tensor.
|
| 705 |
+
"""
|
| 706 |
+
# TODO junnyu, add this to support pure fp16
|
| 707 |
+
sample = sample.cast(self.dtype)
|
| 708 |
+
|
| 709 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
| 710 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
| 711 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
| 712 |
+
# on the fly if necessary.
|
| 713 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
| 714 |
+
|
| 715 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 716 |
+
forward_upsample_size = False
|
| 717 |
+
upsample_size = None
|
| 718 |
+
|
| 719 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
| 720 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
| 721 |
+
forward_upsample_size = True
|
| 722 |
+
|
| 723 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
| 724 |
+
# expects mask of shape:
|
| 725 |
+
# [batch, key_tokens]
|
| 726 |
+
# adds singleton query_tokens dimension:
|
| 727 |
+
# [batch, 1, key_tokens]
|
| 728 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
| 729 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. paddle sdp attn)
|
| 730 |
+
if attention_mask is not None:
|
| 731 |
+
# assume that mask is expressed as:
|
| 732 |
+
# (1 = keep, 0 = discard)
|
| 733 |
+
# convert mask into a bias that can be added to attention scores:
|
| 734 |
+
# (keep = +0, discard = -10000.0)
|
| 735 |
+
attention_mask = (1 - attention_mask.cast(sample.dtype)) * -10000.0
|
| 736 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 737 |
+
|
| 738 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
| 739 |
+
if encoder_attention_mask is not None:
|
| 740 |
+
encoder_attention_mask = (1 - encoder_attention_mask.cast(sample.dtype)) * -10000.0
|
| 741 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 742 |
+
|
| 743 |
+
if encoder_attention_mask_1 is not None:
|
| 744 |
+
encoder_attention_mask_1 = (1 - encoder_attention_mask_1.cast(sample.dtype)) * -10000.0
|
| 745 |
+
encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1)
|
| 746 |
+
|
| 747 |
+
# 1. time
|
| 748 |
+
timesteps = timestep
|
| 749 |
+
if not paddle.is_tensor(timesteps):
|
| 750 |
+
if isinstance(timestep, float):
|
| 751 |
+
dtype = paddle.float32
|
| 752 |
+
else:
|
| 753 |
+
dtype = paddle.int64
|
| 754 |
+
timesteps = paddle.to_tensor([timesteps], dtype=dtype)
|
| 755 |
+
elif len(timesteps.shape) == 0:
|
| 756 |
+
timesteps = timesteps[None]
|
| 757 |
+
|
| 758 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 759 |
+
timesteps = timesteps.expand(
|
| 760 |
+
[
|
| 761 |
+
sample.shape[0],
|
| 762 |
+
]
|
| 763 |
+
)
|
| 764 |
+
t_emb = self.time_proj(timesteps)
|
| 765 |
+
|
| 766 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 767 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 768 |
+
# there might be better ways to encapsulate this.
|
| 769 |
+
t_emb = t_emb.cast(dtype=sample.dtype)
|
| 770 |
+
|
| 771 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 772 |
+
aug_emb = None
|
| 773 |
+
|
| 774 |
+
if self.class_embedding is not None:
|
| 775 |
+
if class_labels is None:
|
| 776 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 777 |
+
|
| 778 |
+
# NEW ADD maybe cast it to float16
|
| 779 |
+
class_labels = class_labels.cast(self.dtype)
|
| 780 |
+
if self.config.class_embed_type == "timestep":
|
| 781 |
+
class_labels = self.time_proj(class_labels)
|
| 782 |
+
|
| 783 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 784 |
+
# there might be better ways to encapsulate this.
|
| 785 |
+
class_labels = class_labels.cast(dtype=sample.dtype)
|
| 786 |
+
|
| 787 |
+
# NEW ADD maybe cast it to int64
|
| 788 |
+
if isinstance(self.class_embedding, nn.Embedding):
|
| 789 |
+
class_labels = class_labels.cast(paddle.int64)
|
| 790 |
+
class_emb = self.class_embedding(class_labels).cast(dtype=sample.dtype)
|
| 791 |
+
|
| 792 |
+
if self.config.class_embeddings_concat:
|
| 793 |
+
emb = paddle.concat([emb, class_emb], axis=-1)
|
| 794 |
+
else:
|
| 795 |
+
emb = emb + class_emb
|
| 796 |
+
|
| 797 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
| 798 |
+
|
| 799 |
+
if self.time_embed_act is not None:
|
| 800 |
+
emb = self.time_embed_act(emb)
|
| 801 |
+
|
| 802 |
+
# 2. pre-process
|
| 803 |
+
sample = self.conv_in(sample)
|
| 804 |
+
|
| 805 |
+
# 3. down
|
| 806 |
+
down_block_res_samples = (sample,)
|
| 807 |
+
for downsample_block in self.down_blocks:
|
| 808 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 809 |
+
sample, res_samples = downsample_block(
|
| 810 |
+
hidden_states=sample,
|
| 811 |
+
temb=emb,
|
| 812 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 813 |
+
attention_mask=attention_mask,
|
| 814 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 815 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 816 |
+
encoder_hidden_states_1=encoder_hidden_states_1,
|
| 817 |
+
encoder_attention_mask_1=encoder_attention_mask_1,
|
| 818 |
+
)
|
| 819 |
+
else:
|
| 820 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 821 |
+
|
| 822 |
+
down_block_res_samples += res_samples
|
| 823 |
+
|
| 824 |
+
# 4. mid
|
| 825 |
+
if self.mid_block is not None:
|
| 826 |
+
sample = self.mid_block(
|
| 827 |
+
sample,
|
| 828 |
+
emb,
|
| 829 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 830 |
+
attention_mask=attention_mask,
|
| 831 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 832 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 833 |
+
encoder_hidden_states_1=encoder_hidden_states_1,
|
| 834 |
+
encoder_attention_mask_1=encoder_attention_mask_1,
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
# 5. up
|
| 838 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 839 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 840 |
+
|
| 841 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 842 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 843 |
+
|
| 844 |
+
# if we have not reached the final block and need to forward the
|
| 845 |
+
# upsample size, we do it here
|
| 846 |
+
if not is_final_block and forward_upsample_size:
|
| 847 |
+
upsample_size = paddle.shape(down_block_res_samples[-1])[2:] # (NOTE,junnyu) make export happier
|
| 848 |
+
|
| 849 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 850 |
+
sample = upsample_block(
|
| 851 |
+
hidden_states=sample,
|
| 852 |
+
temb=emb,
|
| 853 |
+
res_hidden_states_tuple=res_samples,
|
| 854 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 855 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 856 |
+
upsample_size=upsample_size,
|
| 857 |
+
attention_mask=attention_mask,
|
| 858 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 859 |
+
encoder_hidden_states_1=encoder_hidden_states_1,
|
| 860 |
+
encoder_attention_mask_1=encoder_attention_mask_1,
|
| 861 |
+
)
|
| 862 |
+
else:
|
| 863 |
+
sample = upsample_block(
|
| 864 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
# 6. post-process
|
| 868 |
+
if self.conv_norm_out:
|
| 869 |
+
sample = self.conv_norm_out(sample)
|
| 870 |
+
sample = self.conv_act(sample)
|
| 871 |
+
sample = self.conv_out(sample)
|
| 872 |
+
|
| 873 |
+
if not return_dict:
|
| 874 |
+
return (sample,)
|
| 875 |
+
|
| 876 |
+
return UNet2DConditionOutput(sample=sample)
|
| 877 |
+
|
| 878 |
+
|
| 879 |
+
def get_down_block(
|
| 880 |
+
down_block_type,
|
| 881 |
+
num_layers,
|
| 882 |
+
in_channels,
|
| 883 |
+
out_channels,
|
| 884 |
+
temb_channels,
|
| 885 |
+
add_downsample,
|
| 886 |
+
resnet_eps,
|
| 887 |
+
resnet_act_fn,
|
| 888 |
+
transformer_layers_per_block=1,
|
| 889 |
+
num_attention_heads=None,
|
| 890 |
+
resnet_groups=None,
|
| 891 |
+
cross_attention_dim=None,
|
| 892 |
+
downsample_padding=None,
|
| 893 |
+
use_linear_projection=False,
|
| 894 |
+
only_cross_attention=False,
|
| 895 |
+
upcast_attention=False,
|
| 896 |
+
resnet_time_scale_shift="default",
|
| 897 |
+
):
|
| 898 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
| 899 |
+
if down_block_type == "DownBlock2D":
|
| 900 |
+
return DownBlock2D(
|
| 901 |
+
num_layers=num_layers,
|
| 902 |
+
in_channels=in_channels,
|
| 903 |
+
out_channels=out_channels,
|
| 904 |
+
temb_channels=temb_channels,
|
| 905 |
+
add_downsample=add_downsample,
|
| 906 |
+
resnet_eps=resnet_eps,
|
| 907 |
+
resnet_act_fn=resnet_act_fn,
|
| 908 |
+
resnet_groups=resnet_groups,
|
| 909 |
+
downsample_padding=downsample_padding,
|
| 910 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 911 |
+
)
|
| 912 |
+
elif down_block_type == "CrossAttnDownBlock2D":
|
| 913 |
+
if cross_attention_dim is None:
|
| 914 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
|
| 915 |
+
return CrossAttnDownBlock2D(
|
| 916 |
+
num_layers=num_layers,
|
| 917 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
| 918 |
+
in_channels=in_channels,
|
| 919 |
+
out_channels=out_channels,
|
| 920 |
+
temb_channels=temb_channels,
|
| 921 |
+
add_downsample=add_downsample,
|
| 922 |
+
resnet_eps=resnet_eps,
|
| 923 |
+
resnet_act_fn=resnet_act_fn,
|
| 924 |
+
resnet_groups=resnet_groups,
|
| 925 |
+
downsample_padding=downsample_padding,
|
| 926 |
+
cross_attention_dim=cross_attention_dim,
|
| 927 |
+
num_attention_heads=num_attention_heads,
|
| 928 |
+
use_linear_projection=use_linear_projection,
|
| 929 |
+
only_cross_attention=only_cross_attention,
|
| 930 |
+
upcast_attention=upcast_attention,
|
| 931 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 932 |
+
)
|
| 933 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
def get_up_block(
|
| 937 |
+
up_block_type,
|
| 938 |
+
num_layers,
|
| 939 |
+
in_channels,
|
| 940 |
+
out_channels,
|
| 941 |
+
prev_output_channel,
|
| 942 |
+
temb_channels,
|
| 943 |
+
add_upsample,
|
| 944 |
+
resnet_eps,
|
| 945 |
+
resnet_act_fn,
|
| 946 |
+
transformer_layers_per_block=1,
|
| 947 |
+
num_attention_heads=None,
|
| 948 |
+
resnet_groups=None,
|
| 949 |
+
cross_attention_dim=None,
|
| 950 |
+
use_linear_projection=False,
|
| 951 |
+
only_cross_attention=False,
|
| 952 |
+
upcast_attention=False,
|
| 953 |
+
resnet_time_scale_shift="default",
|
| 954 |
+
):
|
| 955 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
| 956 |
+
if up_block_type == "UpBlock2D":
|
| 957 |
+
return UpBlock2D(
|
| 958 |
+
num_layers=num_layers,
|
| 959 |
+
in_channels=in_channels,
|
| 960 |
+
out_channels=out_channels,
|
| 961 |
+
prev_output_channel=prev_output_channel,
|
| 962 |
+
temb_channels=temb_channels,
|
| 963 |
+
add_upsample=add_upsample,
|
| 964 |
+
resnet_eps=resnet_eps,
|
| 965 |
+
resnet_act_fn=resnet_act_fn,
|
| 966 |
+
resnet_groups=resnet_groups,
|
| 967 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 968 |
+
)
|
| 969 |
+
elif up_block_type == "CrossAttnUpBlock2D":
|
| 970 |
+
if cross_attention_dim is None:
|
| 971 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
|
| 972 |
+
return CrossAttnUpBlock2D(
|
| 973 |
+
num_layers=num_layers,
|
| 974 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
| 975 |
+
in_channels=in_channels,
|
| 976 |
+
out_channels=out_channels,
|
| 977 |
+
prev_output_channel=prev_output_channel,
|
| 978 |
+
temb_channels=temb_channels,
|
| 979 |
+
add_upsample=add_upsample,
|
| 980 |
+
resnet_eps=resnet_eps,
|
| 981 |
+
resnet_act_fn=resnet_act_fn,
|
| 982 |
+
resnet_groups=resnet_groups,
|
| 983 |
+
cross_attention_dim=cross_attention_dim,
|
| 984 |
+
num_attention_heads=num_attention_heads,
|
| 985 |
+
use_linear_projection=use_linear_projection,
|
| 986 |
+
only_cross_attention=only_cross_attention,
|
| 987 |
+
upcast_attention=upcast_attention,
|
| 988 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 989 |
+
)
|
| 990 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
| 991 |
+
|
| 992 |
+
|
| 993 |
+
class CrossAttnDownBlock2D(nn.Layer):
|
| 994 |
+
def __init__(
|
| 995 |
+
self,
|
| 996 |
+
in_channels: int,
|
| 997 |
+
out_channels: int,
|
| 998 |
+
temb_channels: int,
|
| 999 |
+
dropout: float = 0.0,
|
| 1000 |
+
num_layers: int = 1,
|
| 1001 |
+
transformer_layers_per_block: int = 1,
|
| 1002 |
+
resnet_eps: float = 1e-6,
|
| 1003 |
+
resnet_time_scale_shift: str = "default",
|
| 1004 |
+
resnet_act_fn: str = "swish",
|
| 1005 |
+
resnet_groups: int = 32,
|
| 1006 |
+
resnet_pre_norm: bool = True,
|
| 1007 |
+
num_attention_heads=1,
|
| 1008 |
+
cross_attention_dim=1280,
|
| 1009 |
+
output_scale_factor=1.0,
|
| 1010 |
+
downsample_padding=1,
|
| 1011 |
+
add_downsample=True,
|
| 1012 |
+
use_linear_projection=False,
|
| 1013 |
+
only_cross_attention=False,
|
| 1014 |
+
upcast_attention=False,
|
| 1015 |
+
):
|
| 1016 |
+
super().__init__()
|
| 1017 |
+
resnets = []
|
| 1018 |
+
attentions = []
|
| 1019 |
+
|
| 1020 |
+
self.has_cross_attention = True
|
| 1021 |
+
self.num_attention_heads = num_attention_heads
|
| 1022 |
+
|
| 1023 |
+
if isinstance(cross_attention_dim, int):
|
| 1024 |
+
cross_attention_dim = (cross_attention_dim,)
|
| 1025 |
+
if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4:
|
| 1026 |
+
raise ValueError(
|
| 1027 |
+
"Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention "
|
| 1028 |
+
f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}"
|
| 1029 |
+
)
|
| 1030 |
+
self.cross_attention_dim = cross_attention_dim
|
| 1031 |
+
|
| 1032 |
+
for i in range(num_layers):
|
| 1033 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 1034 |
+
resnets.append(
|
| 1035 |
+
ResnetBlock2D(
|
| 1036 |
+
in_channels=in_channels,
|
| 1037 |
+
out_channels=out_channels,
|
| 1038 |
+
temb_channels=temb_channels,
|
| 1039 |
+
eps=resnet_eps,
|
| 1040 |
+
groups=resnet_groups,
|
| 1041 |
+
dropout=dropout,
|
| 1042 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 1043 |
+
non_linearity=resnet_act_fn,
|
| 1044 |
+
output_scale_factor=output_scale_factor,
|
| 1045 |
+
pre_norm=resnet_pre_norm,
|
| 1046 |
+
)
|
| 1047 |
+
)
|
| 1048 |
+
for j in range(len(cross_attention_dim)):
|
| 1049 |
+
attentions.append(
|
| 1050 |
+
Transformer2DModel(
|
| 1051 |
+
num_attention_heads,
|
| 1052 |
+
out_channels // num_attention_heads,
|
| 1053 |
+
in_channels=out_channels,
|
| 1054 |
+
num_layers=transformer_layers_per_block,
|
| 1055 |
+
cross_attention_dim=cross_attention_dim[j],
|
| 1056 |
+
norm_num_groups=resnet_groups,
|
| 1057 |
+
use_linear_projection=use_linear_projection,
|
| 1058 |
+
only_cross_attention=only_cross_attention,
|
| 1059 |
+
upcast_attention=upcast_attention,
|
| 1060 |
+
double_self_attention=True if cross_attention_dim[j] is None else False,
|
| 1061 |
+
)
|
| 1062 |
+
)
|
| 1063 |
+
self.attentions = nn.LayerList(attentions)
|
| 1064 |
+
self.resnets = nn.LayerList(resnets)
|
| 1065 |
+
|
| 1066 |
+
if add_downsample:
|
| 1067 |
+
self.downsamplers = nn.LayerList(
|
| 1068 |
+
[
|
| 1069 |
+
Downsample2D(
|
| 1070 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
| 1071 |
+
)
|
| 1072 |
+
]
|
| 1073 |
+
)
|
| 1074 |
+
else:
|
| 1075 |
+
self.downsamplers = None
|
| 1076 |
+
|
| 1077 |
+
self.gradient_checkpointing = False
|
| 1078 |
+
|
| 1079 |
+
def forward(
|
| 1080 |
+
self,
|
| 1081 |
+
hidden_states: paddle.Tensor,
|
| 1082 |
+
temb: Optional[paddle.Tensor] = None,
|
| 1083 |
+
encoder_hidden_states: Optional[paddle.Tensor] = None,
|
| 1084 |
+
attention_mask: Optional[paddle.Tensor] = None,
|
| 1085 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1086 |
+
encoder_attention_mask: Optional[paddle.Tensor] = None,
|
| 1087 |
+
encoder_hidden_states_1: Optional[paddle.Tensor] = None,
|
| 1088 |
+
encoder_attention_mask_1: Optional[paddle.Tensor] = None,
|
| 1089 |
+
):
|
| 1090 |
+
output_states = ()
|
| 1091 |
+
num_layers = len(self.resnets)
|
| 1092 |
+
num_attention_per_layer = len(self.attentions) // num_layers
|
| 1093 |
+
|
| 1094 |
+
encoder_hidden_states_1 = (
|
| 1095 |
+
encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
|
| 1096 |
+
)
|
| 1097 |
+
encoder_attention_mask_1 = (
|
| 1098 |
+
encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
for i in range(num_layers):
|
| 1102 |
+
if self.gradient_checkpointing and not hidden_states.stop_gradient:
|
| 1103 |
+
|
| 1104 |
+
def create_custom_forward(module, return_dict=None):
|
| 1105 |
+
def custom_forward(*inputs):
|
| 1106 |
+
if return_dict is not None:
|
| 1107 |
+
return module(*inputs, return_dict=return_dict)
|
| 1108 |
+
else:
|
| 1109 |
+
return module(*inputs)
|
| 1110 |
+
|
| 1111 |
+
return custom_forward
|
| 1112 |
+
|
| 1113 |
+
ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False}
|
| 1114 |
+
hidden_states = recompute(
|
| 1115 |
+
create_custom_forward(self.resnets[i]),
|
| 1116 |
+
hidden_states,
|
| 1117 |
+
temb,
|
| 1118 |
+
**ckpt_kwargs,
|
| 1119 |
+
)
|
| 1120 |
+
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
| 1121 |
+
if cross_attention_dim is not None and idx <= 1:
|
| 1122 |
+
forward_encoder_hidden_states = encoder_hidden_states
|
| 1123 |
+
forward_encoder_attention_mask = encoder_attention_mask
|
| 1124 |
+
elif cross_attention_dim is not None and idx > 1:
|
| 1125 |
+
forward_encoder_hidden_states = encoder_hidden_states_1
|
| 1126 |
+
forward_encoder_attention_mask = encoder_attention_mask_1
|
| 1127 |
+
else:
|
| 1128 |
+
forward_encoder_hidden_states = None
|
| 1129 |
+
forward_encoder_attention_mask = None
|
| 1130 |
+
hidden_states = recompute(
|
| 1131 |
+
create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
|
| 1132 |
+
hidden_states,
|
| 1133 |
+
forward_encoder_hidden_states,
|
| 1134 |
+
None, # timestep
|
| 1135 |
+
None, # added_cond_kwargs
|
| 1136 |
+
None, # class_labels
|
| 1137 |
+
cross_attention_kwargs,
|
| 1138 |
+
attention_mask,
|
| 1139 |
+
forward_encoder_attention_mask,
|
| 1140 |
+
**ckpt_kwargs,
|
| 1141 |
+
)[0]
|
| 1142 |
+
else:
|
| 1143 |
+
hidden_states = self.resnets[i](hidden_states, temb)
|
| 1144 |
+
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
| 1145 |
+
if cross_attention_dim is not None and idx <= 1:
|
| 1146 |
+
forward_encoder_hidden_states = encoder_hidden_states
|
| 1147 |
+
forward_encoder_attention_mask = encoder_attention_mask
|
| 1148 |
+
elif cross_attention_dim is not None and idx > 1:
|
| 1149 |
+
forward_encoder_hidden_states = encoder_hidden_states_1
|
| 1150 |
+
forward_encoder_attention_mask = encoder_attention_mask_1
|
| 1151 |
+
else:
|
| 1152 |
+
forward_encoder_hidden_states = None
|
| 1153 |
+
forward_encoder_attention_mask = None
|
| 1154 |
+
hidden_states = self.attentions[i * num_attention_per_layer + idx](
|
| 1155 |
+
hidden_states,
|
| 1156 |
+
attention_mask=attention_mask,
|
| 1157 |
+
encoder_hidden_states=forward_encoder_hidden_states,
|
| 1158 |
+
encoder_attention_mask=forward_encoder_attention_mask,
|
| 1159 |
+
return_dict=False,
|
| 1160 |
+
)[0]
|
| 1161 |
+
|
| 1162 |
+
output_states = output_states + (hidden_states,)
|
| 1163 |
+
|
| 1164 |
+
if self.downsamplers is not None:
|
| 1165 |
+
for downsampler in self.downsamplers:
|
| 1166 |
+
hidden_states = downsampler(hidden_states)
|
| 1167 |
+
|
| 1168 |
+
output_states = output_states + (hidden_states,)
|
| 1169 |
+
|
| 1170 |
+
return hidden_states, output_states
|
| 1171 |
+
|
| 1172 |
+
|
| 1173 |
+
class UNetMidBlock2DCrossAttn(nn.Layer):
|
| 1174 |
+
def __init__(
|
| 1175 |
+
self,
|
| 1176 |
+
in_channels: int,
|
| 1177 |
+
temb_channels: int,
|
| 1178 |
+
dropout: float = 0.0,
|
| 1179 |
+
num_layers: int = 1,
|
| 1180 |
+
transformer_layers_per_block: int = 1,
|
| 1181 |
+
resnet_eps: float = 1e-6,
|
| 1182 |
+
resnet_time_scale_shift: str = "default",
|
| 1183 |
+
resnet_act_fn: str = "swish",
|
| 1184 |
+
resnet_groups: int = 32,
|
| 1185 |
+
resnet_pre_norm: bool = True,
|
| 1186 |
+
num_attention_heads=1,
|
| 1187 |
+
output_scale_factor=1.0,
|
| 1188 |
+
cross_attention_dim=1280,
|
| 1189 |
+
use_linear_projection=False,
|
| 1190 |
+
upcast_attention=False,
|
| 1191 |
+
):
|
| 1192 |
+
super().__init__()
|
| 1193 |
+
|
| 1194 |
+
self.has_cross_attention = True
|
| 1195 |
+
self.num_attention_heads = num_attention_heads
|
| 1196 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 1197 |
+
|
| 1198 |
+
if isinstance(cross_attention_dim, int):
|
| 1199 |
+
cross_attention_dim = (cross_attention_dim,)
|
| 1200 |
+
if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4:
|
| 1201 |
+
raise ValueError(
|
| 1202 |
+
"Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention "
|
| 1203 |
+
f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}"
|
| 1204 |
+
)
|
| 1205 |
+
self.cross_attention_dim = cross_attention_dim
|
| 1206 |
+
|
| 1207 |
+
# there is always at least one resnet
|
| 1208 |
+
resnets = [
|
| 1209 |
+
ResnetBlock2D(
|
| 1210 |
+
in_channels=in_channels,
|
| 1211 |
+
out_channels=in_channels,
|
| 1212 |
+
temb_channels=temb_channels,
|
| 1213 |
+
eps=resnet_eps,
|
| 1214 |
+
groups=resnet_groups,
|
| 1215 |
+
dropout=dropout,
|
| 1216 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 1217 |
+
non_linearity=resnet_act_fn,
|
| 1218 |
+
output_scale_factor=output_scale_factor,
|
| 1219 |
+
pre_norm=resnet_pre_norm,
|
| 1220 |
+
)
|
| 1221 |
+
]
|
| 1222 |
+
attentions = []
|
| 1223 |
+
|
| 1224 |
+
for i in range(num_layers):
|
| 1225 |
+
for j in range(len(cross_attention_dim)):
|
| 1226 |
+
attentions.append(
|
| 1227 |
+
Transformer2DModel(
|
| 1228 |
+
num_attention_heads,
|
| 1229 |
+
in_channels // num_attention_heads,
|
| 1230 |
+
in_channels=in_channels,
|
| 1231 |
+
num_layers=transformer_layers_per_block,
|
| 1232 |
+
cross_attention_dim=cross_attention_dim[j],
|
| 1233 |
+
norm_num_groups=resnet_groups,
|
| 1234 |
+
use_linear_projection=use_linear_projection,
|
| 1235 |
+
upcast_attention=upcast_attention,
|
| 1236 |
+
double_self_attention=True if cross_attention_dim[j] is None else False,
|
| 1237 |
+
)
|
| 1238 |
+
)
|
| 1239 |
+
resnets.append(
|
| 1240 |
+
ResnetBlock2D(
|
| 1241 |
+
in_channels=in_channels,
|
| 1242 |
+
out_channels=in_channels,
|
| 1243 |
+
temb_channels=temb_channels,
|
| 1244 |
+
eps=resnet_eps,
|
| 1245 |
+
groups=resnet_groups,
|
| 1246 |
+
dropout=dropout,
|
| 1247 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 1248 |
+
non_linearity=resnet_act_fn,
|
| 1249 |
+
output_scale_factor=output_scale_factor,
|
| 1250 |
+
pre_norm=resnet_pre_norm,
|
| 1251 |
+
)
|
| 1252 |
+
)
|
| 1253 |
+
|
| 1254 |
+
self.attentions = nn.LayerList(attentions)
|
| 1255 |
+
self.resnets = nn.LayerList(resnets)
|
| 1256 |
+
|
| 1257 |
+
self.gradient_checkpointing = False
|
| 1258 |
+
|
| 1259 |
+
def forward(
|
| 1260 |
+
self,
|
| 1261 |
+
hidden_states: paddle.Tensor,
|
| 1262 |
+
temb: Optional[paddle.Tensor] = None,
|
| 1263 |
+
encoder_hidden_states: Optional[paddle.Tensor] = None,
|
| 1264 |
+
attention_mask: Optional[paddle.Tensor] = None,
|
| 1265 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1266 |
+
encoder_attention_mask: Optional[paddle.Tensor] = None,
|
| 1267 |
+
encoder_hidden_states_1: Optional[paddle.Tensor] = None,
|
| 1268 |
+
encoder_attention_mask_1: Optional[paddle.Tensor] = None,
|
| 1269 |
+
) -> paddle.Tensor:
|
| 1270 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
| 1271 |
+
num_attention_per_layer = len(self.attentions) // (len(self.resnets) - 1)
|
| 1272 |
+
|
| 1273 |
+
encoder_hidden_states_1 = (
|
| 1274 |
+
encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
|
| 1275 |
+
)
|
| 1276 |
+
encoder_attention_mask_1 = (
|
| 1277 |
+
encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
|
| 1278 |
+
)
|
| 1279 |
+
|
| 1280 |
+
for i in range(len(self.resnets[1:])):
|
| 1281 |
+
if self.gradient_checkpointing and not hidden_states.stop_gradient:
|
| 1282 |
+
|
| 1283 |
+
def create_custom_forward(module, return_dict=None):
|
| 1284 |
+
def custom_forward(*inputs):
|
| 1285 |
+
if return_dict is not None:
|
| 1286 |
+
return module(*inputs, return_dict=return_dict)
|
| 1287 |
+
else:
|
| 1288 |
+
return module(*inputs)
|
| 1289 |
+
|
| 1290 |
+
return custom_forward
|
| 1291 |
+
|
| 1292 |
+
ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False}
|
| 1293 |
+
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
| 1294 |
+
if cross_attention_dim is not None and idx <= 1:
|
| 1295 |
+
forward_encoder_hidden_states = encoder_hidden_states
|
| 1296 |
+
forward_encoder_attention_mask = encoder_attention_mask
|
| 1297 |
+
elif cross_attention_dim is not None and idx > 1:
|
| 1298 |
+
forward_encoder_hidden_states = encoder_hidden_states_1
|
| 1299 |
+
forward_encoder_attention_mask = encoder_attention_mask_1
|
| 1300 |
+
else:
|
| 1301 |
+
forward_encoder_hidden_states = None
|
| 1302 |
+
forward_encoder_attention_mask = None
|
| 1303 |
+
hidden_states = recompute(
|
| 1304 |
+
create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
|
| 1305 |
+
hidden_states,
|
| 1306 |
+
forward_encoder_hidden_states,
|
| 1307 |
+
None, # timestep
|
| 1308 |
+
None, # added_cond_kwargs
|
| 1309 |
+
None, # class_labels
|
| 1310 |
+
cross_attention_kwargs,
|
| 1311 |
+
attention_mask,
|
| 1312 |
+
forward_encoder_attention_mask,
|
| 1313 |
+
**ckpt_kwargs,
|
| 1314 |
+
)[0]
|
| 1315 |
+
hidden_states = recompute(
|
| 1316 |
+
create_custom_forward(self.resnets[i + 1]),
|
| 1317 |
+
hidden_states,
|
| 1318 |
+
temb,
|
| 1319 |
+
**ckpt_kwargs,
|
| 1320 |
+
)
|
| 1321 |
+
else:
|
| 1322 |
+
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
| 1323 |
+
if cross_attention_dim is not None and idx <= 1:
|
| 1324 |
+
forward_encoder_hidden_states = encoder_hidden_states
|
| 1325 |
+
forward_encoder_attention_mask = encoder_attention_mask
|
| 1326 |
+
elif cross_attention_dim is not None and idx > 1:
|
| 1327 |
+
forward_encoder_hidden_states = encoder_hidden_states_1
|
| 1328 |
+
forward_encoder_attention_mask = encoder_attention_mask_1
|
| 1329 |
+
else:
|
| 1330 |
+
forward_encoder_hidden_states = None
|
| 1331 |
+
forward_encoder_attention_mask = None
|
| 1332 |
+
hidden_states = self.attentions[i * num_attention_per_layer + idx](
|
| 1333 |
+
hidden_states,
|
| 1334 |
+
attention_mask=attention_mask,
|
| 1335 |
+
encoder_hidden_states=forward_encoder_hidden_states,
|
| 1336 |
+
encoder_attention_mask=forward_encoder_attention_mask,
|
| 1337 |
+
return_dict=False,
|
| 1338 |
+
)[0]
|
| 1339 |
+
|
| 1340 |
+
hidden_states = self.resnets[i + 1](hidden_states, temb)
|
| 1341 |
+
|
| 1342 |
+
return hidden_states
|
| 1343 |
+
|
| 1344 |
+
|
| 1345 |
+
class CrossAttnUpBlock2D(nn.Layer):
|
| 1346 |
+
def __init__(
|
| 1347 |
+
self,
|
| 1348 |
+
in_channels: int,
|
| 1349 |
+
out_channels: int,
|
| 1350 |
+
prev_output_channel: int,
|
| 1351 |
+
temb_channels: int,
|
| 1352 |
+
dropout: float = 0.0,
|
| 1353 |
+
num_layers: int = 1,
|
| 1354 |
+
transformer_layers_per_block: int = 1,
|
| 1355 |
+
resnet_eps: float = 1e-6,
|
| 1356 |
+
resnet_time_scale_shift: str = "default",
|
| 1357 |
+
resnet_act_fn: str = "swish",
|
| 1358 |
+
resnet_groups: int = 32,
|
| 1359 |
+
resnet_pre_norm: bool = True,
|
| 1360 |
+
num_attention_heads=1,
|
| 1361 |
+
cross_attention_dim=1280,
|
| 1362 |
+
output_scale_factor=1.0,
|
| 1363 |
+
add_upsample=True,
|
| 1364 |
+
use_linear_projection=False,
|
| 1365 |
+
only_cross_attention=False,
|
| 1366 |
+
upcast_attention=False,
|
| 1367 |
+
):
|
| 1368 |
+
super().__init__()
|
| 1369 |
+
resnets = []
|
| 1370 |
+
attentions = []
|
| 1371 |
+
|
| 1372 |
+
self.has_cross_attention = True
|
| 1373 |
+
self.num_attention_heads = num_attention_heads
|
| 1374 |
+
|
| 1375 |
+
if isinstance(cross_attention_dim, int):
|
| 1376 |
+
cross_attention_dim = (cross_attention_dim,)
|
| 1377 |
+
if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4:
|
| 1378 |
+
raise ValueError(
|
| 1379 |
+
"Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention "
|
| 1380 |
+
f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}"
|
| 1381 |
+
)
|
| 1382 |
+
self.cross_attention_dim = cross_attention_dim
|
| 1383 |
+
|
| 1384 |
+
for i in range(num_layers):
|
| 1385 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
| 1386 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
| 1387 |
+
|
| 1388 |
+
resnets.append(
|
| 1389 |
+
ResnetBlock2D(
|
| 1390 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
| 1391 |
+
out_channels=out_channels,
|
| 1392 |
+
temb_channels=temb_channels,
|
| 1393 |
+
eps=resnet_eps,
|
| 1394 |
+
groups=resnet_groups,
|
| 1395 |
+
dropout=dropout,
|
| 1396 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 1397 |
+
non_linearity=resnet_act_fn,
|
| 1398 |
+
output_scale_factor=output_scale_factor,
|
| 1399 |
+
pre_norm=resnet_pre_norm,
|
| 1400 |
+
)
|
| 1401 |
+
)
|
| 1402 |
+
for j in range(len(cross_attention_dim)):
|
| 1403 |
+
attentions.append(
|
| 1404 |
+
Transformer2DModel(
|
| 1405 |
+
num_attention_heads,
|
| 1406 |
+
out_channels // num_attention_heads,
|
| 1407 |
+
in_channels=out_channels,
|
| 1408 |
+
num_layers=transformer_layers_per_block,
|
| 1409 |
+
cross_attention_dim=cross_attention_dim[j],
|
| 1410 |
+
norm_num_groups=resnet_groups,
|
| 1411 |
+
use_linear_projection=use_linear_projection,
|
| 1412 |
+
only_cross_attention=only_cross_attention,
|
| 1413 |
+
upcast_attention=upcast_attention,
|
| 1414 |
+
double_self_attention=True if cross_attention_dim[j] is None else False,
|
| 1415 |
+
)
|
| 1416 |
+
)
|
| 1417 |
+
self.attentions = nn.LayerList(attentions)
|
| 1418 |
+
self.resnets = nn.LayerList(resnets)
|
| 1419 |
+
|
| 1420 |
+
if add_upsample:
|
| 1421 |
+
self.upsamplers = nn.LayerList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
| 1422 |
+
else:
|
| 1423 |
+
self.upsamplers = None
|
| 1424 |
+
|
| 1425 |
+
self.gradient_checkpointing = False
|
| 1426 |
+
|
| 1427 |
+
def forward(
|
| 1428 |
+
self,
|
| 1429 |
+
hidden_states: paddle.Tensor,
|
| 1430 |
+
res_hidden_states_tuple: Tuple[paddle.Tensor, ...],
|
| 1431 |
+
temb: Optional[paddle.Tensor] = None,
|
| 1432 |
+
encoder_hidden_states: Optional[paddle.Tensor] = None,
|
| 1433 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1434 |
+
upsample_size: Optional[int] = None,
|
| 1435 |
+
attention_mask: Optional[paddle.Tensor] = None,
|
| 1436 |
+
encoder_attention_mask: Optional[paddle.Tensor] = None,
|
| 1437 |
+
encoder_hidden_states_1: Optional[paddle.Tensor] = None,
|
| 1438 |
+
encoder_attention_mask_1: Optional[paddle.Tensor] = None,
|
| 1439 |
+
):
|
| 1440 |
+
num_layers = len(self.resnets)
|
| 1441 |
+
num_attention_per_layer = len(self.attentions) // num_layers
|
| 1442 |
+
|
| 1443 |
+
encoder_hidden_states_1 = (
|
| 1444 |
+
encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
|
| 1445 |
+
)
|
| 1446 |
+
encoder_attention_mask_1 = (
|
| 1447 |
+
encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
|
| 1448 |
+
)
|
| 1449 |
+
|
| 1450 |
+
for i in range(num_layers):
|
| 1451 |
+
# pop res hidden states
|
| 1452 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 1453 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 1454 |
+
hidden_states = paddle.concat([hidden_states, res_hidden_states], axis=1)
|
| 1455 |
+
|
| 1456 |
+
if self.gradient_checkpointing and not hidden_states.stop_gradient:
|
| 1457 |
+
|
| 1458 |
+
def create_custom_forward(module, return_dict=None):
|
| 1459 |
+
def custom_forward(*inputs):
|
| 1460 |
+
if return_dict is not None:
|
| 1461 |
+
return module(*inputs, return_dict=return_dict)
|
| 1462 |
+
else:
|
| 1463 |
+
return module(*inputs)
|
| 1464 |
+
|
| 1465 |
+
return custom_forward
|
| 1466 |
+
|
| 1467 |
+
ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False}
|
| 1468 |
+
hidden_states = recompute(
|
| 1469 |
+
create_custom_forward(self.resnets[i]),
|
| 1470 |
+
hidden_states,
|
| 1471 |
+
temb,
|
| 1472 |
+
**ckpt_kwargs,
|
| 1473 |
+
)
|
| 1474 |
+
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
| 1475 |
+
if cross_attention_dim is not None and idx <= 1:
|
| 1476 |
+
forward_encoder_hidden_states = encoder_hidden_states
|
| 1477 |
+
forward_encoder_attention_mask = encoder_attention_mask
|
| 1478 |
+
elif cross_attention_dim is not None and idx > 1:
|
| 1479 |
+
forward_encoder_hidden_states = encoder_hidden_states_1
|
| 1480 |
+
forward_encoder_attention_mask = encoder_attention_mask_1
|
| 1481 |
+
else:
|
| 1482 |
+
forward_encoder_hidden_states = None
|
| 1483 |
+
forward_encoder_attention_mask = None
|
| 1484 |
+
hidden_states = recompute(
|
| 1485 |
+
create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
|
| 1486 |
+
hidden_states,
|
| 1487 |
+
forward_encoder_hidden_states,
|
| 1488 |
+
None, # timestep
|
| 1489 |
+
None, # added_cond_kwargs
|
| 1490 |
+
None, # class_labels
|
| 1491 |
+
cross_attention_kwargs,
|
| 1492 |
+
attention_mask,
|
| 1493 |
+
forward_encoder_attention_mask,
|
| 1494 |
+
**ckpt_kwargs,
|
| 1495 |
+
)[0]
|
| 1496 |
+
else:
|
| 1497 |
+
hidden_states = self.resnets[i](hidden_states, temb)
|
| 1498 |
+
for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
|
| 1499 |
+
if cross_attention_dim is not None and idx <= 1:
|
| 1500 |
+
forward_encoder_hidden_states = encoder_hidden_states
|
| 1501 |
+
forward_encoder_attention_mask = encoder_attention_mask
|
| 1502 |
+
elif cross_attention_dim is not None and idx > 1:
|
| 1503 |
+
forward_encoder_hidden_states = encoder_hidden_states_1
|
| 1504 |
+
forward_encoder_attention_mask = encoder_attention_mask_1
|
| 1505 |
+
else:
|
| 1506 |
+
forward_encoder_hidden_states = None
|
| 1507 |
+
forward_encoder_attention_mask = None
|
| 1508 |
+
hidden_states = self.attentions[i * num_attention_per_layer + idx](
|
| 1509 |
+
hidden_states,
|
| 1510 |
+
attention_mask=attention_mask,
|
| 1511 |
+
encoder_hidden_states=forward_encoder_hidden_states,
|
| 1512 |
+
encoder_attention_mask=forward_encoder_attention_mask,
|
| 1513 |
+
return_dict=False,
|
| 1514 |
+
)[0]
|
| 1515 |
+
|
| 1516 |
+
if self.upsamplers is not None:
|
| 1517 |
+
for upsampler in self.upsamplers:
|
| 1518 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 1519 |
+
|
| 1520 |
+
return hidden_states
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/audioldm2/pipeline_audioldm2.py
ADDED
|
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 CVSSP, ByteDance and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import paddle
|
| 20 |
+
|
| 21 |
+
from ppdiffusers.transformers import (
|
| 22 |
+
ClapFeatureExtractor,
|
| 23 |
+
ClapModel,
|
| 24 |
+
GPT2Model,
|
| 25 |
+
RobertaTokenizer,
|
| 26 |
+
SpeechT5HifiGan,
|
| 27 |
+
T5EncoderModel,
|
| 28 |
+
T5Tokenizer,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
from ...models import AutoencoderKL
|
| 32 |
+
from ...models.modeling_utils import get_parameter_dtype
|
| 33 |
+
from ...schedulers import KarrasDiffusionSchedulers
|
| 34 |
+
from ...utils import is_librosa_available, logging, replace_example_docstring
|
| 35 |
+
from ...utils.paddle_utils import randn_tensor
|
| 36 |
+
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
| 37 |
+
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
| 38 |
+
|
| 39 |
+
if is_librosa_available():
|
| 40 |
+
import librosa
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 43 |
+
|
| 44 |
+
EXAMPLE_DOC_STRING = """
|
| 45 |
+
Examples:
|
| 46 |
+
```py
|
| 47 |
+
>>> import scipy
|
| 48 |
+
>>> import paddle
|
| 49 |
+
>>> from ppdiffusers import AudioLDM2Pipeline
|
| 50 |
+
|
| 51 |
+
>>> repo_id = "cvssp/audioldm2"
|
| 52 |
+
>>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, paddle_dtype=paddle.float16)
|
| 53 |
+
|
| 54 |
+
>>> # define the prompts
|
| 55 |
+
>>> prompt = "The sound of a hammer hitting a wooden surface."
|
| 56 |
+
>>> negative_prompt = "Low quality."
|
| 57 |
+
|
| 58 |
+
>>> # set the seed for generator
|
| 59 |
+
>>> generator = paddle.Generator().manual_seed(0)
|
| 60 |
+
|
| 61 |
+
>>> # run the generation
|
| 62 |
+
>>> audio = pipe(
|
| 63 |
+
... prompt,
|
| 64 |
+
... negative_prompt=negative_prompt,
|
| 65 |
+
... num_inference_steps=200,
|
| 66 |
+
... audio_length_in_s=10.0,
|
| 67 |
+
... num_waveforms_per_prompt=3,
|
| 68 |
+
... generator=generator,
|
| 69 |
+
... ).audios
|
| 70 |
+
|
| 71 |
+
>>> # save the best audio sample (index 0) as a .wav file
|
| 72 |
+
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0])
|
| 73 |
+
```
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def prepare_inputs_for_generation(
|
| 78 |
+
inputs_embeds,
|
| 79 |
+
attention_mask=None,
|
| 80 |
+
past_key_values=None,
|
| 81 |
+
**kwargs,
|
| 82 |
+
):
|
| 83 |
+
if past_key_values is not None:
|
| 84 |
+
# only last token for inputs_embeds if past is defined in kwargs
|
| 85 |
+
inputs_embeds = inputs_embeds[:, -1:]
|
| 86 |
+
|
| 87 |
+
return {
|
| 88 |
+
"inputs_embeds": inputs_embeds,
|
| 89 |
+
"attention_mask": attention_mask,
|
| 90 |
+
"past_key_values": past_key_values,
|
| 91 |
+
"use_cache": kwargs.get("use_cache"),
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class AudioLDM2Pipeline(DiffusionPipeline):
|
| 96 |
+
r"""
|
| 97 |
+
Pipeline for text-to-audio generation using AudioLDM2.
|
| 98 |
+
|
| 99 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 100 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
vae ([`AutoencoderKL`]):
|
| 104 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 105 |
+
text_encoder ([`~transformers.ClapModel`]):
|
| 106 |
+
First frozen text-encoder. AudioLDM2 uses the joint audio-text embedding model
|
| 107 |
+
[CLAP](https://huggingface.co/docs/transformers/model_doc/clap#transformers.CLAPTextModelWithProjection),
|
| 108 |
+
specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The
|
| 109 |
+
text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to
|
| 110 |
+
rank generated waveforms against the text prompt by computing similarity scores.
|
| 111 |
+
text_encoder_2 ([`~transformers.T5EncoderModel`]):
|
| 112 |
+
Second frozen text-encoder. AudioLDM2 uses the encoder of
|
| 113 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
| 114 |
+
[google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant.
|
| 115 |
+
projection_model ([`AudioLDM2ProjectionModel`]):
|
| 116 |
+
A trained model used to linearly project the hidden-states from the first and second text encoder models
|
| 117 |
+
and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are
|
| 118 |
+
concatenated to give the input to the language model.
|
| 119 |
+
language_model ([`~transformers.GPT2Model`]):
|
| 120 |
+
An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected
|
| 121 |
+
outputs from the two text encoders.
|
| 122 |
+
tokenizer ([`~transformers.RobertaTokenizer`]):
|
| 123 |
+
Tokenizer to tokenize text for the first frozen text-encoder.
|
| 124 |
+
tokenizer_2 ([`~transformers.T5Tokenizer`]):
|
| 125 |
+
Tokenizer to tokenize text for the second frozen text-encoder.
|
| 126 |
+
feature_extractor ([`~transformers.ClapFeatureExtractor`]):
|
| 127 |
+
Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring.
|
| 128 |
+
unet ([`UNet2DConditionModel`]):
|
| 129 |
+
A `UNet2DConditionModel` to denoise the encoded audio latents.
|
| 130 |
+
scheduler ([`SchedulerMixin`]):
|
| 131 |
+
A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
|
| 132 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 133 |
+
vocoder ([`~transformers.SpeechT5HifiGan`]):
|
| 134 |
+
Vocoder of class `SpeechT5HifiGan` to convert the mel-spectrogram latents to the final audio waveform.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
vae: AutoencoderKL,
|
| 140 |
+
text_encoder: ClapModel,
|
| 141 |
+
text_encoder_2: T5EncoderModel,
|
| 142 |
+
projection_model: AudioLDM2ProjectionModel,
|
| 143 |
+
language_model: GPT2Model,
|
| 144 |
+
tokenizer: RobertaTokenizer,
|
| 145 |
+
tokenizer_2: T5Tokenizer,
|
| 146 |
+
feature_extractor: ClapFeatureExtractor,
|
| 147 |
+
unet: AudioLDM2UNet2DConditionModel,
|
| 148 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 149 |
+
vocoder: SpeechT5HifiGan,
|
| 150 |
+
):
|
| 151 |
+
super().__init__()
|
| 152 |
+
|
| 153 |
+
self.register_modules(
|
| 154 |
+
vae=vae,
|
| 155 |
+
text_encoder=text_encoder,
|
| 156 |
+
text_encoder_2=text_encoder_2,
|
| 157 |
+
projection_model=projection_model,
|
| 158 |
+
language_model=language_model,
|
| 159 |
+
tokenizer=tokenizer,
|
| 160 |
+
tokenizer_2=tokenizer_2,
|
| 161 |
+
feature_extractor=feature_extractor,
|
| 162 |
+
unet=unet,
|
| 163 |
+
scheduler=scheduler,
|
| 164 |
+
vocoder=vocoder,
|
| 165 |
+
)
|
| 166 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 167 |
+
|
| 168 |
+
def generate_language_model(
|
| 169 |
+
self,
|
| 170 |
+
inputs_embeds: paddle.Tensor = None,
|
| 171 |
+
max_new_tokens: int = 8,
|
| 172 |
+
**model_kwargs,
|
| 173 |
+
):
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs.
|
| 177 |
+
|
| 178 |
+
Parameters:
|
| 179 |
+
inputs_embeds (`paddle.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 180 |
+
The sequence used as a prompt for the generation.
|
| 181 |
+
max_new_tokens (`int`):
|
| 182 |
+
Number of new tokens to generate.
|
| 183 |
+
model_kwargs (`Dict[str, Any]`, *optional*):
|
| 184 |
+
Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the `forward`
|
| 185 |
+
function of the model.
|
| 186 |
+
|
| 187 |
+
Return:
|
| 188 |
+
`inputs_embeds (`paddle.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 189 |
+
The sequence of generated hidden-states.
|
| 190 |
+
"""
|
| 191 |
+
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
|
| 192 |
+
for _ in range(max_new_tokens):
|
| 193 |
+
# prepare model inputs
|
| 194 |
+
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
|
| 195 |
+
|
| 196 |
+
# forward pass to get next hidden states
|
| 197 |
+
output = self.language_model(**model_inputs, return_dict=True)
|
| 198 |
+
|
| 199 |
+
next_hidden_states = output.last_hidden_state
|
| 200 |
+
|
| 201 |
+
# Update the model input
|
| 202 |
+
inputs_embeds = paddle.concat([inputs_embeds, next_hidden_states[:, -1:, :]], axis=1)
|
| 203 |
+
|
| 204 |
+
# Update generated hidden states, model inputs, and length for next step
|
| 205 |
+
model_kwargs = self.language_model.update_model_kwargs_for_generation(output, model_kwargs)
|
| 206 |
+
|
| 207 |
+
return inputs_embeds[:, -max_new_tokens:, :]
|
| 208 |
+
|
| 209 |
+
def encode_prompt(
|
| 210 |
+
self,
|
| 211 |
+
prompt,
|
| 212 |
+
num_waveforms_per_prompt,
|
| 213 |
+
do_classifier_free_guidance,
|
| 214 |
+
negative_prompt=None,
|
| 215 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 216 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 217 |
+
generated_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 218 |
+
negative_generated_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 219 |
+
attention_mask: Optional[paddle.Tensor] = None,
|
| 220 |
+
negative_attention_mask: Optional[paddle.Tensor] = None,
|
| 221 |
+
max_new_tokens: Optional[int] = None,
|
| 222 |
+
):
|
| 223 |
+
r"""
|
| 224 |
+
Encodes the prompt into text encoder hidden states.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 228 |
+
prompt to be encoded
|
| 229 |
+
num_waveforms_per_prompt (`int`):
|
| 230 |
+
number of waveforms that should be generated per prompt
|
| 231 |
+
do_classifier_free_guidance (`bool`):
|
| 232 |
+
whether to use classifier free guidance or not
|
| 233 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 234 |
+
The prompt or prompts not to guide the audio generation. If not defined, one has to pass
|
| 235 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 236 |
+
less than `1`).
|
| 237 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 238 |
+
Pre-computed text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, *e.g.*
|
| 239 |
+
prompt weighting. If not provided, text embeddings will be computed from `prompt` input argument.
|
| 240 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 241 |
+
Pre-computed negative text embeddings from the Flan T5 model. Can be used to easily tweak text inputs,
|
| 242 |
+
*e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
| 243 |
+
`negative_prompt` input argument.
|
| 244 |
+
generated_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 245 |
+
Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
|
| 246 |
+
*e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
|
| 247 |
+
argument.
|
| 248 |
+
negative_generated_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 249 |
+
Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
|
| 250 |
+
inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
| 251 |
+
`negative_prompt` input argument.
|
| 252 |
+
attention_mask (`paddle.Tensor`, *optional*):
|
| 253 |
+
Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
|
| 254 |
+
be computed from `prompt` input argument.
|
| 255 |
+
negative_attention_mask (`paddle.Tensor`, *optional*):
|
| 256 |
+
Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
|
| 257 |
+
mask will be computed from `negative_prompt` input argument.
|
| 258 |
+
max_new_tokens (`int`, *optional*, defaults to None):
|
| 259 |
+
The number of new tokens to generate with the GPT2 language model.
|
| 260 |
+
Returns:
|
| 261 |
+
prompt_embeds (`paddle.Tensor`):
|
| 262 |
+
Text embeddings from the Flan T5 model.
|
| 263 |
+
attention_mask (`paddle.Tensor`):
|
| 264 |
+
Attention mask to be applied to the `prompt_embeds`.
|
| 265 |
+
generated_prompt_embeds (`paddle.Tensor`):
|
| 266 |
+
Text embeddings generated from the GPT2 langauge model.
|
| 267 |
+
|
| 268 |
+
Example:
|
| 269 |
+
|
| 270 |
+
```python
|
| 271 |
+
>>> import scipy
|
| 272 |
+
>>> import paddle
|
| 273 |
+
>>> from ppdiffusers import AudioLDM2Pipeline
|
| 274 |
+
|
| 275 |
+
>>> repo_id = "cvssp/audioldm2"
|
| 276 |
+
>>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, paddle_dtype=paddle.float16)
|
| 277 |
+
|
| 278 |
+
>>> # Get text embedding vectors
|
| 279 |
+
>>> prompt_embeds, attention_mask, generated_prompt_embeds = pipe.encode_prompt(
|
| 280 |
+
... prompt="Techno music with a strong, upbeat tempo and high melodic riffs",
|
| 281 |
+
... do_classifier_free_guidance=True,
|
| 282 |
+
... )
|
| 283 |
+
|
| 284 |
+
>>> # Pass text embeddings to pipeline for text-conditional audio generation
|
| 285 |
+
>>> audio = pipe(
|
| 286 |
+
... prompt_embeds=prompt_embeds,
|
| 287 |
+
... attention_mask=attention_mask,
|
| 288 |
+
... generated_prompt_embeds=generated_prompt_embeds,
|
| 289 |
+
... num_inference_steps=200,
|
| 290 |
+
... audio_length_in_s=10.0,
|
| 291 |
+
... ).audios[0]
|
| 292 |
+
|
| 293 |
+
>>> # save generated audio sample
|
| 294 |
+
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
|
| 295 |
+
```"""
|
| 296 |
+
if prompt is not None and isinstance(prompt, str):
|
| 297 |
+
batch_size = 1
|
| 298 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 299 |
+
batch_size = len(prompt)
|
| 300 |
+
else:
|
| 301 |
+
batch_size = prompt_embeds.shape[0]
|
| 302 |
+
|
| 303 |
+
# Define tokenizers and text encoders
|
| 304 |
+
tokenizers = [self.tokenizer, self.tokenizer_2]
|
| 305 |
+
text_encoders = [self.text_encoder, self.text_encoder_2]
|
| 306 |
+
|
| 307 |
+
if prompt_embeds is None:
|
| 308 |
+
prompt_embeds_list = []
|
| 309 |
+
attention_mask_list = []
|
| 310 |
+
|
| 311 |
+
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
| 312 |
+
text_inputs = tokenizer(
|
| 313 |
+
prompt,
|
| 314 |
+
padding="max_length" if isinstance(tokenizer, RobertaTokenizer) else True,
|
| 315 |
+
max_length=tokenizer.model_max_length,
|
| 316 |
+
truncation=True,
|
| 317 |
+
return_attention_mask=True,
|
| 318 |
+
return_tensors="pd",
|
| 319 |
+
)
|
| 320 |
+
text_input_ids = text_inputs.input_ids
|
| 321 |
+
attention_mask = text_inputs.attention_mask
|
| 322 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pd").input_ids
|
| 323 |
+
|
| 324 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not paddle.equal_all(
|
| 325 |
+
text_input_ids, untruncated_ids
|
| 326 |
+
):
|
| 327 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
| 328 |
+
logger.warning(
|
| 329 |
+
f"The following part of your input was truncated because {text_encoder.config.model_type} can "
|
| 330 |
+
f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
text_input_ids = text_input_ids
|
| 334 |
+
attention_mask = attention_mask
|
| 335 |
+
|
| 336 |
+
if text_encoder.config.model_type == "clap":
|
| 337 |
+
prompt_embeds = text_encoder.get_text_features(
|
| 338 |
+
text_input_ids,
|
| 339 |
+
attention_mask=attention_mask,
|
| 340 |
+
)
|
| 341 |
+
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
|
| 342 |
+
prompt_embeds = prompt_embeds[:, None, :]
|
| 343 |
+
# make sure that we attend to this single hidden-state
|
| 344 |
+
attention_mask = paddle.ones((batch_size, 1), dtype=attention_mask.dtype)
|
| 345 |
+
else:
|
| 346 |
+
prompt_embeds = text_encoder(
|
| 347 |
+
text_input_ids,
|
| 348 |
+
attention_mask=attention_mask,
|
| 349 |
+
)
|
| 350 |
+
prompt_embeds = prompt_embeds[0]
|
| 351 |
+
|
| 352 |
+
prompt_embeds_list.append(prompt_embeds)
|
| 353 |
+
attention_mask_list.append(attention_mask)
|
| 354 |
+
|
| 355 |
+
projection_output = self.projection_model(
|
| 356 |
+
hidden_states=prompt_embeds_list[0],
|
| 357 |
+
hidden_states_1=prompt_embeds_list[1],
|
| 358 |
+
attention_mask=attention_mask_list[0],
|
| 359 |
+
attention_mask_1=attention_mask_list[1],
|
| 360 |
+
)
|
| 361 |
+
projected_prompt_embeds = projection_output.hidden_states
|
| 362 |
+
projected_attention_mask = projection_output.attention_mask
|
| 363 |
+
|
| 364 |
+
generated_prompt_embeds = self.generate_language_model(
|
| 365 |
+
projected_prompt_embeds,
|
| 366 |
+
attention_mask=projected_attention_mask,
|
| 367 |
+
max_new_tokens=max_new_tokens,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
prompt_embeds = prompt_embeds.cast(dtype=get_parameter_dtype(self.text_encoder_2))
|
| 371 |
+
attention_mask = (
|
| 372 |
+
attention_mask if attention_mask is not None else paddle.ones(prompt_embeds.shape[:2], dtype=paddle.int64)
|
| 373 |
+
)
|
| 374 |
+
generated_prompt_embeds = generated_prompt_embeds.cast(dtype=get_parameter_dtype(self.language_model))
|
| 375 |
+
|
| 376 |
+
bs_embed, seq_len, hidden_size = prompt_embeds.shape
|
| 377 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 378 |
+
prompt_embeds = prompt_embeds.tile([1, num_waveforms_per_prompt, 1])
|
| 379 |
+
prompt_embeds = prompt_embeds.reshape([bs_embed * num_waveforms_per_prompt, seq_len, hidden_size])
|
| 380 |
+
|
| 381 |
+
# duplicate attention mask for each generation per prompt
|
| 382 |
+
attention_mask = attention_mask.tile([1, num_waveforms_per_prompt])
|
| 383 |
+
attention_mask = attention_mask.reshape([bs_embed * num_waveforms_per_prompt, seq_len])
|
| 384 |
+
|
| 385 |
+
bs_embed, seq_len, hidden_size = generated_prompt_embeds.shape
|
| 386 |
+
# duplicate generated embeddings for each generation per prompt, using mps friendly method
|
| 387 |
+
generated_prompt_embeds = generated_prompt_embeds.tile([1, num_waveforms_per_prompt, 1])
|
| 388 |
+
generated_prompt_embeds = generated_prompt_embeds.reshape(
|
| 389 |
+
[bs_embed * num_waveforms_per_prompt, seq_len, hidden_size]
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
# get unconditional embeddings for classifier free guidance
|
| 393 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 394 |
+
uncond_tokens: List[str]
|
| 395 |
+
if negative_prompt is None:
|
| 396 |
+
uncond_tokens = [""] * batch_size
|
| 397 |
+
elif type(prompt) is not type(negative_prompt):
|
| 398 |
+
raise TypeError(
|
| 399 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 400 |
+
f" {type(prompt)}."
|
| 401 |
+
)
|
| 402 |
+
elif isinstance(negative_prompt, str):
|
| 403 |
+
uncond_tokens = [negative_prompt]
|
| 404 |
+
elif batch_size != len(negative_prompt):
|
| 405 |
+
raise ValueError(
|
| 406 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 407 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 408 |
+
" the batch size of `prompt`."
|
| 409 |
+
)
|
| 410 |
+
else:
|
| 411 |
+
uncond_tokens = negative_prompt
|
| 412 |
+
|
| 413 |
+
negative_prompt_embeds_list = []
|
| 414 |
+
negative_attention_mask_list = []
|
| 415 |
+
max_length = prompt_embeds.shape[1]
|
| 416 |
+
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
| 417 |
+
uncond_input = tokenizer(
|
| 418 |
+
uncond_tokens,
|
| 419 |
+
padding="max_length",
|
| 420 |
+
max_length=tokenizer.model_max_length if isinstance(tokenizer, RobertaTokenizer) else max_length,
|
| 421 |
+
truncation=True,
|
| 422 |
+
return_attention_mask=True,
|
| 423 |
+
return_tensors="pd",
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
uncond_input_ids = uncond_input.input_ids
|
| 427 |
+
negative_attention_mask = uncond_input.attention_mask
|
| 428 |
+
|
| 429 |
+
if text_encoder.config.model_type == "clap":
|
| 430 |
+
negative_prompt_embeds = text_encoder.get_text_features(
|
| 431 |
+
uncond_input_ids,
|
| 432 |
+
attention_mask=negative_attention_mask,
|
| 433 |
+
)
|
| 434 |
+
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
|
| 435 |
+
negative_prompt_embeds = negative_prompt_embeds[:, None, :]
|
| 436 |
+
# make sure that we attend to this single hidden-state
|
| 437 |
+
negative_attention_mask = paddle.ones((batch_size, 1), dtype=negative_attention_mask.dtype)
|
| 438 |
+
else:
|
| 439 |
+
negative_prompt_embeds = text_encoder(
|
| 440 |
+
uncond_input_ids,
|
| 441 |
+
attention_mask=negative_attention_mask,
|
| 442 |
+
)
|
| 443 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 444 |
+
|
| 445 |
+
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
| 446 |
+
negative_attention_mask_list.append(negative_attention_mask)
|
| 447 |
+
|
| 448 |
+
projection_output = self.projection_model(
|
| 449 |
+
hidden_states=negative_prompt_embeds_list[0],
|
| 450 |
+
hidden_states_1=negative_prompt_embeds_list[1],
|
| 451 |
+
attention_mask=negative_attention_mask_list[0],
|
| 452 |
+
attention_mask_1=negative_attention_mask_list[1],
|
| 453 |
+
)
|
| 454 |
+
negative_projected_prompt_embeds = projection_output.hidden_states
|
| 455 |
+
negative_projected_attention_mask = projection_output.attention_mask
|
| 456 |
+
|
| 457 |
+
negative_generated_prompt_embeds = self.generate_language_model(
|
| 458 |
+
negative_projected_prompt_embeds,
|
| 459 |
+
attention_mask=negative_projected_attention_mask,
|
| 460 |
+
max_new_tokens=max_new_tokens,
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
if do_classifier_free_guidance:
|
| 464 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 465 |
+
|
| 466 |
+
negative_prompt_embeds = negative_prompt_embeds.cast(dtype=get_parameter_dtype(self.text_encoder_2))
|
| 467 |
+
negative_attention_mask = (
|
| 468 |
+
negative_attention_mask
|
| 469 |
+
if negative_attention_mask is not None
|
| 470 |
+
else paddle.ones(negative_prompt_embeds.shape[:2], dtype=paddle.int64)
|
| 471 |
+
)
|
| 472 |
+
negative_generated_prompt_embeds = negative_generated_prompt_embeds.cast(
|
| 473 |
+
dtype=get_parameter_dtype(self.language_model)
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 477 |
+
negative_prompt_embeds = negative_prompt_embeds.tile([1, num_waveforms_per_prompt, 1])
|
| 478 |
+
negative_prompt_embeds = negative_prompt_embeds.reshape(
|
| 479 |
+
[batch_size * num_waveforms_per_prompt, seq_len, -1]
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
# duplicate unconditional attention mask for each generation per prompt
|
| 483 |
+
negative_attention_mask = negative_attention_mask.tile([1, num_waveforms_per_prompt])
|
| 484 |
+
negative_attention_mask = negative_attention_mask.reshape([batch_size * num_waveforms_per_prompt, seq_len])
|
| 485 |
+
|
| 486 |
+
# duplicate unconditional generated embeddings for each generation per prompt
|
| 487 |
+
seq_len = negative_generated_prompt_embeds.shape[1]
|
| 488 |
+
negative_generated_prompt_embeds = negative_generated_prompt_embeds.tile([1, num_waveforms_per_prompt, 1])
|
| 489 |
+
negative_generated_prompt_embeds = negative_generated_prompt_embeds.reshape(
|
| 490 |
+
[batch_size * num_waveforms_per_prompt, seq_len, -1]
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 494 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 495 |
+
# to avoid doing two forward passes
|
| 496 |
+
prompt_embeds = paddle.concat([negative_prompt_embeds, prompt_embeds])
|
| 497 |
+
attention_mask = paddle.concat([negative_attention_mask, attention_mask])
|
| 498 |
+
generated_prompt_embeds = paddle.concat([negative_generated_prompt_embeds, generated_prompt_embeds])
|
| 499 |
+
|
| 500 |
+
return prompt_embeds, attention_mask, generated_prompt_embeds
|
| 501 |
+
|
| 502 |
+
# Copied from ppdiffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform
|
| 503 |
+
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
| 504 |
+
if mel_spectrogram.dim() == 4:
|
| 505 |
+
mel_spectrogram = mel_spectrogram.squeeze(1)
|
| 506 |
+
|
| 507 |
+
waveform = self.vocoder(mel_spectrogram)
|
| 508 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 509 |
+
waveform = waveform.cast("float32").cpu()
|
| 510 |
+
return waveform
|
| 511 |
+
|
| 512 |
+
def score_waveforms(self, text, audio, num_waveforms_per_prompt, dtype):
|
| 513 |
+
if not is_librosa_available():
|
| 514 |
+
logger.info(
|
| 515 |
+
"Automatic scoring of the generated audio waveforms against the input prompt text requires the "
|
| 516 |
+
"`librosa` package to resample the generated waveforms. Returning the audios in the order they were "
|
| 517 |
+
"generated. To enable automatic scoring, install `librosa` with: `pip install librosa`."
|
| 518 |
+
)
|
| 519 |
+
return audio
|
| 520 |
+
inputs = self.tokenizer(text, return_tensors="pd", padding=True)
|
| 521 |
+
resampled_audio = librosa.resample(
|
| 522 |
+
audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate
|
| 523 |
+
)
|
| 524 |
+
inputs["input_features"] = self.feature_extractor(
|
| 525 |
+
list(resampled_audio), return_tensors="pd", sampling_rate=self.feature_extractor.sampling_rate
|
| 526 |
+
).input_features.cast(dtype=dtype)
|
| 527 |
+
|
| 528 |
+
# compute the audio-text similarity score using the CLAP model
|
| 529 |
+
logits_per_text = self.text_encoder(**inputs).logits_per_text
|
| 530 |
+
# sort by the highest matching generations per prompt
|
| 531 |
+
indices = paddle.argsort(logits_per_text, axis=1, descending=True)[:, :num_waveforms_per_prompt]
|
| 532 |
+
audio = paddle.index_select(
|
| 533 |
+
audio,
|
| 534 |
+
axis=0,
|
| 535 |
+
index=indices.reshape(
|
| 536 |
+
[
|
| 537 |
+
-1,
|
| 538 |
+
]
|
| 539 |
+
),
|
| 540 |
+
)
|
| 541 |
+
return audio
|
| 542 |
+
|
| 543 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 544 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 545 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 546 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 547 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 548 |
+
# and should be between [0, 1]
|
| 549 |
+
|
| 550 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 551 |
+
extra_step_kwargs = {}
|
| 552 |
+
if accepts_eta:
|
| 553 |
+
extra_step_kwargs["eta"] = eta
|
| 554 |
+
|
| 555 |
+
# check if the scheduler accepts generator
|
| 556 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 557 |
+
if accepts_generator:
|
| 558 |
+
extra_step_kwargs["generator"] = generator
|
| 559 |
+
return extra_step_kwargs
|
| 560 |
+
|
| 561 |
+
def check_inputs(
|
| 562 |
+
self,
|
| 563 |
+
prompt,
|
| 564 |
+
audio_length_in_s,
|
| 565 |
+
vocoder_upsample_factor,
|
| 566 |
+
callback_steps,
|
| 567 |
+
negative_prompt=None,
|
| 568 |
+
prompt_embeds=None,
|
| 569 |
+
negative_prompt_embeds=None,
|
| 570 |
+
generated_prompt_embeds=None,
|
| 571 |
+
negative_generated_prompt_embeds=None,
|
| 572 |
+
attention_mask=None,
|
| 573 |
+
negative_attention_mask=None,
|
| 574 |
+
):
|
| 575 |
+
min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
|
| 576 |
+
if audio_length_in_s < min_audio_length_in_s:
|
| 577 |
+
raise ValueError(
|
| 578 |
+
f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
|
| 579 |
+
f"is {audio_length_in_s}."
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
|
| 583 |
+
raise ValueError(
|
| 584 |
+
f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
|
| 585 |
+
f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
|
| 586 |
+
f"{self.vae_scale_factor}."
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
if (callback_steps is None) or (
|
| 590 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 591 |
+
):
|
| 592 |
+
raise ValueError(
|
| 593 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 594 |
+
f" {type(callback_steps)}."
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
if prompt is not None and prompt_embeds is not None:
|
| 598 |
+
raise ValueError(
|
| 599 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 600 |
+
" only forward one of the two."
|
| 601 |
+
)
|
| 602 |
+
elif prompt is None and (prompt_embeds is None or generated_prompt_embeds is None):
|
| 603 |
+
raise ValueError(
|
| 604 |
+
"Provide either `prompt`, or `prompt_embeds` and `generated_prompt_embeds`. Cannot leave "
|
| 605 |
+
"`prompt` undefined without specifying both `prompt_embeds` and `generated_prompt_embeds`."
|
| 606 |
+
)
|
| 607 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 608 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 609 |
+
|
| 610 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 611 |
+
raise ValueError(
|
| 612 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 613 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 614 |
+
)
|
| 615 |
+
elif negative_prompt_embeds is not None and negative_generated_prompt_embeds is None:
|
| 616 |
+
raise ValueError(
|
| 617 |
+
"Cannot forward `negative_prompt_embeds` without `negative_generated_prompt_embeds`. Ensure that"
|
| 618 |
+
"both arguments are specified"
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 622 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 623 |
+
raise ValueError(
|
| 624 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 625 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 626 |
+
f" {negative_prompt_embeds.shape}."
|
| 627 |
+
)
|
| 628 |
+
if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
|
| 629 |
+
raise ValueError(
|
| 630 |
+
"`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
|
| 631 |
+
f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None:
|
| 635 |
+
if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape:
|
| 636 |
+
raise ValueError(
|
| 637 |
+
"`generated_prompt_embeds` and `negative_generated_prompt_embeds` must have the same shape when "
|
| 638 |
+
f"passed directly, but got: `generated_prompt_embeds` {generated_prompt_embeds.shape} != "
|
| 639 |
+
f"`negative_generated_prompt_embeds` {negative_generated_prompt_embeds.shape}."
|
| 640 |
+
)
|
| 641 |
+
if (
|
| 642 |
+
negative_attention_mask is not None
|
| 643 |
+
and negative_attention_mask.shape != negative_prompt_embeds.shape[:2]
|
| 644 |
+
):
|
| 645 |
+
raise ValueError(
|
| 646 |
+
"`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
|
| 647 |
+
f"`attention_mask: {negative_attention_mask.shape} != `prompt_embeds` {negative_prompt_embeds.shape}"
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim
|
| 651 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, dtype, generator, latents=None):
|
| 652 |
+
shape = (
|
| 653 |
+
batch_size,
|
| 654 |
+
num_channels_latents,
|
| 655 |
+
height // self.vae_scale_factor,
|
| 656 |
+
self.vocoder.config.model_in_dim // self.vae_scale_factor,
|
| 657 |
+
)
|
| 658 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 659 |
+
raise ValueError(
|
| 660 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 661 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
if latents is None:
|
| 665 |
+
latents = randn_tensor(shape, generator=generator, dtype=dtype)
|
| 666 |
+
else:
|
| 667 |
+
latents = latents.cast(dtype)
|
| 668 |
+
|
| 669 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 670 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 671 |
+
return latents
|
| 672 |
+
|
| 673 |
+
@paddle.no_grad()
|
| 674 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 675 |
+
def __call__(
|
| 676 |
+
self,
|
| 677 |
+
prompt: Union[str, List[str]] = None,
|
| 678 |
+
audio_length_in_s: Optional[float] = None,
|
| 679 |
+
num_inference_steps: int = 200,
|
| 680 |
+
guidance_scale: float = 3.5,
|
| 681 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 682 |
+
num_waveforms_per_prompt: Optional[int] = 1,
|
| 683 |
+
eta: float = 0.0,
|
| 684 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 685 |
+
latents: Optional[paddle.Tensor] = None,
|
| 686 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 687 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 688 |
+
generated_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 689 |
+
negative_generated_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 690 |
+
attention_mask: Optional[paddle.Tensor] = None,
|
| 691 |
+
negative_attention_mask: Optional[paddle.Tensor] = None,
|
| 692 |
+
max_new_tokens: Optional[int] = None,
|
| 693 |
+
return_dict: bool = True,
|
| 694 |
+
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
|
| 695 |
+
callback_steps: Optional[int] = 1,
|
| 696 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 697 |
+
output_type: Optional[str] = "np",
|
| 698 |
+
):
|
| 699 |
+
r"""
|
| 700 |
+
The call function to the pipeline for generation.
|
| 701 |
+
|
| 702 |
+
Args:
|
| 703 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 704 |
+
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
|
| 705 |
+
audio_length_in_s (`int`, *optional*, defaults to 10.24):
|
| 706 |
+
The length of the generated audio sample in seconds.
|
| 707 |
+
num_inference_steps (`int`, *optional*, defaults to 200):
|
| 708 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
|
| 709 |
+
expense of slower inference.
|
| 710 |
+
guidance_scale (`float`, *optional*, defaults to 3.5):
|
| 711 |
+
A higher guidance scale value encourages the model to generate audio that is closely linked to the text
|
| 712 |
+
`prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 713 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 714 |
+
The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
|
| 715 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 716 |
+
num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
|
| 717 |
+
The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, then automatic
|
| 718 |
+
scoring is performed between the generated outputs and the text prompt. This scoring ranks the
|
| 719 |
+
generated waveforms based on their cosine similarity with the text input in the joint text-audio
|
| 720 |
+
embedding space.
|
| 721 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 722 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
| 723 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 724 |
+
generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*):
|
| 725 |
+
A [`paddle.Generator`] to make generation deterministic.
|
| 726 |
+
latents (`paddle.Tensor`, *optional*):
|
| 727 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for spectrogram
|
| 728 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 729 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 730 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 731 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 732 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 733 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 734 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 735 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 736 |
+
generated_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 737 |
+
Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
|
| 738 |
+
*e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
|
| 739 |
+
argument.
|
| 740 |
+
negative_generated_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 741 |
+
Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
|
| 742 |
+
inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
| 743 |
+
`negative_prompt` input argument.
|
| 744 |
+
attention_mask (`paddle.Tensor`, *optional*):
|
| 745 |
+
Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
|
| 746 |
+
be computed from `prompt` input argument.
|
| 747 |
+
negative_attention_mask (`paddle.Tensor`, *optional*):
|
| 748 |
+
Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
|
| 749 |
+
mask will be computed from `negative_prompt` input argument.
|
| 750 |
+
max_new_tokens (`int`, *optional*, defaults to None):
|
| 751 |
+
Number of new tokens to generate with the GPT2 language model. If not provided, number of tokens will
|
| 752 |
+
be taken from the config of the model.
|
| 753 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 754 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 755 |
+
plain tuple.
|
| 756 |
+
callback (`Callable`, *optional*):
|
| 757 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 758 |
+
following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
|
| 759 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 760 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 761 |
+
every step.
|
| 762 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 763 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 764 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 765 |
+
output_type (`str`, *optional*, defaults to `"np"`):
|
| 766 |
+
The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
|
| 767 |
+
`"pd"` to return a Paddle `paddle.Tensor` object. Set to `"latent"` to return the latent diffusion
|
| 768 |
+
model (LDM) output.
|
| 769 |
+
|
| 770 |
+
Examples:
|
| 771 |
+
|
| 772 |
+
Returns:
|
| 773 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 774 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| 775 |
+
otherwise a `tuple` is returned where the first element is a list with the generated audio.
|
| 776 |
+
"""
|
| 777 |
+
# 0. Convert audio input length from seconds to spectrogram height
|
| 778 |
+
vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
|
| 779 |
+
|
| 780 |
+
if audio_length_in_s is None:
|
| 781 |
+
audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
|
| 782 |
+
|
| 783 |
+
height = int(audio_length_in_s / vocoder_upsample_factor)
|
| 784 |
+
|
| 785 |
+
original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
|
| 786 |
+
if height % self.vae_scale_factor != 0:
|
| 787 |
+
height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
|
| 788 |
+
logger.info(
|
| 789 |
+
f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
|
| 790 |
+
f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
|
| 791 |
+
f"denoising process."
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
# 1. Check inputs. Raise error if not correct
|
| 795 |
+
self.check_inputs(
|
| 796 |
+
prompt,
|
| 797 |
+
audio_length_in_s,
|
| 798 |
+
vocoder_upsample_factor,
|
| 799 |
+
callback_steps,
|
| 800 |
+
negative_prompt,
|
| 801 |
+
prompt_embeds,
|
| 802 |
+
negative_prompt_embeds,
|
| 803 |
+
generated_prompt_embeds,
|
| 804 |
+
negative_generated_prompt_embeds,
|
| 805 |
+
attention_mask,
|
| 806 |
+
negative_attention_mask,
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
# 2. Define call parameters
|
| 810 |
+
if prompt is not None and isinstance(prompt, str):
|
| 811 |
+
batch_size = 1
|
| 812 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 813 |
+
batch_size = len(prompt)
|
| 814 |
+
else:
|
| 815 |
+
batch_size = prompt_embeds.shape[0]
|
| 816 |
+
|
| 817 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 818 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 819 |
+
# corresponds to doing no classifier free guidance.
|
| 820 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 821 |
+
|
| 822 |
+
# 3. Encode input prompt
|
| 823 |
+
prompt_embeds, attention_mask, generated_prompt_embeds = self.encode_prompt(
|
| 824 |
+
prompt,
|
| 825 |
+
num_waveforms_per_prompt,
|
| 826 |
+
do_classifier_free_guidance,
|
| 827 |
+
negative_prompt,
|
| 828 |
+
prompt_embeds=prompt_embeds,
|
| 829 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 830 |
+
generated_prompt_embeds=generated_prompt_embeds,
|
| 831 |
+
negative_generated_prompt_embeds=negative_generated_prompt_embeds,
|
| 832 |
+
attention_mask=attention_mask,
|
| 833 |
+
negative_attention_mask=negative_attention_mask,
|
| 834 |
+
max_new_tokens=max_new_tokens,
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
# 4. Prepare timesteps
|
| 838 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
| 839 |
+
timesteps = self.scheduler.timesteps
|
| 840 |
+
|
| 841 |
+
# 5. Prepare latent variables
|
| 842 |
+
num_channels_latents = self.unet.config.in_channels
|
| 843 |
+
latents = self.prepare_latents(
|
| 844 |
+
batch_size * num_waveforms_per_prompt,
|
| 845 |
+
num_channels_latents,
|
| 846 |
+
height,
|
| 847 |
+
prompt_embeds.dtype,
|
| 848 |
+
generator,
|
| 849 |
+
latents,
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
# 6. Prepare extra step kwargs
|
| 853 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 854 |
+
|
| 855 |
+
# 7. Denoising loop
|
| 856 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 857 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 858 |
+
for i, t in enumerate(timesteps):
|
| 859 |
+
# expand the latents if we are doing classifier free guidance
|
| 860 |
+
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
|
| 861 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 862 |
+
|
| 863 |
+
# predict the noise residual
|
| 864 |
+
noise_pred = self.unet(
|
| 865 |
+
latent_model_input,
|
| 866 |
+
t,
|
| 867 |
+
encoder_hidden_states=generated_prompt_embeds,
|
| 868 |
+
encoder_hidden_states_1=prompt_embeds,
|
| 869 |
+
encoder_attention_mask_1=attention_mask,
|
| 870 |
+
return_dict=False,
|
| 871 |
+
)[0]
|
| 872 |
+
|
| 873 |
+
# perform guidance
|
| 874 |
+
if do_classifier_free_guidance:
|
| 875 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 876 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 877 |
+
|
| 878 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 879 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 880 |
+
|
| 881 |
+
# call the callback, if provided
|
| 882 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 883 |
+
progress_bar.update()
|
| 884 |
+
if callback is not None and i % callback_steps == 0:
|
| 885 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 886 |
+
callback(step_idx, t, latents)
|
| 887 |
+
|
| 888 |
+
# 8. Post-processing
|
| 889 |
+
if not output_type == "latent":
|
| 890 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 891 |
+
mel_spectrogram = self.vae.decode(latents).sample
|
| 892 |
+
else:
|
| 893 |
+
return AudioPipelineOutput(audios=latents)
|
| 894 |
+
|
| 895 |
+
audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
|
| 896 |
+
|
| 897 |
+
audio = audio[:, :original_waveform_length]
|
| 898 |
+
|
| 899 |
+
# 9. Automatic scoring
|
| 900 |
+
if num_waveforms_per_prompt > 1 and prompt is not None:
|
| 901 |
+
audio = self.score_waveforms(
|
| 902 |
+
text=prompt,
|
| 903 |
+
audio=audio,
|
| 904 |
+
num_waveforms_per_prompt=num_waveforms_per_prompt,
|
| 905 |
+
dtype=prompt_embeds.dtype,
|
| 906 |
+
)
|
| 907 |
+
|
| 908 |
+
if output_type == "np":
|
| 909 |
+
audio = audio.numpy()
|
| 910 |
+
|
| 911 |
+
if not return_dict:
|
| 912 |
+
return (audio,)
|
| 913 |
+
|
| 914 |
+
return AudioPipelineOutput(audios=audio)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/cogvideo/__init__.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
from ...utils import (
|
| 18 |
+
PPDIFFUSERS_SLOW_IMPORT,
|
| 19 |
+
OptionalDependencyNotAvailable,
|
| 20 |
+
_LazyModule,
|
| 21 |
+
get_objects_from_module,
|
| 22 |
+
is_paddle_available,
|
| 23 |
+
is_paddlenlp_available,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
_dummy_objects = {}
|
| 27 |
+
_import_structure = {}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 32 |
+
raise OptionalDependencyNotAvailable()
|
| 33 |
+
except OptionalDependencyNotAvailable:
|
| 34 |
+
from ...utils import dummy_paddle_and_paddlenlp_objects # noqa F403
|
| 35 |
+
|
| 36 |
+
_dummy_objects.update(get_objects_from_module(dummy_paddle_and_paddlenlp_objects))
|
| 37 |
+
else:
|
| 38 |
+
_import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
|
| 39 |
+
_import_structure["pipeline_cogvideox_vctrl"] = ["CogVideoXVCtrlPipeline"]
|
| 40 |
+
_import_structure["pipeline_cogvideox_image2video_vctrl"] = ["CogVideoXVCtrlImageToVideoPipeline"]
|
| 41 |
+
# _import_structure["pipeline_cogvideox_fun_control"] = ["CogVideoXFunControlPipeline"]
|
| 42 |
+
# _import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"]
|
| 43 |
+
# _import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"]
|
| 44 |
+
|
| 45 |
+
if TYPE_CHECKING or PPDIFFUSERS_SLOW_IMPORT:
|
| 46 |
+
try:
|
| 47 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 48 |
+
raise OptionalDependencyNotAvailable()
|
| 49 |
+
|
| 50 |
+
except OptionalDependencyNotAvailable:
|
| 51 |
+
from ...utils.dummy_paddle_and_paddlenlp_objects import *
|
| 52 |
+
else:
|
| 53 |
+
from .pipeline_cogvideox import CogVideoXPipeline
|
| 54 |
+
from .pipeline_cogvideox_image2video_vctrl import (
|
| 55 |
+
CogVideoXVCtrlImageToVideoPipeline,
|
| 56 |
+
)
|
| 57 |
+
from .pipeline_cogvideox_vctrl import CogVideoXVCtrlPipeline
|
| 58 |
+
|
| 59 |
+
# from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline
|
| 60 |
+
# from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
|
| 61 |
+
# from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline
|
| 62 |
+
|
| 63 |
+
else:
|
| 64 |
+
import sys
|
| 65 |
+
|
| 66 |
+
sys.modules[__name__] = _LazyModule(
|
| 67 |
+
__name__,
|
| 68 |
+
globals()["__file__"],
|
| 69 |
+
_import_structure,
|
| 70 |
+
module_spec=__spec__,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
for name, value in _dummy_objects.items():
|
| 74 |
+
setattr(sys.modules[__name__], name, value)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_cogvideox.py
ADDED
|
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import math
|
| 18 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import paddle
|
| 21 |
+
|
| 22 |
+
from ppdiffusers.transformers import T5EncoderModel, T5Tokenizer
|
| 23 |
+
|
| 24 |
+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 25 |
+
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
| 26 |
+
from ...models.embeddings import get_3d_rotary_pos_embed
|
| 27 |
+
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
| 28 |
+
from ...utils import logging, replace_example_docstring
|
| 29 |
+
from ...utils.paddle_utils import randn_tensor
|
| 30 |
+
from ...video_processor import VideoProcessor
|
| 31 |
+
from ..pipeline_utils import DiffusionPipeline
|
| 32 |
+
from .pipeline_output import CogVideoXPipelineOutput
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
EXAMPLE_DOC_STRING = """
|
| 38 |
+
Examples:
|
| 39 |
+
```python
|
| 40 |
+
>>> import paddle
|
| 41 |
+
>>> from diffusers import CogVideoXPipeline
|
| 42 |
+
>>> from diffusers.utils import export_to_video
|
| 43 |
+
|
| 44 |
+
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
|
| 45 |
+
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", paddle_dtype=paddle.float16)
|
| 46 |
+
>>> prompt = (
|
| 47 |
+
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
| 48 |
+
... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
|
| 49 |
+
... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
|
| 50 |
+
... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
|
| 51 |
+
... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
|
| 52 |
+
... "atmosphere of this unique musical performance."
|
| 53 |
+
... )
|
| 54 |
+
>>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
| 55 |
+
>>> export_to_video(video, "output.mp4", fps=8)
|
| 56 |
+
```
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
| 61 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 62 |
+
tw = tgt_width
|
| 63 |
+
th = tgt_height
|
| 64 |
+
h, w = src
|
| 65 |
+
r = h / w
|
| 66 |
+
if r > (th / tw):
|
| 67 |
+
resize_height = th
|
| 68 |
+
resize_width = int(round(th / h * w))
|
| 69 |
+
else:
|
| 70 |
+
resize_width = tw
|
| 71 |
+
resize_height = int(round(tw / w * h))
|
| 72 |
+
|
| 73 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 74 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 75 |
+
|
| 76 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 80 |
+
def retrieve_timesteps(
|
| 81 |
+
scheduler,
|
| 82 |
+
num_inference_steps: Optional[int] = None,
|
| 83 |
+
timesteps: Optional[List[int]] = None,
|
| 84 |
+
sigmas: Optional[List[float]] = None,
|
| 85 |
+
**kwargs,
|
| 86 |
+
):
|
| 87 |
+
"""
|
| 88 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 89 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
scheduler (`SchedulerMixin`):
|
| 93 |
+
The scheduler to get timesteps from.
|
| 94 |
+
num_inference_steps (`int`):
|
| 95 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 96 |
+
must be `None`.
|
| 97 |
+
timesteps (`List[int]`, *optional*):
|
| 98 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 99 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 100 |
+
sigmas (`List[float]`, *optional*):
|
| 101 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 102 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
`Tuple[paddle.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 106 |
+
second element is the number of inference steps.
|
| 107 |
+
"""
|
| 108 |
+
if timesteps is not None and sigmas is not None:
|
| 109 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 110 |
+
if timesteps is not None:
|
| 111 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 112 |
+
if not accepts_timesteps:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 115 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 116 |
+
)
|
| 117 |
+
scheduler.set_timesteps(timesteps=timesteps, **kwargs)
|
| 118 |
+
timesteps = scheduler.timesteps
|
| 119 |
+
num_inference_steps = len(timesteps)
|
| 120 |
+
elif sigmas is not None:
|
| 121 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 122 |
+
if not accept_sigmas:
|
| 123 |
+
raise ValueError(
|
| 124 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 125 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 126 |
+
)
|
| 127 |
+
scheduler.set_timesteps(sigmas=sigmas, **kwargs)
|
| 128 |
+
timesteps = scheduler.timesteps
|
| 129 |
+
num_inference_steps = len(timesteps)
|
| 130 |
+
else:
|
| 131 |
+
scheduler.set_timesteps(num_inference_steps, **kwargs)
|
| 132 |
+
timesteps = scheduler.timesteps
|
| 133 |
+
return timesteps, num_inference_steps
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class CogVideoXPipeline(DiffusionPipeline):
|
| 137 |
+
r"""
|
| 138 |
+
Pipeline for text-to-video generation using CogVideoX.
|
| 139 |
+
|
| 140 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 141 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
vae ([`AutoencoderKL`]):
|
| 145 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 146 |
+
text_encoder ([`T5EncoderModel`]):
|
| 147 |
+
Frozen text-encoder. CogVideoX uses
|
| 148 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
| 149 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 150 |
+
tokenizer (`T5Tokenizer`):
|
| 151 |
+
Tokenizer of class
|
| 152 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 153 |
+
transformer ([`CogVideoXTransformer3DModel`]):
|
| 154 |
+
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
| 155 |
+
scheduler ([`SchedulerMixin`]):
|
| 156 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
_optional_components = []
|
| 160 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 161 |
+
|
| 162 |
+
_callback_tensor_inputs = [
|
| 163 |
+
"latents",
|
| 164 |
+
"prompt_embeds",
|
| 165 |
+
"negative_prompt_embeds",
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
def __init__(
|
| 169 |
+
self,
|
| 170 |
+
tokenizer: T5Tokenizer,
|
| 171 |
+
text_encoder: T5EncoderModel,
|
| 172 |
+
vae: AutoencoderKLCogVideoX,
|
| 173 |
+
transformer: CogVideoXTransformer3DModel,
|
| 174 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
| 175 |
+
):
|
| 176 |
+
super().__init__()
|
| 177 |
+
|
| 178 |
+
self.register_modules(
|
| 179 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 180 |
+
)
|
| 181 |
+
self.vae_scale_factor_spatial = (
|
| 182 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 183 |
+
)
|
| 184 |
+
self.vae_scale_factor_temporal = (
|
| 185 |
+
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 189 |
+
|
| 190 |
+
def _get_t5_prompt_embeds(
|
| 191 |
+
self,
|
| 192 |
+
prompt: Union[str, List[str]] = None,
|
| 193 |
+
num_videos_per_prompt: int = 1,
|
| 194 |
+
max_sequence_length: int = 226,
|
| 195 |
+
dtype: Optional[paddle.dtype] = None,
|
| 196 |
+
):
|
| 197 |
+
dtype = dtype or self.text_encoder.dtype
|
| 198 |
+
|
| 199 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 200 |
+
batch_size = len(prompt)
|
| 201 |
+
|
| 202 |
+
text_inputs = self.tokenizer(
|
| 203 |
+
prompt,
|
| 204 |
+
padding="max_length",
|
| 205 |
+
max_length=max_sequence_length,
|
| 206 |
+
truncation=True,
|
| 207 |
+
add_special_tokens=True,
|
| 208 |
+
return_tensors="pd",
|
| 209 |
+
)
|
| 210 |
+
text_input_ids = text_inputs.input_ids
|
| 211 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pd").input_ids
|
| 212 |
+
|
| 213 |
+
if (
|
| 214 |
+
untruncated_ids.shape[-1] >= text_input_ids.shape[-1]
|
| 215 |
+
and not paddle.equal_all(text_input_ids, untruncated_ids).item()
|
| 216 |
+
):
|
| 217 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 218 |
+
logger.warning(
|
| 219 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 220 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
prompt_embeds = self.text_encoder(text_input_ids)[0]
|
| 224 |
+
prompt_embeds = prompt_embeds.cast(dtype)
|
| 225 |
+
|
| 226 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 227 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 228 |
+
prompt_embeds = prompt_embeds.tile([1, num_videos_per_prompt, 1])
|
| 229 |
+
prompt_embeds = prompt_embeds.reshape([batch_size * num_videos_per_prompt, seq_len, -1])
|
| 230 |
+
|
| 231 |
+
return prompt_embeds
|
| 232 |
+
|
| 233 |
+
def encode_prompt(
|
| 234 |
+
self,
|
| 235 |
+
prompt: Union[str, List[str]],
|
| 236 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 237 |
+
do_classifier_free_guidance: bool = True,
|
| 238 |
+
num_videos_per_prompt: int = 1,
|
| 239 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 240 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 241 |
+
max_sequence_length: int = 226,
|
| 242 |
+
dtype: Optional[paddle.dtype] = None,
|
| 243 |
+
):
|
| 244 |
+
r"""
|
| 245 |
+
Encodes the prompt into text encoder hidden states.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 249 |
+
prompt to be encoded
|
| 250 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 251 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 252 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 253 |
+
less than `1`).
|
| 254 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 255 |
+
Whether to use classifier free guidance or not.
|
| 256 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 257 |
+
Number of videos that should be generated per prompt. paddle device to place the resulting embeddings on
|
| 258 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 259 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 260 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 261 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 262 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 263 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 264 |
+
argument.
|
| 265 |
+
dtype: (`paddle.dtype`, *optional*):
|
| 266 |
+
paddle dtype
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 270 |
+
if prompt is not None:
|
| 271 |
+
batch_size = len(prompt)
|
| 272 |
+
else:
|
| 273 |
+
batch_size = prompt_embeds.shape[0]
|
| 274 |
+
|
| 275 |
+
if prompt_embeds is None:
|
| 276 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 277 |
+
prompt=prompt,
|
| 278 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 279 |
+
max_sequence_length=max_sequence_length,
|
| 280 |
+
dtype=dtype,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 284 |
+
negative_prompt = negative_prompt or ""
|
| 285 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 286 |
+
|
| 287 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 288 |
+
raise TypeError(
|
| 289 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 290 |
+
f" {type(prompt)}."
|
| 291 |
+
)
|
| 292 |
+
elif batch_size != len(negative_prompt):
|
| 293 |
+
raise ValueError(
|
| 294 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 295 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 296 |
+
" the batch size of `prompt`."
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 300 |
+
prompt=negative_prompt,
|
| 301 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 302 |
+
max_sequence_length=max_sequence_length,
|
| 303 |
+
dtype=dtype,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
return prompt_embeds, negative_prompt_embeds
|
| 307 |
+
|
| 308 |
+
def prepare_latents(
|
| 309 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, generator, latents=None
|
| 310 |
+
):
|
| 311 |
+
shape = (
|
| 312 |
+
batch_size,
|
| 313 |
+
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
| 314 |
+
num_channels_latents,
|
| 315 |
+
height // self.vae_scale_factor_spatial,
|
| 316 |
+
width // self.vae_scale_factor_spatial,
|
| 317 |
+
)
|
| 318 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 319 |
+
raise ValueError(
|
| 320 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 321 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
if latents is None:
|
| 325 |
+
latents = randn_tensor(shape, generator=generator, dtype=dtype)
|
| 326 |
+
else:
|
| 327 |
+
latents = latents
|
| 328 |
+
|
| 329 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 330 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 331 |
+
return latents
|
| 332 |
+
|
| 333 |
+
def decode_latents(self, latents: paddle.Tensor) -> paddle.Tensor:
|
| 334 |
+
latents = latents.transpose([0, 2, 1, 3, 4]) # [batch_size, num_channels, num_frames, height, width]
|
| 335 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 336 |
+
|
| 337 |
+
frames = self.vae.decode(latents).sample
|
| 338 |
+
return frames
|
| 339 |
+
|
| 340 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 341 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 342 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 343 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 344 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 345 |
+
# and should be between [0, 1]
|
| 346 |
+
|
| 347 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 348 |
+
extra_step_kwargs = {}
|
| 349 |
+
if accepts_eta:
|
| 350 |
+
extra_step_kwargs["eta"] = eta
|
| 351 |
+
|
| 352 |
+
# check if the scheduler accepts generator
|
| 353 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 354 |
+
if accepts_generator:
|
| 355 |
+
extra_step_kwargs["generator"] = generator
|
| 356 |
+
return extra_step_kwargs
|
| 357 |
+
|
| 358 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 359 |
+
def check_inputs(
|
| 360 |
+
self,
|
| 361 |
+
prompt,
|
| 362 |
+
height,
|
| 363 |
+
width,
|
| 364 |
+
negative_prompt,
|
| 365 |
+
callback_on_step_end_tensor_inputs,
|
| 366 |
+
prompt_embeds=None,
|
| 367 |
+
negative_prompt_embeds=None,
|
| 368 |
+
):
|
| 369 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 370 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 371 |
+
|
| 372 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 373 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 374 |
+
):
|
| 375 |
+
raise ValueError(
|
| 376 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 377 |
+
)
|
| 378 |
+
if prompt is not None and prompt_embeds is not None:
|
| 379 |
+
raise ValueError(
|
| 380 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 381 |
+
" only forward one of the two."
|
| 382 |
+
)
|
| 383 |
+
elif prompt is None and prompt_embeds is None:
|
| 384 |
+
raise ValueError(
|
| 385 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 386 |
+
)
|
| 387 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 388 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 389 |
+
|
| 390 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 391 |
+
raise ValueError(
|
| 392 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 393 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 397 |
+
raise ValueError(
|
| 398 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 399 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 403 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 404 |
+
raise ValueError(
|
| 405 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 406 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 407 |
+
f" {negative_prompt_embeds.shape}."
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
def fuse_qkv_projections(self) -> None:
|
| 411 |
+
r"""Enables fused QKV projections."""
|
| 412 |
+
self.fusing_transformer = True
|
| 413 |
+
self.transformer.fuse_qkv_projections()
|
| 414 |
+
|
| 415 |
+
def unfuse_qkv_projections(self) -> None:
|
| 416 |
+
r"""Disable QKV projection fusion if enabled."""
|
| 417 |
+
if not self.fusing_transformer:
|
| 418 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
| 419 |
+
else:
|
| 420 |
+
self.transformer.unfuse_qkv_projections()
|
| 421 |
+
self.fusing_transformer = False
|
| 422 |
+
|
| 423 |
+
def _prepare_rotary_positional_embeddings(
|
| 424 |
+
self,
|
| 425 |
+
height: int,
|
| 426 |
+
width: int,
|
| 427 |
+
num_frames: int,
|
| 428 |
+
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
| 429 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 430 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 431 |
+
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 432 |
+
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 433 |
+
|
| 434 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 435 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 436 |
+
)
|
| 437 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 438 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 439 |
+
crops_coords=grid_crops_coords,
|
| 440 |
+
grid_size=(grid_height, grid_width),
|
| 441 |
+
temporal_size=num_frames,
|
| 442 |
+
use_real=True,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
return freqs_cos, freqs_sin
|
| 446 |
+
|
| 447 |
+
@property
|
| 448 |
+
def guidance_scale(self):
|
| 449 |
+
return self._guidance_scale
|
| 450 |
+
|
| 451 |
+
@property
|
| 452 |
+
def num_timesteps(self):
|
| 453 |
+
return self._num_timesteps
|
| 454 |
+
|
| 455 |
+
@property
|
| 456 |
+
def interrupt(self):
|
| 457 |
+
return self._interrupt
|
| 458 |
+
|
| 459 |
+
@paddle.no_grad()
|
| 460 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 461 |
+
def __call__(
|
| 462 |
+
self,
|
| 463 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 464 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 465 |
+
height: int = 480,
|
| 466 |
+
width: int = 720,
|
| 467 |
+
num_frames: int = 49,
|
| 468 |
+
num_inference_steps: int = 50,
|
| 469 |
+
timesteps: Optional[List[int]] = None,
|
| 470 |
+
guidance_scale: float = 6,
|
| 471 |
+
use_dynamic_cfg: bool = False,
|
| 472 |
+
num_videos_per_prompt: int = 1,
|
| 473 |
+
eta: float = 0.0,
|
| 474 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 475 |
+
latents: Optional[paddle.Tensor] = None,
|
| 476 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 477 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 478 |
+
output_type: str = "pil",
|
| 479 |
+
return_dict: bool = True,
|
| 480 |
+
callback_on_step_end: Optional[
|
| 481 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 482 |
+
] = None,
|
| 483 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 484 |
+
max_sequence_length: int = 226,
|
| 485 |
+
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
| 486 |
+
"""
|
| 487 |
+
Function invoked when calling the pipeline for generation.
|
| 488 |
+
|
| 489 |
+
Args:
|
| 490 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 491 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 492 |
+
instead.
|
| 493 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 494 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 495 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 496 |
+
less than `1`).
|
| 497 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 498 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 499 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 500 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 501 |
+
num_frames (`int`, defaults to `48`):
|
| 502 |
+
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
| 503 |
+
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
| 504 |
+
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
|
| 505 |
+
needs to be satisfied is that of divisibility mentioned above.
|
| 506 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 507 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 508 |
+
expense of slower inference.
|
| 509 |
+
timesteps (`List[int]`, *optional*):
|
| 510 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 511 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 512 |
+
passed will be used. Must be in descending order.
|
| 513 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 514 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 515 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 516 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 517 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 518 |
+
usually at the expense of lower image quality.
|
| 519 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 520 |
+
The number of videos to generate per prompt.
|
| 521 |
+
generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*):
|
| 522 |
+
One or a list of [paddle generator(s)](https://pypaddle.org/docs/stable/generated/paddle.Generator.html)
|
| 523 |
+
to make generation deterministic.
|
| 524 |
+
latents (`paddle.float32`, *optional*):
|
| 525 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 526 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 527 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 528 |
+
prompt_embeds (`paddle.float32`, *optional*):
|
| 529 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 530 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 531 |
+
negative_prompt_embeds (`paddle.float32`, *optional*):
|
| 532 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 533 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 534 |
+
argument.
|
| 535 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 536 |
+
The output format of the generate image. Choose between
|
| 537 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 538 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 539 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 540 |
+
of a plain tuple.
|
| 541 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 542 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 543 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 544 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 545 |
+
`callback_on_step_end_tensor_inputs`.
|
| 546 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 547 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 548 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 549 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 550 |
+
max_sequence_length (`int`, defaults to `226`):
|
| 551 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
| 552 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
| 553 |
+
|
| 554 |
+
Examples:
|
| 555 |
+
|
| 556 |
+
Returns:
|
| 557 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
|
| 558 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
|
| 559 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 560 |
+
"""
|
| 561 |
+
|
| 562 |
+
if num_frames > 49:
|
| 563 |
+
raise ValueError(
|
| 564 |
+
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 568 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 569 |
+
|
| 570 |
+
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
| 571 |
+
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
| 572 |
+
num_videos_per_prompt = 1
|
| 573 |
+
|
| 574 |
+
# 1. Check inputs. Raise error if not correct
|
| 575 |
+
self.check_inputs(
|
| 576 |
+
prompt,
|
| 577 |
+
height,
|
| 578 |
+
width,
|
| 579 |
+
negative_prompt,
|
| 580 |
+
callback_on_step_end_tensor_inputs,
|
| 581 |
+
prompt_embeds,
|
| 582 |
+
negative_prompt_embeds,
|
| 583 |
+
)
|
| 584 |
+
self._guidance_scale = guidance_scale
|
| 585 |
+
self._interrupt = False
|
| 586 |
+
|
| 587 |
+
# 2. Default call parameters
|
| 588 |
+
if prompt is not None and isinstance(prompt, str):
|
| 589 |
+
batch_size = 1
|
| 590 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 591 |
+
batch_size = len(prompt)
|
| 592 |
+
else:
|
| 593 |
+
batch_size = prompt_embeds.shape[0]
|
| 594 |
+
|
| 595 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 596 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 597 |
+
# corresponds to doing no classifier free guidance.
|
| 598 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 599 |
+
|
| 600 |
+
# 3. Encode input prompt
|
| 601 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 602 |
+
prompt,
|
| 603 |
+
negative_prompt,
|
| 604 |
+
do_classifier_free_guidance,
|
| 605 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 606 |
+
prompt_embeds=prompt_embeds,
|
| 607 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 608 |
+
max_sequence_length=max_sequence_length,
|
| 609 |
+
)
|
| 610 |
+
if do_classifier_free_guidance:
|
| 611 |
+
prompt_embeds = paddle.concat([negative_prompt_embeds, prompt_embeds], axis=0)
|
| 612 |
+
|
| 613 |
+
# 4. Prepare timesteps
|
| 614 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps)
|
| 615 |
+
self._num_timesteps = len(timesteps)
|
| 616 |
+
|
| 617 |
+
# 5. Prepare latents.
|
| 618 |
+
latent_channels = self.transformer.config.in_channels
|
| 619 |
+
latents = self.prepare_latents(
|
| 620 |
+
batch_size * num_videos_per_prompt,
|
| 621 |
+
latent_channels,
|
| 622 |
+
num_frames,
|
| 623 |
+
height,
|
| 624 |
+
width,
|
| 625 |
+
prompt_embeds.dtype,
|
| 626 |
+
generator,
|
| 627 |
+
latents,
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 631 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 632 |
+
|
| 633 |
+
# 7. Create rotary embeds if required
|
| 634 |
+
image_rotary_emb = (
|
| 635 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1))
|
| 636 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
| 637 |
+
else None
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
# 8. Denoising loop
|
| 641 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 642 |
+
|
| 643 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 644 |
+
# for DPM-solver++
|
| 645 |
+
old_pred_original_sample = None
|
| 646 |
+
for i, t in enumerate(timesteps):
|
| 647 |
+
if self.interrupt:
|
| 648 |
+
continue
|
| 649 |
+
|
| 650 |
+
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
|
| 651 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 652 |
+
|
| 653 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 654 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 655 |
+
|
| 656 |
+
# predict noise model_output
|
| 657 |
+
noise_pred = self.transformer(
|
| 658 |
+
hidden_states=latent_model_input,
|
| 659 |
+
encoder_hidden_states=prompt_embeds,
|
| 660 |
+
timestep=timestep,
|
| 661 |
+
image_rotary_emb=image_rotary_emb,
|
| 662 |
+
return_dict=False,
|
| 663 |
+
)[0]
|
| 664 |
+
noise_pred = noise_pred.cast("float32")
|
| 665 |
+
|
| 666 |
+
# perform guidance
|
| 667 |
+
if use_dynamic_cfg:
|
| 668 |
+
self._guidance_scale = 1 + guidance_scale * (
|
| 669 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
| 670 |
+
)
|
| 671 |
+
if do_classifier_free_guidance:
|
| 672 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 673 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 674 |
+
|
| 675 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 676 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
| 677 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 678 |
+
else:
|
| 679 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
| 680 |
+
noise_pred,
|
| 681 |
+
old_pred_original_sample,
|
| 682 |
+
t,
|
| 683 |
+
timesteps[i - 1] if i > 0 else None,
|
| 684 |
+
latents,
|
| 685 |
+
**extra_step_kwargs,
|
| 686 |
+
return_dict=False,
|
| 687 |
+
)
|
| 688 |
+
latents = latents.cast(prompt_embeds.dtype)
|
| 689 |
+
|
| 690 |
+
# call the callback, if provided
|
| 691 |
+
if callback_on_step_end is not None:
|
| 692 |
+
callback_kwargs = {}
|
| 693 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 694 |
+
callback_kwargs[k] = locals()[k]
|
| 695 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 696 |
+
|
| 697 |
+
latents = callback_outputs.pop("latents", latents)
|
| 698 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 699 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 700 |
+
|
| 701 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 702 |
+
progress_bar.update()
|
| 703 |
+
|
| 704 |
+
if not output_type == "latent":
|
| 705 |
+
video = self.decode_latents(latents)
|
| 706 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 707 |
+
else:
|
| 708 |
+
video = latents
|
| 709 |
+
|
| 710 |
+
# Offload all models
|
| 711 |
+
self.maybe_free_model_hooks()
|
| 712 |
+
|
| 713 |
+
if not return_dict:
|
| 714 |
+
return (video,)
|
| 715 |
+
|
| 716 |
+
return CogVideoXPipelineOutput(frames=video)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_cogvideox_image2video_vctrl.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
import math
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import paddle
|
| 21 |
+
import paddlenlp
|
| 22 |
+
import PIL
|
| 23 |
+
|
| 24 |
+
from ppdiffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 25 |
+
from ppdiffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 26 |
+
from ppdiffusers.models import (
|
| 27 |
+
AutoencoderKLCogVideoX,
|
| 28 |
+
CogVideoXTransformer3DVCtrlModel,
|
| 29 |
+
VCtrlModel,
|
| 30 |
+
)
|
| 31 |
+
from ppdiffusers.models.embeddings import get_3d_rotary_pos_embed
|
| 32 |
+
from ppdiffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 33 |
+
from ppdiffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
| 34 |
+
from ppdiffusers.utils import BaseOutput, logging
|
| 35 |
+
from ppdiffusers.utils.paddle_utils import randn_tensor
|
| 36 |
+
from ppdiffusers.video_processor import VideoProcessor
|
| 37 |
+
|
| 38 |
+
logger = logging.get_logger(__name__)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def tensor2vid(video: paddle.Tensor, processor, output_type="np"):
|
| 42 |
+
batch_size, channels, num_frames, height, width = tuple(video.shape)
|
| 43 |
+
outputs = []
|
| 44 |
+
for batch_idx in range(batch_size):
|
| 45 |
+
batch_vid = video[batch_idx].transpose(perm=[1, 0, 2, 3])
|
| 46 |
+
batch_output = processor.postprocess(batch_vid, output_type)
|
| 47 |
+
outputs.append(batch_output)
|
| 48 |
+
return outputs
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 52 |
+
tw = tgt_width
|
| 53 |
+
th = tgt_height
|
| 54 |
+
h, w = src
|
| 55 |
+
r = h / w
|
| 56 |
+
if r > th / tw:
|
| 57 |
+
resize_height = th
|
| 58 |
+
resize_width = int(round(th / h * w))
|
| 59 |
+
else:
|
| 60 |
+
resize_width = tw
|
| 61 |
+
resize_height = int(round(tw / w * h))
|
| 62 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 63 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 64 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def retrieve_timesteps(
|
| 68 |
+
scheduler,
|
| 69 |
+
num_inference_steps: Optional[int] = None,
|
| 70 |
+
timesteps: Optional[List[int]] = None,
|
| 71 |
+
sigmas: Optional[List[float]] = None,
|
| 72 |
+
**kwargs
|
| 73 |
+
):
|
| 74 |
+
"""
|
| 75 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 76 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
scheduler (`SchedulerMixin`):
|
| 80 |
+
The scheduler to get timesteps from.
|
| 81 |
+
num_inference_steps (`int`):
|
| 82 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 83 |
+
must be `None`.
|
| 84 |
+
timesteps (`List[int]`, *optional*):
|
| 85 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 86 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 87 |
+
sigmas (`List[float]`, *optional*):
|
| 88 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 89 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
`Tuple[paddle.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 93 |
+
second element is the number of inference steps.
|
| 94 |
+
"""
|
| 95 |
+
if timesteps is not None and sigmas is not None:
|
| 96 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 97 |
+
if timesteps is not None:
|
| 98 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 99 |
+
if not accepts_timesteps:
|
| 100 |
+
raise ValueError(
|
| 101 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler."
|
| 102 |
+
)
|
| 103 |
+
scheduler.set_timesteps(timesteps=timesteps, **kwargs)
|
| 104 |
+
timesteps = scheduler.timesteps
|
| 105 |
+
num_inference_steps = len(timesteps)
|
| 106 |
+
elif sigmas is not None:
|
| 107 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 108 |
+
if not accept_sigmas:
|
| 109 |
+
raise ValueError(
|
| 110 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler."
|
| 111 |
+
)
|
| 112 |
+
scheduler.set_timesteps(sigmas=sigmas, **kwargs)
|
| 113 |
+
timesteps = scheduler.timesteps
|
| 114 |
+
num_inference_steps = len(timesteps)
|
| 115 |
+
else:
|
| 116 |
+
scheduler.set_timesteps(num_inference_steps, **kwargs)
|
| 117 |
+
timesteps = scheduler.timesteps
|
| 118 |
+
return timesteps, num_inference_steps
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def retrieve_latents(
|
| 122 |
+
encoder_output: paddle.Tensor, generator: Optional[paddle.Generator] = None, sample_mode: str = "sample"
|
| 123 |
+
):
|
| 124 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 125 |
+
return encoder_output.latent_dist.sample(generator)
|
| 126 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 127 |
+
return encoder_output.latent_dist.mode()
|
| 128 |
+
elif hasattr(encoder_output, "latents"):
|
| 129 |
+
return encoder_output.latents
|
| 130 |
+
else:
|
| 131 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@dataclass
|
| 135 |
+
class CogVideoXPipelineOutput(BaseOutput):
|
| 136 |
+
"""
|
| 137 |
+
Output class for CogVideo pipelines.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
frames (`paddle.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 141 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 142 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or paddle tensor of shape
|
| 143 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
frames: paddle.Tensor
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class CogVideoXVCtrlImageToVideoPipeline(DiffusionPipeline):
|
| 150 |
+
"""
|
| 151 |
+
Pipeline for image-to-video generation using CogVideoX with VCTRL.
|
| 152 |
+
|
| 153 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 154 |
+
library implements for all the pipelines (such as downloading or saving)
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
vae ([`AutoencoderKL`]):
|
| 158 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 159 |
+
text_encoder ([`T5EncoderModel`]):
|
| 160 |
+
Frozen text-encoder. CogVideoX uses
|
| 161 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
| 162 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 163 |
+
tokenizer (`T5Tokenizer`):
|
| 164 |
+
Tokenizer of class
|
| 165 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 166 |
+
transformer ([`CogVideoXTransformer3DModel`]):
|
| 167 |
+
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
| 168 |
+
scheduler ([`SchedulerMixin`]):
|
| 169 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
_optional_components = []
|
| 173 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 174 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 175 |
+
|
| 176 |
+
def __init__(
|
| 177 |
+
self,
|
| 178 |
+
tokenizer: paddlenlp.transformers.T5Tokenizer,
|
| 179 |
+
text_encoder: paddlenlp.transformers.T5EncoderModel,
|
| 180 |
+
vae: AutoencoderKLCogVideoX,
|
| 181 |
+
transformer: CogVideoXTransformer3DVCtrlModel,
|
| 182 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
| 183 |
+
vctrl: VCtrlModel,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.register_modules(
|
| 187 |
+
tokenizer=tokenizer,
|
| 188 |
+
text_encoder=text_encoder,
|
| 189 |
+
vae=vae,
|
| 190 |
+
transformer=transformer,
|
| 191 |
+
scheduler=scheduler,
|
| 192 |
+
vctrl=vctrl,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
self.vae_scale_factor_spatial = (
|
| 196 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 197 |
+
)
|
| 198 |
+
self.vae_scale_factor_temporal = (
|
| 199 |
+
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
| 200 |
+
)
|
| 201 |
+
self.vae_scaling_factor_image = (
|
| 202 |
+
self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
|
| 203 |
+
)
|
| 204 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 205 |
+
self.vctrl_image_processor = VaeImageProcessor(
|
| 206 |
+
vae_scale_factor=self.vae_scale_factor_spatial, do_convert_rgb=True, do_normalize=True
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def _get_t5_prompt_embeds(
|
| 210 |
+
self,
|
| 211 |
+
prompt: Union[str, List[str]] = None,
|
| 212 |
+
num_videos_per_prompt: int = 1,
|
| 213 |
+
max_sequence_length: int = 226,
|
| 214 |
+
dtype: Optional[paddle.dtype] = None,
|
| 215 |
+
):
|
| 216 |
+
|
| 217 |
+
dtype = dtype or self.text_encoder.dtype
|
| 218 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 219 |
+
batch_size = len(prompt)
|
| 220 |
+
text_inputs = self.tokenizer(
|
| 221 |
+
prompt,
|
| 222 |
+
padding="max_length",
|
| 223 |
+
max_length=max_sequence_length,
|
| 224 |
+
truncation=True,
|
| 225 |
+
add_special_tokens=True,
|
| 226 |
+
return_tensors="pd",
|
| 227 |
+
)
|
| 228 |
+
text_input_ids = text_inputs.input_ids
|
| 229 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pd").input_ids
|
| 230 |
+
if (
|
| 231 |
+
tuple(untruncated_ids.shape)[-1] >= tuple(text_input_ids.shape)[-1]
|
| 232 |
+
and not paddle.equal_all(x=text_input_ids, y=untruncated_ids).item()
|
| 233 |
+
):
|
| 234 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 235 |
+
logger.warning(
|
| 236 |
+
f"The following part of your input was truncated because `max_sequence_length` is set to {max_sequence_length} tokens: {removed_text}"
|
| 237 |
+
)
|
| 238 |
+
prompt_embeds = self.text_encoder(text_input_ids)[0]
|
| 239 |
+
prompt_embeds = prompt_embeds
|
| 240 |
+
_, seq_len, _ = tuple(prompt_embeds.shape)
|
| 241 |
+
prompt_embeds = prompt_embeds.tile(repeat_times=[1, num_videos_per_prompt, 1])
|
| 242 |
+
prompt_embeds = prompt_embeds.reshape([batch_size * num_videos_per_prompt, seq_len, -1])
|
| 243 |
+
return prompt_embeds
|
| 244 |
+
|
| 245 |
+
def encode_prompt(
|
| 246 |
+
self,
|
| 247 |
+
prompt: Union[str, List[str]],
|
| 248 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 249 |
+
do_classifier_free_guidance: bool = True,
|
| 250 |
+
num_videos_per_prompt: int = 1,
|
| 251 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 252 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 253 |
+
max_sequence_length: int = 226,
|
| 254 |
+
dtype: Optional[paddle.dtype] = None,
|
| 255 |
+
):
|
| 256 |
+
"""
|
| 257 |
+
Encodes the prompt into text encoder hidden states.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 261 |
+
prompt to be encoded
|
| 262 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 263 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 264 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 265 |
+
less than `1`).
|
| 266 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 267 |
+
Whether to use classifier free guidance or not.
|
| 268 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 269 |
+
Number of videos that should be generated per prompt.
|
| 270 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 271 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 272 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 273 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 274 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 275 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 276 |
+
argument.
|
| 277 |
+
dtype: (`paddle.dtype`, *optional*):
|
| 278 |
+
paddle dtype
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 282 |
+
if prompt is not None:
|
| 283 |
+
batch_size = len(prompt)
|
| 284 |
+
else:
|
| 285 |
+
batch_size = tuple(prompt_embeds.shape)[0]
|
| 286 |
+
if prompt_embeds is None:
|
| 287 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 288 |
+
prompt=prompt,
|
| 289 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 290 |
+
max_sequence_length=max_sequence_length,
|
| 291 |
+
dtype=dtype,
|
| 292 |
+
)
|
| 293 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 294 |
+
negative_prompt = negative_prompt or ""
|
| 295 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 296 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 297 |
+
raise TypeError(
|
| 298 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} != {type(prompt)}."
|
| 299 |
+
)
|
| 300 |
+
elif batch_size != len(negative_prompt):
|
| 301 |
+
raise ValueError(
|
| 302 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`: {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches the batch size of `prompt`."
|
| 303 |
+
)
|
| 304 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 305 |
+
prompt=negative_prompt,
|
| 306 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 307 |
+
max_sequence_length=max_sequence_length,
|
| 308 |
+
dtype=dtype,
|
| 309 |
+
)
|
| 310 |
+
return prompt_embeds, negative_prompt_embeds
|
| 311 |
+
|
| 312 |
+
def prepare_latents(
|
| 313 |
+
self, image, batch_size, num_channels_latents, num_frames, height, width, dtype, generator, latents=None
|
| 314 |
+
):
|
| 315 |
+
num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 316 |
+
shape = (
|
| 317 |
+
batch_size,
|
| 318 |
+
num_frames,
|
| 319 |
+
num_channels_latents,
|
| 320 |
+
height // self.vae_scale_factor_spatial,
|
| 321 |
+
width // self.vae_scale_factor_spatial,
|
| 322 |
+
)
|
| 323 |
+
image = image.unsqueeze(axis=2)
|
| 324 |
+
|
| 325 |
+
if isinstance(generator, list):
|
| 326 |
+
image_latents = [
|
| 327 |
+
retrieve_latents(self.vae.encode(image[i].unsqueeze(axis=0)), generator[i]) for i in range(batch_size)
|
| 328 |
+
]
|
| 329 |
+
else:
|
| 330 |
+
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(axis=0)), generator) for img in image]
|
| 331 |
+
|
| 332 |
+
image_latents = paddle.concat(x=image_latents, axis=0).to(dtype).transpose(perm=[0, 2, 1, 3, 4])
|
| 333 |
+
image_latents = self.vae_scaling_factor_image * image_latents
|
| 334 |
+
padding_shape = (
|
| 335 |
+
batch_size,
|
| 336 |
+
num_frames - 1,
|
| 337 |
+
num_channels_latents,
|
| 338 |
+
height // self.vae_scale_factor_spatial,
|
| 339 |
+
width // self.vae_scale_factor_spatial,
|
| 340 |
+
)
|
| 341 |
+
latent_padding = paddle.zeros(shape=padding_shape, dtype=dtype)
|
| 342 |
+
image_latents = paddle.concat(x=[image_latents, latent_padding], axis=1)
|
| 343 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 344 |
+
raise ValueError(
|
| 345 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 346 |
+
)
|
| 347 |
+
if latents is None:
|
| 348 |
+
latents = randn_tensor(shape, generator=generator, dtype=dtype)
|
| 349 |
+
else:
|
| 350 |
+
latents = latents
|
| 351 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 352 |
+
return latents, image_latents
|
| 353 |
+
|
| 354 |
+
def decode_latents(self, latents: paddle.Tensor) -> paddle.Tensor:
|
| 355 |
+
latents = latents.transpose(perm=[0, 2, 1, 3, 4])
|
| 356 |
+
latents = 1 / self.vae_scaling_factor_image * latents
|
| 357 |
+
frames = self.vae.decode(latents).sample
|
| 358 |
+
return frames
|
| 359 |
+
|
| 360 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 361 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 362 |
+
extra_step_kwargs = {}
|
| 363 |
+
if accepts_eta:
|
| 364 |
+
extra_step_kwargs["eta"] = eta
|
| 365 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 366 |
+
if accepts_generator:
|
| 367 |
+
extra_step_kwargs["generator"] = generator
|
| 368 |
+
return extra_step_kwargs
|
| 369 |
+
|
| 370 |
+
def check_inputs(
|
| 371 |
+
self,
|
| 372 |
+
image,
|
| 373 |
+
prompt,
|
| 374 |
+
height,
|
| 375 |
+
width,
|
| 376 |
+
negative_prompt,
|
| 377 |
+
callback_on_step_end_tensor_inputs,
|
| 378 |
+
latents=None,
|
| 379 |
+
prompt_embeds=None,
|
| 380 |
+
negative_prompt_embeds=None,
|
| 381 |
+
):
|
| 382 |
+
if (
|
| 383 |
+
not isinstance(image, paddle.Tensor)
|
| 384 |
+
and not isinstance(image, PIL.Image.Image)
|
| 385 |
+
and not isinstance(image, list)
|
| 386 |
+
):
|
| 387 |
+
raise ValueError(
|
| 388 |
+
f"`image` has to be of type `paddle.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is {type(image)}"
|
| 389 |
+
)
|
| 390 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 391 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 392 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 393 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 394 |
+
):
|
| 395 |
+
raise ValueError(
|
| 396 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 397 |
+
)
|
| 398 |
+
if prompt is not None and prompt_embeds is not None:
|
| 399 |
+
raise ValueError(
|
| 400 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to only forward one of the two."
|
| 401 |
+
)
|
| 402 |
+
elif prompt is None and prompt_embeds is None:
|
| 403 |
+
raise ValueError(
|
| 404 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 405 |
+
)
|
| 406 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 407 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 408 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 409 |
+
raise ValueError(
|
| 410 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 411 |
+
)
|
| 412 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 413 |
+
raise ValueError(
|
| 414 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 415 |
+
)
|
| 416 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 417 |
+
if tuple(prompt_embeds.shape) != tuple(negative_prompt_embeds.shape):
|
| 418 |
+
raise ValueError(
|
| 419 |
+
f"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` {tuple(prompt_embeds.shape)} != `negative_prompt_embeds` {tuple(negative_prompt_embeds.shape)}."
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
def fuse_qkv_projections(self) -> None:
|
| 423 |
+
"""Enables fused QKV projections."""
|
| 424 |
+
self.fusing_transformer = True
|
| 425 |
+
self.transformer.fuse_qkv_projections()
|
| 426 |
+
|
| 427 |
+
def unfuse_qkv_projections(self) -> None:
|
| 428 |
+
"""Disable QKV projection fusion if enabled."""
|
| 429 |
+
if not self.fusing_transformer:
|
| 430 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
| 431 |
+
else:
|
| 432 |
+
self.transformer.unfuse_qkv_projections()
|
| 433 |
+
self.fusing_transformer = False
|
| 434 |
+
|
| 435 |
+
def _prepare_rotary_positional_embeddings(
|
| 436 |
+
self,
|
| 437 |
+
height: int,
|
| 438 |
+
width: int,
|
| 439 |
+
num_frames: int,
|
| 440 |
+
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
| 441 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 442 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 443 |
+
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 444 |
+
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 445 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 446 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 447 |
+
)
|
| 448 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 449 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 450 |
+
crops_coords=grid_crops_coords,
|
| 451 |
+
grid_size=(grid_height, grid_width),
|
| 452 |
+
temporal_size=num_frames,
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
return freqs_cos, freqs_sin
|
| 456 |
+
|
| 457 |
+
def _prepare_v_ctrl_rotary_positional_embeddings(
|
| 458 |
+
self, height: int, width: int, num_frames: int
|
| 459 |
+
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
| 460 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.vctrl.config.patch_size)
|
| 461 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.vctrl.config.patch_size)
|
| 462 |
+
base_size_width = 720 // (self.vae_scale_factor_spatial * self.vctrl.config.patch_size)
|
| 463 |
+
base_size_height = 480 // (self.vae_scale_factor_spatial * self.vctrl.config.patch_size)
|
| 464 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 465 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 466 |
+
)
|
| 467 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 468 |
+
embed_dim=self.vctrl.config.attention_head_dim,
|
| 469 |
+
crops_coords=grid_crops_coords,
|
| 470 |
+
grid_size=(grid_height, grid_width),
|
| 471 |
+
temporal_size=num_frames,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
return freqs_cos, freqs_sin
|
| 475 |
+
|
| 476 |
+
def prepare_v_cond_frames(self, image, width, height, dtype):
|
| 477 |
+
image = self.vctrl_image_processor.preprocess(image, height=height, width=width)
|
| 478 |
+
control_images = image.unsqueeze(axis=0).to(dtype=dtype)
|
| 479 |
+
batch_size, num_frames, channels, height, width = tuple(control_images.shape)
|
| 480 |
+
conditioning_frames = control_images
|
| 481 |
+
conditioning_frames = conditioning_frames.transpose(perm=[0, 2, 1, 3, 4])
|
| 482 |
+
return conditioning_frames
|
| 483 |
+
|
| 484 |
+
def prepare_v_cond_video(
|
| 485 |
+
self, conditioning_frames: paddle.Tensor, num_frames: int, conditioning_frame_indices: int, dtype: paddle.dtype
|
| 486 |
+
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
| 487 |
+
assert tuple(conditioning_frames.shape)[2] >= len(conditioning_frame_indices)
|
| 488 |
+
batch_size, channels, _, height, width = tuple(conditioning_frames.shape)
|
| 489 |
+
vctrl_cond = paddle.zeros(shape=(batch_size, channels, num_frames, height, width), dtype=dtype)
|
| 490 |
+
vctrl_cond[:, :, conditioning_frame_indices] = conditioning_frames[:, :, conditioning_frame_indices]
|
| 491 |
+
return vctrl_cond
|
| 492 |
+
|
| 493 |
+
@property
|
| 494 |
+
def guidance_scale(self):
|
| 495 |
+
return self._guidance_scale
|
| 496 |
+
|
| 497 |
+
@property
|
| 498 |
+
def num_timesteps(self):
|
| 499 |
+
return self._num_timesteps
|
| 500 |
+
|
| 501 |
+
@property
|
| 502 |
+
def interrupt(self):
|
| 503 |
+
return self._interrupt
|
| 504 |
+
|
| 505 |
+
@paddle.no_grad()
|
| 506 |
+
def __call__(
|
| 507 |
+
self,
|
| 508 |
+
image: PipelineImageInput,
|
| 509 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 510 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 511 |
+
height: int = 480,
|
| 512 |
+
width: int = 720,
|
| 513 |
+
num_frames: int = 49,
|
| 514 |
+
num_inference_steps: int = 50,
|
| 515 |
+
timesteps: Optional[List[int]] = None,
|
| 516 |
+
guidance_scale: float = 6,
|
| 517 |
+
use_dynamic_cfg: bool = False,
|
| 518 |
+
num_videos_per_prompt: int = 1,
|
| 519 |
+
eta: float = 0.0,
|
| 520 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 521 |
+
latents: Optional[paddle.Tensor] = None,
|
| 522 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 523 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 524 |
+
output_type: str = "pil",
|
| 525 |
+
return_dict: bool = True,
|
| 526 |
+
callback_on_step_end: Optional[
|
| 527 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 528 |
+
] = None,
|
| 529 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 530 |
+
max_sequence_length: int = 226,
|
| 531 |
+
conditioning_frames: Optional[List[PipelineImageInput]] = None,
|
| 532 |
+
conditioning_frame_indices: List[int] = [0],
|
| 533 |
+
conditioning_scale: Union[float, List[float]] = 1.0,
|
| 534 |
+
task: Optional[str] = None,
|
| 535 |
+
conditioning_masks: Optional[List[PipelineImageInput]] = None,
|
| 536 |
+
vctrl_layout_type: str = "even",
|
| 537 |
+
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
| 538 |
+
"""
|
| 539 |
+
Function invoked when calling the pipeline for generation.
|
| 540 |
+
|
| 541 |
+
Args:
|
| 542 |
+
image (`PipelineImageInput`):
|
| 543 |
+
The input image to condition the generation on. Must be an image, a list of images or a `paddle.Tensor`.
|
| 544 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 545 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 546 |
+
instead.
|
| 547 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 548 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 549 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 550 |
+
less than `1`).
|
| 551 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 552 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 553 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 554 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 555 |
+
num_frames (`int`, defaults to `48`):
|
| 556 |
+
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
| 557 |
+
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
| 558 |
+
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
|
| 559 |
+
needs to be satisfied is that of divisibility mentioned above.
|
| 560 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 561 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 562 |
+
expense of slower inference.
|
| 563 |
+
timesteps (`List[int]`, *optional*):
|
| 564 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 565 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 566 |
+
passed will be used. Must be in descending order.
|
| 567 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 568 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 569 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 570 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 571 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 572 |
+
usually at the expense of lower image quality.
|
| 573 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 574 |
+
The number of videos to generate per prompt.
|
| 575 |
+
generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*):
|
| 576 |
+
One or a list of [paddle generator(s)](https://pypaddle.org/docs/stable/generated/paddle.Generator.html)
|
| 577 |
+
to make generation deterministic.
|
| 578 |
+
latents (`paddle.FloatTensor`, *optional*):
|
| 579 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 580 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 581 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 582 |
+
prompt_embeds (`paddle.FloatTensor`, *optional*):
|
| 583 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 584 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 585 |
+
negative_prompt_embeds (`paddle.FloatTensor`, *optional*):
|
| 586 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 587 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 588 |
+
argument.
|
| 589 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 590 |
+
The output format of the generate image. Choose between
|
| 591 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 592 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 593 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 594 |
+
of a plain tuple.
|
| 595 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 596 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 597 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 598 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 599 |
+
`callback_on_step_end_tensor_inputs`.
|
| 600 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 601 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 602 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 603 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 604 |
+
max_sequence_length (`int`, defaults to `226`):
|
| 605 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
| 606 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
| 607 |
+
|
| 608 |
+
Examples:
|
| 609 |
+
|
| 610 |
+
Returns:
|
| 611 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
|
| 612 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
|
| 613 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 614 |
+
"""
|
| 615 |
+
if num_frames > 49:
|
| 616 |
+
raise ValueError(
|
| 617 |
+
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
| 618 |
+
)
|
| 619 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 620 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 621 |
+
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
| 622 |
+
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
| 623 |
+
num_videos_per_prompt = 1
|
| 624 |
+
self.check_inputs(
|
| 625 |
+
image=image,
|
| 626 |
+
prompt=prompt,
|
| 627 |
+
height=height,
|
| 628 |
+
width=width,
|
| 629 |
+
negative_prompt=negative_prompt,
|
| 630 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 631 |
+
latents=latents,
|
| 632 |
+
prompt_embeds=prompt_embeds,
|
| 633 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 634 |
+
)
|
| 635 |
+
self._guidance_scale = guidance_scale
|
| 636 |
+
self._interrupt = False
|
| 637 |
+
if prompt is not None and isinstance(prompt, str):
|
| 638 |
+
batch_size = 1
|
| 639 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 640 |
+
batch_size = len(prompt)
|
| 641 |
+
else:
|
| 642 |
+
batch_size = tuple(prompt_embeds.shape)[0]
|
| 643 |
+
|
| 644 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 645 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 646 |
+
prompt=prompt,
|
| 647 |
+
negative_prompt=negative_prompt,
|
| 648 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 649 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 650 |
+
prompt_embeds=prompt_embeds,
|
| 651 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 652 |
+
max_sequence_length=max_sequence_length,
|
| 653 |
+
)
|
| 654 |
+
if do_classifier_free_guidance:
|
| 655 |
+
prompt_embeds = paddle.concat(x=[negative_prompt_embeds, prompt_embeds], axis=0)
|
| 656 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps)
|
| 657 |
+
self._num_timesteps = len(timesteps)
|
| 658 |
+
image = self.video_processor.preprocess(image, height=height, width=width).to(dtype=prompt_embeds.dtype)
|
| 659 |
+
|
| 660 |
+
latent_channels = self.transformer.config.in_channels // 2
|
| 661 |
+
latents, image_latents = self.prepare_latents(
|
| 662 |
+
image,
|
| 663 |
+
batch_size * num_videos_per_prompt,
|
| 664 |
+
latent_channels,
|
| 665 |
+
num_frames,
|
| 666 |
+
height,
|
| 667 |
+
width,
|
| 668 |
+
prompt_embeds.dtype,
|
| 669 |
+
generator,
|
| 670 |
+
latents,
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
if isinstance(self.vctrl, VCtrlModel):
|
| 674 |
+
v_cond_frames = self.prepare_v_cond_frames(conditioning_frames, width, height, self.vctrl.dtype)
|
| 675 |
+
v_cond = self.prepare_v_cond_video(v_cond_frames, num_frames, conditioning_frame_indices, self.vctrl.dtype)
|
| 676 |
+
|
| 677 |
+
v_cond = self.vae.encode(v_cond).latent_dist.sample()
|
| 678 |
+
|
| 679 |
+
(cond_batch_size, cond_channels, cond_frames, cond_height, cond_width) = tuple(v_cond.shape)
|
| 680 |
+
|
| 681 |
+
def map_frame_latent(indices):
|
| 682 |
+
latent_indices = []
|
| 683 |
+
visited = set()
|
| 684 |
+
for indice in indices:
|
| 685 |
+
if indice == 0:
|
| 686 |
+
value = 0
|
| 687 |
+
else:
|
| 688 |
+
value = (indice - 1) // 4 + 1
|
| 689 |
+
if value not in visited:
|
| 690 |
+
latent_indices.append(value)
|
| 691 |
+
return latent_indices
|
| 692 |
+
|
| 693 |
+
if task == "mask":
|
| 694 |
+
assert conditioning_masks is not None, "conditioning_masks must be provided when task is mask"
|
| 695 |
+
v_cond_mask = self.prepare_v_cond_frames(conditioning_masks, width, height, self.vctrl.dtype)
|
| 696 |
+
v_cond_mask = self.prepare_v_cond_video(
|
| 697 |
+
v_cond_mask, num_frames, conditioning_frame_indices, self.vctrl.dtype
|
| 698 |
+
)
|
| 699 |
+
v_cond_mask = self.vae.encode(v_cond_mask).latent_dist.sample()
|
| 700 |
+
v_cond_mask = v_cond_mask.mean(axis=1, keepdim=True)
|
| 701 |
+
else:
|
| 702 |
+
v_cond_mask = paddle.zeros(
|
| 703 |
+
shape=(cond_batch_size, 1, cond_frames, cond_height, cond_width), dtype=v_cond.dtype
|
| 704 |
+
)
|
| 705 |
+
v_cond_mask[:, :, map_frame_latent(conditioning_frame_indices)] = 1
|
| 706 |
+
v_cond = paddle.concat(x=[v_cond, v_cond_mask], axis=1)
|
| 707 |
+
v_cond = v_cond.transpose(perm=[0, 2, 1, 3, 4])
|
| 708 |
+
assert (
|
| 709 |
+
tuple(v_cond.shape)[3:] == tuple(latents.shape)[3:]
|
| 710 |
+
and tuple(v_cond.shape)[1] == tuple(latents.shape)[1]
|
| 711 |
+
), ("v_cond.shape:" + str(tuple(v_cond.shape)) + " latents.shape:" + str(tuple(latents.shape)))
|
| 712 |
+
else:
|
| 713 |
+
assert False
|
| 714 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 715 |
+
image_rotary_emb = (
|
| 716 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.shape[1])
|
| 717 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
| 718 |
+
else None
|
| 719 |
+
)
|
| 720 |
+
v_cond_rotary_emb = (
|
| 721 |
+
self._prepare_v_ctrl_rotary_positional_embeddings(height, width, v_cond.shape[1])
|
| 722 |
+
if self.vctrl.config.use_rotary_positional_embeddings
|
| 723 |
+
else None
|
| 724 |
+
)
|
| 725 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 726 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 727 |
+
old_pred_original_sample = None
|
| 728 |
+
for i, t in enumerate(timesteps):
|
| 729 |
+
if self.interrupt:
|
| 730 |
+
continue
|
| 731 |
+
latent_model_input = paddle.concat(x=[latents] * 2) if do_classifier_free_guidance else latents
|
| 732 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 733 |
+
latent_image_input = (
|
| 734 |
+
paddle.concat(x=[image_latents] * 2) if do_classifier_free_guidance else image_latents
|
| 735 |
+
)
|
| 736 |
+
latent_model_input = paddle.concat(x=[latent_model_input, latent_image_input], axis=2)
|
| 737 |
+
timestep = t.expand(shape=tuple(latent_model_input.shape)[0])
|
| 738 |
+
control_model_input = latent_model_input
|
| 739 |
+
|
| 740 |
+
vctrl_block_samples = self.vctrl(
|
| 741 |
+
control_model_input,
|
| 742 |
+
timestep,
|
| 743 |
+
v_cond=v_cond,
|
| 744 |
+
v_cond_scale=conditioning_scale,
|
| 745 |
+
image_rotary_emb=v_cond_rotary_emb,
|
| 746 |
+
return_dict=False,
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
noise_pred = self.transformer(
|
| 750 |
+
hidden_states=latent_model_input,
|
| 751 |
+
encoder_hidden_states=prompt_embeds,
|
| 752 |
+
timestep=timestep,
|
| 753 |
+
block_vctrl_residuals=vctrl_block_samples,
|
| 754 |
+
vctrl_layout_type=vctrl_layout_type,
|
| 755 |
+
image_rotary_emb=image_rotary_emb,
|
| 756 |
+
return_dict=False,
|
| 757 |
+
)
|
| 758 |
+
noise_pred = noise_pred.astype(dtype="float32")
|
| 759 |
+
|
| 760 |
+
if use_dynamic_cfg:
|
| 761 |
+
self._guidance_scale = 1 + guidance_scale * (
|
| 762 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
| 763 |
+
)
|
| 764 |
+
if do_classifier_free_guidance:
|
| 765 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(chunks=2)
|
| 766 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 767 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
| 768 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 769 |
+
else:
|
| 770 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
| 771 |
+
noise_pred,
|
| 772 |
+
old_pred_original_sample,
|
| 773 |
+
t,
|
| 774 |
+
timesteps[i - 1] if i > 0 else None,
|
| 775 |
+
latents,
|
| 776 |
+
**extra_step_kwargs,
|
| 777 |
+
return_dict=False,
|
| 778 |
+
)
|
| 779 |
+
latents = latents.to(prompt_embeds.dtype)
|
| 780 |
+
if callback_on_step_end is not None:
|
| 781 |
+
callback_kwargs = {}
|
| 782 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 783 |
+
callback_kwargs[k] = locals()[k]
|
| 784 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 785 |
+
latents = callback_outputs.pop("latents", latents)
|
| 786 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 787 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 788 |
+
if i == len(timesteps) - 1 or i + 1 > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
|
| 789 |
+
progress_bar.update()
|
| 790 |
+
|
| 791 |
+
if not output_type == "latent":
|
| 792 |
+
video = self.decode_latents(latents)
|
| 793 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 794 |
+
else:
|
| 795 |
+
video = latents
|
| 796 |
+
self.maybe_free_model_hooks()
|
| 797 |
+
if not return_dict:
|
| 798 |
+
return (video,)
|
| 799 |
+
return CogVideoXPipelineOutput(frames=video)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_cogvideox_vctrl.py
ADDED
|
@@ -0,0 +1,739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
import math
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import paddle
|
| 21 |
+
import paddlenlp
|
| 22 |
+
|
| 23 |
+
from ppdiffusers import CogVideoXTransformer3DVCtrlModel, VCtrlModel
|
| 24 |
+
from ppdiffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 25 |
+
from ppdiffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 26 |
+
from ppdiffusers.models import AutoencoderKLCogVideoX
|
| 27 |
+
from ppdiffusers.models.embeddings import get_3d_rotary_pos_embed
|
| 28 |
+
from ppdiffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 29 |
+
from ppdiffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
| 30 |
+
from ppdiffusers.utils import BaseOutput, logging
|
| 31 |
+
from ppdiffusers.utils.paddle_utils import randn_tensor
|
| 32 |
+
from ppdiffusers.video_processor import VideoProcessor
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def tensor2vid(video: paddle.Tensor, processor, output_type="np"):
|
| 38 |
+
batch_size, channels, num_frames, height, width = tuple(video.shape)
|
| 39 |
+
outputs = []
|
| 40 |
+
for batch_idx in range(batch_size):
|
| 41 |
+
batch_vid = video[batch_idx].transpose(perm=[1, 0, 2, 3])
|
| 42 |
+
batch_output = processor.postprocess(batch_vid, output_type)
|
| 43 |
+
outputs.append(batch_output)
|
| 44 |
+
return outputs
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 48 |
+
tw = tgt_width
|
| 49 |
+
th = tgt_height
|
| 50 |
+
h, w = src
|
| 51 |
+
r = h / w
|
| 52 |
+
if r > th / tw:
|
| 53 |
+
resize_height = th
|
| 54 |
+
resize_width = int(round(th / h * w))
|
| 55 |
+
else:
|
| 56 |
+
resize_width = tw
|
| 57 |
+
resize_height = int(round(tw / w * h))
|
| 58 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 59 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 60 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def retrieve_timesteps(
|
| 64 |
+
scheduler,
|
| 65 |
+
num_inference_steps: Optional[int] = None,
|
| 66 |
+
timesteps: Optional[List[int]] = None,
|
| 67 |
+
sigmas: Optional[List[float]] = None,
|
| 68 |
+
**kwargs
|
| 69 |
+
):
|
| 70 |
+
"""
|
| 71 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 72 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
scheduler (`SchedulerMixin`):
|
| 76 |
+
The scheduler to get timesteps from.
|
| 77 |
+
num_inference_steps (`int`):
|
| 78 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 79 |
+
must be `None`.
|
| 80 |
+
|
| 81 |
+
timesteps (`List[int]`, *optional*):
|
| 82 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 83 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 84 |
+
sigmas (`List[float]`, *optional*):
|
| 85 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 86 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
`Tuple[paddle.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 90 |
+
second element is the number of inference steps.
|
| 91 |
+
"""
|
| 92 |
+
if timesteps is not None and sigmas is not None:
|
| 93 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 94 |
+
if timesteps is not None:
|
| 95 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 96 |
+
if not accepts_timesteps:
|
| 97 |
+
raise ValueError(
|
| 98 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler."
|
| 99 |
+
)
|
| 100 |
+
scheduler.set_timesteps(timesteps=timesteps, **kwargs)
|
| 101 |
+
timesteps = scheduler.timesteps
|
| 102 |
+
num_inference_steps = len(timesteps)
|
| 103 |
+
elif sigmas is not None:
|
| 104 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 105 |
+
if not accept_sigmas:
|
| 106 |
+
raise ValueError(
|
| 107 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler."
|
| 108 |
+
)
|
| 109 |
+
scheduler.set_timesteps(sigmas=sigmas, **kwargs)
|
| 110 |
+
timesteps = scheduler.timesteps
|
| 111 |
+
num_inference_steps = len(timesteps)
|
| 112 |
+
else:
|
| 113 |
+
scheduler.set_timesteps(num_inference_steps, **kwargs)
|
| 114 |
+
timesteps = scheduler.timesteps
|
| 115 |
+
return timesteps, num_inference_steps
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@dataclass
|
| 119 |
+
class CogVideoXPipelineOutput(BaseOutput):
|
| 120 |
+
"""
|
| 121 |
+
Output class for CogVideo pipelines.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
frames (`paddle.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 125 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 126 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Paddle tensor of shape
|
| 127 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
frames: paddle.Tensor
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class CogVideoXVCtrlPipeline(DiffusionPipeline):
|
| 134 |
+
"""
|
| 135 |
+
Pipeline for text-to-video generation using CogVideoX with VCTRL.
|
| 136 |
+
|
| 137 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 138 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
vae ([`AutoencoderKL`]):
|
| 142 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 143 |
+
text_encoder ([`T5EncoderModel`]):
|
| 144 |
+
Frozen text-encoder. CogVideoX uses
|
| 145 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
| 146 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 147 |
+
tokenizer (`T5Tokenizer`):
|
| 148 |
+
Tokenizer of class
|
| 149 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 150 |
+
transformer ([`CogVideoXTransformer3DModel`]):
|
| 151 |
+
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
| 152 |
+
scheduler ([`SchedulerMixin`]):
|
| 153 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
_optional_components = []
|
| 157 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 158 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
tokenizer: paddlenlp.transformers.T5Tokenizer,
|
| 163 |
+
text_encoder: paddlenlp.transformers.T5EncoderModel,
|
| 164 |
+
vae: AutoencoderKLCogVideoX,
|
| 165 |
+
transformer: CogVideoXTransformer3DVCtrlModel,
|
| 166 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
| 167 |
+
vctrl: VCtrlModel,
|
| 168 |
+
):
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.register_modules(
|
| 171 |
+
tokenizer=tokenizer,
|
| 172 |
+
text_encoder=text_encoder,
|
| 173 |
+
vae=vae,
|
| 174 |
+
transformer=transformer,
|
| 175 |
+
scheduler=scheduler,
|
| 176 |
+
vctrl=vctrl,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
self.vae_scale_factor_spatial = (
|
| 180 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 181 |
+
)
|
| 182 |
+
self.vae_scale_factor_temporal = (
|
| 183 |
+
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
| 184 |
+
)
|
| 185 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 186 |
+
self.vctrl_image_processor = VaeImageProcessor(
|
| 187 |
+
vae_scale_factor=self.vae_scale_factor_spatial, do_convert_rgb=True, do_normalize=True
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def _get_t5_prompt_embeds(
|
| 191 |
+
self,
|
| 192 |
+
prompt: Union[str, List[str]] = None,
|
| 193 |
+
num_videos_per_prompt: int = 1,
|
| 194 |
+
max_sequence_length: int = 226,
|
| 195 |
+
dtype: Optional[paddle.dtype] = None,
|
| 196 |
+
):
|
| 197 |
+
dtype = dtype or self.text_encoder.dtype
|
| 198 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 199 |
+
batch_size = len(prompt)
|
| 200 |
+
text_inputs = self.tokenizer(
|
| 201 |
+
prompt,
|
| 202 |
+
padding="max_length",
|
| 203 |
+
max_length=max_sequence_length,
|
| 204 |
+
truncation=True,
|
| 205 |
+
add_special_tokens=True,
|
| 206 |
+
return_tensors="pd",
|
| 207 |
+
)
|
| 208 |
+
text_input_ids = text_inputs.input_ids
|
| 209 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pd").input_ids
|
| 210 |
+
if (
|
| 211 |
+
tuple(untruncated_ids.shape)[-1] >= tuple(text_input_ids.shape)[-1]
|
| 212 |
+
and not paddle.equal_all(x=text_input_ids, y=untruncated_ids).item()
|
| 213 |
+
):
|
| 214 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 215 |
+
logger.warning(
|
| 216 |
+
f"The following part of your input was truncated because `max_sequence_length` is set to {max_sequence_length} tokens: {removed_text}"
|
| 217 |
+
)
|
| 218 |
+
prompt_embeds = self.text_encoder(text_input_ids)[0]
|
| 219 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
| 220 |
+
_, seq_len, _ = tuple(prompt_embeds.shape)
|
| 221 |
+
prompt_embeds = prompt_embeds.tile(repeat_times=[1, num_videos_per_prompt, 1])
|
| 222 |
+
prompt_embeds = prompt_embeds.reshape([batch_size * num_videos_per_prompt, seq_len, -1])
|
| 223 |
+
return prompt_embeds
|
| 224 |
+
|
| 225 |
+
def encode_prompt(
|
| 226 |
+
self,
|
| 227 |
+
prompt: Union[str, List[str]],
|
| 228 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 229 |
+
do_classifier_free_guidance: bool = True,
|
| 230 |
+
num_videos_per_prompt: int = 1,
|
| 231 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 232 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 233 |
+
max_sequence_length: int = 226,
|
| 234 |
+
dtype: Optional[paddle.dtype] = None,
|
| 235 |
+
):
|
| 236 |
+
"""
|
| 237 |
+
Encodes the prompt into text encoder hidden states.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 241 |
+
prompt to be encoded
|
| 242 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 243 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 244 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 245 |
+
less than `1`).
|
| 246 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 247 |
+
Whether to use classifier free guidance or not.
|
| 248 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 249 |
+
Number of videos that should be generated per prompt.
|
| 250 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 251 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 252 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 253 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 254 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 255 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 256 |
+
argument.
|
| 257 |
+
|
| 258 |
+
dtype: (`paddle.dtype`, *optional*):
|
| 259 |
+
paddle dtype
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 263 |
+
if prompt is not None:
|
| 264 |
+
batch_size = len(prompt)
|
| 265 |
+
else:
|
| 266 |
+
batch_size = tuple(prompt_embeds.shape)[0]
|
| 267 |
+
if prompt_embeds is None:
|
| 268 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 269 |
+
prompt=prompt,
|
| 270 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 271 |
+
max_sequence_length=max_sequence_length,
|
| 272 |
+
dtype=dtype,
|
| 273 |
+
)
|
| 274 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 275 |
+
negative_prompt = negative_prompt or ""
|
| 276 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 277 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 278 |
+
raise TypeError(
|
| 279 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} != {type(prompt)}."
|
| 280 |
+
)
|
| 281 |
+
elif batch_size != len(negative_prompt):
|
| 282 |
+
raise ValueError(
|
| 283 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`: {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches the batch size of `prompt`."
|
| 284 |
+
)
|
| 285 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 286 |
+
prompt=negative_prompt,
|
| 287 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 288 |
+
max_sequence_length=max_sequence_length,
|
| 289 |
+
dtype=dtype,
|
| 290 |
+
)
|
| 291 |
+
return prompt_embeds, negative_prompt_embeds
|
| 292 |
+
|
| 293 |
+
def prepare_latents(
|
| 294 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, generator, latents=None
|
| 295 |
+
):
|
| 296 |
+
shape = (
|
| 297 |
+
batch_size,
|
| 298 |
+
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
| 299 |
+
num_channels_latents,
|
| 300 |
+
height // self.vae_scale_factor_spatial,
|
| 301 |
+
width // self.vae_scale_factor_spatial,
|
| 302 |
+
)
|
| 303 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 304 |
+
raise ValueError(
|
| 305 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 306 |
+
)
|
| 307 |
+
if latents is None:
|
| 308 |
+
latents = randn_tensor(shape, generator=generator, dtype=dtype)
|
| 309 |
+
|
| 310 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 311 |
+
return latents
|
| 312 |
+
|
| 313 |
+
def decode_latents(self, latents: paddle.Tensor) -> paddle.Tensor:
|
| 314 |
+
latents = latents.transpose(perm=[0, 2, 1, 3, 4])
|
| 315 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 316 |
+
frames = self.vae.decode(latents).sample
|
| 317 |
+
return frames
|
| 318 |
+
|
| 319 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 320 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 321 |
+
extra_step_kwargs = {}
|
| 322 |
+
if accepts_eta:
|
| 323 |
+
extra_step_kwargs["eta"] = eta
|
| 324 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 325 |
+
if accepts_generator:
|
| 326 |
+
extra_step_kwargs["generator"] = generator
|
| 327 |
+
return extra_step_kwargs
|
| 328 |
+
|
| 329 |
+
def check_inputs(
|
| 330 |
+
self,
|
| 331 |
+
prompt,
|
| 332 |
+
height,
|
| 333 |
+
width,
|
| 334 |
+
negative_prompt,
|
| 335 |
+
callback_on_step_end_tensor_inputs,
|
| 336 |
+
prompt_embeds=None,
|
| 337 |
+
negative_prompt_embeds=None,
|
| 338 |
+
):
|
| 339 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 340 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 341 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 342 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 343 |
+
):
|
| 344 |
+
raise ValueError(
|
| 345 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 346 |
+
)
|
| 347 |
+
if prompt is not None and prompt_embeds is not None:
|
| 348 |
+
raise ValueError(
|
| 349 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to only forward one of the two."
|
| 350 |
+
)
|
| 351 |
+
elif prompt is None and prompt_embeds is None:
|
| 352 |
+
raise ValueError(
|
| 353 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 354 |
+
)
|
| 355 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 356 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 357 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 358 |
+
raise ValueError(
|
| 359 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 360 |
+
)
|
| 361 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 362 |
+
raise ValueError(
|
| 363 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 364 |
+
)
|
| 365 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 366 |
+
if tuple(prompt_embeds.shape) != tuple(negative_prompt_embeds.shape):
|
| 367 |
+
raise ValueError(
|
| 368 |
+
f"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` {tuple(prompt_embeds.shape)} != `negative_prompt_embeds` {tuple(negative_prompt_embeds.shape)}."
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
def fuse_qkv_projections(self) -> None:
|
| 372 |
+
"""Enables fused QKV projections."""
|
| 373 |
+
self.fusing_transformer = True
|
| 374 |
+
self.transformer.fuse_qkv_projections()
|
| 375 |
+
|
| 376 |
+
def unfuse_qkv_projections(self) -> None:
|
| 377 |
+
"""Disable QKV projection fusion if enabled."""
|
| 378 |
+
if not self.fusing_transformer:
|
| 379 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
| 380 |
+
else:
|
| 381 |
+
self.transformer.unfuse_qkv_projections()
|
| 382 |
+
self.fusing_transformer = False
|
| 383 |
+
|
| 384 |
+
def _prepare_rotary_positional_embeddings(
|
| 385 |
+
self, height: int, width: int, num_frames: int
|
| 386 |
+
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
| 387 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 388 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 389 |
+
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 390 |
+
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 391 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 392 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 393 |
+
)
|
| 394 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 395 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 396 |
+
crops_coords=grid_crops_coords,
|
| 397 |
+
grid_size=(grid_height, grid_width),
|
| 398 |
+
temporal_size=num_frames,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
return freqs_cos, freqs_sin
|
| 402 |
+
|
| 403 |
+
def _prepare_v_ctrl_rotary_positional_embeddings(
|
| 404 |
+
self, height: int, width: int, num_frames: int
|
| 405 |
+
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
| 406 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.vctrl.config.patch_size)
|
| 407 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.vctrl.config.patch_size)
|
| 408 |
+
base_size_width = 720 // (self.vae_scale_factor_spatial * self.vctrl.config.patch_size)
|
| 409 |
+
base_size_height = 480 // (self.vae_scale_factor_spatial * self.vctrl.config.patch_size)
|
| 410 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 411 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 412 |
+
)
|
| 413 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 414 |
+
embed_dim=self.vctrl.config.attention_head_dim,
|
| 415 |
+
crops_coords=grid_crops_coords,
|
| 416 |
+
grid_size=(grid_height, grid_width),
|
| 417 |
+
temporal_size=num_frames,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
return freqs_cos, freqs_sin
|
| 421 |
+
|
| 422 |
+
def prepare_v_cond_frames(self, image, width, height, dtype):
|
| 423 |
+
image = self.vctrl_image_processor.preprocess(image, height=height, width=width)
|
| 424 |
+
control_images = image.unsqueeze(axis=0).to(dtype)
|
| 425 |
+
batch_size, num_frames, channels, height, width = tuple(control_images.shape)
|
| 426 |
+
conditioning_frames = control_images
|
| 427 |
+
conditioning_frames = conditioning_frames.transpose(perm=[0, 2, 1, 3, 4])
|
| 428 |
+
return conditioning_frames
|
| 429 |
+
|
| 430 |
+
def prepare_v_cond_video(
|
| 431 |
+
self, conditioning_frames: paddle.Tensor, num_frames: int, conditioning_frame_indices: int, dtype: paddle.dtype
|
| 432 |
+
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
| 433 |
+
assert tuple(conditioning_frames.shape)[2] >= len(conditioning_frame_indices)
|
| 434 |
+
batch_size, channels, _, height, width = tuple(conditioning_frames.shape)
|
| 435 |
+
vctrl_cond = paddle.zeros(shape=(batch_size, channels, num_frames, height, width), dtype=dtype)
|
| 436 |
+
vctrl_cond[:, :, conditioning_frame_indices] = conditioning_frames[:, :, conditioning_frame_indices]
|
| 437 |
+
return vctrl_cond
|
| 438 |
+
|
| 439 |
+
@property
|
| 440 |
+
def guidance_scale(self):
|
| 441 |
+
return self._guidance_scale
|
| 442 |
+
|
| 443 |
+
@property
|
| 444 |
+
def num_timesteps(self):
|
| 445 |
+
return self._num_timesteps
|
| 446 |
+
|
| 447 |
+
@property
|
| 448 |
+
def interrupt(self):
|
| 449 |
+
return self._interrupt
|
| 450 |
+
|
| 451 |
+
@paddle.no_grad()
|
| 452 |
+
def __call__(
|
| 453 |
+
self,
|
| 454 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 455 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 456 |
+
height: int = 480,
|
| 457 |
+
width: int = 720,
|
| 458 |
+
num_frames: int = 49,
|
| 459 |
+
num_inference_steps: int = 50,
|
| 460 |
+
timesteps: Optional[List[int]] = None,
|
| 461 |
+
guidance_scale: float = 6,
|
| 462 |
+
use_dynamic_cfg: bool = False,
|
| 463 |
+
num_videos_per_prompt: int = 1,
|
| 464 |
+
eta: float = 0.0,
|
| 465 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 466 |
+
latents: Optional[paddle.Tensor] = None,
|
| 467 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 468 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 469 |
+
output_type: str = "pil",
|
| 470 |
+
return_dict: bool = True,
|
| 471 |
+
callback_on_step_end: Optional[
|
| 472 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 473 |
+
] = None,
|
| 474 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 475 |
+
max_sequence_length: int = 226,
|
| 476 |
+
conditioning_frames: Optional[List[PipelineImageInput]] = None,
|
| 477 |
+
conditioning_frame_indices: List[int] = [0],
|
| 478 |
+
conditioning_scale: Union[float, List[float]] = 1.0,
|
| 479 |
+
task: Optional[str] = None,
|
| 480 |
+
conditioning_masks: Optional[List[PipelineImageInput]] = None,
|
| 481 |
+
vctrl_layout_type: str = "even",
|
| 482 |
+
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
| 483 |
+
"""
|
| 484 |
+
Function invoked when calling the pipeline for generation.
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 488 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 489 |
+
instead.
|
| 490 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 491 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 492 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 493 |
+
less than `1`).
|
| 494 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 495 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 496 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 497 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 498 |
+
num_frames (`int`, defaults to `48`):
|
| 499 |
+
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
| 500 |
+
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
| 501 |
+
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
|
| 502 |
+
needs to be satisfied is that of divisibility mentioned above.
|
| 503 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 504 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 505 |
+
expense of slower inference.
|
| 506 |
+
timesteps (`List[int]`, *optional*):
|
| 507 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 508 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 509 |
+
passed will be used. Must be in descending order.
|
| 510 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 511 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 512 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 513 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 514 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 515 |
+
usually at the expense of lower image quality.
|
| 516 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 517 |
+
The number of videos to generate per prompt.
|
| 518 |
+
generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*):
|
| 519 |
+
One or a list of [paddle generator(s)](https://pytorch.org/docs/stable/generated/paddle.Generator.html)
|
| 520 |
+
to make generation deterministic.
|
| 521 |
+
latents (`paddle.FloatTensor`, *optional*):
|
| 522 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 523 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 524 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 525 |
+
prompt_embeds (`paddle.FloatTensor`, *optional*):
|
| 526 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 527 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 528 |
+
negative_prompt_embeds (`paddle.FloatTensor`, *optional*):
|
| 529 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 530 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 531 |
+
argument.
|
| 532 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 533 |
+
The output format of the generate image. Choose between
|
| 534 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 535 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 536 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 537 |
+
of a plain tuple.
|
| 538 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 539 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 540 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 541 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 542 |
+
`callback_on_step_end_tensor_inputs`.
|
| 543 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 544 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 545 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 546 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 547 |
+
max_sequence_length (`int`, defaults to `226`):
|
| 548 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
| 549 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
| 550 |
+
|
| 551 |
+
Examples:
|
| 552 |
+
|
| 553 |
+
Returns:
|
| 554 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
|
| 555 |
+
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
|
| 556 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 557 |
+
"""
|
| 558 |
+
if num_frames > 49:
|
| 559 |
+
raise ValueError(
|
| 560 |
+
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
| 561 |
+
)
|
| 562 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 563 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 564 |
+
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
| 565 |
+
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
|
| 566 |
+
num_videos_per_prompt = 1
|
| 567 |
+
self.check_inputs(
|
| 568 |
+
prompt,
|
| 569 |
+
height,
|
| 570 |
+
width,
|
| 571 |
+
negative_prompt,
|
| 572 |
+
callback_on_step_end_tensor_inputs,
|
| 573 |
+
prompt_embeds,
|
| 574 |
+
negative_prompt_embeds,
|
| 575 |
+
)
|
| 576 |
+
self._guidance_scale = guidance_scale
|
| 577 |
+
self._interrupt = False
|
| 578 |
+
if prompt is not None and isinstance(prompt, str):
|
| 579 |
+
batch_size = 1
|
| 580 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 581 |
+
batch_size = len(prompt)
|
| 582 |
+
else:
|
| 583 |
+
batch_size = tuple(prompt_embeds.shape)[0]
|
| 584 |
+
|
| 585 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 586 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 587 |
+
prompt,
|
| 588 |
+
negative_prompt,
|
| 589 |
+
do_classifier_free_guidance,
|
| 590 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 591 |
+
prompt_embeds=prompt_embeds,
|
| 592 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 593 |
+
max_sequence_length=max_sequence_length,
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
if do_classifier_free_guidance:
|
| 597 |
+
prompt_embeds = paddle.concat(x=[negative_prompt_embeds, prompt_embeds], axis=0)
|
| 598 |
+
|
| 599 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps)
|
| 600 |
+
self._num_timesteps = len(timesteps)
|
| 601 |
+
latent_channels = self.transformer.config.in_channels
|
| 602 |
+
latents = self.prepare_latents(
|
| 603 |
+
batch_size * num_videos_per_prompt,
|
| 604 |
+
latent_channels,
|
| 605 |
+
num_frames,
|
| 606 |
+
height,
|
| 607 |
+
width,
|
| 608 |
+
prompt_embeds.dtype,
|
| 609 |
+
generator,
|
| 610 |
+
latents,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
if isinstance(self.vctrl, VCtrlModel):
|
| 614 |
+
v_cond_frames = self.prepare_v_cond_frames(conditioning_frames, width, height, self.vctrl.dtype)
|
| 615 |
+
v_cond = self.prepare_v_cond_video(v_cond_frames, num_frames, conditioning_frame_indices, self.vctrl.dtype)
|
| 616 |
+
v_cond = self.vae.encode(v_cond).latent_dist.sample()
|
| 617 |
+
|
| 618 |
+
(cond_batch_size, cond_channels, cond_frames, cond_height, cond_width) = tuple(v_cond.shape)
|
| 619 |
+
|
| 620 |
+
def map_frame_latent(indices):
|
| 621 |
+
latent_indices = []
|
| 622 |
+
visited = set()
|
| 623 |
+
for indice in indices:
|
| 624 |
+
if indice == 0:
|
| 625 |
+
value = 0
|
| 626 |
+
else:
|
| 627 |
+
value = (indice - 1) // 4 + 1
|
| 628 |
+
if value not in visited:
|
| 629 |
+
latent_indices.append(value)
|
| 630 |
+
return latent_indices
|
| 631 |
+
|
| 632 |
+
if task == "mask":
|
| 633 |
+
assert conditioning_masks is not None, "conditioning_masks must be provided when task is mask"
|
| 634 |
+
v_cond_mask = self.prepare_v_cond_frames(conditioning_masks, width, height, self.vctrl.dtype)
|
| 635 |
+
v_cond_mask = self.prepare_v_cond_video(
|
| 636 |
+
v_cond_mask, num_frames, conditioning_frame_indices, self.vctrl.dtype
|
| 637 |
+
)
|
| 638 |
+
v_cond_mask = self.vae.encode(v_cond_mask).latent_dist.sample()
|
| 639 |
+
v_cond_mask = v_cond_mask.mean(axis=1, keepdim=True)
|
| 640 |
+
else:
|
| 641 |
+
v_cond_mask = paddle.zeros(
|
| 642 |
+
shape=(cond_batch_size, 1, cond_frames, cond_height, cond_width), dtype=v_cond.dtype
|
| 643 |
+
)
|
| 644 |
+
v_cond_mask[:, :, map_frame_latent(conditioning_frame_indices)] = 1
|
| 645 |
+
v_cond = paddle.concat(x=[v_cond, v_cond_mask], axis=1)
|
| 646 |
+
v_cond = v_cond.transpose(perm=[0, 2, 1, 3, 4])
|
| 647 |
+
assert (
|
| 648 |
+
tuple(v_cond.shape)[3:] == tuple(latents.shape)[3:]
|
| 649 |
+
and tuple(v_cond.shape)[1] == tuple(latents.shape)[1]
|
| 650 |
+
)
|
| 651 |
+
else:
|
| 652 |
+
assert False
|
| 653 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 654 |
+
image_rotary_emb = (
|
| 655 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.shape[1])
|
| 656 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
| 657 |
+
else None
|
| 658 |
+
)
|
| 659 |
+
v_cond_rotary_emb = (
|
| 660 |
+
self._prepare_v_ctrl_rotary_positional_embeddings(height, width, v_cond.shape[1])
|
| 661 |
+
if self.vctrl.config.use_rotary_positional_embeddings
|
| 662 |
+
else None
|
| 663 |
+
)
|
| 664 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 665 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 666 |
+
old_pred_original_sample = None
|
| 667 |
+
for i, t in enumerate(timesteps):
|
| 668 |
+
if self.interrupt:
|
| 669 |
+
continue
|
| 670 |
+
latent_model_input = paddle.concat(x=[latents] * 2) if do_classifier_free_guidance else latents
|
| 671 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 672 |
+
timestep = t.expand(shape=tuple(latent_model_input.shape)[0])
|
| 673 |
+
control_model_input = latent_model_input
|
| 674 |
+
|
| 675 |
+
vctrl_block_samples = self.vctrl(
|
| 676 |
+
control_model_input,
|
| 677 |
+
timestep,
|
| 678 |
+
v_cond=v_cond,
|
| 679 |
+
v_cond_scale=conditioning_scale,
|
| 680 |
+
image_rotary_emb=v_cond_rotary_emb,
|
| 681 |
+
return_dict=False,
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
noise_pred = self.transformer(
|
| 685 |
+
hidden_states=latent_model_input,
|
| 686 |
+
encoder_hidden_states=prompt_embeds,
|
| 687 |
+
timestep=timestep,
|
| 688 |
+
block_vctrl_residuals=vctrl_block_samples,
|
| 689 |
+
vctrl_layout_type=vctrl_layout_type,
|
| 690 |
+
image_rotary_emb=image_rotary_emb,
|
| 691 |
+
return_dict=False,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
noise_pred = noise_pred.astype(dtype="float32")
|
| 695 |
+
|
| 696 |
+
if use_dynamic_cfg:
|
| 697 |
+
self._guidance_scale = 1 + guidance_scale * (
|
| 698 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
| 699 |
+
)
|
| 700 |
+
if do_classifier_free_guidance:
|
| 701 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(chunks=2)
|
| 702 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 703 |
+
|
| 704 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
| 705 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 706 |
+
else:
|
| 707 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
| 708 |
+
noise_pred,
|
| 709 |
+
old_pred_original_sample,
|
| 710 |
+
t,
|
| 711 |
+
timesteps[i - 1] if i > 0 else None,
|
| 712 |
+
latents,
|
| 713 |
+
**extra_step_kwargs,
|
| 714 |
+
return_dict=False,
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
latents = latents.to(prompt_embeds.dtype)
|
| 718 |
+
|
| 719 |
+
if callback_on_step_end is not None:
|
| 720 |
+
callback_kwargs = {}
|
| 721 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 722 |
+
callback_kwargs[k] = locals()[k]
|
| 723 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 724 |
+
latents = callback_outputs.pop("latents", latents)
|
| 725 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 726 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 727 |
+
|
| 728 |
+
if i == len(timesteps) - 1 or i + 1 > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
|
| 729 |
+
progress_bar.update()
|
| 730 |
+
|
| 731 |
+
if not output_type == "latent":
|
| 732 |
+
video = self.decode_latents(latents)
|
| 733 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 734 |
+
else:
|
| 735 |
+
video = latents
|
| 736 |
+
self.maybe_free_model_hooks()
|
| 737 |
+
if not return_dict:
|
| 738 |
+
return (video,)
|
| 739 |
+
return CogVideoXPipelineOutput(frames=video)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_output.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import paddle
|
| 4 |
+
|
| 5 |
+
from ...utils import BaseOutput
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class CogVideoXPipelineOutput(BaseOutput):
|
| 10 |
+
r"""
|
| 11 |
+
Output class for CogVideo pipelines.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
frames (`paddle.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 15 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 16 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 17 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
frames: paddle.Tensor
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Callable, Dict, List, Optional, Union
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import paddle
|
| 19 |
+
import PIL.Image
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
+
from ...models import UNet2DConditionModel, VQModel
|
| 23 |
+
from ...schedulers import DDPMScheduler
|
| 24 |
+
from ...utils import deprecate, logging
|
| 25 |
+
from ...utils.paddle_utils import randn_tensor
|
| 26 |
+
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 29 |
+
|
| 30 |
+
EXAMPLE_DOC_STRING = """
|
| 31 |
+
Examples:
|
| 32 |
+
```py
|
| 33 |
+
>>> from ppdiffusers import KandinskyV22Img2ImgPipeline, KandinskyV22PriorPipeline
|
| 34 |
+
>>> from ppdiffusers.utils import load_image
|
| 35 |
+
>>> import paddle
|
| 36 |
+
|
| 37 |
+
>>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
|
| 38 |
+
... "kandinsky-community/kandinsky-2-2-prior", paddle_dtype=paddle.float16
|
| 39 |
+
... )
|
| 40 |
+
|
| 41 |
+
>>> prompt = "A red cartoon frog, 4k"
|
| 42 |
+
>>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)
|
| 43 |
+
|
| 44 |
+
>>> pipe = KandinskyV22Img2ImgPipeline.from_pretrained(
|
| 45 |
+
... "kandinsky-community/kandinsky-2-2-decoder", paddle_dtype=paddle.float16
|
| 46 |
+
... )
|
| 47 |
+
|
| 48 |
+
>>> init_image = load_image(
|
| 49 |
+
... "https://hf-mirror.com/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
| 50 |
+
... "/kandinsky/frog.png"
|
| 51 |
+
... )
|
| 52 |
+
|
| 53 |
+
>>> image = pipe(
|
| 54 |
+
... image=init_image,
|
| 55 |
+
... image_embeds=image_emb,
|
| 56 |
+
... negative_image_embeds=zero_image_emb,
|
| 57 |
+
... height=768,
|
| 58 |
+
... width=768,
|
| 59 |
+
... num_inference_steps=100,
|
| 60 |
+
... strength=0.2,
|
| 61 |
+
... ).images
|
| 62 |
+
|
| 63 |
+
>>> image[0].save("red_frog.png")
|
| 64 |
+
```
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# Copied from ppdiffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
|
| 69 |
+
def downscale_height_and_width(height, width, scale_factor=8):
|
| 70 |
+
new_height = height // scale_factor**2
|
| 71 |
+
if height % scale_factor**2 != 0:
|
| 72 |
+
new_height += 1
|
| 73 |
+
new_width = width // scale_factor**2
|
| 74 |
+
if width % scale_factor**2 != 0:
|
| 75 |
+
new_width += 1
|
| 76 |
+
return new_height * scale_factor, new_width * scale_factor
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# Copied from ppdiffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
|
| 80 |
+
def prepare_image(pil_image, w=512, h=512):
|
| 81 |
+
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
|
| 82 |
+
arr = np.array(pil_image.convert("RGB"))
|
| 83 |
+
arr = arr.astype(np.float32) / 127.5 - 1
|
| 84 |
+
arr = np.transpose(arr, [2, 0, 1])
|
| 85 |
+
image = paddle.to_tensor(arr).unsqueeze(0)
|
| 86 |
+
return image
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
|
| 90 |
+
"""
|
| 91 |
+
Pipeline for image-to-image generation using Kandinsky
|
| 92 |
+
|
| 93 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 94 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
scheduler ([`DDIMScheduler`]):
|
| 98 |
+
A scheduler to be used in combination with `unet` to generate image latents.
|
| 99 |
+
unet ([`UNet2DConditionModel`]):
|
| 100 |
+
Conditional U-Net architecture to denoise the image embedding.
|
| 101 |
+
movq ([`VQModel`]):
|
| 102 |
+
MoVQ Decoder to generate the image from the latents.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
model_cpu_offload_seq = "unet->movq"
|
| 106 |
+
_callback_tensor_inputs = ["latents", "image_embeds", "negative_image_embeds"]
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
unet: UNet2DConditionModel,
|
| 111 |
+
scheduler: DDPMScheduler,
|
| 112 |
+
movq: VQModel,
|
| 113 |
+
):
|
| 114 |
+
super().__init__()
|
| 115 |
+
|
| 116 |
+
self.register_modules(
|
| 117 |
+
unet=unet,
|
| 118 |
+
scheduler=scheduler,
|
| 119 |
+
movq=movq,
|
| 120 |
+
)
|
| 121 |
+
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
|
| 122 |
+
|
| 123 |
+
# Copied from ppdiffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
|
| 124 |
+
def get_timesteps(self, num_inference_steps, strength):
|
| 125 |
+
# get the original timestep using init_timestep
|
| 126 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 127 |
+
|
| 128 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 129 |
+
timesteps = self.scheduler.timesteps[t_start:]
|
| 130 |
+
|
| 131 |
+
return timesteps, num_inference_steps - t_start
|
| 132 |
+
|
| 133 |
+
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None):
|
| 134 |
+
if not isinstance(image, (paddle.Tensor, PIL.Image.Image, list)):
|
| 135 |
+
raise ValueError(
|
| 136 |
+
f"`image` has to be of type `paddle.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
image = image.cast(dtype=dtype)
|
| 140 |
+
|
| 141 |
+
batch_size = batch_size * num_images_per_prompt
|
| 142 |
+
|
| 143 |
+
if image.shape[1] == 4:
|
| 144 |
+
init_latents = image
|
| 145 |
+
|
| 146 |
+
else:
|
| 147 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 148 |
+
raise ValueError(
|
| 149 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 150 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
elif isinstance(generator, list):
|
| 154 |
+
init_latents = [
|
| 155 |
+
self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
| 156 |
+
]
|
| 157 |
+
init_latents = paddle.concat(init_latents, axis=0)
|
| 158 |
+
else:
|
| 159 |
+
init_latents = self.movq.encode(image).latent_dist.sample(generator)
|
| 160 |
+
|
| 161 |
+
init_latents = self.movq.config.scaling_factor * init_latents
|
| 162 |
+
|
| 163 |
+
init_latents = paddle.concat([init_latents], axis=0)
|
| 164 |
+
|
| 165 |
+
shape = init_latents.shape
|
| 166 |
+
noise = randn_tensor(shape, generator=generator, dtype=dtype)
|
| 167 |
+
|
| 168 |
+
# get latents
|
| 169 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
| 170 |
+
|
| 171 |
+
latents = init_latents
|
| 172 |
+
|
| 173 |
+
return latents
|
| 174 |
+
|
| 175 |
+
@property
|
| 176 |
+
def guidance_scale(self):
|
| 177 |
+
return self._guidance_scale
|
| 178 |
+
|
| 179 |
+
@property
|
| 180 |
+
def do_classifier_free_guidance(self):
|
| 181 |
+
return self._guidance_scale > 1
|
| 182 |
+
|
| 183 |
+
@property
|
| 184 |
+
def num_timesteps(self):
|
| 185 |
+
return self._num_timesteps
|
| 186 |
+
|
| 187 |
+
@paddle.no_grad()
|
| 188 |
+
def __call__(
|
| 189 |
+
self,
|
| 190 |
+
image_embeds: Union[paddle.Tensor, List[paddle.Tensor]],
|
| 191 |
+
image: Union[paddle.Tensor, PIL.Image.Image, List[paddle.Tensor], List[PIL.Image.Image]],
|
| 192 |
+
negative_image_embeds: Union[paddle.Tensor, List[paddle.Tensor]],
|
| 193 |
+
height: int = 512,
|
| 194 |
+
width: int = 512,
|
| 195 |
+
num_inference_steps: int = 100,
|
| 196 |
+
guidance_scale: float = 4.0,
|
| 197 |
+
strength: float = 0.3,
|
| 198 |
+
num_images_per_prompt: int = 1,
|
| 199 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 200 |
+
output_type: Optional[str] = "pil",
|
| 201 |
+
return_dict: bool = True,
|
| 202 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 203 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 204 |
+
**kwargs,
|
| 205 |
+
):
|
| 206 |
+
"""
|
| 207 |
+
Function invoked when calling the pipeline for generation.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
image_embeds (`paddle.Tensor` or `List[paddle.Tensor]`):
|
| 211 |
+
The clip image embeddings for text prompt, that will be used to condition the image generation.
|
| 212 |
+
image (`paddle.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[paddle.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
| 213 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 214 |
+
process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded
|
| 215 |
+
again.
|
| 216 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 217 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
| 218 |
+
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
| 219 |
+
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
| 220 |
+
be maximum and the denoising process will run for the full number of iterations specified in
|
| 221 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
| 222 |
+
negative_image_embeds (`paddle.Tensor` or `List[paddle.Tensor]`):
|
| 223 |
+
The clip image embeddings for negative text prompt, will be used to condition the image generation.
|
| 224 |
+
height (`int`, *optional*, defaults to 512):
|
| 225 |
+
The height in pixels of the generated image.
|
| 226 |
+
width (`int`, *optional*, defaults to 512):
|
| 227 |
+
The width in pixels of the generated image.
|
| 228 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
| 229 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 230 |
+
expense of slower inference.
|
| 231 |
+
guidance_scale (`float`, *optional*, defaults to 4.0):
|
| 232 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 233 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 234 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 235 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 236 |
+
usually at the expense of lower image quality.
|
| 237 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 238 |
+
The number of images to generate per prompt.
|
| 239 |
+
generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*):
|
| 240 |
+
One or a list of [paddle generator(s)] to make generation deterministic.
|
| 241 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 242 |
+
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
| 243 |
+
(`np.array`) or `"pd"` (`paddle.Tensor`).
|
| 244 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 245 |
+
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
| 246 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 247 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 248 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 249 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 250 |
+
`callback_on_step_end_tensor_inputs`.
|
| 251 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 252 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 253 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 254 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 255 |
+
|
| 256 |
+
Examples:
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
callback = kwargs.pop("callback", None)
|
| 263 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 264 |
+
|
| 265 |
+
if callback is not None:
|
| 266 |
+
deprecate(
|
| 267 |
+
"callback",
|
| 268 |
+
"1.0.0",
|
| 269 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 270 |
+
)
|
| 271 |
+
if callback_steps is not None:
|
| 272 |
+
deprecate(
|
| 273 |
+
"callback_steps",
|
| 274 |
+
"1.0.0",
|
| 275 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 279 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 280 |
+
):
|
| 281 |
+
raise ValueError(
|
| 282 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
self._guidance_scale = guidance_scale
|
| 286 |
+
|
| 287 |
+
if isinstance(image_embeds, list):
|
| 288 |
+
image_embeds = paddle.concat(image_embeds, axis=0)
|
| 289 |
+
batch_size = image_embeds.shape[0]
|
| 290 |
+
if isinstance(negative_image_embeds, list):
|
| 291 |
+
negative_image_embeds = paddle.concat(negative_image_embeds, axis=0)
|
| 292 |
+
|
| 293 |
+
if self.do_classifier_free_guidance:
|
| 294 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, axis=0)
|
| 295 |
+
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, axis=0)
|
| 296 |
+
|
| 297 |
+
image_embeds = paddle.concat([negative_image_embeds, image_embeds], axis=0).cast(dtype=self.unet.dtype)
|
| 298 |
+
|
| 299 |
+
if not isinstance(image, list):
|
| 300 |
+
image = [image]
|
| 301 |
+
if not all(isinstance(i, (PIL.Image.Image, paddle.Tensor)) for i in image):
|
| 302 |
+
raise ValueError(
|
| 303 |
+
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
image = paddle.concat([prepare_image(i, width, height) for i in image], axis=0)
|
| 307 |
+
image = image.cast(dtype=image_embeds.dtype)
|
| 308 |
+
|
| 309 |
+
latents = self.movq.encode(image)["latents"]
|
| 310 |
+
latents = latents.repeat_interleave(num_images_per_prompt, axis=0)
|
| 311 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
| 312 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
|
| 313 |
+
latent_timestep = timesteps[:1].tile([batch_size * num_images_per_prompt])
|
| 314 |
+
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
|
| 315 |
+
latents = self.prepare_latents(
|
| 316 |
+
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, generator
|
| 317 |
+
)
|
| 318 |
+
self._num_timesteps = len(timesteps)
|
| 319 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
| 320 |
+
# expand the latents if we are doing classifier free guidance
|
| 321 |
+
latent_model_input = paddle.concat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 322 |
+
|
| 323 |
+
added_cond_kwargs = {"image_embeds": image_embeds}
|
| 324 |
+
noise_pred = self.unet(
|
| 325 |
+
sample=latent_model_input,
|
| 326 |
+
timestep=t,
|
| 327 |
+
encoder_hidden_states=None,
|
| 328 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 329 |
+
return_dict=False,
|
| 330 |
+
)[0]
|
| 331 |
+
|
| 332 |
+
if self.do_classifier_free_guidance:
|
| 333 |
+
noise_pred, variance_pred = noise_pred.split(
|
| 334 |
+
[latents.shape[1], noise_pred.shape[1] - latents.shape[1]], axis=1
|
| 335 |
+
)
|
| 336 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 337 |
+
_, variance_pred_text = variance_pred.chunk(2)
|
| 338 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 339 |
+
noise_pred = paddle.concat([noise_pred, variance_pred_text], axis=1)
|
| 340 |
+
|
| 341 |
+
if not (
|
| 342 |
+
hasattr(self.scheduler.config, "variance_type")
|
| 343 |
+
and self.scheduler.config.variance_type in ["learned", "learned_range"]
|
| 344 |
+
):
|
| 345 |
+
noise_pred, _ = noise_pred.split([latents.shape[1], noise_pred.shape[1] - latents.shape[1]], axis=1)
|
| 346 |
+
|
| 347 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 348 |
+
latents = self.scheduler.step(
|
| 349 |
+
noise_pred,
|
| 350 |
+
t,
|
| 351 |
+
latents,
|
| 352 |
+
generator=generator,
|
| 353 |
+
)[0]
|
| 354 |
+
|
| 355 |
+
if callback_on_step_end is not None:
|
| 356 |
+
callback_kwargs = {}
|
| 357 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 358 |
+
callback_kwargs[k] = locals()[k]
|
| 359 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 360 |
+
|
| 361 |
+
latents = callback_outputs.pop("latents", latents)
|
| 362 |
+
image_embeds = callback_outputs.pop("image_embeds", image_embeds)
|
| 363 |
+
negative_image_embeds = callback_outputs.pop("negative_image_embeds", negative_image_embeds)
|
| 364 |
+
|
| 365 |
+
if callback is not None and i % callback_steps == 0:
|
| 366 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 367 |
+
callback(step_idx, t, latents)
|
| 368 |
+
|
| 369 |
+
if output_type not in ["pd", "np", "pil", "latent"]:
|
| 370 |
+
raise ValueError(
|
| 371 |
+
f"Only the output types `pd`, `pil` ,`np` and `latent` are supported not output_type={output_type}"
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
if not output_type == "latent":
|
| 375 |
+
# post-processing
|
| 376 |
+
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
| 377 |
+
if output_type in ["np", "pil"]:
|
| 378 |
+
image = image * 0.5 + 0.5
|
| 379 |
+
image = image.clip(0, 1)
|
| 380 |
+
image = image.transpose([0, 2, 3, 1]).cast("float32").cpu().numpy()
|
| 381 |
+
|
| 382 |
+
if output_type == "pil":
|
| 383 |
+
image = self.numpy_to_pil(image)
|
| 384 |
+
else:
|
| 385 |
+
image = latents
|
| 386 |
+
|
| 387 |
+
# Offload all models
|
| 388 |
+
|
| 389 |
+
if not return_dict:
|
| 390 |
+
return (image,)
|
| 391 |
+
|
| 392 |
+
return ImagePipelineOutput(images=image)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/kandinsky3/__init__.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
from ...utils import (
|
| 18 |
+
PPDIFFUSERS_SLOW_IMPORT,
|
| 19 |
+
OptionalDependencyNotAvailable,
|
| 20 |
+
_LazyModule,
|
| 21 |
+
get_objects_from_module,
|
| 22 |
+
is_paddle_available,
|
| 23 |
+
is_paddlenlp_available,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
_dummy_objects = {}
|
| 27 |
+
_import_structure = {}
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 31 |
+
raise OptionalDependencyNotAvailable()
|
| 32 |
+
except OptionalDependencyNotAvailable:
|
| 33 |
+
from ...utils import dummy_paddle_and_paddlenlp_objects # noqa F403
|
| 34 |
+
|
| 35 |
+
_dummy_objects.update(get_objects_from_module(dummy_paddle_and_paddlenlp_objects))
|
| 36 |
+
else:
|
| 37 |
+
_import_structure["kandinsky3_pipeline"] = ["Kandinsky3Pipeline"]
|
| 38 |
+
_import_structure["kandinsky3img2img_pipeline"] = ["Kandinsky3Img2ImgPipeline"]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if TYPE_CHECKING or PPDIFFUSERS_SLOW_IMPORT:
|
| 42 |
+
try:
|
| 43 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 44 |
+
raise OptionalDependencyNotAvailable()
|
| 45 |
+
|
| 46 |
+
except OptionalDependencyNotAvailable:
|
| 47 |
+
from ...utils.dummy_paddle_and_paddlenlp_objects import *
|
| 48 |
+
else:
|
| 49 |
+
from .kandinsky3_pipeline import Kandinsky3Pipeline
|
| 50 |
+
from .kandinsky3img2img_pipeline import Kandinsky3Img2ImgPipeline
|
| 51 |
+
else:
|
| 52 |
+
import sys
|
| 53 |
+
|
| 54 |
+
sys.modules[__name__] = _LazyModule(
|
| 55 |
+
__name__,
|
| 56 |
+
globals()["__file__"],
|
| 57 |
+
_import_structure,
|
| 58 |
+
module_spec=__spec__,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
for name, value in _dummy_objects.items():
|
| 62 |
+
setattr(sys.modules[__name__], name, value)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/kandinsky3/kandinsky3_pipeline.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Callable, List, Optional, Union
|
| 16 |
+
|
| 17 |
+
import paddle
|
| 18 |
+
|
| 19 |
+
from ppdiffusers.transformers import T5EncoderModel, T5Tokenizer
|
| 20 |
+
|
| 21 |
+
from ...loaders import LoraLoaderMixin
|
| 22 |
+
from ...models import Kandinsky3UNet, VQModel
|
| 23 |
+
from ...schedulers import DDPMScheduler
|
| 24 |
+
from ...utils import logging
|
| 25 |
+
from ...utils.paddle_utils import randn_tensor
|
| 26 |
+
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def downscale_height_and_width(height, width, scale_factor=8):
|
| 32 |
+
new_height = height // scale_factor**2
|
| 33 |
+
if height % scale_factor**2 != 0:
|
| 34 |
+
new_height += 1
|
| 35 |
+
new_width = width // scale_factor**2
|
| 36 |
+
if width % scale_factor**2 != 0:
|
| 37 |
+
new_width += 1
|
| 38 |
+
return new_height * scale_factor, new_width * scale_factor
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
|
| 42 |
+
model_cpu_offload_seq = "text_encoder->unet->movq"
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
tokenizer: T5Tokenizer,
|
| 47 |
+
text_encoder: T5EncoderModel,
|
| 48 |
+
unet: Kandinsky3UNet,
|
| 49 |
+
scheduler: DDPMScheduler,
|
| 50 |
+
movq: VQModel,
|
| 51 |
+
):
|
| 52 |
+
super().__init__()
|
| 53 |
+
|
| 54 |
+
self.register_modules(
|
| 55 |
+
tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def process_embeds(self, embeddings, attention_mask, cut_context):
|
| 59 |
+
if cut_context:
|
| 60 |
+
embeddings[attention_mask == 0] = paddle.zeros_like(embeddings[attention_mask == 0])
|
| 61 |
+
max_seq_length = attention_mask.sum(-1).max() + 1
|
| 62 |
+
embeddings = embeddings[:, :max_seq_length]
|
| 63 |
+
attention_mask = attention_mask[:, :max_seq_length]
|
| 64 |
+
return embeddings, attention_mask
|
| 65 |
+
|
| 66 |
+
@paddle.no_grad()
|
| 67 |
+
def encode_prompt(
|
| 68 |
+
self,
|
| 69 |
+
prompt,
|
| 70 |
+
do_classifier_free_guidance=True,
|
| 71 |
+
num_images_per_prompt=1,
|
| 72 |
+
negative_prompt=None,
|
| 73 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 74 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 75 |
+
_cut_context=False,
|
| 76 |
+
):
|
| 77 |
+
r"""
|
| 78 |
+
Encodes the prompt into text encoder hidden states.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 82 |
+
prompt to be encoded
|
| 83 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 84 |
+
number of images that should be generated per prompt
|
| 85 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 86 |
+
whether to use classifier free guidance or not
|
| 87 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 88 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 89 |
+
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
| 90 |
+
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
| 91 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 92 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 93 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 94 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 95 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 96 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 97 |
+
argument.
|
| 98 |
+
"""
|
| 99 |
+
if prompt is not None and negative_prompt is not None:
|
| 100 |
+
if type(prompt) is not type(negative_prompt):
|
| 101 |
+
raise TypeError(
|
| 102 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 103 |
+
f" {type(prompt)}."
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if prompt is not None and isinstance(prompt, str):
|
| 107 |
+
batch_size = 1
|
| 108 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 109 |
+
batch_size = len(prompt)
|
| 110 |
+
else:
|
| 111 |
+
batch_size = prompt_embeds.shape[0]
|
| 112 |
+
|
| 113 |
+
max_length = 128
|
| 114 |
+
|
| 115 |
+
if prompt_embeds is None:
|
| 116 |
+
text_inputs = self.tokenizer(
|
| 117 |
+
prompt,
|
| 118 |
+
padding="max_length",
|
| 119 |
+
max_length=max_length,
|
| 120 |
+
truncation=True,
|
| 121 |
+
return_tensors="pd",
|
| 122 |
+
)
|
| 123 |
+
text_input_ids = text_inputs.input_ids
|
| 124 |
+
attention_mask = text_inputs.attention_mask
|
| 125 |
+
prompt_embeds = self.text_encoder(
|
| 126 |
+
text_input_ids,
|
| 127 |
+
attention_mask=attention_mask,
|
| 128 |
+
)
|
| 129 |
+
prompt_embeds = prompt_embeds[0]
|
| 130 |
+
prompt_embeds, attention_mask = self.process_embeds(prompt_embeds, attention_mask, _cut_context)
|
| 131 |
+
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(2)
|
| 132 |
+
|
| 133 |
+
if self.text_encoder is not None:
|
| 134 |
+
dtype = self.text_encoder.dtype
|
| 135 |
+
else:
|
| 136 |
+
dtype = None
|
| 137 |
+
|
| 138 |
+
if dtype is not None:
|
| 139 |
+
prompt_embeds = prompt_embeds.cast(dtype=dtype)
|
| 140 |
+
|
| 141 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 142 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 143 |
+
prompt_embeds = prompt_embeds.tile([1, num_images_per_prompt, 1])
|
| 144 |
+
prompt_embeds = prompt_embeds.reshape([bs_embed * num_images_per_prompt, seq_len, -1])
|
| 145 |
+
attention_mask = attention_mask.tile([num_images_per_prompt, 1])
|
| 146 |
+
# get unconditional embeddings for classifier free guidance
|
| 147 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 148 |
+
uncond_tokens: List[str]
|
| 149 |
+
|
| 150 |
+
if negative_prompt is None:
|
| 151 |
+
uncond_tokens = [""] * batch_size
|
| 152 |
+
elif isinstance(negative_prompt, str):
|
| 153 |
+
uncond_tokens = [negative_prompt]
|
| 154 |
+
elif batch_size != len(negative_prompt):
|
| 155 |
+
raise ValueError(
|
| 156 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 157 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 158 |
+
" the batch size of `prompt`."
|
| 159 |
+
)
|
| 160 |
+
else:
|
| 161 |
+
uncond_tokens = negative_prompt
|
| 162 |
+
if negative_prompt is not None:
|
| 163 |
+
uncond_input = self.tokenizer(
|
| 164 |
+
uncond_tokens,
|
| 165 |
+
padding="max_length",
|
| 166 |
+
max_length=128,
|
| 167 |
+
truncation=True,
|
| 168 |
+
return_attention_mask=True,
|
| 169 |
+
return_tensors="pd",
|
| 170 |
+
)
|
| 171 |
+
text_input_ids = uncond_input.input_ids
|
| 172 |
+
negative_attention_mask = uncond_input.attention_mask
|
| 173 |
+
|
| 174 |
+
negative_prompt_embeds = self.text_encoder(
|
| 175 |
+
text_input_ids,
|
| 176 |
+
attention_mask=negative_attention_mask,
|
| 177 |
+
)
|
| 178 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 179 |
+
negative_prompt_embeds = negative_prompt_embeds[:, : prompt_embeds.shape[1]]
|
| 180 |
+
negative_attention_mask = negative_attention_mask[:, : prompt_embeds.shape[1]]
|
| 181 |
+
negative_prompt_embeds = negative_prompt_embeds * negative_attention_mask.unsqueeze(2)
|
| 182 |
+
|
| 183 |
+
else:
|
| 184 |
+
negative_prompt_embeds = paddle.zeros_like(prompt_embeds)
|
| 185 |
+
negative_attention_mask = paddle.zeros_like(attention_mask)
|
| 186 |
+
|
| 187 |
+
if do_classifier_free_guidance:
|
| 188 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 189 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 190 |
+
if dtype is not None:
|
| 191 |
+
negative_prompt_embeds = negative_prompt_embeds.cast(dtype=dtype)
|
| 192 |
+
if negative_prompt_embeds.shape != prompt_embeds.shape:
|
| 193 |
+
negative_prompt_embeds = negative_prompt_embeds.tile([1, num_images_per_prompt, 1])
|
| 194 |
+
negative_prompt_embeds = negative_prompt_embeds.reshape(
|
| 195 |
+
[batch_size * num_images_per_prompt, seq_len, -1]
|
| 196 |
+
)
|
| 197 |
+
negative_attention_mask = negative_attention_mask.tile([num_images_per_prompt, 1])
|
| 198 |
+
|
| 199 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 200 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 201 |
+
# to avoid doing two forward passes
|
| 202 |
+
else:
|
| 203 |
+
negative_prompt_embeds = None
|
| 204 |
+
negative_attention_mask = None
|
| 205 |
+
return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask
|
| 206 |
+
|
| 207 |
+
def prepare_latents(self, shape, dtype, generator, latents, scheduler):
|
| 208 |
+
if latents is None:
|
| 209 |
+
latents = randn_tensor(shape, generator=generator, dtype=dtype)
|
| 210 |
+
else:
|
| 211 |
+
if latents.shape != list(shape):
|
| 212 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {list(shape)}")
|
| 213 |
+
latents = latents.cast(dtype)
|
| 214 |
+
|
| 215 |
+
latents = latents * scheduler.init_noise_sigma
|
| 216 |
+
return latents
|
| 217 |
+
|
| 218 |
+
def check_inputs(
|
| 219 |
+
self,
|
| 220 |
+
prompt,
|
| 221 |
+
callback_steps,
|
| 222 |
+
negative_prompt=None,
|
| 223 |
+
prompt_embeds=None,
|
| 224 |
+
negative_prompt_embeds=None,
|
| 225 |
+
):
|
| 226 |
+
if (callback_steps is None) or (
|
| 227 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 228 |
+
):
|
| 229 |
+
raise ValueError(
|
| 230 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 231 |
+
f" {type(callback_steps)}."
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
if prompt is not None and prompt_embeds is not None:
|
| 235 |
+
raise ValueError(
|
| 236 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 237 |
+
" only forward one of the two."
|
| 238 |
+
)
|
| 239 |
+
elif prompt is None and prompt_embeds is None:
|
| 240 |
+
raise ValueError(
|
| 241 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 242 |
+
)
|
| 243 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 244 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 245 |
+
|
| 246 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 247 |
+
raise ValueError(
|
| 248 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 249 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 253 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 254 |
+
raise ValueError(
|
| 255 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 256 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 257 |
+
f" {negative_prompt_embeds.shape}."
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
@paddle.no_grad()
|
| 261 |
+
def __call__(
|
| 262 |
+
self,
|
| 263 |
+
prompt: Union[str, List[str]] = None,
|
| 264 |
+
num_inference_steps: int = 25,
|
| 265 |
+
guidance_scale: float = 3.0,
|
| 266 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 267 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 268 |
+
height: Optional[int] = 1024,
|
| 269 |
+
width: Optional[int] = 1024,
|
| 270 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 271 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 272 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 273 |
+
output_type: Optional[str] = "pil",
|
| 274 |
+
return_dict: bool = True,
|
| 275 |
+
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
|
| 276 |
+
callback_steps: int = 1,
|
| 277 |
+
latents=None,
|
| 278 |
+
):
|
| 279 |
+
"""
|
| 280 |
+
Function invoked when calling the pipeline for generation.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 284 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 285 |
+
instead.
|
| 286 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 287 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 288 |
+
expense of slower inference.
|
| 289 |
+
timesteps (`List[int]`, *optional*):
|
| 290 |
+
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
| 291 |
+
timesteps are used. Must be in descending order.
|
| 292 |
+
guidance_scale (`float`, *optional*, defaults to 3.0):
|
| 293 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 294 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 295 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 296 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 297 |
+
usually at the expense of lower image quality.
|
| 298 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 299 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 300 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 301 |
+
less than `1`).
|
| 302 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 303 |
+
The number of images to generate per prompt.
|
| 304 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
| 305 |
+
The height in pixels of the generated image.
|
| 306 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
| 307 |
+
The width in pixels of the generated image.
|
| 308 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 309 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 310 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 311 |
+
generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*):
|
| 312 |
+
One or a list of [paddle generator(s)] to make generation deterministic.
|
| 313 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 314 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 315 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 316 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 317 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 318 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 319 |
+
argument.
|
| 320 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 321 |
+
The output format of the generate image. Choose between
|
| 322 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 323 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 324 |
+
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
| 325 |
+
callback (`Callable`, *optional*):
|
| 326 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 327 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
|
| 328 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 329 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 330 |
+
called at every step.
|
| 331 |
+
clean_caption (`bool`, *optional*, defaults to `True`):
|
| 332 |
+
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
| 333 |
+
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
| 334 |
+
prompt.
|
| 335 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 336 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 337 |
+
`self.processor` in
|
| 338 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 339 |
+
"""
|
| 340 |
+
cut_context = True
|
| 341 |
+
|
| 342 |
+
# 1. Check inputs. Raise error if not correct
|
| 343 |
+
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
| 344 |
+
|
| 345 |
+
if prompt is not None and isinstance(prompt, str):
|
| 346 |
+
batch_size = 1
|
| 347 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 348 |
+
batch_size = len(prompt)
|
| 349 |
+
else:
|
| 350 |
+
batch_size = prompt_embeds.shape[0]
|
| 351 |
+
|
| 352 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 353 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 354 |
+
# corresponds to doing no classifier free guidance.
|
| 355 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 356 |
+
|
| 357 |
+
# 3. Encode input prompt
|
| 358 |
+
prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt(
|
| 359 |
+
prompt,
|
| 360 |
+
do_classifier_free_guidance,
|
| 361 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 362 |
+
negative_prompt=negative_prompt,
|
| 363 |
+
prompt_embeds=prompt_embeds,
|
| 364 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 365 |
+
_cut_context=cut_context,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
if do_classifier_free_guidance:
|
| 369 |
+
prompt_embeds = paddle.concat([negative_prompt_embeds, prompt_embeds])
|
| 370 |
+
attention_mask = paddle.concat([negative_attention_mask, attention_mask])
|
| 371 |
+
# 4. Prepare timesteps
|
| 372 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
| 373 |
+
timesteps = self.scheduler.timesteps
|
| 374 |
+
|
| 375 |
+
# 5. Prepare latents
|
| 376 |
+
height, width = downscale_height_and_width(height, width, 8)
|
| 377 |
+
|
| 378 |
+
latents = self.prepare_latents(
|
| 379 |
+
[batch_size * num_images_per_prompt, 4, height, width],
|
| 380 |
+
prompt_embeds.dtype,
|
| 381 |
+
generator,
|
| 382 |
+
latents,
|
| 383 |
+
self.scheduler,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
# 7. Denoising loop
|
| 387 |
+
# TODO(Yiyi): Correct the following line and use correctly
|
| 388 |
+
# num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 389 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 390 |
+
for i, t in enumerate(timesteps):
|
| 391 |
+
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
|
| 392 |
+
|
| 393 |
+
# predict the noise residual
|
| 394 |
+
noise_pred = self.unet(
|
| 395 |
+
latent_model_input,
|
| 396 |
+
t,
|
| 397 |
+
encoder_hidden_states=prompt_embeds,
|
| 398 |
+
encoder_attention_mask=attention_mask,
|
| 399 |
+
return_dict=False,
|
| 400 |
+
)[0]
|
| 401 |
+
|
| 402 |
+
if do_classifier_free_guidance:
|
| 403 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 404 |
+
|
| 405 |
+
noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond
|
| 406 |
+
# noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 407 |
+
|
| 408 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 409 |
+
latents = self.scheduler.step(
|
| 410 |
+
noise_pred,
|
| 411 |
+
t,
|
| 412 |
+
latents,
|
| 413 |
+
generator=generator,
|
| 414 |
+
).prev_sample
|
| 415 |
+
progress_bar.update()
|
| 416 |
+
if callback is not None and i % callback_steps == 0:
|
| 417 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 418 |
+
callback(step_idx, t, latents)
|
| 419 |
+
|
| 420 |
+
# post-processing
|
| 421 |
+
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
| 422 |
+
|
| 423 |
+
if output_type not in ["pd", "np", "pil"]:
|
| 424 |
+
raise ValueError(
|
| 425 |
+
f"Only the output types `pd`, `pil` and `np` are supported not output_type={output_type}"
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
if output_type in ["np", "pil"]:
|
| 429 |
+
image = image * 0.5 + 0.5
|
| 430 |
+
image = image.clip(0, 1)
|
| 431 |
+
image = image.transpose([0, 2, 3, 1]).cast("float32").cpu().numpy()
|
| 432 |
+
|
| 433 |
+
if output_type == "pil":
|
| 434 |
+
image = self.numpy_to_pil(image)
|
| 435 |
+
|
| 436 |
+
if not return_dict:
|
| 437 |
+
return (image,)
|
| 438 |
+
|
| 439 |
+
return ImagePipelineOutput(images=image)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Callable, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import paddle
|
| 20 |
+
import PIL
|
| 21 |
+
import PIL.Image
|
| 22 |
+
|
| 23 |
+
from ppdiffusers.transformers import T5EncoderModel, T5Tokenizer
|
| 24 |
+
|
| 25 |
+
from ...loaders import LoraLoaderMixin
|
| 26 |
+
from ...models import Kandinsky3UNet, VQModel
|
| 27 |
+
from ...schedulers import DDPMScheduler
|
| 28 |
+
from ...utils import logging
|
| 29 |
+
from ...utils.paddle_utils import randn_tensor
|
| 30 |
+
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def downscale_height_and_width(height, width, scale_factor=8):
|
| 36 |
+
new_height = height // scale_factor**2
|
| 37 |
+
if height % scale_factor**2 != 0:
|
| 38 |
+
new_height += 1
|
| 39 |
+
new_width = width // scale_factor**2
|
| 40 |
+
if width % scale_factor**2 != 0:
|
| 41 |
+
new_width += 1
|
| 42 |
+
return new_height * scale_factor, new_width * scale_factor
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def prepare_image(pil_image):
|
| 46 |
+
arr = np.array(pil_image.convert("RGB"))
|
| 47 |
+
arr = arr.astype(np.float32) / 127.5 - 1
|
| 48 |
+
arr = np.transpose(arr, [2, 0, 1])
|
| 49 |
+
image = paddle.to_tensor(arr).unsqueeze(0)
|
| 50 |
+
return image
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
| 54 |
+
model_cpu_offload_seq = "text_encoder->unet->movq"
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
tokenizer: T5Tokenizer,
|
| 59 |
+
text_encoder: T5EncoderModel,
|
| 60 |
+
unet: Kandinsky3UNet,
|
| 61 |
+
scheduler: DDPMScheduler,
|
| 62 |
+
movq: VQModel,
|
| 63 |
+
):
|
| 64 |
+
super().__init__()
|
| 65 |
+
|
| 66 |
+
self.register_modules(
|
| 67 |
+
tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def get_timesteps(self, num_inference_steps, strength):
|
| 71 |
+
# get the original timestep using init_timestep
|
| 72 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 73 |
+
|
| 74 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 75 |
+
timesteps = self.scheduler.timesteps[t_start:]
|
| 76 |
+
|
| 77 |
+
return timesteps, num_inference_steps - t_start
|
| 78 |
+
|
| 79 |
+
def _process_embeds(self, embeddings, attention_mask, cut_context):
|
| 80 |
+
# return embeddings, attention_mask
|
| 81 |
+
if cut_context:
|
| 82 |
+
embeddings[attention_mask == 0] = paddle.zeros_like(embeddings[attention_mask == 0])
|
| 83 |
+
max_seq_length = attention_mask.sum(-1).max() + 1
|
| 84 |
+
embeddings = embeddings[:, :max_seq_length]
|
| 85 |
+
attention_mask = attention_mask[:, :max_seq_length]
|
| 86 |
+
return embeddings, attention_mask
|
| 87 |
+
|
| 88 |
+
@paddle.no_grad()
|
| 89 |
+
def encode_prompt(
|
| 90 |
+
self,
|
| 91 |
+
prompt,
|
| 92 |
+
do_classifier_free_guidance=True,
|
| 93 |
+
num_images_per_prompt=1,
|
| 94 |
+
negative_prompt=None,
|
| 95 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 96 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 97 |
+
_cut_context=False,
|
| 98 |
+
):
|
| 99 |
+
r"""
|
| 100 |
+
Encodes the prompt into text encoder hidden states.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 104 |
+
prompt to be encoded
|
| 105 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 106 |
+
number of images that should be generated per prompt
|
| 107 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 108 |
+
whether to use classifier free guidance or not
|
| 109 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 110 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 111 |
+
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
| 112 |
+
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
| 113 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 114 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 115 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 116 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 117 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 118 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 119 |
+
argument.
|
| 120 |
+
"""
|
| 121 |
+
if prompt is not None and negative_prompt is not None:
|
| 122 |
+
if type(prompt) is not type(negative_prompt):
|
| 123 |
+
raise TypeError(
|
| 124 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 125 |
+
f" {type(prompt)}."
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
if prompt is not None and isinstance(prompt, str):
|
| 129 |
+
batch_size = 1
|
| 130 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 131 |
+
batch_size = len(prompt)
|
| 132 |
+
else:
|
| 133 |
+
batch_size = prompt_embeds.shape[0]
|
| 134 |
+
|
| 135 |
+
max_length = 128
|
| 136 |
+
|
| 137 |
+
if prompt_embeds is None:
|
| 138 |
+
text_inputs = self.tokenizer(
|
| 139 |
+
prompt,
|
| 140 |
+
padding="max_length",
|
| 141 |
+
max_length=max_length,
|
| 142 |
+
truncation=True,
|
| 143 |
+
return_tensors="pd",
|
| 144 |
+
)
|
| 145 |
+
text_input_ids = text_inputs.input_ids
|
| 146 |
+
attention_mask = text_inputs.attention_mask
|
| 147 |
+
prompt_embeds = self.text_encoder(
|
| 148 |
+
text_input_ids,
|
| 149 |
+
attention_mask=attention_mask,
|
| 150 |
+
)
|
| 151 |
+
prompt_embeds = prompt_embeds[0]
|
| 152 |
+
prompt_embeds, attention_mask = self._process_embeds(prompt_embeds, attention_mask, _cut_context)
|
| 153 |
+
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(2)
|
| 154 |
+
|
| 155 |
+
if self.text_encoder is not None:
|
| 156 |
+
dtype = self.text_encoder.dtype
|
| 157 |
+
else:
|
| 158 |
+
dtype = None
|
| 159 |
+
|
| 160 |
+
if dtype is not None:
|
| 161 |
+
prompt_embeds = prompt_embeds.cast(dtype=dtype)
|
| 162 |
+
|
| 163 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 164 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 165 |
+
prompt_embeds = prompt_embeds = prompt_embeds.tile([1, num_images_per_prompt, 1])
|
| 166 |
+
prompt_embeds = prompt_embeds.reshape([bs_embed * num_images_per_prompt, seq_len, -1])
|
| 167 |
+
attention_mask = attention_mask.tile([num_images_per_prompt, 1])
|
| 168 |
+
# get unconditional embeddings for classifier free guidance
|
| 169 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 170 |
+
uncond_tokens: List[str]
|
| 171 |
+
|
| 172 |
+
if negative_prompt is None:
|
| 173 |
+
uncond_tokens = [""] * batch_size
|
| 174 |
+
elif isinstance(negative_prompt, str):
|
| 175 |
+
uncond_tokens = [negative_prompt]
|
| 176 |
+
elif batch_size != len(negative_prompt):
|
| 177 |
+
raise ValueError(
|
| 178 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 179 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 180 |
+
" the batch size of `prompt`."
|
| 181 |
+
)
|
| 182 |
+
else:
|
| 183 |
+
uncond_tokens = negative_prompt
|
| 184 |
+
if negative_prompt is not None:
|
| 185 |
+
uncond_input = self.tokenizer(
|
| 186 |
+
uncond_tokens,
|
| 187 |
+
padding="max_length",
|
| 188 |
+
max_length=128,
|
| 189 |
+
truncation=True,
|
| 190 |
+
return_attention_mask=True,
|
| 191 |
+
return_tensors="pd",
|
| 192 |
+
)
|
| 193 |
+
text_input_ids = uncond_input.input_ids
|
| 194 |
+
negative_attention_mask = uncond_input.attention_mask
|
| 195 |
+
|
| 196 |
+
negative_prompt_embeds = self.text_encoder(
|
| 197 |
+
text_input_ids,
|
| 198 |
+
attention_mask=negative_attention_mask,
|
| 199 |
+
)
|
| 200 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 201 |
+
negative_prompt_embeds = negative_prompt_embeds[:, : prompt_embeds.shape[1]]
|
| 202 |
+
negative_attention_mask = negative_attention_mask[:, : prompt_embeds.shape[1]]
|
| 203 |
+
negative_prompt_embeds = negative_prompt_embeds * negative_attention_mask.unsqueeze(2)
|
| 204 |
+
|
| 205 |
+
else:
|
| 206 |
+
negative_prompt_embeds = paddle.zeros_like(prompt_embeds)
|
| 207 |
+
negative_attention_mask = paddle.zeros_like(attention_mask)
|
| 208 |
+
|
| 209 |
+
if do_classifier_free_guidance:
|
| 210 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 211 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 212 |
+
if dtype is not None:
|
| 213 |
+
negative_prompt_embeds = negative_prompt_embeds.cast(dtype=dtype)
|
| 214 |
+
if negative_prompt_embeds.shape != prompt_embeds.shape:
|
| 215 |
+
negative_prompt_embeds = negative_prompt_embeds.tile([1, num_images_per_prompt, 1])
|
| 216 |
+
negative_prompt_embeds = negative_prompt_embeds.reshape(
|
| 217 |
+
[batch_size * num_images_per_prompt, seq_len, -1]
|
| 218 |
+
)
|
| 219 |
+
negative_attention_mask = negative_attention_mask.tile([num_images_per_prompt, 1])
|
| 220 |
+
|
| 221 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 222 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 223 |
+
# to avoid doing two forward passes
|
| 224 |
+
else:
|
| 225 |
+
negative_prompt_embeds = None
|
| 226 |
+
negative_attention_mask = None
|
| 227 |
+
return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask
|
| 228 |
+
|
| 229 |
+
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None):
|
| 230 |
+
if not isinstance(image, (paddle.Tensor, PIL.Image.Image, list)):
|
| 231 |
+
raise ValueError(
|
| 232 |
+
f"`image` has to be of type `paddle.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
image = image.cast(dtype=dtype)
|
| 236 |
+
|
| 237 |
+
batch_size = batch_size * num_images_per_prompt
|
| 238 |
+
|
| 239 |
+
if image.shape[1] == 4:
|
| 240 |
+
init_latents = image
|
| 241 |
+
|
| 242 |
+
else:
|
| 243 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 244 |
+
raise ValueError(
|
| 245 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 246 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
elif isinstance(generator, list):
|
| 250 |
+
init_latents = [
|
| 251 |
+
self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
| 252 |
+
]
|
| 253 |
+
init_latents = paddle.concat(init_latents, axis=0)
|
| 254 |
+
else:
|
| 255 |
+
init_latents = self.movq.encode(image).latent_dist.sample(generator)
|
| 256 |
+
|
| 257 |
+
init_latents = self.movq.config.scaling_factor * init_latents
|
| 258 |
+
|
| 259 |
+
init_latents = paddle.concat([init_latents], axis=0)
|
| 260 |
+
|
| 261 |
+
shape = init_latents.shape
|
| 262 |
+
noise = randn_tensor(shape, generator=generator, dtype=dtype)
|
| 263 |
+
|
| 264 |
+
# get latents
|
| 265 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
| 266 |
+
|
| 267 |
+
latents = init_latents
|
| 268 |
+
|
| 269 |
+
return latents
|
| 270 |
+
|
| 271 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 272 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 273 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 274 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 275 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 276 |
+
# and should be between [0, 1]
|
| 277 |
+
|
| 278 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 279 |
+
extra_step_kwargs = {}
|
| 280 |
+
if accepts_eta:
|
| 281 |
+
extra_step_kwargs["eta"] = eta
|
| 282 |
+
|
| 283 |
+
# check if the scheduler accepts generator
|
| 284 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 285 |
+
if accepts_generator:
|
| 286 |
+
extra_step_kwargs["generator"] = generator
|
| 287 |
+
return extra_step_kwargs
|
| 288 |
+
|
| 289 |
+
def check_inputs(
|
| 290 |
+
self,
|
| 291 |
+
prompt,
|
| 292 |
+
callback_steps,
|
| 293 |
+
negative_prompt=None,
|
| 294 |
+
prompt_embeds=None,
|
| 295 |
+
negative_prompt_embeds=None,
|
| 296 |
+
):
|
| 297 |
+
if (callback_steps is None) or (
|
| 298 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 299 |
+
):
|
| 300 |
+
raise ValueError(
|
| 301 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 302 |
+
f" {type(callback_steps)}."
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
if prompt is not None and prompt_embeds is not None:
|
| 306 |
+
raise ValueError(
|
| 307 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 308 |
+
" only forward one of the two."
|
| 309 |
+
)
|
| 310 |
+
elif prompt is None and prompt_embeds is None:
|
| 311 |
+
raise ValueError(
|
| 312 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 313 |
+
)
|
| 314 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 315 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 316 |
+
|
| 317 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 318 |
+
raise ValueError(
|
| 319 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 320 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 324 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 325 |
+
raise ValueError(
|
| 326 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 327 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 328 |
+
f" {negative_prompt_embeds.shape}."
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
@paddle.no_grad()
|
| 332 |
+
def __call__(
|
| 333 |
+
self,
|
| 334 |
+
prompt: Union[str, List[str]] = None,
|
| 335 |
+
image: Union[paddle.Tensor, PIL.Image.Image, List[paddle.Tensor], List[PIL.Image.Image]] = None,
|
| 336 |
+
strength: float = 0.3,
|
| 337 |
+
num_inference_steps: int = 25,
|
| 338 |
+
guidance_scale: float = 3.0,
|
| 339 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 340 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 341 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 342 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 343 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 344 |
+
output_type: Optional[str] = "pil",
|
| 345 |
+
return_dict: bool = True,
|
| 346 |
+
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
|
| 347 |
+
callback_steps: int = 1,
|
| 348 |
+
latents=None,
|
| 349 |
+
):
|
| 350 |
+
cut_context = True
|
| 351 |
+
# 1. Check inputs. Raise error if not correct
|
| 352 |
+
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
| 353 |
+
|
| 354 |
+
if prompt is not None and isinstance(prompt, str):
|
| 355 |
+
batch_size = 1
|
| 356 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 357 |
+
batch_size = len(prompt)
|
| 358 |
+
else:
|
| 359 |
+
batch_size = prompt_embeds.shape[0]
|
| 360 |
+
|
| 361 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 362 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 363 |
+
# corresponds to doing no classifier free guidance.
|
| 364 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 365 |
+
|
| 366 |
+
# 3. Encode input prompt
|
| 367 |
+
prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt(
|
| 368 |
+
prompt,
|
| 369 |
+
do_classifier_free_guidance,
|
| 370 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 371 |
+
negative_prompt=negative_prompt,
|
| 372 |
+
prompt_embeds=prompt_embeds,
|
| 373 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 374 |
+
_cut_context=cut_context,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
if do_classifier_free_guidance:
|
| 378 |
+
prompt_embeds = paddle.concat([negative_prompt_embeds, prompt_embeds])
|
| 379 |
+
attention_mask = paddle.concat([negative_attention_mask, attention_mask])
|
| 380 |
+
if not isinstance(image, list):
|
| 381 |
+
image = [image]
|
| 382 |
+
if not all(isinstance(i, (PIL.Image.Image, paddle.Tensor)) for i in image):
|
| 383 |
+
raise ValueError(
|
| 384 |
+
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
image = paddle.concat([prepare_image(i) for i in image], axis=0)
|
| 388 |
+
image = image.cast(dtype=prompt_embeds.dtype)
|
| 389 |
+
# 4. Prepare timesteps
|
| 390 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
| 391 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
|
| 392 |
+
# 5. Prepare latents
|
| 393 |
+
latents = self.movq.encode(image)["latents"]
|
| 394 |
+
latents = latents.repeat_interleave(num_images_per_prompt, axis=0)
|
| 395 |
+
latent_timestep = timesteps[:1].tile([batch_size * num_images_per_prompt])
|
| 396 |
+
latents = self.prepare_latents(
|
| 397 |
+
latents, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, generator
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# 7. Denoising loop
|
| 401 |
+
# TODO(Yiyi): Correct the following line and use correctly
|
| 402 |
+
# num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 403 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 404 |
+
for i, t in enumerate(timesteps):
|
| 405 |
+
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
|
| 406 |
+
|
| 407 |
+
# predict the noise residual
|
| 408 |
+
noise_pred = self.unet(
|
| 409 |
+
latent_model_input,
|
| 410 |
+
t,
|
| 411 |
+
encoder_hidden_states=prompt_embeds,
|
| 412 |
+
encoder_attention_mask=attention_mask,
|
| 413 |
+
)[0]
|
| 414 |
+
if do_classifier_free_guidance:
|
| 415 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 416 |
+
|
| 417 |
+
noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond
|
| 418 |
+
|
| 419 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 420 |
+
latents = self.scheduler.step(
|
| 421 |
+
noise_pred,
|
| 422 |
+
t,
|
| 423 |
+
latents,
|
| 424 |
+
generator=generator,
|
| 425 |
+
).prev_sample
|
| 426 |
+
progress_bar.update()
|
| 427 |
+
if callback is not None and i % callback_steps == 0:
|
| 428 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 429 |
+
callback(step_idx, t, latents)
|
| 430 |
+
# post-processing
|
| 431 |
+
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
|
| 432 |
+
|
| 433 |
+
if output_type not in ["pd", "np", "pil"]:
|
| 434 |
+
raise ValueError(
|
| 435 |
+
f"Only the output types `pd`, `pil` and `np` are supported not output_type={output_type}"
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
if output_type in ["np", "pil"]:
|
| 439 |
+
image = image * 0.5 + 0.5
|
| 440 |
+
image = image.clip(0, 1)
|
| 441 |
+
image = image.transpose([0, 2, 3, 1]).cast("float32").cpu().numpy()
|
| 442 |
+
|
| 443 |
+
if output_type == "pil":
|
| 444 |
+
image = self.numpy_to_pil(image)
|
| 445 |
+
|
| 446 |
+
if not return_dict:
|
| 447 |
+
return (image,)
|
| 448 |
+
|
| 449 |
+
return ImagePipelineOutput(images=image)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/latent_consistency_models/__init__.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
from ...utils import (
|
| 18 |
+
PPDIFFUSERS_SLOW_IMPORT,
|
| 19 |
+
OptionalDependencyNotAvailable,
|
| 20 |
+
_LazyModule,
|
| 21 |
+
get_objects_from_module,
|
| 22 |
+
is_paddle_available,
|
| 23 |
+
is_paddlenlp_available,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
_dummy_objects = {}
|
| 27 |
+
_import_structure = {}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 32 |
+
raise OptionalDependencyNotAvailable()
|
| 33 |
+
except OptionalDependencyNotAvailable:
|
| 34 |
+
from ...utils import dummy_paddle_and_paddlenlp_objects # noqa F403
|
| 35 |
+
|
| 36 |
+
_dummy_objects.update(get_objects_from_module(dummy_paddle_and_paddlenlp_objects))
|
| 37 |
+
else:
|
| 38 |
+
_import_structure["pipeline_latent_consistency_img2img"] = ["LatentConsistencyModelImg2ImgPipeline"]
|
| 39 |
+
_import_structure["pipeline_latent_consistency_text2img"] = ["LatentConsistencyModelPipeline"]
|
| 40 |
+
|
| 41 |
+
if TYPE_CHECKING or PPDIFFUSERS_SLOW_IMPORT:
|
| 42 |
+
try:
|
| 43 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 44 |
+
raise OptionalDependencyNotAvailable()
|
| 45 |
+
|
| 46 |
+
except OptionalDependencyNotAvailable:
|
| 47 |
+
from ...utils.dummy_paddle_and_paddlenlp_objects import *
|
| 48 |
+
else:
|
| 49 |
+
from .pipeline_latent_consistency_img2img import (
|
| 50 |
+
LatentConsistencyModelImg2ImgPipeline,
|
| 51 |
+
)
|
| 52 |
+
from .pipeline_latent_consistency_text2img import LatentConsistencyModelPipeline
|
| 53 |
+
|
| 54 |
+
else:
|
| 55 |
+
import sys
|
| 56 |
+
|
| 57 |
+
sys.modules[__name__] = _LazyModule(
|
| 58 |
+
__name__,
|
| 59 |
+
globals()["__file__"],
|
| 60 |
+
_import_structure,
|
| 61 |
+
module_spec=__spec__,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
for name, value in _dummy_objects.items():
|
| 65 |
+
setattr(sys.modules[__name__], name, value)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
ADDED
|
@@ -0,0 +1,796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
| 16 |
+
# and https://github.com/hojonathanho/diffusion
|
| 17 |
+
|
| 18 |
+
import inspect
|
| 19 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 20 |
+
|
| 21 |
+
import paddle
|
| 22 |
+
import PIL.Image
|
| 23 |
+
|
| 24 |
+
from ppdiffusers.transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
| 25 |
+
|
| 26 |
+
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
| 27 |
+
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
| 28 |
+
from ...models import AutoencoderKL, UNet2DConditionModel
|
| 29 |
+
from ...models.lora import adjust_lora_scale_text_encoder
|
| 30 |
+
from ...schedulers import LCMScheduler
|
| 31 |
+
from ...utils import USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring
|
| 32 |
+
from ...utils.paddle_utils import randn_tensor
|
| 33 |
+
from ..pipeline_utils import DiffusionPipeline
|
| 34 |
+
from ..stable_diffusion import (
|
| 35 |
+
StableDiffusionPipelineOutput,
|
| 36 |
+
StableDiffusionSafetyChecker,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 43 |
+
def retrieve_latents(
|
| 44 |
+
encoder_output: paddle.Tensor, generator: Optional[paddle.Generator] = None, sample_mode: str = "sample"
|
| 45 |
+
):
|
| 46 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 47 |
+
return encoder_output.latent_dist.sample(generator)
|
| 48 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 49 |
+
return encoder_output.latent_dist.mode()
|
| 50 |
+
elif hasattr(encoder_output, "latents"):
|
| 51 |
+
return encoder_output.latents
|
| 52 |
+
else:
|
| 53 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 57 |
+
def retrieve_timesteps(
|
| 58 |
+
scheduler,
|
| 59 |
+
num_inference_steps: Optional[int] = None,
|
| 60 |
+
timesteps: Optional[List[int]] = None,
|
| 61 |
+
**kwargs,
|
| 62 |
+
):
|
| 63 |
+
"""
|
| 64 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 65 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
scheduler (`SchedulerMixin`):
|
| 69 |
+
The scheduler to get timesteps from.
|
| 70 |
+
num_inference_steps (`int`):
|
| 71 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
| 72 |
+
`timesteps` must be `None`.
|
| 73 |
+
timesteps (`List[int]`, *optional*):
|
| 74 |
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
| 75 |
+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
| 76 |
+
must be `None`.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
`Tuple[paddle.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 80 |
+
second element is the number of inference steps.
|
| 81 |
+
"""
|
| 82 |
+
if timesteps is not None:
|
| 83 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 84 |
+
if not accepts_timesteps:
|
| 85 |
+
raise ValueError(
|
| 86 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 87 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 88 |
+
)
|
| 89 |
+
scheduler.set_timesteps(timesteps=timesteps, **kwargs)
|
| 90 |
+
timesteps = scheduler.timesteps
|
| 91 |
+
num_inference_steps = len(timesteps)
|
| 92 |
+
else:
|
| 93 |
+
scheduler.set_timesteps(num_inference_steps, **kwargs)
|
| 94 |
+
timesteps = scheduler.timesteps
|
| 95 |
+
return timesteps, num_inference_steps
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
EXAMPLE_DOC_STRING = """
|
| 99 |
+
Examples:
|
| 100 |
+
```py
|
| 101 |
+
>>> from ppdiffusers import AutoPipelineForImage2Image
|
| 102 |
+
>>> import paddle
|
| 103 |
+
>>> import PIL
|
| 104 |
+
|
| 105 |
+
>>> pipe = AutoPipelineForImage2Image.from_pretrained("SimianLuo/LCM_Dreamshaper_v7")
|
| 106 |
+
>>> # To save GPU memory, paddle.float16 can be used, but it may compromise image quality.
|
| 107 |
+
>>> pipe.to(paddle_dtype=paddle.float32)
|
| 108 |
+
|
| 109 |
+
>>> prompt = "High altitude snowy mountains"
|
| 110 |
+
>>> image = PIL.Image.open("./snowy_mountains.png")
|
| 111 |
+
|
| 112 |
+
>>> # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
|
| 113 |
+
>>> num_inference_steps = 4
|
| 114 |
+
>>> images = pipe(
|
| 115 |
+
... prompt=prompt, image=image, num_inference_steps=num_inference_steps, guidance_scale=8.0
|
| 116 |
+
... ).images
|
| 117 |
+
|
| 118 |
+
>>> images[0].save("image.png")
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class LatentConsistencyModelImg2ImgPipeline(
|
| 125 |
+
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
| 126 |
+
):
|
| 127 |
+
r"""
|
| 128 |
+
Pipeline for image-to-image generation using a latent consistency model.
|
| 129 |
+
|
| 130 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 131 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 132 |
+
|
| 133 |
+
The pipeline also inherits the following loading methods:
|
| 134 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
| 135 |
+
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 136 |
+
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
| 137 |
+
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
vae ([`AutoencoderKL`]):
|
| 141 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 142 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
| 143 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
| 144 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
| 145 |
+
A `CLIPTokenizer` to tokenize text.
|
| 146 |
+
unet ([`UNet2DConditionModel`]):
|
| 147 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
| 148 |
+
scheduler ([`SchedulerMixin`]):
|
| 149 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only
|
| 150 |
+
supports [`LCMScheduler`].
|
| 151 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
| 152 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 153 |
+
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
| 154 |
+
about a model's potential harms.
|
| 155 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
| 156 |
+
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
| 157 |
+
requires_safety_checker (`bool`, *optional*, defaults to `True`):
|
| 158 |
+
Whether the pipeline requires a safety checker component.
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
model_cpu_offload_seq = "text_encoder->unet->vae"
|
| 162 |
+
_optional_components = ["safety_checker", "feature_extractor"]
|
| 163 |
+
_exclude_from_cpu_offload = ["safety_checker"]
|
| 164 |
+
_callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"]
|
| 165 |
+
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
vae: AutoencoderKL,
|
| 169 |
+
text_encoder: CLIPTextModel,
|
| 170 |
+
tokenizer: CLIPTokenizer,
|
| 171 |
+
unet: UNet2DConditionModel,
|
| 172 |
+
scheduler: LCMScheduler,
|
| 173 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 174 |
+
feature_extractor: CLIPImageProcessor,
|
| 175 |
+
requires_safety_checker: bool = True,
|
| 176 |
+
):
|
| 177 |
+
super().__init__()
|
| 178 |
+
|
| 179 |
+
self.register_modules(
|
| 180 |
+
vae=vae,
|
| 181 |
+
text_encoder=text_encoder,
|
| 182 |
+
tokenizer=tokenizer,
|
| 183 |
+
unet=unet,
|
| 184 |
+
scheduler=scheduler,
|
| 185 |
+
safety_checker=safety_checker,
|
| 186 |
+
feature_extractor=feature_extractor,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if safety_checker is None and requires_safety_checker:
|
| 190 |
+
logger.warning(
|
| 191 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
| 192 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
| 193 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
| 194 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
| 195 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
| 196 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 200 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 201 |
+
|
| 202 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
| 203 |
+
def encode_prompt(
|
| 204 |
+
self,
|
| 205 |
+
prompt,
|
| 206 |
+
num_images_per_prompt,
|
| 207 |
+
do_classifier_free_guidance,
|
| 208 |
+
negative_prompt=None,
|
| 209 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 210 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 211 |
+
lora_scale: Optional[float] = None,
|
| 212 |
+
clip_skip: Optional[int] = None,
|
| 213 |
+
):
|
| 214 |
+
r"""
|
| 215 |
+
Encodes the prompt into text encoder hidden states.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 219 |
+
prompt to be encoded
|
| 220 |
+
num_images_per_prompt (`int`):
|
| 221 |
+
number of images that should be generated per prompt
|
| 222 |
+
do_classifier_free_guidance (`bool`):
|
| 223 |
+
whether to use classifier free guidance or not
|
| 224 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 225 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 226 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 227 |
+
less than `1`).
|
| 228 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 229 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 230 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 231 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 232 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 233 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 234 |
+
argument.
|
| 235 |
+
lora_scale (`float`, *optional*):
|
| 236 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 237 |
+
clip_skip (`int`, *optional*):
|
| 238 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 239 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 240 |
+
"""
|
| 241 |
+
# set lora scale so that monkey patched LoRA
|
| 242 |
+
# function of text encoder can correctly access it
|
| 243 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
| 244 |
+
self._lora_scale = lora_scale
|
| 245 |
+
|
| 246 |
+
# dynamically adjust the LoRA scale
|
| 247 |
+
if not USE_PEFT_BACKEND:
|
| 248 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
| 249 |
+
|
| 250 |
+
if prompt is not None and isinstance(prompt, str):
|
| 251 |
+
batch_size = 1
|
| 252 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 253 |
+
batch_size = len(prompt)
|
| 254 |
+
else:
|
| 255 |
+
batch_size = prompt_embeds.shape[0]
|
| 256 |
+
|
| 257 |
+
if prompt_embeds is None:
|
| 258 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 259 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 260 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 261 |
+
|
| 262 |
+
text_inputs = self.tokenizer(
|
| 263 |
+
prompt,
|
| 264 |
+
padding="max_length",
|
| 265 |
+
max_length=self.tokenizer.model_max_length,
|
| 266 |
+
truncation=True,
|
| 267 |
+
return_tensors="pd",
|
| 268 |
+
)
|
| 269 |
+
text_input_ids = text_inputs.input_ids
|
| 270 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pd").input_ids
|
| 271 |
+
|
| 272 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not paddle.equal_all(
|
| 273 |
+
text_input_ids, untruncated_ids
|
| 274 |
+
):
|
| 275 |
+
removed_text = self.tokenizer.batch_decode(
|
| 276 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 277 |
+
)
|
| 278 |
+
logger.warning(
|
| 279 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 280 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 284 |
+
attention_mask = text_inputs.attention_mask
|
| 285 |
+
else:
|
| 286 |
+
attention_mask = None
|
| 287 |
+
|
| 288 |
+
if clip_skip is None:
|
| 289 |
+
prompt_embeds = self.text_encoder(text_input_ids, attention_mask=attention_mask)
|
| 290 |
+
prompt_embeds = prompt_embeds[0]
|
| 291 |
+
else:
|
| 292 |
+
prompt_embeds = self.text_encoder(
|
| 293 |
+
text_input_ids, attention_mask=attention_mask, output_hidden_states=True
|
| 294 |
+
)
|
| 295 |
+
# Access the `hidden_states` first, that contains a tuple of
|
| 296 |
+
# all the hidden states from the encoder layers. Then index into
|
| 297 |
+
# the tuple to access the hidden states from the desired layer.
|
| 298 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
| 299 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
| 300 |
+
# representations. The `last_hidden_states` that we typically use for
|
| 301 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
| 302 |
+
# layer.
|
| 303 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
| 304 |
+
|
| 305 |
+
if self.text_encoder is not None:
|
| 306 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
| 307 |
+
elif self.unet is not None:
|
| 308 |
+
prompt_embeds_dtype = self.unet.dtype
|
| 309 |
+
else:
|
| 310 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
| 311 |
+
|
| 312 |
+
prompt_embeds = prompt_embeds.cast(dtype=prompt_embeds_dtype)
|
| 313 |
+
|
| 314 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 315 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 316 |
+
prompt_embeds = prompt_embeds.tile([1, num_images_per_prompt, 1])
|
| 317 |
+
prompt_embeds = prompt_embeds.reshape([bs_embed * num_images_per_prompt, seq_len, -1])
|
| 318 |
+
|
| 319 |
+
# get unconditional embeddings for classifier free guidance
|
| 320 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 321 |
+
uncond_tokens: List[str]
|
| 322 |
+
if negative_prompt is None:
|
| 323 |
+
uncond_tokens = [""] * batch_size
|
| 324 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
| 325 |
+
raise TypeError(
|
| 326 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 327 |
+
f" {type(prompt)}."
|
| 328 |
+
)
|
| 329 |
+
elif isinstance(negative_prompt, str):
|
| 330 |
+
uncond_tokens = [negative_prompt]
|
| 331 |
+
elif batch_size != len(negative_prompt):
|
| 332 |
+
raise ValueError(
|
| 333 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 334 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 335 |
+
" the batch size of `prompt`."
|
| 336 |
+
)
|
| 337 |
+
else:
|
| 338 |
+
uncond_tokens = negative_prompt
|
| 339 |
+
|
| 340 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 341 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 342 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
| 343 |
+
|
| 344 |
+
max_length = prompt_embeds.shape[1]
|
| 345 |
+
uncond_input = self.tokenizer(
|
| 346 |
+
uncond_tokens,
|
| 347 |
+
padding="max_length",
|
| 348 |
+
max_length=max_length,
|
| 349 |
+
truncation=True,
|
| 350 |
+
return_tensors="pd",
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 354 |
+
attention_mask = uncond_input.attention_mask
|
| 355 |
+
else:
|
| 356 |
+
attention_mask = None
|
| 357 |
+
|
| 358 |
+
negative_prompt_embeds = self.text_encoder(
|
| 359 |
+
uncond_input.input_ids,
|
| 360 |
+
attention_mask=attention_mask,
|
| 361 |
+
)
|
| 362 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 363 |
+
|
| 364 |
+
if do_classifier_free_guidance:
|
| 365 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 366 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 367 |
+
|
| 368 |
+
negative_prompt_embeds = negative_prompt_embeds.cast(dtype=prompt_embeds_dtype)
|
| 369 |
+
|
| 370 |
+
negative_prompt_embeds = negative_prompt_embeds.tile([1, num_images_per_prompt, 1])
|
| 371 |
+
negative_prompt_embeds = negative_prompt_embeds.reshape([batch_size * num_images_per_prompt, seq_len, -1])
|
| 372 |
+
|
| 373 |
+
return prompt_embeds, negative_prompt_embeds
|
| 374 |
+
|
| 375 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
| 376 |
+
def run_safety_checker(self, image, dtype):
|
| 377 |
+
if self.safety_checker is None:
|
| 378 |
+
has_nsfw_concept = None
|
| 379 |
+
else:
|
| 380 |
+
if paddle.is_tensor(image):
|
| 381 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
| 382 |
+
else:
|
| 383 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
| 384 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pd")
|
| 385 |
+
image, has_nsfw_concept = self.safety_checker(
|
| 386 |
+
images=image, clip_input=safety_checker_input.pixel_values.cast(dtype=dtype)
|
| 387 |
+
)
|
| 388 |
+
return image, has_nsfw_concept
|
| 389 |
+
|
| 390 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
|
| 391 |
+
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None):
|
| 392 |
+
if not isinstance(image, (paddle.Tensor, PIL.Image.Image, list)):
|
| 393 |
+
raise ValueError(
|
| 394 |
+
f"`image` has to be of type `paddle.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
image = image.cast(dtype=dtype)
|
| 398 |
+
|
| 399 |
+
batch_size = batch_size * num_images_per_prompt
|
| 400 |
+
|
| 401 |
+
if image.shape[1] == 4:
|
| 402 |
+
init_latents = image
|
| 403 |
+
|
| 404 |
+
else:
|
| 405 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 406 |
+
raise ValueError(
|
| 407 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 408 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
elif isinstance(generator, list):
|
| 412 |
+
init_latents = [
|
| 413 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
| 414 |
+
for i in range(batch_size)
|
| 415 |
+
]
|
| 416 |
+
init_latents = paddle.concat(init_latents, axis=0)
|
| 417 |
+
else:
|
| 418 |
+
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
| 419 |
+
|
| 420 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
| 421 |
+
|
| 422 |
+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
| 423 |
+
# expand init_latents for batch_size
|
| 424 |
+
deprecation_message = (
|
| 425 |
+
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
| 426 |
+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
| 427 |
+
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
| 428 |
+
" your script to pass as many initial images as text prompts to suppress this warning."
|
| 429 |
+
)
|
| 430 |
+
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
| 431 |
+
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
| 432 |
+
init_latents = paddle.concat([init_latents] * additional_image_per_prompt, axis=0)
|
| 433 |
+
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
| 434 |
+
raise ValueError(
|
| 435 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
| 436 |
+
)
|
| 437 |
+
else:
|
| 438 |
+
init_latents = paddle.concat([init_latents], axis=0)
|
| 439 |
+
|
| 440 |
+
shape = init_latents.shape
|
| 441 |
+
noise = randn_tensor(shape, generator=generator, dtype=dtype)
|
| 442 |
+
|
| 443 |
+
# get latents
|
| 444 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
| 445 |
+
latents = init_latents
|
| 446 |
+
|
| 447 |
+
return latents
|
| 448 |
+
|
| 449 |
+
# Copied from ppdiffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
| 450 |
+
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=paddle.float32):
|
| 451 |
+
"""
|
| 452 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
timesteps (`paddle.Tensor`):
|
| 456 |
+
generate embedding vectors at these timesteps
|
| 457 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
| 458 |
+
dimension of the embeddings to generate
|
| 459 |
+
dtype:
|
| 460 |
+
data type of the generated embeddings
|
| 461 |
+
|
| 462 |
+
Returns:
|
| 463 |
+
`paddle.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
| 464 |
+
"""
|
| 465 |
+
assert len(w.shape) == 1
|
| 466 |
+
w = w * 1000.0
|
| 467 |
+
|
| 468 |
+
half_dim = embedding_dim // 2
|
| 469 |
+
emb = paddle.log(paddle.to_tensor(10000.0)) / (half_dim - 1)
|
| 470 |
+
emb = paddle.exp(paddle.arange(half_dim, dtype=dtype) * -emb)
|
| 471 |
+
emb = w.cast(dtype=dtype)[:, None] * emb[None, :]
|
| 472 |
+
emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=1)
|
| 473 |
+
if embedding_dim % 2 == 1:
|
| 474 |
+
emb = paddle.concat(emb, paddle.zeros([emb.shape[0], 1]), axis=-1)
|
| 475 |
+
assert emb.shape == [w.shape[0], embedding_dim]
|
| 476 |
+
return emb
|
| 477 |
+
|
| 478 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 479 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 480 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 481 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 482 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 483 |
+
# and should be between [0, 1]
|
| 484 |
+
|
| 485 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 486 |
+
extra_step_kwargs = {}
|
| 487 |
+
if accepts_eta:
|
| 488 |
+
extra_step_kwargs["eta"] = eta
|
| 489 |
+
|
| 490 |
+
# check if the scheduler accepts generator
|
| 491 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 492 |
+
if accepts_generator:
|
| 493 |
+
extra_step_kwargs["generator"] = generator
|
| 494 |
+
return extra_step_kwargs
|
| 495 |
+
|
| 496 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
| 497 |
+
def get_timesteps(self, num_inference_steps, strength):
|
| 498 |
+
# get the original timestep using init_timestep
|
| 499 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 500 |
+
|
| 501 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 502 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
| 503 |
+
|
| 504 |
+
return timesteps, num_inference_steps - t_start
|
| 505 |
+
|
| 506 |
+
def check_inputs(
|
| 507 |
+
self,
|
| 508 |
+
prompt: Union[str, List[str]],
|
| 509 |
+
strength: float,
|
| 510 |
+
callback_steps: int,
|
| 511 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 512 |
+
callback_on_step_end_tensor_inputs=None,
|
| 513 |
+
):
|
| 514 |
+
if strength < 0 or strength > 1:
|
| 515 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
| 516 |
+
|
| 517 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
| 518 |
+
raise ValueError(
|
| 519 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 520 |
+
f" {type(callback_steps)}."
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 524 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 525 |
+
):
|
| 526 |
+
raise ValueError(
|
| 527 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
if prompt is not None and prompt_embeds is not None:
|
| 531 |
+
raise ValueError(
|
| 532 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 533 |
+
" only forward one of the two."
|
| 534 |
+
)
|
| 535 |
+
elif prompt is None and prompt_embeds is None:
|
| 536 |
+
raise ValueError(
|
| 537 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 538 |
+
)
|
| 539 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 540 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 541 |
+
|
| 542 |
+
@property
|
| 543 |
+
def guidance_scale(self):
|
| 544 |
+
return self._guidance_scale
|
| 545 |
+
|
| 546 |
+
@property
|
| 547 |
+
def cross_attention_kwargs(self):
|
| 548 |
+
return self._cross_attention_kwargs
|
| 549 |
+
|
| 550 |
+
@property
|
| 551 |
+
def clip_skip(self):
|
| 552 |
+
return self._clip_skip
|
| 553 |
+
|
| 554 |
+
@property
|
| 555 |
+
def num_timesteps(self):
|
| 556 |
+
return self._num_timesteps
|
| 557 |
+
|
| 558 |
+
@paddle.no_grad()
|
| 559 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 560 |
+
def __call__(
|
| 561 |
+
self,
|
| 562 |
+
prompt: Union[str, List[str]] = None,
|
| 563 |
+
image: PipelineImageInput = None,
|
| 564 |
+
num_inference_steps: int = 4,
|
| 565 |
+
strength: float = 0.8,
|
| 566 |
+
original_inference_steps: int = None,
|
| 567 |
+
timesteps: List[int] = None,
|
| 568 |
+
guidance_scale: float = 8.5,
|
| 569 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 570 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 571 |
+
latents: Optional[paddle.Tensor] = None,
|
| 572 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 573 |
+
output_type: Optional[str] = "pil",
|
| 574 |
+
return_dict: bool = True,
|
| 575 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 576 |
+
clip_skip: Optional[int] = None,
|
| 577 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 578 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 579 |
+
**kwargs,
|
| 580 |
+
):
|
| 581 |
+
r"""
|
| 582 |
+
The call function to the pipeline for generation.
|
| 583 |
+
|
| 584 |
+
Args:
|
| 585 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 586 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 587 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 588 |
+
The height in pixels of the generated image.
|
| 589 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 590 |
+
The width in pixels of the generated image.
|
| 591 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 592 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 593 |
+
expense of slower inference.
|
| 594 |
+
original_inference_steps (`int`, *optional*):
|
| 595 |
+
The original number of inference steps use to generate a linearly-spaced timestep schedule, from which
|
| 596 |
+
we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule,
|
| 597 |
+
following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the
|
| 598 |
+
scheduler's `original_inference_steps` attribute.
|
| 599 |
+
timesteps (`List[int]`, *optional*):
|
| 600 |
+
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
| 601 |
+
timesteps on the original LCM training/distillation timestep schedule are used. Must be in descending
|
| 602 |
+
order.
|
| 603 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 604 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 605 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 606 |
+
Note that the original latent consistency models paper uses a different CFG formulation where the
|
| 607 |
+
guidance scales are decreased by 1 (so in the paper formulation CFG is enabled when `guidance_scale >
|
| 608 |
+
0`).
|
| 609 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 610 |
+
The number of images to generate per prompt.
|
| 611 |
+
generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*):
|
| 612 |
+
A [`paddle.Generator`] to make generation deterministic.
|
| 613 |
+
latents (`paddle.Tensor`, *optional*):
|
| 614 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 615 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 616 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 617 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 618 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 619 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 620 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 621 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 622 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 623 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 624 |
+
plain tuple.
|
| 625 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 626 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 627 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 628 |
+
clip_skip (`int`, *optional*):
|
| 629 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 630 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 631 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 632 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 633 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 634 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 635 |
+
`callback_on_step_end_tensor_inputs`.
|
| 636 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 637 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 638 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 639 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 640 |
+
|
| 641 |
+
Examples:
|
| 642 |
+
|
| 643 |
+
Returns:
|
| 644 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 645 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| 646 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
| 647 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
| 648 |
+
"not-safe-for-work" (nsfw) content.
|
| 649 |
+
"""
|
| 650 |
+
callback = kwargs.pop("callback", None)
|
| 651 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 652 |
+
|
| 653 |
+
if callback is not None:
|
| 654 |
+
deprecate(
|
| 655 |
+
"callback",
|
| 656 |
+
"1.0.0",
|
| 657 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 658 |
+
)
|
| 659 |
+
if callback_steps is not None:
|
| 660 |
+
deprecate(
|
| 661 |
+
"callback_steps",
|
| 662 |
+
"1.0.0",
|
| 663 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
# 1. Check inputs. Raise error if not correct
|
| 667 |
+
self.check_inputs(prompt, strength, callback_steps, prompt_embeds, callback_on_step_end_tensor_inputs)
|
| 668 |
+
self._guidance_scale = guidance_scale
|
| 669 |
+
self._clip_skip = clip_skip
|
| 670 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 671 |
+
|
| 672 |
+
# 2. Define call parameters
|
| 673 |
+
if prompt is not None and isinstance(prompt, str):
|
| 674 |
+
batch_size = 1
|
| 675 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 676 |
+
batch_size = len(prompt)
|
| 677 |
+
else:
|
| 678 |
+
batch_size = prompt_embeds.shape[0]
|
| 679 |
+
|
| 680 |
+
# do_classifier_free_guidance = guidance_scale > 1.0
|
| 681 |
+
|
| 682 |
+
# 3. Encode input prompt
|
| 683 |
+
lora_scale = (
|
| 684 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
# NOTE: when a LCM is distilled from an LDM via latent consistency distillation (Algorithm 1) with guided
|
| 688 |
+
# distillation, the forward pass of the LCM learns to approximate sampling from the LDM using CFG with the
|
| 689 |
+
# unconditional prompt "" (the empty string). Due to this, LCMs currently do not support negative prompts.
|
| 690 |
+
prompt_embeds, _ = self.encode_prompt(
|
| 691 |
+
prompt,
|
| 692 |
+
num_images_per_prompt,
|
| 693 |
+
False,
|
| 694 |
+
negative_prompt=None,
|
| 695 |
+
prompt_embeds=prompt_embeds,
|
| 696 |
+
negative_prompt_embeds=None,
|
| 697 |
+
lora_scale=lora_scale,
|
| 698 |
+
clip_skip=self.clip_skip,
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
# 4. Encode image
|
| 702 |
+
image = self.image_processor.preprocess(image)
|
| 703 |
+
|
| 704 |
+
# 5. Prepare timesteps
|
| 705 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 706 |
+
self.scheduler,
|
| 707 |
+
num_inference_steps,
|
| 708 |
+
timesteps,
|
| 709 |
+
original_inference_steps=original_inference_steps,
|
| 710 |
+
strength=strength,
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
# 6. Prepare latent variables
|
| 714 |
+
original_inference_steps = (
|
| 715 |
+
original_inference_steps
|
| 716 |
+
if original_inference_steps is not None
|
| 717 |
+
else self.scheduler.config.original_inference_steps
|
| 718 |
+
)
|
| 719 |
+
latent_timestep = timesteps[:1]
|
| 720 |
+
latents = self.prepare_latents(
|
| 721 |
+
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, generator
|
| 722 |
+
)
|
| 723 |
+
bs = batch_size * num_images_per_prompt
|
| 724 |
+
|
| 725 |
+
# 6. Get Guidance Scale Embedding
|
| 726 |
+
# NOTE: We use the Imagen CFG formulation that StableDiffusionPipeline uses rather than the original LCM paper
|
| 727 |
+
# CFG formulation, so we need to subtract 1 from the input guidance_scale.
|
| 728 |
+
# LCM CFG formulation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond), (cfg_scale > 0.0 using CFG)
|
| 729 |
+
w = paddle.to_tensor([self.guidance_scale - 1]).tile(
|
| 730 |
+
[
|
| 731 |
+
bs,
|
| 732 |
+
]
|
| 733 |
+
)
|
| 734 |
+
w_embedding = self.get_guidance_scale_embedding(w, embedding_dim=self.unet.config.time_cond_proj_dim).cast(
|
| 735 |
+
dtype=latents.dtype
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 739 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)
|
| 740 |
+
|
| 741 |
+
# 8. LCM Multistep Sampling Loop
|
| 742 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 743 |
+
self._num_timesteps = len(timesteps)
|
| 744 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 745 |
+
for i, t in enumerate(timesteps):
|
| 746 |
+
latents = latents.cast(dtype=prompt_embeds.dtype)
|
| 747 |
+
|
| 748 |
+
# model prediction (v-prediction, eps, x)
|
| 749 |
+
model_pred = self.unet(
|
| 750 |
+
latents,
|
| 751 |
+
t,
|
| 752 |
+
timestep_cond=w_embedding,
|
| 753 |
+
encoder_hidden_states=prompt_embeds,
|
| 754 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 755 |
+
return_dict=False,
|
| 756 |
+
)[0]
|
| 757 |
+
|
| 758 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 759 |
+
latents, denoised = self.scheduler.step(model_pred, t, latents, **extra_step_kwargs, return_dict=False)
|
| 760 |
+
if callback_on_step_end is not None:
|
| 761 |
+
callback_kwargs = {}
|
| 762 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 763 |
+
callback_kwargs[k] = locals()[k]
|
| 764 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 765 |
+
|
| 766 |
+
latents = callback_outputs.pop("latents", latents)
|
| 767 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 768 |
+
w_embedding = callback_outputs.pop("w_embedding", w_embedding)
|
| 769 |
+
denoised = callback_outputs.pop("denoised", denoised)
|
| 770 |
+
|
| 771 |
+
# call the callback, if provided
|
| 772 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 773 |
+
progress_bar.update()
|
| 774 |
+
if callback is not None and i % callback_steps == 0:
|
| 775 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 776 |
+
callback(step_idx, t, latents)
|
| 777 |
+
|
| 778 |
+
denoised = denoised.to(prompt_embeds.dtype)
|
| 779 |
+
if not output_type == "latent":
|
| 780 |
+
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 781 |
+
image, has_nsfw_concept = self.run_safety_checker(image, prompt_embeds.dtype)
|
| 782 |
+
else:
|
| 783 |
+
image = denoised
|
| 784 |
+
has_nsfw_concept = None
|
| 785 |
+
|
| 786 |
+
if has_nsfw_concept is None:
|
| 787 |
+
do_denormalize = [True] * image.shape[0]
|
| 788 |
+
else:
|
| 789 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 790 |
+
|
| 791 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 792 |
+
|
| 793 |
+
if not return_dict:
|
| 794 |
+
return (image, has_nsfw_concept)
|
| 795 |
+
|
| 796 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
ADDED
|
@@ -0,0 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
| 16 |
+
# and https://github.com/hojonathanho/diffusion
|
| 17 |
+
|
| 18 |
+
import inspect
|
| 19 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 20 |
+
|
| 21 |
+
import paddle
|
| 22 |
+
|
| 23 |
+
from ppdiffusers.transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
| 24 |
+
|
| 25 |
+
from ...image_processor import VaeImageProcessor
|
| 26 |
+
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
| 27 |
+
from ...models import AutoencoderKL, UNet2DConditionModel
|
| 28 |
+
from ...models.lora import adjust_lora_scale_text_encoder
|
| 29 |
+
from ...schedulers import LCMScheduler
|
| 30 |
+
from ...utils import USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring
|
| 31 |
+
from ...utils.paddle_utils import randn_tensor
|
| 32 |
+
from ..pipeline_utils import DiffusionPipeline
|
| 33 |
+
from ..stable_diffusion import (
|
| 34 |
+
StableDiffusionPipelineOutput,
|
| 35 |
+
StableDiffusionSafetyChecker,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 39 |
+
|
| 40 |
+
EXAMPLE_DOC_STRING = """
|
| 41 |
+
Examples:
|
| 42 |
+
```py
|
| 43 |
+
>>> from ppdiffusers import DiffusionPipeline
|
| 44 |
+
>>> import paddle
|
| 45 |
+
|
| 46 |
+
>>> pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7")
|
| 47 |
+
>>> # To save GPU memory, paddle.float16 can be used, but it may compromise image quality.
|
| 48 |
+
>>> pipe.to(paddle_dtype=paddle.float32)
|
| 49 |
+
|
| 50 |
+
>>> prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
| 51 |
+
|
| 52 |
+
>>> # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
|
| 53 |
+
>>> num_inference_steps = 4
|
| 54 |
+
>>> images = pipe(prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=8.0).images
|
| 55 |
+
>>> images[0].save("image.png")
|
| 56 |
+
```
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 61 |
+
def retrieve_timesteps(
|
| 62 |
+
scheduler,
|
| 63 |
+
num_inference_steps: Optional[int] = None,
|
| 64 |
+
timesteps: Optional[List[int]] = None,
|
| 65 |
+
**kwargs,
|
| 66 |
+
):
|
| 67 |
+
"""
|
| 68 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 69 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
scheduler (`SchedulerMixin`):
|
| 73 |
+
The scheduler to get timesteps from.
|
| 74 |
+
num_inference_steps (`int`):
|
| 75 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
| 76 |
+
`timesteps` must be `None`.
|
| 77 |
+
timesteps (`List[int]`, *optional*):
|
| 78 |
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
| 79 |
+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
| 80 |
+
must be `None`.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
`Tuple[paddle.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 84 |
+
second element is the number of inference steps.
|
| 85 |
+
"""
|
| 86 |
+
if timesteps is not None:
|
| 87 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 88 |
+
if not accepts_timesteps:
|
| 89 |
+
raise ValueError(
|
| 90 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 91 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 92 |
+
)
|
| 93 |
+
scheduler.set_timesteps(timesteps=timesteps, **kwargs)
|
| 94 |
+
timesteps = scheduler.timesteps
|
| 95 |
+
num_inference_steps = len(timesteps)
|
| 96 |
+
else:
|
| 97 |
+
scheduler.set_timesteps(num_inference_steps, **kwargs)
|
| 98 |
+
timesteps = scheduler.timesteps
|
| 99 |
+
return timesteps, num_inference_steps
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class LatentConsistencyModelPipeline(
|
| 103 |
+
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
| 104 |
+
):
|
| 105 |
+
r"""
|
| 106 |
+
Pipeline for text-to-image generation using a latent consistency model.
|
| 107 |
+
|
| 108 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 109 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 110 |
+
|
| 111 |
+
The pipeline also inherits the following loading methods:
|
| 112 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
| 113 |
+
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
| 114 |
+
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
| 115 |
+
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
vae ([`AutoencoderKL`]):
|
| 119 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 120 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
| 121 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
| 122 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
| 123 |
+
A `CLIPTokenizer` to tokenize text.
|
| 124 |
+
unet ([`UNet2DConditionModel`]):
|
| 125 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
| 126 |
+
scheduler ([`SchedulerMixin`]):
|
| 127 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only
|
| 128 |
+
supports [`LCMScheduler`].
|
| 129 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
| 130 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 131 |
+
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
| 132 |
+
about a model's potential harms.
|
| 133 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
| 134 |
+
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
| 135 |
+
requires_safety_checker (`bool`, *optional*, defaults to `True`):
|
| 136 |
+
Whether the pipeline requires a safety checker component.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
model_cpu_offload_seq = "text_encoder->unet->vae"
|
| 140 |
+
_optional_components = ["safety_checker", "feature_extractor"]
|
| 141 |
+
_exclude_from_cpu_offload = ["safety_checker"]
|
| 142 |
+
_callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"]
|
| 143 |
+
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
vae: AutoencoderKL,
|
| 147 |
+
text_encoder: CLIPTextModel,
|
| 148 |
+
tokenizer: CLIPTokenizer,
|
| 149 |
+
unet: UNet2DConditionModel,
|
| 150 |
+
scheduler: LCMScheduler,
|
| 151 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 152 |
+
feature_extractor: CLIPImageProcessor,
|
| 153 |
+
requires_safety_checker: bool = True,
|
| 154 |
+
):
|
| 155 |
+
super().__init__()
|
| 156 |
+
|
| 157 |
+
if safety_checker is None and requires_safety_checker:
|
| 158 |
+
logger.warning(
|
| 159 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
| 160 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
| 161 |
+
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
| 162 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
| 163 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
| 164 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
if safety_checker is not None and feature_extractor is None:
|
| 168 |
+
raise ValueError(
|
| 169 |
+
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
| 170 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
self.register_modules(
|
| 174 |
+
vae=vae,
|
| 175 |
+
text_encoder=text_encoder,
|
| 176 |
+
tokenizer=tokenizer,
|
| 177 |
+
unet=unet,
|
| 178 |
+
scheduler=scheduler,
|
| 179 |
+
safety_checker=safety_checker,
|
| 180 |
+
feature_extractor=feature_extractor,
|
| 181 |
+
)
|
| 182 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 183 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 184 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
| 185 |
+
|
| 186 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
|
| 187 |
+
def encode_prompt(
|
| 188 |
+
self,
|
| 189 |
+
prompt,
|
| 190 |
+
num_images_per_prompt,
|
| 191 |
+
do_classifier_free_guidance,
|
| 192 |
+
negative_prompt=None,
|
| 193 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 194 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 195 |
+
lora_scale: Optional[float] = None,
|
| 196 |
+
clip_skip: Optional[int] = None,
|
| 197 |
+
):
|
| 198 |
+
r"""
|
| 199 |
+
Encodes the prompt into text encoder hidden states.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 203 |
+
prompt to be encoded
|
| 204 |
+
num_images_per_prompt (`int`):
|
| 205 |
+
number of images that should be generated per prompt
|
| 206 |
+
do_classifier_free_guidance (`bool`):
|
| 207 |
+
whether to use classifier free guidance or not
|
| 208 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 209 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 210 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 211 |
+
less than `1`).
|
| 212 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 213 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 214 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 215 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 216 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 217 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 218 |
+
argument.
|
| 219 |
+
lora_scale (`float`, *optional*):
|
| 220 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 221 |
+
clip_skip (`int`, *optional*):
|
| 222 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 223 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 224 |
+
"""
|
| 225 |
+
# set lora scale so that monkey patched LoRA
|
| 226 |
+
# function of text encoder can correctly access it
|
| 227 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
| 228 |
+
self._lora_scale = lora_scale
|
| 229 |
+
|
| 230 |
+
# dynamically adjust the LoRA scale
|
| 231 |
+
if not USE_PEFT_BACKEND:
|
| 232 |
+
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
| 233 |
+
|
| 234 |
+
if prompt is not None and isinstance(prompt, str):
|
| 235 |
+
batch_size = 1
|
| 236 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 237 |
+
batch_size = len(prompt)
|
| 238 |
+
else:
|
| 239 |
+
batch_size = prompt_embeds.shape[0]
|
| 240 |
+
|
| 241 |
+
if prompt_embeds is None:
|
| 242 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 243 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 244 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 245 |
+
|
| 246 |
+
text_inputs = self.tokenizer(
|
| 247 |
+
prompt,
|
| 248 |
+
padding="max_length",
|
| 249 |
+
max_length=self.tokenizer.model_max_length,
|
| 250 |
+
truncation=True,
|
| 251 |
+
return_tensors="pd",
|
| 252 |
+
)
|
| 253 |
+
text_input_ids = text_inputs.input_ids
|
| 254 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pd").input_ids
|
| 255 |
+
|
| 256 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not paddle.equal_all(
|
| 257 |
+
text_input_ids, untruncated_ids
|
| 258 |
+
):
|
| 259 |
+
removed_text = self.tokenizer.batch_decode(
|
| 260 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
| 261 |
+
)
|
| 262 |
+
logger.warning(
|
| 263 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 264 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 268 |
+
attention_mask = text_inputs.attention_mask
|
| 269 |
+
else:
|
| 270 |
+
attention_mask = None
|
| 271 |
+
|
| 272 |
+
if clip_skip is None:
|
| 273 |
+
prompt_embeds = self.text_encoder(text_input_ids, attention_mask=attention_mask)
|
| 274 |
+
prompt_embeds = prompt_embeds[0]
|
| 275 |
+
else:
|
| 276 |
+
prompt_embeds = self.text_encoder(
|
| 277 |
+
text_input_ids, attention_mask=attention_mask, output_hidden_states=True
|
| 278 |
+
)
|
| 279 |
+
# Access the `hidden_states` first, that contains a tuple of
|
| 280 |
+
# all the hidden states from the encoder layers. Then index into
|
| 281 |
+
# the tuple to access the hidden states from the desired layer.
|
| 282 |
+
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
| 283 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
| 284 |
+
# representations. The `last_hidden_states` that we typically use for
|
| 285 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
| 286 |
+
# layer.
|
| 287 |
+
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
| 288 |
+
|
| 289 |
+
if self.text_encoder is not None:
|
| 290 |
+
prompt_embeds_dtype = self.text_encoder.dtype
|
| 291 |
+
elif self.unet is not None:
|
| 292 |
+
prompt_embeds_dtype = self.unet.dtype
|
| 293 |
+
else:
|
| 294 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
| 295 |
+
|
| 296 |
+
prompt_embeds = prompt_embeds.cast(dtype=prompt_embeds_dtype)
|
| 297 |
+
|
| 298 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 299 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 300 |
+
prompt_embeds = prompt_embeds.tile([1, num_images_per_prompt, 1])
|
| 301 |
+
prompt_embeds = prompt_embeds.reshape([bs_embed * num_images_per_prompt, seq_len, -1])
|
| 302 |
+
|
| 303 |
+
# get unconditional embeddings for classifier free guidance
|
| 304 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 305 |
+
uncond_tokens: List[str]
|
| 306 |
+
if negative_prompt is None:
|
| 307 |
+
uncond_tokens = [""] * batch_size
|
| 308 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
| 309 |
+
raise TypeError(
|
| 310 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 311 |
+
f" {type(prompt)}."
|
| 312 |
+
)
|
| 313 |
+
elif isinstance(negative_prompt, str):
|
| 314 |
+
uncond_tokens = [negative_prompt]
|
| 315 |
+
elif batch_size != len(negative_prompt):
|
| 316 |
+
raise ValueError(
|
| 317 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 318 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 319 |
+
" the batch size of `prompt`."
|
| 320 |
+
)
|
| 321 |
+
else:
|
| 322 |
+
uncond_tokens = negative_prompt
|
| 323 |
+
|
| 324 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 325 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 326 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
| 327 |
+
|
| 328 |
+
max_length = prompt_embeds.shape[1]
|
| 329 |
+
uncond_input = self.tokenizer(
|
| 330 |
+
uncond_tokens,
|
| 331 |
+
padding="max_length",
|
| 332 |
+
max_length=max_length,
|
| 333 |
+
truncation=True,
|
| 334 |
+
return_tensors="pd",
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
| 338 |
+
attention_mask = uncond_input.attention_mask
|
| 339 |
+
else:
|
| 340 |
+
attention_mask = None
|
| 341 |
+
|
| 342 |
+
negative_prompt_embeds = self.text_encoder(
|
| 343 |
+
uncond_input.input_ids,
|
| 344 |
+
attention_mask=attention_mask,
|
| 345 |
+
)
|
| 346 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 347 |
+
|
| 348 |
+
if do_classifier_free_guidance:
|
| 349 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 350 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 351 |
+
|
| 352 |
+
negative_prompt_embeds = negative_prompt_embeds.cast(dtype=prompt_embeds_dtype)
|
| 353 |
+
|
| 354 |
+
negative_prompt_embeds = negative_prompt_embeds.tile([1, num_images_per_prompt, 1])
|
| 355 |
+
negative_prompt_embeds = negative_prompt_embeds.reshape([batch_size * num_images_per_prompt, seq_len, -1])
|
| 356 |
+
|
| 357 |
+
return prompt_embeds, negative_prompt_embeds
|
| 358 |
+
|
| 359 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
| 360 |
+
def run_safety_checker(self, image, dtype):
|
| 361 |
+
if self.safety_checker is None:
|
| 362 |
+
has_nsfw_concept = None
|
| 363 |
+
else:
|
| 364 |
+
if paddle.is_tensor(image):
|
| 365 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
| 366 |
+
else:
|
| 367 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
| 368 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pd")
|
| 369 |
+
image, has_nsfw_concept = self.safety_checker(
|
| 370 |
+
images=image, clip_input=safety_checker_input.pixel_values.cast(dtype=dtype)
|
| 371 |
+
)
|
| 372 |
+
return image, has_nsfw_concept
|
| 373 |
+
|
| 374 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
| 375 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
|
| 376 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
| 377 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 378 |
+
raise ValueError(
|
| 379 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 380 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
if latents is None:
|
| 384 |
+
latents = randn_tensor(shape, generator=generator, dtype=dtype)
|
| 385 |
+
else:
|
| 386 |
+
latents = latents.cast(dtype=dtype)
|
| 387 |
+
|
| 388 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 389 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 390 |
+
return latents
|
| 391 |
+
|
| 392 |
+
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=paddle.float32):
|
| 393 |
+
"""
|
| 394 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
timesteps (`paddle.Tensor`):
|
| 398 |
+
generate embedding vectors at these timesteps
|
| 399 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
| 400 |
+
dimension of the embeddings to generate
|
| 401 |
+
dtype:
|
| 402 |
+
data type of the generated embeddings
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
`paddle.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
| 406 |
+
"""
|
| 407 |
+
assert len(w.shape) == 1
|
| 408 |
+
w = w * 1000.0
|
| 409 |
+
|
| 410 |
+
half_dim = embedding_dim // 2
|
| 411 |
+
emb = paddle.log(paddle.to_tensor(10000.0)) / (half_dim - 1)
|
| 412 |
+
emb = paddle.exp(paddle.arange(half_dim, dtype=dtype) * -emb)
|
| 413 |
+
emb = w.cast(dtype=dtype)[:, None] * emb[None, :]
|
| 414 |
+
emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=1)
|
| 415 |
+
if embedding_dim % 2 == 1:
|
| 416 |
+
emb = paddle.concat(emb, paddle.zeros([emb.shape[0], 1]), axis=-1)
|
| 417 |
+
assert emb.shape == [w.shape[0], embedding_dim]
|
| 418 |
+
return emb
|
| 419 |
+
|
| 420 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 421 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 422 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 423 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 424 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 425 |
+
# and should be between [0, 1]
|
| 426 |
+
|
| 427 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 428 |
+
extra_step_kwargs = {}
|
| 429 |
+
if accepts_eta:
|
| 430 |
+
extra_step_kwargs["eta"] = eta
|
| 431 |
+
|
| 432 |
+
# check if the scheduler accepts generator
|
| 433 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 434 |
+
if accepts_generator:
|
| 435 |
+
extra_step_kwargs["generator"] = generator
|
| 436 |
+
return extra_step_kwargs
|
| 437 |
+
|
| 438 |
+
# Currently StableDiffusionPipeline.check_inputs with negative prompt stuff removed
|
| 439 |
+
def check_inputs(
|
| 440 |
+
self,
|
| 441 |
+
prompt: Union[str, List[str]],
|
| 442 |
+
height: int,
|
| 443 |
+
width: int,
|
| 444 |
+
callback_steps: int,
|
| 445 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 446 |
+
callback_on_step_end_tensor_inputs=None,
|
| 447 |
+
):
|
| 448 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 449 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 450 |
+
|
| 451 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
| 452 |
+
raise ValueError(
|
| 453 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 454 |
+
f" {type(callback_steps)}."
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 458 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 459 |
+
):
|
| 460 |
+
raise ValueError(
|
| 461 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
if prompt is not None and prompt_embeds is not None:
|
| 465 |
+
raise ValueError(
|
| 466 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 467 |
+
" only forward one of the two."
|
| 468 |
+
)
|
| 469 |
+
elif prompt is None and prompt_embeds is None:
|
| 470 |
+
raise ValueError(
|
| 471 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 472 |
+
)
|
| 473 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 474 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 475 |
+
|
| 476 |
+
@property
|
| 477 |
+
def guidance_scale(self):
|
| 478 |
+
return self._guidance_scale
|
| 479 |
+
|
| 480 |
+
@property
|
| 481 |
+
def cross_attention_kwargs(self):
|
| 482 |
+
return self._cross_attention_kwargs
|
| 483 |
+
|
| 484 |
+
@property
|
| 485 |
+
def clip_skip(self):
|
| 486 |
+
return self._clip_skip
|
| 487 |
+
|
| 488 |
+
@property
|
| 489 |
+
def num_timesteps(self):
|
| 490 |
+
return self._num_timesteps
|
| 491 |
+
|
| 492 |
+
@paddle.no_grad()
|
| 493 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 494 |
+
def __call__(
|
| 495 |
+
self,
|
| 496 |
+
prompt: Union[str, List[str]] = None,
|
| 497 |
+
height: Optional[int] = None,
|
| 498 |
+
width: Optional[int] = None,
|
| 499 |
+
num_inference_steps: int = 4,
|
| 500 |
+
original_inference_steps: int = None,
|
| 501 |
+
timesteps: List[int] = None,
|
| 502 |
+
guidance_scale: float = 8.5,
|
| 503 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 504 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 505 |
+
latents: Optional[paddle.Tensor] = None,
|
| 506 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 507 |
+
output_type: Optional[str] = "pil",
|
| 508 |
+
return_dict: bool = True,
|
| 509 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 510 |
+
clip_skip: Optional[int] = None,
|
| 511 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 512 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 513 |
+
**kwargs,
|
| 514 |
+
):
|
| 515 |
+
r"""
|
| 516 |
+
The call function to the pipeline for generation.
|
| 517 |
+
|
| 518 |
+
Args:
|
| 519 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 520 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 521 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 522 |
+
The height in pixels of the generated image.
|
| 523 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 524 |
+
The width in pixels of the generated image.
|
| 525 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 526 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 527 |
+
expense of slower inference.
|
| 528 |
+
original_inference_steps (`int`, *optional*):
|
| 529 |
+
The original number of inference steps use to generate a linearly-spaced timestep schedule, from which
|
| 530 |
+
we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule,
|
| 531 |
+
following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the
|
| 532 |
+
scheduler's `original_inference_steps` attribute.
|
| 533 |
+
timesteps (`List[int]`, *optional*):
|
| 534 |
+
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
| 535 |
+
timesteps on the original LCM training/distillation timestep schedule are used. Must be in descending
|
| 536 |
+
order.
|
| 537 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 538 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 539 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 540 |
+
Note that the original latent consistency models paper uses a different CFG formulation where the
|
| 541 |
+
guidance scales are decreased by 1 (so in the paper formulation CFG is enabled when `guidance_scale >
|
| 542 |
+
0`).
|
| 543 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 544 |
+
The number of images to generate per prompt.
|
| 545 |
+
generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*):
|
| 546 |
+
A [`paddle.Generator`] to make generation deterministic.
|
| 547 |
+
latents (`paddle.Tensor`, *optional*):
|
| 548 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 549 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 550 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 551 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 552 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 553 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 554 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 555 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 556 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 557 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 558 |
+
plain tuple.
|
| 559 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 560 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 561 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 562 |
+
clip_skip (`int`, *optional*):
|
| 563 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 564 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 565 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 566 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 567 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 568 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 569 |
+
`callback_on_step_end_tensor_inputs`.
|
| 570 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 571 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 572 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 573 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 574 |
+
|
| 575 |
+
Examples:
|
| 576 |
+
|
| 577 |
+
Returns:
|
| 578 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 579 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| 580 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
| 581 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
| 582 |
+
"not-safe-for-work" (nsfw) content.
|
| 583 |
+
"""
|
| 584 |
+
|
| 585 |
+
callback = kwargs.pop("callback", None)
|
| 586 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 587 |
+
|
| 588 |
+
if callback is not None:
|
| 589 |
+
deprecate(
|
| 590 |
+
"callback",
|
| 591 |
+
"1.0.0",
|
| 592 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 593 |
+
)
|
| 594 |
+
if callback_steps is not None:
|
| 595 |
+
deprecate(
|
| 596 |
+
"callback_steps",
|
| 597 |
+
"1.0.0",
|
| 598 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
# 0. Default height and width to unet
|
| 602 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 603 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 604 |
+
|
| 605 |
+
# 1. Check inputs. Raise error if not correct
|
| 606 |
+
self.check_inputs(prompt, height, width, callback_steps, prompt_embeds, callback_on_step_end_tensor_inputs)
|
| 607 |
+
self._guidance_scale = guidance_scale
|
| 608 |
+
self._clip_skip = clip_skip
|
| 609 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 610 |
+
|
| 611 |
+
# 2. Define call parameters
|
| 612 |
+
if prompt is not None and isinstance(prompt, str):
|
| 613 |
+
batch_size = 1
|
| 614 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 615 |
+
batch_size = len(prompt)
|
| 616 |
+
else:
|
| 617 |
+
batch_size = prompt_embeds.shape[0]
|
| 618 |
+
|
| 619 |
+
# do_classifier_free_guidance = guidance_scale > 1.0
|
| 620 |
+
|
| 621 |
+
# 3. Encode input prompt
|
| 622 |
+
lora_scale = (
|
| 623 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
# NOTE: when a LCM is distilled from an LDM via latent consistency distillation (Algorithm 1) with guided
|
| 627 |
+
# distillation, the forward pass of the LCM learns to approximate sampling from the LDM using CFG with the
|
| 628 |
+
# unconditional prompt "" (the empty string). Due to this, LCMs currently do not support negative prompts.
|
| 629 |
+
prompt_embeds, _ = self.encode_prompt(
|
| 630 |
+
prompt,
|
| 631 |
+
num_images_per_prompt,
|
| 632 |
+
False,
|
| 633 |
+
negative_prompt=None,
|
| 634 |
+
prompt_embeds=prompt_embeds,
|
| 635 |
+
negative_prompt_embeds=None,
|
| 636 |
+
lora_scale=lora_scale,
|
| 637 |
+
clip_skip=self.clip_skip,
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
# 4. Prepare timesteps
|
| 641 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 642 |
+
self.scheduler, num_inference_steps, timesteps, original_inference_steps=original_inference_steps
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
# 5. Prepare latent variable
|
| 646 |
+
num_channels_latents = self.unet.config.in_channels
|
| 647 |
+
latents = self.prepare_latents(
|
| 648 |
+
batch_size * num_images_per_prompt,
|
| 649 |
+
num_channels_latents,
|
| 650 |
+
height,
|
| 651 |
+
width,
|
| 652 |
+
prompt_embeds.dtype,
|
| 653 |
+
generator,
|
| 654 |
+
latents,
|
| 655 |
+
)
|
| 656 |
+
bs = batch_size * num_images_per_prompt
|
| 657 |
+
|
| 658 |
+
# 6. Get Guidance Scale Embedding
|
| 659 |
+
# NOTE: We use the Imagen CFG formulation that StableDiffusionPipeline uses rather than the original LCM paper
|
| 660 |
+
# CFG formulation, so we need to subtract 1 from the input guidance_scale.
|
| 661 |
+
# LCM CFG formulation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond), (cfg_scale > 0.0 using CFG)
|
| 662 |
+
w = paddle.to_tensor([self.guidance_scale - 1]).tile(
|
| 663 |
+
[
|
| 664 |
+
bs,
|
| 665 |
+
]
|
| 666 |
+
)
|
| 667 |
+
w_embedding = self.get_guidance_scale_embedding(w, embedding_dim=self.unet.config.time_cond_proj_dim).cast(
|
| 668 |
+
dtype=latents.dtype
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 672 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)
|
| 673 |
+
|
| 674 |
+
# 8. LCM MultiStep Sampling Loop:
|
| 675 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 676 |
+
self._num_timesteps = len(timesteps)
|
| 677 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 678 |
+
for i, t in enumerate(timesteps):
|
| 679 |
+
latents = latents.cast(dtype=prompt_embeds.dtype)
|
| 680 |
+
|
| 681 |
+
# model prediction (v-prediction, eps, x)
|
| 682 |
+
model_pred = self.unet(
|
| 683 |
+
latents,
|
| 684 |
+
t,
|
| 685 |
+
timestep_cond=w_embedding,
|
| 686 |
+
encoder_hidden_states=prompt_embeds,
|
| 687 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 688 |
+
return_dict=False,
|
| 689 |
+
)[0]
|
| 690 |
+
|
| 691 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 692 |
+
latents, denoised = self.scheduler.step(model_pred, t, latents, **extra_step_kwargs, return_dict=False)
|
| 693 |
+
if callback_on_step_end is not None:
|
| 694 |
+
callback_kwargs = {}
|
| 695 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 696 |
+
callback_kwargs[k] = locals()[k]
|
| 697 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 698 |
+
|
| 699 |
+
latents = callback_outputs.pop("latents", latents)
|
| 700 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 701 |
+
w_embedding = callback_outputs.pop("w_embedding", w_embedding)
|
| 702 |
+
denoised = callback_outputs.pop("denoised", denoised)
|
| 703 |
+
|
| 704 |
+
# call the callback, if provided
|
| 705 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 706 |
+
progress_bar.update()
|
| 707 |
+
if callback is not None and i % callback_steps == 0:
|
| 708 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 709 |
+
callback(step_idx, t, latents)
|
| 710 |
+
|
| 711 |
+
denoised = denoised.to(prompt_embeds.dtype)
|
| 712 |
+
if not output_type == "latent":
|
| 713 |
+
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 714 |
+
image, has_nsfw_concept = self.run_safety_checker(image, prompt_embeds.dtype)
|
| 715 |
+
else:
|
| 716 |
+
image = denoised
|
| 717 |
+
has_nsfw_concept = None
|
| 718 |
+
|
| 719 |
+
if has_nsfw_concept is None:
|
| 720 |
+
do_denormalize = [True] * image.shape[0]
|
| 721 |
+
else:
|
| 722 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 723 |
+
|
| 724 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 725 |
+
|
| 726 |
+
if not return_dict:
|
| 727 |
+
return (image, has_nsfw_concept)
|
| 728 |
+
|
| 729 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/lvdm/pipeline_latent_video_diffusion_model_uncond.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import os
|
| 18 |
+
from typing import Callable, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import paddle
|
| 22 |
+
|
| 23 |
+
from ...configuration_utils import FrozenDict
|
| 24 |
+
from ...models import LVDMAutoencoderKL, LVDMUNet3DModel
|
| 25 |
+
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
| 26 |
+
from ...utils import deprecate, logging, randn_tensor
|
| 27 |
+
from ..pipeline_utils import DiffusionPipeline
|
| 28 |
+
from . import VideoPipelineOutput
|
| 29 |
+
from .video_save import save_results
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class LVDMUncondPipeline(DiffusionPipeline):
|
| 35 |
+
r"""
|
| 36 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 37 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 38 |
+
|
| 39 |
+
Parameters:
|
| 40 |
+
vae ([`LVDMAutoencoderKL`]):
|
| 41 |
+
Autoencoder Model to encode and decode videos to and from latent representations.
|
| 42 |
+
unet ([`LVDMUNet3DModel`]): 3D conditional U-Net architecture to denoise the encoded video latents.
|
| 43 |
+
scheduler ([`SchedulerMixin`]):
|
| 44 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
| 45 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
vae: LVDMAutoencoderKL,
|
| 51 |
+
unet: LVDMUNet3DModel,
|
| 52 |
+
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
| 53 |
+
):
|
| 54 |
+
super().__init__()
|
| 55 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
| 56 |
+
deprecation_message = (
|
| 57 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
| 58 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
| 59 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
| 60 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
| 61 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
| 62 |
+
" file"
|
| 63 |
+
)
|
| 64 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
| 65 |
+
new_config = dict(scheduler.config)
|
| 66 |
+
new_config["steps_offset"] = 1
|
| 67 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
| 68 |
+
|
| 69 |
+
self.register_modules(vae=vae, unet=unet, scheduler=scheduler)
|
| 70 |
+
|
| 71 |
+
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
| 72 |
+
r"""
|
| 73 |
+
Enable sliced attention computation.
|
| 74 |
+
|
| 75 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
| 76 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
| 80 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
| 81 |
+
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
| 82 |
+
`attention_head_dim` must be a multiple of `slice_size`.
|
| 83 |
+
"""
|
| 84 |
+
if slice_size == "auto":
|
| 85 |
+
# half the attention head size is usually a good trade-off between
|
| 86 |
+
# speed and memory
|
| 87 |
+
slice_size = self.unet.config.attention_head_dim // 2
|
| 88 |
+
self.unet.set_attention_slice(slice_size)
|
| 89 |
+
|
| 90 |
+
def disable_attention_slicing(self):
|
| 91 |
+
r"""
|
| 92 |
+
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
| 93 |
+
back to computing attention in one step.
|
| 94 |
+
"""
|
| 95 |
+
# set slice_size = `None` to disable `attention slicing`
|
| 96 |
+
self.enable_attention_slicing(None)
|
| 97 |
+
|
| 98 |
+
@paddle.no_grad()
|
| 99 |
+
def paddle_to_np(self, x):
|
| 100 |
+
sample = x.detach().cpu()
|
| 101 |
+
if sample.dim() == 5:
|
| 102 |
+
sample = sample.transpose(perm=[0, 2, 3, 4, 1])
|
| 103 |
+
else:
|
| 104 |
+
sample = sample.transpose(perm=[0, 2, 3, 1])
|
| 105 |
+
|
| 106 |
+
if isinstance("uint8", paddle.dtype):
|
| 107 |
+
dtype = "uint8"
|
| 108 |
+
elif isinstance("uint8", str) and "uint8" not in ["cpu", "cuda", "ipu", "xpu"]:
|
| 109 |
+
dtype = "uint8"
|
| 110 |
+
elif isinstance("uint8", paddle.Tensor):
|
| 111 |
+
dtype = "uint8".dtype
|
| 112 |
+
else:
|
| 113 |
+
dtype = ((sample + 1) * 127.5).clip(min=0, max=255).dtype
|
| 114 |
+
sample = ((sample + 1) * 127.5).clip(min=0, max=255).cast(dtype)
|
| 115 |
+
|
| 116 |
+
sample = sample.numpy()
|
| 117 |
+
return sample
|
| 118 |
+
|
| 119 |
+
@paddle.no_grad()
|
| 120 |
+
def __call__(
|
| 121 |
+
self,
|
| 122 |
+
batch_size: int = 1,
|
| 123 |
+
num_frames: Optional[int] = 16,
|
| 124 |
+
height: Optional[int] = 256,
|
| 125 |
+
width: Optional[int] = 256,
|
| 126 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 127 |
+
eta: Optional[float] = 0.0,
|
| 128 |
+
num_inference_steps: Optional[int] = 50,
|
| 129 |
+
latents: Optional[paddle.Tensor] = None,
|
| 130 |
+
output_type: Optional[str] = "pil",
|
| 131 |
+
return_dict: bool = True,
|
| 132 |
+
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
|
| 133 |
+
callback_steps: Optional[int] = 1,
|
| 134 |
+
save_dir=None,
|
| 135 |
+
save_name=None,
|
| 136 |
+
scale_factor: Optional[float] = 0.33422927,
|
| 137 |
+
shift_factor: Optional[float] = 1.4606637,
|
| 138 |
+
save_fps: Optional[int] = 8,
|
| 139 |
+
**kwargs,
|
| 140 |
+
) -> Union[Tuple, VideoPipelineOutput]:
|
| 141 |
+
r"""
|
| 142 |
+
Args:
|
| 143 |
+
height (`int`, *optional*, defaults to 256):
|
| 144 |
+
The height in pixels of the generated image.
|
| 145 |
+
width (`int`, *optional*, defaults to 256):
|
| 146 |
+
The width in pixels of the generated image.
|
| 147 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 148 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 149 |
+
expense of slower inference.
|
| 150 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 151 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 152 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 153 |
+
generator (`paddle.Generator`, *optional*):
|
| 154 |
+
One or a list of paddle generator(s) to make generation deterministic.
|
| 155 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 156 |
+
The output format of the generate image. Choose between
|
| 157 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 158 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 159 |
+
Whether or not to return a [`~pipeline_utils.VideoPipelineOutput`] instead of a
|
| 160 |
+
plain tuple.
|
| 161 |
+
callback (`Callable`, *optional*):
|
| 162 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 163 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
|
| 164 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 165 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 166 |
+
called at every step.
|
| 167 |
+
save_dir (`str` or `List[str]`, *optional*):
|
| 168 |
+
If provided, will save videos generated to *save_dir*. Otherwise will save them to the current path.
|
| 169 |
+
save_name (`str` or `List[str]`, *optional*):
|
| 170 |
+
If provided, will save videos generated to *save_name*.
|
| 171 |
+
scale_factor (`float`, *optional*, defaults to 0.33422927):
|
| 172 |
+
A scale factor to apply to the generated video.
|
| 173 |
+
shift_factor (`float`, *optional*, defaults to 1.4606637):
|
| 174 |
+
A shift factor to apply to the generated video.
|
| 175 |
+
save_fps (`int`, *optional*, defaults to 8):
|
| 176 |
+
The number of frames per second to save.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
[`~pipeline_utils.VideoPipelineOutput`] or `tuple`: [`~pipeline_utils.VideoPipelineOutput`] if
|
| 180 |
+
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
| 181 |
+
generated images.
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 185 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 186 |
+
|
| 187 |
+
if (callback_steps is None) or (
|
| 188 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 189 |
+
):
|
| 190 |
+
raise ValueError(
|
| 191 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 192 |
+
f" {type(callback_steps)}."
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# get the initial random noise unless the user supplied it
|
| 196 |
+
latents_shape = [
|
| 197 |
+
batch_size,
|
| 198 |
+
self.unet.in_channels,
|
| 199 |
+
num_frames,
|
| 200 |
+
height // 8,
|
| 201 |
+
width // 8,
|
| 202 |
+
] # (batch_size, C, N, H, W)
|
| 203 |
+
|
| 204 |
+
if latents is None:
|
| 205 |
+
latents = randn_tensor(
|
| 206 |
+
latents_shape,
|
| 207 |
+
generator=generator,
|
| 208 |
+
)
|
| 209 |
+
else:
|
| 210 |
+
if latents.shape != latents_shape:
|
| 211 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
| 212 |
+
# set timesteps
|
| 213 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
| 214 |
+
|
| 215 |
+
# Some schedulers like PNDM have timesteps as arrays
|
| 216 |
+
# It's more optimized to move all timesteps to correct device beforehand
|
| 217 |
+
timesteps_tensor = self.scheduler.timesteps
|
| 218 |
+
|
| 219 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 220 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 221 |
+
|
| 222 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 223 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 224 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 225 |
+
# and should be between [0, 1]
|
| 226 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 227 |
+
extra_step_kwargs = {}
|
| 228 |
+
if accepts_eta:
|
| 229 |
+
extra_step_kwargs["eta"] = eta
|
| 230 |
+
|
| 231 |
+
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
| 232 |
+
latent_model_input = latents
|
| 233 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 234 |
+
|
| 235 |
+
t_tensor = paddle.expand(
|
| 236 |
+
t,
|
| 237 |
+
[
|
| 238 |
+
latent_model_input.shape[0],
|
| 239 |
+
],
|
| 240 |
+
)
|
| 241 |
+
# predict the noise residual
|
| 242 |
+
noise_pred = self.unet(latent_model_input, t_tensor).sample
|
| 243 |
+
|
| 244 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 245 |
+
latents = self.scheduler.step(noise_pred, t, latents, generator=generator, **extra_step_kwargs).prev_sample
|
| 246 |
+
|
| 247 |
+
# call the callback, if provided
|
| 248 |
+
if callback is not None and i % callback_steps == 0:
|
| 249 |
+
callback(i, t, latents)
|
| 250 |
+
|
| 251 |
+
all_videos = []
|
| 252 |
+
latents = 1.0 / scale_factor * latents - shift_factor
|
| 253 |
+
sampled_videos = self.vae.decode(latents).sample
|
| 254 |
+
all_videos.append(self.paddle_to_np(sampled_videos))
|
| 255 |
+
all_videos = np.concatenate(all_videos, axis=0)
|
| 256 |
+
|
| 257 |
+
# return sampled_videos
|
| 258 |
+
videos_frames = []
|
| 259 |
+
for idx in range(sampled_videos.shape[0]):
|
| 260 |
+
video = sampled_videos[idx]
|
| 261 |
+
video_frames = []
|
| 262 |
+
for fidx in range(video.shape[1]):
|
| 263 |
+
frame = video[:, fidx]
|
| 264 |
+
frame = (frame / 2 + 0.5).clip(0, 1)
|
| 265 |
+
frame = frame.transpose([1, 2, 0]).astype("float32").numpy()
|
| 266 |
+
if output_type == "pil":
|
| 267 |
+
frame = self.numpy_to_pil(frame)
|
| 268 |
+
video_frames.append(frame)
|
| 269 |
+
videos_frames.append(video_frames)
|
| 270 |
+
|
| 271 |
+
if not save_name:
|
| 272 |
+
save_name = "defaul_video"
|
| 273 |
+
if not save_dir:
|
| 274 |
+
save_dir = "."
|
| 275 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 276 |
+
save_results(all_videos, save_dir=save_dir, save_name=save_name, save_fps=save_fps)
|
| 277 |
+
return VideoPipelineOutput(frames=videos_frames, samples=sampled_videos)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/lvdm/video_save.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
import cv2
|
| 18 |
+
import numpy as np
|
| 19 |
+
import paddle
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
import accimage
|
| 24 |
+
except ImportError:
|
| 25 |
+
accimage = None
|
| 26 |
+
import math
|
| 27 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 28 |
+
|
| 29 |
+
from PIL import Image
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
import av
|
| 33 |
+
|
| 34 |
+
av.logging.set_level(av.logging.ERROR)
|
| 35 |
+
if not hasattr(av.video.frame.VideoFrame, "pict_type"):
|
| 36 |
+
av = ImportError("""Your version of PyAV is too old for the necessary video operations.""")
|
| 37 |
+
except ImportError:
|
| 38 |
+
av = ImportError("""PyAV is not installed, and is necessary for the video operations.""")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _check_av_available() -> None:
|
| 42 |
+
if isinstance(av, Exception):
|
| 43 |
+
raise av
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def write_video(
|
| 47 |
+
filename: str,
|
| 48 |
+
video_array: paddle.Tensor,
|
| 49 |
+
fps: float,
|
| 50 |
+
video_codec: str = "libx264",
|
| 51 |
+
options: Optional[Dict[str, Any]] = None,
|
| 52 |
+
audio_array: Optional[paddle.Tensor] = None,
|
| 53 |
+
audio_fps: Optional[float] = None,
|
| 54 |
+
audio_codec: Optional[str] = None,
|
| 55 |
+
audio_options: Optional[Dict[str, Any]] = None,
|
| 56 |
+
) -> None:
|
| 57 |
+
"""
|
| 58 |
+
Writes a 4d tensor in [T, H, W, C] format in a video file
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
filename (str): path where the video will be saved
|
| 62 |
+
video_array (Tensor[T, H, W, C]): tensor containing the individual frames,
|
| 63 |
+
as a uint8 tensor in [T, H, W, C] format
|
| 64 |
+
fps (Number): video frames per second
|
| 65 |
+
video_codec (str): the name of the video codec, i.e. "libx264", "h264", etc.
|
| 66 |
+
options (Dict): dictionary containing options to be passed into the PyAV video stream
|
| 67 |
+
audio_array (Tensor[C, N]): tensor containing the audio, where C is the number of channels
|
| 68 |
+
and N is the number of samples
|
| 69 |
+
audio_fps (Number): audio sample rate, typically 44100 or 48000
|
| 70 |
+
audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc.
|
| 71 |
+
audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream
|
| 72 |
+
"""
|
| 73 |
+
_check_av_available()
|
| 74 |
+
video_array = paddle.to_tensor(data=video_array).astype("uint8").numpy()
|
| 75 |
+
if isinstance(fps, float):
|
| 76 |
+
fps = np.round(fps)
|
| 77 |
+
with av.open(filename, mode="w") as container:
|
| 78 |
+
stream = container.add_stream(video_codec, rate=fps)
|
| 79 |
+
stream.width = video_array.shape[2]
|
| 80 |
+
stream.height = video_array.shape[1]
|
| 81 |
+
stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
|
| 82 |
+
stream.options = options or {}
|
| 83 |
+
if audio_array is not None:
|
| 84 |
+
audio_format_dtypes = {
|
| 85 |
+
"dbl": "<f8",
|
| 86 |
+
"dblp": "<f8",
|
| 87 |
+
"flt": "<f4",
|
| 88 |
+
"fltp": "<f4",
|
| 89 |
+
"s16": "<i2",
|
| 90 |
+
"s16p": "<i2",
|
| 91 |
+
"s32": "<i4",
|
| 92 |
+
"s32p": "<i4",
|
| 93 |
+
"u8": "u1",
|
| 94 |
+
"u8p": "u1",
|
| 95 |
+
}
|
| 96 |
+
a_stream = container.add_stream(audio_codec, rate=audio_fps)
|
| 97 |
+
a_stream.options = audio_options or {}
|
| 98 |
+
num_channels = audio_array.shape[0]
|
| 99 |
+
audio_layout = "stereo" if num_channels > 1 else "mono"
|
| 100 |
+
audio_sample_fmt = container.streams.audio[0].format.name
|
| 101 |
+
format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt])
|
| 102 |
+
audio_array = paddle.to_tensor(data=audio_array).numpy().astype(format_dtype)
|
| 103 |
+
frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout)
|
| 104 |
+
frame.sample_rate = audio_fps
|
| 105 |
+
for packet in a_stream.encode(frame):
|
| 106 |
+
container.mux(packet)
|
| 107 |
+
for packet in a_stream.encode():
|
| 108 |
+
container.mux(packet)
|
| 109 |
+
for img in video_array:
|
| 110 |
+
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
|
| 111 |
+
frame.pict_type = "NONE"
|
| 112 |
+
for packet in stream.encode(frame):
|
| 113 |
+
container.mux(packet)
|
| 114 |
+
for packet in stream.encode():
|
| 115 |
+
container.mux(packet)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@paddle.no_grad()
|
| 119 |
+
def make_grid(
|
| 120 |
+
tensor: Union[paddle.Tensor, List[paddle.Tensor]],
|
| 121 |
+
nrow: int = 8,
|
| 122 |
+
padding: int = 2,
|
| 123 |
+
normalize: bool = False,
|
| 124 |
+
value_range: Optional[Tuple[int, int]] = None,
|
| 125 |
+
scale_each: bool = False,
|
| 126 |
+
pad_value: float = 0.0,
|
| 127 |
+
) -> paddle.Tensor:
|
| 128 |
+
"""
|
| 129 |
+
Make a grid of images.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W)
|
| 133 |
+
or a list of images all of the same size.
|
| 134 |
+
nrow (int, optional): Number of images displayed in each row of the grid.
|
| 135 |
+
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
|
| 136 |
+
padding (int, optional): amount of padding. Default: ``2``.
|
| 137 |
+
normalize (bool, optional): If True, shift the image to the range (0, 1),
|
| 138 |
+
by the min and max values specified by ``value_range``. Default: ``False``.
|
| 139 |
+
value_range (tuple, optional): tuple (min, max) where min and max are numbers,
|
| 140 |
+
then these numbers are used to normalize the image. By default, min and max
|
| 141 |
+
are computed from the tensor.
|
| 142 |
+
scale_each (bool, optional): If ``True``, scale each image in the batch of
|
| 143 |
+
images separately rather than the (min, max) over all images. Default: ``False``.
|
| 144 |
+
pad_value (float, optional): Value for the padded pixels. Default: ``0``.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
grid (Tensor): the tensor containing grid of images.
|
| 148 |
+
"""
|
| 149 |
+
if not paddle.is_tensor(x=tensor):
|
| 150 |
+
if isinstance(tensor, list):
|
| 151 |
+
for t in tensor:
|
| 152 |
+
if not paddle.is_tensor(x=t):
|
| 153 |
+
raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}")
|
| 154 |
+
else:
|
| 155 |
+
raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
|
| 156 |
+
if isinstance(tensor, list):
|
| 157 |
+
tensor = paddle.stack(x=tensor, axis=0)
|
| 158 |
+
if tensor.dim() == 2:
|
| 159 |
+
tensor = tensor.unsqueeze(axis=0)
|
| 160 |
+
if tensor.dim() == 3:
|
| 161 |
+
if tensor.shape[0] == 1:
|
| 162 |
+
tensor = paddle.concat(x=(tensor, tensor, tensor), axis=0)
|
| 163 |
+
tensor = tensor.unsqueeze(axis=0)
|
| 164 |
+
if tensor.dim() == 4 and tensor.shape[1] == 1:
|
| 165 |
+
tensor = paddle.concat(x=(tensor, tensor, tensor), axis=1)
|
| 166 |
+
if normalize is True:
|
| 167 |
+
tensor = tensor.clone()
|
| 168 |
+
if value_range is not None and not isinstance(value_range, tuple):
|
| 169 |
+
raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers")
|
| 170 |
+
|
| 171 |
+
def norm_ip(img, low, high):
|
| 172 |
+
img.clip_(min=low, max=high)
|
| 173 |
+
img = img.substract(low).divide(max(high - low, 1e-05))
|
| 174 |
+
|
| 175 |
+
def norm_range(t, value_range):
|
| 176 |
+
if value_range is not None:
|
| 177 |
+
norm_ip(t, value_range[0], value_range[1])
|
| 178 |
+
else:
|
| 179 |
+
norm_ip(t, float(t.min()), float(t.max()))
|
| 180 |
+
|
| 181 |
+
if scale_each is True:
|
| 182 |
+
for t in tensor:
|
| 183 |
+
norm_range(t, value_range)
|
| 184 |
+
else:
|
| 185 |
+
norm_range(tensor, value_range)
|
| 186 |
+
if not isinstance(tensor, paddle.Tensor):
|
| 187 |
+
raise TypeError("tensor should be of type paddle.Tensor")
|
| 188 |
+
if tensor.shape[0] == 1:
|
| 189 |
+
return tensor.squeeze(axis=0)
|
| 190 |
+
nmaps = tensor.shape[0]
|
| 191 |
+
xmaps = min(nrow, nmaps)
|
| 192 |
+
ymaps = int(math.ceil(float(nmaps) / xmaps))
|
| 193 |
+
height, width = int(tensor.shape[2] + padding), int(tensor.shape[3] + padding)
|
| 194 |
+
num_channels = tensor.shape[1]
|
| 195 |
+
grid = paddle.full(
|
| 196 |
+
shape=(num_channels, height * ymaps + padding, width * xmaps + padding),
|
| 197 |
+
fill_value=pad_value,
|
| 198 |
+
dtype=tensor.dtype,
|
| 199 |
+
)
|
| 200 |
+
k = 0
|
| 201 |
+
for y in range(ymaps):
|
| 202 |
+
for x in range(xmaps):
|
| 203 |
+
if k >= nmaps:
|
| 204 |
+
break
|
| 205 |
+
start_0 = grid.shape[1] + y * height + padding if y * height + padding < 0 else y * height + padding
|
| 206 |
+
start_1 = (
|
| 207 |
+
paddle.slice(grid, [1], [start_0], [start_0 + height - padding]).shape[2] + x * width + padding
|
| 208 |
+
if x * width + padding < 0
|
| 209 |
+
else x * width + padding
|
| 210 |
+
)
|
| 211 |
+
paddle.assign(
|
| 212 |
+
tensor[k],
|
| 213 |
+
output=paddle.slice(
|
| 214 |
+
paddle.slice(grid, [1], [start_0], [start_0 + height - padding]),
|
| 215 |
+
[2],
|
| 216 |
+
[start_1],
|
| 217 |
+
[start_1 + width - padding],
|
| 218 |
+
),
|
| 219 |
+
)
|
| 220 |
+
k = k + 1
|
| 221 |
+
return grid
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def _is_pil_image(img: Any) -> bool:
|
| 225 |
+
if accimage is not None:
|
| 226 |
+
return isinstance(img, (Image.Image, accimage.Image))
|
| 227 |
+
else:
|
| 228 |
+
return isinstance(img, Image.Image)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def get_image_num_channels(img: Any) -> int:
|
| 232 |
+
if _is_pil_image(img):
|
| 233 |
+
if hasattr(img, "getbands"):
|
| 234 |
+
return len(img.getbands())
|
| 235 |
+
else:
|
| 236 |
+
return img.channels
|
| 237 |
+
raise TypeError(f"Unexpected type {type(img)}")
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def to_tensor(pic) -> paddle.Tensor:
|
| 241 |
+
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
|
| 242 |
+
See :class:`~paddle.vision.transforms.ToTensor` for more details.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
Tensor: Converted image.
|
| 249 |
+
"""
|
| 250 |
+
default_float_dtype = paddle.get_default_dtype()
|
| 251 |
+
if isinstance(pic, np.ndarray):
|
| 252 |
+
if pic.ndim == 2:
|
| 253 |
+
pic = pic[:, :, (None)]
|
| 254 |
+
img = paddle.to_tensor(data=pic.transpose((2, 0, 1)))
|
| 255 |
+
if img.dtype == paddle.uint8:
|
| 256 |
+
return paddle.divide(img.cast(default_float_dtype), paddle.to_tensor(255, dtype=paddle.float32))
|
| 257 |
+
else:
|
| 258 |
+
return img
|
| 259 |
+
mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32}
|
| 260 |
+
img = paddle.to_tensor(data=np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))
|
| 261 |
+
if pic.mode == "1":
|
| 262 |
+
img = 255 * img
|
| 263 |
+
img = img.reshape([pic.size[1], pic.size[0], get_image_num_channels(pic)])
|
| 264 |
+
img = img.transpose(perm=(2, 0, 1))
|
| 265 |
+
if img.dtype == paddle.uint8:
|
| 266 |
+
return paddle.divide(img.cast(default_float_dtype), 255)
|
| 267 |
+
else:
|
| 268 |
+
return img
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def load_num_videos(data_path, num_videos):
|
| 272 |
+
if isinstance(data_path, str):
|
| 273 |
+
videos = np.load(data_path)["arr_0"]
|
| 274 |
+
elif isinstance(data_path, np.ndarray):
|
| 275 |
+
videos = data_path
|
| 276 |
+
else:
|
| 277 |
+
raise Exception
|
| 278 |
+
if num_videos is not None:
|
| 279 |
+
videos = videos[:num_videos, :, :, :, :]
|
| 280 |
+
return videos
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def fill_with_black_squares(video, desired_len: int) -> paddle.Tensor:
|
| 284 |
+
if len(video) >= desired_len:
|
| 285 |
+
return video
|
| 286 |
+
return paddle.concat(
|
| 287 |
+
x=[
|
| 288 |
+
video,
|
| 289 |
+
paddle.zeros_like(x=video[0]).unsqueeze(axis=0).tile(repeat_times=[desired_len - len(video), 1, 1, 1]),
|
| 290 |
+
],
|
| 291 |
+
axis=0,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def npz_to_video_grid(data_path, out_path, num_frames=None, fps=8, num_videos=None, nrow=None, verbose=True):
|
| 296 |
+
if isinstance(data_path, str):
|
| 297 |
+
videos = load_num_videos(data_path, num_videos)
|
| 298 |
+
elif isinstance(data_path, np.ndarray):
|
| 299 |
+
videos = data_path
|
| 300 |
+
else:
|
| 301 |
+
raise Exception
|
| 302 |
+
n, t, h, w, c = videos.shape
|
| 303 |
+
videos_th = []
|
| 304 |
+
for i in range(n):
|
| 305 |
+
video = videos[(i), :, :, :, :]
|
| 306 |
+
images = [video[(j), :, :, :] for j in range(t)]
|
| 307 |
+
images = [to_tensor(img) for img in images]
|
| 308 |
+
|
| 309 |
+
video = paddle.stack(x=images)
|
| 310 |
+
videos_th.append(video)
|
| 311 |
+
if num_frames is None:
|
| 312 |
+
num_frames = videos.shape[1]
|
| 313 |
+
if verbose:
|
| 314 |
+
videos = [fill_with_black_squares(v, num_frames) for v in tqdm(videos_th, desc="Adding empty frames")]
|
| 315 |
+
else:
|
| 316 |
+
videos = [fill_with_black_squares(v, num_frames) for v in videos_th]
|
| 317 |
+
frame_grids = paddle.stack(x=videos).transpose(perm=[1, 0, 2, 3, 4])
|
| 318 |
+
if nrow is None:
|
| 319 |
+
nrow = int(np.ceil(np.sqrt(n)))
|
| 320 |
+
if verbose:
|
| 321 |
+
frame_grids = [make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc="Making grids")]
|
| 322 |
+
|
| 323 |
+
else:
|
| 324 |
+
frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids]
|
| 325 |
+
|
| 326 |
+
if os.path.dirname(out_path) != "":
|
| 327 |
+
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
| 328 |
+
if isinstance("uint8", paddle.dtype):
|
| 329 |
+
dtype = "uint8"
|
| 330 |
+
elif isinstance("uint8", str) and "uint8" not in ["cpu", "cuda", "ipu", "xpu"]:
|
| 331 |
+
dtype = "uint8"
|
| 332 |
+
elif isinstance("uint8", paddle.Tensor):
|
| 333 |
+
dtype = "uint8".dtype
|
| 334 |
+
else:
|
| 335 |
+
dtype = (paddle.stack(x=frame_grids) * 255).dtype
|
| 336 |
+
frame_grids = (paddle.stack(x=frame_grids) * 255).transpose(perm=[0, 2, 3, 1]).cast(dtype)
|
| 337 |
+
write_video(out_path, frame_grids, fps=fps, video_codec="h264", options={"crf": "10"})
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def savenp2sheet(imgs, savepath, nrow=None):
|
| 341 |
+
"""save multiple imgs (in numpy array type) to a img sheet.
|
| 342 |
+
img sheet is one row.
|
| 343 |
+
|
| 344 |
+
imgs:
|
| 345 |
+
np array of size [N, H, W, 3] or List[array] with array size = [H,W,3]
|
| 346 |
+
"""
|
| 347 |
+
if imgs.ndim == 4:
|
| 348 |
+
img_list = [imgs[i] for i in range(imgs.shape[0])]
|
| 349 |
+
imgs = img_list
|
| 350 |
+
imgs_new = []
|
| 351 |
+
for i, img in enumerate(imgs):
|
| 352 |
+
if img.ndim == 3 and img.shape[0] == 3:
|
| 353 |
+
img = np.transpose(img, (1, 2, 0))
|
| 354 |
+
assert img.ndim == 3 and img.shape[-1] == 3, img.shape
|
| 355 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 356 |
+
imgs_new.append(img)
|
| 357 |
+
n = len(imgs)
|
| 358 |
+
if nrow is not None:
|
| 359 |
+
n_cols = nrow
|
| 360 |
+
else:
|
| 361 |
+
n_cols = int(n**0.5)
|
| 362 |
+
n_rows = int(np.ceil(n / n_cols))
|
| 363 |
+
print(n_cols)
|
| 364 |
+
print(n_rows)
|
| 365 |
+
imgsheet = cv2.vconcat([cv2.hconcat(imgs_new[i * n_cols : (i + 1) * n_cols]) for i in range(n_rows)])
|
| 366 |
+
cv2.imwrite(savepath, imgsheet)
|
| 367 |
+
print(f"saved in {savepath}")
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def npz_to_imgsheet_5d(data_path, res_dir, nrow=None):
|
| 371 |
+
if isinstance(data_path, str):
|
| 372 |
+
imgs = np.load(data_path)["arr_0"]
|
| 373 |
+
elif isinstance(data_path, np.ndarray):
|
| 374 |
+
imgs = data_path
|
| 375 |
+
else:
|
| 376 |
+
raise Exception
|
| 377 |
+
if os.path.isdir(res_dir):
|
| 378 |
+
res_path = os.path.join(res_dir, "samples.jpg")
|
| 379 |
+
else:
|
| 380 |
+
assert res_dir.endswith(".jpg")
|
| 381 |
+
res_path = res_dir
|
| 382 |
+
imgs = np.concatenate([imgs[i] for i in range(imgs.shape[0])], axis=0)
|
| 383 |
+
savenp2sheet(imgs, res_path, nrow=nrow)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def save_results(
|
| 387 |
+
videos,
|
| 388 |
+
save_dir,
|
| 389 |
+
save_name="results",
|
| 390 |
+
save_fps=8,
|
| 391 |
+
save_mp4=True,
|
| 392 |
+
save_npz=False,
|
| 393 |
+
save_mp4_sheet=False,
|
| 394 |
+
save_jpg=False,
|
| 395 |
+
):
|
| 396 |
+
if save_mp4:
|
| 397 |
+
save_subdir = os.path.join(save_dir, "videos")
|
| 398 |
+
os.makedirs(save_subdir, exist_ok=True)
|
| 399 |
+
shape_str = "x".join([str(x) for x in videos[0:1, (...)].shape])
|
| 400 |
+
for i in range(videos.shape[0]):
|
| 401 |
+
npz_to_video_grid(
|
| 402 |
+
videos[i : i + 1, (...)],
|
| 403 |
+
os.path.join(save_subdir, f"{save_name}_{i:03d}_{shape_str}.mp4"),
|
| 404 |
+
fps=save_fps,
|
| 405 |
+
)
|
| 406 |
+
print(f"Successfully saved videos in {save_subdir}")
|
| 407 |
+
shape_str = "x".join([str(x) for x in videos.shape])
|
| 408 |
+
if save_npz:
|
| 409 |
+
save_path = os.path.join(save_dir, f"{save_name}_{shape_str}.npz")
|
| 410 |
+
np.savez(save_path, videos)
|
| 411 |
+
print(f"Successfully saved npz in {save_path}")
|
| 412 |
+
if save_mp4_sheet:
|
| 413 |
+
save_path = os.path.join(save_dir, f"{save_name}_{shape_str}.mp4")
|
| 414 |
+
npz_to_video_grid(videos, save_path, fps=save_fps)
|
| 415 |
+
print(f"Successfully saved mp4 sheet in {save_path}")
|
| 416 |
+
if save_jpg:
|
| 417 |
+
save_path = os.path.join(save_dir, f"{save_name}_{shape_str}.jpg")
|
| 418 |
+
npz_to_imgsheet_5d(videos, save_path, nrow=videos.shape[1])
|
| 419 |
+
print(f"Successfully saved jpg sheet in {save_path}")
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/paint_by_example/__init__.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import TYPE_CHECKING, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import PIL
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
+
from ...utils import (
|
| 23 |
+
PPDIFFUSERS_SLOW_IMPORT,
|
| 24 |
+
OptionalDependencyNotAvailable,
|
| 25 |
+
_LazyModule,
|
| 26 |
+
get_objects_from_module,
|
| 27 |
+
is_paddle_available,
|
| 28 |
+
is_paddlenlp_available,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
_dummy_objects = {}
|
| 32 |
+
_import_structure = {}
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 36 |
+
raise OptionalDependencyNotAvailable()
|
| 37 |
+
except OptionalDependencyNotAvailable:
|
| 38 |
+
from ...utils import dummy_paddle_and_paddlenlp_objects # noqa F403
|
| 39 |
+
|
| 40 |
+
_dummy_objects.update(get_objects_from_module(dummy_paddle_and_paddlenlp_objects))
|
| 41 |
+
else:
|
| 42 |
+
_import_structure["image_encoder"] = ["PaintByExampleImageEncoder"]
|
| 43 |
+
_import_structure["pipeline_paint_by_example"] = ["PaintByExamplePipeline"]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if TYPE_CHECKING or PPDIFFUSERS_SLOW_IMPORT:
|
| 47 |
+
try:
|
| 48 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 49 |
+
raise OptionalDependencyNotAvailable()
|
| 50 |
+
|
| 51 |
+
except OptionalDependencyNotAvailable:
|
| 52 |
+
from ...utils.dummy_paddle_and_paddlenlp_objects import *
|
| 53 |
+
else:
|
| 54 |
+
from .image_encoder import PaintByExampleImageEncoder
|
| 55 |
+
from .pipeline_paint_by_example import PaintByExamplePipeline
|
| 56 |
+
|
| 57 |
+
else:
|
| 58 |
+
import sys
|
| 59 |
+
|
| 60 |
+
sys.modules[__name__] = _LazyModule(
|
| 61 |
+
__name__,
|
| 62 |
+
globals()["__file__"],
|
| 63 |
+
_import_structure,
|
| 64 |
+
module_spec=__spec__,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
for name, value in _dummy_objects.items():
|
| 68 |
+
setattr(sys.modules[__name__], name, value)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/paint_by_example/image_encoder.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import paddle
|
| 15 |
+
from paddle import nn
|
| 16 |
+
from paddlenlp.utils.converter import StateDictNameMapping
|
| 17 |
+
from paddlenlp.utils.log import logger as ppnlp_logger
|
| 18 |
+
|
| 19 |
+
from ppdiffusers.transformers import (
|
| 20 |
+
CLIPPretrainedModel,
|
| 21 |
+
CLIPVisionConfig,
|
| 22 |
+
CLIPVisionModel,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from ...models.attention import BasicTransformerBlock
|
| 26 |
+
from ...utils import logging
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class PaintByExampleImageEncoder(CLIPPretrainedModel):
|
| 32 |
+
config_class = CLIPVisionConfig
|
| 33 |
+
|
| 34 |
+
@classmethod
|
| 35 |
+
def _update_deprecated_state_dict(cls, state_dict=None, loaded_keys=None, model=None):
|
| 36 |
+
if state_dict is None:
|
| 37 |
+
return loaded_keys
|
| 38 |
+
_deprecated_dict = getattr(cls, "_deprecated_dict", None)
|
| 39 |
+
from_deprecated_state_dict = _deprecated_dict is not None and any(
|
| 40 |
+
cls._deprecated_dict.get("key", "NONE") in all_key for all_key in state_dict.keys()
|
| 41 |
+
)
|
| 42 |
+
if from_deprecated_state_dict:
|
| 43 |
+
ppnlp_logger.warning(
|
| 44 |
+
"Loading from deprecated state_dict, please load new state_dict via setting `use_safetensors=True`."
|
| 45 |
+
)
|
| 46 |
+
for name in list(state_dict.keys()):
|
| 47 |
+
# if name start with prefix "model.", we will convert it
|
| 48 |
+
if name.startswith("model."):
|
| 49 |
+
deprecated_name = name
|
| 50 |
+
for old_name, new_name in cls._deprecated_dict.get("name_mapping", {}).items():
|
| 51 |
+
name = name.replace(old_name, new_name)
|
| 52 |
+
|
| 53 |
+
if ".attn.c_attn." in name and name in state_dict:
|
| 54 |
+
state_dict[name] = paddle.concat([state_dict[name], state_dict.pop(deprecated_name)], axis=-1)
|
| 55 |
+
else:
|
| 56 |
+
state_dict[name] = state_dict.pop(deprecated_name)
|
| 57 |
+
loaded_keys = list(state_dict.keys())
|
| 58 |
+
return loaded_keys
|
| 59 |
+
|
| 60 |
+
@classmethod
|
| 61 |
+
def _get_name_mappings(cls, config: CLIPVisionConfig):
|
| 62 |
+
num_vision_layer = config.num_hidden_layers
|
| 63 |
+
hard_mappings = [
|
| 64 |
+
# other
|
| 65 |
+
["final_layer_norm.weight", "final_layer_norm.weight"],
|
| 66 |
+
["proj_out.weight", "proj_out.weight", "transpose"],
|
| 67 |
+
["proj_out.bias", "proj_out.bias"],
|
| 68 |
+
["uncond_vector", "uncond_vector"],
|
| 69 |
+
# model prefix
|
| 70 |
+
["model.vision_model.embeddings.class_embedding", "model.vision_model.embeddings.class_embedding"],
|
| 71 |
+
[
|
| 72 |
+
"model.vision_model.embeddings.patch_embedding.weight",
|
| 73 |
+
"model.vision_model.embeddings.patch_embedding.weight",
|
| 74 |
+
],
|
| 75 |
+
[
|
| 76 |
+
"model.vision_model.embeddings.position_embedding.weight",
|
| 77 |
+
"model.vision_model.embeddings.position_embedding.weight",
|
| 78 |
+
],
|
| 79 |
+
["model.vision_model.pre_layrnorm.weight", "model.vision_model.pre_layrnorm.weight"],
|
| 80 |
+
["model.vision_model.pre_layrnorm.bias", "model.vision_model.pre_layrnorm.bias"],
|
| 81 |
+
["model.vision_model.post_layernorm.weight", "model.vision_model.post_layernorm.weight"],
|
| 82 |
+
["model.vision_model.post_layernorm.bias", "model.vision_model.post_layernorm.bias"],
|
| 83 |
+
]
|
| 84 |
+
for layer_index in range(num_vision_layer):
|
| 85 |
+
for name in [
|
| 86 |
+
"self_attn.q_proj",
|
| 87 |
+
"self_attn.k_proj",
|
| 88 |
+
"self_attn.v_proj",
|
| 89 |
+
"self_attn.out_proj",
|
| 90 |
+
"mlp.fc1",
|
| 91 |
+
"mlp.fc2",
|
| 92 |
+
"layer_norm1",
|
| 93 |
+
"layer_norm2",
|
| 94 |
+
]:
|
| 95 |
+
action = None if "norm" in name else "transpose"
|
| 96 |
+
# model prefix
|
| 97 |
+
hard_mappings.extend(
|
| 98 |
+
[
|
| 99 |
+
[
|
| 100 |
+
f"model.vision_model.encoder.layers.{layer_index}.{name}.weight",
|
| 101 |
+
f"model.vision_model.encoder.layers.{layer_index}.{name}.weight",
|
| 102 |
+
action,
|
| 103 |
+
],
|
| 104 |
+
[
|
| 105 |
+
f"model.vision_model.encoder.layers.{layer_index}.{name}.bias",
|
| 106 |
+
f"model.vision_model.encoder.layers.{layer_index}.{name}.bias",
|
| 107 |
+
],
|
| 108 |
+
]
|
| 109 |
+
)
|
| 110 |
+
num_mapper_layer = (config.num_hidden_layers + 1) // 5
|
| 111 |
+
for layer_index in range(num_mapper_layer):
|
| 112 |
+
# mapper prefix
|
| 113 |
+
for name in [
|
| 114 |
+
"attn1.to_q",
|
| 115 |
+
"attn1.to_k",
|
| 116 |
+
"attn1.to_v",
|
| 117 |
+
"attn1.to_out",
|
| 118 |
+
"ff.net.0.proj",
|
| 119 |
+
"ff.net.2",
|
| 120 |
+
"norm1",
|
| 121 |
+
"norm3",
|
| 122 |
+
]:
|
| 123 |
+
action = None if "norm" in name else "transpose"
|
| 124 |
+
# model prefix
|
| 125 |
+
hard_mappings.extend(
|
| 126 |
+
[
|
| 127 |
+
[
|
| 128 |
+
f"mapper.blocks.{layer_index}.{name}.weight",
|
| 129 |
+
f"mapper.blocks.{layer_index}.{name}.weight",
|
| 130 |
+
action,
|
| 131 |
+
],
|
| 132 |
+
[
|
| 133 |
+
f"mapper.blocks.{layer_index}.{name}.bias",
|
| 134 |
+
f"mapper.blocks.{layer_index}.{name}.bias",
|
| 135 |
+
],
|
| 136 |
+
]
|
| 137 |
+
)
|
| 138 |
+
mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(hard_mappings)]
|
| 139 |
+
return mappings
|
| 140 |
+
|
| 141 |
+
def __init__(self, config: CLIPVisionConfig, proj_size=None):
|
| 142 |
+
super().__init__(config)
|
| 143 |
+
self.proj_size = proj_size or getattr(config, "projection_dim", 768)
|
| 144 |
+
|
| 145 |
+
self.model = CLIPVisionModel(config)
|
| 146 |
+
self.mapper = PaintByExampleMapper(config)
|
| 147 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
|
| 148 |
+
self.proj_out = nn.Linear(config.hidden_size, self.proj_size)
|
| 149 |
+
|
| 150 |
+
# uncondition for scaling
|
| 151 |
+
self.uncond_vector = nn.Parameter(paddle.randn((1, 1, self.proj_size)))
|
| 152 |
+
|
| 153 |
+
def forward(self, pixel_values, return_uncond_vector=False):
|
| 154 |
+
clip_output = self.model(pixel_values=pixel_values)
|
| 155 |
+
latent_states = clip_output.pooler_output
|
| 156 |
+
latent_states = self.mapper(latent_states[:, None])
|
| 157 |
+
latent_states = self.final_layer_norm(latent_states)
|
| 158 |
+
latent_states = self.proj_out(latent_states)
|
| 159 |
+
if return_uncond_vector:
|
| 160 |
+
return latent_states, self.uncond_vector
|
| 161 |
+
|
| 162 |
+
return latent_states
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class PaintByExampleMapper(nn.Layer):
|
| 166 |
+
def __init__(self, config):
|
| 167 |
+
super().__init__()
|
| 168 |
+
num_layers = (config.num_hidden_layers + 1) // 5
|
| 169 |
+
hid_size = config.hidden_size
|
| 170 |
+
num_heads = 1
|
| 171 |
+
self.blocks = nn.LayerList(
|
| 172 |
+
[
|
| 173 |
+
BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True)
|
| 174 |
+
for _ in range(num_layers)
|
| 175 |
+
]
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def forward(self, hidden_states):
|
| 179 |
+
for block in self.blocks:
|
| 180 |
+
hidden_states = block(hidden_states)
|
| 181 |
+
|
| 182 |
+
return hidden_states
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
ADDED
|
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Callable, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import paddle
|
| 20 |
+
import PIL.Image
|
| 21 |
+
|
| 22 |
+
from ppdiffusers.transformers import CLIPImageProcessor
|
| 23 |
+
|
| 24 |
+
from ...image_processor import VaeImageProcessor
|
| 25 |
+
from ...models import AutoencoderKL, UNet2DConditionModel
|
| 26 |
+
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
| 27 |
+
from ...utils import deprecate, logging
|
| 28 |
+
from ...utils.paddle_utils import randn_tensor
|
| 29 |
+
from ..pipeline_utils import DiffusionPipeline
|
| 30 |
+
from ..stable_diffusion import StableDiffusionPipelineOutput
|
| 31 |
+
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 32 |
+
from .image_encoder import PaintByExampleImageEncoder
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 38 |
+
def retrieve_latents(
|
| 39 |
+
encoder_output: paddle.Tensor, generator: Optional[paddle.Generator] = None, sample_mode: str = "sample"
|
| 40 |
+
):
|
| 41 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 42 |
+
return encoder_output.latent_dist.sample(generator)
|
| 43 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 44 |
+
return encoder_output.latent_dist.mode()
|
| 45 |
+
elif hasattr(encoder_output, "latents"):
|
| 46 |
+
return encoder_output.latents
|
| 47 |
+
else:
|
| 48 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def prepare_mask_and_masked_image(image, mask):
|
| 52 |
+
"""
|
| 53 |
+
Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be
|
| 54 |
+
converted to ``paddle.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
|
| 55 |
+
``image`` and ``1`` for the ``mask``.
|
| 56 |
+
|
| 57 |
+
The ``image`` will be converted to ``paddle.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
|
| 58 |
+
binarized (``mask > 0.5``) and cast to ``paddle.float32`` too.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
image (Union[np.array, PIL.Image, paddle.Tensor]): The image to inpaint.
|
| 62 |
+
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
|
| 63 |
+
``paddle.Tensor`` or a ``batch x channels x height x width`` ``paddle.Tensor``.
|
| 64 |
+
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
|
| 65 |
+
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
|
| 66 |
+
``paddle.Tensor`` or a ``batch x 1 x height x width`` ``paddle.Tensor``.
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
Raises:
|
| 70 |
+
ValueError: ``paddle.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``paddle.Tensor`` mask
|
| 71 |
+
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
|
| 72 |
+
TypeError: ``mask`` is a ``paddle.Tensor`` but ``image`` is not
|
| 73 |
+
(ot the other way around).
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
tuple[paddle.Tensor]: The pair (mask, masked_image) as ``paddle.Tensor`` with 4
|
| 77 |
+
dimensions: ``batch x channels x height x width``.
|
| 78 |
+
"""
|
| 79 |
+
if isinstance(image, paddle.Tensor):
|
| 80 |
+
if not isinstance(mask, paddle.Tensor):
|
| 81 |
+
raise TypeError(f"`image` is a paddle.Tensor but `mask` (type: {type(mask)} is not")
|
| 82 |
+
|
| 83 |
+
# Batch single image
|
| 84 |
+
if image.ndim == 3:
|
| 85 |
+
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
|
| 86 |
+
image = image.unsqueeze(0)
|
| 87 |
+
|
| 88 |
+
# Batch and add channel dim for single mask
|
| 89 |
+
if mask.ndim == 2:
|
| 90 |
+
mask = mask.unsqueeze(0).unsqueeze(0)
|
| 91 |
+
|
| 92 |
+
# Batch single mask or add channel dim
|
| 93 |
+
if mask.ndim == 3:
|
| 94 |
+
# Batched mask
|
| 95 |
+
if mask.shape[0] == image.shape[0]:
|
| 96 |
+
mask = mask.unsqueeze(1)
|
| 97 |
+
else:
|
| 98 |
+
mask = mask.unsqueeze(0)
|
| 99 |
+
|
| 100 |
+
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
|
| 101 |
+
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
|
| 102 |
+
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
|
| 103 |
+
assert mask.shape[1] == 1, "Mask image must have a single channel"
|
| 104 |
+
|
| 105 |
+
# Check image is in [-1, 1]
|
| 106 |
+
if image.min() < -1 or image.max() > 1:
|
| 107 |
+
raise ValueError("Image should be in [-1, 1] range")
|
| 108 |
+
|
| 109 |
+
# Check mask is in [0, 1]
|
| 110 |
+
if mask.min() < 0 or mask.max() > 1:
|
| 111 |
+
raise ValueError("Mask should be in [0, 1] range")
|
| 112 |
+
|
| 113 |
+
# paint-by-example inverses the mask
|
| 114 |
+
mask = 1 - mask
|
| 115 |
+
|
| 116 |
+
# Binarize mask
|
| 117 |
+
mask[mask < 0.5] = 0
|
| 118 |
+
mask[mask >= 0.5] = 1
|
| 119 |
+
|
| 120 |
+
# Image as float32
|
| 121 |
+
image = image.cast(dtype=paddle.float32)
|
| 122 |
+
elif isinstance(mask, paddle.Tensor):
|
| 123 |
+
raise TypeError(f"`mask` is a paddle.Tensor but `image` (type: {type(image)} is not")
|
| 124 |
+
else:
|
| 125 |
+
if isinstance(image, PIL.Image.Image):
|
| 126 |
+
image = [image]
|
| 127 |
+
|
| 128 |
+
image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0)
|
| 129 |
+
image = image.transpose(0, 3, 1, 2)
|
| 130 |
+
image = paddle.to_tensor(image).cast(dtype=paddle.float32) / 127.5 - 1.0
|
| 131 |
+
|
| 132 |
+
# preprocess mask
|
| 133 |
+
if isinstance(mask, PIL.Image.Image):
|
| 134 |
+
mask = [mask]
|
| 135 |
+
|
| 136 |
+
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
|
| 137 |
+
mask = mask.astype(np.float32) / 255.0
|
| 138 |
+
|
| 139 |
+
# paint-by-example inverses the mask
|
| 140 |
+
mask = 1 - mask
|
| 141 |
+
|
| 142 |
+
mask[mask < 0.5] = 0
|
| 143 |
+
mask[mask >= 0.5] = 1
|
| 144 |
+
mask = paddle.to_tensor(mask)
|
| 145 |
+
|
| 146 |
+
masked_image = image * mask
|
| 147 |
+
|
| 148 |
+
return mask, masked_image
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class PaintByExamplePipeline(DiffusionPipeline):
|
| 152 |
+
r"""
|
| 153 |
+
<Tip warning={true}>
|
| 154 |
+
|
| 155 |
+
🧪 This is an experimental feature!
|
| 156 |
+
|
| 157 |
+
</Tip>
|
| 158 |
+
|
| 159 |
+
Pipeline for image-guided image inpainting using Stable Diffusion.
|
| 160 |
+
|
| 161 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 162 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
vae ([`AutoencoderKL`]):
|
| 166 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 167 |
+
image_encoder ([`PaintByExampleImageEncoder`]):
|
| 168 |
+
Encodes the example input image. The `unet` is conditioned on the example image instead of a text prompt.
|
| 169 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
| 170 |
+
A `CLIPTokenizer` to tokenize text.
|
| 171 |
+
unet ([`UNet2DConditionModel`]):
|
| 172 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
| 173 |
+
scheduler ([`SchedulerMixin`]):
|
| 174 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 175 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 176 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
| 177 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 178 |
+
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
| 179 |
+
about a model's potential harms.
|
| 180 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
| 181 |
+
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
| 182 |
+
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
# TODO: feature_extractor is required to encode initial images (if they are in PIL format),
|
| 186 |
+
# we should give a descriptive message if the pipeline doesn't have one.
|
| 187 |
+
|
| 188 |
+
model_cpu_offload_seq = "unet->vae"
|
| 189 |
+
_exclude_from_cpu_offload = ["image_encoder"]
|
| 190 |
+
_optional_components = ["safety_checker"]
|
| 191 |
+
|
| 192 |
+
def __init__(
|
| 193 |
+
self,
|
| 194 |
+
vae: AutoencoderKL,
|
| 195 |
+
image_encoder: PaintByExampleImageEncoder,
|
| 196 |
+
unet: UNet2DConditionModel,
|
| 197 |
+
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
| 198 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 199 |
+
feature_extractor: CLIPImageProcessor,
|
| 200 |
+
requires_safety_checker: bool = False,
|
| 201 |
+
):
|
| 202 |
+
super().__init__()
|
| 203 |
+
|
| 204 |
+
self.register_modules(
|
| 205 |
+
vae=vae,
|
| 206 |
+
image_encoder=image_encoder,
|
| 207 |
+
unet=unet,
|
| 208 |
+
scheduler=scheduler,
|
| 209 |
+
safety_checker=safety_checker,
|
| 210 |
+
feature_extractor=feature_extractor,
|
| 211 |
+
)
|
| 212 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 213 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 214 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
| 215 |
+
|
| 216 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
| 217 |
+
def run_safety_checker(self, image, dtype):
|
| 218 |
+
if self.safety_checker is None:
|
| 219 |
+
has_nsfw_concept = None
|
| 220 |
+
else:
|
| 221 |
+
if paddle.is_tensor(image):
|
| 222 |
+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
| 223 |
+
else:
|
| 224 |
+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
| 225 |
+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pd")
|
| 226 |
+
image, has_nsfw_concept = self.safety_checker(
|
| 227 |
+
images=image, clip_input=safety_checker_input.pixel_values.cast(dtype=dtype)
|
| 228 |
+
)
|
| 229 |
+
return image, has_nsfw_concept
|
| 230 |
+
|
| 231 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 232 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 233 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 234 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 235 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 236 |
+
# and should be between [0, 1]
|
| 237 |
+
|
| 238 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 239 |
+
extra_step_kwargs = {}
|
| 240 |
+
if accepts_eta:
|
| 241 |
+
extra_step_kwargs["eta"] = eta
|
| 242 |
+
|
| 243 |
+
# check if the scheduler accepts generator
|
| 244 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 245 |
+
if accepts_generator:
|
| 246 |
+
extra_step_kwargs["generator"] = generator
|
| 247 |
+
return extra_step_kwargs
|
| 248 |
+
|
| 249 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
| 250 |
+
def decode_latents(self, latents):
|
| 251 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
| 252 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
| 253 |
+
|
| 254 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 255 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 256 |
+
image = (image / 2 + 0.5).clip(0, 1)
|
| 257 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 258 |
+
image = image.cast(dtype=paddle.float32).transpose([0, 2, 3, 1]).cpu().numpy()
|
| 259 |
+
return image
|
| 260 |
+
|
| 261 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs
|
| 262 |
+
def check_inputs(self, image, height, width, callback_steps):
|
| 263 |
+
if (
|
| 264 |
+
not isinstance(image, paddle.Tensor)
|
| 265 |
+
and not isinstance(image, PIL.Image.Image)
|
| 266 |
+
and not isinstance(image, list)
|
| 267 |
+
):
|
| 268 |
+
raise ValueError(
|
| 269 |
+
"`image` has to be of type `paddle.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
| 270 |
+
f" {type(image)}"
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 274 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 275 |
+
|
| 276 |
+
if (callback_steps is None) or (
|
| 277 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 278 |
+
):
|
| 279 |
+
raise ValueError(
|
| 280 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 281 |
+
f" {type(callback_steps)}."
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
| 285 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
|
| 286 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
| 287 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 288 |
+
raise ValueError(
|
| 289 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 290 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
if latents is None:
|
| 294 |
+
latents = randn_tensor(shape, generator=generator, dtype=dtype)
|
| 295 |
+
else:
|
| 296 |
+
latents = latents.cast(dtype)
|
| 297 |
+
|
| 298 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 299 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 300 |
+
return latents
|
| 301 |
+
|
| 302 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
|
| 303 |
+
def prepare_mask_latents(
|
| 304 |
+
self, mask, masked_image, batch_size, height, width, dtype, generator, do_classifier_free_guidance
|
| 305 |
+
):
|
| 306 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
| 307 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
| 308 |
+
# and half precision
|
| 309 |
+
mask = paddle.nn.functional.interpolate(
|
| 310 |
+
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
|
| 311 |
+
)
|
| 312 |
+
mask = mask.cast(dtype=dtype)
|
| 313 |
+
|
| 314 |
+
masked_image = masked_image.cast(dtype=dtype)
|
| 315 |
+
|
| 316 |
+
if masked_image.shape[1] == 4:
|
| 317 |
+
masked_image_latents = masked_image
|
| 318 |
+
else:
|
| 319 |
+
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
|
| 320 |
+
|
| 321 |
+
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
| 322 |
+
if mask.shape[0] < batch_size:
|
| 323 |
+
if not batch_size % mask.shape[0] == 0:
|
| 324 |
+
raise ValueError(
|
| 325 |
+
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
| 326 |
+
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
| 327 |
+
" of masks that you pass is divisible by the total requested batch size."
|
| 328 |
+
)
|
| 329 |
+
mask = mask.tile([batch_size // mask.shape[0], 1, 1, 1])
|
| 330 |
+
if masked_image_latents.shape[0] < batch_size:
|
| 331 |
+
if not batch_size % masked_image_latents.shape[0] == 0:
|
| 332 |
+
raise ValueError(
|
| 333 |
+
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
| 334 |
+
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
| 335 |
+
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
| 336 |
+
)
|
| 337 |
+
masked_image_latents = masked_image_latents.tile([batch_size // masked_image_latents.shape[0], 1, 1, 1])
|
| 338 |
+
|
| 339 |
+
mask = paddle.concat([mask] * 2) if do_classifier_free_guidance else mask
|
| 340 |
+
masked_image_latents = (
|
| 341 |
+
paddle.concat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# aligning device to prevent device errors when concating it with the latent model input
|
| 345 |
+
masked_image_latents = masked_image_latents.cast(dtype=dtype)
|
| 346 |
+
return mask, masked_image_latents
|
| 347 |
+
|
| 348 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
|
| 349 |
+
def _encode_vae_image(self, image: paddle.Tensor, generator: paddle.Generator):
|
| 350 |
+
if isinstance(generator, list):
|
| 351 |
+
image_latents = [
|
| 352 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
| 353 |
+
for i in range(image.shape[0])
|
| 354 |
+
]
|
| 355 |
+
image_latents = paddle.concat(image_latents, axis=0)
|
| 356 |
+
else:
|
| 357 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
| 358 |
+
|
| 359 |
+
image_latents = self.vae.config.scaling_factor * image_latents
|
| 360 |
+
|
| 361 |
+
return image_latents
|
| 362 |
+
|
| 363 |
+
def _encode_image(self, image, num_images_per_prompt, do_classifier_free_guidance):
|
| 364 |
+
dtype = next(self.image_encoder.named_parameters())[1].dtype
|
| 365 |
+
|
| 366 |
+
if not isinstance(image, paddle.Tensor):
|
| 367 |
+
image = self.feature_extractor(images=image, return_tensors="pd").pixel_values
|
| 368 |
+
|
| 369 |
+
image = image.cast(dtype=dtype)
|
| 370 |
+
image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True)
|
| 371 |
+
|
| 372 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
| 373 |
+
bs_embed, seq_len, _ = image_embeddings.shape
|
| 374 |
+
image_embeddings = image_embeddings.tile([1, num_images_per_prompt, 1])
|
| 375 |
+
image_embeddings = image_embeddings.reshape([bs_embed * num_images_per_prompt, seq_len, -1])
|
| 376 |
+
|
| 377 |
+
if do_classifier_free_guidance:
|
| 378 |
+
negative_prompt_embeds = negative_prompt_embeds.tile([1, image_embeddings.shape[0], 1])
|
| 379 |
+
negative_prompt_embeds = negative_prompt_embeds.reshape([bs_embed * num_images_per_prompt, 1, -1])
|
| 380 |
+
|
| 381 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 382 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 383 |
+
# to avoid doing two forward passes
|
| 384 |
+
image_embeddings = paddle.concat([negative_prompt_embeds, image_embeddings])
|
| 385 |
+
|
| 386 |
+
return image_embeddings
|
| 387 |
+
|
| 388 |
+
@paddle.no_grad()
|
| 389 |
+
def __call__(
|
| 390 |
+
self,
|
| 391 |
+
example_image: Union[paddle.Tensor, PIL.Image.Image],
|
| 392 |
+
image: Union[paddle.Tensor, PIL.Image.Image],
|
| 393 |
+
mask_image: Union[paddle.Tensor, PIL.Image.Image],
|
| 394 |
+
height: Optional[int] = None,
|
| 395 |
+
width: Optional[int] = None,
|
| 396 |
+
num_inference_steps: int = 50,
|
| 397 |
+
guidance_scale: float = 5.0,
|
| 398 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 399 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 400 |
+
eta: float = 0.0,
|
| 401 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 402 |
+
latents: Optional[paddle.Tensor] = None,
|
| 403 |
+
output_type: Optional[str] = "pil",
|
| 404 |
+
return_dict: bool = True,
|
| 405 |
+
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
|
| 406 |
+
callback_steps: int = 1,
|
| 407 |
+
):
|
| 408 |
+
r"""
|
| 409 |
+
The call function to the pipeline for generation.
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
example_image (`paddle.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`):
|
| 413 |
+
An example image to guide image generation.
|
| 414 |
+
image (`paddle.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`):
|
| 415 |
+
`Image` or tensor representing an image batch to be inpainted (parts of the image are masked out with
|
| 416 |
+
`mask_image` and repainted according to `prompt`).
|
| 417 |
+
mask_image (`paddle.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`):
|
| 418 |
+
`Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted,
|
| 419 |
+
while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel
|
| 420 |
+
(luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the
|
| 421 |
+
expected shape would be `(B, H, W, 1)`.
|
| 422 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 423 |
+
The height in pixels of the generated image.
|
| 424 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 425 |
+
The width in pixels of the generated image.
|
| 426 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 427 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 428 |
+
expense of slower inference.
|
| 429 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 430 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 431 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 432 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 433 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 434 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 435 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 436 |
+
The number of images to generate per prompt.
|
| 437 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 438 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
| 439 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 440 |
+
generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*):
|
| 441 |
+
A [`paddle.Generator`] to make generation deterministic.
|
| 442 |
+
|
| 443 |
+
latents (`paddle.Tensor`, *optional*):
|
| 444 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 445 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 446 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 447 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 448 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 449 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 450 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 451 |
+
plain tuple.
|
| 452 |
+
callback (`Callable`, *optional*):
|
| 453 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 454 |
+
following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
|
| 455 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 456 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 457 |
+
every step.
|
| 458 |
+
|
| 459 |
+
Example:
|
| 460 |
+
|
| 461 |
+
```py
|
| 462 |
+
>>> import PIL
|
| 463 |
+
>>> import requests
|
| 464 |
+
>>> import paddle
|
| 465 |
+
>>> from io import BytesIO
|
| 466 |
+
>>> from ppdiffusers import PaintByExamplePipeline
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
>>> def download_image(url):
|
| 470 |
+
... response = requests.get(url)
|
| 471 |
+
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
>>> img_url = (
|
| 475 |
+
... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/image/example_1.png"
|
| 476 |
+
... )
|
| 477 |
+
>>> mask_url = (
|
| 478 |
+
... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/mask/example_1.png"
|
| 479 |
+
... )
|
| 480 |
+
>>> example_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/reference/example_1.jpg"
|
| 481 |
+
|
| 482 |
+
>>> init_image = download_image(img_url).resize((512, 512))
|
| 483 |
+
>>> mask_image = download_image(mask_url).resize((512, 512))
|
| 484 |
+
>>> example_image = download_image(example_url).resize((512, 512))
|
| 485 |
+
|
| 486 |
+
>>> pipe = PaintByExamplePipeline.from_pretrained(
|
| 487 |
+
... "Fantasy-Studio/Paint-by-Example",
|
| 488 |
+
... paddle_dtype=paddle.float16,
|
| 489 |
+
... )
|
| 490 |
+
|
| 491 |
+
>>> image = pipe(image=init_image, mask_image=mask_image, example_image=example_image).images[0]
|
| 492 |
+
>>> image
|
| 493 |
+
```
|
| 494 |
+
|
| 495 |
+
Returns:
|
| 496 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 497 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| 498 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
| 499 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
| 500 |
+
"not-safe-for-work" (nsfw) content.
|
| 501 |
+
"""
|
| 502 |
+
# 1. Define call parameters
|
| 503 |
+
if isinstance(image, PIL.Image.Image):
|
| 504 |
+
batch_size = 1
|
| 505 |
+
elif isinstance(image, list):
|
| 506 |
+
batch_size = len(image)
|
| 507 |
+
else:
|
| 508 |
+
batch_size = image.shape[0]
|
| 509 |
+
|
| 510 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 511 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 512 |
+
# corresponds to doing no classifier free guidance.
|
| 513 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 514 |
+
|
| 515 |
+
# 2. Preprocess mask and image
|
| 516 |
+
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
|
| 517 |
+
height, width = masked_image.shape[-2:]
|
| 518 |
+
|
| 519 |
+
# 3. Check inputs
|
| 520 |
+
self.check_inputs(example_image, height, width, callback_steps)
|
| 521 |
+
|
| 522 |
+
# 4. Encode input image
|
| 523 |
+
image_embeddings = self._encode_image(example_image, num_images_per_prompt, do_classifier_free_guidance)
|
| 524 |
+
|
| 525 |
+
# 5. set timesteps
|
| 526 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
| 527 |
+
timesteps = self.scheduler.timesteps
|
| 528 |
+
|
| 529 |
+
# 6. Prepare latent variables
|
| 530 |
+
num_channels_latents = self.vae.config.latent_channels
|
| 531 |
+
latents = self.prepare_latents(
|
| 532 |
+
batch_size * num_images_per_prompt,
|
| 533 |
+
num_channels_latents,
|
| 534 |
+
height,
|
| 535 |
+
width,
|
| 536 |
+
image_embeddings.dtype,
|
| 537 |
+
generator,
|
| 538 |
+
latents,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
# 7. Prepare mask latent variables
|
| 542 |
+
mask, masked_image_latents = self.prepare_mask_latents(
|
| 543 |
+
mask,
|
| 544 |
+
masked_image,
|
| 545 |
+
batch_size * num_images_per_prompt,
|
| 546 |
+
height,
|
| 547 |
+
width,
|
| 548 |
+
image_embeddings.dtype,
|
| 549 |
+
generator,
|
| 550 |
+
do_classifier_free_guidance,
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
# 8. Check that sizes of mask, masked image and latents match
|
| 554 |
+
num_channels_mask = mask.shape[1]
|
| 555 |
+
num_channels_masked_image = masked_image_latents.shape[1]
|
| 556 |
+
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
| 557 |
+
raise ValueError(
|
| 558 |
+
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
| 559 |
+
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
| 560 |
+
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
| 561 |
+
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
| 562 |
+
" `pipeline.unet` or your `mask_image` or `image` input."
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 566 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 567 |
+
|
| 568 |
+
# 10. Denoising loop
|
| 569 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 570 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 571 |
+
for i, t in enumerate(timesteps):
|
| 572 |
+
# expand the latents if we are doing classifier free guidance
|
| 573 |
+
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
|
| 574 |
+
|
| 575 |
+
# concat latents, mask, masked_image_latents in the channel dimension
|
| 576 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 577 |
+
latent_model_input = latent_model_input.cast(masked_image_latents.dtype)
|
| 578 |
+
latent_model_input = paddle.concat([latent_model_input, masked_image_latents, mask], axis=1)
|
| 579 |
+
|
| 580 |
+
# predict the noise residual
|
| 581 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
|
| 582 |
+
|
| 583 |
+
# perform guidance
|
| 584 |
+
if do_classifier_free_guidance:
|
| 585 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 586 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 587 |
+
|
| 588 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 589 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 590 |
+
|
| 591 |
+
# call the callback, if provided
|
| 592 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 593 |
+
progress_bar.update()
|
| 594 |
+
if callback is not None and i % callback_steps == 0:
|
| 595 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 596 |
+
callback(step_idx, t, latents)
|
| 597 |
+
|
| 598 |
+
if not output_type == "latent":
|
| 599 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 600 |
+
image, has_nsfw_concept = self.run_safety_checker(image, image_embeddings.dtype)
|
| 601 |
+
else:
|
| 602 |
+
image = latents
|
| 603 |
+
has_nsfw_concept = None
|
| 604 |
+
|
| 605 |
+
if has_nsfw_concept is None:
|
| 606 |
+
do_denormalize = [True] * image.shape[0]
|
| 607 |
+
else:
|
| 608 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 609 |
+
|
| 610 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 611 |
+
|
| 612 |
+
if not return_dict:
|
| 613 |
+
return (image, has_nsfw_concept)
|
| 614 |
+
|
| 615 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/pixart_alpha/__init__.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
from ...utils import (
|
| 18 |
+
PPDIFFUSERS_SLOW_IMPORT,
|
| 19 |
+
OptionalDependencyNotAvailable,
|
| 20 |
+
_LazyModule,
|
| 21 |
+
get_objects_from_module,
|
| 22 |
+
is_paddle_available,
|
| 23 |
+
is_paddlenlp_available,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
_dummy_objects = {}
|
| 27 |
+
_import_structure = {}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 32 |
+
raise OptionalDependencyNotAvailable()
|
| 33 |
+
except OptionalDependencyNotAvailable:
|
| 34 |
+
from ...utils import dummy_paddle_and_paddlenlp_objects # noqa F403
|
| 35 |
+
|
| 36 |
+
_dummy_objects.update(get_objects_from_module(dummy_paddle_and_paddlenlp_objects))
|
| 37 |
+
else:
|
| 38 |
+
_import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"]
|
| 39 |
+
|
| 40 |
+
if TYPE_CHECKING or PPDIFFUSERS_SLOW_IMPORT:
|
| 41 |
+
try:
|
| 42 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 43 |
+
raise OptionalDependencyNotAvailable()
|
| 44 |
+
|
| 45 |
+
except OptionalDependencyNotAvailable:
|
| 46 |
+
from ...utils.dummy_paddle_and_paddlenlp_objects import *
|
| 47 |
+
else:
|
| 48 |
+
from .pipeline_pixart_alpha import PixArtAlphaPipeline
|
| 49 |
+
|
| 50 |
+
else:
|
| 51 |
+
import sys
|
| 52 |
+
|
| 53 |
+
sys.modules[__name__] = _LazyModule(
|
| 54 |
+
__name__,
|
| 55 |
+
globals()["__file__"],
|
| 56 |
+
_import_structure,
|
| 57 |
+
module_spec=__spec__,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
for name, value in _dummy_objects.items():
|
| 61 |
+
setattr(sys.modules[__name__], name, value)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
ADDED
|
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import html
|
| 16 |
+
import inspect
|
| 17 |
+
import re
|
| 18 |
+
import urllib.parse as ul
|
| 19 |
+
from typing import Callable, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import paddle
|
| 22 |
+
import paddle.nn.functional as F
|
| 23 |
+
|
| 24 |
+
from ppdiffusers.transformers import T5EncoderModel, T5Tokenizer
|
| 25 |
+
|
| 26 |
+
from ...image_processor import VaeImageProcessor
|
| 27 |
+
from ...models import AutoencoderKL, Transformer2DModel
|
| 28 |
+
from ...schedulers import DPMSolverMultistepScheduler
|
| 29 |
+
from ...utils import (
|
| 30 |
+
BACKENDS_MAPPING,
|
| 31 |
+
deprecate,
|
| 32 |
+
is_bs4_available,
|
| 33 |
+
is_ftfy_available,
|
| 34 |
+
logging,
|
| 35 |
+
replace_example_docstring,
|
| 36 |
+
)
|
| 37 |
+
from ...utils.paddle_utils import randn_tensor
|
| 38 |
+
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 41 |
+
|
| 42 |
+
if is_bs4_available():
|
| 43 |
+
from bs4 import BeautifulSoup
|
| 44 |
+
|
| 45 |
+
if is_ftfy_available():
|
| 46 |
+
import ftfy
|
| 47 |
+
|
| 48 |
+
EXAMPLE_DOC_STRING = """
|
| 49 |
+
Examples:
|
| 50 |
+
```py
|
| 51 |
+
>>> import paddle
|
| 52 |
+
>>> from ppdiffusers import PixArtAlphaPipeline
|
| 53 |
+
|
| 54 |
+
>>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
|
| 55 |
+
>>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", paddle_dtype=paddle.float16)
|
| 56 |
+
|
| 57 |
+
>>> prompt = "A small cactus with a happy face in the Sahara desert."
|
| 58 |
+
>>> image = pipe(prompt).images[0]
|
| 59 |
+
```
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
ASPECT_RATIO_1024_BIN = {
|
| 63 |
+
"0.25": [512.0, 2048.0],
|
| 64 |
+
"0.28": [512.0, 1856.0],
|
| 65 |
+
"0.32": [576.0, 1792.0],
|
| 66 |
+
"0.33": [576.0, 1728.0],
|
| 67 |
+
"0.35": [576.0, 1664.0],
|
| 68 |
+
"0.4": [640.0, 1600.0],
|
| 69 |
+
"0.42": [640.0, 1536.0],
|
| 70 |
+
"0.48": [704.0, 1472.0],
|
| 71 |
+
"0.5": [704.0, 1408.0],
|
| 72 |
+
"0.52": [704.0, 1344.0],
|
| 73 |
+
"0.57": [768.0, 1344.0],
|
| 74 |
+
"0.6": [768.0, 1280.0],
|
| 75 |
+
"0.68": [832.0, 1216.0],
|
| 76 |
+
"0.72": [832.0, 1152.0],
|
| 77 |
+
"0.78": [896.0, 1152.0],
|
| 78 |
+
"0.82": [896.0, 1088.0],
|
| 79 |
+
"0.88": [960.0, 1088.0],
|
| 80 |
+
"0.94": [960.0, 1024.0],
|
| 81 |
+
"1.0": [1024.0, 1024.0],
|
| 82 |
+
"1.07": [1024.0, 960.0],
|
| 83 |
+
"1.13": [1088.0, 960.0],
|
| 84 |
+
"1.21": [1088.0, 896.0],
|
| 85 |
+
"1.29": [1152.0, 896.0],
|
| 86 |
+
"1.38": [1152.0, 832.0],
|
| 87 |
+
"1.46": [1216.0, 832.0],
|
| 88 |
+
"1.67": [1280.0, 768.0],
|
| 89 |
+
"1.75": [1344.0, 768.0],
|
| 90 |
+
"2.0": [1408.0, 704.0],
|
| 91 |
+
"2.09": [1472.0, 704.0],
|
| 92 |
+
"2.4": [1536.0, 640.0],
|
| 93 |
+
"2.5": [1600.0, 640.0],
|
| 94 |
+
"3.0": [1728.0, 576.0],
|
| 95 |
+
"4.0": [2048.0, 512.0],
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
ASPECT_RATIO_512_BIN = {
|
| 99 |
+
"0.25": [256.0, 1024.0],
|
| 100 |
+
"0.28": [256.0, 928.0],
|
| 101 |
+
"0.32": [288.0, 896.0],
|
| 102 |
+
"0.33": [288.0, 864.0],
|
| 103 |
+
"0.35": [288.0, 832.0],
|
| 104 |
+
"0.4": [320.0, 800.0],
|
| 105 |
+
"0.42": [320.0, 768.0],
|
| 106 |
+
"0.48": [352.0, 736.0],
|
| 107 |
+
"0.5": [352.0, 704.0],
|
| 108 |
+
"0.52": [352.0, 672.0],
|
| 109 |
+
"0.57": [384.0, 672.0],
|
| 110 |
+
"0.6": [384.0, 640.0],
|
| 111 |
+
"0.68": [416.0, 608.0],
|
| 112 |
+
"0.72": [416.0, 576.0],
|
| 113 |
+
"0.78": [448.0, 576.0],
|
| 114 |
+
"0.82": [448.0, 544.0],
|
| 115 |
+
"0.88": [480.0, 544.0],
|
| 116 |
+
"0.94": [480.0, 512.0],
|
| 117 |
+
"1.0": [512.0, 512.0],
|
| 118 |
+
"1.07": [512.0, 480.0],
|
| 119 |
+
"1.13": [544.0, 480.0],
|
| 120 |
+
"1.21": [544.0, 448.0],
|
| 121 |
+
"1.29": [576.0, 448.0],
|
| 122 |
+
"1.38": [576.0, 416.0],
|
| 123 |
+
"1.46": [608.0, 416.0],
|
| 124 |
+
"1.67": [640.0, 384.0],
|
| 125 |
+
"1.75": [672.0, 384.0],
|
| 126 |
+
"2.0": [704.0, 352.0],
|
| 127 |
+
"2.09": [736.0, 352.0],
|
| 128 |
+
"2.4": [768.0, 320.0],
|
| 129 |
+
"2.5": [800.0, 320.0],
|
| 130 |
+
"3.0": [864.0, 288.0],
|
| 131 |
+
"4.0": [1024.0, 256.0],
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class PixArtAlphaPipeline(DiffusionPipeline):
|
| 136 |
+
r"""
|
| 137 |
+
Pipeline for text-to-image generation using PixArt-Alpha.
|
| 138 |
+
|
| 139 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 140 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
vae ([`AutoencoderKL`]):
|
| 144 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 145 |
+
text_encoder ([`T5EncoderModel`]):
|
| 146 |
+
Frozen text-encoder. PixArt-Alpha uses
|
| 147 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
| 148 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 149 |
+
tokenizer (`T5Tokenizer`):
|
| 150 |
+
Tokenizer of class
|
| 151 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 152 |
+
transformer ([`Transformer2DModel`]):
|
| 153 |
+
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
|
| 154 |
+
scheduler ([`SchedulerMixin`]):
|
| 155 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
bad_punct_regex = re.compile(
|
| 159 |
+
r"["
|
| 160 |
+
+ "#®•©™&@·º½¾¿¡§~"
|
| 161 |
+
+ r"\)"
|
| 162 |
+
+ r"\("
|
| 163 |
+
+ r"\]"
|
| 164 |
+
+ r"\["
|
| 165 |
+
+ r"\}"
|
| 166 |
+
+ r"\{"
|
| 167 |
+
+ r"\|"
|
| 168 |
+
+ "\\"
|
| 169 |
+
+ r"\/"
|
| 170 |
+
+ r"\*"
|
| 171 |
+
+ r"]{1,}"
|
| 172 |
+
) # noqa
|
| 173 |
+
|
| 174 |
+
_optional_components = ["tokenizer", "text_encoder"]
|
| 175 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 176 |
+
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
tokenizer: T5Tokenizer,
|
| 180 |
+
text_encoder: T5EncoderModel,
|
| 181 |
+
vae: AutoencoderKL,
|
| 182 |
+
transformer: Transformer2DModel,
|
| 183 |
+
scheduler: DPMSolverMultistepScheduler,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
|
| 187 |
+
self.register_modules(
|
| 188 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 192 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 193 |
+
|
| 194 |
+
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
|
| 195 |
+
def mask_text_embeddings(self, emb, mask):
|
| 196 |
+
if emb.shape[0] == 1:
|
| 197 |
+
keep_index = mask.sum().item()
|
| 198 |
+
return emb[:, :, :keep_index, :], keep_index
|
| 199 |
+
else:
|
| 200 |
+
masked_feature = emb * mask[:, None, :, None]
|
| 201 |
+
return masked_feature, emb.shape[2]
|
| 202 |
+
|
| 203 |
+
# Adapted from ppdiffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
|
| 204 |
+
def encode_prompt(
|
| 205 |
+
self,
|
| 206 |
+
prompt: Union[str, List[str]],
|
| 207 |
+
do_classifier_free_guidance: bool = True,
|
| 208 |
+
negative_prompt: str = "",
|
| 209 |
+
num_images_per_prompt: int = 1,
|
| 210 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 211 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 212 |
+
prompt_attention_mask: Optional[paddle.Tensor] = None,
|
| 213 |
+
negative_prompt_attention_mask: Optional[paddle.Tensor] = None,
|
| 214 |
+
clean_caption: bool = False,
|
| 215 |
+
**kwargs,
|
| 216 |
+
):
|
| 217 |
+
r"""
|
| 218 |
+
Encodes the prompt into text encoder hidden states.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 222 |
+
prompt to be encoded
|
| 223 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 224 |
+
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
| 225 |
+
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
| 226 |
+
PixArt-Alpha, this should be "".
|
| 227 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 228 |
+
whether to use classifier free guidance or not
|
| 229 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 230 |
+
number of images that should be generated per prompt
|
| 231 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 232 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 233 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 234 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 235 |
+
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
|
| 236 |
+
string.
|
| 237 |
+
clean_caption (bool, defaults to `False`):
|
| 238 |
+
If `True`, the function will preprocess and clean the provided caption before encoding.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
if "mask_feature" in kwargs:
|
| 242 |
+
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
|
| 243 |
+
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
|
| 244 |
+
|
| 245 |
+
if prompt is not None and isinstance(prompt, str):
|
| 246 |
+
batch_size = 1
|
| 247 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 248 |
+
batch_size = len(prompt)
|
| 249 |
+
else:
|
| 250 |
+
batch_size = prompt_embeds.shape[0]
|
| 251 |
+
|
| 252 |
+
# See Section 3.1. of the paper.
|
| 253 |
+
max_length = 120
|
| 254 |
+
|
| 255 |
+
if prompt_embeds is None:
|
| 256 |
+
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
| 257 |
+
text_inputs = self.tokenizer(
|
| 258 |
+
prompt,
|
| 259 |
+
padding="max_length",
|
| 260 |
+
max_length=max_length,
|
| 261 |
+
truncation=True,
|
| 262 |
+
add_special_tokens=True,
|
| 263 |
+
return_tensors="pd",
|
| 264 |
+
)
|
| 265 |
+
text_input_ids = text_inputs.input_ids
|
| 266 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pd").input_ids
|
| 267 |
+
|
| 268 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not paddle.equal_all(
|
| 269 |
+
text_input_ids, untruncated_ids
|
| 270 |
+
):
|
| 271 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
| 272 |
+
logger.warning(
|
| 273 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 274 |
+
f" {max_length} tokens: {removed_text}"
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 278 |
+
|
| 279 |
+
prompt_embeds = self.text_encoder(text_input_ids, attention_mask=prompt_attention_mask)
|
| 280 |
+
prompt_embeds = prompt_embeds[0]
|
| 281 |
+
|
| 282 |
+
if self.text_encoder is not None:
|
| 283 |
+
dtype = self.text_encoder.dtype
|
| 284 |
+
elif self.transformer is not None:
|
| 285 |
+
dtype = self.transformer.dtype
|
| 286 |
+
else:
|
| 287 |
+
dtype = None
|
| 288 |
+
|
| 289 |
+
prompt_embeds = prompt_embeds.cast(dtype=dtype)
|
| 290 |
+
|
| 291 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 292 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 293 |
+
prompt_embeds = prompt_embeds.tile([1, num_images_per_prompt, 1])
|
| 294 |
+
prompt_embeds = prompt_embeds.reshape([bs_embed * num_images_per_prompt, seq_len, -1])
|
| 295 |
+
prompt_attention_mask = prompt_attention_mask.reshape([bs_embed, -1])
|
| 296 |
+
prompt_attention_mask = prompt_attention_mask.tile([num_images_per_prompt, 1])
|
| 297 |
+
|
| 298 |
+
# get unconditional embeddings for classifier free guidance
|
| 299 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 300 |
+
uncond_tokens = [negative_prompt] * batch_size
|
| 301 |
+
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
| 302 |
+
max_length = prompt_embeds.shape[1]
|
| 303 |
+
uncond_input = self.tokenizer(
|
| 304 |
+
uncond_tokens,
|
| 305 |
+
padding="max_length",
|
| 306 |
+
max_length=max_length,
|
| 307 |
+
truncation=True,
|
| 308 |
+
return_attention_mask=True,
|
| 309 |
+
add_special_tokens=True,
|
| 310 |
+
return_tensors="pd",
|
| 311 |
+
)
|
| 312 |
+
negative_prompt_attention_mask = uncond_input.attention_mask
|
| 313 |
+
|
| 314 |
+
negative_prompt_embeds = self.text_encoder(
|
| 315 |
+
uncond_input.input_ids, attention_mask=negative_prompt_attention_mask
|
| 316 |
+
)
|
| 317 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 318 |
+
|
| 319 |
+
if do_classifier_free_guidance:
|
| 320 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 321 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 322 |
+
|
| 323 |
+
negative_prompt_embeds = negative_prompt_embeds.cast(dtype=dtype)
|
| 324 |
+
|
| 325 |
+
negative_prompt_embeds = negative_prompt_embeds.tile([1, num_images_per_prompt, 1])
|
| 326 |
+
negative_prompt_embeds = negative_prompt_embeds.reshape([batch_size * num_images_per_prompt, seq_len, -1])
|
| 327 |
+
|
| 328 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.reshape([bs_embed, -1])
|
| 329 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.tile([num_images_per_prompt, 1])
|
| 330 |
+
else:
|
| 331 |
+
negative_prompt_embeds = None
|
| 332 |
+
negative_prompt_attention_mask = None
|
| 333 |
+
|
| 334 |
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
| 335 |
+
|
| 336 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 337 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 338 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 339 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 340 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 341 |
+
# and should be between [0, 1]
|
| 342 |
+
|
| 343 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 344 |
+
extra_step_kwargs = {}
|
| 345 |
+
if accepts_eta:
|
| 346 |
+
extra_step_kwargs["eta"] = eta
|
| 347 |
+
|
| 348 |
+
# check if the scheduler accepts generator
|
| 349 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 350 |
+
if accepts_generator:
|
| 351 |
+
extra_step_kwargs["generator"] = generator
|
| 352 |
+
return extra_step_kwargs
|
| 353 |
+
|
| 354 |
+
def check_inputs(
|
| 355 |
+
self,
|
| 356 |
+
prompt,
|
| 357 |
+
height,
|
| 358 |
+
width,
|
| 359 |
+
negative_prompt,
|
| 360 |
+
callback_steps,
|
| 361 |
+
prompt_embeds=None,
|
| 362 |
+
negative_prompt_embeds=None,
|
| 363 |
+
prompt_attention_mask=None,
|
| 364 |
+
negative_prompt_attention_mask=None,
|
| 365 |
+
):
|
| 366 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 367 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 368 |
+
|
| 369 |
+
if (callback_steps is None) or (
|
| 370 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 371 |
+
):
|
| 372 |
+
raise ValueError(
|
| 373 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 374 |
+
f" {type(callback_steps)}."
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
if prompt is not None and prompt_embeds is not None:
|
| 378 |
+
raise ValueError(
|
| 379 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 380 |
+
" only forward one of the two."
|
| 381 |
+
)
|
| 382 |
+
elif prompt is None and prompt_embeds is None:
|
| 383 |
+
raise ValueError(
|
| 384 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 385 |
+
)
|
| 386 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 387 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 388 |
+
|
| 389 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 390 |
+
raise ValueError(
|
| 391 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 392 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 396 |
+
raise ValueError(
|
| 397 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 398 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
if prompt_embeds is not None and prompt_attention_mask is None:
|
| 402 |
+
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
| 403 |
+
|
| 404 |
+
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
| 405 |
+
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
| 406 |
+
|
| 407 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 408 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 409 |
+
raise ValueError(
|
| 410 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 411 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 412 |
+
f" {negative_prompt_embeds.shape}."
|
| 413 |
+
)
|
| 414 |
+
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
| 415 |
+
raise ValueError(
|
| 416 |
+
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
| 417 |
+
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
| 418 |
+
f" {negative_prompt_attention_mask.shape}."
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
# Copied from ppdiffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
| 422 |
+
def _text_preprocessing(self, text, clean_caption=False):
|
| 423 |
+
if clean_caption and not is_bs4_available():
|
| 424 |
+
logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
| 425 |
+
logger.warn("Setting `clean_caption` to False...")
|
| 426 |
+
clean_caption = False
|
| 427 |
+
|
| 428 |
+
if clean_caption and not is_ftfy_available():
|
| 429 |
+
logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
| 430 |
+
logger.warn("Setting `clean_caption` to False...")
|
| 431 |
+
clean_caption = False
|
| 432 |
+
|
| 433 |
+
if not isinstance(text, (tuple, list)):
|
| 434 |
+
text = [text]
|
| 435 |
+
|
| 436 |
+
def process(text: str):
|
| 437 |
+
if clean_caption:
|
| 438 |
+
text = self._clean_caption(text)
|
| 439 |
+
text = self._clean_caption(text)
|
| 440 |
+
else:
|
| 441 |
+
text = text.lower().strip()
|
| 442 |
+
return text
|
| 443 |
+
|
| 444 |
+
return [process(t) for t in text]
|
| 445 |
+
|
| 446 |
+
# Copied from ppdiffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
| 447 |
+
def _clean_caption(self, caption):
|
| 448 |
+
caption = str(caption)
|
| 449 |
+
caption = ul.unquote_plus(caption)
|
| 450 |
+
caption = caption.strip().lower()
|
| 451 |
+
caption = re.sub("<person>", "person", caption)
|
| 452 |
+
# urls:
|
| 453 |
+
caption = re.sub(
|
| 454 |
+
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
| 455 |
+
"",
|
| 456 |
+
caption,
|
| 457 |
+
) # regex for urls
|
| 458 |
+
caption = re.sub(
|
| 459 |
+
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
| 460 |
+
"",
|
| 461 |
+
caption,
|
| 462 |
+
) # regex for urls
|
| 463 |
+
# html:
|
| 464 |
+
caption = BeautifulSoup(caption, features="html.parser").text
|
| 465 |
+
|
| 466 |
+
# @<nickname>
|
| 467 |
+
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
| 468 |
+
|
| 469 |
+
# 31C0—31EF CJK Strokes
|
| 470 |
+
# 31F0—31FF Katakana Phonetic Extensions
|
| 471 |
+
# 3200—32FF Enclosed CJK Letters and Months
|
| 472 |
+
# 3300—33FF CJK Compatibility
|
| 473 |
+
# 3400—4DBF CJK Unified Ideographs Extension A
|
| 474 |
+
# 4DC0—4DFF Yijing Hexagram Symbols
|
| 475 |
+
# 4E00—9FFF CJK Unified Ideographs
|
| 476 |
+
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
| 477 |
+
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
| 478 |
+
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
| 479 |
+
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
| 480 |
+
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
| 481 |
+
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
| 482 |
+
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
| 483 |
+
#######################################################
|
| 484 |
+
|
| 485 |
+
# все виды тире / all types of dash --> "-"
|
| 486 |
+
caption = re.sub(
|
| 487 |
+
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
| 488 |
+
"-",
|
| 489 |
+
caption,
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
# кавычки к одному стандарту
|
| 493 |
+
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
| 494 |
+
caption = re.sub(r"[‘’]", "'", caption)
|
| 495 |
+
|
| 496 |
+
# "
|
| 497 |
+
caption = re.sub(r""?", "", caption)
|
| 498 |
+
# &
|
| 499 |
+
caption = re.sub(r"&", "", caption)
|
| 500 |
+
|
| 501 |
+
# ip adresses:
|
| 502 |
+
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
| 503 |
+
|
| 504 |
+
# article ids:
|
| 505 |
+
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
| 506 |
+
|
| 507 |
+
# \n
|
| 508 |
+
caption = re.sub(r"\\n", " ", caption)
|
| 509 |
+
|
| 510 |
+
# "#123"
|
| 511 |
+
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
| 512 |
+
# "#12345.."
|
| 513 |
+
caption = re.sub(r"#\d{5,}\b", "", caption)
|
| 514 |
+
# "123456.."
|
| 515 |
+
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
| 516 |
+
# filenames:
|
| 517 |
+
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
| 518 |
+
|
| 519 |
+
#
|
| 520 |
+
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
| 521 |
+
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
| 522 |
+
|
| 523 |
+
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
| 524 |
+
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
| 525 |
+
|
| 526 |
+
# this-is-my-cute-cat / this_is_my_cute_cat
|
| 527 |
+
regex2 = re.compile(r"(?:\-|\_)")
|
| 528 |
+
if len(re.findall(regex2, caption)) > 3:
|
| 529 |
+
caption = re.sub(regex2, " ", caption)
|
| 530 |
+
|
| 531 |
+
caption = ftfy.fix_text(caption)
|
| 532 |
+
caption = html.unescape(html.unescape(caption))
|
| 533 |
+
|
| 534 |
+
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
| 535 |
+
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
| 536 |
+
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
| 537 |
+
|
| 538 |
+
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
| 539 |
+
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
| 540 |
+
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
| 541 |
+
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
| 542 |
+
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
| 543 |
+
|
| 544 |
+
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
| 545 |
+
|
| 546 |
+
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
| 547 |
+
|
| 548 |
+
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
| 549 |
+
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
| 550 |
+
caption = re.sub(r"\s+", " ", caption)
|
| 551 |
+
|
| 552 |
+
caption.strip()
|
| 553 |
+
|
| 554 |
+
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
| 555 |
+
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
| 556 |
+
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
| 557 |
+
caption = re.sub(r"^\.\S+$", "", caption)
|
| 558 |
+
|
| 559 |
+
return caption.strip()
|
| 560 |
+
|
| 561 |
+
# Copied from ppdiffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
| 562 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
|
| 563 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
| 564 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 565 |
+
raise ValueError(
|
| 566 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 567 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
if latents is None:
|
| 571 |
+
latents = randn_tensor(shape, generator=generator, dtype=dtype)
|
| 572 |
+
else:
|
| 573 |
+
latents = latents.cast(dtype)
|
| 574 |
+
|
| 575 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 576 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 577 |
+
return latents
|
| 578 |
+
|
| 579 |
+
@staticmethod
|
| 580 |
+
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
| 581 |
+
"""Returns binned height and width."""
|
| 582 |
+
ar = float(height / width)
|
| 583 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
| 584 |
+
default_hw = ratios[closest_ratio]
|
| 585 |
+
return int(default_hw[0]), int(default_hw[1])
|
| 586 |
+
|
| 587 |
+
@staticmethod
|
| 588 |
+
def resize_and_crop_tensor(samples: paddle.Tensor, new_width: int, new_height: int) -> paddle.Tensor:
|
| 589 |
+
orig_height, orig_width = samples.shape[2], samples.shape[3]
|
| 590 |
+
|
| 591 |
+
# Check if resizing is needed
|
| 592 |
+
if orig_height != new_height or orig_width != new_width:
|
| 593 |
+
ratio = max(new_height / orig_height, new_width / orig_width)
|
| 594 |
+
resized_width = int(orig_width * ratio)
|
| 595 |
+
resized_height = int(orig_height * ratio)
|
| 596 |
+
|
| 597 |
+
# Resize
|
| 598 |
+
samples = F.interpolate(
|
| 599 |
+
samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
# Center Crop
|
| 603 |
+
start_x = (resized_width - new_width) // 2
|
| 604 |
+
end_x = start_x + new_width
|
| 605 |
+
start_y = (resized_height - new_height) // 2
|
| 606 |
+
end_y = start_y + new_height
|
| 607 |
+
samples = samples[:, :, start_y:end_y, start_x:end_x]
|
| 608 |
+
|
| 609 |
+
return samples
|
| 610 |
+
|
| 611 |
+
@paddle.no_grad()
|
| 612 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 613 |
+
def __call__(
|
| 614 |
+
self,
|
| 615 |
+
prompt: Union[str, List[str]] = None,
|
| 616 |
+
negative_prompt: str = "",
|
| 617 |
+
num_inference_steps: int = 20,
|
| 618 |
+
timesteps: List[int] = None,
|
| 619 |
+
guidance_scale: float = 4.5,
|
| 620 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 621 |
+
height: Optional[int] = None,
|
| 622 |
+
width: Optional[int] = None,
|
| 623 |
+
eta: float = 0.0,
|
| 624 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 625 |
+
latents: Optional[paddle.Tensor] = None,
|
| 626 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 627 |
+
prompt_attention_mask: Optional[paddle.Tensor] = None,
|
| 628 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 629 |
+
negative_prompt_attention_mask: Optional[paddle.Tensor] = None,
|
| 630 |
+
output_type: Optional[str] = "pil",
|
| 631 |
+
return_dict: bool = True,
|
| 632 |
+
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
|
| 633 |
+
callback_steps: int = 1,
|
| 634 |
+
clean_caption: bool = True,
|
| 635 |
+
use_resolution_binning: bool = True,
|
| 636 |
+
**kwargs,
|
| 637 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 638 |
+
"""
|
| 639 |
+
Function invoked when calling the pipeline for generation.
|
| 640 |
+
|
| 641 |
+
Args:
|
| 642 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 643 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 644 |
+
instead.
|
| 645 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 646 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 647 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 648 |
+
less than `1`).
|
| 649 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
| 650 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 651 |
+
expense of slower inference.
|
| 652 |
+
timesteps (`List[int]`, *optional*):
|
| 653 |
+
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
| 654 |
+
timesteps are used. Must be in descending order.
|
| 655 |
+
guidance_scale (`float`, *optional*, defaults to 4.5):
|
| 656 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 657 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 658 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 659 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 660 |
+
usually at the expense of lower image quality.
|
| 661 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 662 |
+
The number of images to generate per prompt.
|
| 663 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
| 664 |
+
The height in pixels of the generated image.
|
| 665 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
| 666 |
+
The width in pixels of the generated image.
|
| 667 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 668 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 669 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 670 |
+
generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*):
|
| 671 |
+
One or a list of [paddle generator(s)] to make generation deterministic.
|
| 672 |
+
latents (`paddle.Tensor`, *optional*):
|
| 673 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 674 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 675 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 676 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 677 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 678 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 679 |
+
prompt_attention_mask (`paddle.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
|
| 680 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 681 |
+
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
|
| 682 |
+
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
| 683 |
+
negative_prompt_attention_mask (`paddle.Tensor`, *optional*):
|
| 684 |
+
Pre-generated attention mask for negative text embeddings.
|
| 685 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 686 |
+
The output format of the generate image. Choose between
|
| 687 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 688 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 689 |
+
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
| 690 |
+
callback (`Callable`, *optional*):
|
| 691 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 692 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
|
| 693 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 694 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 695 |
+
called at every step.
|
| 696 |
+
clean_caption (`bool`, *optional*, defaults to `True`):
|
| 697 |
+
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
| 698 |
+
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
| 699 |
+
prompt.
|
| 700 |
+
use_resolution_binning (`bool` defaults to `True`):
|
| 701 |
+
If set to `True`, the requested height and width are first mapped to the closest resolutions using
|
| 702 |
+
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
|
| 703 |
+
the requested resolution. Useful for generating non-square images.
|
| 704 |
+
|
| 705 |
+
Examples:
|
| 706 |
+
|
| 707 |
+
Returns:
|
| 708 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
| 709 |
+
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
| 710 |
+
returned where the first element is a list with the generated images
|
| 711 |
+
"""
|
| 712 |
+
if "mask_feature" in kwargs:
|
| 713 |
+
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
|
| 714 |
+
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
|
| 715 |
+
# 1. Check inputs. Raise error if not correct
|
| 716 |
+
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
| 717 |
+
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
| 718 |
+
if use_resolution_binning:
|
| 719 |
+
aspect_ratio_bin = (
|
| 720 |
+
ASPECT_RATIO_1024_BIN if self.transformer.config.sample_size == 128 else ASPECT_RATIO_512_BIN
|
| 721 |
+
)
|
| 722 |
+
orig_height, orig_width = height, width
|
| 723 |
+
height, width = self.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
| 724 |
+
|
| 725 |
+
self.check_inputs(
|
| 726 |
+
prompt,
|
| 727 |
+
height,
|
| 728 |
+
width,
|
| 729 |
+
negative_prompt,
|
| 730 |
+
callback_steps,
|
| 731 |
+
prompt_embeds,
|
| 732 |
+
negative_prompt_embeds,
|
| 733 |
+
prompt_attention_mask,
|
| 734 |
+
negative_prompt_attention_mask,
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
# 2. Default height and width to transformer
|
| 738 |
+
if prompt is not None and isinstance(prompt, str):
|
| 739 |
+
batch_size = 1
|
| 740 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 741 |
+
batch_size = len(prompt)
|
| 742 |
+
else:
|
| 743 |
+
batch_size = prompt_embeds.shape[0]
|
| 744 |
+
|
| 745 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 746 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 747 |
+
# corresponds to doing no classifier free guidance.
|
| 748 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 749 |
+
|
| 750 |
+
# 3. Encode input prompt
|
| 751 |
+
(
|
| 752 |
+
prompt_embeds,
|
| 753 |
+
prompt_attention_mask,
|
| 754 |
+
negative_prompt_embeds,
|
| 755 |
+
negative_prompt_attention_mask,
|
| 756 |
+
) = self.encode_prompt(
|
| 757 |
+
prompt,
|
| 758 |
+
do_classifier_free_guidance,
|
| 759 |
+
negative_prompt=negative_prompt,
|
| 760 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 761 |
+
prompt_embeds=prompt_embeds,
|
| 762 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 763 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 764 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
| 765 |
+
clean_caption=clean_caption,
|
| 766 |
+
)
|
| 767 |
+
if do_classifier_free_guidance:
|
| 768 |
+
prompt_embeds = paddle.concat([negative_prompt_embeds, prompt_embeds], axis=0)
|
| 769 |
+
prompt_attention_mask = paddle.concat([negative_prompt_attention_mask, prompt_attention_mask], axis=0)
|
| 770 |
+
|
| 771 |
+
# 4. Prepare timesteps
|
| 772 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
| 773 |
+
timesteps = self.scheduler.timesteps
|
| 774 |
+
|
| 775 |
+
# 5. Prepare latents.
|
| 776 |
+
latent_channels = self.transformer.config.in_channels
|
| 777 |
+
latents = self.prepare_latents(
|
| 778 |
+
batch_size * num_images_per_prompt,
|
| 779 |
+
latent_channels,
|
| 780 |
+
height,
|
| 781 |
+
width,
|
| 782 |
+
prompt_embeds.dtype,
|
| 783 |
+
generator,
|
| 784 |
+
latents,
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 788 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 789 |
+
|
| 790 |
+
# 6.1 Prepare micro-conditions.
|
| 791 |
+
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
| 792 |
+
if self.transformer.config.sample_size == 128:
|
| 793 |
+
resolution = paddle.to_tensor([height, width]).tile([batch_size * num_images_per_prompt, 1])
|
| 794 |
+
aspect_ratio = paddle.to_tensor([float(height / width)]).tile([batch_size * num_images_per_prompt, 1])
|
| 795 |
+
resolution = resolution.cast(dtype=prompt_embeds.dtype)
|
| 796 |
+
aspect_ratio = aspect_ratio.cast(dtype=prompt_embeds.dtype)
|
| 797 |
+
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
| 798 |
+
|
| 799 |
+
# 7. Denoising loop
|
| 800 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 801 |
+
|
| 802 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 803 |
+
for i, t in enumerate(timesteps):
|
| 804 |
+
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
|
| 805 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 806 |
+
|
| 807 |
+
current_timestep = t
|
| 808 |
+
if not paddle.is_tensor(current_timestep):
|
| 809 |
+
if isinstance(current_timestep, float):
|
| 810 |
+
dtype = paddle.float32
|
| 811 |
+
else:
|
| 812 |
+
dtype = paddle.int64
|
| 813 |
+
current_timestep = paddle.to_tensor([current_timestep], dtype=dtype)
|
| 814 |
+
elif len(current_timestep.shape) == 0:
|
| 815 |
+
current_timestep = current_timestep[None]
|
| 816 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 817 |
+
current_timestep = current_timestep.expand(
|
| 818 |
+
[
|
| 819 |
+
latent_model_input.shape[0],
|
| 820 |
+
]
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
# predict noise model_output
|
| 824 |
+
noise_pred = self.transformer(
|
| 825 |
+
latent_model_input,
|
| 826 |
+
encoder_hidden_states=prompt_embeds,
|
| 827 |
+
encoder_attention_mask=prompt_attention_mask,
|
| 828 |
+
timestep=current_timestep,
|
| 829 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 830 |
+
return_dict=False,
|
| 831 |
+
)[0]
|
| 832 |
+
|
| 833 |
+
# perform guidance
|
| 834 |
+
if do_classifier_free_guidance:
|
| 835 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 836 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 837 |
+
|
| 838 |
+
# learned sigma
|
| 839 |
+
if self.transformer.config.out_channels // 2 == latent_channels:
|
| 840 |
+
noise_pred = noise_pred.chunk(2, axis=1)[0]
|
| 841 |
+
else:
|
| 842 |
+
noise_pred = noise_pred
|
| 843 |
+
|
| 844 |
+
# compute previous image: x_t -> x_t-1
|
| 845 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 846 |
+
|
| 847 |
+
# call the callback, if provided
|
| 848 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 849 |
+
progress_bar.update()
|
| 850 |
+
if callback is not None and i % callback_steps == 0:
|
| 851 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 852 |
+
callback(step_idx, t, latents)
|
| 853 |
+
|
| 854 |
+
if not output_type == "latent":
|
| 855 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 856 |
+
if use_resolution_binning:
|
| 857 |
+
image = self.resize_and_crop_tensor(image, orig_width, orig_height)
|
| 858 |
+
else:
|
| 859 |
+
image = latents
|
| 860 |
+
|
| 861 |
+
if not output_type == "latent":
|
| 862 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 863 |
+
|
| 864 |
+
if not return_dict:
|
| 865 |
+
return (image,)
|
| 866 |
+
|
| 867 |
+
return ImagePipelineOutput(images=image)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/pndm/pipeline_pndm.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from typing import List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import paddle
|
| 19 |
+
|
| 20 |
+
from ...models import UNet2DModel
|
| 21 |
+
from ...schedulers import PNDMScheduler
|
| 22 |
+
from ...utils.paddle_utils import randn_tensor
|
| 23 |
+
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class PNDMPipeline(DiffusionPipeline):
|
| 27 |
+
r"""
|
| 28 |
+
Pipeline for unconditional image generation.
|
| 29 |
+
|
| 30 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 31 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 32 |
+
|
| 33 |
+
Parameters:
|
| 34 |
+
unet ([`UNet2DModel`]):
|
| 35 |
+
A `UNet2DModel` to denoise the encoded image latents.
|
| 36 |
+
scheduler ([`PNDMScheduler`]):
|
| 37 |
+
A `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
unet: UNet2DModel
|
| 41 |
+
scheduler: PNDMScheduler
|
| 42 |
+
|
| 43 |
+
def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
scheduler = PNDMScheduler.from_config(scheduler.config)
|
| 47 |
+
|
| 48 |
+
self.register_modules(unet=unet, scheduler=scheduler)
|
| 49 |
+
|
| 50 |
+
@paddle.no_grad()
|
| 51 |
+
def __call__(
|
| 52 |
+
self,
|
| 53 |
+
batch_size: int = 1,
|
| 54 |
+
num_inference_steps: int = 50,
|
| 55 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 56 |
+
output_type: Optional[str] = "pil",
|
| 57 |
+
return_dict: bool = True,
|
| 58 |
+
**kwargs,
|
| 59 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 60 |
+
r"""
|
| 61 |
+
The call function to the pipeline for generation.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
batch_size (`int`, `optional`, defaults to 1):
|
| 65 |
+
The number of images to generate.
|
| 66 |
+
num_inference_steps (`int`, `optional`, defaults to 50):
|
| 67 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 68 |
+
expense of slower inference.
|
| 69 |
+
generator (`paddle.Generator`, `optional`):
|
| 70 |
+
A [`paddle.Generator`] to make generation deterministic.
|
| 71 |
+
output_type (`str`, `optional`, defaults to `"pil"`):
|
| 72 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 73 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 74 |
+
Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.
|
| 75 |
+
|
| 76 |
+
Example:
|
| 77 |
+
|
| 78 |
+
```py
|
| 79 |
+
>>> from ppdiffusers import PNDMPipeline
|
| 80 |
+
|
| 81 |
+
>>> # load model and scheduler
|
| 82 |
+
>>> pndm = PNDMPipeline.from_pretrained("google/ddpm-cifar10-32")
|
| 83 |
+
|
| 84 |
+
>>> # run pipeline in inference (sample random noise and denoise)
|
| 85 |
+
>>> image = pndm().images[0]
|
| 86 |
+
|
| 87 |
+
>>> # save image
|
| 88 |
+
>>> image.save("pndm_generated_image.png")
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
| 93 |
+
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
| 94 |
+
returned where the first element is a list with the generated images.
|
| 95 |
+
"""
|
| 96 |
+
# For more information on the sampling method you can take a look at Algorithm 2 of
|
| 97 |
+
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
|
| 98 |
+
|
| 99 |
+
# Sample gaussian noise to begin loop
|
| 100 |
+
image = randn_tensor(
|
| 101 |
+
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
|
| 102 |
+
generator=generator,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
| 106 |
+
for t in self.progress_bar(self.scheduler.timesteps):
|
| 107 |
+
model_output = self.unet(image, t).sample
|
| 108 |
+
|
| 109 |
+
image = self.scheduler.step(model_output, t, image).prev_sample
|
| 110 |
+
|
| 111 |
+
image = (image / 2 + 0.5).clip(0, 1)
|
| 112 |
+
image = image.transpose([0, 2, 3, 1]).cast("float32").cpu().numpy()
|
| 113 |
+
if output_type == "pil":
|
| 114 |
+
image = self.numpy_to_pil(image)
|
| 115 |
+
|
| 116 |
+
if not return_dict:
|
| 117 |
+
return (image,)
|
| 118 |
+
|
| 119 |
+
return ImagePipelineOutput(images=image)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/score_sde_ve/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
from ...utils import PPDIFFUSERS_SLOW_IMPORT, _LazyModule
|
| 18 |
+
|
| 19 |
+
_import_structure = {"pipeline_score_sde_ve": ["ScoreSdeVePipeline"]}
|
| 20 |
+
|
| 21 |
+
if TYPE_CHECKING or PPDIFFUSERS_SLOW_IMPORT:
|
| 22 |
+
from .pipeline_score_sde_ve import ScoreSdeVePipeline
|
| 23 |
+
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
sys.modules[__name__] = _LazyModule(
|
| 28 |
+
__name__,
|
| 29 |
+
globals()["__file__"],
|
| 30 |
+
_import_structure,
|
| 31 |
+
module_spec=__spec__,
|
| 32 |
+
)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import List, Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import paddle
|
| 18 |
+
|
| 19 |
+
from ...models import UNet2DModel
|
| 20 |
+
from ...schedulers import ScoreSdeVeScheduler
|
| 21 |
+
from ...utils.paddle_utils import randn_tensor
|
| 22 |
+
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ScoreSdeVePipeline(DiffusionPipeline):
|
| 26 |
+
r"""
|
| 27 |
+
Pipeline for unconditional image generation.
|
| 28 |
+
|
| 29 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 30 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 31 |
+
|
| 32 |
+
Parameters:
|
| 33 |
+
unet ([`UNet2DModel`]):
|
| 34 |
+
A `UNet2DModel` to denoise the encoded image.
|
| 35 |
+
scheduler ([`ScoreSdeVeScheduler`]):
|
| 36 |
+
A `ScoreSdeVeScheduler` to be used in combination with `unet` to denoise the encoded image.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
unet: UNet2DModel
|
| 40 |
+
scheduler: ScoreSdeVeScheduler
|
| 41 |
+
|
| 42 |
+
def __init__(self, unet: UNet2DModel, scheduler: ScoreSdeVeScheduler):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.register_modules(unet=unet, scheduler=scheduler)
|
| 45 |
+
|
| 46 |
+
@paddle.no_grad()
|
| 47 |
+
def __call__(
|
| 48 |
+
self,
|
| 49 |
+
batch_size: int = 1,
|
| 50 |
+
num_inference_steps: int = 2000,
|
| 51 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 52 |
+
output_type: Optional[str] = "pil",
|
| 53 |
+
return_dict: bool = True,
|
| 54 |
+
**kwargs,
|
| 55 |
+
) -> Union[ImagePipelineOutput, Tuple]:
|
| 56 |
+
r"""
|
| 57 |
+
The call function to the pipeline for generation.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
batch_size (`int`, *optional*, defaults to 1):
|
| 61 |
+
The number of images to generate.
|
| 62 |
+
generator (`paddle.Generator`, `optional`):
|
| 63 |
+
A [`paddle.Generator`] to make generation deterministic.
|
| 64 |
+
output_type (`str`, `optional`, defaults to `"pil"`):
|
| 65 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 66 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 67 |
+
Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
| 71 |
+
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
| 72 |
+
returned where the first element is a list with the generated images.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
img_size = self.unet.config.sample_size
|
| 76 |
+
shape = (batch_size, 3, img_size, img_size)
|
| 77 |
+
|
| 78 |
+
model = self.unet
|
| 79 |
+
|
| 80 |
+
sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma
|
| 81 |
+
|
| 82 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
| 83 |
+
self.scheduler.set_sigmas(num_inference_steps)
|
| 84 |
+
|
| 85 |
+
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
| 86 |
+
sigma_t = self.scheduler.sigmas[i] * paddle.ones(
|
| 87 |
+
[
|
| 88 |
+
shape[0],
|
| 89 |
+
]
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# correction step
|
| 93 |
+
for _ in range(self.scheduler.config.correct_steps):
|
| 94 |
+
model_output = self.unet(sample, sigma_t).sample
|
| 95 |
+
sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample
|
| 96 |
+
|
| 97 |
+
# prediction step
|
| 98 |
+
model_output = model(sample, sigma_t).sample
|
| 99 |
+
output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
|
| 100 |
+
|
| 101 |
+
sample, sample_mean = output.prev_sample, output.prev_sample_mean
|
| 102 |
+
|
| 103 |
+
sample = sample_mean.clip(0, 1)
|
| 104 |
+
sample = sample.transpose([0, 2, 3, 1]).cast("float32").cpu().numpy()
|
| 105 |
+
if output_type == "pil":
|
| 106 |
+
sample = self.numpy_to_pil(sample)
|
| 107 |
+
|
| 108 |
+
if not return_dict:
|
| 109 |
+
return (sample,)
|
| 110 |
+
|
| 111 |
+
return ImagePipelineOutput(images=sample)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/spectrogram_diffusion/__init__.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# flake8: noqa
|
| 16 |
+
from typing import TYPE_CHECKING
|
| 17 |
+
|
| 18 |
+
from ...utils import (
|
| 19 |
+
PPDIFFUSERS_SLOW_IMPORT,
|
| 20 |
+
OptionalDependencyNotAvailable,
|
| 21 |
+
_LazyModule,
|
| 22 |
+
get_objects_from_module,
|
| 23 |
+
is_note_seq_available,
|
| 24 |
+
is_paddle_available,
|
| 25 |
+
is_paddlenlp_available,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
_dummy_objects = {}
|
| 29 |
+
_import_structure = {}
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 33 |
+
raise OptionalDependencyNotAvailable()
|
| 34 |
+
except OptionalDependencyNotAvailable:
|
| 35 |
+
from ...utils import dummy_paddle_and_paddlenlp_objects # noqa F403
|
| 36 |
+
|
| 37 |
+
_dummy_objects.update(get_objects_from_module(dummy_paddle_and_paddlenlp_objects))
|
| 38 |
+
else:
|
| 39 |
+
_import_structure["continous_encoder"] = ["SpectrogramContEncoder"]
|
| 40 |
+
_import_structure["notes_encoder"] = ["SpectrogramNotesEncoder"]
|
| 41 |
+
_import_structure["pipeline_spectrogram_diffusion"] = [
|
| 42 |
+
"SpectrogramContEncoder",
|
| 43 |
+
"SpectrogramDiffusionPipeline",
|
| 44 |
+
"T5FilmDecoder",
|
| 45 |
+
]
|
| 46 |
+
try:
|
| 47 |
+
if not (is_paddlenlp_available() and is_paddle_available() and is_note_seq_available()):
|
| 48 |
+
raise OptionalDependencyNotAvailable()
|
| 49 |
+
except OptionalDependencyNotAvailable:
|
| 50 |
+
from ...utils import dummy_paddle_and_paddlenlp_and_note_seq_objects
|
| 51 |
+
|
| 52 |
+
_dummy_objects.update(get_objects_from_module(dummy_paddle_and_paddlenlp_and_note_seq_objects))
|
| 53 |
+
else:
|
| 54 |
+
_import_structure["midi_utils"] = ["MidiProcessor"]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if TYPE_CHECKING or PPDIFFUSERS_SLOW_IMPORT:
|
| 58 |
+
try:
|
| 59 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 60 |
+
raise OptionalDependencyNotAvailable()
|
| 61 |
+
|
| 62 |
+
except OptionalDependencyNotAvailable:
|
| 63 |
+
from ...utils.dummy_paddle_and_paddlenlp_objects import *
|
| 64 |
+
else:
|
| 65 |
+
from .pipeline_spectrogram_diffusion import (
|
| 66 |
+
SpectrogramContEncoder,
|
| 67 |
+
SpectrogramDiffusionPipeline,
|
| 68 |
+
SpectrogramNotesEncoder,
|
| 69 |
+
T5FilmDecoder,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
if not (is_paddlenlp_available() and is_paddle_available() and is_note_seq_available()):
|
| 74 |
+
raise OptionalDependencyNotAvailable()
|
| 75 |
+
except OptionalDependencyNotAvailable:
|
| 76 |
+
from ...utils.dummy_paddle_and_paddlenlp_and_note_seq_objects import *
|
| 77 |
+
|
| 78 |
+
else:
|
| 79 |
+
from .midi_utils import MidiProcessor
|
| 80 |
+
|
| 81 |
+
else:
|
| 82 |
+
import sys
|
| 83 |
+
|
| 84 |
+
sys.modules[__name__] = _LazyModule(
|
| 85 |
+
__name__,
|
| 86 |
+
globals()["__file__"],
|
| 87 |
+
_import_structure,
|
| 88 |
+
module_spec=__spec__,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
for name, value in _dummy_objects.items():
|
| 92 |
+
setattr(sys.modules[__name__], name, value)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/spectrogram_diffusion/continous_encoder.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The Music Spectrogram Diffusion Authors.
|
| 2 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import paddle
|
| 17 |
+
import paddle.nn as nn
|
| 18 |
+
|
| 19 |
+
from ppdiffusers.transformers.model_utils import ModuleUtilsMixin
|
| 20 |
+
from ppdiffusers.transformers.t5.modeling import T5Block, T5Config, T5LayerNorm
|
| 21 |
+
|
| 22 |
+
from ...configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from ...models import ModelMixin
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
|
| 27 |
+
@register_to_config
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
input_dims: int,
|
| 31 |
+
targets_context_length: int,
|
| 32 |
+
d_model: int,
|
| 33 |
+
dropout_rate: float,
|
| 34 |
+
num_layers: int,
|
| 35 |
+
num_heads: int,
|
| 36 |
+
d_kv: int,
|
| 37 |
+
d_ff: int,
|
| 38 |
+
feed_forward_proj: str,
|
| 39 |
+
is_decoder: bool = False,
|
| 40 |
+
):
|
| 41 |
+
super().__init__()
|
| 42 |
+
|
| 43 |
+
self.input_proj = nn.Linear(input_dims, d_model, bias_attr=False)
|
| 44 |
+
|
| 45 |
+
self.position_encoding = nn.Embedding(targets_context_length, d_model)
|
| 46 |
+
self.position_encoding.weight.stop_gradient = True
|
| 47 |
+
|
| 48 |
+
self.dropout_pre = nn.Dropout(dropout_rate)
|
| 49 |
+
|
| 50 |
+
t5config = T5Config(
|
| 51 |
+
d_model=d_model,
|
| 52 |
+
num_heads=num_heads,
|
| 53 |
+
d_kv=d_kv,
|
| 54 |
+
d_ff=d_ff,
|
| 55 |
+
feed_forward_proj=feed_forward_proj,
|
| 56 |
+
dropout_rate=dropout_rate,
|
| 57 |
+
is_decoder=is_decoder,
|
| 58 |
+
is_encoder_decoder=False,
|
| 59 |
+
)
|
| 60 |
+
self.encoders = nn.LayerList()
|
| 61 |
+
for lyr_num in range(num_layers):
|
| 62 |
+
lyr = T5Block(t5config)
|
| 63 |
+
self.encoders.append(lyr)
|
| 64 |
+
|
| 65 |
+
self.layer_norm = T5LayerNorm(d_model)
|
| 66 |
+
self.dropout_post = nn.Dropout(p=dropout_rate)
|
| 67 |
+
|
| 68 |
+
def forward(self, encoder_inputs, encoder_inputs_mask):
|
| 69 |
+
x = self.input_proj(encoder_inputs)
|
| 70 |
+
|
| 71 |
+
# terminal relative positional encodings
|
| 72 |
+
max_positions = encoder_inputs.shape[1]
|
| 73 |
+
input_positions = paddle.arange(max_positions)
|
| 74 |
+
|
| 75 |
+
seq_lens = encoder_inputs_mask.sum(-1)
|
| 76 |
+
input_positions = paddle.roll(input_positions.unsqueeze(0), tuple(seq_lens.tolist()), axis=0)
|
| 77 |
+
x += self.position_encoding(input_positions)
|
| 78 |
+
|
| 79 |
+
x = self.dropout_pre(x)
|
| 80 |
+
|
| 81 |
+
# inverted the attention mask
|
| 82 |
+
input_shape = encoder_inputs.shape
|
| 83 |
+
extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape)
|
| 84 |
+
|
| 85 |
+
for lyr in self.encoders:
|
| 86 |
+
x = lyr(x, extended_attention_mask)[0]
|
| 87 |
+
x = self.layer_norm(x)
|
| 88 |
+
|
| 89 |
+
return self.dropout_post(x), encoder_inputs_mask
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/spectrogram_diffusion/midi_utils.py
ADDED
|
@@ -0,0 +1,694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The Music Spectrogram Diffusion Authors.
|
| 2 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import dataclasses
|
| 17 |
+
import math
|
| 18 |
+
import os
|
| 19 |
+
from typing import (
|
| 20 |
+
Any,
|
| 21 |
+
Callable,
|
| 22 |
+
List,
|
| 23 |
+
Mapping,
|
| 24 |
+
MutableMapping,
|
| 25 |
+
Optional,
|
| 26 |
+
Sequence,
|
| 27 |
+
Tuple,
|
| 28 |
+
Union,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
import paddle
|
| 33 |
+
import paddle.nn.functional as F
|
| 34 |
+
|
| 35 |
+
from ...utils import is_note_seq_available
|
| 36 |
+
from .pipeline_spectrogram_diffusion import TARGET_FEATURE_LENGTH
|
| 37 |
+
|
| 38 |
+
if is_note_seq_available():
|
| 39 |
+
import note_seq
|
| 40 |
+
else:
|
| 41 |
+
raise ImportError("Please install note-seq via `pip install note-seq`")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
INPUT_FEATURE_LENGTH = 2048
|
| 45 |
+
|
| 46 |
+
SAMPLE_RATE = 16000
|
| 47 |
+
HOP_SIZE = 320
|
| 48 |
+
FRAME_RATE = int(SAMPLE_RATE // HOP_SIZE)
|
| 49 |
+
|
| 50 |
+
DEFAULT_STEPS_PER_SECOND = 100
|
| 51 |
+
DEFAULT_MAX_SHIFT_SECONDS = 10
|
| 52 |
+
DEFAULT_NUM_VELOCITY_BINS = 1
|
| 53 |
+
|
| 54 |
+
SLAKH_CLASS_PROGRAMS = {
|
| 55 |
+
"Acoustic Piano": 0,
|
| 56 |
+
"Electric Piano": 4,
|
| 57 |
+
"Chromatic Percussion": 8,
|
| 58 |
+
"Organ": 16,
|
| 59 |
+
"Acoustic Guitar": 24,
|
| 60 |
+
"Clean Electric Guitar": 26,
|
| 61 |
+
"Distorted Electric Guitar": 29,
|
| 62 |
+
"Acoustic Bass": 32,
|
| 63 |
+
"Electric Bass": 33,
|
| 64 |
+
"Violin": 40,
|
| 65 |
+
"Viola": 41,
|
| 66 |
+
"Cello": 42,
|
| 67 |
+
"Contrabass": 43,
|
| 68 |
+
"Orchestral Harp": 46,
|
| 69 |
+
"Timpani": 47,
|
| 70 |
+
"String Ensemble": 48,
|
| 71 |
+
"Synth Strings": 50,
|
| 72 |
+
"Choir and Voice": 52,
|
| 73 |
+
"Orchestral Hit": 55,
|
| 74 |
+
"Trumpet": 56,
|
| 75 |
+
"Trombone": 57,
|
| 76 |
+
"Tuba": 58,
|
| 77 |
+
"French Horn": 60,
|
| 78 |
+
"Brass Section": 61,
|
| 79 |
+
"Soprano/Alto Sax": 64,
|
| 80 |
+
"Tenor Sax": 66,
|
| 81 |
+
"Baritone Sax": 67,
|
| 82 |
+
"Oboe": 68,
|
| 83 |
+
"English Horn": 69,
|
| 84 |
+
"Bassoon": 70,
|
| 85 |
+
"Clarinet": 71,
|
| 86 |
+
"Pipe": 73,
|
| 87 |
+
"Synth Lead": 80,
|
| 88 |
+
"Synth Pad": 88,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclasses.dataclass
|
| 93 |
+
class NoteRepresentationConfig:
|
| 94 |
+
"""Configuration note representations."""
|
| 95 |
+
|
| 96 |
+
onsets_only: bool
|
| 97 |
+
include_ties: bool
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@dataclasses.dataclass
|
| 101 |
+
class NoteEventData:
|
| 102 |
+
pitch: int
|
| 103 |
+
velocity: Optional[int] = None
|
| 104 |
+
program: Optional[int] = None
|
| 105 |
+
is_drum: Optional[bool] = None
|
| 106 |
+
instrument: Optional[int] = None
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@dataclasses.dataclass
|
| 110 |
+
class NoteEncodingState:
|
| 111 |
+
"""Encoding state for note transcription, keeping track of active pitches."""
|
| 112 |
+
|
| 113 |
+
# velocity bin for active pitches and programs
|
| 114 |
+
active_pitches: MutableMapping[Tuple[int, int], int] = dataclasses.field(default_factory=dict)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@dataclasses.dataclass
|
| 118 |
+
class EventRange:
|
| 119 |
+
type: str
|
| 120 |
+
min_value: int
|
| 121 |
+
max_value: int
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@dataclasses.dataclass
|
| 125 |
+
class Event:
|
| 126 |
+
type: str
|
| 127 |
+
value: int
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class Tokenizer:
|
| 131 |
+
def __init__(self, regular_ids: int):
|
| 132 |
+
# The special tokens: 0=PAD, 1=EOS, and 2=UNK
|
| 133 |
+
self._num_special_tokens = 3
|
| 134 |
+
self._num_regular_tokens = regular_ids
|
| 135 |
+
|
| 136 |
+
def encode(self, token_ids):
|
| 137 |
+
encoded = []
|
| 138 |
+
for token_id in token_ids:
|
| 139 |
+
if not 0 <= token_id < self._num_regular_tokens:
|
| 140 |
+
raise ValueError(
|
| 141 |
+
f"token_id {token_id} does not fall within valid range of [0, {self._num_regular_tokens})"
|
| 142 |
+
)
|
| 143 |
+
encoded.append(token_id + self._num_special_tokens)
|
| 144 |
+
|
| 145 |
+
# Add EOS token
|
| 146 |
+
encoded.append(1)
|
| 147 |
+
|
| 148 |
+
# Pad to till INPUT_FEATURE_LENGTH
|
| 149 |
+
encoded = encoded + [0] * (INPUT_FEATURE_LENGTH - len(encoded))
|
| 150 |
+
|
| 151 |
+
return encoded
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class Codec:
|
| 155 |
+
"""Encode and decode events.
|
| 156 |
+
|
| 157 |
+
Useful for declaring what certain ranges of a vocabulary should be used for. This is intended to be used from
|
| 158 |
+
Python before encoding or after decoding with GenericTokenVocabulary. This class is more lightweight and does not
|
| 159 |
+
include things like EOS or UNK token handling.
|
| 160 |
+
|
| 161 |
+
To ensure that 'shift' events are always the first block of the vocab and start at 0, that event type is required
|
| 162 |
+
and specified separately.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
def __init__(self, max_shift_steps: int, steps_per_second: float, event_ranges: List[EventRange]):
|
| 166 |
+
"""Define Codec.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
max_shift_steps: Maximum number of shift steps that can be encoded.
|
| 170 |
+
steps_per_second: Shift steps will be interpreted as having a duration of
|
| 171 |
+
1 / steps_per_second.
|
| 172 |
+
event_ranges: Other supported event types and their ranges.
|
| 173 |
+
"""
|
| 174 |
+
self.steps_per_second = steps_per_second
|
| 175 |
+
self._shift_range = EventRange(type="shift", min_value=0, max_value=max_shift_steps)
|
| 176 |
+
self._event_ranges = [self._shift_range] + event_ranges
|
| 177 |
+
# Ensure all event types have unique names.
|
| 178 |
+
assert len(self._event_ranges) == len({er.type for er in self._event_ranges})
|
| 179 |
+
|
| 180 |
+
@property
|
| 181 |
+
def num_classes(self) -> int:
|
| 182 |
+
return sum(er.max_value - er.min_value + 1 for er in self._event_ranges)
|
| 183 |
+
|
| 184 |
+
# The next couple methods are simplified special case methods just for shift
|
| 185 |
+
# events that are intended to be used from within autograph functions.
|
| 186 |
+
|
| 187 |
+
def is_shift_event_index(self, index: int) -> bool:
|
| 188 |
+
return (self._shift_range.min_value <= index) and (index <= self._shift_range.max_value)
|
| 189 |
+
|
| 190 |
+
@property
|
| 191 |
+
def max_shift_steps(self) -> int:
|
| 192 |
+
return self._shift_range.max_value
|
| 193 |
+
|
| 194 |
+
def encode_event(self, event: Event) -> int:
|
| 195 |
+
"""Encode an event to an index."""
|
| 196 |
+
offset = 0
|
| 197 |
+
for er in self._event_ranges:
|
| 198 |
+
if event.type == er.type:
|
| 199 |
+
if not er.min_value <= event.value <= er.max_value:
|
| 200 |
+
raise ValueError(
|
| 201 |
+
f"Event value {event.value} is not within valid range "
|
| 202 |
+
f"[{er.min_value}, {er.max_value}] for type {event.type}"
|
| 203 |
+
)
|
| 204 |
+
return offset + event.value - er.min_value
|
| 205 |
+
offset += er.max_value - er.min_value + 1
|
| 206 |
+
|
| 207 |
+
raise ValueError(f"Unknown event type: {event.type}")
|
| 208 |
+
|
| 209 |
+
def event_type_range(self, event_type: str) -> Tuple[int, int]:
|
| 210 |
+
"""Return [min_id, max_id] for an event type."""
|
| 211 |
+
offset = 0
|
| 212 |
+
for er in self._event_ranges:
|
| 213 |
+
if event_type == er.type:
|
| 214 |
+
return offset, offset + (er.max_value - er.min_value)
|
| 215 |
+
offset += er.max_value - er.min_value + 1
|
| 216 |
+
|
| 217 |
+
raise ValueError(f"Unknown event type: {event_type}")
|
| 218 |
+
|
| 219 |
+
def decode_event_index(self, index: int) -> Event:
|
| 220 |
+
"""Decode an event index to an Event."""
|
| 221 |
+
offset = 0
|
| 222 |
+
for er in self._event_ranges:
|
| 223 |
+
if offset <= index <= offset + er.max_value - er.min_value:
|
| 224 |
+
return Event(type=er.type, value=er.min_value + index - offset)
|
| 225 |
+
offset += er.max_value - er.min_value + 1
|
| 226 |
+
|
| 227 |
+
raise ValueError(f"Unknown event index: {index}")
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@dataclasses.dataclass
|
| 231 |
+
class ProgramGranularity:
|
| 232 |
+
# both tokens_map_fn and program_map_fn should be idempotent
|
| 233 |
+
tokens_map_fn: Callable[[Sequence[int], Codec], Sequence[int]]
|
| 234 |
+
program_map_fn: Callable[[int], int]
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def drop_programs(tokens, codec: Codec):
|
| 238 |
+
"""Drops program change events from a token sequence."""
|
| 239 |
+
min_program_id, max_program_id = codec.event_type_range("program")
|
| 240 |
+
return tokens[(tokens < min_program_id) | (tokens > max_program_id)]
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def programs_to_midi_classes(tokens, codec):
|
| 244 |
+
"""Modifies program events to be the first program in the MIDI class."""
|
| 245 |
+
min_program_id, max_program_id = codec.event_type_range("program")
|
| 246 |
+
is_program = (tokens >= min_program_id) & (tokens <= max_program_id)
|
| 247 |
+
return np.where(is_program, min_program_id + 8 * ((tokens - min_program_id) // 8), tokens)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
PROGRAM_GRANULARITIES = {
|
| 251 |
+
# "flat" granularity; drop program change tokens and set NoteSequence
|
| 252 |
+
# programs to zero
|
| 253 |
+
"flat": ProgramGranularity(tokens_map_fn=drop_programs, program_map_fn=lambda program: 0),
|
| 254 |
+
# map each program to the first program in its MIDI class
|
| 255 |
+
"midi_class": ProgramGranularity(
|
| 256 |
+
tokens_map_fn=programs_to_midi_classes, program_map_fn=lambda program: 8 * (program // 8)
|
| 257 |
+
),
|
| 258 |
+
# leave programs as is
|
| 259 |
+
"full": ProgramGranularity(tokens_map_fn=lambda tokens, codec: tokens, program_map_fn=lambda program: program),
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def unfold(tensor, dimension, size, step=1):
|
| 264 |
+
assert dimension < len(tensor.shape), "dimension must be less than tensor dimensions"
|
| 265 |
+
assert tensor.shape[dimension] >= size, "size should not be greater than the dimension of tensor"
|
| 266 |
+
|
| 267 |
+
slices = []
|
| 268 |
+
for i in range(0, tensor.shape[dimension] - size + 1, step):
|
| 269 |
+
start = [0] * len(tensor.shape)
|
| 270 |
+
end = list(tensor.shape)
|
| 271 |
+
start[dimension] = i
|
| 272 |
+
end[dimension] = i + size
|
| 273 |
+
axes = list(range(len(start)))
|
| 274 |
+
slice = paddle.slice(tensor, axes, start, end)
|
| 275 |
+
slices.append(slice)
|
| 276 |
+
|
| 277 |
+
unfolded_tensor = paddle.stack(slices, axis=dimension)
|
| 278 |
+
|
| 279 |
+
return unfolded_tensor
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1):
|
| 283 |
+
"""
|
| 284 |
+
equivalent of tf.signal.frame
|
| 285 |
+
"""
|
| 286 |
+
signal_length = signal.shape[axis]
|
| 287 |
+
if pad_end:
|
| 288 |
+
frames_overlap = frame_length - frame_step
|
| 289 |
+
rest_samples = np.abs(signal_length - frames_overlap) % np.abs(frame_length - frames_overlap)
|
| 290 |
+
pad_size = int(frame_length - rest_samples)
|
| 291 |
+
if pad_size != 0:
|
| 292 |
+
pad_axis = [0] * signal.ndim
|
| 293 |
+
pad_axis[axis] = pad_size
|
| 294 |
+
signal = F.pad(x=signal, pad=pad_axis, mode="constant", value=pad_value)
|
| 295 |
+
frames = unfold(signal, axis, frame_length, frame_step)
|
| 296 |
+
return frames
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def program_to_slakh_program(program):
|
| 300 |
+
# this is done very hackily, probably should use a custom mapping
|
| 301 |
+
for slakh_program in sorted(SLAKH_CLASS_PROGRAMS.values(), reverse=True):
|
| 302 |
+
if program >= slakh_program:
|
| 303 |
+
return slakh_program
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def audio_to_frames(
|
| 307 |
+
samples,
|
| 308 |
+
hop_size: int,
|
| 309 |
+
frame_rate: int,
|
| 310 |
+
) -> Tuple[Sequence[Sequence[int]], paddle.Tensor]:
|
| 311 |
+
"""Convert audio samples to non-overlapping frames and frame times."""
|
| 312 |
+
frame_size = hop_size
|
| 313 |
+
samples = np.pad(samples, [0, frame_size - len(samples) % frame_size], mode="constant")
|
| 314 |
+
|
| 315 |
+
# Split audio into frames.
|
| 316 |
+
frames = frame(
|
| 317 |
+
paddle.to_tensor(samples).unsqueeze(0),
|
| 318 |
+
frame_length=frame_size,
|
| 319 |
+
frame_step=frame_size,
|
| 320 |
+
pad_end=False, # TODO check why its off by 1 here when True
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
num_frames = len(samples) // frame_size
|
| 324 |
+
|
| 325 |
+
times = np.arange(num_frames) / frame_rate
|
| 326 |
+
return frames, times
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def note_sequence_to_onsets_and_offsets_and_programs(
|
| 330 |
+
ns: note_seq.NoteSequence,
|
| 331 |
+
) -> Tuple[Sequence[float], Sequence[NoteEventData]]:
|
| 332 |
+
"""Extract onset & offset times and pitches & programs from a NoteSequence.
|
| 333 |
+
|
| 334 |
+
The onset & offset times will not necessarily be in sorted order.
|
| 335 |
+
|
| 336 |
+
Args:
|
| 337 |
+
ns: NoteSequence from which to extract onsets and offsets.
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
times: A list of note onset and offset times. values: A list of NoteEventData objects where velocity is zero for
|
| 341 |
+
note
|
| 342 |
+
offsets.
|
| 343 |
+
"""
|
| 344 |
+
# Sort by program and pitch and put offsets before onsets as a tiebreaker for
|
| 345 |
+
# subsequent stable sort.
|
| 346 |
+
notes = sorted(ns.notes, key=lambda note: (note.is_drum, note.program, note.pitch))
|
| 347 |
+
times = [note.end_time for note in notes if not note.is_drum] + [note.start_time for note in notes]
|
| 348 |
+
values = [
|
| 349 |
+
NoteEventData(pitch=note.pitch, velocity=0, program=note.program, is_drum=False)
|
| 350 |
+
for note in notes
|
| 351 |
+
if not note.is_drum
|
| 352 |
+
] + [
|
| 353 |
+
NoteEventData(pitch=note.pitch, velocity=note.velocity, program=note.program, is_drum=note.is_drum)
|
| 354 |
+
for note in notes
|
| 355 |
+
]
|
| 356 |
+
return times, values
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def num_velocity_bins_from_codec(codec: Codec):
|
| 360 |
+
"""Get number of velocity bins from event codec."""
|
| 361 |
+
lo, hi = codec.event_type_range("velocity")
|
| 362 |
+
return hi - lo
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# segment an array into segments of length n
|
| 366 |
+
def segment(a, n):
|
| 367 |
+
return [a[i : i + n] for i in range(0, len(a), n)]
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def velocity_to_bin(velocity, num_velocity_bins):
|
| 371 |
+
if velocity == 0:
|
| 372 |
+
return 0
|
| 373 |
+
else:
|
| 374 |
+
return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def note_event_data_to_events(
|
| 378 |
+
state: Optional[NoteEncodingState],
|
| 379 |
+
value: NoteEventData,
|
| 380 |
+
codec: Codec,
|
| 381 |
+
) -> Sequence[Event]:
|
| 382 |
+
"""Convert note event data to a sequence of events."""
|
| 383 |
+
if value.velocity is None:
|
| 384 |
+
# onsets only, no program or velocity
|
| 385 |
+
return [Event("pitch", value.pitch)]
|
| 386 |
+
else:
|
| 387 |
+
num_velocity_bins = num_velocity_bins_from_codec(codec)
|
| 388 |
+
velocity_bin = velocity_to_bin(value.velocity, num_velocity_bins)
|
| 389 |
+
if value.program is None:
|
| 390 |
+
# onsets + offsets + velocities only, no programs
|
| 391 |
+
if state is not None:
|
| 392 |
+
state.active_pitches[value.pitch, 0] = velocity_bin
|
| 393 |
+
return [Event("velocity", velocity_bin), Event("pitch", value.pitch)]
|
| 394 |
+
else:
|
| 395 |
+
if value.is_drum:
|
| 396 |
+
# drum events use a separate vocabulary
|
| 397 |
+
return [Event("velocity", velocity_bin), Event("drum", value.pitch)]
|
| 398 |
+
else:
|
| 399 |
+
# program + velocity + pitch
|
| 400 |
+
if state is not None:
|
| 401 |
+
state.active_pitches[(value.pitch, value.program)] = velocity_bin
|
| 402 |
+
return [
|
| 403 |
+
Event("program", value.program),
|
| 404 |
+
Event("velocity", velocity_bin),
|
| 405 |
+
Event("pitch", value.pitch),
|
| 406 |
+
]
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def note_encoding_state_to_events(state: NoteEncodingState) -> Sequence[Event]:
|
| 410 |
+
"""Output program and pitch events for active notes plus a final tie event."""
|
| 411 |
+
events = []
|
| 412 |
+
for pitch, program in sorted(state.active_pitches.keys(), key=lambda k: k[::-1]):
|
| 413 |
+
if state.active_pitches[pitch, program]:
|
| 414 |
+
events += [Event("program", program), Event("pitch", pitch)]
|
| 415 |
+
events.append(Event("tie", 0))
|
| 416 |
+
return events
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def encode_and_index_events(
|
| 420 |
+
state, event_times, event_values, codec, frame_times, encode_event_fn, encoding_state_to_events_fn=None
|
| 421 |
+
):
|
| 422 |
+
"""Encode a sequence of timed events and index to audio frame times.
|
| 423 |
+
|
| 424 |
+
Encodes time shifts as repeated single step shifts for later run length encoding.
|
| 425 |
+
|
| 426 |
+
Optionally, also encodes a sequence of "state events", keeping track of the current encoding state at each audio
|
| 427 |
+
frame. This can be used e.g. to prepend events representing the current state to a targets segment.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
state: Initial event encoding state.
|
| 431 |
+
event_times: Sequence of event times.
|
| 432 |
+
event_values: Sequence of event values.
|
| 433 |
+
encode_event_fn: Function that transforms event value into a sequence of one
|
| 434 |
+
or more Event objects.
|
| 435 |
+
codec: An Codec object that maps Event objects to indices.
|
| 436 |
+
frame_times: Time for every audio frame.
|
| 437 |
+
encoding_state_to_events_fn: Function that transforms encoding state into a
|
| 438 |
+
sequence of one or more Event objects.
|
| 439 |
+
|
| 440 |
+
Returns:
|
| 441 |
+
events: Encoded events and shifts. event_start_indices: Corresponding start event index for every audio frame.
|
| 442 |
+
Note: one event can correspond to multiple audio indices due to sampling rate differences. This makes
|
| 443 |
+
splitting sequences tricky because the same event can appear at the end of one sequence and the beginning of
|
| 444 |
+
another.
|
| 445 |
+
event_end_indices: Corresponding end event index for every audio frame. Used
|
| 446 |
+
to ensure when slicing that one chunk ends where the next begins. Should always be true that
|
| 447 |
+
event_end_indices[i] = event_start_indices[i + 1].
|
| 448 |
+
state_events: Encoded "state" events representing the encoding state before
|
| 449 |
+
each event.
|
| 450 |
+
state_event_indices: Corresponding state event index for every audio frame.
|
| 451 |
+
"""
|
| 452 |
+
indices = np.argsort(event_times, kind="stable")
|
| 453 |
+
event_steps = [round(event_times[i] * codec.steps_per_second) for i in indices]
|
| 454 |
+
event_values = [event_values[i] for i in indices]
|
| 455 |
+
|
| 456 |
+
events = []
|
| 457 |
+
state_events = []
|
| 458 |
+
event_start_indices = []
|
| 459 |
+
state_event_indices = []
|
| 460 |
+
|
| 461 |
+
cur_step = 0
|
| 462 |
+
cur_event_idx = 0
|
| 463 |
+
cur_state_event_idx = 0
|
| 464 |
+
|
| 465 |
+
def fill_event_start_indices_to_cur_step():
|
| 466 |
+
while (
|
| 467 |
+
len(event_start_indices) < len(frame_times)
|
| 468 |
+
and frame_times[len(event_start_indices)] < cur_step / codec.steps_per_second
|
| 469 |
+
):
|
| 470 |
+
event_start_indices.append(cur_event_idx)
|
| 471 |
+
state_event_indices.append(cur_state_event_idx)
|
| 472 |
+
|
| 473 |
+
for event_step, event_value in zip(event_steps, event_values):
|
| 474 |
+
while event_step > cur_step:
|
| 475 |
+
events.append(codec.encode_event(Event(type="shift", value=1)))
|
| 476 |
+
cur_step += 1
|
| 477 |
+
fill_event_start_indices_to_cur_step()
|
| 478 |
+
cur_event_idx = len(events)
|
| 479 |
+
cur_state_event_idx = len(state_events)
|
| 480 |
+
if encoding_state_to_events_fn:
|
| 481 |
+
# Dump state to state events *before* processing the next event, because
|
| 482 |
+
# we want to capture the state prior to the occurrence of the event.
|
| 483 |
+
for e in encoding_state_to_events_fn(state):
|
| 484 |
+
state_events.append(codec.encode_event(e))
|
| 485 |
+
|
| 486 |
+
for e in encode_event_fn(state, event_value, codec):
|
| 487 |
+
events.append(codec.encode_event(e))
|
| 488 |
+
|
| 489 |
+
# After the last event, continue filling out the event_start_indices array.
|
| 490 |
+
# The inequality is not strict because if our current step lines up exactly
|
| 491 |
+
# with (the start of) an audio frame, we need to add an additional shift event
|
| 492 |
+
# to "cover" that frame.
|
| 493 |
+
while cur_step / codec.steps_per_second <= frame_times[-1]:
|
| 494 |
+
events.append(codec.encode_event(Event(type="shift", value=1)))
|
| 495 |
+
cur_step += 1
|
| 496 |
+
fill_event_start_indices_to_cur_step()
|
| 497 |
+
cur_event_idx = len(events)
|
| 498 |
+
|
| 499 |
+
# Now fill in event_end_indices. We need this extra array to make sure that
|
| 500 |
+
# when we slice events, each slice ends exactly where the subsequent slice
|
| 501 |
+
# begins.
|
| 502 |
+
event_end_indices = event_start_indices[1:] + [len(events)]
|
| 503 |
+
|
| 504 |
+
events = np.array(events).astype(np.int32)
|
| 505 |
+
state_events = np.array(state_events).astype(np.int32)
|
| 506 |
+
event_start_indices = segment(np.array(event_start_indices).astype(np.int32), TARGET_FEATURE_LENGTH)
|
| 507 |
+
event_end_indices = segment(np.array(event_end_indices).astype(np.int32), TARGET_FEATURE_LENGTH)
|
| 508 |
+
state_event_indices = segment(np.array(state_event_indices).astype(np.int32), TARGET_FEATURE_LENGTH)
|
| 509 |
+
|
| 510 |
+
outputs = []
|
| 511 |
+
for start_indices, end_indices, event_indices in zip(event_start_indices, event_end_indices, state_event_indices):
|
| 512 |
+
outputs.append(
|
| 513 |
+
{
|
| 514 |
+
"inputs": events,
|
| 515 |
+
"event_start_indices": start_indices,
|
| 516 |
+
"event_end_indices": end_indices,
|
| 517 |
+
"state_events": state_events,
|
| 518 |
+
"state_event_indices": event_indices,
|
| 519 |
+
}
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
return outputs
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def extract_sequence_with_indices(features, state_events_end_token=None, feature_key="inputs"):
|
| 526 |
+
"""Extract target sequence corresponding to audio token segment."""
|
| 527 |
+
features = features.copy()
|
| 528 |
+
start_idx = features["event_start_indices"][0]
|
| 529 |
+
end_idx = features["event_end_indices"][-1]
|
| 530 |
+
|
| 531 |
+
features[feature_key] = features[feature_key][start_idx:end_idx]
|
| 532 |
+
|
| 533 |
+
if state_events_end_token is not None:
|
| 534 |
+
# Extract the state events corresponding to the audio start token, and
|
| 535 |
+
# prepend them to the targets array.
|
| 536 |
+
state_event_start_idx = features["state_event_indices"][0]
|
| 537 |
+
state_event_end_idx = state_event_start_idx + 1
|
| 538 |
+
while features["state_events"][state_event_end_idx - 1] != state_events_end_token:
|
| 539 |
+
state_event_end_idx += 1
|
| 540 |
+
features[feature_key] = np.concatenate(
|
| 541 |
+
[
|
| 542 |
+
features["state_events"][state_event_start_idx:state_event_end_idx],
|
| 543 |
+
features[feature_key],
|
| 544 |
+
],
|
| 545 |
+
axis=0,
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
return features
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def map_midi_programs(
|
| 552 |
+
feature, codec: Codec, granularity_type: str = "full", feature_key: str = "inputs"
|
| 553 |
+
) -> Mapping[str, Any]:
|
| 554 |
+
"""Apply MIDI program map to token sequences."""
|
| 555 |
+
granularity = PROGRAM_GRANULARITIES[granularity_type]
|
| 556 |
+
|
| 557 |
+
feature[feature_key] = granularity.tokens_map_fn(feature[feature_key], codec)
|
| 558 |
+
return feature
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
def run_length_encode_shifts_fn(
|
| 562 |
+
features,
|
| 563 |
+
codec: Codec,
|
| 564 |
+
feature_key: str = "inputs",
|
| 565 |
+
state_change_event_types: Sequence[str] = (),
|
| 566 |
+
) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]:
|
| 567 |
+
"""Return a function that run-length encodes shifts for a given codec.
|
| 568 |
+
|
| 569 |
+
Args:
|
| 570 |
+
codec: The Codec to use for shift events.
|
| 571 |
+
feature_key: The feature key for which to run-length encode shifts.
|
| 572 |
+
state_change_event_types: A list of event types that represent state
|
| 573 |
+
changes; tokens corresponding to these event types will be interpreted as state changes and redundant ones
|
| 574 |
+
will be removed.
|
| 575 |
+
|
| 576 |
+
Returns:
|
| 577 |
+
A preprocessing function that run-length encodes single-step shifts.
|
| 578 |
+
"""
|
| 579 |
+
state_change_event_ranges = [codec.event_type_range(event_type) for event_type in state_change_event_types]
|
| 580 |
+
|
| 581 |
+
def run_length_encode_shifts(features: MutableMapping[str, Any]) -> Mapping[str, Any]:
|
| 582 |
+
"""Combine leading/interior shifts, trim trailing shifts.
|
| 583 |
+
|
| 584 |
+
Args:
|
| 585 |
+
features: Dict of features to process.
|
| 586 |
+
|
| 587 |
+
Returns:
|
| 588 |
+
A dict of features.
|
| 589 |
+
"""
|
| 590 |
+
events = features[feature_key]
|
| 591 |
+
|
| 592 |
+
shift_steps = 0
|
| 593 |
+
total_shift_steps = 0
|
| 594 |
+
output = np.array([], dtype=np.int32)
|
| 595 |
+
|
| 596 |
+
current_state = np.zeros(len(state_change_event_ranges), dtype=np.int32)
|
| 597 |
+
|
| 598 |
+
for event in events:
|
| 599 |
+
if codec.is_shift_event_index(event):
|
| 600 |
+
shift_steps += 1
|
| 601 |
+
total_shift_steps += 1
|
| 602 |
+
|
| 603 |
+
else:
|
| 604 |
+
# If this event is a state change and has the same value as the current
|
| 605 |
+
# state, we can skip it entirely.
|
| 606 |
+
is_redundant = False
|
| 607 |
+
for i, (min_index, max_index) in enumerate(state_change_event_ranges):
|
| 608 |
+
if (min_index <= event) and (event <= max_index):
|
| 609 |
+
if current_state[i] == event:
|
| 610 |
+
is_redundant = True
|
| 611 |
+
current_state[i] = event
|
| 612 |
+
if is_redundant:
|
| 613 |
+
continue
|
| 614 |
+
|
| 615 |
+
# Once we've reached a non-shift event, RLE all previous shift events
|
| 616 |
+
# before outputting the non-shift event.
|
| 617 |
+
if shift_steps > 0:
|
| 618 |
+
shift_steps = total_shift_steps
|
| 619 |
+
while shift_steps > 0:
|
| 620 |
+
output_steps = np.minimum(codec.max_shift_steps, shift_steps)
|
| 621 |
+
output = np.concatenate([output, [output_steps]], axis=0)
|
| 622 |
+
shift_steps -= output_steps
|
| 623 |
+
output = np.concatenate([output, [event]], axis=0)
|
| 624 |
+
|
| 625 |
+
features[feature_key] = output
|
| 626 |
+
return features
|
| 627 |
+
|
| 628 |
+
return run_length_encode_shifts(features)
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def note_representation_processor_chain(features, codec: Codec, note_representation_config: NoteRepresentationConfig):
|
| 632 |
+
tie_token = codec.encode_event(Event("tie", 0))
|
| 633 |
+
state_events_end_token = tie_token if note_representation_config.include_ties else None
|
| 634 |
+
|
| 635 |
+
features = extract_sequence_with_indices(
|
| 636 |
+
features, state_events_end_token=state_events_end_token, feature_key="inputs"
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
features = map_midi_programs(features, codec)
|
| 640 |
+
|
| 641 |
+
features = run_length_encode_shifts_fn(features, codec, state_change_event_types=["velocity", "program"])
|
| 642 |
+
|
| 643 |
+
return features
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
class MidiProcessor:
|
| 647 |
+
def __init__(self):
|
| 648 |
+
self.codec = Codec(
|
| 649 |
+
max_shift_steps=DEFAULT_MAX_SHIFT_SECONDS * DEFAULT_STEPS_PER_SECOND,
|
| 650 |
+
steps_per_second=DEFAULT_STEPS_PER_SECOND,
|
| 651 |
+
event_ranges=[
|
| 652 |
+
EventRange("pitch", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH),
|
| 653 |
+
EventRange("velocity", 0, DEFAULT_NUM_VELOCITY_BINS),
|
| 654 |
+
EventRange("tie", 0, 0),
|
| 655 |
+
EventRange("program", note_seq.MIN_MIDI_PROGRAM, note_seq.MAX_MIDI_PROGRAM),
|
| 656 |
+
EventRange("drum", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH),
|
| 657 |
+
],
|
| 658 |
+
)
|
| 659 |
+
self.tokenizer = Tokenizer(self.codec.num_classes)
|
| 660 |
+
self.note_representation_config = NoteRepresentationConfig(onsets_only=False, include_ties=True)
|
| 661 |
+
|
| 662 |
+
def __call__(self, midi: Union[bytes, os.PathLike, str]):
|
| 663 |
+
if not isinstance(midi, bytes):
|
| 664 |
+
with open(midi, "rb") as f:
|
| 665 |
+
midi = f.read()
|
| 666 |
+
|
| 667 |
+
ns = note_seq.midi_to_note_sequence(midi)
|
| 668 |
+
ns_sus = note_seq.apply_sustain_control_changes(ns)
|
| 669 |
+
|
| 670 |
+
for note in ns_sus.notes:
|
| 671 |
+
if not note.is_drum:
|
| 672 |
+
note.program = program_to_slakh_program(note.program)
|
| 673 |
+
|
| 674 |
+
samples = np.zeros(int(ns_sus.total_time * SAMPLE_RATE))
|
| 675 |
+
|
| 676 |
+
_, frame_times = audio_to_frames(samples, HOP_SIZE, FRAME_RATE)
|
| 677 |
+
times, values = note_sequence_to_onsets_and_offsets_and_programs(ns_sus)
|
| 678 |
+
|
| 679 |
+
events = encode_and_index_events(
|
| 680 |
+
state=NoteEncodingState(),
|
| 681 |
+
event_times=times,
|
| 682 |
+
event_values=values,
|
| 683 |
+
frame_times=frame_times,
|
| 684 |
+
codec=self.codec,
|
| 685 |
+
encode_event_fn=note_event_data_to_events,
|
| 686 |
+
encoding_state_to_events_fn=note_encoding_state_to_events,
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
events = [
|
| 690 |
+
note_representation_processor_chain(event, self.codec, self.note_representation_config) for event in events
|
| 691 |
+
]
|
| 692 |
+
input_tokens = [self.tokenizer.encode(event["inputs"]) for event in events]
|
| 693 |
+
|
| 694 |
+
return input_tokens
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/spectrogram_diffusion/notes_encoder.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The Music Spectrogram Diffusion Authors.
|
| 2 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import paddle
|
| 17 |
+
import paddle.nn as nn
|
| 18 |
+
|
| 19 |
+
from ppdiffusers.transformers.model_utils import ModuleUtilsMixin
|
| 20 |
+
from ppdiffusers.transformers.t5.modeling import T5Block, T5Config, T5LayerNorm
|
| 21 |
+
|
| 22 |
+
from ...configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from ...models import ModelMixin
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SpectrogramNotesEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
|
| 27 |
+
@register_to_config
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
max_length: int,
|
| 31 |
+
vocab_size: int,
|
| 32 |
+
d_model: int,
|
| 33 |
+
dropout_rate: float,
|
| 34 |
+
num_layers: int,
|
| 35 |
+
num_heads: int,
|
| 36 |
+
d_kv: int,
|
| 37 |
+
d_ff: int,
|
| 38 |
+
feed_forward_proj: str,
|
| 39 |
+
is_decoder: bool = False,
|
| 40 |
+
):
|
| 41 |
+
super().__init__()
|
| 42 |
+
|
| 43 |
+
self.token_embedder = nn.Embedding(vocab_size, d_model)
|
| 44 |
+
|
| 45 |
+
self.position_encoding = nn.Embedding(max_length, d_model)
|
| 46 |
+
self.position_encoding.weight.stop_gradient = True
|
| 47 |
+
|
| 48 |
+
self.dropout_pre = nn.Dropout(p=dropout_rate)
|
| 49 |
+
|
| 50 |
+
t5config = T5Config(
|
| 51 |
+
vocab_size=vocab_size,
|
| 52 |
+
d_model=d_model,
|
| 53 |
+
num_heads=num_heads,
|
| 54 |
+
d_kv=d_kv,
|
| 55 |
+
d_ff=d_ff,
|
| 56 |
+
dropout_rate=dropout_rate,
|
| 57 |
+
feed_forward_proj=feed_forward_proj,
|
| 58 |
+
is_decoder=is_decoder,
|
| 59 |
+
is_encoder_decoder=False,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
self.encoders = nn.LayerList()
|
| 63 |
+
for lyr_num in range(num_layers):
|
| 64 |
+
lyr = T5Block(t5config)
|
| 65 |
+
self.encoders.append(lyr)
|
| 66 |
+
|
| 67 |
+
self.layer_norm = T5LayerNorm(d_model)
|
| 68 |
+
self.dropout_post = nn.Dropout(p=dropout_rate)
|
| 69 |
+
|
| 70 |
+
def forward(self, encoder_input_tokens, encoder_inputs_mask):
|
| 71 |
+
x = self.token_embedder(encoder_input_tokens)
|
| 72 |
+
|
| 73 |
+
seq_length = encoder_input_tokens.shape[1]
|
| 74 |
+
inputs_positions = paddle.arange(seq_length)
|
| 75 |
+
x += self.position_encoding(inputs_positions)
|
| 76 |
+
|
| 77 |
+
x = self.dropout_pre(x)
|
| 78 |
+
|
| 79 |
+
# inverted the attention mask
|
| 80 |
+
input_shape = encoder_input_tokens.shape
|
| 81 |
+
extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape)
|
| 82 |
+
|
| 83 |
+
for lyr in self.encoders:
|
| 84 |
+
x = lyr(x, extended_attention_mask)[0]
|
| 85 |
+
x = self.layer_norm(x)
|
| 86 |
+
|
| 87 |
+
return self.dropout_post(x), encoder_inputs_mask
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The Music Spectrogram Diffusion Authors.
|
| 2 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import paddle
|
| 21 |
+
|
| 22 |
+
from ...models import T5FilmDecoder
|
| 23 |
+
from ...schedulers import DDPMScheduler
|
| 24 |
+
from ...utils import is_fastdeploy_available, logging
|
| 25 |
+
from ...utils.paddle_utils import randn_tensor
|
| 26 |
+
|
| 27 |
+
if is_fastdeploy_available():
|
| 28 |
+
from ..fastdeploy_utils import FastDeployRuntimeModel
|
| 29 |
+
|
| 30 |
+
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
| 31 |
+
from .continous_encoder import SpectrogramContEncoder
|
| 32 |
+
from .notes_encoder import SpectrogramNotesEncoder
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 35 |
+
|
| 36 |
+
TARGET_FEATURE_LENGTH = 256
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SpectrogramDiffusionPipeline(DiffusionPipeline):
|
| 40 |
+
r"""
|
| 41 |
+
Pipeline for unconditional audio generation.
|
| 42 |
+
|
| 43 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 44 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
notes_encoder ([`SpectrogramNotesEncoder`]):
|
| 48 |
+
continuous_encoder ([`SpectrogramContEncoder`]):
|
| 49 |
+
decoder ([`T5FilmDecoder`]):
|
| 50 |
+
A [`T5FilmDecoder`] to denoise the encoded audio latents.
|
| 51 |
+
scheduler ([`DDPMScheduler`]):
|
| 52 |
+
A scheduler to be used in combination with `decoder` to denoise the encoded audio latents.
|
| 53 |
+
melgan ([`FastDeployRuntimeModel`]):
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
_optional_components = ["melgan"]
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
notes_encoder: SpectrogramNotesEncoder,
|
| 61 |
+
continuous_encoder: SpectrogramContEncoder,
|
| 62 |
+
decoder: T5FilmDecoder,
|
| 63 |
+
scheduler: DDPMScheduler,
|
| 64 |
+
melgan: FastDeployRuntimeModel if is_fastdeploy_available() else Any,
|
| 65 |
+
) -> None:
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
# From MELGAN
|
| 69 |
+
self.min_value = math.log(1e-5) # Matches MelGAN training.
|
| 70 |
+
self.max_value = 4.0 # Largest value for most examples
|
| 71 |
+
self.n_dims = 128
|
| 72 |
+
|
| 73 |
+
self.register_modules(
|
| 74 |
+
notes_encoder=notes_encoder,
|
| 75 |
+
continuous_encoder=continuous_encoder,
|
| 76 |
+
decoder=decoder,
|
| 77 |
+
scheduler=scheduler,
|
| 78 |
+
melgan=melgan,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def scale_features(self, features, output_range=(-1.0, 1.0), clip=False):
|
| 82 |
+
"""Linearly scale features to network outputs range."""
|
| 83 |
+
min_out, max_out = output_range
|
| 84 |
+
if clip:
|
| 85 |
+
features = paddle.clip(features, self.min_value, self.max_value)
|
| 86 |
+
# Scale to [0, 1].
|
| 87 |
+
zero_one = (features - self.min_value) / (self.max_value - self.min_value)
|
| 88 |
+
# Scale to [min_out, max_out].
|
| 89 |
+
return zero_one * (max_out - min_out) + min_out
|
| 90 |
+
|
| 91 |
+
def scale_to_features(self, outputs, input_range=(-1.0, 1.0), clip=False):
|
| 92 |
+
"""Invert by linearly scaling network outputs to features range."""
|
| 93 |
+
min_out, max_out = input_range
|
| 94 |
+
outputs = paddle.clip(outputs, min_out, max_out) if clip else outputs
|
| 95 |
+
# Scale to [0, 1].
|
| 96 |
+
zero_one = (outputs - min_out) / (max_out - min_out)
|
| 97 |
+
# Scale to [self.min_value, self.max_value].
|
| 98 |
+
return zero_one * (self.max_value - self.min_value) + self.min_value
|
| 99 |
+
|
| 100 |
+
def encode(self, input_tokens, continuous_inputs, continuous_mask):
|
| 101 |
+
tokens_mask = input_tokens > 0
|
| 102 |
+
tokens_encoded, tokens_mask = self.notes_encoder(
|
| 103 |
+
encoder_input_tokens=input_tokens, encoder_inputs_mask=tokens_mask
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
continuous_encoded, continuous_mask = self.continuous_encoder(
|
| 107 |
+
encoder_inputs=continuous_inputs.cast(self.continuous_encoder.dtype), encoder_inputs_mask=continuous_mask
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
tokens_mask = tokens_mask.cast(tokens_encoded.dtype)
|
| 111 |
+
continuous_mask = continuous_mask.cast(continuous_encoded.dtype)
|
| 112 |
+
return [(tokens_encoded, tokens_mask), (continuous_encoded, continuous_mask)]
|
| 113 |
+
|
| 114 |
+
def decode(self, encodings_and_masks, input_tokens, noise_time):
|
| 115 |
+
timesteps = noise_time
|
| 116 |
+
if not paddle.is_tensor(x=timesteps):
|
| 117 |
+
timesteps = paddle.to_tensor(data=[timesteps], dtype="int64")
|
| 118 |
+
elif paddle.is_tensor(x=timesteps) and len(timesteps.shape) == 0:
|
| 119 |
+
timesteps = timesteps[None]
|
| 120 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 121 |
+
timesteps = timesteps * paddle.ones(shape=input_tokens.shape[0], dtype=timesteps.dtype)
|
| 122 |
+
logits = self.decoder(
|
| 123 |
+
encodings_and_masks=encodings_and_masks, decoder_input_tokens=input_tokens, decoder_noise_time=timesteps
|
| 124 |
+
)
|
| 125 |
+
return logits
|
| 126 |
+
|
| 127 |
+
@paddle.no_grad()
|
| 128 |
+
def __call__(
|
| 129 |
+
self,
|
| 130 |
+
input_tokens: List[List[int]],
|
| 131 |
+
generator: Optional[paddle.Generator] = None,
|
| 132 |
+
num_inference_steps: int = 100,
|
| 133 |
+
return_dict: bool = True,
|
| 134 |
+
output_type: str = "numpy",
|
| 135 |
+
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
|
| 136 |
+
callback_steps: int = 1,
|
| 137 |
+
) -> Union[AudioPipelineOutput, Tuple]:
|
| 138 |
+
if (callback_steps is None) or (
|
| 139 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 140 |
+
):
|
| 141 |
+
raise ValueError(
|
| 142 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 143 |
+
f" {type(callback_steps)}."
|
| 144 |
+
)
|
| 145 |
+
r"""
|
| 146 |
+
The call function to the pipeline for generation.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
input_tokens (`List[List[int]]`):
|
| 150 |
+
generator (`paddle.Generator` or `List[paddle.Generator]`, *optional*):
|
| 151 |
+
A [`topaddlerch.Generator`] to make generation deterministic.
|
| 152 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
| 153 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
|
| 154 |
+
expense of slower inference.
|
| 155 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 156 |
+
Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
|
| 157 |
+
output_type (`str`, *optional*, defaults to `"numpy"`):
|
| 158 |
+
The output format of the generated audio.
|
| 159 |
+
callback (`Callable`, *optional*):
|
| 160 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
| 161 |
+
following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
|
| 162 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 163 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
| 164 |
+
every step.
|
| 165 |
+
|
| 166 |
+
Example:
|
| 167 |
+
|
| 168 |
+
```py
|
| 169 |
+
>>> from ppdiffusers import SpectrogramDiffusionPipeline, MidiProcessor
|
| 170 |
+
|
| 171 |
+
>>> pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion")
|
| 172 |
+
>>> processor = MidiProcessor()
|
| 173 |
+
|
| 174 |
+
>>> # Download MIDI from: wget http://www.piano-midi.de/midis/beethoven/beethoven_hammerklavier_2.mid
|
| 175 |
+
>>> output = pipe(processor("beethoven_hammerklavier_2.mid"))
|
| 176 |
+
|
| 177 |
+
>>> audio = output.audios[0]
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
[`pipelines.AudioPipelineOutput`] or `tuple`:
|
| 182 |
+
If `return_dict` is `True`, [`pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is
|
| 183 |
+
returned where the first element is a list with the generated audio.
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
pred_mel = np.zeros([1, TARGET_FEATURE_LENGTH, self.n_dims], dtype=np.float32)
|
| 187 |
+
full_pred_mel = np.zeros([1, 0, self.n_dims], np.float32)
|
| 188 |
+
ones = paddle.ones((1, TARGET_FEATURE_LENGTH), dtype=paddle.bool)
|
| 189 |
+
|
| 190 |
+
for i, encoder_input_tokens in enumerate(input_tokens):
|
| 191 |
+
if i == 0:
|
| 192 |
+
encoder_continuous_inputs = paddle.to_tensor(pred_mel[:1].copy()).cast(dtype=self.decoder.dtype)
|
| 193 |
+
# The first chunk has no previous context.
|
| 194 |
+
encoder_continuous_mask = paddle.zeros((1, TARGET_FEATURE_LENGTH), dtype=paddle.bool)
|
| 195 |
+
else:
|
| 196 |
+
# The full song pipeline does not feed in a context feature, so the mask
|
| 197 |
+
# will be all 0s after the feature converter. Because we know we're
|
| 198 |
+
# feeding in a full context chunk from the previous prediction, set it
|
| 199 |
+
# to all 1s.
|
| 200 |
+
encoder_continuous_mask = ones
|
| 201 |
+
|
| 202 |
+
encoder_continuous_inputs = self.scale_features(
|
| 203 |
+
encoder_continuous_inputs, output_range=[-1.0, 1.0], clip=True
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
encodings_and_masks = self.encode(
|
| 207 |
+
input_tokens=paddle.to_tensor([encoder_input_tokens], dtype="int32"),
|
| 208 |
+
continuous_inputs=encoder_continuous_inputs,
|
| 209 |
+
continuous_mask=encoder_continuous_mask.cast(dtype="int32"),
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Sample encoder_continuous_inputs shaped gaussian noise to begin loop
|
| 213 |
+
x = randn_tensor(
|
| 214 |
+
shape=encoder_continuous_inputs.shape,
|
| 215 |
+
generator=generator,
|
| 216 |
+
dtype=self.decoder.dtype,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# set step values
|
| 220 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
| 221 |
+
|
| 222 |
+
# Denoising diffusion loop
|
| 223 |
+
for j, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
| 224 |
+
output = self.decode(
|
| 225 |
+
encodings_and_masks=encodings_and_masks,
|
| 226 |
+
input_tokens=x,
|
| 227 |
+
noise_time=t / self.scheduler.config.num_train_timesteps, # rescale to [0, 1)
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Compute previous output: x_t -> x_t-1
|
| 231 |
+
x = self.scheduler.step(output, t, x, generator=generator).prev_sample
|
| 232 |
+
|
| 233 |
+
mel = self.scale_to_features(x, input_range=[-1.0, 1.0])
|
| 234 |
+
encoder_continuous_inputs = mel[:1]
|
| 235 |
+
pred_mel = mel.cast(dtype="float32").cpu().numpy()
|
| 236 |
+
|
| 237 |
+
full_pred_mel = np.concatenate([full_pred_mel, pred_mel[:1]], axis=1)
|
| 238 |
+
|
| 239 |
+
# call the callback, if provided
|
| 240 |
+
if callback is not None and i % callback_steps == 0:
|
| 241 |
+
callback(i, full_pred_mel)
|
| 242 |
+
|
| 243 |
+
logger.info("Generated segment", i)
|
| 244 |
+
|
| 245 |
+
if output_type == "numpy" and not is_fastdeploy_available():
|
| 246 |
+
raise ValueError(
|
| 247 |
+
"Cannot return output in 'np' format if FastDeploy is not available. Make sure to have FastDeploy installed or set 'output_type' to 'mel'."
|
| 248 |
+
)
|
| 249 |
+
elif output_type == "numpy" and self.melgan is None:
|
| 250 |
+
raise ValueError(
|
| 251 |
+
"Cannot return output in 'np' format if melgan component is not defined. Make sure to define `self.melgan` or set 'output_type' to 'mel'."
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
if output_type == "numpy":
|
| 255 |
+
output = self.melgan(input_features=full_pred_mel.astype(np.float32))[0]
|
| 256 |
+
else:
|
| 257 |
+
output = full_pred_mel
|
| 258 |
+
|
| 259 |
+
if not return_dict:
|
| 260 |
+
return (output,)
|
| 261 |
+
|
| 262 |
+
return AudioPipelineOutput(audios=output)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/stable_diffusion/__init__.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import TYPE_CHECKING
|
| 16 |
+
|
| 17 |
+
from ...utils import (
|
| 18 |
+
PPDIFFUSERS_SLOW_IMPORT,
|
| 19 |
+
OptionalDependencyNotAvailable,
|
| 20 |
+
_LazyModule,
|
| 21 |
+
get_objects_from_module,
|
| 22 |
+
is_fastdeploy_available,
|
| 23 |
+
is_k_diffusion_available,
|
| 24 |
+
is_k_diffusion_version,
|
| 25 |
+
is_paddle_available,
|
| 26 |
+
is_paddlenlp_available,
|
| 27 |
+
is_paddlenlp_version,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
_dummy_objects = {}
|
| 31 |
+
_additional_imports = {}
|
| 32 |
+
_import_structure = {"pipeline_output": ["StableDiffusionPipelineOutput"]}
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 36 |
+
raise OptionalDependencyNotAvailable()
|
| 37 |
+
except OptionalDependencyNotAvailable:
|
| 38 |
+
from ...utils import dummy_paddle_and_paddlenlp_objects # noqa F403
|
| 39 |
+
|
| 40 |
+
_dummy_objects.update(get_objects_from_module(dummy_paddle_and_paddlenlp_objects))
|
| 41 |
+
else:
|
| 42 |
+
_import_structure["clip_image_project_model"] = ["CLIPImageProjection"]
|
| 43 |
+
_import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"]
|
| 44 |
+
_import_structure["pipeline_stable_diffusion"] = ["StableDiffusionPipeline"]
|
| 45 |
+
_import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
|
| 46 |
+
_import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"]
|
| 47 |
+
_import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"]
|
| 48 |
+
_import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"]
|
| 49 |
+
_import_structure["pipeline_stable_diffusion_img2img"] = ["StableDiffusionImg2ImgPipeline"]
|
| 50 |
+
_import_structure["pipeline_stable_diffusion_inpaint"] = ["StableDiffusionInpaintPipeline"]
|
| 51 |
+
_import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"]
|
| 52 |
+
_import_structure["pipeline_stable_diffusion_instruct_pix2pix"] = ["StableDiffusionInstructPix2PixPipeline"]
|
| 53 |
+
_import_structure["pipeline_stable_diffusion_latent_upscale"] = ["StableDiffusionLatentUpscalePipeline"]
|
| 54 |
+
_import_structure["pipeline_stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"]
|
| 55 |
+
_import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"]
|
| 56 |
+
_import_structure["pipeline_stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"]
|
| 57 |
+
_import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"]
|
| 58 |
+
_import_structure["pipeline_stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
|
| 59 |
+
_import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"]
|
| 60 |
+
_import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"]
|
| 61 |
+
_import_structure["pipeline_stable_unclip_img2img"] = ["StableUnCLIPImg2ImgPipeline"]
|
| 62 |
+
_import_structure["safety_checker"] = ["StableDiffusionSafetyChecker"]
|
| 63 |
+
_import_structure["stable_unclip_image_normalizer"] = ["StableUnCLIPImageNormalizer"]
|
| 64 |
+
_import_structure["pipeline_paddleinfer_cycle_diffusion"] = ["PaddleInferCycleDiffusionPipeline"]
|
| 65 |
+
_import_structure["pipeline_paddleinfer_stable_diffusion"] = ["PaddleInferStableDiffusionPipeline"]
|
| 66 |
+
_import_structure["pipeline_paddleinfer_stable_diffusion_img2img"] = ["PaddleInferStableDiffusionImg2ImgPipeline"]
|
| 67 |
+
_import_structure["pipeline_paddleinfer_stable_diffusion_inpaint"] = ["PaddleInferStableDiffusionInpaintPipeline"]
|
| 68 |
+
_import_structure["pipeline_paddleinfer_stable_diffusion_inpaint_legacy"] = [
|
| 69 |
+
"PaddleInferStableDiffusionInpaintPipelineLegacy"
|
| 70 |
+
]
|
| 71 |
+
_import_structure["pipeline_paddleinfer_stable_diffusion_mega"] = ["PaddleInferStableDiffusionMegaPipeline"]
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
if not (is_paddlenlp_available() and is_paddle_available() and is_paddlenlp_version(">=", "2.6.0")):
|
| 75 |
+
raise OptionalDependencyNotAvailable()
|
| 76 |
+
except OptionalDependencyNotAvailable:
|
| 77 |
+
from ...utils.dummy_paddle_and_paddlenlp_objects import (
|
| 78 |
+
StableDiffusionImageVariationPipeline,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
_dummy_objects.update({"StableDiffusionImageVariationPipeline": StableDiffusionImageVariationPipeline})
|
| 82 |
+
else:
|
| 83 |
+
_import_structure["pipeline_stable_diffusion_image_variation"] = ["StableDiffusionImageVariationPipeline"]
|
| 84 |
+
try:
|
| 85 |
+
if not (is_paddlenlp_available() and is_paddle_available() and is_paddlenlp_version(">=", "2.6.0")):
|
| 86 |
+
raise OptionalDependencyNotAvailable()
|
| 87 |
+
except OptionalDependencyNotAvailable:
|
| 88 |
+
from ...utils.dummy_paddle_and_paddlenlp_objects import (
|
| 89 |
+
StableDiffusionDepth2ImgPipeline,
|
| 90 |
+
StableDiffusionDiffEditPipeline,
|
| 91 |
+
StableDiffusionPix2PixZeroPipeline,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
_dummy_objects.update(
|
| 95 |
+
{
|
| 96 |
+
"StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline,
|
| 97 |
+
"StableDiffusionDiffEditPipeline": StableDiffusionDiffEditPipeline,
|
| 98 |
+
"StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline,
|
| 99 |
+
}
|
| 100 |
+
)
|
| 101 |
+
else:
|
| 102 |
+
_import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"]
|
| 103 |
+
_import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
|
| 104 |
+
_import_structure["pipeline_stable_diffusion_pix2pix_zero"] = ["StableDiffusionPix2PixZeroPipeline"]
|
| 105 |
+
try:
|
| 106 |
+
if not (
|
| 107 |
+
is_paddle_available()
|
| 108 |
+
and is_paddlenlp_available()
|
| 109 |
+
and is_k_diffusion_available()
|
| 110 |
+
and is_k_diffusion_version(">=", "0.0.12")
|
| 111 |
+
):
|
| 112 |
+
raise OptionalDependencyNotAvailable()
|
| 113 |
+
except OptionalDependencyNotAvailable:
|
| 114 |
+
from ...utils import dummy_paddle_and_paddlenlp_and_k_diffusion_objects
|
| 115 |
+
|
| 116 |
+
_dummy_objects.update(get_objects_from_module(dummy_paddle_and_paddlenlp_and_k_diffusion_objects))
|
| 117 |
+
else:
|
| 118 |
+
_import_structure["pipeline_stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"]
|
| 119 |
+
try:
|
| 120 |
+
if not (is_paddlenlp_available() and is_fastdeploy_available()):
|
| 121 |
+
raise OptionalDependencyNotAvailable()
|
| 122 |
+
except OptionalDependencyNotAvailable:
|
| 123 |
+
from ...utils import dummy_fastdeploy_objects # noqa F403
|
| 124 |
+
|
| 125 |
+
_dummy_objects.update(get_objects_from_module(dummy_fastdeploy_objects))
|
| 126 |
+
else:
|
| 127 |
+
_import_structure["pipeline_fastdeploy_stable_diffusion"] = ["FastDeployStableDiffusionPipeline"]
|
| 128 |
+
_import_structure["pipeline_fastdeploy_stable_diffusion_img2img"] = ["FastDeployStableDiffusionImg2ImgPipeline"]
|
| 129 |
+
_import_structure["pipeline_fastdeploy_stable_diffusion_inpaint"] = ["FastDeployStableDiffusionInpaintPipeline"]
|
| 130 |
+
_import_structure["pipeline_fastdeploy_stable_diffusion_inpaint_legacy"] = [
|
| 131 |
+
"FastDeployStableDiffusionInpaintPipelineLegacy"
|
| 132 |
+
]
|
| 133 |
+
# new add
|
| 134 |
+
_import_structure["pipeline_fastdeploy_stable_diffusion_mega"] = ["FastDeployStableDiffusionMegaPipeline"]
|
| 135 |
+
_import_structure["pipeline_fastdeploy_cycle_diffusion"] = ["FastDeployCycleDiffusionPipeline"]
|
| 136 |
+
_import_structure["pipeline_fastdeploy_stable_diffusion_upscale"] = ["FastDeployStableDiffusionUpscalePipeline"]
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
if TYPE_CHECKING or PPDIFFUSERS_SLOW_IMPORT:
|
| 140 |
+
try:
|
| 141 |
+
if not (is_paddlenlp_available() and is_paddle_available()):
|
| 142 |
+
raise OptionalDependencyNotAvailable()
|
| 143 |
+
|
| 144 |
+
except OptionalDependencyNotAvailable:
|
| 145 |
+
from ...utils.dummy_paddle_and_paddlenlp_objects import *
|
| 146 |
+
|
| 147 |
+
else:
|
| 148 |
+
from .clip_image_project_model import CLIPImageProjection
|
| 149 |
+
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
|
| 150 |
+
from .pipeline_output import StableDiffusionPipelineOutput
|
| 151 |
+
|
| 152 |
+
# paddleinfer
|
| 153 |
+
from .pipeline_paddleinfer_cycle_diffusion import (
|
| 154 |
+
PaddleInferCycleDiffusionPipeline,
|
| 155 |
+
)
|
| 156 |
+
from .pipeline_paddleinfer_stable_diffusion import (
|
| 157 |
+
PaddleInferStableDiffusionPipeline,
|
| 158 |
+
)
|
| 159 |
+
from .pipeline_paddleinfer_stable_diffusion_img2img import (
|
| 160 |
+
PaddleInferStableDiffusionImg2ImgPipeline,
|
| 161 |
+
)
|
| 162 |
+
from .pipeline_paddleinfer_stable_diffusion_inpaint import (
|
| 163 |
+
PaddleInferStableDiffusionInpaintPipeline,
|
| 164 |
+
)
|
| 165 |
+
from .pipeline_paddleinfer_stable_diffusion_inpaint_legacy import (
|
| 166 |
+
PaddleInferStableDiffusionInpaintPipelineLegacy,
|
| 167 |
+
)
|
| 168 |
+
from .pipeline_paddleinfer_stable_diffusion_mega import (
|
| 169 |
+
PaddleInferStableDiffusionMegaPipeline,
|
| 170 |
+
)
|
| 171 |
+
from .pipeline_stable_diffusion import StableDiffusionPipeline
|
| 172 |
+
from .pipeline_stable_diffusion_attend_and_excite import (
|
| 173 |
+
StableDiffusionAttendAndExcitePipeline,
|
| 174 |
+
)
|
| 175 |
+
from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline
|
| 176 |
+
from .pipeline_stable_diffusion_gligen_text_image import (
|
| 177 |
+
StableDiffusionGLIGENTextImagePipeline,
|
| 178 |
+
)
|
| 179 |
+
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
| 180 |
+
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
| 181 |
+
from .pipeline_stable_diffusion_inpaint_legacy import (
|
| 182 |
+
StableDiffusionInpaintPipelineLegacy,
|
| 183 |
+
)
|
| 184 |
+
from .pipeline_stable_diffusion_instruct_pix2pix import (
|
| 185 |
+
StableDiffusionInstructPix2PixPipeline,
|
| 186 |
+
)
|
| 187 |
+
from .pipeline_stable_diffusion_latent_upscale import (
|
| 188 |
+
StableDiffusionLatentUpscalePipeline,
|
| 189 |
+
)
|
| 190 |
+
from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
|
| 191 |
+
from .pipeline_stable_diffusion_model_editing import (
|
| 192 |
+
StableDiffusionModelEditingPipeline,
|
| 193 |
+
)
|
| 194 |
+
from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline
|
| 195 |
+
from .pipeline_stable_diffusion_paradigms import (
|
| 196 |
+
StableDiffusionParadigmsPipeline,
|
| 197 |
+
)
|
| 198 |
+
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
|
| 199 |
+
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
|
| 200 |
+
from .pipeline_stable_unclip import StableUnCLIPPipeline
|
| 201 |
+
from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline
|
| 202 |
+
from .safety_checker import StableDiffusionSafetyChecker
|
| 203 |
+
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
if not (is_paddlenlp_available() and is_paddle_available() and is_paddlenlp_version(">=", "2.6.0")):
|
| 207 |
+
raise OptionalDependencyNotAvailable()
|
| 208 |
+
except OptionalDependencyNotAvailable:
|
| 209 |
+
from ...utils.dummy_paddle_and_paddlenlp_objects import (
|
| 210 |
+
StableDiffusionImageVariationPipeline,
|
| 211 |
+
)
|
| 212 |
+
else:
|
| 213 |
+
from .pipeline_stable_diffusion_image_variation import (
|
| 214 |
+
StableDiffusionImageVariationPipeline,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
try:
|
| 218 |
+
if not (is_paddlenlp_available() and is_paddle_available() and is_paddlenlp_version(">=", "2.6.0")):
|
| 219 |
+
raise OptionalDependencyNotAvailable()
|
| 220 |
+
except OptionalDependencyNotAvailable:
|
| 221 |
+
from ...utils.dummy_paddle_and_paddlenlp_objects import (
|
| 222 |
+
StableDiffusionDepth2ImgPipeline,
|
| 223 |
+
StableDiffusionDiffEditPipeline,
|
| 224 |
+
StableDiffusionPix2PixZeroPipeline,
|
| 225 |
+
)
|
| 226 |
+
else:
|
| 227 |
+
from .pipeline_stable_diffusion_depth2img import (
|
| 228 |
+
StableDiffusionDepth2ImgPipeline,
|
| 229 |
+
)
|
| 230 |
+
from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
|
| 231 |
+
from .pipeline_stable_diffusion_pix2pix_zero import (
|
| 232 |
+
StableDiffusionPix2PixZeroPipeline,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
try:
|
| 236 |
+
if not (
|
| 237 |
+
is_paddle_available()
|
| 238 |
+
and is_paddlenlp_available()
|
| 239 |
+
and is_k_diffusion_available()
|
| 240 |
+
and is_k_diffusion_version(">=", "0.0.12")
|
| 241 |
+
):
|
| 242 |
+
raise OptionalDependencyNotAvailable()
|
| 243 |
+
except OptionalDependencyNotAvailable:
|
| 244 |
+
from ...utils.dummy_paddle_and_paddlenlp_and_k_diffusion_objects import *
|
| 245 |
+
else:
|
| 246 |
+
from .pipeline_stable_diffusion_k_diffusion import (
|
| 247 |
+
StableDiffusionKDiffusionPipeline,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
try:
|
| 251 |
+
if not (is_paddlenlp_available() and is_fastdeploy_available()):
|
| 252 |
+
raise OptionalDependencyNotAvailable()
|
| 253 |
+
except OptionalDependencyNotAvailable:
|
| 254 |
+
from ...utils.dummy_fastdeploy_objects import *
|
| 255 |
+
else:
|
| 256 |
+
from .pipeline_fastdeploy_cycle_diffusion import (
|
| 257 |
+
FastDeployCycleDiffusionPipeline,
|
| 258 |
+
)
|
| 259 |
+
from .pipeline_fastdeploy_stable_diffusion import (
|
| 260 |
+
FastDeployStableDiffusionPipeline,
|
| 261 |
+
)
|
| 262 |
+
from .pipeline_fastdeploy_stable_diffusion_img2img import (
|
| 263 |
+
FastDeployStableDiffusionImg2ImgPipeline,
|
| 264 |
+
)
|
| 265 |
+
from .pipeline_fastdeploy_stable_diffusion_inpaint import (
|
| 266 |
+
FastDeployStableDiffusionInpaintPipeline,
|
| 267 |
+
)
|
| 268 |
+
from .pipeline_fastdeploy_stable_diffusion_inpaint_legacy import (
|
| 269 |
+
FastDeployStableDiffusionInpaintPipelineLegacy,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# new add
|
| 273 |
+
from .pipeline_fastdeploy_stable_diffusion_mega import (
|
| 274 |
+
FastDeployStableDiffusionMegaPipeline,
|
| 275 |
+
)
|
| 276 |
+
from .pipeline_fastdeploy_stable_diffusion_upscale import (
|
| 277 |
+
FastDeployStableDiffusionUpscalePipeline,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
else:
|
| 282 |
+
import sys
|
| 283 |
+
|
| 284 |
+
sys.modules[__name__] = _LazyModule(
|
| 285 |
+
__name__,
|
| 286 |
+
globals()["__file__"],
|
| 287 |
+
_import_structure,
|
| 288 |
+
module_spec=__spec__,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
for name, value in _dummy_objects.items():
|
| 292 |
+
setattr(sys.modules[__name__], name, value)
|
| 293 |
+
for name, value in _additional_imports.items():
|
| 294 |
+
setattr(sys.modules[__name__], name, value)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/stable_diffusion/convert_from_ckpt.py
ADDED
|
@@ -0,0 +1,1915 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
""" Conversion script for the Stable Diffusion checkpoints."""
|
| 16 |
+
|
| 17 |
+
import re
|
| 18 |
+
from io import BytesIO
|
| 19 |
+
from typing import Dict, Optional, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import paddle
|
| 23 |
+
import requests
|
| 24 |
+
|
| 25 |
+
from ppdiffusers.transformers import (
|
| 26 |
+
BertTokenizer,
|
| 27 |
+
CLIPImageProcessor,
|
| 28 |
+
CLIPTextConfig,
|
| 29 |
+
CLIPTextModel,
|
| 30 |
+
CLIPTextModelWithProjection,
|
| 31 |
+
CLIPTokenizer,
|
| 32 |
+
CLIPVisionConfig,
|
| 33 |
+
CLIPVisionModelWithProjection,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
from ...models import (
|
| 37 |
+
AutoencoderKL,
|
| 38 |
+
ControlNetModel,
|
| 39 |
+
PriorTransformer,
|
| 40 |
+
UNet2DConditionModel,
|
| 41 |
+
)
|
| 42 |
+
from ...schedulers import (
|
| 43 |
+
DDIMScheduler,
|
| 44 |
+
DDPMScheduler,
|
| 45 |
+
DPMSolverMultistepScheduler,
|
| 46 |
+
EulerAncestralDiscreteScheduler,
|
| 47 |
+
EulerDiscreteScheduler,
|
| 48 |
+
HeunDiscreteScheduler,
|
| 49 |
+
LMSDiscreteScheduler,
|
| 50 |
+
PNDMScheduler,
|
| 51 |
+
UnCLIPScheduler,
|
| 52 |
+
)
|
| 53 |
+
from ...utils import is_omegaconf_available, is_paddlenlp_available, logging, smart_load
|
| 54 |
+
from ...utils.import_utils import BACKENDS_MAPPING
|
| 55 |
+
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
| 56 |
+
from ..paint_by_example import PaintByExampleImageEncoder
|
| 57 |
+
from ..pipeline_utils import DiffusionPipeline
|
| 58 |
+
from .safety_checker import StableDiffusionSafetyChecker
|
| 59 |
+
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
| 60 |
+
|
| 61 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 62 |
+
|
| 63 |
+
from ...models.modeling_utils import ContextManagers, faster_set_state_dict
|
| 64 |
+
|
| 65 |
+
if is_paddlenlp_available():
|
| 66 |
+
try:
|
| 67 |
+
from paddlenlp.transformers.model_utils import no_init_weights
|
| 68 |
+
except ImportError:
|
| 69 |
+
from ...utils.paddle_utils import no_init_weights
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
| 73 |
+
"""
|
| 74 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
| 75 |
+
"""
|
| 76 |
+
if n_shave_prefix_segments >= 0:
|
| 77 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
| 78 |
+
else:
|
| 79 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 83 |
+
"""
|
| 84 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 85 |
+
"""
|
| 86 |
+
mapping = []
|
| 87 |
+
for old_item in old_list:
|
| 88 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
| 89 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
| 90 |
+
|
| 91 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
| 92 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
| 93 |
+
|
| 94 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
| 95 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
| 96 |
+
|
| 97 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 98 |
+
|
| 99 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 100 |
+
|
| 101 |
+
return mapping
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 105 |
+
"""
|
| 106 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 107 |
+
"""
|
| 108 |
+
mapping = []
|
| 109 |
+
for old_item in old_list:
|
| 110 |
+
new_item = old_item
|
| 111 |
+
|
| 112 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
| 113 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 114 |
+
|
| 115 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 116 |
+
|
| 117 |
+
return mapping
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 121 |
+
"""
|
| 122 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 123 |
+
"""
|
| 124 |
+
mapping = []
|
| 125 |
+
for old_item in old_list:
|
| 126 |
+
new_item = old_item
|
| 127 |
+
|
| 128 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
| 129 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
| 130 |
+
|
| 131 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
| 132 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
| 133 |
+
|
| 134 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 135 |
+
|
| 136 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 137 |
+
|
| 138 |
+
return mapping
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 142 |
+
"""
|
| 143 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 144 |
+
"""
|
| 145 |
+
mapping = []
|
| 146 |
+
for old_item in old_list:
|
| 147 |
+
new_item = old_item
|
| 148 |
+
|
| 149 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
| 150 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
| 151 |
+
|
| 152 |
+
new_item = new_item.replace("q.weight", "to_q.weight")
|
| 153 |
+
new_item = new_item.replace("q.bias", "to_q.bias")
|
| 154 |
+
|
| 155 |
+
new_item = new_item.replace("k.weight", "to_k.weight")
|
| 156 |
+
new_item = new_item.replace("k.bias", "to_k.bias")
|
| 157 |
+
|
| 158 |
+
new_item = new_item.replace("v.weight", "to_v.weight")
|
| 159 |
+
new_item = new_item.replace("v.bias", "to_v.bias")
|
| 160 |
+
|
| 161 |
+
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
| 162 |
+
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
| 163 |
+
|
| 164 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 165 |
+
|
| 166 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 167 |
+
|
| 168 |
+
return mapping
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def assign_to_checkpoint(
|
| 172 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
| 173 |
+
):
|
| 174 |
+
"""
|
| 175 |
+
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
| 176 |
+
attention layers, and takes into account additional replacements that may arise.
|
| 177 |
+
|
| 178 |
+
Assigns the weights to the new checkpoint.
|
| 179 |
+
"""
|
| 180 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
| 181 |
+
|
| 182 |
+
# Splits the attention layers into three variables.
|
| 183 |
+
if attention_paths_to_split is not None:
|
| 184 |
+
for path, path_map in attention_paths_to_split.items():
|
| 185 |
+
old_tensor = old_checkpoint[path]
|
| 186 |
+
channels = old_tensor.shape[0] // 3
|
| 187 |
+
|
| 188 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
| 189 |
+
|
| 190 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
| 191 |
+
|
| 192 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
| 193 |
+
query, key, value = np.split(old_tensor, 3, axis=1)
|
| 194 |
+
|
| 195 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
| 196 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
| 197 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
| 198 |
+
|
| 199 |
+
for path in paths:
|
| 200 |
+
new_path = path["new"]
|
| 201 |
+
|
| 202 |
+
# These have already been assigned
|
| 203 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
| 204 |
+
continue
|
| 205 |
+
|
| 206 |
+
# Global renaming happens here
|
| 207 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
| 208 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
| 209 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
| 210 |
+
|
| 211 |
+
if additional_replacements is not None:
|
| 212 |
+
for replacement in additional_replacements:
|
| 213 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
| 214 |
+
|
| 215 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
| 216 |
+
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
|
| 217 |
+
shape = old_checkpoint[path["old"]].shape
|
| 218 |
+
if is_attn_weight and len(shape) == 3:
|
| 219 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
| 220 |
+
elif is_attn_weight and len(shape) == 4:
|
| 221 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
| 222 |
+
else:
|
| 223 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def conv_attn_to_linear(checkpoint):
|
| 227 |
+
keys = list(checkpoint.keys())
|
| 228 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
| 229 |
+
for key in keys:
|
| 230 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
| 231 |
+
if checkpoint[key].ndim > 2:
|
| 232 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
| 233 |
+
elif "proj_attn.weight" in key:
|
| 234 |
+
if checkpoint[key].ndim > 2:
|
| 235 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
|
| 239 |
+
"""
|
| 240 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
| 241 |
+
"""
|
| 242 |
+
if controlnet:
|
| 243 |
+
unet_params = original_config.model.params.control_stage_config.params
|
| 244 |
+
else:
|
| 245 |
+
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
|
| 246 |
+
unet_params = original_config.model.params.unet_config.params
|
| 247 |
+
else:
|
| 248 |
+
unet_params = original_config.model.params.network_config.params
|
| 249 |
+
|
| 250 |
+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
| 251 |
+
|
| 252 |
+
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
| 253 |
+
|
| 254 |
+
down_block_types = []
|
| 255 |
+
resolution = 1
|
| 256 |
+
for i in range(len(block_out_channels)):
|
| 257 |
+
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
| 258 |
+
down_block_types.append(block_type)
|
| 259 |
+
if i != len(block_out_channels) - 1:
|
| 260 |
+
resolution *= 2
|
| 261 |
+
|
| 262 |
+
up_block_types = []
|
| 263 |
+
for i in range(len(block_out_channels)):
|
| 264 |
+
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
| 265 |
+
up_block_types.append(block_type)
|
| 266 |
+
resolution //= 2
|
| 267 |
+
|
| 268 |
+
if unet_params.transformer_depth is not None:
|
| 269 |
+
transformer_layers_per_block = (
|
| 270 |
+
unet_params.transformer_depth
|
| 271 |
+
if isinstance(unet_params.transformer_depth, int)
|
| 272 |
+
else list(unet_params.transformer_depth)
|
| 273 |
+
)
|
| 274 |
+
else:
|
| 275 |
+
transformer_layers_per_block = 1
|
| 276 |
+
|
| 277 |
+
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
| 278 |
+
|
| 279 |
+
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
| 280 |
+
use_linear_projection = (
|
| 281 |
+
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
| 282 |
+
)
|
| 283 |
+
if use_linear_projection:
|
| 284 |
+
# stable diffusion 2-base-512 and 2-768
|
| 285 |
+
if head_dim is None:
|
| 286 |
+
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
|
| 287 |
+
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
|
| 288 |
+
|
| 289 |
+
class_embed_type = None
|
| 290 |
+
addition_embed_type = None
|
| 291 |
+
addition_time_embed_dim = None
|
| 292 |
+
projection_class_embeddings_input_dim = None
|
| 293 |
+
context_dim = None
|
| 294 |
+
|
| 295 |
+
if unet_params.context_dim is not None:
|
| 296 |
+
context_dim = (
|
| 297 |
+
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
if "num_classes" in unet_params:
|
| 301 |
+
if unet_params.num_classes == "sequential":
|
| 302 |
+
if context_dim in [2048, 1280]:
|
| 303 |
+
# SDXL
|
| 304 |
+
addition_embed_type = "text_time"
|
| 305 |
+
addition_time_embed_dim = 256
|
| 306 |
+
else:
|
| 307 |
+
class_embed_type = "projection"
|
| 308 |
+
assert "adm_in_channels" in unet_params
|
| 309 |
+
projection_class_embeddings_input_dim = unet_params.adm_in_channels
|
| 310 |
+
|
| 311 |
+
config = {
|
| 312 |
+
"sample_size": image_size // vae_scale_factor,
|
| 313 |
+
"in_channels": unet_params.in_channels,
|
| 314 |
+
"down_block_types": tuple(down_block_types),
|
| 315 |
+
"block_out_channels": tuple(block_out_channels),
|
| 316 |
+
"layers_per_block": unet_params.num_res_blocks,
|
| 317 |
+
"cross_attention_dim": context_dim,
|
| 318 |
+
"attention_head_dim": head_dim,
|
| 319 |
+
"use_linear_projection": use_linear_projection,
|
| 320 |
+
"class_embed_type": class_embed_type,
|
| 321 |
+
"addition_embed_type": addition_embed_type,
|
| 322 |
+
"addition_time_embed_dim": addition_time_embed_dim,
|
| 323 |
+
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
| 324 |
+
"transformer_layers_per_block": transformer_layers_per_block,
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
if "disable_self_attentions" in unet_params:
|
| 328 |
+
config["only_cross_attention"] = unet_params.disable_self_attentions
|
| 329 |
+
|
| 330 |
+
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
|
| 331 |
+
config["num_class_embeds"] = unet_params.num_classes
|
| 332 |
+
|
| 333 |
+
if controlnet:
|
| 334 |
+
config["conditioning_channels"] = unet_params.hint_channels
|
| 335 |
+
else:
|
| 336 |
+
config["out_channels"] = unet_params.out_channels
|
| 337 |
+
config["up_block_types"] = tuple(up_block_types)
|
| 338 |
+
|
| 339 |
+
return config
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def create_vae_diffusers_config(original_config, image_size: int):
|
| 343 |
+
"""
|
| 344 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
| 345 |
+
"""
|
| 346 |
+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
| 347 |
+
_ = original_config.model.params.first_stage_config.params.embed_dim
|
| 348 |
+
|
| 349 |
+
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
| 350 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
| 351 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
| 352 |
+
|
| 353 |
+
config = {
|
| 354 |
+
"sample_size": image_size,
|
| 355 |
+
"in_channels": vae_params.in_channels,
|
| 356 |
+
"out_channels": vae_params.out_ch,
|
| 357 |
+
"down_block_types": tuple(down_block_types),
|
| 358 |
+
"up_block_types": tuple(up_block_types),
|
| 359 |
+
"block_out_channels": tuple(block_out_channels),
|
| 360 |
+
"latent_channels": vae_params.z_channels,
|
| 361 |
+
"layers_per_block": vae_params.num_res_blocks,
|
| 362 |
+
}
|
| 363 |
+
return config
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def get_default(params, key, default):
|
| 367 |
+
if key in params:
|
| 368 |
+
return params[key]
|
| 369 |
+
else:
|
| 370 |
+
return default
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def create_diffusers_schedular(original_config):
|
| 374 |
+
schedular = DDIMScheduler(
|
| 375 |
+
num_train_timesteps=original_config.model.params.timesteps,
|
| 376 |
+
beta_start=original_config.model.params.linear_start,
|
| 377 |
+
beta_end=original_config.model.params.linear_end,
|
| 378 |
+
beta_schedule="scaled_linear",
|
| 379 |
+
)
|
| 380 |
+
return schedular
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def create_ldm_bert_config(original_config):
|
| 384 |
+
bert_params = dict(original_config.model.params.cond_stage_config.params)
|
| 385 |
+
config = dict(
|
| 386 |
+
vocab_size=get_default(bert_params, "vocab_size", 30522),
|
| 387 |
+
max_position_embeddings=get_default(bert_params, "max_seq_len", 77),
|
| 388 |
+
encoder_layers=get_default(bert_params, "n_layer", 32),
|
| 389 |
+
encoder_ffn_dim=get_default(bert_params, "n_embed", 1280) * 4,
|
| 390 |
+
encoder_attention_heads=8,
|
| 391 |
+
head_dim=64,
|
| 392 |
+
activation_function="gelu",
|
| 393 |
+
d_model=get_default(bert_params, "n_embed", 1280),
|
| 394 |
+
dropout=0.0,
|
| 395 |
+
attention_dropout=0.0,
|
| 396 |
+
activation_dropout=0.0,
|
| 397 |
+
init_std=0.02,
|
| 398 |
+
pad_token_id=0,
|
| 399 |
+
)
|
| 400 |
+
return LDMBertConfig(**config)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def convert_ldm_unet_checkpoint(
|
| 404 |
+
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
|
| 405 |
+
):
|
| 406 |
+
"""
|
| 407 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
| 408 |
+
"""
|
| 409 |
+
|
| 410 |
+
if skip_extract_state_dict:
|
| 411 |
+
unet_state_dict = checkpoint
|
| 412 |
+
else:
|
| 413 |
+
# extract state_dict for UNet
|
| 414 |
+
unet_state_dict = {}
|
| 415 |
+
keys = list(checkpoint.keys())
|
| 416 |
+
|
| 417 |
+
if controlnet:
|
| 418 |
+
unet_key = "control_model."
|
| 419 |
+
else:
|
| 420 |
+
unet_key = "model.diffusion_model."
|
| 421 |
+
|
| 422 |
+
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
| 423 |
+
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
| 424 |
+
logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
| 425 |
+
logger.warning(
|
| 426 |
+
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
| 427 |
+
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
| 428 |
+
)
|
| 429 |
+
for key in keys:
|
| 430 |
+
if key.startswith("model.diffusion_model"):
|
| 431 |
+
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
| 432 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
| 433 |
+
else:
|
| 434 |
+
if sum(k.startswith("model_ema") for k in keys) > 100:
|
| 435 |
+
logger.warning(
|
| 436 |
+
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
| 437 |
+
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
for key in keys:
|
| 441 |
+
if key.startswith(unet_key):
|
| 442 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
| 443 |
+
|
| 444 |
+
new_checkpoint = {}
|
| 445 |
+
|
| 446 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
| 447 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
| 448 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
| 449 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
| 450 |
+
|
| 451 |
+
if config["class_embed_type"] is None:
|
| 452 |
+
# No parameters to port
|
| 453 |
+
...
|
| 454 |
+
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
| 455 |
+
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
| 456 |
+
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
| 457 |
+
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
| 458 |
+
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
| 459 |
+
else:
|
| 460 |
+
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
| 461 |
+
|
| 462 |
+
if config["addition_embed_type"] == "text_time":
|
| 463 |
+
new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
| 464 |
+
new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
| 465 |
+
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
| 466 |
+
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
| 467 |
+
|
| 468 |
+
# Relevant to StableDiffusionUpscalePipeline
|
| 469 |
+
if "num_class_embeds" in config:
|
| 470 |
+
new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"]
|
| 471 |
+
|
| 472 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
| 473 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
| 474 |
+
|
| 475 |
+
if not controlnet:
|
| 476 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
| 477 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
| 478 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
| 479 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
| 480 |
+
|
| 481 |
+
# Retrieves the keys for the input blocks only
|
| 482 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
| 483 |
+
input_blocks = {
|
| 484 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
| 485 |
+
for layer_id in range(num_input_blocks)
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
# Retrieves the keys for the middle blocks only
|
| 489 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
| 490 |
+
middle_blocks = {
|
| 491 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
| 492 |
+
for layer_id in range(num_middle_blocks)
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
# Retrieves the keys for the output blocks only
|
| 496 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
| 497 |
+
output_blocks = {
|
| 498 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
| 499 |
+
for layer_id in range(num_output_blocks)
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
for i in range(1, num_input_blocks):
|
| 503 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
| 504 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
| 505 |
+
|
| 506 |
+
resnets = [
|
| 507 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
| 508 |
+
]
|
| 509 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
| 510 |
+
|
| 511 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
| 512 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
| 513 |
+
f"input_blocks.{i}.0.op.weight"
|
| 514 |
+
)
|
| 515 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
| 516 |
+
f"input_blocks.{i}.0.op.bias"
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
paths = renew_resnet_paths(resnets)
|
| 520 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 521 |
+
assign_to_checkpoint(
|
| 522 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
if len(attentions):
|
| 526 |
+
paths = renew_attention_paths(attentions)
|
| 527 |
+
|
| 528 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
| 529 |
+
assign_to_checkpoint(
|
| 530 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
resnet_0 = middle_blocks[0]
|
| 534 |
+
attentions = middle_blocks[1]
|
| 535 |
+
resnet_1 = middle_blocks[2]
|
| 536 |
+
|
| 537 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
| 538 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
| 539 |
+
|
| 540 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
| 541 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
| 542 |
+
|
| 543 |
+
attentions_paths = renew_attention_paths(attentions)
|
| 544 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
| 545 |
+
assign_to_checkpoint(
|
| 546 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
for i in range(num_output_blocks):
|
| 550 |
+
block_id = i // (config["layers_per_block"] + 1)
|
| 551 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
| 552 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
| 553 |
+
output_block_list = {}
|
| 554 |
+
|
| 555 |
+
for layer in output_block_layers:
|
| 556 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
| 557 |
+
if layer_id in output_block_list:
|
| 558 |
+
output_block_list[layer_id].append(layer_name)
|
| 559 |
+
else:
|
| 560 |
+
output_block_list[layer_id] = [layer_name]
|
| 561 |
+
|
| 562 |
+
if len(output_block_list) > 1:
|
| 563 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
| 564 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
| 565 |
+
|
| 566 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
| 567 |
+
paths = renew_resnet_paths(resnets)
|
| 568 |
+
|
| 569 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 570 |
+
assign_to_checkpoint(
|
| 571 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
| 575 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
| 576 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
| 577 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
| 578 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
| 579 |
+
]
|
| 580 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
| 581 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
| 582 |
+
]
|
| 583 |
+
|
| 584 |
+
# Clear attentions as they have been attributed above.
|
| 585 |
+
if len(attentions) == 2:
|
| 586 |
+
attentions = []
|
| 587 |
+
|
| 588 |
+
if len(attentions):
|
| 589 |
+
paths = renew_attention_paths(attentions)
|
| 590 |
+
meta_path = {
|
| 591 |
+
"old": f"output_blocks.{i}.1",
|
| 592 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
| 593 |
+
}
|
| 594 |
+
assign_to_checkpoint(
|
| 595 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 596 |
+
)
|
| 597 |
+
else:
|
| 598 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
| 599 |
+
for path in resnet_0_paths:
|
| 600 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
| 601 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
| 602 |
+
|
| 603 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
| 604 |
+
|
| 605 |
+
if controlnet:
|
| 606 |
+
# conditioning embedding
|
| 607 |
+
|
| 608 |
+
orig_index = 0
|
| 609 |
+
|
| 610 |
+
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
|
| 611 |
+
f"input_hint_block.{orig_index}.weight"
|
| 612 |
+
)
|
| 613 |
+
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
|
| 614 |
+
f"input_hint_block.{orig_index}.bias"
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
orig_index += 2
|
| 618 |
+
|
| 619 |
+
diffusers_index = 0
|
| 620 |
+
|
| 621 |
+
while diffusers_index < 6:
|
| 622 |
+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
|
| 623 |
+
f"input_hint_block.{orig_index}.weight"
|
| 624 |
+
)
|
| 625 |
+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
|
| 626 |
+
f"input_hint_block.{orig_index}.bias"
|
| 627 |
+
)
|
| 628 |
+
diffusers_index += 1
|
| 629 |
+
orig_index += 2
|
| 630 |
+
|
| 631 |
+
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
|
| 632 |
+
f"input_hint_block.{orig_index}.weight"
|
| 633 |
+
)
|
| 634 |
+
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
|
| 635 |
+
f"input_hint_block.{orig_index}.bias"
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
# down blocks
|
| 639 |
+
for i in range(num_input_blocks):
|
| 640 |
+
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
|
| 641 |
+
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
|
| 642 |
+
|
| 643 |
+
# mid block
|
| 644 |
+
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
|
| 645 |
+
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
|
| 646 |
+
|
| 647 |
+
return new_checkpoint
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
| 651 |
+
# extract state dict for VAE
|
| 652 |
+
vae_state_dict = {}
|
| 653 |
+
keys = list(checkpoint.keys())
|
| 654 |
+
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
|
| 655 |
+
for key in keys:
|
| 656 |
+
if key.startswith(vae_key):
|
| 657 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
| 658 |
+
|
| 659 |
+
new_checkpoint = {}
|
| 660 |
+
|
| 661 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
| 662 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
| 663 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
| 664 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
| 665 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
| 666 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
| 667 |
+
|
| 668 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
| 669 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
| 670 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
| 671 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
| 672 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
| 673 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
| 674 |
+
|
| 675 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
| 676 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
| 677 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
| 678 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
| 679 |
+
|
| 680 |
+
# Retrieves the keys for the encoder down blocks only
|
| 681 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
| 682 |
+
down_blocks = {
|
| 683 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
# Retrieves the keys for the decoder up blocks only
|
| 687 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
| 688 |
+
up_blocks = {
|
| 689 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
for i in range(num_down_blocks):
|
| 693 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
| 694 |
+
|
| 695 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
| 696 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
| 697 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
| 698 |
+
)
|
| 699 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
| 700 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 704 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
| 705 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 706 |
+
|
| 707 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
| 708 |
+
num_mid_res_blocks = 2
|
| 709 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 710 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
| 711 |
+
|
| 712 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 713 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 714 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 715 |
+
|
| 716 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
| 717 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 718 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 719 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 720 |
+
conv_attn_to_linear(new_checkpoint)
|
| 721 |
+
|
| 722 |
+
for i in range(num_up_blocks):
|
| 723 |
+
block_id = num_up_blocks - 1 - i
|
| 724 |
+
resnets = [
|
| 725 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
| 726 |
+
]
|
| 727 |
+
|
| 728 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
| 729 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
| 730 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
| 731 |
+
]
|
| 732 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
| 733 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
| 734 |
+
]
|
| 735 |
+
|
| 736 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 737 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
| 738 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 739 |
+
|
| 740 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
| 741 |
+
num_mid_res_blocks = 2
|
| 742 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 743 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
| 744 |
+
|
| 745 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 746 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 747 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 748 |
+
|
| 749 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
| 750 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 751 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 752 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 753 |
+
conv_attn_to_linear(new_checkpoint)
|
| 754 |
+
return new_checkpoint
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
def convert_diffusers_vae_unet_to_ppdiffusers(vae_or_unet, diffusers_vae_unet_checkpoint):
|
| 758 |
+
import paddle.nn as nn
|
| 759 |
+
|
| 760 |
+
need_transpose = []
|
| 761 |
+
for k, v in vae_or_unet.named_sublayers(include_self=True):
|
| 762 |
+
if isinstance(v, nn.Linear):
|
| 763 |
+
need_transpose.append(k + ".weight")
|
| 764 |
+
new_vae_or_unet = {}
|
| 765 |
+
for k in list(diffusers_vae_unet_checkpoint.keys()):
|
| 766 |
+
v = diffusers_vae_unet_checkpoint.pop(k)
|
| 767 |
+
if k not in need_transpose:
|
| 768 |
+
new_vae_or_unet[k] = v
|
| 769 |
+
else:
|
| 770 |
+
new_vae_or_unet[k] = v.T
|
| 771 |
+
return new_vae_or_unet
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
def convert_ldm_bert_checkpoint(checkpoint, config):
|
| 775 |
+
# extract state dict for bert
|
| 776 |
+
bert_state_dict = {}
|
| 777 |
+
bert_key = "cond_stage_model."
|
| 778 |
+
keys = list(checkpoint.keys())
|
| 779 |
+
for key in keys:
|
| 780 |
+
if key.startswith(bert_key):
|
| 781 |
+
bert_state_dict[key.replace(bert_key, "")] = checkpoint.get(key)
|
| 782 |
+
|
| 783 |
+
new_checkpoint = {}
|
| 784 |
+
new_checkpoint["model.embeddings.embed_tokens.weight"] = bert_state_dict["transformer.token_emb.weight"]
|
| 785 |
+
new_checkpoint["model.embeddings.embed_positions.weight"] = bert_state_dict["transformer.pos_emb.emb.weight"]
|
| 786 |
+
for i in range(config.encoder_layers):
|
| 787 |
+
double_i = 2 * i
|
| 788 |
+
double_i_plus1 = 2 * i + 1
|
| 789 |
+
# convert norm
|
| 790 |
+
new_checkpoint[f"model.layers.{i}.self_attn_layer_norm.weight"] = bert_state_dict[
|
| 791 |
+
f"transformer.attn_layers.layers.{double_i}.0.weight"
|
| 792 |
+
]
|
| 793 |
+
new_checkpoint[f"model.layers.{i}.self_attn_layer_norm.bias"] = bert_state_dict[
|
| 794 |
+
f"transformer.attn_layers.layers.{double_i}.0.bias"
|
| 795 |
+
]
|
| 796 |
+
|
| 797 |
+
new_checkpoint[f"model.layers.{i}.self_attn.q_proj.weight"] = bert_state_dict[
|
| 798 |
+
f"transformer.attn_layers.layers.{double_i}.1.to_q.weight"
|
| 799 |
+
].T
|
| 800 |
+
new_checkpoint[f"model.layers.{i}.self_attn.k_proj.weight"] = bert_state_dict[
|
| 801 |
+
f"transformer.attn_layers.layers.{double_i}.1.to_k.weight"
|
| 802 |
+
].T
|
| 803 |
+
new_checkpoint[f"model.layers.{i}.self_attn.v_proj.weight"] = bert_state_dict[
|
| 804 |
+
f"transformer.attn_layers.layers.{double_i}.1.to_v.weight"
|
| 805 |
+
].T
|
| 806 |
+
new_checkpoint[f"model.layers.{i}.self_attn.out_proj.weight"] = bert_state_dict[
|
| 807 |
+
f"transformer.attn_layers.layers.{double_i}.1.to_out.weight"
|
| 808 |
+
].T
|
| 809 |
+
new_checkpoint[f"model.layers.{i}.self_attn.out_proj.bias"] = bert_state_dict[
|
| 810 |
+
f"transformer.attn_layers.layers.{double_i}.1.to_out.bias"
|
| 811 |
+
]
|
| 812 |
+
|
| 813 |
+
new_checkpoint[f"model.layers.{i}.final_layer_norm.weight"] = bert_state_dict[
|
| 814 |
+
f"transformer.attn_layers.layers.{double_i_plus1}.0.weight"
|
| 815 |
+
]
|
| 816 |
+
new_checkpoint[f"model.layers.{i}.final_layer_norm.bias"] = bert_state_dict[
|
| 817 |
+
f"transformer.attn_layers.layers.{double_i_plus1}.0.bias"
|
| 818 |
+
]
|
| 819 |
+
new_checkpoint[f"model.layers.{i}.fc1.weight"] = bert_state_dict[
|
| 820 |
+
f"transformer.attn_layers.layers.{double_i_plus1}.1.net.0.0.weight"
|
| 821 |
+
].T
|
| 822 |
+
new_checkpoint[f"model.layers.{i}.fc1.bias"] = bert_state_dict[
|
| 823 |
+
f"transformer.attn_layers.layers.{double_i_plus1}.1.net.0.0.bias"
|
| 824 |
+
]
|
| 825 |
+
new_checkpoint[f"model.layers.{i}.fc2.weight"] = bert_state_dict[
|
| 826 |
+
f"transformer.attn_layers.layers.{double_i_plus1}.1.net.2.weight"
|
| 827 |
+
].T
|
| 828 |
+
new_checkpoint[f"model.layers.{i}.fc2.bias"] = bert_state_dict[
|
| 829 |
+
f"transformer.attn_layers.layers.{double_i_plus1}.1.net.2.bias"
|
| 830 |
+
].T
|
| 831 |
+
|
| 832 |
+
new_checkpoint["model.layer_norm.weight"] = bert_state_dict["transformer.norm.weight"]
|
| 833 |
+
new_checkpoint["model.layer_norm.bias"] = bert_state_dict["transformer.norm.bias"]
|
| 834 |
+
ldmbert = LDMBertModel(config)
|
| 835 |
+
ldmbert.eval()
|
| 836 |
+
faster_set_state_dict(ldmbert, new_checkpoint)
|
| 837 |
+
return ldmbert
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
|
| 841 |
+
if text_encoder is None:
|
| 842 |
+
config_name = "openai/clip-vit-large-patch14"
|
| 843 |
+
try:
|
| 844 |
+
config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
|
| 845 |
+
except Exception:
|
| 846 |
+
raise ValueError(
|
| 847 |
+
f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'."
|
| 848 |
+
)
|
| 849 |
+
init_contexts = []
|
| 850 |
+
init_contexts.append(paddle.dtype_guard(paddle.float32))
|
| 851 |
+
init_contexts.append(no_init_weights(_enable=True))
|
| 852 |
+
if hasattr(paddle, "LazyGuard"):
|
| 853 |
+
init_contexts.append(paddle.LazyGuard())
|
| 854 |
+
with ContextManagers(init_contexts):
|
| 855 |
+
text_model = CLIPTextModel(config)
|
| 856 |
+
else:
|
| 857 |
+
text_model = text_encoder
|
| 858 |
+
|
| 859 |
+
keys = list(checkpoint.keys())
|
| 860 |
+
|
| 861 |
+
text_model_dict = {}
|
| 862 |
+
|
| 863 |
+
remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"]
|
| 864 |
+
|
| 865 |
+
for key in keys:
|
| 866 |
+
for prefix in remove_prefixes:
|
| 867 |
+
if key.startswith(prefix):
|
| 868 |
+
text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]
|
| 869 |
+
|
| 870 |
+
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
|
| 871 |
+
text_model_dict.pop("text_model.embeddings.position_ids", None)
|
| 872 |
+
|
| 873 |
+
faster_set_state_dict(text_model, convert_diffusers_vae_unet_to_ppdiffusers(text_model, text_model_dict))
|
| 874 |
+
|
| 875 |
+
return text_model
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
textenc_conversion_lst = [
|
| 879 |
+
("positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
| 880 |
+
("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
| 881 |
+
("ln_final.weight", "text_model.final_layer_norm.weight"),
|
| 882 |
+
("ln_final.bias", "text_model.final_layer_norm.bias"),
|
| 883 |
+
("text_projection", "text_projection.weight"),
|
| 884 |
+
]
|
| 885 |
+
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
|
| 886 |
+
|
| 887 |
+
textenc_transformer_conversion_lst = [
|
| 888 |
+
# (stable-diffusion, HF Diffusers)
|
| 889 |
+
("resblocks.", "text_model.encoder.layers."),
|
| 890 |
+
("ln_1", "layer_norm1"),
|
| 891 |
+
("ln_2", "layer_norm2"),
|
| 892 |
+
(".c_fc.", ".fc1."),
|
| 893 |
+
(".c_proj.", ".fc2."),
|
| 894 |
+
(".attn", ".self_attn"),
|
| 895 |
+
("ln_final.", "transformer.text_model.final_layer_norm."),
|
| 896 |
+
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
| 897 |
+
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
| 898 |
+
]
|
| 899 |
+
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
|
| 900 |
+
textenc_pattern = re.compile("|".join(protected.keys()))
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False):
|
| 904 |
+
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
|
| 905 |
+
model = PaintByExampleImageEncoder(config)
|
| 906 |
+
|
| 907 |
+
keys = list(checkpoint.keys())
|
| 908 |
+
|
| 909 |
+
text_model_dict = {}
|
| 910 |
+
|
| 911 |
+
for key in keys:
|
| 912 |
+
if key.startswith("cond_stage_model.transformer"):
|
| 913 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
| 914 |
+
|
| 915 |
+
# load clip vision
|
| 916 |
+
faster_set_state_dict(model.model, convert_diffusers_vae_unet_to_ppdiffusers(model.model, text_model_dict))
|
| 917 |
+
|
| 918 |
+
# load mapper
|
| 919 |
+
keys_mapper = {
|
| 920 |
+
k[len("cond_stage_model.mapper.res") :]: v
|
| 921 |
+
for k, v in checkpoint.items()
|
| 922 |
+
if k.startswith("cond_stage_model.mapper")
|
| 923 |
+
}
|
| 924 |
+
|
| 925 |
+
MAPPING = {
|
| 926 |
+
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
|
| 927 |
+
"attn.c_proj": ["attn1.to_out.0"],
|
| 928 |
+
"ln_1": ["norm1"],
|
| 929 |
+
"ln_2": ["norm3"],
|
| 930 |
+
"mlp.c_fc": ["ff.net.0.proj"],
|
| 931 |
+
"mlp.c_proj": ["ff.net.2"],
|
| 932 |
+
}
|
| 933 |
+
|
| 934 |
+
mapped_weights = {}
|
| 935 |
+
for key, value in keys_mapper.items():
|
| 936 |
+
prefix = key[: len("blocks.i")]
|
| 937 |
+
suffix = key.split(prefix)[-1].split(".")[-1]
|
| 938 |
+
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
|
| 939 |
+
mapped_names = MAPPING[name]
|
| 940 |
+
|
| 941 |
+
num_splits = len(mapped_names)
|
| 942 |
+
for i, mapped_name in enumerate(mapped_names):
|
| 943 |
+
new_name = ".".join([prefix, mapped_name, suffix])
|
| 944 |
+
shape = value.shape[0] // num_splits
|
| 945 |
+
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
|
| 946 |
+
|
| 947 |
+
faster_set_state_dict(model.mapper, convert_diffusers_vae_unet_to_ppdiffusers(model.mapper, mapped_weights))
|
| 948 |
+
|
| 949 |
+
# load final layer norm
|
| 950 |
+
faster_set_state_dict(
|
| 951 |
+
model.final_layer_norm,
|
| 952 |
+
{
|
| 953 |
+
"bias": checkpoint["cond_stage_model.final_ln.bias"],
|
| 954 |
+
"weight": checkpoint["cond_stage_model.final_ln.weight"],
|
| 955 |
+
},
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
# load final proj
|
| 959 |
+
faster_set_state_dict(
|
| 960 |
+
model.proj_out,
|
| 961 |
+
{
|
| 962 |
+
"bias": checkpoint["proj_out.bias"],
|
| 963 |
+
"weight": checkpoint["proj_out.weight"].T,
|
| 964 |
+
},
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
# load uncond vector
|
| 968 |
+
model.uncond_vector.set_value(checkpoint["learnable_vector"])
|
| 969 |
+
return model
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
def convert_open_clip_checkpoint(
|
| 973 |
+
checkpoint,
|
| 974 |
+
config_name,
|
| 975 |
+
prefix="cond_stage_model.model.",
|
| 976 |
+
has_projection=False,
|
| 977 |
+
local_files_only=False,
|
| 978 |
+
**config_kwargs,
|
| 979 |
+
):
|
| 980 |
+
# text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
| 981 |
+
# text_model = CLIPTextModelWithProjection.from_pretrained(
|
| 982 |
+
# "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
|
| 983 |
+
# )
|
| 984 |
+
try:
|
| 985 |
+
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only)
|
| 986 |
+
except Exception:
|
| 987 |
+
raise ValueError(
|
| 988 |
+
f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'."
|
| 989 |
+
)
|
| 990 |
+
|
| 991 |
+
init_contexts = []
|
| 992 |
+
init_contexts.append(paddle.dtype_guard(paddle.float32))
|
| 993 |
+
init_contexts.append(no_init_weights(_enable=True))
|
| 994 |
+
if hasattr(paddle, "LazyGuard"):
|
| 995 |
+
init_contexts.append(paddle.LazyGuard())
|
| 996 |
+
with ContextManagers(init_contexts):
|
| 997 |
+
text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)
|
| 998 |
+
|
| 999 |
+
keys = list(checkpoint.keys())
|
| 1000 |
+
|
| 1001 |
+
keys_to_ignore = []
|
| 1002 |
+
if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23:
|
| 1003 |
+
# make sure to remove all keys > 22
|
| 1004 |
+
keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")]
|
| 1005 |
+
keys_to_ignore += ["cond_stage_model.model.text_projection"]
|
| 1006 |
+
|
| 1007 |
+
text_model_dict = {}
|
| 1008 |
+
|
| 1009 |
+
if prefix + "text_projection" in checkpoint:
|
| 1010 |
+
d_model = int(checkpoint[prefix + "text_projection"].shape[0])
|
| 1011 |
+
else:
|
| 1012 |
+
d_model = 1024
|
| 1013 |
+
|
| 1014 |
+
# text_model_dict["text_model.embeddings.position_embedding.weight"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
| 1015 |
+
|
| 1016 |
+
for key in keys:
|
| 1017 |
+
if key in keys_to_ignore:
|
| 1018 |
+
continue
|
| 1019 |
+
|
| 1020 |
+
if key[len(prefix) :] in textenc_conversion_map:
|
| 1021 |
+
if key.endswith("text_projection"):
|
| 1022 |
+
value = checkpoint[key].T
|
| 1023 |
+
else:
|
| 1024 |
+
value = checkpoint[key]
|
| 1025 |
+
|
| 1026 |
+
text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value
|
| 1027 |
+
|
| 1028 |
+
if key.startswith(prefix + "transformer."):
|
| 1029 |
+
new_key = key[len(prefix + "transformer.") :]
|
| 1030 |
+
if new_key.endswith(".in_proj_weight"):
|
| 1031 |
+
new_key = new_key[: -len(".in_proj_weight")]
|
| 1032 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
| 1033 |
+
text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
| 1034 |
+
text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
|
| 1035 |
+
text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
|
| 1036 |
+
elif new_key.endswith(".in_proj_bias"):
|
| 1037 |
+
new_key = new_key[: -len(".in_proj_bias")]
|
| 1038 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
| 1039 |
+
text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
|
| 1040 |
+
text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
|
| 1041 |
+
text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
|
| 1042 |
+
else:
|
| 1043 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
| 1044 |
+
|
| 1045 |
+
text_model_dict[new_key] = checkpoint[key]
|
| 1046 |
+
|
| 1047 |
+
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
|
| 1048 |
+
text_model_dict.pop("text_model.embeddings.position_ids", None)
|
| 1049 |
+
faster_set_state_dict(text_model, convert_diffusers_vae_unet_to_ppdiffusers(text_model, text_model_dict))
|
| 1050 |
+
|
| 1051 |
+
return text_model
|
| 1052 |
+
|
| 1053 |
+
|
| 1054 |
+
def stable_unclip_image_encoder(original_config, local_files_only=False):
|
| 1055 |
+
"""
|
| 1056 |
+
Returns the image processor and clip image encoder for the img2img unclip pipeline.
|
| 1057 |
+
|
| 1058 |
+
We currently know of two types of stable unclip models which separately use the clip and the openclip image
|
| 1059 |
+
encoders.
|
| 1060 |
+
"""
|
| 1061 |
+
|
| 1062 |
+
image_embedder_config = original_config.model.params.embedder_config
|
| 1063 |
+
|
| 1064 |
+
sd_clip_image_embedder_class = image_embedder_config.target
|
| 1065 |
+
sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
|
| 1066 |
+
|
| 1067 |
+
if sd_clip_image_embedder_class == "ClipImageEmbedder":
|
| 1068 |
+
clip_model_name = image_embedder_config.params.model
|
| 1069 |
+
|
| 1070 |
+
if clip_model_name == "ViT-L/14":
|
| 1071 |
+
feature_extractor = CLIPImageProcessor()
|
| 1072 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
| 1073 |
+
"openai/clip-vit-large-patch14", local_files_only=local_files_only
|
| 1074 |
+
)
|
| 1075 |
+
else:
|
| 1076 |
+
raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
|
| 1077 |
+
|
| 1078 |
+
elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
|
| 1079 |
+
feature_extractor = CLIPImageProcessor()
|
| 1080 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
| 1081 |
+
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=local_files_only
|
| 1082 |
+
)
|
| 1083 |
+
else:
|
| 1084 |
+
raise NotImplementedError(
|
| 1085 |
+
f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
|
| 1086 |
+
)
|
| 1087 |
+
|
| 1088 |
+
return feature_extractor, image_encoder
|
| 1089 |
+
|
| 1090 |
+
|
| 1091 |
+
def stable_unclip_image_noising_components(
|
| 1092 |
+
original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
|
| 1093 |
+
):
|
| 1094 |
+
"""
|
| 1095 |
+
Returns the noising components for the img2img and txt2img unclip pipelines.
|
| 1096 |
+
|
| 1097 |
+
Converts the stability noise augmentor into
|
| 1098 |
+
1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
|
| 1099 |
+
2. a `DDPMScheduler` for holding the noise schedule
|
| 1100 |
+
|
| 1101 |
+
If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
|
| 1102 |
+
"""
|
| 1103 |
+
noise_aug_config = original_config.model.params.noise_aug_config
|
| 1104 |
+
noise_aug_class = noise_aug_config.target
|
| 1105 |
+
noise_aug_class = noise_aug_class.split(".")[-1]
|
| 1106 |
+
|
| 1107 |
+
if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
|
| 1108 |
+
noise_aug_config = noise_aug_config.params
|
| 1109 |
+
embedding_dim = noise_aug_config.timestep_dim
|
| 1110 |
+
max_noise_level = noise_aug_config.noise_schedule_config.timesteps
|
| 1111 |
+
beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
|
| 1112 |
+
|
| 1113 |
+
image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
|
| 1114 |
+
image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
|
| 1115 |
+
|
| 1116 |
+
if "clip_stats_path" in noise_aug_config:
|
| 1117 |
+
if clip_stats_path is None:
|
| 1118 |
+
raise ValueError("This stable unclip config requires a `clip_stats_path`")
|
| 1119 |
+
|
| 1120 |
+
from ...utils import torch_load
|
| 1121 |
+
|
| 1122 |
+
clip_mean, clip_std = torch_load(clip_stats_path)
|
| 1123 |
+
if hasattr(clip_mean, "numpy"):
|
| 1124 |
+
clip_mean = clip_mean.numpy()
|
| 1125 |
+
if hasattr(clip_std, "numpy"):
|
| 1126 |
+
clip_std = clip_std.numpy()
|
| 1127 |
+
clip_mean = clip_mean[None, :]
|
| 1128 |
+
clip_std = clip_std[None, :]
|
| 1129 |
+
|
| 1130 |
+
clip_stats_state_dict = {
|
| 1131 |
+
"mean": clip_mean,
|
| 1132 |
+
"std": clip_std,
|
| 1133 |
+
}
|
| 1134 |
+
faster_set_state_dict(image_normalizer, clip_stats_state_dict)
|
| 1135 |
+
else:
|
| 1136 |
+
raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
|
| 1137 |
+
|
| 1138 |
+
return image_normalizer, image_noising_scheduler
|
| 1139 |
+
|
| 1140 |
+
|
| 1141 |
+
def convert_controlnet_checkpoint(
|
| 1142 |
+
checkpoint,
|
| 1143 |
+
original_config,
|
| 1144 |
+
checkpoint_path,
|
| 1145 |
+
image_size,
|
| 1146 |
+
upcast_attention,
|
| 1147 |
+
extract_ema,
|
| 1148 |
+
use_linear_projection=None,
|
| 1149 |
+
cross_attention_dim=None,
|
| 1150 |
+
):
|
| 1151 |
+
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
| 1152 |
+
ctrlnet_config["upcast_attention"] = upcast_attention
|
| 1153 |
+
|
| 1154 |
+
ctrlnet_config.pop("sample_size")
|
| 1155 |
+
|
| 1156 |
+
if use_linear_projection is not None:
|
| 1157 |
+
ctrlnet_config["use_linear_projection"] = use_linear_projection
|
| 1158 |
+
|
| 1159 |
+
if cross_attention_dim is not None:
|
| 1160 |
+
ctrlnet_config["cross_attention_dim"] = cross_attention_dim
|
| 1161 |
+
|
| 1162 |
+
init_contexts = []
|
| 1163 |
+
init_contexts.append(paddle.dtype_guard(paddle.float32))
|
| 1164 |
+
init_contexts.append(no_init_weights(_enable=True))
|
| 1165 |
+
if hasattr(paddle, "LazyGuard"):
|
| 1166 |
+
init_contexts.append(paddle.LazyGuard())
|
| 1167 |
+
with ContextManagers(init_contexts):
|
| 1168 |
+
controlnet = ControlNetModel(**ctrlnet_config)
|
| 1169 |
+
|
| 1170 |
+
# Some controlnet ckpt files are distributed independently from the rest of the
|
| 1171 |
+
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
|
| 1172 |
+
if "time_embed.0.weight" in checkpoint:
|
| 1173 |
+
skip_extract_state_dict = True
|
| 1174 |
+
else:
|
| 1175 |
+
skip_extract_state_dict = False
|
| 1176 |
+
|
| 1177 |
+
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
|
| 1178 |
+
checkpoint,
|
| 1179 |
+
ctrlnet_config,
|
| 1180 |
+
path=checkpoint_path,
|
| 1181 |
+
extract_ema=extract_ema,
|
| 1182 |
+
controlnet=True,
|
| 1183 |
+
skip_extract_state_dict=skip_extract_state_dict,
|
| 1184 |
+
)
|
| 1185 |
+
|
| 1186 |
+
faster_set_state_dict(controlnet, convert_diffusers_vae_unet_to_ppdiffusers(controlnet, converted_ctrl_checkpoint))
|
| 1187 |
+
|
| 1188 |
+
return controlnet
|
| 1189 |
+
|
| 1190 |
+
|
| 1191 |
+
def download_from_original_stable_diffusion_ckpt(
|
| 1192 |
+
checkpoint_path_or_dict: Union[str, Dict[str, paddle.Tensor]],
|
| 1193 |
+
original_config_file: str = None,
|
| 1194 |
+
image_size: Optional[int] = None,
|
| 1195 |
+
prediction_type: str = None,
|
| 1196 |
+
model_type: str = None,
|
| 1197 |
+
extract_ema: bool = False,
|
| 1198 |
+
scheduler_type: str = "pndm",
|
| 1199 |
+
num_in_channels: Optional[int] = None,
|
| 1200 |
+
upcast_attention: Optional[bool] = None,
|
| 1201 |
+
device: str = None,
|
| 1202 |
+
from_safetensors: bool = False,
|
| 1203 |
+
stable_unclip: Optional[str] = None,
|
| 1204 |
+
stable_unclip_prior: Optional[str] = None,
|
| 1205 |
+
clip_stats_path: Optional[str] = None,
|
| 1206 |
+
controlnet: Optional[bool] = None,
|
| 1207 |
+
adapter: Optional[bool] = None,
|
| 1208 |
+
load_safety_checker: bool = True,
|
| 1209 |
+
pipeline_class: DiffusionPipeline = None,
|
| 1210 |
+
local_files_only=False,
|
| 1211 |
+
vae_path=None,
|
| 1212 |
+
vae=None,
|
| 1213 |
+
text_encoder=None,
|
| 1214 |
+
tokenizer=None,
|
| 1215 |
+
config_files=None,
|
| 1216 |
+
paddle_dtype=None,
|
| 1217 |
+
**kwargs,
|
| 1218 |
+
) -> DiffusionPipeline:
|
| 1219 |
+
"""
|
| 1220 |
+
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
| 1221 |
+
config file.
|
| 1222 |
+
|
| 1223 |
+
Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the
|
| 1224 |
+
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
|
| 1225 |
+
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
|
| 1226 |
+
|
| 1227 |
+
Args:
|
| 1228 |
+
checkpoint_path_or_dict (`str` or `dict`): Path to `.ckpt` file, or the state dict.
|
| 1229 |
+
original_config_file (`str`):
|
| 1230 |
+
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
|
| 1231 |
+
inferred by looking for a key that only exists in SD2.0 models.
|
| 1232 |
+
image_size (`int`, *optional*, defaults to 512):
|
| 1233 |
+
The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
|
| 1234 |
+
Base. Use 768 for Stable Diffusion v2.
|
| 1235 |
+
prediction_type (`str`, *optional*):
|
| 1236 |
+
The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
|
| 1237 |
+
Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
|
| 1238 |
+
num_in_channels (`int`, *optional*, defaults to None):
|
| 1239 |
+
The number of input channels. If `None`, it will be automatically inferred.
|
| 1240 |
+
scheduler_type (`str`, *optional*, defaults to 'pndm'):
|
| 1241 |
+
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
| 1242 |
+
"ddim"]`.
|
| 1243 |
+
model_type (`str`, *optional*, defaults to `None`):
|
| 1244 |
+
The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
|
| 1245 |
+
"FrozenCLIPEmbedder", "PaintByExample"]`.
|
| 1246 |
+
is_img2img (`bool`, *optional*, defaults to `False`):
|
| 1247 |
+
Whether the model should be loaded as an img2img pipeline.
|
| 1248 |
+
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
|
| 1249 |
+
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
|
| 1250 |
+
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
|
| 1251 |
+
inference. Non-EMA weights are usually better to continue fine-tuning.
|
| 1252 |
+
upcast_attention (`bool`, *optional*, defaults to `None`):
|
| 1253 |
+
Whether the attention computation should always be upcasted. This is necessary when running stable
|
| 1254 |
+
diffusion 2.1.
|
| 1255 |
+
device (`str`, *optional*, defaults to `None`):
|
| 1256 |
+
The device to use. Pass `None` to determine automatically.
|
| 1257 |
+
from_safetensors (`str`, *optional*, defaults to `False`):
|
| 1258 |
+
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of Paddle.
|
| 1259 |
+
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
| 1260 |
+
Whether to load the safety checker or not. Defaults to `True`.
|
| 1261 |
+
pipeline_class (`str`, *optional*, defaults to `None`):
|
| 1262 |
+
The pipeline class to use. Pass `None` to determine automatically.
|
| 1263 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 1264 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
| 1265 |
+
vae (`AutoencoderKL`, *optional*, defaults to `None`):
|
| 1266 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If
|
| 1267 |
+
this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
|
| 1268 |
+
text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
|
| 1269 |
+
An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)
|
| 1270 |
+
to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
|
| 1271 |
+
variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
|
| 1272 |
+
tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
|
| 1273 |
+
An instance of
|
| 1274 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
|
| 1275 |
+
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if
|
| 1276 |
+
needed.
|
| 1277 |
+
config_files (`Dict[str, str]`, *optional*, defaults to `None`):
|
| 1278 |
+
A dictionary mapping from config file names to their contents. If this parameter is `None`, the function
|
| 1279 |
+
will load the config files by itself, if needed. Valid keys are:
|
| 1280 |
+
- `v1`: Config file for Stable Diffusion v1
|
| 1281 |
+
- `v2`: Config file for Stable Diffusion v2
|
| 1282 |
+
- `xl`: Config file for Stable Diffusion XL
|
| 1283 |
+
- `xl_refiner`: Config file for Stable Diffusion XL Refiner
|
| 1284 |
+
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
| 1285 |
+
"""
|
| 1286 |
+
|
| 1287 |
+
# import pipelines here to avoid circular import error when using from_single_file method
|
| 1288 |
+
from ppdiffusers import (
|
| 1289 |
+
LDMTextToImagePipeline,
|
| 1290 |
+
PaintByExamplePipeline,
|
| 1291 |
+
StableDiffusionControlNetPipeline,
|
| 1292 |
+
StableDiffusionInpaintPipeline,
|
| 1293 |
+
StableDiffusionPipeline,
|
| 1294 |
+
StableDiffusionUpscalePipeline,
|
| 1295 |
+
StableDiffusionXLImg2ImgPipeline,
|
| 1296 |
+
StableDiffusionXLPipeline,
|
| 1297 |
+
StableUnCLIPImg2ImgPipeline,
|
| 1298 |
+
StableUnCLIPPipeline,
|
| 1299 |
+
)
|
| 1300 |
+
|
| 1301 |
+
if prediction_type == "v-prediction":
|
| 1302 |
+
prediction_type = "v_prediction"
|
| 1303 |
+
|
| 1304 |
+
if not is_omegaconf_available():
|
| 1305 |
+
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
| 1306 |
+
|
| 1307 |
+
from omegaconf import OmegaConf
|
| 1308 |
+
|
| 1309 |
+
if isinstance(checkpoint_path_or_dict, str):
|
| 1310 |
+
checkpoint = smart_load(checkpoint_path_or_dict, return_numpy=True, return_global_step=True)
|
| 1311 |
+
|
| 1312 |
+
elif isinstance(checkpoint_path_or_dict, dict):
|
| 1313 |
+
checkpoint = checkpoint_path_or_dict
|
| 1314 |
+
|
| 1315 |
+
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
|
| 1316 |
+
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
|
| 1317 |
+
while "state_dict" in checkpoint:
|
| 1318 |
+
checkpoint = checkpoint["state_dict"]
|
| 1319 |
+
|
| 1320 |
+
global_step = int(checkpoint.pop("global_step", -1))
|
| 1321 |
+
|
| 1322 |
+
if global_step == -1:
|
| 1323 |
+
print("global_step key not found in model")
|
| 1324 |
+
|
| 1325 |
+
# must cast them to float32
|
| 1326 |
+
newcheckpoint = {}
|
| 1327 |
+
for k, v in checkpoint.items():
|
| 1328 |
+
try:
|
| 1329 |
+
if "int" in str(v.dtype):
|
| 1330 |
+
continue
|
| 1331 |
+
except Exception:
|
| 1332 |
+
continue
|
| 1333 |
+
newcheckpoint[k] = v.astype("float32")
|
| 1334 |
+
checkpoint = newcheckpoint
|
| 1335 |
+
|
| 1336 |
+
if original_config_file is None:
|
| 1337 |
+
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
| 1338 |
+
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
|
| 1339 |
+
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
|
| 1340 |
+
is_upscale = pipeline_class == StableDiffusionUpscalePipeline
|
| 1341 |
+
|
| 1342 |
+
config_url = None
|
| 1343 |
+
|
| 1344 |
+
# model_type = "v1"
|
| 1345 |
+
if config_files is not None and "v1" in config_files:
|
| 1346 |
+
original_config_file = config_files["v1"]
|
| 1347 |
+
else:
|
| 1348 |
+
config_url = "https://paddlenlp.bj.bcebos.com/models/community/junnyu/develop/v1-inference.yaml"
|
| 1349 |
+
|
| 1350 |
+
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
|
| 1351 |
+
# model_type = "v2"
|
| 1352 |
+
if config_files is not None and "v2" in config_files:
|
| 1353 |
+
original_config_file = config_files["v2"]
|
| 1354 |
+
else:
|
| 1355 |
+
config_url = "https://paddlenlp.bj.bcebos.com/models/community/junnyu/develop/v2-inference-v.yaml"
|
| 1356 |
+
if global_step == 110000:
|
| 1357 |
+
# v2.1 needs to upcast attention
|
| 1358 |
+
upcast_attention = True
|
| 1359 |
+
elif key_name_sd_xl_base in checkpoint:
|
| 1360 |
+
# only base xl has two text embedders
|
| 1361 |
+
if config_files is not None and "xl" in config_files:
|
| 1362 |
+
original_config_file = config_files["xl"]
|
| 1363 |
+
else:
|
| 1364 |
+
config_url = "https://paddlenlp.bj.bcebos.com/models/community/junnyu/develop/sd_xl_base.yaml"
|
| 1365 |
+
elif key_name_sd_xl_refiner in checkpoint:
|
| 1366 |
+
# only refiner xl has embedder and one text embedders
|
| 1367 |
+
if config_files is not None and "xl_refiner" in config_files:
|
| 1368 |
+
original_config_file = config_files["xl_refiner"]
|
| 1369 |
+
else:
|
| 1370 |
+
config_url = "https://paddlenlp.bj.bcebos.com/models/community/junnyu/develop/sd_xl_refiner.yaml"
|
| 1371 |
+
|
| 1372 |
+
if is_upscale:
|
| 1373 |
+
config_url = "https://paddlenlp.bj.bcebos.com/models/community/junnyu/develop/x4-upscaling.yaml"
|
| 1374 |
+
|
| 1375 |
+
if config_url is not None:
|
| 1376 |
+
original_config_file = BytesIO(requests.get(config_url).content)
|
| 1377 |
+
|
| 1378 |
+
original_config = OmegaConf.load(original_config_file)
|
| 1379 |
+
|
| 1380 |
+
# Convert the text model.
|
| 1381 |
+
if (
|
| 1382 |
+
model_type is None
|
| 1383 |
+
and "cond_stage_config" in original_config.model.params
|
| 1384 |
+
and original_config.model.params.cond_stage_config is not None
|
| 1385 |
+
):
|
| 1386 |
+
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
| 1387 |
+
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
|
| 1388 |
+
elif model_type is None and original_config.model.params.network_config is not None:
|
| 1389 |
+
if original_config.model.params.network_config.params.context_dim == 2048:
|
| 1390 |
+
model_type = "SDXL"
|
| 1391 |
+
else:
|
| 1392 |
+
model_type = "SDXL-Refiner"
|
| 1393 |
+
if image_size is None:
|
| 1394 |
+
image_size = 1024
|
| 1395 |
+
|
| 1396 |
+
if pipeline_class is None:
|
| 1397 |
+
# Check if we have a SDXL or SD model and initialize default pipeline
|
| 1398 |
+
if model_type not in ["SDXL", "SDXL-Refiner"]:
|
| 1399 |
+
pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
|
| 1400 |
+
else:
|
| 1401 |
+
pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline
|
| 1402 |
+
|
| 1403 |
+
if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
|
| 1404 |
+
num_in_channels = 9
|
| 1405 |
+
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
|
| 1406 |
+
num_in_channels = 7
|
| 1407 |
+
elif num_in_channels is None:
|
| 1408 |
+
num_in_channels = 4
|
| 1409 |
+
|
| 1410 |
+
if "unet_config" in original_config.model.params:
|
| 1411 |
+
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
| 1412 |
+
|
| 1413 |
+
if (
|
| 1414 |
+
"parameterization" in original_config["model"]["params"]
|
| 1415 |
+
and original_config["model"]["params"]["parameterization"] == "v"
|
| 1416 |
+
):
|
| 1417 |
+
if prediction_type is None:
|
| 1418 |
+
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
|
| 1419 |
+
# as it relies on a brittle global step parameter here
|
| 1420 |
+
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
|
| 1421 |
+
if image_size is None:
|
| 1422 |
+
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
|
| 1423 |
+
# as it relies on a brittle global step parameter here
|
| 1424 |
+
image_size = 512 if global_step == 875000 else 768
|
| 1425 |
+
else:
|
| 1426 |
+
if prediction_type is None:
|
| 1427 |
+
prediction_type = "epsilon"
|
| 1428 |
+
if image_size is None:
|
| 1429 |
+
image_size = 512
|
| 1430 |
+
|
| 1431 |
+
if controlnet is None and "control_stage_config" in original_config.model.params:
|
| 1432 |
+
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
|
| 1433 |
+
controlnet = convert_controlnet_checkpoint(
|
| 1434 |
+
checkpoint, original_config, path, image_size, upcast_attention, extract_ema
|
| 1435 |
+
)
|
| 1436 |
+
|
| 1437 |
+
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
|
| 1438 |
+
|
| 1439 |
+
if model_type in ["SDXL", "SDXL-Refiner"]:
|
| 1440 |
+
scheduler_dict = {
|
| 1441 |
+
"beta_schedule": "scaled_linear",
|
| 1442 |
+
"beta_start": 0.00085,
|
| 1443 |
+
"beta_end": 0.012,
|
| 1444 |
+
"interpolation_type": "linear",
|
| 1445 |
+
"num_train_timesteps": num_train_timesteps,
|
| 1446 |
+
"prediction_type": "epsilon",
|
| 1447 |
+
"sample_max_value": 1.0,
|
| 1448 |
+
"set_alpha_to_one": False,
|
| 1449 |
+
"skip_prk_steps": True,
|
| 1450 |
+
"steps_offset": 1,
|
| 1451 |
+
"timestep_spacing": "leading",
|
| 1452 |
+
}
|
| 1453 |
+
scheduler = EulerDiscreteScheduler.from_config(scheduler_dict)
|
| 1454 |
+
scheduler_type = "euler"
|
| 1455 |
+
else:
|
| 1456 |
+
beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
|
| 1457 |
+
beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
|
| 1458 |
+
scheduler = DDIMScheduler(
|
| 1459 |
+
beta_end=beta_end,
|
| 1460 |
+
beta_schedule="scaled_linear",
|
| 1461 |
+
beta_start=beta_start,
|
| 1462 |
+
num_train_timesteps=num_train_timesteps,
|
| 1463 |
+
steps_offset=1,
|
| 1464 |
+
clip_sample=False,
|
| 1465 |
+
set_alpha_to_one=False,
|
| 1466 |
+
prediction_type=prediction_type,
|
| 1467 |
+
)
|
| 1468 |
+
# make sure scheduler works correctly with DDIM
|
| 1469 |
+
scheduler.register_to_config(clip_sample=False)
|
| 1470 |
+
|
| 1471 |
+
if scheduler_type == "pndm":
|
| 1472 |
+
config = dict(scheduler.config)
|
| 1473 |
+
config["skip_prk_steps"] = True
|
| 1474 |
+
scheduler = PNDMScheduler.from_config(config)
|
| 1475 |
+
elif scheduler_type == "lms":
|
| 1476 |
+
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
| 1477 |
+
elif scheduler_type == "heun":
|
| 1478 |
+
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
| 1479 |
+
elif scheduler_type == "euler":
|
| 1480 |
+
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
| 1481 |
+
elif scheduler_type == "euler-ancestral":
|
| 1482 |
+
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
| 1483 |
+
elif scheduler_type == "dpm":
|
| 1484 |
+
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
| 1485 |
+
elif scheduler_type == "ddim":
|
| 1486 |
+
scheduler = scheduler
|
| 1487 |
+
else:
|
| 1488 |
+
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
| 1489 |
+
|
| 1490 |
+
if pipeline_class == StableDiffusionUpscalePipeline:
|
| 1491 |
+
image_size = original_config.model.params.unet_config.params.image_size
|
| 1492 |
+
|
| 1493 |
+
# Convert the UNet2DConditionModel model.
|
| 1494 |
+
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
| 1495 |
+
unet_config["upcast_attention"] = upcast_attention
|
| 1496 |
+
|
| 1497 |
+
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
|
| 1498 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
| 1499 |
+
checkpoint, unet_config, path=path, extract_ema=extract_ema
|
| 1500 |
+
)
|
| 1501 |
+
|
| 1502 |
+
init_contexts = []
|
| 1503 |
+
init_contexts.append(paddle.dtype_guard(paddle.float32))
|
| 1504 |
+
init_contexts.append(no_init_weights(_enable=True))
|
| 1505 |
+
if hasattr(paddle, "LazyGuard"):
|
| 1506 |
+
init_contexts.append(paddle.LazyGuard())
|
| 1507 |
+
with ContextManagers(init_contexts):
|
| 1508 |
+
unet = UNet2DConditionModel(**unet_config)
|
| 1509 |
+
|
| 1510 |
+
faster_set_state_dict(unet, convert_diffusers_vae_unet_to_ppdiffusers(unet, converted_unet_checkpoint))
|
| 1511 |
+
|
| 1512 |
+
# Convert the VAE model.
|
| 1513 |
+
if vae_path is None and vae is None:
|
| 1514 |
+
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
| 1515 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
| 1516 |
+
|
| 1517 |
+
if (
|
| 1518 |
+
"model" in original_config
|
| 1519 |
+
and "params" in original_config.model
|
| 1520 |
+
and "scale_factor" in original_config.model.params
|
| 1521 |
+
):
|
| 1522 |
+
vae_scaling_factor = original_config.model.params.scale_factor
|
| 1523 |
+
else:
|
| 1524 |
+
vae_scaling_factor = 0.18215 # default SD scaling factor
|
| 1525 |
+
|
| 1526 |
+
vae_config["scaling_factor"] = vae_scaling_factor
|
| 1527 |
+
|
| 1528 |
+
init_contexts = []
|
| 1529 |
+
init_contexts.append(paddle.dtype_guard(paddle.float32))
|
| 1530 |
+
init_contexts.append(no_init_weights(_enable=True))
|
| 1531 |
+
if hasattr(paddle, "LazyGuard"):
|
| 1532 |
+
init_contexts.append(paddle.LazyGuard())
|
| 1533 |
+
with ContextManagers(init_contexts):
|
| 1534 |
+
vae = AutoencoderKL(**vae_config)
|
| 1535 |
+
|
| 1536 |
+
faster_set_state_dict(vae, convert_diffusers_vae_unet_to_ppdiffusers(vae, converted_vae_checkpoint))
|
| 1537 |
+
|
| 1538 |
+
elif vae is None:
|
| 1539 |
+
vae = AutoencoderKL.from_pretrained(vae_path, local_files_only=local_files_only)
|
| 1540 |
+
|
| 1541 |
+
if model_type == "FrozenOpenCLIPEmbedder":
|
| 1542 |
+
config_name = "stabilityai/stable-diffusion-2"
|
| 1543 |
+
config_kwargs = {"subfolder": "text_encoder"}
|
| 1544 |
+
|
| 1545 |
+
text_model = convert_open_clip_checkpoint(
|
| 1546 |
+
checkpoint, config_name, local_files_only=local_files_only, **config_kwargs
|
| 1547 |
+
)
|
| 1548 |
+
|
| 1549 |
+
try:
|
| 1550 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
| 1551 |
+
"stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only
|
| 1552 |
+
)
|
| 1553 |
+
except Exception:
|
| 1554 |
+
raise ValueError(
|
| 1555 |
+
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'stabilityai/stable-diffusion-2'."
|
| 1556 |
+
)
|
| 1557 |
+
|
| 1558 |
+
if stable_unclip is None:
|
| 1559 |
+
if controlnet:
|
| 1560 |
+
pipe = pipeline_class(
|
| 1561 |
+
vae=vae,
|
| 1562 |
+
text_encoder=text_model,
|
| 1563 |
+
tokenizer=tokenizer,
|
| 1564 |
+
unet=unet,
|
| 1565 |
+
scheduler=scheduler,
|
| 1566 |
+
controlnet=controlnet,
|
| 1567 |
+
safety_checker=None,
|
| 1568 |
+
feature_extractor=None,
|
| 1569 |
+
)
|
| 1570 |
+
if hasattr(pipe, "requires_safety_checker"):
|
| 1571 |
+
pipe.requires_safety_checker = False
|
| 1572 |
+
|
| 1573 |
+
elif pipeline_class == StableDiffusionUpscalePipeline:
|
| 1574 |
+
scheduler = DDIMScheduler.from_pretrained(
|
| 1575 |
+
"stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler"
|
| 1576 |
+
)
|
| 1577 |
+
low_res_scheduler = DDPMScheduler.from_pretrained(
|
| 1578 |
+
"stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler"
|
| 1579 |
+
)
|
| 1580 |
+
|
| 1581 |
+
pipe = pipeline_class(
|
| 1582 |
+
vae=vae,
|
| 1583 |
+
text_encoder=text_model,
|
| 1584 |
+
tokenizer=tokenizer,
|
| 1585 |
+
unet=unet,
|
| 1586 |
+
scheduler=scheduler,
|
| 1587 |
+
low_res_scheduler=low_res_scheduler,
|
| 1588 |
+
safety_checker=None,
|
| 1589 |
+
feature_extractor=None,
|
| 1590 |
+
)
|
| 1591 |
+
|
| 1592 |
+
else:
|
| 1593 |
+
pipe = pipeline_class(
|
| 1594 |
+
vae=vae,
|
| 1595 |
+
text_encoder=text_model,
|
| 1596 |
+
tokenizer=tokenizer,
|
| 1597 |
+
unet=unet,
|
| 1598 |
+
scheduler=scheduler,
|
| 1599 |
+
safety_checker=None,
|
| 1600 |
+
feature_extractor=None,
|
| 1601 |
+
)
|
| 1602 |
+
if hasattr(pipe, "requires_safety_checker"):
|
| 1603 |
+
pipe.requires_safety_checker = False
|
| 1604 |
+
|
| 1605 |
+
else:
|
| 1606 |
+
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
|
| 1607 |
+
original_config,
|
| 1608 |
+
clip_stats_path=clip_stats_path,
|
| 1609 |
+
)
|
| 1610 |
+
|
| 1611 |
+
if stable_unclip == "img2img":
|
| 1612 |
+
feature_extractor, image_encoder = stable_unclip_image_encoder(original_config)
|
| 1613 |
+
|
| 1614 |
+
pipe = StableUnCLIPImg2ImgPipeline(
|
| 1615 |
+
# image encoding components
|
| 1616 |
+
feature_extractor=feature_extractor,
|
| 1617 |
+
image_encoder=image_encoder,
|
| 1618 |
+
# image noising components
|
| 1619 |
+
image_normalizer=image_normalizer,
|
| 1620 |
+
image_noising_scheduler=image_noising_scheduler,
|
| 1621 |
+
# regular denoising components
|
| 1622 |
+
tokenizer=tokenizer,
|
| 1623 |
+
text_encoder=text_model,
|
| 1624 |
+
unet=unet,
|
| 1625 |
+
scheduler=scheduler,
|
| 1626 |
+
# vae
|
| 1627 |
+
vae=vae,
|
| 1628 |
+
)
|
| 1629 |
+
elif stable_unclip == "txt2img":
|
| 1630 |
+
if stable_unclip_prior is None or stable_unclip_prior == "karlo":
|
| 1631 |
+
karlo_model = "kakaobrain/karlo-v1-alpha"
|
| 1632 |
+
prior = PriorTransformer.from_pretrained(
|
| 1633 |
+
karlo_model, subfolder="prior", local_files_only=local_files_only
|
| 1634 |
+
)
|
| 1635 |
+
|
| 1636 |
+
try:
|
| 1637 |
+
prior_tokenizer = CLIPTokenizer.from_pretrained(
|
| 1638 |
+
"openai/clip-vit-large-patch14", local_files_only=local_files_only
|
| 1639 |
+
)
|
| 1640 |
+
except Exception:
|
| 1641 |
+
raise ValueError(
|
| 1642 |
+
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
|
| 1643 |
+
)
|
| 1644 |
+
prior_text_model = CLIPTextModelWithProjection.from_pretrained(
|
| 1645 |
+
"openai/clip-vit-large-patch14", local_files_only=local_files_only
|
| 1646 |
+
)
|
| 1647 |
+
|
| 1648 |
+
prior_scheduler = UnCLIPScheduler.from_pretrained(
|
| 1649 |
+
karlo_model, subfolder="prior_scheduler", local_files_only=local_files_only
|
| 1650 |
+
)
|
| 1651 |
+
prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)
|
| 1652 |
+
else:
|
| 1653 |
+
raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}")
|
| 1654 |
+
|
| 1655 |
+
pipe = StableUnCLIPPipeline(
|
| 1656 |
+
# prior components
|
| 1657 |
+
prior_tokenizer=prior_tokenizer,
|
| 1658 |
+
prior_text_encoder=prior_text_model,
|
| 1659 |
+
prior=prior,
|
| 1660 |
+
prior_scheduler=prior_scheduler,
|
| 1661 |
+
# image noising components
|
| 1662 |
+
image_normalizer=image_normalizer,
|
| 1663 |
+
image_noising_scheduler=image_noising_scheduler,
|
| 1664 |
+
# regular denoising components
|
| 1665 |
+
tokenizer=tokenizer,
|
| 1666 |
+
text_encoder=text_model,
|
| 1667 |
+
unet=unet,
|
| 1668 |
+
scheduler=scheduler,
|
| 1669 |
+
# vae
|
| 1670 |
+
vae=vae,
|
| 1671 |
+
)
|
| 1672 |
+
else:
|
| 1673 |
+
raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}")
|
| 1674 |
+
elif model_type == "PaintByExample":
|
| 1675 |
+
vision_model = convert_paint_by_example_checkpoint(checkpoint)
|
| 1676 |
+
try:
|
| 1677 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
| 1678 |
+
"openai/clip-vit-large-patch14", local_files_only=local_files_only
|
| 1679 |
+
)
|
| 1680 |
+
except Exception:
|
| 1681 |
+
raise ValueError(
|
| 1682 |
+
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
|
| 1683 |
+
)
|
| 1684 |
+
try:
|
| 1685 |
+
feature_extractor = CLIPImageProcessor.from_pretrained(
|
| 1686 |
+
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
|
| 1687 |
+
)
|
| 1688 |
+
except Exception:
|
| 1689 |
+
raise ValueError(
|
| 1690 |
+
f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'."
|
| 1691 |
+
)
|
| 1692 |
+
pipe = PaintByExamplePipeline(
|
| 1693 |
+
vae=vae,
|
| 1694 |
+
image_encoder=vision_model,
|
| 1695 |
+
unet=unet,
|
| 1696 |
+
scheduler=scheduler,
|
| 1697 |
+
safety_checker=None,
|
| 1698 |
+
feature_extractor=feature_extractor,
|
| 1699 |
+
)
|
| 1700 |
+
elif model_type == "FrozenCLIPEmbedder":
|
| 1701 |
+
text_model = convert_ldm_clip_checkpoint(
|
| 1702 |
+
checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
|
| 1703 |
+
)
|
| 1704 |
+
try:
|
| 1705 |
+
tokenizer = (
|
| 1706 |
+
CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
|
| 1707 |
+
if tokenizer is None
|
| 1708 |
+
else tokenizer
|
| 1709 |
+
)
|
| 1710 |
+
except Exception:
|
| 1711 |
+
raise ValueError(
|
| 1712 |
+
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
|
| 1713 |
+
)
|
| 1714 |
+
|
| 1715 |
+
if load_safety_checker:
|
| 1716 |
+
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
| 1717 |
+
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
|
| 1718 |
+
)
|
| 1719 |
+
feature_extractor = CLIPImageProcessor.from_pretrained(
|
| 1720 |
+
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
|
| 1721 |
+
)
|
| 1722 |
+
else:
|
| 1723 |
+
safety_checker = None
|
| 1724 |
+
feature_extractor = None
|
| 1725 |
+
|
| 1726 |
+
if controlnet:
|
| 1727 |
+
pipe = pipeline_class(
|
| 1728 |
+
vae=vae,
|
| 1729 |
+
text_encoder=text_model,
|
| 1730 |
+
tokenizer=tokenizer,
|
| 1731 |
+
unet=unet,
|
| 1732 |
+
controlnet=controlnet,
|
| 1733 |
+
scheduler=scheduler,
|
| 1734 |
+
safety_checker=safety_checker,
|
| 1735 |
+
feature_extractor=feature_extractor,
|
| 1736 |
+
requires_safety_checker=load_safety_checker,
|
| 1737 |
+
)
|
| 1738 |
+
else:
|
| 1739 |
+
pipe = pipeline_class(
|
| 1740 |
+
vae=vae,
|
| 1741 |
+
text_encoder=text_model,
|
| 1742 |
+
tokenizer=tokenizer,
|
| 1743 |
+
unet=unet,
|
| 1744 |
+
scheduler=scheduler,
|
| 1745 |
+
safety_checker=safety_checker,
|
| 1746 |
+
feature_extractor=feature_extractor,
|
| 1747 |
+
requires_safety_checker=load_safety_checker,
|
| 1748 |
+
)
|
| 1749 |
+
elif model_type in ["SDXL", "SDXL-Refiner"]:
|
| 1750 |
+
if model_type == "SDXL":
|
| 1751 |
+
try:
|
| 1752 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
| 1753 |
+
"openai/clip-vit-large-patch14", local_files_only=local_files_only
|
| 1754 |
+
)
|
| 1755 |
+
except Exception:
|
| 1756 |
+
raise ValueError(
|
| 1757 |
+
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
|
| 1758 |
+
)
|
| 1759 |
+
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
|
| 1760 |
+
try:
|
| 1761 |
+
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
| 1762 |
+
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
|
| 1763 |
+
)
|
| 1764 |
+
except Exception:
|
| 1765 |
+
raise ValueError(
|
| 1766 |
+
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
|
| 1767 |
+
)
|
| 1768 |
+
|
| 1769 |
+
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
| 1770 |
+
config_kwargs = {"projection_dim": 1280}
|
| 1771 |
+
text_encoder_2 = convert_open_clip_checkpoint(
|
| 1772 |
+
checkpoint,
|
| 1773 |
+
config_name,
|
| 1774 |
+
prefix="conditioner.embedders.1.model.",
|
| 1775 |
+
has_projection=True,
|
| 1776 |
+
local_files_only=local_files_only,
|
| 1777 |
+
**config_kwargs,
|
| 1778 |
+
)
|
| 1779 |
+
|
| 1780 |
+
if controlnet:
|
| 1781 |
+
pipe = pipeline_class(
|
| 1782 |
+
vae=vae,
|
| 1783 |
+
text_encoder=text_encoder,
|
| 1784 |
+
tokenizer=tokenizer,
|
| 1785 |
+
text_encoder_2=text_encoder_2,
|
| 1786 |
+
tokenizer_2=tokenizer_2,
|
| 1787 |
+
unet=unet,
|
| 1788 |
+
controlnet=controlnet,
|
| 1789 |
+
scheduler=scheduler,
|
| 1790 |
+
force_zeros_for_empty_prompt=True,
|
| 1791 |
+
)
|
| 1792 |
+
elif adapter:
|
| 1793 |
+
pipe = pipeline_class(
|
| 1794 |
+
vae=vae,
|
| 1795 |
+
text_encoder=text_encoder,
|
| 1796 |
+
tokenizer=tokenizer,
|
| 1797 |
+
text_encoder_2=text_encoder_2,
|
| 1798 |
+
tokenizer_2=tokenizer_2,
|
| 1799 |
+
unet=unet,
|
| 1800 |
+
adapter=adapter,
|
| 1801 |
+
scheduler=scheduler,
|
| 1802 |
+
force_zeros_for_empty_prompt=True,
|
| 1803 |
+
)
|
| 1804 |
+
else:
|
| 1805 |
+
pipe = pipeline_class(
|
| 1806 |
+
vae=vae,
|
| 1807 |
+
text_encoder=text_encoder,
|
| 1808 |
+
tokenizer=tokenizer,
|
| 1809 |
+
text_encoder_2=text_encoder_2,
|
| 1810 |
+
tokenizer_2=tokenizer_2,
|
| 1811 |
+
unet=unet,
|
| 1812 |
+
scheduler=scheduler,
|
| 1813 |
+
force_zeros_for_empty_prompt=True,
|
| 1814 |
+
)
|
| 1815 |
+
else:
|
| 1816 |
+
tokenizer = None
|
| 1817 |
+
text_encoder = None
|
| 1818 |
+
try:
|
| 1819 |
+
tokenizer_2 = CLIPTokenizer.from_pretrained(
|
| 1820 |
+
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
|
| 1821 |
+
)
|
| 1822 |
+
except Exception:
|
| 1823 |
+
raise ValueError(
|
| 1824 |
+
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
|
| 1825 |
+
)
|
| 1826 |
+
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
| 1827 |
+
config_kwargs = {"projection_dim": 1280}
|
| 1828 |
+
text_encoder_2 = convert_open_clip_checkpoint(
|
| 1829 |
+
checkpoint,
|
| 1830 |
+
config_name,
|
| 1831 |
+
prefix="conditioner.embedders.0.model.",
|
| 1832 |
+
has_projection=True,
|
| 1833 |
+
local_files_only=local_files_only,
|
| 1834 |
+
**config_kwargs,
|
| 1835 |
+
)
|
| 1836 |
+
|
| 1837 |
+
pipe = StableDiffusionXLImg2ImgPipeline(
|
| 1838 |
+
vae=vae,
|
| 1839 |
+
text_encoder=text_encoder,
|
| 1840 |
+
tokenizer=tokenizer,
|
| 1841 |
+
text_encoder_2=text_encoder_2,
|
| 1842 |
+
tokenizer_2=tokenizer_2,
|
| 1843 |
+
unet=unet,
|
| 1844 |
+
scheduler=scheduler,
|
| 1845 |
+
requires_aesthetics_score=True,
|
| 1846 |
+
force_zeros_for_empty_prompt=False,
|
| 1847 |
+
)
|
| 1848 |
+
else:
|
| 1849 |
+
text_config = create_ldm_bert_config(original_config)
|
| 1850 |
+
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
| 1851 |
+
tokenizer = BertTokenizer.from_pretrained(
|
| 1852 |
+
"bert-base-uncased", local_files_only=local_files_only, model_max_length=77
|
| 1853 |
+
)
|
| 1854 |
+
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
| 1855 |
+
if paddle_dtype is not None:
|
| 1856 |
+
pipe.to(paddle_dtype=paddle_dtype)
|
| 1857 |
+
|
| 1858 |
+
return pipe
|
| 1859 |
+
|
| 1860 |
+
|
| 1861 |
+
def download_controlnet_from_original_ckpt(
|
| 1862 |
+
checkpoint_path: str,
|
| 1863 |
+
original_config_file: str,
|
| 1864 |
+
image_size: int = 512,
|
| 1865 |
+
extract_ema: bool = False,
|
| 1866 |
+
num_in_channels: Optional[int] = None,
|
| 1867 |
+
upcast_attention: Optional[bool] = None,
|
| 1868 |
+
device: str = None,
|
| 1869 |
+
from_safetensors: bool = False,
|
| 1870 |
+
use_linear_projection: Optional[bool] = None,
|
| 1871 |
+
cross_attention_dim: Optional[bool] = None,
|
| 1872 |
+
) -> DiffusionPipeline:
|
| 1873 |
+
if not is_omegaconf_available():
|
| 1874 |
+
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
| 1875 |
+
|
| 1876 |
+
from omegaconf import OmegaConf
|
| 1877 |
+
|
| 1878 |
+
checkpoint = smart_load(checkpoint_path, return_numpy=True)
|
| 1879 |
+
|
| 1880 |
+
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
|
| 1881 |
+
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
|
| 1882 |
+
while "state_dict" in checkpoint:
|
| 1883 |
+
checkpoint = checkpoint["state_dict"]
|
| 1884 |
+
|
| 1885 |
+
# must cast them to float32
|
| 1886 |
+
newcheckpoint = {}
|
| 1887 |
+
for k, v in checkpoint.items():
|
| 1888 |
+
try:
|
| 1889 |
+
if "int" in str(v.dtype):
|
| 1890 |
+
continue
|
| 1891 |
+
except Exception:
|
| 1892 |
+
continue
|
| 1893 |
+
newcheckpoint[k] = v.astype("float32")
|
| 1894 |
+
checkpoint = newcheckpoint
|
| 1895 |
+
|
| 1896 |
+
original_config = OmegaConf.load(original_config_file)
|
| 1897 |
+
|
| 1898 |
+
if num_in_channels is not None:
|
| 1899 |
+
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
| 1900 |
+
|
| 1901 |
+
if "control_stage_config" not in original_config.model.params:
|
| 1902 |
+
raise ValueError("`control_stage_config` not present in original config")
|
| 1903 |
+
|
| 1904 |
+
controlnet = convert_controlnet_checkpoint(
|
| 1905 |
+
checkpoint,
|
| 1906 |
+
original_config,
|
| 1907 |
+
checkpoint_path,
|
| 1908 |
+
image_size,
|
| 1909 |
+
upcast_attention,
|
| 1910 |
+
extract_ema,
|
| 1911 |
+
use_linear_projection=use_linear_projection,
|
| 1912 |
+
cross_attention_dim=cross_attention_dim,
|
| 1913 |
+
)
|
| 1914 |
+
|
| 1915 |
+
return controlnet
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/stable_diffusion/pipeline_fastdeploy_stable_diffusion_inpaint.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import paddle
|
| 20 |
+
import PIL
|
| 21 |
+
|
| 22 |
+
from ppdiffusers.transformers import CLIPImageProcessor, CLIPTokenizer
|
| 23 |
+
|
| 24 |
+
from ...image_processor import PipelineImageInput
|
| 25 |
+
from ...loaders import IPAdapterMixin
|
| 26 |
+
from ...schedulers import KarrasDiffusionSchedulers
|
| 27 |
+
from ...utils import PIL_INTERPOLATION, logging
|
| 28 |
+
from ..fastdeploy_utils import FastDeployDiffusionPipelineMixin, FastDeployRuntimeModel
|
| 29 |
+
from ..pipeline_utils import DiffusionPipeline
|
| 30 |
+
from . import StableDiffusionPipelineOutput
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def prepare_mask_and_masked_image(image, mask, height=None, width=None, return_image: bool = False):
|
| 36 |
+
"""
|
| 37 |
+
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
|
| 38 |
+
converted to ``paddle.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
|
| 39 |
+
``image`` and ``1`` for the ``mask``.
|
| 40 |
+
|
| 41 |
+
The ``image`` will be converted to ``paddle.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
|
| 42 |
+
binarized (``mask > 0.5``) and cast to ``paddle.float32`` too.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
image (Union[np.array, PIL.Image, paddle.Tensor]): The image to inpaint.
|
| 46 |
+
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
|
| 47 |
+
``paddle.Tensor`` or a ``batch x channels x height x width`` ``paddle.Tensor``.
|
| 48 |
+
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
|
| 49 |
+
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
|
| 50 |
+
``paddle.Tensor`` or a ``batch x 1 x height x width`` ``paddle.Tensor``.
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
Raises:
|
| 54 |
+
ValueError: ``paddle.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``paddle.Tensor`` mask
|
| 55 |
+
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
|
| 56 |
+
TypeError: ``mask`` is a ``paddle.Tensor`` but ``image`` is not
|
| 57 |
+
(ot the other way around).
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
tuple[paddle.Tensor]: The pair (mask, masked_image) as ``paddle.Tensor`` with 4
|
| 61 |
+
dimensions: ``batch x channels x height x width``.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
if image is None:
|
| 65 |
+
raise ValueError("`image` input cannot be undefined.")
|
| 66 |
+
|
| 67 |
+
if mask is None:
|
| 68 |
+
raise ValueError("`mask_image` input cannot be undefined.")
|
| 69 |
+
|
| 70 |
+
if isinstance(image, paddle.Tensor):
|
| 71 |
+
if not isinstance(mask, paddle.Tensor):
|
| 72 |
+
raise TypeError(f"`image` is a paddle.Tensor but `mask` (type: {type(mask)} is not")
|
| 73 |
+
|
| 74 |
+
# Batch single image
|
| 75 |
+
if image.ndim == 3:
|
| 76 |
+
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
|
| 77 |
+
image = image.unsqueeze(0)
|
| 78 |
+
|
| 79 |
+
# Batch and add channel dim for single mask
|
| 80 |
+
if mask.ndim == 2:
|
| 81 |
+
mask = mask.unsqueeze(0).unsqueeze(0)
|
| 82 |
+
|
| 83 |
+
# Batch single mask or add channel dim
|
| 84 |
+
if mask.ndim == 3:
|
| 85 |
+
# Single batched mask, no channel dim or single mask not batched but channel dim
|
| 86 |
+
if mask.shape[0] == 1:
|
| 87 |
+
mask = mask.unsqueeze(0)
|
| 88 |
+
|
| 89 |
+
# Batched masks no channel dim
|
| 90 |
+
else:
|
| 91 |
+
mask = mask.unsqueeze(1)
|
| 92 |
+
|
| 93 |
+
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
|
| 94 |
+
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
|
| 95 |
+
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
|
| 96 |
+
|
| 97 |
+
# Check image is in [-1, 1]
|
| 98 |
+
if image.min() < -1 or image.max() > 1:
|
| 99 |
+
raise ValueError("Image should be in [-1, 1] range")
|
| 100 |
+
|
| 101 |
+
# Check mask is in [0, 1]
|
| 102 |
+
if mask.min() < 0 or mask.max() > 1:
|
| 103 |
+
raise ValueError("Mask should be in [0, 1] range")
|
| 104 |
+
|
| 105 |
+
# Binarize mask
|
| 106 |
+
mask[mask < 0.5] = 0
|
| 107 |
+
mask[mask >= 0.5] = 1
|
| 108 |
+
|
| 109 |
+
# Image as float32
|
| 110 |
+
image = image.cast(dtype=paddle.float32)
|
| 111 |
+
elif isinstance(mask, paddle.Tensor):
|
| 112 |
+
raise TypeError(f"`mask` is a paddle.Tensor but `image` (type: {type(image)} is not")
|
| 113 |
+
else:
|
| 114 |
+
# preprocess image
|
| 115 |
+
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
| 116 |
+
image = [image]
|
| 117 |
+
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
| 118 |
+
# resize all images w.r.t passed height an width
|
| 119 |
+
if width is None or height is None:
|
| 120 |
+
w, h = image[0].size
|
| 121 |
+
else:
|
| 122 |
+
w, h = width, height
|
| 123 |
+
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
| 124 |
+
image = [i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) for i in image]
|
| 125 |
+
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
| 126 |
+
image = np.concatenate(image, axis=0)
|
| 127 |
+
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
| 128 |
+
image = np.concatenate([i[None, :] for i in image], axis=0)
|
| 129 |
+
|
| 130 |
+
image = image.transpose(0, 3, 1, 2)
|
| 131 |
+
image = paddle.to_tensor(image, dtype=paddle.float32) / 127.5 - 1.0
|
| 132 |
+
|
| 133 |
+
# preprocess mask
|
| 134 |
+
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
|
| 135 |
+
mask = [mask]
|
| 136 |
+
|
| 137 |
+
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
|
| 138 |
+
# resize all images w.r.t passed height an width
|
| 139 |
+
if width is None or height is None:
|
| 140 |
+
w, h = mask[0].size
|
| 141 |
+
else:
|
| 142 |
+
w, h = width, height
|
| 143 |
+
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
| 144 |
+
mask = [i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) for i in mask]
|
| 145 |
+
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
|
| 146 |
+
mask = mask.astype(np.float32) / 255.0
|
| 147 |
+
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
| 148 |
+
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
| 149 |
+
|
| 150 |
+
mask[mask < 0.5] = 0
|
| 151 |
+
mask[mask >= 0.5] = 1
|
| 152 |
+
mask = paddle.to_tensor(mask)
|
| 153 |
+
|
| 154 |
+
masked_image = image * (mask < 0.5).cast(image.dtype)
|
| 155 |
+
|
| 156 |
+
# n.b. ensure backwards compatibility as old function does not return image
|
| 157 |
+
if return_image:
|
| 158 |
+
return mask, masked_image, image
|
| 159 |
+
|
| 160 |
+
return mask, masked_image
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class FastDeployStableDiffusionInpaintPipeline(DiffusionPipeline, FastDeployDiffusionPipelineMixin, IPAdapterMixin):
|
| 164 |
+
r"""
|
| 165 |
+
Pipeline for text-guided image inpainting using Stable Diffusion.
|
| 166 |
+
|
| 167 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 168 |
+
library implements for all the pipelines (such as downloading or saving etc.)
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
vae_encoder ([`FastDeployRuntimeModel`]):
|
| 172 |
+
Variational Auto-Encoder (VAE) Model to encode images to latent representations.
|
| 173 |
+
vae_decoder ([`FastDeployRuntimeModel`]):
|
| 174 |
+
Variational Auto-Encoder (VAE) Model to decode images from latent representations.
|
| 175 |
+
text_encoder ([`FastDeployRuntimeModel`]):
|
| 176 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
| 177 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 178 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 179 |
+
tokenizer (`CLIPTokenizer`):
|
| 180 |
+
Tokenizer of class
|
| 181 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 182 |
+
unet ([`FastDeployRuntimeModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
| 183 |
+
scheduler ([`SchedulerMixin`]):
|
| 184 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 185 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], [`PNDMScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`]
|
| 186 |
+
or [`DPMSolverMultistepScheduler`].
|
| 187 |
+
safety_checker ([`FastDeployRuntimeModel`]):
|
| 188 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 189 |
+
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
| 190 |
+
feature_extractor ([`CLIPImageProcessor`]):
|
| 191 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
| 192 |
+
"""
|
| 193 |
+
_optional_components = ["image_encoder", "safety_checker", "feature_extractor"]
|
| 194 |
+
|
| 195 |
+
def __init__(
|
| 196 |
+
self,
|
| 197 |
+
vae_encoder: FastDeployRuntimeModel,
|
| 198 |
+
vae_decoder: FastDeployRuntimeModel,
|
| 199 |
+
text_encoder: FastDeployRuntimeModel,
|
| 200 |
+
tokenizer: CLIPTokenizer,
|
| 201 |
+
unet: FastDeployRuntimeModel,
|
| 202 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 203 |
+
safety_checker: FastDeployRuntimeModel,
|
| 204 |
+
feature_extractor: CLIPImageProcessor,
|
| 205 |
+
image_encoder: FastDeployRuntimeModel,
|
| 206 |
+
requires_safety_checker: bool = False,
|
| 207 |
+
):
|
| 208 |
+
super().__init__()
|
| 209 |
+
if safety_checker is None and requires_safety_checker:
|
| 210 |
+
logger.warning(
|
| 211 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
| 212 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
| 213 |
+
" results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
|
| 214 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
| 215 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
| 216 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
| 217 |
+
)
|
| 218 |
+
if safety_checker is not None and feature_extractor is None:
|
| 219 |
+
raise ValueError(
|
| 220 |
+
f"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
| 221 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
self.register_modules(
|
| 225 |
+
vae_encoder=vae_encoder,
|
| 226 |
+
vae_decoder=vae_decoder,
|
| 227 |
+
text_encoder=text_encoder,
|
| 228 |
+
tokenizer=tokenizer,
|
| 229 |
+
unet=unet,
|
| 230 |
+
scheduler=scheduler,
|
| 231 |
+
safety_checker=safety_checker,
|
| 232 |
+
feature_extractor=feature_extractor,
|
| 233 |
+
image_encoder=image_encoder,
|
| 234 |
+
)
|
| 235 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
| 236 |
+
self.post_init()
|
| 237 |
+
|
| 238 |
+
def __call__(
|
| 239 |
+
self,
|
| 240 |
+
prompt: Union[str, List[str]] = None,
|
| 241 |
+
image: Union[paddle.Tensor, PIL.Image.Image] = None,
|
| 242 |
+
mask_image: Union[paddle.Tensor, PIL.Image.Image] = None,
|
| 243 |
+
height: int = None,
|
| 244 |
+
width: int = None,
|
| 245 |
+
strength: float = 1.0,
|
| 246 |
+
num_inference_steps: int = 50,
|
| 247 |
+
guidance_scale: float = 7.5,
|
| 248 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 249 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 250 |
+
timesteps: List[int] = None,
|
| 251 |
+
add_predicted_noise: Optional[bool] = False,
|
| 252 |
+
eta: float = 0.0,
|
| 253 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 254 |
+
latents: Optional[paddle.Tensor] = None,
|
| 255 |
+
parse_prompt_type: Optional[str] = "lpw",
|
| 256 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 257 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 258 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 259 |
+
guidance_rescale: float = 0.0,
|
| 260 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 261 |
+
output_type: Optional[str] = "pil",
|
| 262 |
+
return_dict: bool = True,
|
| 263 |
+
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
|
| 264 |
+
callback_steps: Optional[int] = 1,
|
| 265 |
+
controlnet_cond: Union[paddle.Tensor, PIL.Image.Image] = None,
|
| 266 |
+
controlnet_conditioning_scale: float = 1.0,
|
| 267 |
+
infer_op_dict: Dict[str, str] = None,
|
| 268 |
+
):
|
| 269 |
+
r"""
|
| 270 |
+
Function invoked when calling the pipeline for generation.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 274 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 275 |
+
instead.
|
| 276 |
+
image (`paddle.Tensor` or `PIL.Image.Image`):
|
| 277 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 278 |
+
process. This is the image whose masked region will be inpainted.
|
| 279 |
+
mask_image (`paddle.Tensor` or `PIL.Image.Image`):
|
| 280 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
| 281 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
| 282 |
+
PIL image, it will be converted to a single channel (luminance) before use. If mask is a tensor, the
|
| 283 |
+
expected shape should be either `(B, H, W, C)` or `(B, C, H, W)`, where C is 1 or 3.
|
| 284 |
+
height (`int`, *optional*, defaults to None):
|
| 285 |
+
The height in pixels of the generated image.
|
| 286 |
+
width (`int`, *optional*, defaults to None):
|
| 287 |
+
The width in pixels of the generated image.
|
| 288 |
+
strength (`float`, *optional*, defaults to 1.0):
|
| 289 |
+
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
|
| 290 |
+
is 1, the denoising process will be run on the masked area for the full number of iterations specified
|
| 291 |
+
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to
|
| 292 |
+
that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
|
| 293 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 294 |
+
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
|
| 295 |
+
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
|
| 296 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 297 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 298 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 299 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 300 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 301 |
+
usually at the expense of lower image quality.
|
| 302 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 303 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 304 |
+
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
|
| 305 |
+
is less than `1`).
|
| 306 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 307 |
+
The number of images to generate per prompt.
|
| 308 |
+
add_predicted_noise (`bool`, *optional*, defaults to False):
|
| 309 |
+
Use predicted noise instead of random noise when constructing noisy versions of the original image in
|
| 310 |
+
the reverse diffusion process
|
| 311 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 312 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 313 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 314 |
+
generator (`paddle.Generator`, *optional*):
|
| 315 |
+
One or a list of [paddle generator(s)] to make generation deterministic.
|
| 316 |
+
latents (`paddle.Tensor`, *optional*):
|
| 317 |
+
Pre-generated noise tensor, sampled from a Gaussian distribution, to be used as inputs for image
|
| 318 |
+
generation. If not provided, a noise tensor will ge generated by sampling using the supplied random
|
| 319 |
+
`generator`.
|
| 320 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 321 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 322 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 323 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 324 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 325 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 326 |
+
argument.
|
| 327 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 328 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 329 |
+
The output format of the generate image. Choose between
|
| 330 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 331 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 332 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 333 |
+
plain tuple.
|
| 334 |
+
callback (`Callable`, *optional*):
|
| 335 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 336 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
|
| 337 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 338 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 339 |
+
called at every step.
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 343 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 344 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 345 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 346 |
+
(nsfw) content, according to the `safety_checker`.
|
| 347 |
+
"""
|
| 348 |
+
# 0. Preprocess mask and image
|
| 349 |
+
mask, masked_image, init_image = prepare_mask_and_masked_image(
|
| 350 |
+
image,
|
| 351 |
+
mask_image,
|
| 352 |
+
height,
|
| 353 |
+
width,
|
| 354 |
+
return_image=True,
|
| 355 |
+
)
|
| 356 |
+
height, width = init_image.shape[-2:]
|
| 357 |
+
|
| 358 |
+
# 1. Check inputs
|
| 359 |
+
self.check_inputs(
|
| 360 |
+
prompt,
|
| 361 |
+
height,
|
| 362 |
+
width,
|
| 363 |
+
callback_steps,
|
| 364 |
+
negative_prompt,
|
| 365 |
+
prompt_embeds,
|
| 366 |
+
negative_prompt_embeds,
|
| 367 |
+
strength,
|
| 368 |
+
)
|
| 369 |
+
infer_op_dict = self.prepare_infer_op_dict(infer_op_dict)
|
| 370 |
+
|
| 371 |
+
# 2. Define call parameters
|
| 372 |
+
if prompt is not None and isinstance(prompt, str):
|
| 373 |
+
batch_size = 1
|
| 374 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 375 |
+
batch_size = len(prompt)
|
| 376 |
+
else:
|
| 377 |
+
batch_size = prompt_embeds.shape[0]
|
| 378 |
+
|
| 379 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 380 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 381 |
+
# corresponds to doing no classifier free guidance.
|
| 382 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 383 |
+
|
| 384 |
+
# 3. Encode input prompt
|
| 385 |
+
prompt_embeds = self._encode_prompt(
|
| 386 |
+
prompt,
|
| 387 |
+
num_images_per_prompt,
|
| 388 |
+
do_classifier_free_guidance,
|
| 389 |
+
negative_prompt,
|
| 390 |
+
prompt_embeds=prompt_embeds,
|
| 391 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 392 |
+
parse_prompt_type=parse_prompt_type,
|
| 393 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
| 394 |
+
infer_op=infer_op_dict.get("text_encoder", None),
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
if ip_adapter_image is not None:
|
| 398 |
+
image_embeds, negative_image_embeds = self.encode_image(
|
| 399 |
+
ip_adapter_image, num_images_per_prompt, infer_op=infer_op_dict.get("image_encoder", None)
|
| 400 |
+
)
|
| 401 |
+
if do_classifier_free_guidance:
|
| 402 |
+
image_embeds = paddle.concat([negative_image_embeds, image_embeds])
|
| 403 |
+
|
| 404 |
+
# 4. set timesteps
|
| 405 |
+
timesteps, num_inference_steps = self.retrieve_timesteps(self.scheduler, num_inference_steps, timesteps)
|
| 406 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
|
| 407 |
+
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
| 408 |
+
latent_timestep = timesteps[:1].tile([batch_size * num_images_per_prompt])
|
| 409 |
+
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
| 410 |
+
is_strength_max = strength == 1.0
|
| 411 |
+
|
| 412 |
+
# 5. Prepare latent variables
|
| 413 |
+
num_channels_latents = self.vae_decoder_num_latent_channels
|
| 414 |
+
num_channels_unet = self.unet_num_latent_channels
|
| 415 |
+
is_legacy = return_image_latents = num_channels_unet == 4
|
| 416 |
+
|
| 417 |
+
latents_outputs = self.prepare_latents(
|
| 418 |
+
batch_size * num_images_per_prompt,
|
| 419 |
+
height,
|
| 420 |
+
width,
|
| 421 |
+
generator,
|
| 422 |
+
latents,
|
| 423 |
+
image=init_image,
|
| 424 |
+
timestep=latent_timestep,
|
| 425 |
+
is_strength_max=is_strength_max,
|
| 426 |
+
return_noise=True,
|
| 427 |
+
return_image_latents=return_image_latents,
|
| 428 |
+
infer_op=infer_op_dict.get("vae_encoder", None),
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
if return_image_latents:
|
| 432 |
+
latents, noise, image_latents = latents_outputs
|
| 433 |
+
else:
|
| 434 |
+
latents, noise = latents_outputs
|
| 435 |
+
|
| 436 |
+
# 6. Prepare mask latent variables
|
| 437 |
+
mask, masked_image_latents = self.prepare_mask_latents(
|
| 438 |
+
mask,
|
| 439 |
+
masked_image,
|
| 440 |
+
batch_size * num_images_per_prompt,
|
| 441 |
+
height,
|
| 442 |
+
width,
|
| 443 |
+
do_classifier_free_guidance,
|
| 444 |
+
return_masked_image_latents=True,
|
| 445 |
+
infer_op=infer_op_dict.get("vae_encoder", None),
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
# 7. Check that sizes of mask, masked image and latents match
|
| 449 |
+
if num_channels_unet == 9:
|
| 450 |
+
# default case for runwayml/stable-diffusion-inpainting
|
| 451 |
+
num_channels_mask = mask.shape[1]
|
| 452 |
+
num_channels_masked_image = masked_image_latents.shape[1]
|
| 453 |
+
if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:
|
| 454 |
+
raise ValueError(
|
| 455 |
+
f"Incorrect configuration settings! Received `num_channels_latents`: {num_channels_latents} +"
|
| 456 |
+
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
| 457 |
+
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
| 458 |
+
" `pipeline.unet` or your `mask_image` or `image` input."
|
| 459 |
+
)
|
| 460 |
+
elif num_channels_unet != 4:
|
| 461 |
+
raise ValueError(f"The unet should have either 4 or 9 input channels, not {num_channels_unet}.")
|
| 462 |
+
# do_controlnet
|
| 463 |
+
do_controlnet = controlnet_cond is not None and num_channels_unet == 4
|
| 464 |
+
if do_controlnet:
|
| 465 |
+
control_image, control_conditioning_scale = self.prepare_controlnet_cond(
|
| 466 |
+
controlnet_cond=controlnet_cond,
|
| 467 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
| 468 |
+
width=width,
|
| 469 |
+
height=height,
|
| 470 |
+
batch_size=batch_size,
|
| 471 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 472 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 476 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 477 |
+
|
| 478 |
+
if do_classifier_free_guidance:
|
| 479 |
+
init_mask = mask[: mask.shape[0] // 2]
|
| 480 |
+
else:
|
| 481 |
+
init_mask = mask
|
| 482 |
+
|
| 483 |
+
# 9. Denoising loop
|
| 484 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 485 |
+
is_scheduler_support_step_index = self.is_scheduler_support_step_index()
|
| 486 |
+
|
| 487 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 488 |
+
for i, t in enumerate(timesteps):
|
| 489 |
+
# expand the latents if we are doing classifier free guidance
|
| 490 |
+
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
|
| 491 |
+
if is_scheduler_support_step_index:
|
| 492 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, step_index=i)
|
| 493 |
+
else:
|
| 494 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 495 |
+
|
| 496 |
+
output_shape = latent_model_input.shape
|
| 497 |
+
if not is_legacy:
|
| 498 |
+
# concat latents, mask, masked_image_latents in the channel dimension
|
| 499 |
+
latent_model_input = paddle.concat([latent_model_input, mask, masked_image_latents], axis=1)
|
| 500 |
+
|
| 501 |
+
unet_inputs = dict(
|
| 502 |
+
sample=latent_model_input,
|
| 503 |
+
timestep=t,
|
| 504 |
+
encoder_hidden_states=prompt_embeds,
|
| 505 |
+
infer_op=infer_op_dict.get("unet", None),
|
| 506 |
+
output_shape=output_shape,
|
| 507 |
+
)
|
| 508 |
+
if do_controlnet:
|
| 509 |
+
unet_inputs["controlnet_cond"] = control_image
|
| 510 |
+
unet_inputs["controlnet_conditioning_scale"] = control_conditioning_scale
|
| 511 |
+
# Add image embeds for IP-Adapter
|
| 512 |
+
# added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
| 513 |
+
if ip_adapter_image:
|
| 514 |
+
unet_inputs["image_embeds"] = image_embeds
|
| 515 |
+
|
| 516 |
+
# predict the noise residual
|
| 517 |
+
noise_pred_unet = self.unet(**unet_inputs)[0]
|
| 518 |
+
|
| 519 |
+
# perform guidance
|
| 520 |
+
if do_classifier_free_guidance:
|
| 521 |
+
noise_pred_uncond, noise_pred_text = noise_pred_unet.chunk(2)
|
| 522 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 523 |
+
|
| 524 |
+
if guidance_rescale > 0.0:
|
| 525 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
| 526 |
+
noise_pred = self.rescale_noise_cfg(
|
| 527 |
+
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
| 528 |
+
)
|
| 529 |
+
else:
|
| 530 |
+
noise_pred = noise_pred_unet
|
| 531 |
+
|
| 532 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 533 |
+
if is_scheduler_support_step_index:
|
| 534 |
+
scheduler_output = self.scheduler.step(
|
| 535 |
+
noise_pred, t, latents, step_index=i, return_pred_original_sample=False, **extra_step_kwargs
|
| 536 |
+
)
|
| 537 |
+
else:
|
| 538 |
+
scheduler_output = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
|
| 539 |
+
latents = scheduler_output.prev_sample
|
| 540 |
+
|
| 541 |
+
if is_legacy:
|
| 542 |
+
if i < len(timesteps) - 1:
|
| 543 |
+
# masking
|
| 544 |
+
if add_predicted_noise:
|
| 545 |
+
init_latents_proper = self.scheduler.add_noise(image_latents, noise_pred_uncond, t)
|
| 546 |
+
else:
|
| 547 |
+
# https://github.com/huggingface/diffusers/pull/3749/files#diff-39d36ab1e622684e35fe6971c12fb44e24756bdc383aba3d7f6e3b1625bdaafc
|
| 548 |
+
noise_timestep = timesteps[i + 1]
|
| 549 |
+
init_latents_proper = self.scheduler.add_noise(image_latents, noise, noise_timestep)
|
| 550 |
+
else:
|
| 551 |
+
init_latents_proper = image_latents
|
| 552 |
+
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
| 553 |
+
|
| 554 |
+
# call the callback, if provided
|
| 555 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 556 |
+
progress_bar.update()
|
| 557 |
+
if callback is not None and i % callback_steps == 0:
|
| 558 |
+
callback(i, t, latents)
|
| 559 |
+
if i == len(timesteps) - 1:
|
| 560 |
+
# sync for accuracy it/s measure
|
| 561 |
+
paddle.device.synchronize()
|
| 562 |
+
|
| 563 |
+
if not output_type == "latent":
|
| 564 |
+
image = self._decode_vae_latents(
|
| 565 |
+
latents / self.vae_scaling_factor, infer_op=infer_op_dict.get("vae_decoder", None)
|
| 566 |
+
)
|
| 567 |
+
image, has_nsfw_concept = self.run_safety_checker(image)
|
| 568 |
+
else:
|
| 569 |
+
image = latents
|
| 570 |
+
has_nsfw_concept = None
|
| 571 |
+
|
| 572 |
+
if has_nsfw_concept is None:
|
| 573 |
+
do_denormalize = [True] * image.shape[0]
|
| 574 |
+
else:
|
| 575 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 576 |
+
|
| 577 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 578 |
+
|
| 579 |
+
if not return_dict:
|
| 580 |
+
return (image, has_nsfw_concept)
|
| 581 |
+
|
| 582 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/stable_diffusion/pipeline_output.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import PIL.Image
|
| 20 |
+
|
| 21 |
+
from ...utils import BaseOutput
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class StableDiffusionPipelineOutput(BaseOutput):
|
| 26 |
+
"""
|
| 27 |
+
Output class for Stable Diffusion pipelines.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
| 31 |
+
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
| 32 |
+
num_channels)`.
|
| 33 |
+
nsfw_content_detected (`List[bool]`)
|
| 34 |
+
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
|
| 35 |
+
`None` if safety checking could not be performed.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
| 39 |
+
nsfw_content_detected: Optional[List[bool]]
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/stable_diffusion/pipeline_paddleinfer_stable_diffusion.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from typing import Callable, Dict, List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import paddle
|
| 20 |
+
import PIL
|
| 21 |
+
|
| 22 |
+
from ppdiffusers.transformers import CLIPImageProcessor, CLIPTokenizer
|
| 23 |
+
|
| 24 |
+
from ...image_processor import PipelineImageInput
|
| 25 |
+
from ...loaders import IPAdapterMixin
|
| 26 |
+
from ...schedulers import KarrasDiffusionSchedulers
|
| 27 |
+
from ...utils import logging
|
| 28 |
+
from ..paddleinfer_utils import (
|
| 29 |
+
PaddleInferDiffusionPipelineMixin,
|
| 30 |
+
PaddleInferRuntimeModel,
|
| 31 |
+
)
|
| 32 |
+
from ..pipeline_utils import DiffusionPipeline
|
| 33 |
+
from . import StableDiffusionPipelineOutput
|
| 34 |
+
|
| 35 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class PaddleInferStableDiffusionPipeline(DiffusionPipeline, PaddleInferDiffusionPipelineMixin, IPAdapterMixin):
|
| 39 |
+
r"""
|
| 40 |
+
Pipeline for text-to-image generation using Stable Diffusion.
|
| 41 |
+
|
| 42 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 43 |
+
library implements for all the pipelines (such as downloading or saving etc.)
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
vae_encoder ([`PaddleInferRuntimeModel`]):
|
| 47 |
+
Variational Auto-Encoder (VAE) Model to encode images to latent representations.
|
| 48 |
+
vae_decoder ([`PaddleInferRuntimeModel`]):
|
| 49 |
+
Variational Auto-Encoder (VAE) Model to decode images from latent representations.
|
| 50 |
+
text_encoder ([`PaddleInferRuntimeModel`]):
|
| 51 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
| 52 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 53 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 54 |
+
tokenizer (`CLIPTokenizer`):
|
| 55 |
+
Tokenizer of class
|
| 56 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 57 |
+
unet ([`PaddleInferRuntimeModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
| 58 |
+
scheduler ([`SchedulerMixin`]):
|
| 59 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 60 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], [`PNDMScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`]
|
| 61 |
+
or [`DPMSolverMultistepScheduler`].
|
| 62 |
+
safety_checker ([`PaddleInferRuntimeModel`]):
|
| 63 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 64 |
+
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
| 65 |
+
feature_extractor ([`CLIPImageProcessor`]):
|
| 66 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
| 67 |
+
"""
|
| 68 |
+
_optional_components = ["vae_encoder", "image_encoder", "safety_checker", "feature_extractor"]
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
vae_encoder: PaddleInferRuntimeModel,
|
| 73 |
+
vae_decoder: PaddleInferRuntimeModel,
|
| 74 |
+
text_encoder: PaddleInferRuntimeModel,
|
| 75 |
+
tokenizer: CLIPTokenizer,
|
| 76 |
+
unet: PaddleInferRuntimeModel,
|
| 77 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 78 |
+
safety_checker: PaddleInferRuntimeModel,
|
| 79 |
+
feature_extractor: CLIPImageProcessor,
|
| 80 |
+
image_encoder: PaddleInferRuntimeModel,
|
| 81 |
+
requires_safety_checker: bool = False,
|
| 82 |
+
):
|
| 83 |
+
super().__init__()
|
| 84 |
+
if safety_checker is None and requires_safety_checker:
|
| 85 |
+
logger.warning(
|
| 86 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
| 87 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
| 88 |
+
" results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
|
| 89 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
| 90 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
| 91 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
| 92 |
+
)
|
| 93 |
+
if safety_checker is not None and feature_extractor is None:
|
| 94 |
+
raise ValueError(
|
| 95 |
+
f"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
| 96 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
self.register_modules(
|
| 100 |
+
vae_encoder=vae_encoder,
|
| 101 |
+
vae_decoder=vae_decoder,
|
| 102 |
+
text_encoder=text_encoder,
|
| 103 |
+
tokenizer=tokenizer,
|
| 104 |
+
unet=unet,
|
| 105 |
+
scheduler=scheduler,
|
| 106 |
+
safety_checker=safety_checker,
|
| 107 |
+
feature_extractor=feature_extractor,
|
| 108 |
+
image_encoder=image_encoder,
|
| 109 |
+
)
|
| 110 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
| 111 |
+
self.post_init()
|
| 112 |
+
|
| 113 |
+
def __call__(
|
| 114 |
+
self,
|
| 115 |
+
prompt: Union[str, List[str]] = None,
|
| 116 |
+
height: Optional[int] = None,
|
| 117 |
+
width: Optional[int] = None,
|
| 118 |
+
num_inference_steps: int = 50,
|
| 119 |
+
guidance_scale: float = 7.5,
|
| 120 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 121 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 122 |
+
timesteps: List[int] = None,
|
| 123 |
+
eta: float = 0.0,
|
| 124 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 125 |
+
latents: Optional[paddle.Tensor] = None,
|
| 126 |
+
parse_prompt_type: Optional[str] = "lpw",
|
| 127 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 128 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 129 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 130 |
+
guidance_rescale: float = 0.0,
|
| 131 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 132 |
+
output_type: Optional[str] = "pil",
|
| 133 |
+
return_dict: bool = True,
|
| 134 |
+
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
|
| 135 |
+
callback_steps: Optional[int] = 1,
|
| 136 |
+
controlnet_cond: Union[paddle.Tensor, PIL.Image.Image] = None,
|
| 137 |
+
controlnet_conditioning_scale: float = 1.0,
|
| 138 |
+
infer_op_dict: Dict[str, str] = None,
|
| 139 |
+
):
|
| 140 |
+
r"""
|
| 141 |
+
Function invoked when calling the pipeline for generation.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 145 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 146 |
+
instead.
|
| 147 |
+
height (`int`, *optional*, defaults to None):
|
| 148 |
+
The height in pixels of the generated image.
|
| 149 |
+
width (`int`, *optional*, defaults to None):
|
| 150 |
+
The width in pixels of the generated image.
|
| 151 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 152 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 153 |
+
expense of slower inference.
|
| 154 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 155 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 156 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 157 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 158 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 159 |
+
usually at the expense of lower image quality.
|
| 160 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 161 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 162 |
+
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
|
| 163 |
+
is less than `1`).
|
| 164 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 165 |
+
The number of images to generate per prompt.
|
| 166 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 167 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 168 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 169 |
+
generator (`paddle.Generator`, *optional*):
|
| 170 |
+
One or a list of paddle generator(s) to make generation deterministic.
|
| 171 |
+
latents (`paddle.Tensor`, *optional*):
|
| 172 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 173 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 174 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 175 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 176 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 177 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 178 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 179 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 180 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 181 |
+
argument.
|
| 182 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 183 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 184 |
+
The output format of the generate image. Choose between
|
| 185 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 186 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 187 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 188 |
+
plain tuple.
|
| 189 |
+
callback (`Callable`, *optional*):
|
| 190 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 191 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
|
| 192 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 193 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 194 |
+
called at every step.
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 198 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 199 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 200 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 201 |
+
(nsfw) content, according to the `safety_checker`.
|
| 202 |
+
"""
|
| 203 |
+
# 0. Default height and width to unet
|
| 204 |
+
height = height or 512
|
| 205 |
+
width = width or 512
|
| 206 |
+
# 1. Check inputs. Raise error if not correct
|
| 207 |
+
self.check_inputs(
|
| 208 |
+
prompt,
|
| 209 |
+
height,
|
| 210 |
+
width,
|
| 211 |
+
callback_steps,
|
| 212 |
+
negative_prompt,
|
| 213 |
+
prompt_embeds,
|
| 214 |
+
negative_prompt_embeds,
|
| 215 |
+
)
|
| 216 |
+
infer_op_dict = self.prepare_infer_op_dict(infer_op_dict)
|
| 217 |
+
|
| 218 |
+
# 2. Define call parameters
|
| 219 |
+
if prompt is not None and isinstance(prompt, str):
|
| 220 |
+
batch_size = 1
|
| 221 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 222 |
+
batch_size = len(prompt)
|
| 223 |
+
else:
|
| 224 |
+
batch_size = prompt_embeds.shape[0]
|
| 225 |
+
|
| 226 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 227 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 228 |
+
# corresponds to doing no classifier free guidance.
|
| 229 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 230 |
+
|
| 231 |
+
# do_controlnet
|
| 232 |
+
do_controlnet = controlnet_cond is not None
|
| 233 |
+
if do_controlnet:
|
| 234 |
+
control_image, control_conditioning_scale = self.prepare_controlnet_cond(
|
| 235 |
+
controlnet_cond=controlnet_cond,
|
| 236 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
| 237 |
+
width=width,
|
| 238 |
+
height=height,
|
| 239 |
+
batch_size=batch_size,
|
| 240 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 241 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# 3. Encode input prompt
|
| 245 |
+
prompt_embeds = self._encode_prompt(
|
| 246 |
+
prompt,
|
| 247 |
+
num_images_per_prompt,
|
| 248 |
+
do_classifier_free_guidance,
|
| 249 |
+
negative_prompt,
|
| 250 |
+
prompt_embeds=prompt_embeds,
|
| 251 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 252 |
+
parse_prompt_type=parse_prompt_type,
|
| 253 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
| 254 |
+
infer_op=infer_op_dict.get("text_encoder", None),
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
if ip_adapter_image is not None:
|
| 258 |
+
image_embeds, negative_image_embeds = self.encode_image(
|
| 259 |
+
ip_adapter_image, num_images_per_prompt, infer_op=infer_op_dict.get("image_encoder", None)
|
| 260 |
+
)
|
| 261 |
+
if do_classifier_free_guidance:
|
| 262 |
+
image_embeds = paddle.concat([negative_image_embeds, image_embeds])
|
| 263 |
+
|
| 264 |
+
# 4. Prepare timesteps
|
| 265 |
+
timesteps, num_inference_steps = self.retrieve_timesteps(self.scheduler, num_inference_steps, timesteps)
|
| 266 |
+
# 5. Prepare latent variables
|
| 267 |
+
latents = self.prepare_latents(
|
| 268 |
+
batch_size * num_images_per_prompt,
|
| 269 |
+
height,
|
| 270 |
+
width,
|
| 271 |
+
generator,
|
| 272 |
+
latents,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 276 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 277 |
+
|
| 278 |
+
# 7. Denoising loop
|
| 279 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 280 |
+
|
| 281 |
+
is_scheduler_support_step_index = self.is_scheduler_support_step_index()
|
| 282 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 283 |
+
for i, t in enumerate(timesteps):
|
| 284 |
+
# expand the latents if we are doing classifier free guidance
|
| 285 |
+
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
|
| 286 |
+
if is_scheduler_support_step_index:
|
| 287 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, step_index=i)
|
| 288 |
+
else:
|
| 289 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 290 |
+
|
| 291 |
+
unet_inputs = dict(
|
| 292 |
+
sample=latent_model_input,
|
| 293 |
+
timestep=t,
|
| 294 |
+
encoder_hidden_states=prompt_embeds,
|
| 295 |
+
infer_op=infer_op_dict.get("unet", None),
|
| 296 |
+
output_shape=latent_model_input.shape,
|
| 297 |
+
)
|
| 298 |
+
if do_controlnet:
|
| 299 |
+
unet_inputs["controlnet_cond"] = control_image
|
| 300 |
+
unet_inputs["controlnet_conditioning_scale"] = control_conditioning_scale
|
| 301 |
+
# Add image embeds for IP-Adapter
|
| 302 |
+
# added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
| 303 |
+
if ip_adapter_image:
|
| 304 |
+
unet_inputs["image_embeds"] = image_embeds
|
| 305 |
+
|
| 306 |
+
# predict the noise residual
|
| 307 |
+
noise_pred_unet = self.unet(**unet_inputs)[0]
|
| 308 |
+
if str(os.environ.get("FLAGS_model_return_data")).lower() in ("true", "1"):
|
| 309 |
+
print(f"StableDiffusion infer: step {i+1} , origin output {noise_pred_unet.abs().numpy().mean()} ", flush=True)
|
| 310 |
+
|
| 311 |
+
# perform guidance
|
| 312 |
+
if do_classifier_free_guidance:
|
| 313 |
+
noise_pred_uncond, noise_pred_text = noise_pred_unet.chunk(2)
|
| 314 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 315 |
+
|
| 316 |
+
if guidance_rescale > 0.0:
|
| 317 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
| 318 |
+
noise_pred = self.rescale_noise_cfg(
|
| 319 |
+
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
| 320 |
+
)
|
| 321 |
+
else:
|
| 322 |
+
noise_pred = noise_pred_unet
|
| 323 |
+
|
| 324 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 325 |
+
if is_scheduler_support_step_index:
|
| 326 |
+
scheduler_output = self.scheduler.step(
|
| 327 |
+
noise_pred, t, latents, step_index=i, return_pred_original_sample=False, **extra_step_kwargs
|
| 328 |
+
)
|
| 329 |
+
else:
|
| 330 |
+
scheduler_output = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
|
| 331 |
+
latents = scheduler_output.prev_sample
|
| 332 |
+
# call the callback, if provided
|
| 333 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 334 |
+
progress_bar.update()
|
| 335 |
+
if callback is not None and i % callback_steps == 0:
|
| 336 |
+
callback(i, t, latents)
|
| 337 |
+
if i == len(timesteps) - 1:
|
| 338 |
+
# sync for accuracy it/s measure
|
| 339 |
+
paddle.device.synchronize()
|
| 340 |
+
|
| 341 |
+
if not output_type == "latent":
|
| 342 |
+
image = self._decode_vae_latents(
|
| 343 |
+
latents / self.vae_scaling_factor, infer_op=infer_op_dict.get("vae_decoder", None)
|
| 344 |
+
)
|
| 345 |
+
image, has_nsfw_concept = self.run_safety_checker(image)
|
| 346 |
+
else:
|
| 347 |
+
image = latents
|
| 348 |
+
has_nsfw_concept = None
|
| 349 |
+
|
| 350 |
+
if has_nsfw_concept is None:
|
| 351 |
+
do_denormalize = [True] * image.shape[0]
|
| 352 |
+
else:
|
| 353 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 354 |
+
|
| 355 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 356 |
+
|
| 357 |
+
if not return_dict:
|
| 358 |
+
return (image, has_nsfw_concept)
|
| 359 |
+
|
| 360 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
VLMEvalKit_old/PaddleMIX/ppdiffusers/ppdiffusers/pipelines/stable_diffusion/pipeline_paddleinfer_stable_diffusion_img2img.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import paddle
|
| 19 |
+
import PIL
|
| 20 |
+
|
| 21 |
+
from ppdiffusers.transformers import CLIPImageProcessor, CLIPTokenizer
|
| 22 |
+
|
| 23 |
+
from ...image_processor import PipelineImageInput
|
| 24 |
+
from ...loaders import IPAdapterMixin
|
| 25 |
+
from ...schedulers import KarrasDiffusionSchedulers
|
| 26 |
+
from ...utils import logging
|
| 27 |
+
from ..paddleinfer_utils import (
|
| 28 |
+
PaddleInferDiffusionPipelineMixin,
|
| 29 |
+
PaddleInferRuntimeModel,
|
| 30 |
+
)
|
| 31 |
+
from ..pipeline_utils import DiffusionPipeline
|
| 32 |
+
from . import StableDiffusionPipelineOutput
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class PaddleInferStableDiffusionImg2ImgPipeline(DiffusionPipeline, PaddleInferDiffusionPipelineMixin, IPAdapterMixin):
|
| 38 |
+
r"""
|
| 39 |
+
Pipeline for text-guided image-to-image generation using Stable Diffusion.
|
| 40 |
+
|
| 41 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 42 |
+
library implements for all the pipelines (such as downloading or saving etc.)
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
vae_encoder ([`PaddleInferRuntimeModel`]):
|
| 46 |
+
Variational Auto-Encoder (VAE) Model to encode images to latent representations.
|
| 47 |
+
vae_decoder ([`PaddleInferRuntimeModel`]):
|
| 48 |
+
Variational Auto-Encoder (VAE) Model to decode images from latent representations.
|
| 49 |
+
text_encoder ([`PaddleInferRuntimeModel`]):
|
| 50 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
| 51 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 52 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 53 |
+
tokenizer (`CLIPTokenizer`):
|
| 54 |
+
Tokenizer of class
|
| 55 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 56 |
+
unet ([`PaddleInferRuntimeModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
| 57 |
+
scheduler ([`SchedulerMixin`]):
|
| 58 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 59 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], [`PNDMScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`]
|
| 60 |
+
or [`DPMSolverMultistepScheduler`].
|
| 61 |
+
safety_checker ([`PaddleInferRuntimeModel`]):
|
| 62 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 63 |
+
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
| 64 |
+
feature_extractor ([`CLIPImageProcessor`]):
|
| 65 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
| 66 |
+
"""
|
| 67 |
+
_optional_components = ["image_encoder", "safety_checker", "feature_extractor"]
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
vae_encoder: PaddleInferRuntimeModel,
|
| 72 |
+
vae_decoder: PaddleInferRuntimeModel,
|
| 73 |
+
text_encoder: PaddleInferRuntimeModel,
|
| 74 |
+
tokenizer: CLIPTokenizer,
|
| 75 |
+
unet: PaddleInferRuntimeModel,
|
| 76 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 77 |
+
safety_checker: PaddleInferRuntimeModel,
|
| 78 |
+
feature_extractor: CLIPImageProcessor,
|
| 79 |
+
image_encoder: PaddleInferRuntimeModel,
|
| 80 |
+
requires_safety_checker: bool = False,
|
| 81 |
+
):
|
| 82 |
+
super().__init__()
|
| 83 |
+
if safety_checker is None and requires_safety_checker:
|
| 84 |
+
logger.warning(
|
| 85 |
+
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
| 86 |
+
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
| 87 |
+
" results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
|
| 88 |
+
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
| 89 |
+
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
| 90 |
+
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
| 91 |
+
)
|
| 92 |
+
if safety_checker is not None and feature_extractor is None:
|
| 93 |
+
raise ValueError(
|
| 94 |
+
f"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
| 95 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
self.register_modules(
|
| 99 |
+
vae_encoder=vae_encoder,
|
| 100 |
+
vae_decoder=vae_decoder,
|
| 101 |
+
text_encoder=text_encoder,
|
| 102 |
+
tokenizer=tokenizer,
|
| 103 |
+
unet=unet,
|
| 104 |
+
scheduler=scheduler,
|
| 105 |
+
safety_checker=safety_checker,
|
| 106 |
+
feature_extractor=feature_extractor,
|
| 107 |
+
image_encoder=image_encoder,
|
| 108 |
+
)
|
| 109 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
| 110 |
+
self.post_init()
|
| 111 |
+
|
| 112 |
+
def __call__(
|
| 113 |
+
self,
|
| 114 |
+
prompt: Union[str, List[str]] = None,
|
| 115 |
+
image: Union[paddle.Tensor, PIL.Image.Image] = None,
|
| 116 |
+
height: Optional[int] = None,
|
| 117 |
+
width: Optional[int] = None,
|
| 118 |
+
strength: float = 0.8,
|
| 119 |
+
num_inference_steps: int = 50,
|
| 120 |
+
guidance_scale: float = 7.5,
|
| 121 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 122 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 123 |
+
timesteps: List[int] = None,
|
| 124 |
+
eta: float = 0.0,
|
| 125 |
+
generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
|
| 126 |
+
latents: Optional[paddle.Tensor] = None,
|
| 127 |
+
parse_prompt_type: Optional[str] = "lpw",
|
| 128 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 129 |
+
prompt_embeds: Optional[paddle.Tensor] = None,
|
| 130 |
+
negative_prompt_embeds: Optional[paddle.Tensor] = None,
|
| 131 |
+
guidance_rescale: float = 0.0,
|
| 132 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 133 |
+
output_type: Optional[str] = "pil",
|
| 134 |
+
return_dict: bool = True,
|
| 135 |
+
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
|
| 136 |
+
callback_steps: Optional[int] = 1,
|
| 137 |
+
controlnet_cond: Union[paddle.Tensor, PIL.Image.Image] = None,
|
| 138 |
+
controlnet_conditioning_scale: float = 1.0,
|
| 139 |
+
infer_op_dict: Dict[str, str] = None,
|
| 140 |
+
):
|
| 141 |
+
r"""
|
| 142 |
+
Function invoked when calling the pipeline for generation.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 146 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 147 |
+
instead.
|
| 148 |
+
height (`int`, *optional*, defaults to None):
|
| 149 |
+
The height in pixels of the generated image.
|
| 150 |
+
width (`int`, *optional*, defaults to None):
|
| 151 |
+
The width in pixels of the generated image.
|
| 152 |
+
image (`paddle.Tensor` or `PIL.Image.Image`):
|
| 153 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 154 |
+
process.
|
| 155 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 156 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
| 157 |
+
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
| 158 |
+
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
| 159 |
+
be maximum and the denoising process will run for the full number of iterations specified in
|
| 160 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
| 161 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 162 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 163 |
+
expense of slower inference. This parameter will be modulated by `strength`.
|
| 164 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 165 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 166 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 167 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 168 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 169 |
+
usually at the expense of lower image quality.
|
| 170 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 171 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 172 |
+
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
|
| 173 |
+
is less than `1`).
|
| 174 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 175 |
+
The number of images to generate per prompt.
|
| 176 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 177 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 178 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 179 |
+
generator (`paddle.Generator`, *optional*):
|
| 180 |
+
One or a list of paddle generator(s) to make generation deterministic.
|
| 181 |
+
latents (`paddle.Tensor`, *optional*):
|
| 182 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 183 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 184 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 185 |
+
prompt_embeds (`paddle.Tensor`, *optional*):
|
| 186 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 187 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 188 |
+
negative_prompt_embeds (`paddle.Tensor`, *optional*):
|
| 189 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 190 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 191 |
+
argument.
|
| 192 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 193 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 194 |
+
The output format of the generate image. Choose between
|
| 195 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 196 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 197 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 198 |
+
plain tuple.
|
| 199 |
+
callback (`Callable`, *optional*):
|
| 200 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 201 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
|
| 202 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 203 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 204 |
+
called at every step.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 208 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 209 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 210 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 211 |
+
(nsfw) content, according to the `safety_checker`.
|
| 212 |
+
"""
|
| 213 |
+
# 0. Preprocess image
|
| 214 |
+
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
| 215 |
+
height, width = init_image.shape[-2:]
|
| 216 |
+
|
| 217 |
+
# 1. Check inputs. Raise error if not correct
|
| 218 |
+
self.check_inputs(
|
| 219 |
+
prompt,
|
| 220 |
+
height,
|
| 221 |
+
width,
|
| 222 |
+
callback_steps,
|
| 223 |
+
negative_prompt,
|
| 224 |
+
prompt_embeds,
|
| 225 |
+
negative_prompt_embeds,
|
| 226 |
+
strength,
|
| 227 |
+
)
|
| 228 |
+
infer_op_dict = self.prepare_infer_op_dict(infer_op_dict)
|
| 229 |
+
|
| 230 |
+
# 2. Define call parameters
|
| 231 |
+
if prompt is not None and isinstance(prompt, str):
|
| 232 |
+
batch_size = 1
|
| 233 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 234 |
+
batch_size = len(prompt)
|
| 235 |
+
else:
|
| 236 |
+
batch_size = prompt_embeds.shape[0]
|
| 237 |
+
|
| 238 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 239 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 240 |
+
# corresponds to doing no classifier free guidance.
|
| 241 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 242 |
+
|
| 243 |
+
# do_controlnet
|
| 244 |
+
do_controlnet = controlnet_cond is not None
|
| 245 |
+
if do_controlnet:
|
| 246 |
+
control_image, control_conditioning_scale = self.prepare_controlnet_cond(
|
| 247 |
+
controlnet_cond=controlnet_cond,
|
| 248 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
| 249 |
+
width=width,
|
| 250 |
+
height=height,
|
| 251 |
+
batch_size=batch_size,
|
| 252 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 253 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# 3. Encode input prompt
|
| 257 |
+
prompt_embeds = self._encode_prompt(
|
| 258 |
+
prompt,
|
| 259 |
+
num_images_per_prompt,
|
| 260 |
+
do_classifier_free_guidance,
|
| 261 |
+
negative_prompt,
|
| 262 |
+
prompt_embeds=prompt_embeds,
|
| 263 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 264 |
+
parse_prompt_type=parse_prompt_type,
|
| 265 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
| 266 |
+
infer_op=infer_op_dict.get("text_encoder", None),
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
if ip_adapter_image is not None:
|
| 270 |
+
image_embeds, negative_image_embeds = self.encode_image(
|
| 271 |
+
ip_adapter_image, num_images_per_prompt, infer_op=infer_op_dict.get("image_encoder", None)
|
| 272 |
+
)
|
| 273 |
+
if do_classifier_free_guidance:
|
| 274 |
+
image_embeds = paddle.concat([negative_image_embeds, image_embeds])
|
| 275 |
+
|
| 276 |
+
# 4. Prepare timesteps
|
| 277 |
+
timesteps, num_inference_steps = self.retrieve_timesteps(self.scheduler, num_inference_steps, timesteps)
|
| 278 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
|
| 279 |
+
|
| 280 |
+
# 5. Prepare latent variables
|
| 281 |
+
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
| 282 |
+
latent_timestep = timesteps[:1].tile([batch_size * num_images_per_prompt])
|
| 283 |
+
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
| 284 |
+
is_strength_max = strength == 1.0
|
| 285 |
+
latents = self.prepare_latents(
|
| 286 |
+
batch_size * num_images_per_prompt,
|
| 287 |
+
height,
|
| 288 |
+
width,
|
| 289 |
+
generator,
|
| 290 |
+
latents,
|
| 291 |
+
image=init_image,
|
| 292 |
+
timestep=latent_timestep,
|
| 293 |
+
is_strength_max=is_strength_max,
|
| 294 |
+
infer_op=infer_op_dict.get("vae_encoder", None),
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 298 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 299 |
+
|
| 300 |
+
# 7. Denoising loop
|
| 301 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 302 |
+
|
| 303 |
+
is_scheduler_support_step_index = self.is_scheduler_support_step_index()
|
| 304 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 305 |
+
for i, t in enumerate(timesteps):
|
| 306 |
+
# expand the latents if we are doing classifier free guidance
|
| 307 |
+
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
|
| 308 |
+
if is_scheduler_support_step_index:
|
| 309 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, step_index=i)
|
| 310 |
+
else:
|
| 311 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 312 |
+
|
| 313 |
+
unet_inputs = dict(
|
| 314 |
+
sample=latent_model_input,
|
| 315 |
+
timestep=t,
|
| 316 |
+
encoder_hidden_states=prompt_embeds,
|
| 317 |
+
infer_op=infer_op_dict.get("unet", None),
|
| 318 |
+
output_shape=latent_model_input.shape,
|
| 319 |
+
)
|
| 320 |
+
if do_controlnet:
|
| 321 |
+
unet_inputs["controlnet_cond"] = control_image
|
| 322 |
+
unet_inputs["controlnet_conditioning_scale"] = control_conditioning_scale
|
| 323 |
+
# Add image embeds for IP-Adapter
|
| 324 |
+
# added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
| 325 |
+
if ip_adapter_image:
|
| 326 |
+
unet_inputs["image_embeds"] = image_embeds
|
| 327 |
+
|
| 328 |
+
# predict the noise residual
|
| 329 |
+
noise_pred_unet = self.unet(**unet_inputs)[0]
|
| 330 |
+
|
| 331 |
+
# perform guidance
|
| 332 |
+
if do_classifier_free_guidance:
|
| 333 |
+
noise_pred_uncond, noise_pred_text = noise_pred_unet.chunk(2)
|
| 334 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 335 |
+
|
| 336 |
+
if guidance_rescale > 0.0:
|
| 337 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
| 338 |
+
noise_pred = self.rescale_noise_cfg(
|
| 339 |
+
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
| 340 |
+
)
|
| 341 |
+
else:
|
| 342 |
+
noise_pred = noise_pred_unet
|
| 343 |
+
|
| 344 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 345 |
+
if is_scheduler_support_step_index:
|
| 346 |
+
scheduler_output = self.scheduler.step(
|
| 347 |
+
noise_pred, t, latents, step_index=i, return_pred_original_sample=False, **extra_step_kwargs
|
| 348 |
+
)
|
| 349 |
+
else:
|
| 350 |
+
scheduler_output = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
|
| 351 |
+
latents = scheduler_output.prev_sample
|
| 352 |
+
|
| 353 |
+
# call the callback, if provided
|
| 354 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 355 |
+
progress_bar.update()
|
| 356 |
+
if callback is not None and i % callback_steps == 0:
|
| 357 |
+
callback(i, t, latents)
|
| 358 |
+
if i == len(timesteps) - 1:
|
| 359 |
+
# sync for accuracy it/s measure
|
| 360 |
+
paddle.device.synchronize()
|
| 361 |
+
|
| 362 |
+
if not output_type == "latent":
|
| 363 |
+
image = self._decode_vae_latents(
|
| 364 |
+
latents / self.vae_scaling_factor, infer_op=infer_op_dict.get("vae_decoder", None)
|
| 365 |
+
)
|
| 366 |
+
image, has_nsfw_concept = self.run_safety_checker(image)
|
| 367 |
+
else:
|
| 368 |
+
image = latents
|
| 369 |
+
has_nsfw_concept = None
|
| 370 |
+
|
| 371 |
+
if has_nsfw_concept is None:
|
| 372 |
+
do_denormalize = [True] * image.shape[0]
|
| 373 |
+
else:
|
| 374 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 375 |
+
|
| 376 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 377 |
+
|
| 378 |
+
if not return_dict:
|
| 379 |
+
return (image, has_nsfw_concept)
|
| 380 |
+
|
| 381 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|