retwpay commited on
Commit
219e674
·
verified ·
1 Parent(s): 2c374c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -276
app.py CHANGED
@@ -7,7 +7,6 @@ import random
7
  from diffusers import StableDiffusionXLPipeline
8
  from diffusers import EulerAncestralDiscreteScheduler
9
  import torch
10
- import re
11
  from compel import Compel, ReturnedEmbeddingsType
12
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -29,223 +28,32 @@ pipe.text_encoder_2.to(torch.float16)
29
  pipe.vae.to(torch.float16)
30
  pipe.unet.to(torch.float16)
31
 
32
- # Initialize Compel for long prompt processing
33
  compel = Compel(
34
  tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
35
  text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
36
  returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
37
  requires_pooled=[False, True],
38
- truncate_long_prompts=False # Enable long prompt processing
39
  )
40
 
41
  MAX_SEED = np.iinfo(np.int32).max
42
  MAX_IMAGE_SIZE = 1216
43
 
44
- # =====================================
45
- # Long Prompt Processing Functions
46
- # =====================================
47
-
48
- def parse_prompt_attention(text):
49
- """Parse prompt with attention weights like (word:1.2) or [word:0.8]"""
50
- re_attention = re.compile(r"""
51
- \\\(|
52
- \\\)|
53
- \\\[|
54
- \\]|
55
- \\\\|
56
- \\|
57
- \(|
58
- \[|
59
- :([+-]?[.\d]+)\)|
60
- \)|
61
- ]|
62
- [^\\()\[\]:]+|
63
- :
64
- """, re.X)
65
-
66
- res = []
67
- round_brackets = []
68
- square_brackets = []
69
-
70
- round_bracket_multiplier = 1.1
71
- square_bracket_multiplier = 1 / 1.1
72
-
73
- def multiply_range(start_position, multiplier):
74
- for p in range(start_position, len(res)):
75
- res[p][1] *= multiplier
76
-
77
- for m in re_attention.finditer(text):
78
- text = m.group(0)
79
- weight = m.group(1)
80
-
81
- if text.startswith('\\'):
82
- res.append([text[1:], 1.0])
83
- elif text == '(':
84
- round_brackets.append(len(res))
85
- elif text == '[':
86
- square_brackets.append(len(res))
87
- elif weight is not None and len(round_brackets) > 0:
88
- multiply_range(round_brackets.pop(), float(weight))
89
- elif text == ')' and len(round_brackets) > 0:
90
- multiply_range(round_brackets.pop(), round_bracket_multiplier)
91
- elif text == ']' and len(square_brackets) > 0:
92
- multiply_range(square_brackets.pop(), square_bracket_multiplier)
93
- else:
94
- parts = re.split(re.compile(r"\s*\bBREAK\b\s*", re.S), text)
95
- for i, part in enumerate(parts):
96
- if i > 0:
97
- res.append(["BREAK", -1])
98
- res.append([part, 1.0])
99
-
100
- for pos in round_brackets:
101
- multiply_range(pos, round_bracket_multiplier)
102
-
103
- for pos in square_brackets:
104
- multiply_range(pos, square_bracket_multiplier)
105
-
106
- if len(res) == 0:
107
- res = [["", 1.0]]
108
-
109
- # merge runs of identical weights
110
- i = 0
111
- while i + 1 < len(res):
112
- if res[i][1] == res[i + 1][1]:
113
- res[i][0] += res[i + 1][0]
114
- res.pop(i + 1)
115
- else:
116
- i += 1
117
-
118
- return res
119
-
120
- def prompt_attention_to_invoke_prompt(attention):
121
- """Convert attention data back to compel format"""
122
- tokens = []
123
- for text, weight in attention:
124
- weight = round(weight, 2)
125
- if weight == 1.0:
126
- tokens.append(text)
127
- elif weight < 1.0:
128
- if weight < 0.8:
129
- tokens.append(f"({text}){weight}")
130
- else:
131
- tokens.append(f"({text})-" + "-" * int((1.0 - weight) * 10))
132
- else:
133
- if weight < 1.3:
134
- tokens.append(f"({text})" + "+" * int((weight - 1.0) * 10))
135
- else:
136
- tokens.append(f"({text}){weight}")
137
- return "".join(tokens)
138
-
139
- def tokenize_line(line, tokenizer):
140
- """Split long prompts into chunks at appropriate boundaries"""
141
- actual_prompt = line.lower().strip()
142
- actual_tokens = tokenizer.tokenize(actual_prompt)
143
- max_tokens = tokenizer.model_max_length - 2
144
- comma_token = tokenizer.tokenize(',')[0]
145
-
146
- chunks = []
147
- chunk = []
148
- for item in actual_tokens:
149
- chunk.append(item)
150
- if len(chunk) == max_tokens:
151
- if chunk[-1] != comma_token:
152
- for i in range(max_tokens-1, -1, -1):
153
- if chunk[i] == comma_token:
154
- actual_chunk, actual_prompt = detokenize(chunk[:i+1], actual_prompt)
155
- chunks.append(actual_chunk)
156
- chunk = chunk[i+1:]
157
- break
158
- else:
159
- actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
160
- chunks.append(actual_chunk)
161
- chunk = []
162
- else:
163
- actual_chunk, actual_prompt = detokenize(chunk, actual_prompt)
164
- chunks.append(actual_chunk)
165
- chunk = []
166
- if chunk:
167
- actual_chunk, _ = detokenize(chunk, actual_prompt)
168
- chunks.append(actual_chunk)
169
-
170
- return chunks
171
-
172
- def detokenize(chunk, actual_prompt):
173
- """Convert tokens back to text"""
174
- chunk[-1] = chunk[-1].replace('</w>', '')
175
- chanked_prompt = ''.join(chunk).strip()
176
- while '</w>' in chanked_prompt:
177
- if actual_prompt[chanked_prompt.find('</w>')] == ' ':
178
- chanked_prompt = chanked_prompt.replace('</w>', ' ', 1)
179
- else:
180
- chanked_prompt = chanked_prompt.replace('</w>', '', 1)
181
- actual_prompt = actual_prompt.replace(chanked_prompt,'')
182
- return chanked_prompt.strip(), actual_prompt.strip()
183
-
184
- def merge_embeds(prompt_chunks, compel):
185
- """Merge multiple prompt chunks with weighted combination"""
186
- num_chunks = len(prompt_chunks)
187
- if num_chunks != 0:
188
- power_prompt = 1/(num_chunks*(num_chunks+1)//2)
189
- prompt_embs = compel(prompt_chunks)
190
- t_list = list(torch.split(prompt_embs, 1, dim=0))
191
- for i in range(num_chunks):
192
- t_list[-(i+1)] = t_list[-(i+1)] * ((i+1)*power_prompt)
193
- prompt_emb = torch.stack(t_list, dim=0).sum(dim=0)
194
- else:
195
- prompt_emb = compel('')
196
- return prompt_emb
197
-
198
- def process_long_prompt(prompt, pipeline, compel, only_convert_string=False):
199
- """Main function to process long prompts with attention weights"""
200
-
201
- # Fix excessive emphasis symbols
202
- prompt = prompt.replace("((", "(").replace("))", ")").replace("\\", "\\\\\\")
203
-
204
- # Parse attention weights
205
- attention = parse_prompt_attention(prompt)
206
- global_attention_chunks = []
207
-
208
- for att in attention:
209
- for chunk in att[0].split(','):
210
- temp_prompt_chunks = tokenize_line(chunk, pipeline.tokenizer)
211
- for small_chunk in temp_prompt_chunks:
212
- temp_dict = {
213
- "weight": round(att[1], 2),
214
- "length": len(pipeline.tokenizer.tokenize(f'{small_chunk},')),
215
- "prompt": f'{small_chunk},'
216
- }
217
- global_attention_chunks.append(temp_dict)
218
-
219
- max_tokens = pipeline.tokenizer.model_max_length - 2
220
- global_prompt_chunks = []
221
- current_list = []
222
- current_length = 0
223
-
224
- for item in global_attention_chunks:
225
- if current_length + item['length'] > max_tokens:
226
- global_prompt_chunks.append(current_list)
227
- current_list = [[item['prompt'], item['weight']]]
228
- current_length = item['length']
229
- else:
230
- if not current_list:
231
- current_list.append([item['prompt'], item['weight']])
232
- else:
233
- if item['weight'] != current_list[-1][1]:
234
- current_list.append([item['prompt'], item['weight']])
235
- else:
236
- current_list[-1][0] += f" {item['prompt']}"
237
- current_length += item['length']
238
 
239
- if current_list:
240
- global_prompt_chunks.append(current_list)
241
-
242
- if only_convert_string:
243
- return ' '.join([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chunks])
244
-
245
- return merge_embeds([prompt_attention_to_invoke_prompt(i) for i in global_prompt_chunks], compel)
246
-
247
  @spaces.GPU
248
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, enable_long_prompt):
 
 
249
 
250
  if randomize_seed:
251
  seed = random.randint(0, MAX_SEED)
@@ -253,67 +61,49 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
253
  generator = torch.Generator(device=device).manual_seed(seed)
254
 
255
  try:
256
- if enable_long_prompt:
257
- # Use advanced prompt processing
258
- print("Using advanced long prompt processing...")
259
-
260
- # Process prompts with attention weights and chunking
261
- if not negative_prompt:
262
- negative_prompt = ""
263
-
264
- processed_prompt = process_long_prompt(prompt, pipe, compel, only_convert_string=True)
265
- processed_negative = process_long_prompt(negative_prompt, pipe, compel, only_convert_string=True)
266
 
267
- # Get embeddings
268
- conditioning, pooled = compel([processed_prompt, processed_negative])
269
-
270
- # Generate with embeddings
271
- output_image = pipe(
272
- prompt_embeds=conditioning[0:1],
273
- pooled_prompt_embeds=pooled[0:1],
274
- negative_prompt_embeds=conditioning[1:2],
275
- negative_pooled_prompt_embeds=pooled[1:2],
276
- guidance_scale=guidance_scale,
277
- num_inference_steps=num_inference_steps,
278
- width=width,
279
- height=height,
280
- generator=generator
281
- ).images[0]
282
-
283
- else:
284
- # Use standard processing with warning for long prompts
285
- if len(prompt.split()) > 60:
286
- print("Warning: Prompt may be too long. Consider enabling 'Long Prompt Processing'")
287
-
288
- output_image = pipe(
289
- prompt=prompt,
290
- negative_prompt=negative_prompt,
291
- guidance_scale=guidance_scale,
292
- num_inference_steps=num_inference_steps,
293
- width=width,
294
- height=height,
295
- generator=generator
296
- ).images[0]
297
 
298
- return output_image
 
 
 
 
 
 
 
 
 
299
 
 
300
  except RuntimeError as e:
301
  print(f"Error during generation: {e}")
 
302
  error_img = Image.new('RGB', (width, height), color=(0, 0, 0))
303
  return error_img
304
 
 
305
  css = """
306
  #col-container {
307
  margin: 0 auto;
308
  max-width: 520px;
309
  }
310
- .long-prompt-info {
311
- background-color: #f0f8ff;
312
- padding: 10px;
313
- border-radius: 5px;
314
- margin: 10px 0;
315
- font-size: 12px;
316
- }
317
  """
318
 
319
  with gr.Blocks(css=css) as demo:
@@ -324,8 +114,8 @@ with gr.Blocks(css=css) as demo:
324
  prompt = gr.Text(
325
  label="Prompt",
326
  show_label=False,
327
- max_lines=3, # Increased for longer prompts
328
- placeholder="Enter your prompt. Use (word:1.2) for emphasis or [word:0.8] for de-emphasis",
329
  container=False,
330
  )
331
 
@@ -334,28 +124,11 @@ with gr.Blocks(css=css) as demo:
334
  result = gr.Image(label="Result", show_label=False)
335
 
336
  with gr.Accordion("Advanced Settings", open=False):
337
-
338
- # Long prompt processing toggle
339
- enable_long_prompt = gr.Checkbox(
340
- label="Enable Long Prompt Processing",
341
- value=True,
342
- info="Process very long prompts with attention weights like (word:1.2) or [word:0.8]"
343
- )
344
-
345
- with gr.Column(elem_class="long-prompt-info"):
346
- gr.HTML("""
347
- <strong>Long Prompt Features:</strong><br>
348
- • <code>(word:1.2)</code> - Increase attention to 'word' by 1.2x<br>
349
- • <code>[word:0.8]</code> - Decrease attention to 'word' by 0.8x<br>
350
- • <code>((word))</code> - Strong emphasis (1.21x)<br>
351
- • <code>[[word]]</code> - Strong de-emphasis (0.83x)<br>
352
- • No token limit - write detailed prompts!
353
- """)
354
 
355
  negative_prompt = gr.Text(
356
  label="Negative prompt",
357
- max_lines=2, # Increased for longer negative prompts
358
- placeholder="Enter a negative prompt (supports same weight syntax)",
359
  value="nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn"
360
  )
361
 
@@ -405,7 +178,7 @@ with gr.Blocks(css=css) as demo:
405
 
406
  run_button.click(
407
  fn=infer,
408
- inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, enable_long_prompt],
409
  outputs=[result]
410
  )
411
 
 
7
  from diffusers import StableDiffusionXLPipeline
8
  from diffusers import EulerAncestralDiscreteScheduler
9
  import torch
 
10
  from compel import Compel, ReturnedEmbeddingsType
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
28
  pipe.vae.to(torch.float16)
29
  pipe.unet.to(torch.float16)
30
 
31
+ # 追加: Initialize Compel for long prompt processing
32
  compel = Compel(
33
  tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
34
  text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
35
  returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
36
  requires_pooled=[False, True],
37
+ truncate_long_prompts=False
38
  )
39
 
40
  MAX_SEED = np.iinfo(np.int32).max
41
  MAX_IMAGE_SIZE = 1216
42
 
43
+ # 追加: Simple long prompt processing function
44
+ def process_long_prompt(prompt, negative_prompt=""):
45
+ """Simple long prompt processing using Compel"""
46
+ try:
47
+ conditioning, pooled = compel([prompt, negative_prompt])
48
+ return conditioning, pooled
49
+ except Exception as e:
50
+ print(f"Long prompt processing failed: {e}, falling back to standard processing")
51
+ return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
 
 
 
 
 
 
 
 
53
  @spaces.GPU
54
+ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
55
+ # 変更: Remove the 60-word limit warning and add long prompt check
56
+ use_long_prompt = len(prompt.split()) > 60 or len(prompt) > 300
57
 
58
  if randomize_seed:
59
  seed = random.randint(0, MAX_SEED)
 
61
  generator = torch.Generator(device=device).manual_seed(seed)
62
 
63
  try:
64
+ # 追加: Try long prompt processing first if prompt is long
65
+ if use_long_prompt:
66
+ print("Using long prompt processing...")
67
+ conditioning, pooled = process_long_prompt(prompt, negative_prompt)
 
 
 
 
 
 
68
 
69
+ if conditioning is not None:
70
+ output_image = pipe(
71
+ prompt_embeds=conditioning[0:1],
72
+ pooled_prompt_embeds=pooled[0:1],
73
+ negative_prompt_embeds=conditioning[1:2],
74
+ negative_pooled_prompt_embeds=pooled[1:2],
75
+ guidance_scale=guidance_scale,
76
+ num_inference_steps=num_inference_steps,
77
+ width=width,
78
+ height=height,
79
+ generator=generator
80
+ ).images[0]
81
+ return output_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # Fall back to standard processing
84
+ output_image = pipe(
85
+ prompt=prompt,
86
+ negative_prompt=negative_prompt,
87
+ guidance_scale=guidance_scale,
88
+ num_inference_steps=num_inference_steps,
89
+ width=width,
90
+ height=height,
91
+ generator=generator
92
+ ).images[0]
93
 
94
+ return output_image
95
  except RuntimeError as e:
96
  print(f"Error during generation: {e}")
97
+ # Return a blank image with error message
98
  error_img = Image.new('RGB', (width, height), color=(0, 0, 0))
99
  return error_img
100
 
101
+
102
  css = """
103
  #col-container {
104
  margin: 0 auto;
105
  max-width: 520px;
106
  }
 
 
 
 
 
 
 
107
  """
108
 
109
  with gr.Blocks(css=css) as demo:
 
114
  prompt = gr.Text(
115
  label="Prompt",
116
  show_label=False,
117
+ max_lines=3, # 変更: Increased from 1 to 3 for longer prompts
118
+ placeholder="Enter your prompt (long prompts are automatically supported)", # 変更: Updated placeholder
119
  container=False,
120
  )
121
 
 
124
  result = gr.Image(label="Result", show_label=False)
125
 
126
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  negative_prompt = gr.Text(
129
  label="Negative prompt",
130
+ max_lines=2, # 変更: Increased from 1 to 2
131
+ placeholder="Enter a negative prompt",
132
  value="nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn"
133
  )
134
 
 
178
 
179
  run_button.click(
180
  fn=infer,
181
+ inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
182
  outputs=[result]
183
  )
184