retwpay commited on
Commit
bd37b6c
·
verified ·
1 Parent(s): 49e16ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +287 -21
app.py CHANGED
@@ -7,6 +7,8 @@ import random
7
  from diffusers import StableDiffusionXLPipeline
8
  from diffusers import EulerAncestralDiscreteScheduler
9
  import torch
 
 
10
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
@@ -27,44 +29,291 @@ pipe.text_encoder_2.to(torch.float16)
27
  pipe.vae.to(torch.float16)
28
  pipe.unet.to(torch.float16)
29
 
 
 
 
 
 
 
 
 
 
30
  MAX_SEED = np.iinfo(np.int32).max
31
  MAX_IMAGE_SIZE = 1216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  @spaces.GPU
34
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
35
- # Check and truncate prompt if too long (CLIP can only handle 77 tokens)
36
- if len(prompt.split()) > 60: # Rough estimate to avoid exceeding token limit
37
- print("Warning: Prompt may be too long and will be truncated by the model")
38
-
39
  if randomize_seed:
40
  seed = random.randint(0, MAX_SEED)
41
 
42
  generator = torch.Generator(device=device).manual_seed(seed)
43
 
44
  try:
45
- output_image = pipe(
46
- prompt=prompt,
47
- negative_prompt=negative_prompt,
48
- guidance_scale=guidance_scale,
49
- num_inference_steps=num_inference_steps,
50
- width=width,
51
- height=height,
52
- generator=generator
53
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  return output_image
 
56
  except RuntimeError as e:
57
  print(f"Error during generation: {e}")
58
- # Return a blank image with error message
59
  error_img = Image.new('RGB', (width, height), color=(0, 0, 0))
60
  return error_img
61
 
62
-
63
  css = """
64
  #col-container {
65
  margin: 0 auto;
66
  max-width: 520px;
67
  }
 
 
 
 
 
 
 
68
  """
69
 
70
  with gr.Blocks(css=css) as demo:
@@ -75,8 +324,8 @@ with gr.Blocks(css=css) as demo:
75
  prompt = gr.Text(
76
  label="Prompt",
77
  show_label=False,
78
- max_lines=1,
79
- placeholder="Enter your prompt (keep it under 60 words for best results)",
80
  container=False,
81
  )
82
 
@@ -85,11 +334,28 @@ with gr.Blocks(css=css) as demo:
85
  result = gr.Image(label="Result", show_label=False)
86
 
87
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  negative_prompt = gr.Text(
90
  label="Negative prompt",
91
- max_lines=1,
92
- placeholder="Enter a negative prompt",
93
  value="nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn"
94
  )
95
 
@@ -139,7 +405,7 @@ with gr.Blocks(css=css) as demo:
139
 
140
  run_button.click(
141
  fn=infer,
142
- inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
143
  outputs=[result]
144
  )
145
 
 
7
  from diffusers import StableDiffusionXLPipeline
8
  from diffusers import EulerAncestralDiscreteScheduler
9
  import torch
10
+ import re
11
+ from compel import Compel, ReturnedEmbeddingsType
12
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
 
29
  pipe.vae.to(torch.float16)
30
  pipe.unet.to(torch.float16)
31
 
32
+ # Initialize Compel for long prompt processing
33
+ compel = Compel(
34
+ tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
35
+ text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
36
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
37
+ requires_pooled=[False, True],
38
+ truncate_long_prompts=False # Enable long prompt processing
39
+ )
40
+
41
  MAX_SEED = np.iinfo(np.int32).max
42
  MAX_IMAGE_SIZE = 1216
43
+
44
+ # =====================================
45
+ # Long Prompt Processing Functions
46
+ # =====================================
47
+
48
+ def parse_prompt_attention(text):
49
+ """Parse prompt with attention weights like (word:1.2) or [word:0.8]"""
50
+ re_attention = re.compile(r"""
51
+ \\\(|
52
+ \\\)|
53
+ \\\[|
54
+ \\]|
55
+ \\\\|
56
+ \\|
57
+ \(|
58
+ \[|
59
+ :([+-]?[.\d]+)\)|
60
+ \)|
61
+ ]|
62
+ [^\\()\[\]:]+|
63
+ :
64
+ """, re.X)
65
+
66
+ res = []
67
+ round_brackets = []
68
+ square_brackets = []
69
+
70
+ round_bracket_multiplier = 1.1
71
+ square_bracket_multiplier = 1 / 1.1
72
+
73
+ def multiply_range(start_position, multiplier):
74
+ for p in range(start_position, len(res)):
75
+ res[p][1] *= multiplier
76
+
77
+ for m in re_attention.finditer(text):
78
+ text = m.group(0)
79
+ weight = m.group(1)
80
+
81
+ if text.startswith('\\'):
82
+ res.append([text[1:], 1.0])
83
+ elif text == '(':
84
+ round_brackets.append(len(res))
85
+ elif text == '[':
86
+ square_brackets.append(len(res))
87
+ elif weight is not None and len(round_brackets) > 0:
88
+ multiply_range(round_brackets.pop(), float(weight))
89
+ elif text == ')' and len(round_brackets) > 0:
90
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
91
+ elif text == ']' and len(square_brackets) > 0:
92
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
93
+ else:
94
+ parts = re.split(re.compile(r"\s*\bBREAK\b\s*", re.S), text)
95
+ for i, part in enumerate(parts):
96
+ if i > 0:
97
+ res.append(["BREAK", -1])
98
+ res.append([part, 1.0])
99
+
100
+ for pos in round_brackets:
101
+ multiply_range(pos, round_bracket_multiplier)
102
+
103
+ for pos in square_brackets:
104
+ multiply_range(pos, square_bracket_multiplier)
105
+
106
+ if len(res) == 0:
107
+ res = [["", 1.0]]
108
+
109
+ # merge runs of identical weights
110
+ i = 0
111
+ while i + 1 < len(res):
112
+ if res[i][1] == res[i + 1][1]:
113
+ res[i][0] += res[i + 1][0]
114
+ res.pop(i + 1)
115
+ else:
116
+ i += 1
117
+
118
+ return res
119
+
120
+ def prompt_attention_to_invoke_prompt(attention):
121
+ """Convert attention data back to compel format"""
122
+ tokens = []
123
+ for text, weight in attention:
124
+ weight = round(weight, 2)
125
+ if weight == 1.0:
126
+ tokens.append(text)
127
+ elif weight < 1.0:
128
+ if weight < 0.8:
129
+ tokens.append(f"({text}){weight}")
130
+ else:
131
+ tokens.append(f"({text})-" + "-" * int((1.0 - weight) * 10))
132
+ else:
133
+ if weight < 1.3:
134
+ tokens.append(f"({text})" + "+" * int((weight - 1.0) * 10))
135
+ else:
136
+ tokens.append(f"({text}){weight}")
137
+ return "".join(tokens)
138
+
139
+ def tokenize_line(line, tokenizer):
140
+ """Split long prompts into chunks at appropriate boundaries"""
141
+ actual_prompt = line.lower().strip()
142
+ actual_tokens = tokenizer.tokenize(actual_prompt)
143
+ max_tokens = tokenizer.model_max_length - 2
144
+ comma_token = tokenizer.tokenize(',')[0]
145
+
146
+ chunks = []
147
+ chunk = []
148
+ for item in actual_tokens:
149
+ chunk.append(item)
150
+ if len(chunk) == max_tokens:
151
+ if chunk[-1] != comma_token:
152
+ for i in range(max_tokens-1, -1, -1):
153
+ if chunk[i] == comma_token:
154
+ actual_chunk, actual_prompt = detokenize(chunk[:i+1], actual_prompt)
155
+ chunks.append(actual_chunk)
156
+ chunk = chunk[i+1:]
157
+ break
158
+ else:
159
+ actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
160
+ chunks.append(actual_chunk)
161
+ chunk = []
162
+ else:
163
+ actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
164
+ chunks.append(actual_chunk)
165
+ chunk = []
166
+ if chunk:
167
+ actual_chunk, _ = detokenize(chunk, actual_prompt)
168
+ chunks.append(actual_chunk)
169
+
170
+ return chunks
171
+
172
+ def detokenize(chunk, actual_prompt):
173
+ """Convert tokens back to text"""
174
+ chunk[-1] = chunk[-1].replace('</w>', '')
175
+ chanked_prompt = ''.join(chunk).strip()
176
+ while '</w>' in chanked_prompt:
177
+ if actual_prompt[chanked_prompt.find('</w>')] == ' ':
178
+ chanked_prompt = chanked_prompt.replace('</w>', ' ', 1)
179
+ else:
180
+ chanked_prompt = chanked_prompt.replace('</w>', '', 1)
181
+ actual_prompt = actual_prompt.replace(chanked_prompt,'')
182
+ return chanked_prompt.strip(), actual_prompt.strip()
183
+
184
+ def merge_embeds(prompt_chunks, compel):
185
+ """Merge multiple prompt chunks with weighted combination"""
186
+ num_chunks = len(prompt_chunks)
187
+ if num_chunks != 0:
188
+ power_prompt = 1/(num_chunks*(num_chunks+1)//2)
189
+ prompt_embs = compel(prompt_chunks)
190
+ t_list = list(torch.split(prompt_embs, 1, dim=0))
191
+ for i in range(num_chunks):
192
+ t_list[-(i+1)] = t_list[-(i+1)] * ((i+1)*power_prompt)
193
+ prompt_emb = torch.stack(t_list, dim=0).sum(dim=0)
194
+ else:
195
+ prompt_emb = compel('')
196
+ return prompt_emb
197
+
198
+ def process_long_prompt(prompt, pipeline, compel, only_convert_string=False):
199
+ """Main function to process long prompts with attention weights"""
200
 
201
+ # Fix excessive emphasis symbols
202
+ prompt = prompt.replace("((", "(").replace("))", ")").replace("\\", "\\\\\\")
203
+
204
+ # Parse attention weights
205
+ attention = parse_prompt_attention(prompt)
206
+ global_attention_chunks = []
207
+
208
+ for att in attention:
209
+ for chunk in att[0].split(','):
210
+ temp_prompt_chunks = tokenize_line(chunk, pipeline.tokenizer)
211
+ for small_chunk in temp_prompt_chunks:
212
+ temp_dict = {
213
+ "weight": round(att[1], 2),
214
+ "length": len(pipeline.tokenizer.tokenize(f'{small_chunk},')),
215
+ "prompt": f'{small_chunk},'
216
+ }
217
+ global_attention_chunks.append(temp_dict)
218
+
219
+ max_tokens = pipeline.tokenizer.model_max_length - 2
220
+ global_prompt_chunks = []
221
+ current_list = []
222
+ current_length = 0
223
+
224
+ for item in global_attention_chunks:
225
+ if current_length + item['length'] > max_tokens:
226
+ global_prompt_chunks.append(current_list)
227
+ current_list = [[item['prompt'], item['weight']]]
228
+ current_length = item['length']
229
+ else:
230
+ if not current_list:
231
+ current_list.append([item['prompt'], item['weight']])
232
+ else:
233
+ if item['weight'] != current_list[-1][1]:
234
+ current_list.append([item['prompt'], item['weight']])
235
+ else:
236
+ current_list[-1][0] += f" {item['prompt']}"
237
+ current_length += item['length']
238
+
239
+ if current_list:
240
+ global_prompt_chunks.append(current_list)
241
+
242
+ if only_convert_string:
243
+ return ' '.join([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chunks])
244
+
245
+ return merge_embeds([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chunks], compel)
246
+
247
  @spaces.GPU
248
+ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, enable_long_prompt):
249
+
 
 
 
250
  if randomize_seed:
251
  seed = random.randint(0, MAX_SEED)
252
 
253
  generator = torch.Generator(device=device).manual_seed(seed)
254
 
255
  try:
256
+ if enable_long_prompt:
257
+ # Use advanced prompt processing
258
+ print("Using advanced long prompt processing...")
259
+
260
+ # Process prompts with attention weights and chunking
261
+ if not negative_prompt:
262
+ negative_prompt = ""
263
+
264
+ processed_prompt = process_long_prompt(prompt, pipe, compel, only_convert_string=True)
265
+ processed_negative = process_long_prompt(negative_prompt, pipe, compel, only_convert_string=True)
266
+
267
+ # Get embeddings
268
+ conditioning, pooled = compel([processed_prompt, processed_negative])
269
+
270
+ # Generate with embeddings
271
+ output_image = pipe(
272
+ prompt_embeds=conditioning[0:1],
273
+ pooled_prompt_embeds=pooled[0:1],
274
+ negative_prompt_embeds=conditioning[1:2],
275
+ negative_pooled_prompt_embeds=pooled[1:2],
276
+ guidance_scale=guidance_scale,
277
+ num_inference_steps=num_inference_steps,
278
+ width=width,
279
+ height=height,
280
+ generator=generator
281
+ ).images[0]
282
+
283
+ else:
284
+ # Use standard processing with warning for long prompts
285
+ if len(prompt.split()) > 60:
286
+ print("Warning: Prompt may be too long. Consider enabling 'Long Prompt Processing'")
287
+
288
+ output_image = pipe(
289
+ prompt=prompt,
290
+ negative_prompt=negative_prompt,
291
+ guidance_scale=guidance_scale,
292
+ num_inference_steps=num_inference_steps,
293
+ width=width,
294
+ height=height,
295
+ generator=generator
296
+ ).images[0]
297
 
298
  return output_image
299
+
300
  except RuntimeError as e:
301
  print(f"Error during generation: {e}")
 
302
  error_img = Image.new('RGB', (width, height), color=(0, 0, 0))
303
  return error_img
304
 
 
305
  css = """
306
  #col-container {
307
  margin: 0 auto;
308
  max-width: 520px;
309
  }
310
+ .long-prompt-info {
311
+ background-color: #f0f8ff;
312
+ padding: 10px;
313
+ border-radius: 5px;
314
+ margin: 10px 0;
315
+ font-size: 12px;
316
+ }
317
  """
318
 
319
  with gr.Blocks(css=css) as demo:
 
324
  prompt = gr.Text(
325
  label="Prompt",
326
  show_label=False,
327
+ max_lines=3, # Increased for longer prompts
328
+ placeholder="Enter your prompt. Use (word:1.2) for emphasis or [word:0.8] for de-emphasis",
329
  container=False,
330
  )
331
 
 
334
  result = gr.Image(label="Result", show_label=False)
335
 
336
  with gr.Accordion("Advanced Settings", open=False):
337
+
338
+ # Long prompt processing toggle
339
+ enable_long_prompt = gr.Checkbox(
340
+ label="Enable Long Prompt Processing",
341
+ value=True,
342
+ info="Process very long prompts with attention weights like (word:1.2) or [word:0.8]"
343
+ )
344
+
345
+ with gr.Column(elem_class="long-prompt-info"):
346
+ gr.HTML("""
347
+ <strong>Long Prompt Features:</strong><br>
348
+ • <code>(word:1.2)</code> - Increase attention to 'word' by 1.2x<br>
349
+ • <code>[word:0.8]</code> - Decrease attention to 'word' by 0.8x<br>
350
+ • <code>((word))</code> - Strong emphasis (1.21x)<br>
351
+ • <code>[[word]]</code> - Strong de-emphasis (0.83x)<br>
352
+ • No token limit - write detailed prompts!
353
+ """)
354
 
355
  negative_prompt = gr.Text(
356
  label="Negative prompt",
357
+ max_lines=2, # Increased for longer negative prompts
358
+ placeholder="Enter a negative prompt (supports same weight syntax)",
359
  value="nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn"
360
  )
361
 
 
405
 
406
  run_button.click(
407
  fn=infer,
408
+ inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, enable_long_prompt],
409
  outputs=[result]
410
  )
411