Spaces:
Running
on
Zero
Running
on
Zero
| import base64 | |
| import gradio as gr | |
| import spaces | |
| import codecs | |
| import hashlib | |
| import hmac | |
| import inspect | |
| import io | |
| import json | |
| import math | |
| import os | |
| import pickle | |
| import random | |
| import tempfile | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
| # Third-party general libraries | |
| import httpimport | |
| import numpy as np | |
| import requests | |
| from packaging import version | |
| from PIL import Image | |
| from PIL.PngImagePlugin import PngInfo | |
| import PIL | |
| # PyTorch | |
| import torch | |
| import torch.nn.functional as F | |
| # Hugging Face & Diffusers | |
| from transformers import CLIPTokenizer | |
| from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline | |
| from diffusers.loaders import ( | |
| StableDiffusionXLLoraLoaderMixin, | |
| TextualInversionLoaderMixin, | |
| ) | |
| from diffusers.models.lora import adjust_lora_scale_text_encoder | |
| from diffusers.schedulers import EulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler | |
| from diffusers.utils import ( | |
| USE_PEFT_BACKEND, | |
| logging, | |
| scale_lora_layers, | |
| unscale_lora_layers, | |
| ) | |
| from diffusers.utils.torch_utils import randn_tensor | |
| if torch.cuda.get_device_properties(0).major >= 8: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| theme = gr.themes.Base(font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif']) | |
| device="cuda" | |
| PRESET_Q = "year_2022, best quality, high quality, very aesthetic" | |
| NEGATIVE_PROMPT = "lowres, worst quality, displeasing, bad anatomy, text, error, extra digit, cropped, error, fewer, extra, missing, worst quality, jpeg artifacts, censored, worst quality displeasing, bad quality" | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| def get_class(name: str): | |
| import importlib | |
| module_name, class_name = name.rsplit(".", 1) | |
| module = importlib.import_module(module_name, package=None) | |
| return getattr(module, class_name) | |
| def retrieve_latents( | |
| encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" | |
| ): | |
| if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": | |
| return encoder_output.latent_dist.sample(generator) | |
| elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": | |
| return encoder_output.latent_dist.mode() | |
| elif hasattr(encoder_output, "latents"): | |
| return encoder_output.latents | |
| else: | |
| raise AttributeError("Could not access latents of provided encoder_output") | |
| def parse_prompt_attention(text): | |
| """ | |
| Parses a string with attention tokens and returns a list of pairs: text and its associated weight. | |
| Accepted tokens are: | |
| (abc) - increases attention to abc by a multiplier of 1.1 | |
| (abc:3.12) - increases attention to abc by a multiplier of 3.12 | |
| [abc] - decreases attention to abc by a multiplier of 1.1 | |
| \\( - literal character '(' | |
| \\[ - literal character '[' | |
| \\) - literal character ')' | |
| \\] - literal character ']' | |
| \\ - literal character '\' | |
| anything else - just text | |
| >>> parse_prompt_attention('normal text') | |
| [['normal text', 1.0]] | |
| >>> parse_prompt_attention('an (important) word') | |
| [['an ', 1.0], ['important', 1.1], [' word', 1.0]] | |
| >>> parse_prompt_attention('(unbalanced') | |
| [['unbalanced', 1.1]] | |
| >>> parse_prompt_attention('\\(literal\\]') | |
| [['(literal]', 1.0]] | |
| >>> parse_prompt_attention('(unnecessary)(parens)') | |
| [['unnecessaryparens', 1.1]] | |
| >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') | |
| [['a ', 1.0], | |
| ['house', 1.5730000000000004], | |
| [' ', 1.1], | |
| ['on', 1.0], | |
| [' a ', 1.1], | |
| ['hill', 0.55], | |
| [', sun, ', 1.1], | |
| ['sky', 1.4641000000000006], | |
| ['.', 1.1]] | |
| """ | |
| import re | |
| re_attention = re.compile( | |
| r""" | |
| \{|\}|\\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)| | |
| \)|]|[^\\()\[\]:]+|: | |
| """, | |
| re.X, | |
| ) | |
| re_break = re.compile(r"\s*\bBREAK\b\s*", re.S) | |
| res = [] | |
| round_brackets = [] | |
| square_brackets = [] | |
| curly_brackets = [] | |
| round_bracket_multiplier = 1.05 | |
| curly_bracket_multiplier = 1.05 | |
| square_bracket_multiplier = 1 / 1.05 | |
| def multiply_range(start_position, multiplier): | |
| for p in range(start_position, len(res)): | |
| res[p][1] *= multiplier | |
| for m in re_attention.finditer(text): | |
| text = m.group(0) | |
| weight = m.group(1) | |
| if text.startswith("\\"): | |
| res.append([text[1:], 1.0]) | |
| elif text == "(": | |
| round_brackets.append(len(res)) | |
| elif text == "{": | |
| curly_brackets.append(len(res)) | |
| elif text == "[": | |
| square_brackets.append(len(res)) | |
| elif weight is not None and len(round_brackets) > 0: | |
| multiply_range(round_brackets.pop(), float(weight)) | |
| elif text == ")" and len(round_brackets) > 0: | |
| multiply_range(round_brackets.pop(), round_bracket_multiplier) | |
| elif text == "}" and len(round_brackets) > 0: | |
| multiply_range(curly_brackets.pop(), curly_bracket_multiplier) | |
| elif text == "]" and len(square_brackets) > 0: | |
| multiply_range(square_brackets.pop(), square_bracket_multiplier) | |
| else: | |
| parts = re.split(re_break, text) | |
| for i, part in enumerate(parts): | |
| if i > 0: | |
| res.append(["BREAK", -1]) | |
| res.append([part, 1.0]) | |
| for pos in round_brackets: | |
| multiply_range(pos, round_bracket_multiplier) | |
| for pos in square_brackets: | |
| multiply_range(pos, square_bracket_multiplier) | |
| if len(res) == 0: | |
| res = [["", 1.0]] | |
| # merge runs of identical weights | |
| i = 0 | |
| while i + 1 < len(res): | |
| if res[i][1] == res[i + 1][1]: | |
| res[i][0] += res[i + 1][0] | |
| res.pop(i + 1) | |
| else: | |
| i += 1 | |
| return res | |
| def get_prompts_tokens_with_weights(clip_tokenizer: CLIPTokenizer, prompt: str): | |
| """ | |
| Get prompt token ids and weights, this function works for both prompt and negative prompt | |
| Args: | |
| pipe (CLIPTokenizer) | |
| A CLIPTokenizer | |
| prompt (str) | |
| A prompt string with weights | |
| Returns: | |
| text_tokens (list) | |
| A list contains token ids | |
| text_weight (list) | |
| A list contains the correspondent weight of token ids | |
| Example: | |
| import torch | |
| from transformers import CLIPTokenizer | |
| clip_tokenizer = CLIPTokenizer.from_pretrained( | |
| "stablediffusionapi/deliberate-v2" | |
| , subfolder = "tokenizer" | |
| , dtype = torch.float16 | |
| ) | |
| token_id_list, token_weight_list = get_prompts_tokens_with_weights( | |
| clip_tokenizer = clip_tokenizer | |
| ,prompt = "a (red:1.5) cat"*70 | |
| ) | |
| """ | |
| texts_and_weights = parse_prompt_attention(prompt) | |
| text_tokens, text_weights = [], [] | |
| for word, weight in texts_and_weights: | |
| # tokenize and discard the starting and the ending token | |
| token = clip_tokenizer(word, truncation=False).input_ids[1:-1] # so that tokenize whatever length prompt | |
| # the returned token is a 1d list: [320, 1125, 539, 320] | |
| # merge the new tokens to the all tokens holder: text_tokens | |
| text_tokens = [*text_tokens, *token] | |
| # each token chunk will come with one weight, like ['red cat', 2.0] | |
| # need to expand weight for each token. | |
| chunk_weights = [weight] * len(token) | |
| # append the weight back to the weight holder: text_weights | |
| text_weights = [*text_weights, *chunk_weights] | |
| return text_tokens, text_weights | |
| def group_tokens_and_weights(token_ids: list, weights: list, pad_last_block=False): | |
| """ | |
| Produce tokens and weights in groups and pad the missing tokens | |
| Args: | |
| token_ids (list) | |
| The token ids from tokenizer | |
| weights (list) | |
| The weights list from function get_prompts_tokens_with_weights | |
| pad_last_block (bool) | |
| Control if fill the last token list to 75 tokens with eos | |
| Returns: | |
| new_token_ids (2d list) | |
| new_weights (2d list) | |
| Example: | |
| token_groups,weight_groups = group_tokens_and_weights( | |
| token_ids = token_id_list | |
| , weights = token_weight_list | |
| ) | |
| """ | |
| bos, eos = 49406, 49407 | |
| # this will be a 2d list | |
| new_token_ids = [] | |
| new_weights = [] | |
| while len(token_ids) >= 75: | |
| # get the first 75 tokens | |
| head_75_tokens = [token_ids.pop(0) for _ in range(75)] | |
| head_75_weights = [weights.pop(0) for _ in range(75)] | |
| # extract token ids and weights | |
| temp_77_token_ids = [bos] + head_75_tokens + [eos] | |
| temp_77_weights = [1.0] + head_75_weights + [1.0] | |
| # add 77 token and weights chunk to the holder list | |
| new_token_ids.append(temp_77_token_ids) | |
| new_weights.append(temp_77_weights) | |
| # padding the left | |
| if len(token_ids) > 0: | |
| padding_len = 75 - len(token_ids) if pad_last_block else 0 | |
| temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos] | |
| new_token_ids.append(temp_77_token_ids) | |
| temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0] | |
| new_weights.append(temp_77_weights) | |
| return new_token_ids, new_weights | |
| def get_weighted_text_embeddings_sdxl( | |
| pipe, | |
| prompt: str = "", | |
| prompt_2: str = None, | |
| neg_prompt: str = "", | |
| neg_prompt_2: str = None, | |
| num_images_per_prompt: int = 1, | |
| device: Optional[torch.device] = None, | |
| clip_skip: Optional[int] = None, | |
| lora_scale: Optional[int] = None, | |
| ): | |
| """ | |
| This function can process long prompt with weights, no length limitation | |
| for Stable Diffusion XL | |
| Args: | |
| pipe (StableDiffusionPipeline) | |
| prompt (str) | |
| prompt_2 (str) | |
| neg_prompt (str) | |
| neg_prompt_2 (str) | |
| num_images_per_prompt (int) | |
| device (torch.device) | |
| clip_skip (int) | |
| Returns: | |
| prompt_embeds (torch.Tensor) | |
| neg_prompt_embeds (torch.Tensor) | |
| """ | |
| device = device or pipe._execution_device | |
| # set lora scale so that monkey patched LoRA | |
| # function of text encoder can correctly access it | |
| if lora_scale is not None and isinstance(pipe, StableDiffusionXLLoraLoaderMixin): | |
| pipe._lora_scale = lora_scale | |
| # dynamically adjust the LoRA scale | |
| if pipe.text_encoder is not None: | |
| if not USE_PEFT_BACKEND: | |
| adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale) | |
| else: | |
| scale_lora_layers(pipe.text_encoder, lora_scale) | |
| if pipe.text_encoder_2 is not None: | |
| if not USE_PEFT_BACKEND: | |
| adjust_lora_scale_text_encoder(pipe.text_encoder_2, lora_scale) | |
| else: | |
| scale_lora_layers(pipe.text_encoder_2, lora_scale) | |
| if prompt_2: | |
| prompt = f"{prompt} {prompt_2}" | |
| if neg_prompt_2: | |
| neg_prompt = f"{neg_prompt} {neg_prompt_2}" | |
| prompt_t1 = prompt_t2 = prompt | |
| neg_prompt_t1 = neg_prompt_t2 = neg_prompt | |
| if isinstance(pipe, TextualInversionLoaderMixin): | |
| prompt_t1 = pipe.maybe_convert_prompt(prompt_t1, pipe.tokenizer) | |
| neg_prompt_t1 = pipe.maybe_convert_prompt(neg_prompt_t1, pipe.tokenizer) | |
| prompt_t2 = pipe.maybe_convert_prompt(prompt_t2, pipe.tokenizer_2) | |
| neg_prompt_t2 = pipe.maybe_convert_prompt(neg_prompt_t2, pipe.tokenizer_2) | |
| eos = pipe.tokenizer.eos_token_id | |
| # tokenizer 1 | |
| prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, prompt_t1) | |
| neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt_t1) | |
| # tokenizer 2 | |
| prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt_t2) | |
| neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt_t2) | |
| # padding the shorter one for prompt set 1 | |
| prompt_token_len = len(prompt_tokens) | |
| neg_prompt_token_len = len(neg_prompt_tokens) | |
| if prompt_token_len > neg_prompt_token_len: | |
| # padding the neg_prompt with eos token | |
| neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len) | |
| neg_prompt_weights = neg_prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len) | |
| else: | |
| # padding the prompt | |
| prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len) | |
| prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len) | |
| # padding the shorter one for token set 2 | |
| prompt_token_len_2 = len(prompt_tokens_2) | |
| neg_prompt_token_len_2 = len(neg_prompt_tokens_2) | |
| if prompt_token_len_2 > neg_prompt_token_len_2: | |
| # padding the neg_prompt with eos token | |
| neg_prompt_tokens_2 = neg_prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2) | |
| neg_prompt_weights_2 = neg_prompt_weights_2 + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2) | |
| else: | |
| # padding the prompt | |
| prompt_tokens_2 = prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2) | |
| prompt_weights_2 = prompt_weights + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2) | |
| embeds = [] | |
| neg_embeds = [] | |
| prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(prompt_tokens.copy(), prompt_weights.copy()) | |
| neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights( | |
| neg_prompt_tokens.copy(), neg_prompt_weights.copy() | |
| ) | |
| prompt_token_groups_2, prompt_weight_groups_2 = group_tokens_and_weights( | |
| prompt_tokens_2.copy(), prompt_weights_2.copy() | |
| ) | |
| neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = group_tokens_and_weights( | |
| neg_prompt_tokens_2.copy(), neg_prompt_weights_2.copy() | |
| ) | |
| # get prompt embeddings one by one is not working. | |
| for i in range(len(prompt_token_groups)): | |
| # get positive prompt embeddings with weights | |
| token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=device) | |
| weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=device) | |
| token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=device) | |
| # use first text encoder | |
| prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True) | |
| # use second text encoder | |
| prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True) | |
| pooled_prompt_embeds = prompt_embeds_2[0] | |
| if clip_skip is None: | |
| prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2] | |
| prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2] | |
| else: | |
| # "2" because SDXL always indexes from the penultimate layer. | |
| prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-(clip_skip + 2)] | |
| prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-(clip_skip + 2)] | |
| prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states] | |
| token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0) | |
| for j in range(len(weight_tensor)): | |
| if weight_tensor[j] != 1.0: | |
| token_embedding[j] = ( | |
| token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j] | |
| ) | |
| token_embedding = token_embedding.unsqueeze(0) | |
| embeds.append(token_embedding) | |
| # get negative prompt embeddings with weights | |
| neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=device) | |
| neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=device) | |
| neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=device) | |
| # use first text encoder | |
| neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(device), output_hidden_states=True) | |
| neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2] | |
| # use second text encoder | |
| neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(device), output_hidden_states=True) | |
| neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2] | |
| negative_pooled_prompt_embeds = neg_prompt_embeds_2[0] | |
| neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states] | |
| neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0) | |
| for z in range(len(neg_weight_tensor)): | |
| if neg_weight_tensor[z] != 1.0: | |
| neg_token_embedding[z] = ( | |
| neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z] | |
| ) | |
| neg_token_embedding = neg_token_embedding.unsqueeze(0) | |
| neg_embeds.append(neg_token_embedding) | |
| prompt_embeds = torch.cat(embeds, dim=1) | |
| negative_prompt_embeds = torch.cat(neg_embeds, dim=1) | |
| bs_embed, seq_len, _ = prompt_embeds.shape | |
| # duplicate text embeddings for each generation per prompt, using mps friendly method | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) | |
| seq_len = negative_prompt_embeds.shape[1] | |
| negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) | |
| pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view( | |
| bs_embed * num_images_per_prompt, -1 | |
| ) | |
| negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view( | |
| bs_embed * num_images_per_prompt, -1 | |
| ) | |
| if pipe.text_encoder is not None: | |
| if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: | |
| # Retrieve the original scale by scaling back the LoRA layers | |
| unscale_lora_layers(pipe.text_encoder, lora_scale) | |
| if pipe.text_encoder_2 is not None: | |
| if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: | |
| # Retrieve the original scale by scaling back the LoRA layers | |
| unscale_lora_layers(pipe.text_encoder_2, lora_scale) | |
| return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds | |
| class ModImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline): | |
| def encode_prompt(self, prompt, num_images_per_prompt, negative_prompt, lora_scale, clip_skip, **kwags): | |
| return get_weighted_text_embeddings_sdxl( | |
| pipe=self, | |
| prompt=prompt, | |
| neg_prompt=negative_prompt, | |
| num_images_per_prompt=num_images_per_prompt, | |
| clip_skip=clip_skip, | |
| lora_scale=lora_scale, | |
| ) | |
| def get_timesteps(self, num_inference_steps, strength, device, **kwargs): | |
| # get the original timestep using init_timestep | |
| init_timestep = min(num_inference_steps * strength, num_inference_steps) | |
| t_start = int(max(num_inference_steps - init_timestep, 0)) | |
| timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] | |
| if hasattr(self.scheduler, "set_begin_index"): | |
| self.scheduler.set_begin_index(t_start * self.scheduler.order) | |
| return timesteps, num_inference_steps - t_start | |
| class ModText2ImgPipeline(StableDiffusionXLPipeline): | |
| def encode_prompt(self, prompt, num_images_per_prompt, negative_prompt, lora_scale, clip_skip, **kwags): | |
| return get_weighted_text_embeddings_sdxl( | |
| pipe=self, | |
| prompt=prompt, | |
| neg_prompt=negative_prompt, | |
| num_images_per_prompt=num_images_per_prompt, | |
| clip_skip=clip_skip, | |
| lora_scale=lora_scale, | |
| ) | |
| class ModFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): | |
| def init_noise_sigma(self): | |
| return 1.0 | |
| def scale_model_input(self, x, y): | |
| return x | |
| def add_noise(self, x, n, t): | |
| return self.scale_noise(x, t, n) | |
| def get_pipeline_initialize(model_1="", model_2=""): | |
| pipe = ModText2ImgPipeline.from_single_file( | |
| os.getenv("SDXL_MODEL", model_1), | |
| torch_dtype=torch.float16 | |
| ) | |
| pipe.fuse_qkv_projections() | |
| pipe.unet.set_attention_backend("_flash_3_hub") | |
| pipe.scheduler = ModFlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=2.0) | |
| pipe.unet.to(memory_format=torch.channels_last) | |
| pipe.vae.to(memory_format=torch.channels_last) | |
| pipe2 = ModText2ImgPipeline.from_single_file( | |
| os.getenv("SDXL_MODEL_2", model_2), | |
| torch_dtype=torch.float16 | |
| ) | |
| pipe2.fuse_qkv_projections() | |
| pipe2.unet.set_attention_backend("_flash_3_hub") | |
| pipe2.scheduler = ModFlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=2.0) | |
| pipe2.unet.to(memory_format=torch.channels_last) | |
| pipe2.vae.to(memory_format=torch.channels_last) | |
| pipe_img2img = ModImg2ImgPipeline( | |
| vae=pipe.vae, | |
| unet=pipe.unet, | |
| text_encoder=pipe.text_encoder, | |
| text_encoder_2=pipe.text_encoder_2, | |
| tokenizer=pipe.tokenizer, | |
| tokenizer_2=pipe.tokenizer_2, | |
| scheduler=pipe.scheduler, | |
| image_encoder=pipe.image_encoder, | |
| feature_extractor=pipe.feature_extractor, | |
| ) | |
| pipe2_img2img = ModImg2ImgPipeline( | |
| vae=pipe2.vae, | |
| unet=pipe2.unet, | |
| text_encoder=pipe2.text_encoder, | |
| text_encoder_2=pipe2.text_encoder_2, | |
| tokenizer=pipe2.tokenizer, | |
| tokenizer_2=pipe2.tokenizer_2, | |
| scheduler=pipe2.scheduler, | |
| image_encoder=pipe2.image_encoder, | |
| feature_extractor=pipe2.feature_extractor, | |
| ) | |
| return pipe, pipe2, pipe_img2img, pipe2_img2img | |
| def sign_message(message, key): | |
| hmac_digest = hmac.new(key.encode(), message.encode(), hashlib.sha512).digest() | |
| signed_hash = base64.b64encode(hmac_digest).decode() | |
| return signed_hash | |
| def run(prompt, radio="model-v2", preset=PRESET_Q, h=1216, w=832, negative_prompt=NEGATIVE_PROMPT, guidance_scale=4.0, randomize_seed=True, seed=42, do_img2img=False, init_image=None, image2image_resize=True, image2image_strength=0, inference_steps=25, progress=gr.Progress(track_tqdm=True)): | |
| if init_image is None: | |
| do_img2img = False | |
| if do_img2img and image2image_resize: | |
| # init_image: np.ndarray | |
| init_image = Image.fromarray(init_image) | |
| init_image = init_image.resize((w, h)) | |
| init_image = np.array(init_image) | |
| prompt = prompt.strip() + ", " + preset.strip() | |
| negative_prompt = negative_prompt.strip() if negative_prompt and negative_prompt.strip() else None | |
| print(f"Initial seed for prompt `{prompt}`", seed) | |
| if(randomize_seed): | |
| seed = random.randint(0, 9007199254740991) | |
| if not prompt and not negative_prompt: | |
| guidance_scale = 0.0 | |
| generator = torch.Generator(device="cuda").manual_seed(seed) | |
| if inference_steps > 50: | |
| inference_steps = 50 | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| if not do_img2img: | |
| if radio == "model-v2": | |
| image = pipe(prompt, height=h, width=w, negative_prompt=negative_prompt, guidance_scale=guidance_scale, guidance_rescale=0.75, generator=generator, num_inference_steps=inference_steps).images[0] | |
| else: | |
| image = pipe2(prompt, height=h, width=w, negative_prompt=negative_prompt, guidance_scale=guidance_scale, guidance_rescale=0.75, generator=generator, num_inference_steps=inference_steps).images[0] | |
| else: | |
| init_image = Image.fromarray(init_image) | |
| if radio == "model-v2": | |
| image = pipe_img2img(prompt, image=init_image, strength=image2image_strength, negative_prompt=negative_prompt, guidance_scale=guidance_scale, generator=generator, num_inference_steps=inference_steps).images[0] | |
| else: | |
| image = pipe2_img2img(prompt, image=init_image, strength=image2image_strength, negative_prompt=negative_prompt, guidance_scale=guidance_scale, generator=generator, num_inference_steps=inference_steps).images[0] | |
| naifix = prompt[:40].replace(":", "_").replace("\\", "_").replace("/", "_") + f" s-{seed}-" | |
| with tempfile.NamedTemporaryFile(prefix=naifix, suffix=".png", delete=False) as tmpfile: | |
| parameters = { | |
| "prompt": prompt, | |
| "steps": inference_steps, | |
| "height": h, | |
| "width": w, | |
| "scale": guidance_scale, | |
| "uncond_scale": 0.0, | |
| "cfg_rescale": 0.0, | |
| "seed": seed, | |
| "n_samples": 1, | |
| "hide_debug_overlay": False, | |
| "noise_schedule": "native", | |
| "legacy_v3_extend": False, | |
| "reference_information_extracted_multiple": [], | |
| "reference_strength_multiple": [], | |
| "sampler": "k_dpmpp_2m_sde", | |
| "controlnet_strength": 1.0, | |
| "controlnet_model": None, | |
| "dynamic_thresholding": False, | |
| "dynamic_thresholding_percentile": 0.999, | |
| "dynamic_thresholding_mimic_scale": 10.0, | |
| "sm": False, | |
| "sm_dyn": False, | |
| "skip_cfg_above_sigma": 23.69030960605558, | |
| "skip_cfg_below_sigma": 0.0, | |
| "lora_unet_weights": None, | |
| "lora_clip_weights": None, | |
| "deliberate_euler_ancestral_bug": True, | |
| "prefer_brownian": False, | |
| "cfg_sched_eligibility": "enable_for_post_summer_samplers", | |
| "explike_fine_detail": False, | |
| "minimize_sigma_inf": False, | |
| "uncond_per_vibe": True, | |
| "wonky_vibe_correlation": True, | |
| "version": 1, | |
| "uc": "nsfw, lowres, {bad}, error, fewer, extra, missing, worst quality, jpeg artifacts, bad quality, watermark, unfinished, displeasing, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], lowres, {bad}, error, fewer, extra, missing, worst quality, jpeg artifacts, bad quality, unfinished, displeasing, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract],{{{{chibi,doll,+_+}}}},", | |
| } | |
| metadata_params = { | |
| "request_type": "PromptGenerateRequest", | |
| "signed_hash": sign_message(json.dumps(parameters), "novelai-client"), | |
| **parameters | |
| } | |
| metadata = PngInfo() | |
| metadata.add_text("Title", "AI generated image") | |
| metadata.add_text("Description", prompt) | |
| metadata.add_text("Software", "NovelAI") | |
| metadata.add_text("Source", "Stable Diffusion XL 7BCCAA2C") | |
| metadata.add_text("Nya", "Nya~") | |
| metadata.add_text("Generation time", f"1.{random.randint(1000000000, 9999999999)}") | |
| metadata.add_text("Comment", json.dumps(metadata_params)) | |
| image.save(tmpfile, "png", pnginfo=metadata) | |
| return tmpfile.name, seed | |
| pipe, pipe2, pipe_img2img, pipe2_img2img = get_pipeline_initialize() | |
| pipe, pipe2 = pipe.to(device), pipe2.to(device) | |
| with gr.Blocks(theme=theme) as demo: | |
| gr.Markdown('''# SDXL Experiments | |
| Just a simple demo for some SDXL model.''') | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| with gr.Row(): | |
| prompt = gr.Textbox(show_label=False, scale=5, value="1girl, rurudo", placeholder="Your prompt", info="Leave blank to test unconditional generation") | |
| button = gr.Button("Generate", min_width=120) | |
| preset = gr.Textbox(show_label=False, scale=5, value=PRESET_Q, info="Quality presets") | |
| radio = gr.Radio(["model-v2-beta", "model-v2"], value="model-v2", label = "Choose the inference model") | |
| inference_steps = gr.Slider(label="Inference Steps", value=25, minimum=4, maximum=50, step=1) | |
| with gr.Row(): | |
| height = gr.Slider(label="Height", value=1216, minimum=512, maximum=2560, step=64) | |
| width = gr.Slider(label="Width", value=832, minimum=512, maximum=2560, step=64) | |
| guidance_scale = gr.Number(label="CFG Guidance Scale", info="The guidance scale for CFG, ignored if no prompt is entered (unconditional generation)", value=4.0) | |
| negative_prompt = gr.Textbox(label="Negative prompt", value=NEGATIVE_PROMPT, info="Is only applied for the CFG part, leave blank for unconditional generation") | |
| seed = gr.Number(label="Seed", value=42, info="Seed for random number generator") | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| do_img2img = gr.Checkbox(label="Image to Image", value=False) | |
| init_image = gr.Image(label="Input Image", visible=False) | |
| image2image_resize = gr.Checkbox(label="Resize input image", value=False, visible=False) | |
| image2image_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Noising strength", value=0.7, visible=False) | |
| with gr.Column(): | |
| output = gr.Image(type="filepath", interactive=False) | |
| gr.Examples(fn=run, examples=["mayano_top_gun_\(umamusume\), 1girl, rurudo", "sho (sho lwlw),[[[ohisashiburi]]],fukuro daizi,tianliang duohe fangdongye,[daidai ookami],year_2023, (wariza), depth of field, official_art"], inputs=prompt, outputs=[output, seed], cache_examples="lazy") | |
| do_img2img.change( | |
| fn=lambda x: [gr.update(visible=x), gr.update(visible=x), gr.update(visible=x)], | |
| inputs=[do_img2img], | |
| outputs=[init_image, image2image_resize, image2image_strength] | |
| ) | |
| gr.on( | |
| triggers=[ | |
| button.click, | |
| prompt.submit | |
| ], | |
| fn=run, | |
| inputs=[prompt, radio, preset, height, width, negative_prompt, guidance_scale, randomize_seed, seed, do_img2img, init_image, image2image_resize, image2image_strength, inference_steps], | |
| outputs=[output, seed], | |
| concurrency_limit=1, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |