TSXu commited on
Commit
d3ccd4b
·
1 Parent(s): c6a1e05

Add batch generation, torch.compile acceleration, fix dtype issues

Browse files

- Add batch generation option (1-4 images) with gallery output
- Enable torch.compile() for inference acceleration
- Fix bfloat16/float32 dtype mismatches in pipeline
- Change default inference steps to 25
- Update examples with batch_size parameter

Files changed (4) hide show
  1. app.py +54 -27
  2. inference.py +8 -0
  3. src/flux/sampling.py +4 -4
  4. src/flux/xflux_pipeline.py +4 -4
app.py CHANGED
@@ -75,7 +75,7 @@ def init_generator():
75
  font_descriptions_path='dataset/chirography.json',
76
  author_descriptions_path='dataset/calligraphy_styles_en.json',
77
  use_deepspeed=False,
78
- use_4bit_quantization=False, # Disable 4-bit quantization for faster init
79
  )
80
  return generator
81
 
@@ -110,6 +110,7 @@ def generate_calligraphy(
110
  num_steps: int,
111
  seed: int,
112
  random_seed: bool,
 
113
  ):
114
  """
115
  Generate calligraphy based on user inputs
@@ -121,10 +122,13 @@ def generate_calligraphy(
121
  num_steps: Number of denoising steps
122
  seed: Random seed
123
  random_seed: Whether to use random seed
 
124
 
125
  Returns:
126
- Generated image and condition image
127
  """
 
 
128
  # Validate text - must be 1-7 characters
129
  if len(text) < 1:
130
  raise gr.Error("文本不能为空 / Text cannot be empty")
@@ -146,22 +150,34 @@ def generate_calligraphy(
146
 
147
  # Handle seed
148
  if random_seed:
149
- import torch
150
  seed = torch.randint(0, 2**32, (1,)).item()
151
 
152
  # Initialize generator if needed
153
  gen = init_generator()
154
 
155
- # Generate
156
- result_img, cond_img = gen.generate(
157
- text=text,
158
- font_style=font,
159
- author=author,
160
- num_steps=num_steps,
161
- seed=seed,
162
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- return result_img, f"Seed: {seed}"
165
 
166
 
167
  # Create Gradio interface
@@ -215,7 +231,7 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
215
  label="生成步数 / Inference Steps",
216
  minimum=10,
217
  maximum=50,
218
- value=39,
219
  step=1,
220
  info="更多步数 = 更高质量,但更慢 / More steps = higher quality, but slower"
221
  )
@@ -231,6 +247,15 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
231
  value=False
232
  )
233
 
 
 
 
 
 
 
 
 
 
234
  generate_btn = gr.Button("🎨 生成书法 / Generate Calligraphy", variant="primary", size="lg")
235
 
236
  with gr.Column(scale=1):
@@ -238,15 +263,15 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
238
  gr.Markdown("### 🖼️ 生成结果 / Generated Result")
239
  gr.Markdown("") # Add spacing
240
 
241
- with gr.Row():
242
- gr.Column(scale=1) # Left spacer
243
- with gr.Column(scale=2):
244
- output_image = gr.Image(
245
- show_label=False,
246
- type="pil",
247
- height=600
248
- )
249
- gr.Column(scale=1) # Right spacer
250
 
251
  seed_info = gr.Textbox(
252
  label="种子信息 / Seed Info",
@@ -283,18 +308,19 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
283
  num_steps,
284
  seed,
285
  random_seed,
 
286
  ],
287
- outputs=[output_image, seed_info]
288
  )
289
 
290
  # Examples
291
  gr.Markdown("### 📋 示例 / Examples")
292
  gr.Examples(
293
  examples=[
294
- ["春风得意马蹄疾", "赵佶\\宋徽宗", "楷 (Regular Script)", 39, 42, False],
295
- ["海内存知己", "黄庭坚", "行 (Running Script)", 39, 42, False],
296
- ["天道酬勤", "王羲之", "草 (Cursive Script)", 39, 42, False],
297
- ["宁静致��", "None (Synthetic / 合成风格)", "楷 (Regular Script)", 39, 42, False],
298
  ],
299
  inputs=[
300
  text_input,
@@ -303,6 +329,7 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
303
  num_steps,
304
  seed,
305
  random_seed,
 
306
  ],
307
  )
308
 
 
75
  font_descriptions_path='dataset/chirography.json',
76
  author_descriptions_path='dataset/calligraphy_styles_en.json',
77
  use_deepspeed=False,
78
+ use_4bit_quantization=False, # Disabled - quantization overhead not worth it
79
  )
80
  return generator
81
 
 
110
  num_steps: int,
111
  seed: int,
112
  random_seed: bool,
113
+ batch_size: int = 1,
114
  ):
115
  """
116
  Generate calligraphy based on user inputs
 
122
  num_steps: Number of denoising steps
123
  seed: Random seed
124
  random_seed: Whether to use random seed
125
+ batch_size: Number of images to generate
126
 
127
  Returns:
128
+ Generated images (gallery) and seed info
129
  """
130
+ import torch
131
+
132
  # Validate text - must be 1-7 characters
133
  if len(text) < 1:
134
  raise gr.Error("文本不能为空 / Text cannot be empty")
 
150
 
151
  # Handle seed
152
  if random_seed:
 
153
  seed = torch.randint(0, 2**32, (1,)).item()
154
 
155
  # Initialize generator if needed
156
  gen = init_generator()
157
 
158
+ # Generate batch of images
159
+ results = []
160
+ seeds_used = []
161
+
162
+ for i in range(batch_size):
163
+ current_seed = seed + i # Increment seed for each image in batch
164
+ result_img, cond_img = gen.generate(
165
+ text=text,
166
+ font_style=font,
167
+ author=author,
168
+ num_steps=num_steps,
169
+ seed=current_seed,
170
+ )
171
+ results.append(result_img)
172
+ seeds_used.append(current_seed)
173
+
174
+ # Format seed info
175
+ if batch_size == 1:
176
+ seed_info = f"Seed: {seeds_used[0]}"
177
+ else:
178
+ seed_info = f"Seeds: {seeds_used[0]} - {seeds_used[-1]} ({batch_size} images)"
179
 
180
+ return results, seed_info
181
 
182
 
183
  # Create Gradio interface
 
231
  label="生成步数 / Inference Steps",
232
  minimum=10,
233
  maximum=50,
234
+ value=25,
235
  step=1,
236
  info="更多步数 = 更高质量,但更慢 / More steps = higher quality, but slower"
237
  )
 
247
  value=False
248
  )
249
 
250
+ batch_size = gr.Slider(
251
+ label="批量生成数量 / Batch Size",
252
+ minimum=1,
253
+ maximum=4,
254
+ value=1,
255
+ step=1,
256
+ info="生成多张图片以选择最佳效果 / Generate multiple images to pick the best"
257
+ )
258
+
259
  generate_btn = gr.Button("🎨 生成书法 / Generate Calligraphy", variant="primary", size="lg")
260
 
261
  with gr.Column(scale=1):
 
263
  gr.Markdown("### 🖼️ 生成结果 / Generated Result")
264
  gr.Markdown("") # Add spacing
265
 
266
+ output_gallery = gr.Gallery(
267
+ label="生成结果 / Generated Results",
268
+ show_label=False,
269
+ columns=2,
270
+ rows=2,
271
+ height=650,
272
+ object_fit="contain",
273
+ allow_preview=True
274
+ )
275
 
276
  seed_info = gr.Textbox(
277
  label="种子信息 / Seed Info",
 
308
  num_steps,
309
  seed,
310
  random_seed,
311
+ batch_size,
312
  ],
313
+ outputs=[output_gallery, seed_info]
314
  )
315
 
316
  # Examples
317
  gr.Markdown("### 📋 示例 / Examples")
318
  gr.Examples(
319
  examples=[
320
+ ["春风得意马蹄疾", "赵佶\\宋徽宗", "楷 (Regular Script)", 25, 42, False, 1],
321
+ ["海内存知己", "黄庭坚", "行 (Running Script)", 25, 42, False, 1],
322
+ ["天道酬勤", "王羲之", "草 (Cursive Script)", 25, 42, False, 1],
323
+ ["宁静致", "None (Synthetic / 合成风格)", "楷 (Regular Script)", 25, 42, False, 1],
324
  ],
325
  inputs=[
326
  text_input,
 
329
  num_steps,
330
  seed,
331
  random_seed,
332
+ batch_size,
333
  ],
334
  )
335
 
inference.py CHANGED
@@ -338,6 +338,14 @@ class CalligraphyGenerator:
338
  if not use_deepspeed:
339
  print(f"Moving model to {self.device}...")
340
  model = model.to(self.device)
 
 
 
 
 
 
 
 
341
 
342
  return model
343
 
 
338
  if not use_deepspeed:
339
  print(f"Moving model to {self.device}...")
340
  model = model.to(self.device)
341
+
342
+ # Apply torch.compile for faster inference (PyTorch 2.0+)
343
+ try:
344
+ print("Applying torch.compile() for acceleration...")
345
+ model = torch.compile(model, mode="reduce-overhead")
346
+ print("torch.compile() applied successfully!")
347
+ except Exception as e:
348
+ print(f"torch.compile() not available or failed: {e}")
349
 
350
  return model
351
 
src/flux/sampling.py CHANGED
@@ -61,10 +61,10 @@ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[st
61
 
62
  return {
63
  "img": img,
64
- "img_ids": img_ids.to(img.device),
65
- "txt": txt.to(img.device, dtype=img_dtype),
66
- "txt_ids": txt_ids.to(img.device),
67
- "vec": vec.to(img.device, dtype=img_dtype),
68
  }
69
 
70
 
 
61
 
62
  return {
63
  "img": img,
64
+ "img_ids": img_ids.to(device=img.device, dtype=img_dtype),
65
+ "txt": txt.to(device=img.device, dtype=img_dtype),
66
+ "txt_ids": txt_ids.to(device=img.device, dtype=img_dtype),
67
+ "vec": vec.to(device=img.device, dtype=img_dtype),
68
  }
69
 
70
 
src/flux/xflux_pipeline.py CHANGED
@@ -195,13 +195,13 @@ class XFluxPipeline:
195
  padding="max_length",
196
  max_length=required_chars
197
  )["input_ids"]
198
- cond_txt_latent = self.embed_tokens(cond_text_token).to(self.device, torch.bfloat16)
199
 
200
  if not is_generation:
201
  cond_txt_latent = torch.rand(
202
  cond_txt_latent.size(),
203
  device=self.device,
204
- dtype=torch.bfloat16,
205
  generator=torch.Generator(device=self.device).manual_seed(seed)
206
  )
207
 
@@ -226,7 +226,7 @@ class XFluxPipeline:
226
  controlnet_image = self.annotator(controlnet_image, width, height)
227
  controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
228
  controlnet_image = controlnet_image.permute(
229
- 2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device)
230
 
231
  return self.forward(
232
  prompt,
@@ -313,7 +313,7 @@ class XFluxPipeline:
313
  ):
314
  x = get_noise(
315
  1, height, width, device=self.device,
316
- dtype=torch.bfloat16, seed=seed
317
  )
318
 
319
  timesteps = get_schedule(
 
195
  padding="max_length",
196
  max_length=required_chars
197
  )["input_ids"]
198
+ cond_txt_latent = self.embed_tokens(cond_text_token).to(self.device, torch.float32)
199
 
200
  if not is_generation:
201
  cond_txt_latent = torch.rand(
202
  cond_txt_latent.size(),
203
  device=self.device,
204
+ dtype=torch.float32,
205
  generator=torch.Generator(device=self.device).manual_seed(seed)
206
  )
207
 
 
226
  controlnet_image = self.annotator(controlnet_image, width, height)
227
  controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
228
  controlnet_image = controlnet_image.permute(
229
+ 2, 0, 1).unsqueeze(0).to(torch.float32).to(self.device)
230
 
231
  return self.forward(
232
  prompt,
 
313
  ):
314
  x = get_noise(
315
  1, height, width, device=self.device,
316
+ dtype=torch.float32, seed=seed
317
  )
318
 
319
  timesteps = get_schedule(