nyanko7 commited on
Commit
2cb0c87
·
verified ·
1 Parent(s): 0ac7cff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +569 -123
app.py CHANGED
@@ -1,139 +1,590 @@
1
- import io
 
 
 
2
  import inspect
3
- import os
4
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
-
6
  import math
7
- import torch
 
8
  import random
9
- import torch.nn.functional as F
10
  import tempfile
11
- import gradio as gr
12
- import spaces
 
13
  import httpimport
14
- import json
15
- from PIL import Image
16
  from packaging import version
 
17
  from PIL.PngImagePlugin import PngInfo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  if torch.cuda.get_device_properties(0).major >= 8:
20
  torch.backends.cuda.matmul.allow_tf32 = True
21
  torch.backends.cudnn.allow_tf32 = True
22
 
23
- with httpimport.remote_repo(os.getenv("MODULE_URL")):
24
- import pipeline
25
- pipe, pipe2, pipe_img2img, pipe2_img2img = pipeline.get_pipeline_initialize()
26
-
27
  theme = gr.themes.Base(font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'])
28
  device="cuda"
29
  pipe = pipe.to(device)
30
  pipe2 = pipe2.to(device)
31
  PRESET_Q = "year_2022, best quality, high quality, very aesthetic"
32
  NEGATIVE_PROMPT = "lowres, worst quality, displeasing, bad anatomy, text, error, extra digit, cropped, error, fewer, extra, missing, worst quality, jpeg artifacts, censored, worst quality displeasing, bad quality"
 
33
 
34
- import hashlib
35
- import base64
36
- import hmac
37
 
38
- import numpy as np
39
- import pickle
40
- import requests
41
- import codecs
42
 
43
- def tpu_inference_api(
44
- prompt: str,
45
- radio: str = "model-v2",
46
- preset: str = "year_2022, best quality, high quality, very aesthetic",
47
- h: int = 1216,
48
- w: int = 832,
49
- negative_prompt: str = "lowres, worst quality, displeasing, bad anatomy, text, error, extra digit, cropped, error, fewer, extra, missing, worst quality, jpeg artifacts, censored, ai-generated worst quality displeasing, bad quality",
50
- guidance_scale: float = 4.0,
51
- randomize_seed: bool = True,
52
- seed: int = 42,
53
- do_img2img: bool = False,
54
- init_image: Optional[str] = None,
55
- image2image_strength: float = 0,
56
- inference_steps = 25,
57
- ) -> bytes:
58
- url = os.getenv("TPU_INFERENCE_API")
59
- if(randomize_seed):
60
- seed = random.randint(0, 9007199254740991)
61
- randomize_seed = False
62
-
63
- payload = {
64
- "prompt": prompt,
65
- "radio": radio,
66
- "preset": preset,
67
- "height": h,
68
- "width": w,
69
- "negative_prompt": negative_prompt,
70
- "guidance_scale": guidance_scale,
71
- "randomize_seed": randomize_seed,
72
- "seed": seed,
73
- "do_img2img": do_img2img,
74
- "image2image_strength": image2image_strength,
75
- "init_image": init_image,
76
- "inference_steps": inference_steps,
77
- }
78
- response = requests.post(url, json=payload, timeout=60)
79
- if response.status_code != 200:
80
- raise Exception(f"Error calling API: {response.status_code} - {response.text}")
81
-
82
- image = Image.open(io.BytesIO(response.content))
83
- naifix = prompt[:40].replace(":", "_").replace("\\", "_").replace("/", "_") + f" s-{seed}-"
84
- with tempfile.NamedTemporaryFile(prefix=naifix, suffix=".png", delete=False) as tmpfile:
85
- parameters = {
86
- "prompt": prompt,
87
- "steps": 25,
88
- "height": h,
89
- "width": w,
90
- "scale": guidance_scale,
91
- "uncond_scale": 0.0,
92
- "cfg_rescale": 0.0,
93
- "seed": seed,
94
- "n_samples": 1,
95
- "hide_debug_overlay": False,
96
- "noise_schedule": "native",
97
- "legacy_v3_extend": False,
98
- "reference_information_extracted_multiple": [],
99
- "reference_strength_multiple": [],
100
- "sampler": "k_dpmpp_2m_sde",
101
- "controlnet_strength": 1.0,
102
- "controlnet_model": None,
103
- "dynamic_thresholding": False,
104
- "dynamic_thresholding_percentile": 0.999,
105
- "dynamic_thresholding_mimic_scale": 10.0,
106
- "sm": False,
107
- "sm_dyn": False,
108
- "skip_cfg_above_sigma": 23.69030960605558,
109
- "skip_cfg_below_sigma": 0.0,
110
- "lora_unet_weights": None,
111
- "lora_clip_weights": None,
112
- "deliberate_euler_ancestral_bug": True,
113
- "prefer_brownian": False,
114
- "cfg_sched_eligibility": "enable_for_post_summer_samplers",
115
- "explike_fine_detail": False,
116
- "minimize_sigma_inf": False,
117
- "uncond_per_vibe": True,
118
- "wonky_vibe_correlation": True,
119
- "version": 1,
120
- "uc": "nsfw, lowres, {bad}, error, fewer, extra, missing, worst quality, jpeg artifacts, bad quality, watermark, unfinished, displeasing, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], lowres, {bad}, error, fewer, extra, missing, worst quality, jpeg artifacts, bad quality, unfinished, displeasing, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract],{{{{chibi,doll,+_+}}}},",
121
- }
122
- metadata_params = {
123
- "request_type": "PromptGenerateRequest",
124
- "signed_hash": sign_message(json.dumps(parameters), "novelai-client"),
125
- **parameters
126
- }
127
- metadata = PngInfo()
128
- metadata.add_text("Title", "AI generated image")
129
- metadata.add_text("Description", prompt)
130
- metadata.add_text("Software", "NovelAI")
131
- metadata.add_text("Source", "Stable Diffusion XL 7BCCAA2C")
132
- metadata.add_text("Nya", "Nya~")
133
- metadata.add_text("Generation time", f"1.{random.randint(1000000000, 9999999999)}")
134
- metadata.add_text("Comment", json.dumps(metadata_params))
135
- image.save(tmpfile, "png", pnginfo=metadata)
136
- return tmpfile.name, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  def sign_message(message, key):
139
  hmac_digest = hmac.new(key.encode(), message.encode(), hashlib.sha512).digest()
@@ -150,14 +601,6 @@ def run(prompt, radio="model-v2", preset=PRESET_Q, h=1216, w=832, negative_promp
150
  init_image = init_image.resize((w, h))
151
  init_image = np.array(init_image)
152
 
153
- if tpu_inference:
154
- prompt = prompt.replace("!", " ").replace("\n", " ") # remote endpoint unsupported
155
- if do_img2img:
156
- init_image = codecs.encode(pickle.dumps(init_image, protocol=pickle.HIGHEST_PROTOCOL), "base64").decode('latin1')
157
- return tpu_inference_api(prompt, radio, preset, h, w, negative_prompt, guidance_scale, randomize_seed, seed, do_img2img, init_image, image2image_strength, inference_steps=inference_steps)
158
- else:
159
- return tpu_inference_api(prompt, radio, preset, h, w, negative_prompt, guidance_scale, randomize_seed, seed, inference_steps=inference_steps)
160
-
161
  return zero_inference_api(prompt, radio, preset, h, w, negative_prompt, guidance_scale, randomize_seed, seed, do_img2img, init_image, image2image_strength, inference_steps=inference_steps)
162
 
163
  @spaces.GPU
@@ -244,6 +687,8 @@ def zero_inference_api(prompt, radio="model-v2", preset=PRESET_Q, h=1216, w=832,
244
  image.save(tmpfile, "png", pnginfo=metadata)
245
  return tmpfile.name, seed
246
 
 
 
247
  with gr.Blocks(theme=theme) as demo:
248
  gr.Markdown('''# SDXL Experiments
249
  Just a simple demo for some SDXL model.''')
@@ -292,5 +737,6 @@ with gr.Blocks(theme=theme) as demo:
292
  outputs=[output, seed],
293
  concurrency_limit=1,
294
  )
 
295
  if __name__ == "__main__":
296
  demo.launch(share=True)
 
1
+ import base64
2
+ import codecs
3
+ import hashlib
4
+ import hmac
5
  import inspect
6
+ import io
7
+ import json
 
8
  import math
9
+ import os
10
+ import pickle
11
  import random
 
12
  import tempfile
13
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
14
+
15
+ # Third-party general libraries
16
  import httpimport
17
+ import numpy as np
18
+ import requests
19
  from packaging import version
20
+ from PIL import Image
21
  from PIL.PngImagePlugin import PngInfo
22
+ import PIL
23
+
24
+ # PyTorch
25
+ import torch
26
+ import torch.nn.functional as F
27
+
28
+ # Hugging Face & Diffusers
29
+ from transformers import CLIPTokenizer
30
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
31
+ from diffusers.loaders import (
32
+ StableDiffusionXLLoraLoaderMixin,
33
+ TextualInversionLoaderMixin,
34
+ )
35
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
36
+ from diffusers.schedulers import EulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler
37
+ from diffusers.utils import (
38
+ USE_PEFT_BACKEND,
39
+ logging,
40
+ scale_lora_layers,
41
+ unscale_lora_layers,
42
+ )
43
+ from diffusers.utils.torch_utils import randn_tensor
44
+
45
+ # App / UI
46
+ import gradio as gr
47
+ import spaces
48
 
49
  if torch.cuda.get_device_properties(0).major >= 8:
50
  torch.backends.cuda.matmul.allow_tf32 = True
51
  torch.backends.cudnn.allow_tf32 = True
52
 
 
 
 
 
53
  theme = gr.themes.Base(font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'])
54
  device="cuda"
55
  pipe = pipe.to(device)
56
  pipe2 = pipe2.to(device)
57
  PRESET_Q = "year_2022, best quality, high quality, very aesthetic"
58
  NEGATIVE_PROMPT = "lowres, worst quality, displeasing, bad anatomy, text, error, extra digit, cropped, error, fewer, extra, missing, worst quality, jpeg artifacts, censored, worst quality displeasing, bad quality"
59
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
60
 
61
+ def get_class(name: str):
62
+ import importlib
 
63
 
64
+ module_name, class_name = name.rsplit(".", 1)
65
+ module = importlib.import_module(module_name, package=None)
66
+ return getattr(module, class_name)
 
67
 
68
+ def retrieve_latents(
69
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
70
+ ):
71
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
72
+ return encoder_output.latent_dist.sample(generator)
73
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
74
+ return encoder_output.latent_dist.mode()
75
+ elif hasattr(encoder_output, "latents"):
76
+ return encoder_output.latents
77
+ else:
78
+ raise AttributeError("Could not access latents of provided encoder_output")
79
+
80
+ def parse_prompt_attention(text):
81
+ """
82
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
83
+ Accepted tokens are:
84
+ (abc) - increases attention to abc by a multiplier of 1.1
85
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
86
+ [abc] - decreases attention to abc by a multiplier of 1.1
87
+ \\( - literal character '('
88
+ \\[ - literal character '['
89
+ \\) - literal character ')'
90
+ \\] - literal character ']'
91
+ \\ - literal character '\'
92
+ anything else - just text
93
+
94
+ >>> parse_prompt_attention('normal text')
95
+ [['normal text', 1.0]]
96
+ >>> parse_prompt_attention('an (important) word')
97
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
98
+ >>> parse_prompt_attention('(unbalanced')
99
+ [['unbalanced', 1.1]]
100
+ >>> parse_prompt_attention('\\(literal\\]')
101
+ [['(literal]', 1.0]]
102
+ >>> parse_prompt_attention('(unnecessary)(parens)')
103
+ [['unnecessaryparens', 1.1]]
104
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
105
+ [['a ', 1.0],
106
+ ['house', 1.5730000000000004],
107
+ [' ', 1.1],
108
+ ['on', 1.0],
109
+ [' a ', 1.1],
110
+ ['hill', 0.55],
111
+ [', sun, ', 1.1],
112
+ ['sky', 1.4641000000000006],
113
+ ['.', 1.1]]
114
+ """
115
+ import re
116
+
117
+ re_attention = re.compile(
118
+ r"""
119
+ \{|\}|\\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|
120
+ \)|]|[^\\()\[\]:]+|:
121
+ """,
122
+ re.X,
123
+ )
124
+
125
+ re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
126
+
127
+ res = []
128
+ round_brackets = []
129
+ square_brackets = []
130
+ curly_brackets = []
131
+ round_bracket_multiplier = 1.05
132
+ curly_bracket_multiplier = 1.05
133
+ square_bracket_multiplier = 1 / 1.05
134
+
135
+ def multiply_range(start_position, multiplier):
136
+ for p in range(start_position, len(res)):
137
+ res[p][1] *= multiplier
138
+
139
+ for m in re_attention.finditer(text):
140
+ text = m.group(0)
141
+ weight = m.group(1)
142
+
143
+ if text.startswith("\\"):
144
+ res.append([text[1:], 1.0])
145
+ elif text == "(":
146
+ round_brackets.append(len(res))
147
+ elif text == "{":
148
+ curly_brackets.append(len(res))
149
+ elif text == "[":
150
+ square_brackets.append(len(res))
151
+ elif weight is not None and len(round_brackets) > 0:
152
+ multiply_range(round_brackets.pop(), float(weight))
153
+ elif text == ")" and len(round_brackets) > 0:
154
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
155
+ elif text == "}" and len(round_brackets) > 0:
156
+ multiply_range(curly_brackets.pop(), curly_bracket_multiplier)
157
+ elif text == "]" and len(square_brackets) > 0:
158
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
159
+ else:
160
+ parts = re.split(re_break, text)
161
+ for i, part in enumerate(parts):
162
+ if i > 0:
163
+ res.append(["BREAK", -1])
164
+ res.append([part, 1.0])
165
+
166
+ for pos in round_brackets:
167
+ multiply_range(pos, round_bracket_multiplier)
168
+
169
+ for pos in square_brackets:
170
+ multiply_range(pos, square_bracket_multiplier)
171
+
172
+ if len(res) == 0:
173
+ res = [["", 1.0]]
174
+
175
+ # merge runs of identical weights
176
+ i = 0
177
+ while i + 1 < len(res):
178
+ if res[i][1] == res[i + 1][1]:
179
+ res[i][0] += res[i + 1][0]
180
+ res.pop(i + 1)
181
+ else:
182
+ i += 1
183
+
184
+ return res
185
+
186
+
187
+ def get_prompts_tokens_with_weights(clip_tokenizer: CLIPTokenizer, prompt: str):
188
+ """
189
+ Get prompt token ids and weights, this function works for both prompt and negative prompt
190
+
191
+ Args:
192
+ pipe (CLIPTokenizer)
193
+ A CLIPTokenizer
194
+ prompt (str)
195
+ A prompt string with weights
196
+
197
+ Returns:
198
+ text_tokens (list)
199
+ A list contains token ids
200
+ text_weight (list)
201
+ A list contains the correspondent weight of token ids
202
+
203
+ Example:
204
+ import torch
205
+ from transformers import CLIPTokenizer
206
+
207
+ clip_tokenizer = CLIPTokenizer.from_pretrained(
208
+ "stablediffusionapi/deliberate-v2"
209
+ , subfolder = "tokenizer"
210
+ , dtype = torch.float16
211
+ )
212
+
213
+ token_id_list, token_weight_list = get_prompts_tokens_with_weights(
214
+ clip_tokenizer = clip_tokenizer
215
+ ,prompt = "a (red:1.5) cat"*70
216
+ )
217
+ """
218
+ texts_and_weights = parse_prompt_attention(prompt)
219
+ text_tokens, text_weights = [], []
220
+ for word, weight in texts_and_weights:
221
+ # tokenize and discard the starting and the ending token
222
+ token = clip_tokenizer(word, truncation=False).input_ids[1:-1] # so that tokenize whatever length prompt
223
+ # the returned token is a 1d list: [320, 1125, 539, 320]
224
+
225
+ # merge the new tokens to the all tokens holder: text_tokens
226
+ text_tokens = [*text_tokens, *token]
227
+
228
+ # each token chunk will come with one weight, like ['red cat', 2.0]
229
+ # need to expand weight for each token.
230
+ chunk_weights = [weight] * len(token)
231
+
232
+ # append the weight back to the weight holder: text_weights
233
+ text_weights = [*text_weights, *chunk_weights]
234
+ return text_tokens, text_weights
235
+
236
+
237
+ def group_tokens_and_weights(token_ids: list, weights: list, pad_last_block=False):
238
+ """
239
+ Produce tokens and weights in groups and pad the missing tokens
240
+
241
+ Args:
242
+ token_ids (list)
243
+ The token ids from tokenizer
244
+ weights (list)
245
+ The weights list from function get_prompts_tokens_with_weights
246
+ pad_last_block (bool)
247
+ Control if fill the last token list to 75 tokens with eos
248
+ Returns:
249
+ new_token_ids (2d list)
250
+ new_weights (2d list)
251
+
252
+ Example:
253
+ token_groups,weight_groups = group_tokens_and_weights(
254
+ token_ids = token_id_list
255
+ , weights = token_weight_list
256
+ )
257
+ """
258
+ bos, eos = 49406, 49407
259
+
260
+ # this will be a 2d list
261
+ new_token_ids = []
262
+ new_weights = []
263
+ while len(token_ids) >= 75:
264
+ # get the first 75 tokens
265
+ head_75_tokens = [token_ids.pop(0) for _ in range(75)]
266
+ head_75_weights = [weights.pop(0) for _ in range(75)]
267
+
268
+ # extract token ids and weights
269
+ temp_77_token_ids = [bos] + head_75_tokens + [eos]
270
+ temp_77_weights = [1.0] + head_75_weights + [1.0]
271
+
272
+ # add 77 token and weights chunk to the holder list
273
+ new_token_ids.append(temp_77_token_ids)
274
+ new_weights.append(temp_77_weights)
275
+
276
+ # padding the left
277
+ if len(token_ids) > 0:
278
+ padding_len = 75 - len(token_ids) if pad_last_block else 0
279
+
280
+ temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
281
+ new_token_ids.append(temp_77_token_ids)
282
+
283
+ temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
284
+ new_weights.append(temp_77_weights)
285
+
286
+ return new_token_ids, new_weights
287
+
288
+
289
+ def get_weighted_text_embeddings_sdxl(
290
+ pipe,
291
+ prompt: str = "",
292
+ prompt_2: str = None,
293
+ neg_prompt: str = "",
294
+ neg_prompt_2: str = None,
295
+ num_images_per_prompt: int = 1,
296
+ device: Optional[torch.device] = None,
297
+ clip_skip: Optional[int] = None,
298
+ lora_scale: Optional[int] = None,
299
+ ):
300
+ """
301
+ This function can process long prompt with weights, no length limitation
302
+ for Stable Diffusion XL
303
+
304
+ Args:
305
+ pipe (StableDiffusionPipeline)
306
+ prompt (str)
307
+ prompt_2 (str)
308
+ neg_prompt (str)
309
+ neg_prompt_2 (str)
310
+ num_images_per_prompt (int)
311
+ device (torch.device)
312
+ clip_skip (int)
313
+ Returns:
314
+ prompt_embeds (torch.Tensor)
315
+ neg_prompt_embeds (torch.Tensor)
316
+ """
317
+ device = device or pipe._execution_device
318
+
319
+ # set lora scale so that monkey patched LoRA
320
+ # function of text encoder can correctly access it
321
+ if lora_scale is not None and isinstance(pipe, StableDiffusionXLLoraLoaderMixin):
322
+ pipe._lora_scale = lora_scale
323
+
324
+ # dynamically adjust the LoRA scale
325
+ if pipe.text_encoder is not None:
326
+ if not USE_PEFT_BACKEND:
327
+ adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)
328
+ else:
329
+ scale_lora_layers(pipe.text_encoder, lora_scale)
330
+
331
+ if pipe.text_encoder_2 is not None:
332
+ if not USE_PEFT_BACKEND:
333
+ adjust_lora_scale_text_encoder(pipe.text_encoder_2, lora_scale)
334
+ else:
335
+ scale_lora_layers(pipe.text_encoder_2, lora_scale)
336
+
337
+ if prompt_2:
338
+ prompt = f"{prompt} {prompt_2}"
339
+
340
+ if neg_prompt_2:
341
+ neg_prompt = f"{neg_prompt} {neg_prompt_2}"
342
+
343
+ prompt_t1 = prompt_t2 = prompt
344
+ neg_prompt_t1 = neg_prompt_t2 = neg_prompt
345
+
346
+ if isinstance(pipe, TextualInversionLoaderMixin):
347
+ prompt_t1 = pipe.maybe_convert_prompt(prompt_t1, pipe.tokenizer)
348
+ neg_prompt_t1 = pipe.maybe_convert_prompt(neg_prompt_t1, pipe.tokenizer)
349
+ prompt_t2 = pipe.maybe_convert_prompt(prompt_t2, pipe.tokenizer_2)
350
+ neg_prompt_t2 = pipe.maybe_convert_prompt(neg_prompt_t2, pipe.tokenizer_2)
351
+
352
+ eos = pipe.tokenizer.eos_token_id
353
+
354
+ # tokenizer 1
355
+ prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, prompt_t1)
356
+ neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt_t1)
357
+
358
+ # tokenizer 2
359
+ prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt_t2)
360
+ neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt_t2)
361
+
362
+ # padding the shorter one for prompt set 1
363
+ prompt_token_len = len(prompt_tokens)
364
+ neg_prompt_token_len = len(neg_prompt_tokens)
365
+
366
+ if prompt_token_len > neg_prompt_token_len:
367
+ # padding the neg_prompt with eos token
368
+ neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
369
+ neg_prompt_weights = neg_prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
370
+ else:
371
+ # padding the prompt
372
+ prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
373
+ prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
374
+
375
+ # padding the shorter one for token set 2
376
+ prompt_token_len_2 = len(prompt_tokens_2)
377
+ neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
378
+
379
+ if prompt_token_len_2 > neg_prompt_token_len_2:
380
+ # padding the neg_prompt with eos token
381
+ neg_prompt_tokens_2 = neg_prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
382
+ neg_prompt_weights_2 = neg_prompt_weights_2 + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
383
+ else:
384
+ # padding the prompt
385
+ prompt_tokens_2 = prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
386
+ prompt_weights_2 = prompt_weights + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
387
+
388
+ embeds = []
389
+ neg_embeds = []
390
+
391
+ prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(prompt_tokens.copy(), prompt_weights.copy())
392
+
393
+ neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
394
+ neg_prompt_tokens.copy(), neg_prompt_weights.copy()
395
+ )
396
+
397
+ prompt_token_groups_2, prompt_weight_groups_2 = group_tokens_and_weights(
398
+ prompt_tokens_2.copy(), prompt_weights_2.copy()
399
+ )
400
+
401
+ neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = group_tokens_and_weights(
402
+ neg_prompt_tokens_2.copy(), neg_prompt_weights_2.copy()
403
+ )
404
+
405
+ # get prompt embeddings one by one is not working.
406
+ for i in range(len(prompt_token_groups)):
407
+ # get positive prompt embeddings with weights
408
+ token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=device)
409
+ weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=device)
410
+
411
+ token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=device)
412
+
413
+ # use first text encoder
414
+ prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True)
415
+
416
+ # use second text encoder
417
+ prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True)
418
+ pooled_prompt_embeds = prompt_embeds_2[0]
419
+
420
+ if clip_skip is None:
421
+ prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
422
+ prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
423
+ else:
424
+ # "2" because SDXL always indexes from the penultimate layer.
425
+ prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-(clip_skip + 2)]
426
+ prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-(clip_skip + 2)]
427
+
428
+ prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
429
+ token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0)
430
+
431
+ for j in range(len(weight_tensor)):
432
+ if weight_tensor[j] != 1.0:
433
+ token_embedding[j] = (
434
+ token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]
435
+ )
436
+
437
+ token_embedding = token_embedding.unsqueeze(0)
438
+ embeds.append(token_embedding)
439
+
440
+ # get negative prompt embeddings with weights
441
+ neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=device)
442
+ neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=device)
443
+ neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=device)
444
+
445
+ # use first text encoder
446
+ neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(device), output_hidden_states=True)
447
+ neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
448
+
449
+ # use second text encoder
450
+ neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(device), output_hidden_states=True)
451
+ neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
452
+ negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
453
+
454
+ neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]
455
+ neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0)
456
+
457
+ for z in range(len(neg_weight_tensor)):
458
+ if neg_weight_tensor[z] != 1.0:
459
+ neg_token_embedding[z] = (
460
+ neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]
461
+ )
462
+
463
+ neg_token_embedding = neg_token_embedding.unsqueeze(0)
464
+ neg_embeds.append(neg_token_embedding)
465
+
466
+ prompt_embeds = torch.cat(embeds, dim=1)
467
+ negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
468
+
469
+ bs_embed, seq_len, _ = prompt_embeds.shape
470
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
471
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
472
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
473
+
474
+ seq_len = negative_prompt_embeds.shape[1]
475
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
476
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
477
+
478
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(
479
+ bs_embed * num_images_per_prompt, -1
480
+ )
481
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(
482
+ bs_embed * num_images_per_prompt, -1
483
+ )
484
+
485
+ if pipe.text_encoder is not None:
486
+ if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
487
+ # Retrieve the original scale by scaling back the LoRA layers
488
+ unscale_lora_layers(pipe.text_encoder, lora_scale)
489
+
490
+ if pipe.text_encoder_2 is not None:
491
+ if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
492
+ # Retrieve the original scale by scaling back the LoRA layers
493
+ unscale_lora_layers(pipe.text_encoder_2, lora_scale)
494
+
495
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
496
+
497
+ class ModImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):
498
+
499
+ def encode_prompt(self, prompt, num_images_per_prompt, negative_prompt, lora_scale, clip_skip, **kwags):
500
+ return get_weighted_text_embeddings_sdxl(
501
+ pipe=self,
502
+ prompt=prompt,
503
+ neg_prompt=negative_prompt,
504
+ num_images_per_prompt=num_images_per_prompt,
505
+ clip_skip=clip_skip,
506
+ lora_scale=lora_scale,
507
+ )
508
+
509
+ def get_timesteps(self, num_inference_steps, strength, device, **kwargs):
510
+ # get the original timestep using init_timestep
511
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
512
+
513
+ t_start = int(max(num_inference_steps - init_timestep, 0))
514
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
515
+ if hasattr(self.scheduler, "set_begin_index"):
516
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
517
+
518
+ return timesteps, num_inference_steps - t_start
519
+
520
+ class ModText2ImgPipeline(StableDiffusionXLPipeline):
521
+
522
+ def encode_prompt(self, prompt, num_images_per_prompt, negative_prompt, lora_scale, clip_skip, **kwags):
523
+ return get_weighted_text_embeddings_sdxl(
524
+ pipe=self,
525
+ prompt=prompt,
526
+ neg_prompt=negative_prompt,
527
+ num_images_per_prompt=num_images_per_prompt,
528
+ clip_skip=clip_skip,
529
+ lora_scale=lora_scale,
530
+ )
531
+
532
+ class ModFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
533
+
534
+ @property
535
+ def init_noise_sigma(self):
536
+ return 1.0
537
+
538
+ def scale_model_input(self, x, y):
539
+ return x
540
+
541
+ def add_noise(self, x, n, t):
542
+ return self.scale_noise(x, t, n)
543
+
544
+ def get_pipeline_initialize(model_1="", model_2=""):
545
+ pipe = ModText2ImgPipeline.from_single_file(
546
+ os.getenv("SDXL_MODEL", model_1),
547
+ torch_dtype=torch.float16
548
+ )
549
+ pipe.fuse_qkv_projections()
550
+ pipe.unet.set_attention_backend("_flash_3_hub")
551
+ pipe.scheduler = ModFlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=2.0)
552
+ pipe.unet.to(memory_format=torch.channels_last)
553
+ pipe.vae.to(memory_format=torch.channels_last)
554
+
555
+ pipe2 = ModText2ImgPipeline.from_single_file(
556
+ os.getenv("SDXL_MODEL_2", model_2),
557
+ torch_dtype=torch.float16
558
+ )
559
+ pipe2.fuse_qkv_projections()
560
+ pipe2.unet.set_attention_backend("_flash_3_hub")
561
+ pipe2.scheduler = ModFlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=2.0)
562
+ pipe2.unet.to(memory_format=torch.channels_last)
563
+ pipe2.vae.to(memory_format=torch.channels_last)
564
+
565
+ pipe_img2img = ModImg2ImgPipeline(
566
+ vae=pipe.vae,
567
+ unet=pipe.unet,
568
+ text_encoder=pipe.text_encoder,
569
+ text_encoder_2=pipe.text_encoder_2,
570
+ tokenizer=pipe.tokenizer,
571
+ tokenizer_2=pipe.tokenizer_2,
572
+ scheduler=pipe.scheduler,
573
+ image_encoder=pipe.image_encoder,
574
+ feature_extractor=pipe.feature_extractor,
575
+ )
576
+ pipe2_img2img = ModImg2ImgPipeline(
577
+ vae=pipe2.vae,
578
+ unet=pipe2.unet,
579
+ text_encoder=pipe2.text_encoder,
580
+ text_encoder_2=pipe2.text_encoder_2,
581
+ tokenizer=pipe2.tokenizer,
582
+ tokenizer_2=pipe2.tokenizer_2,
583
+ scheduler=pipe2.scheduler,
584
+ image_encoder=pipe2.image_encoder,
585
+ feature_extractor=pipe2.feature_extractor,
586
+ )
587
+ return pipe, pipe2, pipe_img2img, pipe2_img2img
588
 
589
  def sign_message(message, key):
590
  hmac_digest = hmac.new(key.encode(), message.encode(), hashlib.sha512).digest()
 
601
  init_image = init_image.resize((w, h))
602
  init_image = np.array(init_image)
603
 
 
 
 
 
 
 
 
 
604
  return zero_inference_api(prompt, radio, preset, h, w, negative_prompt, guidance_scale, randomize_seed, seed, do_img2img, init_image, image2image_strength, inference_steps=inference_steps)
605
 
606
  @spaces.GPU
 
687
  image.save(tmpfile, "png", pnginfo=metadata)
688
  return tmpfile.name, seed
689
 
690
+ pipe, pipe2, pipe_img2img, pipe2_img2img = pipeline.get_pipeline_initialize()
691
+
692
  with gr.Blocks(theme=theme) as demo:
693
  gr.Markdown('''# SDXL Experiments
694
  Just a simple demo for some SDXL model.''')
 
737
  outputs=[output, seed],
738
  concurrency_limit=1,
739
  )
740
+
741
  if __name__ == "__main__":
742
  demo.launch(share=True)