toaru-xl-model / app.py
nyanko7's picture
Update app.py
29bedde verified
import base64
import gradio as gr
import spaces
import codecs
import hashlib
import hmac
import inspect
import io
import json
import math
import os
import pickle
import random
import tempfile
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
# Third-party general libraries
import httpimport
import numpy as np
import requests
from packaging import version
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import PIL
# PyTorch
import torch
import torch.nn.functional as F
# Hugging Face & Diffusers
from transformers import CLIPTokenizer
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
from diffusers.loaders import (
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.schedulers import EulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
USE_PEFT_BACKEND,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
theme = gr.themes.Base(font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'])
device="cuda"
PRESET_Q = "year_2022, best quality, high quality, very aesthetic"
NEGATIVE_PROMPT = "lowres, worst quality, displeasing, bad anatomy, text, error, extra digit, cropped, error, fewer, extra, missing, worst quality, jpeg artifacts, censored, worst quality displeasing, bad quality"
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_class(name: str):
import importlib
module_name, class_name = name.rsplit(".", 1)
module = importlib.import_module(module_name, package=None)
return getattr(module, class_name)
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\\( - literal character '('
\\[ - literal character '['
\\) - literal character ')'
\\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\\(literal\\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
import re
re_attention = re.compile(
r"""
\{|\}|\\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|
\)|]|[^\\()\[\]:]+|:
""",
re.X,
)
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
res = []
round_brackets = []
square_brackets = []
curly_brackets = []
round_bracket_multiplier = 1.05
curly_bracket_multiplier = 1.05
square_bracket_multiplier = 1 / 1.05
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "{":
curly_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "}" and len(round_brackets) > 0:
multiply_range(curly_brackets.pop(), curly_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
parts = re.split(re_break, text)
for i, part in enumerate(parts):
if i > 0:
res.append(["BREAK", -1])
res.append([part, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_tokens_with_weights(clip_tokenizer: CLIPTokenizer, prompt: str):
"""
Get prompt token ids and weights, this function works for both prompt and negative prompt
Args:
pipe (CLIPTokenizer)
A CLIPTokenizer
prompt (str)
A prompt string with weights
Returns:
text_tokens (list)
A list contains token ids
text_weight (list)
A list contains the correspondent weight of token ids
Example:
import torch
from transformers import CLIPTokenizer
clip_tokenizer = CLIPTokenizer.from_pretrained(
"stablediffusionapi/deliberate-v2"
, subfolder = "tokenizer"
, dtype = torch.float16
)
token_id_list, token_weight_list = get_prompts_tokens_with_weights(
clip_tokenizer = clip_tokenizer
,prompt = "a (red:1.5) cat"*70
)
"""
texts_and_weights = parse_prompt_attention(prompt)
text_tokens, text_weights = [], []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = clip_tokenizer(word, truncation=False).input_ids[1:-1] # so that tokenize whatever length prompt
# the returned token is a 1d list: [320, 1125, 539, 320]
# merge the new tokens to the all tokens holder: text_tokens
text_tokens = [*text_tokens, *token]
# each token chunk will come with one weight, like ['red cat', 2.0]
# need to expand weight for each token.
chunk_weights = [weight] * len(token)
# append the weight back to the weight holder: text_weights
text_weights = [*text_weights, *chunk_weights]
return text_tokens, text_weights
def group_tokens_and_weights(token_ids: list, weights: list, pad_last_block=False):
"""
Produce tokens and weights in groups and pad the missing tokens
Args:
token_ids (list)
The token ids from tokenizer
weights (list)
The weights list from function get_prompts_tokens_with_weights
pad_last_block (bool)
Control if fill the last token list to 75 tokens with eos
Returns:
new_token_ids (2d list)
new_weights (2d list)
Example:
token_groups,weight_groups = group_tokens_and_weights(
token_ids = token_id_list
, weights = token_weight_list
)
"""
bos, eos = 49406, 49407
# this will be a 2d list
new_token_ids = []
new_weights = []
while len(token_ids) >= 75:
# get the first 75 tokens
head_75_tokens = [token_ids.pop(0) for _ in range(75)]
head_75_weights = [weights.pop(0) for _ in range(75)]
# extract token ids and weights
temp_77_token_ids = [bos] + head_75_tokens + [eos]
temp_77_weights = [1.0] + head_75_weights + [1.0]
# add 77 token and weights chunk to the holder list
new_token_ids.append(temp_77_token_ids)
new_weights.append(temp_77_weights)
# padding the left
if len(token_ids) > 0:
padding_len = 75 - len(token_ids) if pad_last_block else 0
temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
new_token_ids.append(temp_77_token_ids)
temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
new_weights.append(temp_77_weights)
return new_token_ids, new_weights
def get_weighted_text_embeddings_sdxl(
pipe,
prompt: str = "",
prompt_2: str = None,
neg_prompt: str = "",
neg_prompt_2: str = None,
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
clip_skip: Optional[int] = None,
lora_scale: Optional[int] = None,
):
"""
This function can process long prompt with weights, no length limitation
for Stable Diffusion XL
Args:
pipe (StableDiffusionPipeline)
prompt (str)
prompt_2 (str)
neg_prompt (str)
neg_prompt_2 (str)
num_images_per_prompt (int)
device (torch.device)
clip_skip (int)
Returns:
prompt_embeds (torch.Tensor)
neg_prompt_embeds (torch.Tensor)
"""
device = device or pipe._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(pipe, StableDiffusionXLLoraLoaderMixin):
pipe._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if pipe.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
else:
scale_lora_layers(pipe.text_encoder, lora_scale)
if pipe.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(pipe.text_encoder_2, lora_scale)
else:
scale_lora_layers(pipe.text_encoder_2, lora_scale)
if prompt_2:
prompt = f"{prompt} {prompt_2}"
if neg_prompt_2:
neg_prompt = f"{neg_prompt} {neg_prompt_2}"
prompt_t1 = prompt_t2 = prompt
neg_prompt_t1 = neg_prompt_t2 = neg_prompt
if isinstance(pipe, TextualInversionLoaderMixin):
prompt_t1 = pipe.maybe_convert_prompt(prompt_t1, pipe.tokenizer)
neg_prompt_t1 = pipe.maybe_convert_prompt(neg_prompt_t1, pipe.tokenizer)
prompt_t2 = pipe.maybe_convert_prompt(prompt_t2, pipe.tokenizer_2)
neg_prompt_t2 = pipe.maybe_convert_prompt(neg_prompt_t2, pipe.tokenizer_2)
eos = pipe.tokenizer.eos_token_id
# tokenizer 1
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, prompt_t1)
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt_t1)
# tokenizer 2
prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt_t2)
neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt_t2)
# padding the shorter one for prompt set 1
prompt_token_len = len(prompt_tokens)
neg_prompt_token_len = len(neg_prompt_tokens)
if prompt_token_len > neg_prompt_token_len:
# padding the neg_prompt with eos token
neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
neg_prompt_weights = neg_prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
else:
# padding the prompt
prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
# padding the shorter one for token set 2
prompt_token_len_2 = len(prompt_tokens_2)
neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
if prompt_token_len_2 > neg_prompt_token_len_2:
# padding the neg_prompt with eos token
neg_prompt_tokens_2 = neg_prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
neg_prompt_weights_2 = neg_prompt_weights_2 + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
else:
# padding the prompt
prompt_tokens_2 = prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
prompt_weights_2 = prompt_weights + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
embeds = []
neg_embeds = []
prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(prompt_tokens.copy(), prompt_weights.copy())
neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
neg_prompt_tokens.copy(), neg_prompt_weights.copy()
)
prompt_token_groups_2, prompt_weight_groups_2 = group_tokens_and_weights(
prompt_tokens_2.copy(), prompt_weights_2.copy()
)
neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = group_tokens_and_weights(
neg_prompt_tokens_2.copy(), neg_prompt_weights_2.copy()
)
# get prompt embeddings one by one is not working.
for i in range(len(prompt_token_groups)):
# get positive prompt embeddings with weights
token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=device)
weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=device)
token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=device)
# use first text encoder
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True)
# use second text encoder
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds_2[0]
if clip_skip is None:
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
else:
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-(clip_skip + 2)]
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-(clip_skip + 2)]
prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0)
for j in range(len(weight_tensor)):
if weight_tensor[j] != 1.0:
token_embedding[j] = (
token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]
)
token_embedding = token_embedding.unsqueeze(0)
embeds.append(token_embedding)
# get negative prompt embeddings with weights
neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=device)
neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=device)
neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=device)
# use first text encoder
neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(device), output_hidden_states=True)
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
# use second text encoder
neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(device), output_hidden_states=True)
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]
neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0)
for z in range(len(neg_weight_tensor)):
if neg_weight_tensor[z] != 1.0:
neg_token_embedding[z] = (
neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]
)
neg_token_embedding = neg_token_embedding.unsqueeze(0)
neg_embeds.append(neg_token_embedding)
prompt_embeds = torch.cat(embeds, dim=1)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(
bs_embed * num_images_per_prompt, -1
)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(
bs_embed * num_images_per_prompt, -1
)
if pipe.text_encoder is not None:
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder, lora_scale)
if pipe.text_encoder_2 is not None:
if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(pipe.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
class ModImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
def encode_prompt(self, prompt, num_images_per_prompt, negative_prompt, lora_scale, clip_skip, **kwags):
return get_weighted_text_embeddings_sdxl(
pipe=self,
prompt=prompt,
neg_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip,
lora_scale=lora_scale,
)
def get_timesteps(self, num_inference_steps, strength, device, **kwargs):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
class ModText2ImgPipeline(StableDiffusionXLPipeline):
def encode_prompt(self, prompt, num_images_per_prompt, negative_prompt, lora_scale, clip_skip, **kwags):
return get_weighted_text_embeddings_sdxl(
pipe=self,
prompt=prompt,
neg_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
clip_skip=clip_skip,
lora_scale=lora_scale,
)
class ModFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
@property
def init_noise_sigma(self):
return 1.0
def scale_model_input(self, x, y):
return x
def add_noise(self, x, n, t):
return self.scale_noise(x, t, n)
def get_pipeline_initialize(model_1="", model_2=""):
pipe = ModText2ImgPipeline.from_single_file(
os.getenv("SDXL_MODEL", model_1),
torch_dtype=torch.float16
)
pipe.fuse_qkv_projections()
pipe.unet.set_attention_backend("_flash_3_hub")
pipe.scheduler = ModFlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=2.0)
pipe.unet.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)
pipe2 = ModText2ImgPipeline.from_single_file(
os.getenv("SDXL_MODEL_2", model_2),
torch_dtype=torch.float16
)
pipe2.fuse_qkv_projections()
pipe2.unet.set_attention_backend("_flash_3_hub")
pipe2.scheduler = ModFlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=2.0)
pipe2.unet.to(memory_format=torch.channels_last)
pipe2.vae.to(memory_format=torch.channels_last)
pipe_img2img = ModImg2ImgPipeline(
vae=pipe.vae,
unet=pipe.unet,
text_encoder=pipe.text_encoder,
text_encoder_2=pipe.text_encoder_2,
tokenizer=pipe.tokenizer,
tokenizer_2=pipe.tokenizer_2,
scheduler=pipe.scheduler,
image_encoder=pipe.image_encoder,
feature_extractor=pipe.feature_extractor,
)
pipe2_img2img = ModImg2ImgPipeline(
vae=pipe2.vae,
unet=pipe2.unet,
text_encoder=pipe2.text_encoder,
text_encoder_2=pipe2.text_encoder_2,
tokenizer=pipe2.tokenizer,
tokenizer_2=pipe2.tokenizer_2,
scheduler=pipe2.scheduler,
image_encoder=pipe2.image_encoder,
feature_extractor=pipe2.feature_extractor,
)
return pipe, pipe2, pipe_img2img, pipe2_img2img
def sign_message(message, key):
hmac_digest = hmac.new(key.encode(), message.encode(), hashlib.sha512).digest()
signed_hash = base64.b64encode(hmac_digest).decode()
return signed_hash
@spaces.GPU
def run(prompt, radio="model-v2", preset=PRESET_Q, h=1216, w=832, negative_prompt=NEGATIVE_PROMPT, guidance_scale=4.0, randomize_seed=True, seed=42, do_img2img=False, init_image=None, image2image_resize=True, image2image_strength=0, inference_steps=25, progress=gr.Progress(track_tqdm=True)):
if init_image is None:
do_img2img = False
if do_img2img and image2image_resize:
# init_image: np.ndarray
init_image = Image.fromarray(init_image)
init_image = init_image.resize((w, h))
init_image = np.array(init_image)
prompt = prompt.strip() + ", " + preset.strip()
negative_prompt = negative_prompt.strip() if negative_prompt and negative_prompt.strip() else None
print(f"Initial seed for prompt `{prompt}`", seed)
if(randomize_seed):
seed = random.randint(0, 9007199254740991)
if not prompt and not negative_prompt:
guidance_scale = 0.0
generator = torch.Generator(device="cuda").manual_seed(seed)
if inference_steps > 50:
inference_steps = 50
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
if not do_img2img:
if radio == "model-v2":
image = pipe(prompt, height=h, width=w, negative_prompt=negative_prompt, guidance_scale=guidance_scale, guidance_rescale=0.75, generator=generator, num_inference_steps=inference_steps).images[0]
else:
image = pipe2(prompt, height=h, width=w, negative_prompt=negative_prompt, guidance_scale=guidance_scale, guidance_rescale=0.75, generator=generator, num_inference_steps=inference_steps).images[0]
else:
init_image = Image.fromarray(init_image)
if radio == "model-v2":
image = pipe_img2img(prompt, image=init_image, strength=image2image_strength, negative_prompt=negative_prompt, guidance_scale=guidance_scale, generator=generator, num_inference_steps=inference_steps).images[0]
else:
image = pipe2_img2img(prompt, image=init_image, strength=image2image_strength, negative_prompt=negative_prompt, guidance_scale=guidance_scale, generator=generator, num_inference_steps=inference_steps).images[0]
naifix = prompt[:40].replace(":", "_").replace("\\", "_").replace("/", "_") + f" s-{seed}-"
with tempfile.NamedTemporaryFile(prefix=naifix, suffix=".png", delete=False) as tmpfile:
parameters = {
"prompt": prompt,
"steps": inference_steps,
"height": h,
"width": w,
"scale": guidance_scale,
"uncond_scale": 0.0,
"cfg_rescale": 0.0,
"seed": seed,
"n_samples": 1,
"hide_debug_overlay": False,
"noise_schedule": "native",
"legacy_v3_extend": False,
"reference_information_extracted_multiple": [],
"reference_strength_multiple": [],
"sampler": "k_dpmpp_2m_sde",
"controlnet_strength": 1.0,
"controlnet_model": None,
"dynamic_thresholding": False,
"dynamic_thresholding_percentile": 0.999,
"dynamic_thresholding_mimic_scale": 10.0,
"sm": False,
"sm_dyn": False,
"skip_cfg_above_sigma": 23.69030960605558,
"skip_cfg_below_sigma": 0.0,
"lora_unet_weights": None,
"lora_clip_weights": None,
"deliberate_euler_ancestral_bug": True,
"prefer_brownian": False,
"cfg_sched_eligibility": "enable_for_post_summer_samplers",
"explike_fine_detail": False,
"minimize_sigma_inf": False,
"uncond_per_vibe": True,
"wonky_vibe_correlation": True,
"version": 1,
"uc": "nsfw, lowres, {bad}, error, fewer, extra, missing, worst quality, jpeg artifacts, bad quality, watermark, unfinished, displeasing, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], lowres, {bad}, error, fewer, extra, missing, worst quality, jpeg artifacts, bad quality, unfinished, displeasing, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract],{{{{chibi,doll,+_+}}}},",
}
metadata_params = {
"request_type": "PromptGenerateRequest",
"signed_hash": sign_message(json.dumps(parameters), "novelai-client"),
**parameters
}
metadata = PngInfo()
metadata.add_text("Title", "AI generated image")
metadata.add_text("Description", prompt)
metadata.add_text("Software", "NovelAI")
metadata.add_text("Source", "Stable Diffusion XL 7BCCAA2C")
metadata.add_text("Nya", "Nya~")
metadata.add_text("Generation time", f"1.{random.randint(1000000000, 9999999999)}")
metadata.add_text("Comment", json.dumps(metadata_params))
image.save(tmpfile, "png", pnginfo=metadata)
return tmpfile.name, seed
pipe, pipe2, pipe_img2img, pipe2_img2img = get_pipeline_initialize()
pipe, pipe2 = pipe.to(device), pipe2.to(device)
with gr.Blocks(theme=theme) as demo:
gr.Markdown('''# SDXL Experiments
Just a simple demo for some SDXL model.''')
with gr.Row():
with gr.Column():
with gr.Group():
with gr.Row():
prompt = gr.Textbox(show_label=False, scale=5, value="1girl, rurudo", placeholder="Your prompt", info="Leave blank to test unconditional generation")
button = gr.Button("Generate", min_width=120)
preset = gr.Textbox(show_label=False, scale=5, value=PRESET_Q, info="Quality presets")
radio = gr.Radio(["model-v2-beta", "model-v2"], value="model-v2", label = "Choose the inference model")
inference_steps = gr.Slider(label="Inference Steps", value=25, minimum=4, maximum=50, step=1)
with gr.Row():
height = gr.Slider(label="Height", value=1216, minimum=512, maximum=2560, step=64)
width = gr.Slider(label="Width", value=832, minimum=512, maximum=2560, step=64)
guidance_scale = gr.Number(label="CFG Guidance Scale", info="The guidance scale for CFG, ignored if no prompt is entered (unconditional generation)", value=4.0)
negative_prompt = gr.Textbox(label="Negative prompt", value=NEGATIVE_PROMPT, info="Is only applied for the CFG part, leave blank for unconditional generation")
seed = gr.Number(label="Seed", value=42, info="Seed for random number generator")
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
do_img2img = gr.Checkbox(label="Image to Image", value=False)
init_image = gr.Image(label="Input Image", visible=False)
image2image_resize = gr.Checkbox(label="Resize input image", value=False, visible=False)
image2image_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Noising strength", value=0.7, visible=False)
with gr.Column():
output = gr.Image(type="filepath", interactive=False)
gr.Examples(fn=run, examples=["mayano_top_gun_\(umamusume\), 1girl, rurudo", "sho (sho lwlw),[[[ohisashiburi]]],fukuro daizi,tianliang duohe fangdongye,[daidai ookami],year_2023, (wariza), depth of field, official_art"], inputs=prompt, outputs=[output, seed], cache_examples="lazy")
do_img2img.change(
fn=lambda x: [gr.update(visible=x), gr.update(visible=x), gr.update(visible=x)],
inputs=[do_img2img],
outputs=[init_image, image2image_resize, image2image_strength]
)
gr.on(
triggers=[
button.click,
prompt.submit
],
fn=run,
inputs=[prompt, radio, preset, height, width, negative_prompt, guidance_scale, randomize_seed, seed, do_img2img, init_image, image2image_resize, image2image_strength, inference_steps],
outputs=[output, seed],
concurrency_limit=1,
)
if __name__ == "__main__":
demo.launch(share=True)