Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,139 +1,590 @@
|
|
| 1 |
-
import
|
|
|
|
|
|
|
|
|
|
| 2 |
import inspect
|
| 3 |
-
import
|
| 4 |
-
|
| 5 |
-
|
| 6 |
import math
|
| 7 |
-
import
|
|
|
|
| 8 |
import random
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
import tempfile
|
| 11 |
-
import
|
| 12 |
-
|
|
|
|
| 13 |
import httpimport
|
| 14 |
-
import
|
| 15 |
-
|
| 16 |
from packaging import version
|
|
|
|
| 17 |
from PIL.PngImagePlugin import PngInfo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
if torch.cuda.get_device_properties(0).major >= 8:
|
| 20 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 21 |
torch.backends.cudnn.allow_tf32 = True
|
| 22 |
|
| 23 |
-
with httpimport.remote_repo(os.getenv("MODULE_URL")):
|
| 24 |
-
import pipeline
|
| 25 |
-
pipe, pipe2, pipe_img2img, pipe2_img2img = pipeline.get_pipeline_initialize()
|
| 26 |
-
|
| 27 |
theme = gr.themes.Base(font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'])
|
| 28 |
device="cuda"
|
| 29 |
pipe = pipe.to(device)
|
| 30 |
pipe2 = pipe2.to(device)
|
| 31 |
PRESET_Q = "year_2022, best quality, high quality, very aesthetic"
|
| 32 |
NEGATIVE_PROMPT = "lowres, worst quality, displeasing, bad anatomy, text, error, extra digit, cropped, error, fewer, extra, missing, worst quality, jpeg artifacts, censored, worst quality displeasing, bad quality"
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
-
import
|
| 36 |
-
import hmac
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
import codecs
|
| 42 |
|
| 43 |
-
def
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
def sign_message(message, key):
|
| 139 |
hmac_digest = hmac.new(key.encode(), message.encode(), hashlib.sha512).digest()
|
|
@@ -150,14 +601,6 @@ def run(prompt, radio="model-v2", preset=PRESET_Q, h=1216, w=832, negative_promp
|
|
| 150 |
init_image = init_image.resize((w, h))
|
| 151 |
init_image = np.array(init_image)
|
| 152 |
|
| 153 |
-
if tpu_inference:
|
| 154 |
-
prompt = prompt.replace("!", " ").replace("\n", " ") # remote endpoint unsupported
|
| 155 |
-
if do_img2img:
|
| 156 |
-
init_image = codecs.encode(pickle.dumps(init_image, protocol=pickle.HIGHEST_PROTOCOL), "base64").decode('latin1')
|
| 157 |
-
return tpu_inference_api(prompt, radio, preset, h, w, negative_prompt, guidance_scale, randomize_seed, seed, do_img2img, init_image, image2image_strength, inference_steps=inference_steps)
|
| 158 |
-
else:
|
| 159 |
-
return tpu_inference_api(prompt, radio, preset, h, w, negative_prompt, guidance_scale, randomize_seed, seed, inference_steps=inference_steps)
|
| 160 |
-
|
| 161 |
return zero_inference_api(prompt, radio, preset, h, w, negative_prompt, guidance_scale, randomize_seed, seed, do_img2img, init_image, image2image_strength, inference_steps=inference_steps)
|
| 162 |
|
| 163 |
@spaces.GPU
|
|
@@ -244,6 +687,8 @@ def zero_inference_api(prompt, radio="model-v2", preset=PRESET_Q, h=1216, w=832,
|
|
| 244 |
image.save(tmpfile, "png", pnginfo=metadata)
|
| 245 |
return tmpfile.name, seed
|
| 246 |
|
|
|
|
|
|
|
| 247 |
with gr.Blocks(theme=theme) as demo:
|
| 248 |
gr.Markdown('''# SDXL Experiments
|
| 249 |
Just a simple demo for some SDXL model.''')
|
|
@@ -292,5 +737,6 @@ with gr.Blocks(theme=theme) as demo:
|
|
| 292 |
outputs=[output, seed],
|
| 293 |
concurrency_limit=1,
|
| 294 |
)
|
|
|
|
| 295 |
if __name__ == "__main__":
|
| 296 |
demo.launch(share=True)
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import codecs
|
| 3 |
+
import hashlib
|
| 4 |
+
import hmac
|
| 5 |
import inspect
|
| 6 |
+
import io
|
| 7 |
+
import json
|
|
|
|
| 8 |
import math
|
| 9 |
+
import os
|
| 10 |
+
import pickle
|
| 11 |
import random
|
|
|
|
| 12 |
import tempfile
|
| 13 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 14 |
+
|
| 15 |
+
# Third-party general libraries
|
| 16 |
import httpimport
|
| 17 |
+
import numpy as np
|
| 18 |
+
import requests
|
| 19 |
from packaging import version
|
| 20 |
+
from PIL import Image
|
| 21 |
from PIL.PngImagePlugin import PngInfo
|
| 22 |
+
import PIL
|
| 23 |
+
|
| 24 |
+
# PyTorch
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
|
| 28 |
+
# Hugging Face & Diffusers
|
| 29 |
+
from transformers import CLIPTokenizer
|
| 30 |
+
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
|
| 31 |
+
from diffusers.loaders import (
|
| 32 |
+
StableDiffusionXLLoraLoaderMixin,
|
| 33 |
+
TextualInversionLoaderMixin,
|
| 34 |
+
)
|
| 35 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
| 36 |
+
from diffusers.schedulers import EulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler
|
| 37 |
+
from diffusers.utils import (
|
| 38 |
+
USE_PEFT_BACKEND,
|
| 39 |
+
logging,
|
| 40 |
+
scale_lora_layers,
|
| 41 |
+
unscale_lora_layers,
|
| 42 |
+
)
|
| 43 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 44 |
+
|
| 45 |
+
# App / UI
|
| 46 |
+
import gradio as gr
|
| 47 |
+
import spaces
|
| 48 |
|
| 49 |
if torch.cuda.get_device_properties(0).major >= 8:
|
| 50 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 51 |
torch.backends.cudnn.allow_tf32 = True
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
theme = gr.themes.Base(font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'])
|
| 54 |
device="cuda"
|
| 55 |
pipe = pipe.to(device)
|
| 56 |
pipe2 = pipe2.to(device)
|
| 57 |
PRESET_Q = "year_2022, best quality, high quality, very aesthetic"
|
| 58 |
NEGATIVE_PROMPT = "lowres, worst quality, displeasing, bad anatomy, text, error, extra digit, cropped, error, fewer, extra, missing, worst quality, jpeg artifacts, censored, worst quality displeasing, bad quality"
|
| 59 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 60 |
|
| 61 |
+
def get_class(name: str):
|
| 62 |
+
import importlib
|
|
|
|
| 63 |
|
| 64 |
+
module_name, class_name = name.rsplit(".", 1)
|
| 65 |
+
module = importlib.import_module(module_name, package=None)
|
| 66 |
+
return getattr(module, class_name)
|
|
|
|
| 67 |
|
| 68 |
+
def retrieve_latents(
|
| 69 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 70 |
+
):
|
| 71 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 72 |
+
return encoder_output.latent_dist.sample(generator)
|
| 73 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 74 |
+
return encoder_output.latent_dist.mode()
|
| 75 |
+
elif hasattr(encoder_output, "latents"):
|
| 76 |
+
return encoder_output.latents
|
| 77 |
+
else:
|
| 78 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 79 |
+
|
| 80 |
+
def parse_prompt_attention(text):
|
| 81 |
+
"""
|
| 82 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
| 83 |
+
Accepted tokens are:
|
| 84 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
| 85 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
| 86 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
| 87 |
+
\\( - literal character '('
|
| 88 |
+
\\[ - literal character '['
|
| 89 |
+
\\) - literal character ')'
|
| 90 |
+
\\] - literal character ']'
|
| 91 |
+
\\ - literal character '\'
|
| 92 |
+
anything else - just text
|
| 93 |
+
|
| 94 |
+
>>> parse_prompt_attention('normal text')
|
| 95 |
+
[['normal text', 1.0]]
|
| 96 |
+
>>> parse_prompt_attention('an (important) word')
|
| 97 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
| 98 |
+
>>> parse_prompt_attention('(unbalanced')
|
| 99 |
+
[['unbalanced', 1.1]]
|
| 100 |
+
>>> parse_prompt_attention('\\(literal\\]')
|
| 101 |
+
[['(literal]', 1.0]]
|
| 102 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
| 103 |
+
[['unnecessaryparens', 1.1]]
|
| 104 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
| 105 |
+
[['a ', 1.0],
|
| 106 |
+
['house', 1.5730000000000004],
|
| 107 |
+
[' ', 1.1],
|
| 108 |
+
['on', 1.0],
|
| 109 |
+
[' a ', 1.1],
|
| 110 |
+
['hill', 0.55],
|
| 111 |
+
[', sun, ', 1.1],
|
| 112 |
+
['sky', 1.4641000000000006],
|
| 113 |
+
['.', 1.1]]
|
| 114 |
+
"""
|
| 115 |
+
import re
|
| 116 |
+
|
| 117 |
+
re_attention = re.compile(
|
| 118 |
+
r"""
|
| 119 |
+
\{|\}|\\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|
|
| 120 |
+
\)|]|[^\\()\[\]:]+|:
|
| 121 |
+
""",
|
| 122 |
+
re.X,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
|
| 126 |
+
|
| 127 |
+
res = []
|
| 128 |
+
round_brackets = []
|
| 129 |
+
square_brackets = []
|
| 130 |
+
curly_brackets = []
|
| 131 |
+
round_bracket_multiplier = 1.05
|
| 132 |
+
curly_bracket_multiplier = 1.05
|
| 133 |
+
square_bracket_multiplier = 1 / 1.05
|
| 134 |
+
|
| 135 |
+
def multiply_range(start_position, multiplier):
|
| 136 |
+
for p in range(start_position, len(res)):
|
| 137 |
+
res[p][1] *= multiplier
|
| 138 |
+
|
| 139 |
+
for m in re_attention.finditer(text):
|
| 140 |
+
text = m.group(0)
|
| 141 |
+
weight = m.group(1)
|
| 142 |
+
|
| 143 |
+
if text.startswith("\\"):
|
| 144 |
+
res.append([text[1:], 1.0])
|
| 145 |
+
elif text == "(":
|
| 146 |
+
round_brackets.append(len(res))
|
| 147 |
+
elif text == "{":
|
| 148 |
+
curly_brackets.append(len(res))
|
| 149 |
+
elif text == "[":
|
| 150 |
+
square_brackets.append(len(res))
|
| 151 |
+
elif weight is not None and len(round_brackets) > 0:
|
| 152 |
+
multiply_range(round_brackets.pop(), float(weight))
|
| 153 |
+
elif text == ")" and len(round_brackets) > 0:
|
| 154 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
| 155 |
+
elif text == "}" and len(round_brackets) > 0:
|
| 156 |
+
multiply_range(curly_brackets.pop(), curly_bracket_multiplier)
|
| 157 |
+
elif text == "]" and len(square_brackets) > 0:
|
| 158 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
| 159 |
+
else:
|
| 160 |
+
parts = re.split(re_break, text)
|
| 161 |
+
for i, part in enumerate(parts):
|
| 162 |
+
if i > 0:
|
| 163 |
+
res.append(["BREAK", -1])
|
| 164 |
+
res.append([part, 1.0])
|
| 165 |
+
|
| 166 |
+
for pos in round_brackets:
|
| 167 |
+
multiply_range(pos, round_bracket_multiplier)
|
| 168 |
+
|
| 169 |
+
for pos in square_brackets:
|
| 170 |
+
multiply_range(pos, square_bracket_multiplier)
|
| 171 |
+
|
| 172 |
+
if len(res) == 0:
|
| 173 |
+
res = [["", 1.0]]
|
| 174 |
+
|
| 175 |
+
# merge runs of identical weights
|
| 176 |
+
i = 0
|
| 177 |
+
while i + 1 < len(res):
|
| 178 |
+
if res[i][1] == res[i + 1][1]:
|
| 179 |
+
res[i][0] += res[i + 1][0]
|
| 180 |
+
res.pop(i + 1)
|
| 181 |
+
else:
|
| 182 |
+
i += 1
|
| 183 |
+
|
| 184 |
+
return res
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def get_prompts_tokens_with_weights(clip_tokenizer: CLIPTokenizer, prompt: str):
|
| 188 |
+
"""
|
| 189 |
+
Get prompt token ids and weights, this function works for both prompt and negative prompt
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
pipe (CLIPTokenizer)
|
| 193 |
+
A CLIPTokenizer
|
| 194 |
+
prompt (str)
|
| 195 |
+
A prompt string with weights
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
text_tokens (list)
|
| 199 |
+
A list contains token ids
|
| 200 |
+
text_weight (list)
|
| 201 |
+
A list contains the correspondent weight of token ids
|
| 202 |
+
|
| 203 |
+
Example:
|
| 204 |
+
import torch
|
| 205 |
+
from transformers import CLIPTokenizer
|
| 206 |
+
|
| 207 |
+
clip_tokenizer = CLIPTokenizer.from_pretrained(
|
| 208 |
+
"stablediffusionapi/deliberate-v2"
|
| 209 |
+
, subfolder = "tokenizer"
|
| 210 |
+
, dtype = torch.float16
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
token_id_list, token_weight_list = get_prompts_tokens_with_weights(
|
| 214 |
+
clip_tokenizer = clip_tokenizer
|
| 215 |
+
,prompt = "a (red:1.5) cat"*70
|
| 216 |
+
)
|
| 217 |
+
"""
|
| 218 |
+
texts_and_weights = parse_prompt_attention(prompt)
|
| 219 |
+
text_tokens, text_weights = [], []
|
| 220 |
+
for word, weight in texts_and_weights:
|
| 221 |
+
# tokenize and discard the starting and the ending token
|
| 222 |
+
token = clip_tokenizer(word, truncation=False).input_ids[1:-1] # so that tokenize whatever length prompt
|
| 223 |
+
# the returned token is a 1d list: [320, 1125, 539, 320]
|
| 224 |
+
|
| 225 |
+
# merge the new tokens to the all tokens holder: text_tokens
|
| 226 |
+
text_tokens = [*text_tokens, *token]
|
| 227 |
+
|
| 228 |
+
# each token chunk will come with one weight, like ['red cat', 2.0]
|
| 229 |
+
# need to expand weight for each token.
|
| 230 |
+
chunk_weights = [weight] * len(token)
|
| 231 |
+
|
| 232 |
+
# append the weight back to the weight holder: text_weights
|
| 233 |
+
text_weights = [*text_weights, *chunk_weights]
|
| 234 |
+
return text_tokens, text_weights
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def group_tokens_and_weights(token_ids: list, weights: list, pad_last_block=False):
|
| 238 |
+
"""
|
| 239 |
+
Produce tokens and weights in groups and pad the missing tokens
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
token_ids (list)
|
| 243 |
+
The token ids from tokenizer
|
| 244 |
+
weights (list)
|
| 245 |
+
The weights list from function get_prompts_tokens_with_weights
|
| 246 |
+
pad_last_block (bool)
|
| 247 |
+
Control if fill the last token list to 75 tokens with eos
|
| 248 |
+
Returns:
|
| 249 |
+
new_token_ids (2d list)
|
| 250 |
+
new_weights (2d list)
|
| 251 |
+
|
| 252 |
+
Example:
|
| 253 |
+
token_groups,weight_groups = group_tokens_and_weights(
|
| 254 |
+
token_ids = token_id_list
|
| 255 |
+
, weights = token_weight_list
|
| 256 |
+
)
|
| 257 |
+
"""
|
| 258 |
+
bos, eos = 49406, 49407
|
| 259 |
+
|
| 260 |
+
# this will be a 2d list
|
| 261 |
+
new_token_ids = []
|
| 262 |
+
new_weights = []
|
| 263 |
+
while len(token_ids) >= 75:
|
| 264 |
+
# get the first 75 tokens
|
| 265 |
+
head_75_tokens = [token_ids.pop(0) for _ in range(75)]
|
| 266 |
+
head_75_weights = [weights.pop(0) for _ in range(75)]
|
| 267 |
+
|
| 268 |
+
# extract token ids and weights
|
| 269 |
+
temp_77_token_ids = [bos] + head_75_tokens + [eos]
|
| 270 |
+
temp_77_weights = [1.0] + head_75_weights + [1.0]
|
| 271 |
+
|
| 272 |
+
# add 77 token and weights chunk to the holder list
|
| 273 |
+
new_token_ids.append(temp_77_token_ids)
|
| 274 |
+
new_weights.append(temp_77_weights)
|
| 275 |
+
|
| 276 |
+
# padding the left
|
| 277 |
+
if len(token_ids) > 0:
|
| 278 |
+
padding_len = 75 - len(token_ids) if pad_last_block else 0
|
| 279 |
+
|
| 280 |
+
temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
|
| 281 |
+
new_token_ids.append(temp_77_token_ids)
|
| 282 |
+
|
| 283 |
+
temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
|
| 284 |
+
new_weights.append(temp_77_weights)
|
| 285 |
+
|
| 286 |
+
return new_token_ids, new_weights
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def get_weighted_text_embeddings_sdxl(
|
| 290 |
+
pipe,
|
| 291 |
+
prompt: str = "",
|
| 292 |
+
prompt_2: str = None,
|
| 293 |
+
neg_prompt: str = "",
|
| 294 |
+
neg_prompt_2: str = None,
|
| 295 |
+
num_images_per_prompt: int = 1,
|
| 296 |
+
device: Optional[torch.device] = None,
|
| 297 |
+
clip_skip: Optional[int] = None,
|
| 298 |
+
lora_scale: Optional[int] = None,
|
| 299 |
+
):
|
| 300 |
+
"""
|
| 301 |
+
This function can process long prompt with weights, no length limitation
|
| 302 |
+
for Stable Diffusion XL
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
pipe (StableDiffusionPipeline)
|
| 306 |
+
prompt (str)
|
| 307 |
+
prompt_2 (str)
|
| 308 |
+
neg_prompt (str)
|
| 309 |
+
neg_prompt_2 (str)
|
| 310 |
+
num_images_per_prompt (int)
|
| 311 |
+
device (torch.device)
|
| 312 |
+
clip_skip (int)
|
| 313 |
+
Returns:
|
| 314 |
+
prompt_embeds (torch.Tensor)
|
| 315 |
+
neg_prompt_embeds (torch.Tensor)
|
| 316 |
+
"""
|
| 317 |
+
device = device or pipe._execution_device
|
| 318 |
+
|
| 319 |
+
# set lora scale so that monkey patched LoRA
|
| 320 |
+
# function of text encoder can correctly access it
|
| 321 |
+
if lora_scale is not None and isinstance(pipe, StableDiffusionXLLoraLoaderMixin):
|
| 322 |
+
pipe._lora_scale = lora_scale
|
| 323 |
+
|
| 324 |
+
# dynamically adjust the LoRA scale
|
| 325 |
+
if pipe.text_encoder is not None:
|
| 326 |
+
if not USE_PEFT_BACKEND:
|
| 327 |
+
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
|
| 328 |
+
else:
|
| 329 |
+
scale_lora_layers(pipe.text_encoder, lora_scale)
|
| 330 |
+
|
| 331 |
+
if pipe.text_encoder_2 is not None:
|
| 332 |
+
if not USE_PEFT_BACKEND:
|
| 333 |
+
adjust_lora_scale_text_encoder(pipe.text_encoder_2, lora_scale)
|
| 334 |
+
else:
|
| 335 |
+
scale_lora_layers(pipe.text_encoder_2, lora_scale)
|
| 336 |
+
|
| 337 |
+
if prompt_2:
|
| 338 |
+
prompt = f"{prompt} {prompt_2}"
|
| 339 |
+
|
| 340 |
+
if neg_prompt_2:
|
| 341 |
+
neg_prompt = f"{neg_prompt} {neg_prompt_2}"
|
| 342 |
+
|
| 343 |
+
prompt_t1 = prompt_t2 = prompt
|
| 344 |
+
neg_prompt_t1 = neg_prompt_t2 = neg_prompt
|
| 345 |
+
|
| 346 |
+
if isinstance(pipe, TextualInversionLoaderMixin):
|
| 347 |
+
prompt_t1 = pipe.maybe_convert_prompt(prompt_t1, pipe.tokenizer)
|
| 348 |
+
neg_prompt_t1 = pipe.maybe_convert_prompt(neg_prompt_t1, pipe.tokenizer)
|
| 349 |
+
prompt_t2 = pipe.maybe_convert_prompt(prompt_t2, pipe.tokenizer_2)
|
| 350 |
+
neg_prompt_t2 = pipe.maybe_convert_prompt(neg_prompt_t2, pipe.tokenizer_2)
|
| 351 |
+
|
| 352 |
+
eos = pipe.tokenizer.eos_token_id
|
| 353 |
+
|
| 354 |
+
# tokenizer 1
|
| 355 |
+
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, prompt_t1)
|
| 356 |
+
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt_t1)
|
| 357 |
+
|
| 358 |
+
# tokenizer 2
|
| 359 |
+
prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt_t2)
|
| 360 |
+
neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt_t2)
|
| 361 |
+
|
| 362 |
+
# padding the shorter one for prompt set 1
|
| 363 |
+
prompt_token_len = len(prompt_tokens)
|
| 364 |
+
neg_prompt_token_len = len(neg_prompt_tokens)
|
| 365 |
+
|
| 366 |
+
if prompt_token_len > neg_prompt_token_len:
|
| 367 |
+
# padding the neg_prompt with eos token
|
| 368 |
+
neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
|
| 369 |
+
neg_prompt_weights = neg_prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
|
| 370 |
+
else:
|
| 371 |
+
# padding the prompt
|
| 372 |
+
prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
|
| 373 |
+
prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
|
| 374 |
+
|
| 375 |
+
# padding the shorter one for token set 2
|
| 376 |
+
prompt_token_len_2 = len(prompt_tokens_2)
|
| 377 |
+
neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
|
| 378 |
+
|
| 379 |
+
if prompt_token_len_2 > neg_prompt_token_len_2:
|
| 380 |
+
# padding the neg_prompt with eos token
|
| 381 |
+
neg_prompt_tokens_2 = neg_prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
| 382 |
+
neg_prompt_weights_2 = neg_prompt_weights_2 + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
| 383 |
+
else:
|
| 384 |
+
# padding the prompt
|
| 385 |
+
prompt_tokens_2 = prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
| 386 |
+
prompt_weights_2 = prompt_weights + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
| 387 |
+
|
| 388 |
+
embeds = []
|
| 389 |
+
neg_embeds = []
|
| 390 |
+
|
| 391 |
+
prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(prompt_tokens.copy(), prompt_weights.copy())
|
| 392 |
+
|
| 393 |
+
neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
|
| 394 |
+
neg_prompt_tokens.copy(), neg_prompt_weights.copy()
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
prompt_token_groups_2, prompt_weight_groups_2 = group_tokens_and_weights(
|
| 398 |
+
prompt_tokens_2.copy(), prompt_weights_2.copy()
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = group_tokens_and_weights(
|
| 402 |
+
neg_prompt_tokens_2.copy(), neg_prompt_weights_2.copy()
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# get prompt embeddings one by one is not working.
|
| 406 |
+
for i in range(len(prompt_token_groups)):
|
| 407 |
+
# get positive prompt embeddings with weights
|
| 408 |
+
token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=device)
|
| 409 |
+
weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=device)
|
| 410 |
+
|
| 411 |
+
token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=device)
|
| 412 |
+
|
| 413 |
+
# use first text encoder
|
| 414 |
+
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True)
|
| 415 |
+
|
| 416 |
+
# use second text encoder
|
| 417 |
+
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True)
|
| 418 |
+
pooled_prompt_embeds = prompt_embeds_2[0]
|
| 419 |
+
|
| 420 |
+
if clip_skip is None:
|
| 421 |
+
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
|
| 422 |
+
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
|
| 423 |
+
else:
|
| 424 |
+
# "2" because SDXL always indexes from the penultimate layer.
|
| 425 |
+
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-(clip_skip + 2)]
|
| 426 |
+
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-(clip_skip + 2)]
|
| 427 |
+
|
| 428 |
+
prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
|
| 429 |
+
token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0)
|
| 430 |
+
|
| 431 |
+
for j in range(len(weight_tensor)):
|
| 432 |
+
if weight_tensor[j] != 1.0:
|
| 433 |
+
token_embedding[j] = (
|
| 434 |
+
token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
token_embedding = token_embedding.unsqueeze(0)
|
| 438 |
+
embeds.append(token_embedding)
|
| 439 |
+
|
| 440 |
+
# get negative prompt embeddings with weights
|
| 441 |
+
neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=device)
|
| 442 |
+
neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=device)
|
| 443 |
+
neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=device)
|
| 444 |
+
|
| 445 |
+
# use first text encoder
|
| 446 |
+
neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(device), output_hidden_states=True)
|
| 447 |
+
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
|
| 448 |
+
|
| 449 |
+
# use second text encoder
|
| 450 |
+
neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(device), output_hidden_states=True)
|
| 451 |
+
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
|
| 452 |
+
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
|
| 453 |
+
|
| 454 |
+
neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]
|
| 455 |
+
neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0)
|
| 456 |
+
|
| 457 |
+
for z in range(len(neg_weight_tensor)):
|
| 458 |
+
if neg_weight_tensor[z] != 1.0:
|
| 459 |
+
neg_token_embedding[z] = (
|
| 460 |
+
neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
neg_token_embedding = neg_token_embedding.unsqueeze(0)
|
| 464 |
+
neg_embeds.append(neg_token_embedding)
|
| 465 |
+
|
| 466 |
+
prompt_embeds = torch.cat(embeds, dim=1)
|
| 467 |
+
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
|
| 468 |
+
|
| 469 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 470 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 471 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 472 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 473 |
+
|
| 474 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 475 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 476 |
+
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 477 |
+
|
| 478 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(
|
| 479 |
+
bs_embed * num_images_per_prompt, -1
|
| 480 |
+
)
|
| 481 |
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(
|
| 482 |
+
bs_embed * num_images_per_prompt, -1
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
if pipe.text_encoder is not None:
|
| 486 |
+
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 487 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 488 |
+
unscale_lora_layers(pipe.text_encoder, lora_scale)
|
| 489 |
+
|
| 490 |
+
if pipe.text_encoder_2 is not None:
|
| 491 |
+
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 492 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 493 |
+
unscale_lora_layers(pipe.text_encoder_2, lora_scale)
|
| 494 |
+
|
| 495 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
| 496 |
+
|
| 497 |
+
class ModImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
|
| 498 |
+
|
| 499 |
+
def encode_prompt(self, prompt, num_images_per_prompt, negative_prompt, lora_scale, clip_skip, **kwags):
|
| 500 |
+
return get_weighted_text_embeddings_sdxl(
|
| 501 |
+
pipe=self,
|
| 502 |
+
prompt=prompt,
|
| 503 |
+
neg_prompt=negative_prompt,
|
| 504 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 505 |
+
clip_skip=clip_skip,
|
| 506 |
+
lora_scale=lora_scale,
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
def get_timesteps(self, num_inference_steps, strength, device, **kwargs):
|
| 510 |
+
# get the original timestep using init_timestep
|
| 511 |
+
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
| 512 |
+
|
| 513 |
+
t_start = int(max(num_inference_steps - init_timestep, 0))
|
| 514 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
| 515 |
+
if hasattr(self.scheduler, "set_begin_index"):
|
| 516 |
+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
| 517 |
+
|
| 518 |
+
return timesteps, num_inference_steps - t_start
|
| 519 |
+
|
| 520 |
+
class ModText2ImgPipeline(StableDiffusionXLPipeline):
|
| 521 |
+
|
| 522 |
+
def encode_prompt(self, prompt, num_images_per_prompt, negative_prompt, lora_scale, clip_skip, **kwags):
|
| 523 |
+
return get_weighted_text_embeddings_sdxl(
|
| 524 |
+
pipe=self,
|
| 525 |
+
prompt=prompt,
|
| 526 |
+
neg_prompt=negative_prompt,
|
| 527 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 528 |
+
clip_skip=clip_skip,
|
| 529 |
+
lora_scale=lora_scale,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
class ModFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
| 533 |
+
|
| 534 |
+
@property
|
| 535 |
+
def init_noise_sigma(self):
|
| 536 |
+
return 1.0
|
| 537 |
+
|
| 538 |
+
def scale_model_input(self, x, y):
|
| 539 |
+
return x
|
| 540 |
+
|
| 541 |
+
def add_noise(self, x, n, t):
|
| 542 |
+
return self.scale_noise(x, t, n)
|
| 543 |
+
|
| 544 |
+
def get_pipeline_initialize(model_1="", model_2=""):
|
| 545 |
+
pipe = ModText2ImgPipeline.from_single_file(
|
| 546 |
+
os.getenv("SDXL_MODEL", model_1),
|
| 547 |
+
torch_dtype=torch.float16
|
| 548 |
+
)
|
| 549 |
+
pipe.fuse_qkv_projections()
|
| 550 |
+
pipe.unet.set_attention_backend("_flash_3_hub")
|
| 551 |
+
pipe.scheduler = ModFlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=2.0)
|
| 552 |
+
pipe.unet.to(memory_format=torch.channels_last)
|
| 553 |
+
pipe.vae.to(memory_format=torch.channels_last)
|
| 554 |
+
|
| 555 |
+
pipe2 = ModText2ImgPipeline.from_single_file(
|
| 556 |
+
os.getenv("SDXL_MODEL_2", model_2),
|
| 557 |
+
torch_dtype=torch.float16
|
| 558 |
+
)
|
| 559 |
+
pipe2.fuse_qkv_projections()
|
| 560 |
+
pipe2.unet.set_attention_backend("_flash_3_hub")
|
| 561 |
+
pipe2.scheduler = ModFlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=2.0)
|
| 562 |
+
pipe2.unet.to(memory_format=torch.channels_last)
|
| 563 |
+
pipe2.vae.to(memory_format=torch.channels_last)
|
| 564 |
+
|
| 565 |
+
pipe_img2img = ModImg2ImgPipeline(
|
| 566 |
+
vae=pipe.vae,
|
| 567 |
+
unet=pipe.unet,
|
| 568 |
+
text_encoder=pipe.text_encoder,
|
| 569 |
+
text_encoder_2=pipe.text_encoder_2,
|
| 570 |
+
tokenizer=pipe.tokenizer,
|
| 571 |
+
tokenizer_2=pipe.tokenizer_2,
|
| 572 |
+
scheduler=pipe.scheduler,
|
| 573 |
+
image_encoder=pipe.image_encoder,
|
| 574 |
+
feature_extractor=pipe.feature_extractor,
|
| 575 |
+
)
|
| 576 |
+
pipe2_img2img = ModImg2ImgPipeline(
|
| 577 |
+
vae=pipe2.vae,
|
| 578 |
+
unet=pipe2.unet,
|
| 579 |
+
text_encoder=pipe2.text_encoder,
|
| 580 |
+
text_encoder_2=pipe2.text_encoder_2,
|
| 581 |
+
tokenizer=pipe2.tokenizer,
|
| 582 |
+
tokenizer_2=pipe2.tokenizer_2,
|
| 583 |
+
scheduler=pipe2.scheduler,
|
| 584 |
+
image_encoder=pipe2.image_encoder,
|
| 585 |
+
feature_extractor=pipe2.feature_extractor,
|
| 586 |
+
)
|
| 587 |
+
return pipe, pipe2, pipe_img2img, pipe2_img2img
|
| 588 |
|
| 589 |
def sign_message(message, key):
|
| 590 |
hmac_digest = hmac.new(key.encode(), message.encode(), hashlib.sha512).digest()
|
|
|
|
| 601 |
init_image = init_image.resize((w, h))
|
| 602 |
init_image = np.array(init_image)
|
| 603 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
return zero_inference_api(prompt, radio, preset, h, w, negative_prompt, guidance_scale, randomize_seed, seed, do_img2img, init_image, image2image_strength, inference_steps=inference_steps)
|
| 605 |
|
| 606 |
@spaces.GPU
|
|
|
|
| 687 |
image.save(tmpfile, "png", pnginfo=metadata)
|
| 688 |
return tmpfile.name, seed
|
| 689 |
|
| 690 |
+
pipe, pipe2, pipe_img2img, pipe2_img2img = pipeline.get_pipeline_initialize()
|
| 691 |
+
|
| 692 |
with gr.Blocks(theme=theme) as demo:
|
| 693 |
gr.Markdown('''# SDXL Experiments
|
| 694 |
Just a simple demo for some SDXL model.''')
|
|
|
|
| 737 |
outputs=[output, seed],
|
| 738 |
concurrency_limit=1,
|
| 739 |
)
|
| 740 |
+
|
| 741 |
if __name__ == "__main__":
|
| 742 |
demo.launch(share=True)
|