TSXu commited on
Commit
89e2699
·
1 Parent(s): e7cbbce

UI improvements: move status bar to right side, simplify layout, update defaults to Wang Xizhi

Browse files
app.py CHANGED
@@ -1,6 +1,7 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
  Gradio Demo for Chinese Calligraphy Generation - HuggingFace Space Version
 
4
  """
5
 
6
  # IMPORTANT: import spaces first before any CUDA-related packages
@@ -9,6 +10,7 @@ import spaces
9
  import gradio as gr
10
  import json
11
  import csv
 
12
 
13
  # Load author and font mappings from CSV
14
  def load_author_fonts_from_csv(csv_path):
@@ -83,84 +85,81 @@ def init_generator():
83
  def update_font_choices(author: str):
84
  """
85
  Update available font choices based on selected author
86
-
87
- Args:
88
- author: Selected author name
89
-
90
- Returns:
91
- Updated dropdown with available fonts for the author
92
  """
93
  if author == "None (Synthetic / 合成风格)" or author not in AUTHOR_FONTS:
94
- # If no author or synthetic, show all font types
95
  choices = list(FONT_STYLE_NAMES.values())
96
  else:
97
- # Show only fonts available for this author
98
  available_fonts = AUTHOR_FONTS[author]
99
  choices = [FONT_STYLE_NAMES[font] for font in available_fonts if font in FONT_STYLE_NAMES]
100
 
101
- # Return updated dropdown with first choice as default
102
  return gr.Dropdown(choices=choices, value=choices[0] if choices else None)
103
 
104
 
105
- @spaces.GPU(duration=300) # 5 minutes for model loading + generation
106
- def generate_calligraphy(
 
 
 
 
 
 
 
 
107
  text: str,
108
  author_dropdown: str,
109
  font_style: str,
110
  num_steps: int,
111
- seed: int,
112
- random_seed: bool,
113
- batch_size: int = 1,
114
  ):
115
  """
116
- Generate calligraphy based on user inputs
117
 
118
  Args:
119
  text: Input text (1-7 characters)
120
- author_dropdown: Selected author from dropdown
121
- font_style: Selected font style (display name)
122
- num_steps: Number of denoising steps
123
- seed: Random seed
124
- random_seed: Whether to use random seed
125
- batch_size: Number of images to generate
126
 
127
- Returns:
128
- Generated images (gallery) and seed info
129
  """
130
  import torch
131
 
132
- # Validate text - must be 1-7 characters
133
  if len(text) < 1:
134
  raise gr.Error("文本不能为空 / Text cannot be empty")
135
  if len(text) > 7:
136
  raise gr.Error(f"文本最多7个字符 / Text must be at most 7 characters. Current: {len(text)}")
137
 
138
- # Extract font style value from display name
139
- font = None
140
- for font_key, font_display in FONT_STYLE_NAMES.items():
141
- if font_display == font_style:
142
- font = font_key
143
- break
144
-
145
  if font is None:
146
  raise gr.Error(f"无法识别的字体风格 / Unknown font style: {font_style}")
147
 
148
  # Determine author
149
  author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
150
 
151
- # Handle seed
152
- if random_seed:
153
- seed = torch.randint(0, 2**32, (1,)).item()
154
 
155
- # Initialize generator if needed
156
  gen = init_generator()
157
 
158
- # Generate batch of images
 
 
159
  results = []
160
  seeds_used = []
161
 
162
- for i in range(batch_size):
163
- current_seed = seed + i # Increment seed for each image in batch
 
 
 
 
164
  result_img, cond_img = gen.generate(
165
  text=text,
166
  font_style=font,
@@ -168,16 +167,19 @@ def generate_calligraphy(
168
  num_steps=num_steps,
169
  seed=current_seed,
170
  )
171
- results.append(result_img)
 
172
  seeds_used.append(current_seed)
 
 
 
173
 
174
- # Format seed info
175
- if batch_size == 1:
176
- seed_info = f"Seed: {seeds_used[0]}"
177
  else:
178
- seed_info = f"Seeds: {seeds_used[0]} - {seeds_used[-1]} ({batch_size} images)"
179
-
180
- return results, seed_info
181
 
182
 
183
  # Create Gradio interface
@@ -199,8 +201,8 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
199
 
200
  text_input = gr.Textbox(
201
  label="输入文本 / Input Text (1-7个字符 / 1-7 characters)",
202
- placeholder="请输入1-7个汉字 / Enter 1-7 Chinese characters, e.g.: 春风得意马蹄疾",
203
- value="春风得意马蹄疾",
204
  max_lines=1
205
  )
206
 
@@ -209,19 +211,19 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
209
  author_dropdown = gr.Dropdown(
210
  label="1. 选择书法家 / Select Calligrapher",
211
  choices=["None (Synthetic / 合成风格)"] + AUTHOR_LIST,
212
- value="赵佶\\宋徽宗",
213
  info="先选择历史书法家 / Choose a historical calligrapher first"
214
  )
215
 
216
- # Get initial fonts for default author (赵佶\宋徽宗)
217
- initial_author = "赵佶\\宋徽宗"
218
  initial_fonts = AUTHOR_FONTS.get(initial_author, ["楷", "草", "行"])
219
  initial_font_choices = [FONT_STYLE_NAMES[f] for f in initial_fonts if f in FONT_STYLE_NAMES]
220
 
221
  font_style = gr.Dropdown(
222
  label="2. 选择字体风格 / Select Font Style",
223
  choices=initial_font_choices,
224
- value=" (Regular Script)",
225
  info="根据所选书法家显示可用字体 / Shows available fonts for selected calligrapher"
226
  )
227
 
@@ -236,45 +238,40 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
236
  info="更多步数 = 更高质量,但更慢 / More steps = higher quality, but slower"
237
  )
238
 
239
- with gr.Row():
240
- seed = gr.Number(
241
- label="随机种子 / Seed",
242
- value=42,
243
- precision=0
244
- )
245
- random_seed = gr.Checkbox(
246
- label="随机种子 / Random Seed",
247
- value=False
248
- )
249
 
250
- batch_size = gr.Slider(
251
- label="批量生成数量 / Batch Size",
252
  minimum=1,
253
- maximum=4,
254
  value=1,
255
- step=1,
256
- info="生成多张图片以选择最佳效果 / Generate multiple images to pick the best"
257
  )
258
 
259
- generate_btn = gr.Button("🎨 生成书法 / Generate Calligraphy", variant="primary", size="lg")
260
 
261
  with gr.Column(scale=1):
262
  # Output section
263
- gr.Markdown("### 🖼️ 生成结果 / Generated Result")
264
- gr.Markdown("") # Add spacing
265
 
266
  output_gallery = gr.Gallery(
267
  label="生成结果 / Generated Results",
268
  show_label=False,
269
  columns=2,
270
  rows=2,
271
- height=650,
272
  object_fit="contain",
273
  allow_preview=True
274
  )
275
 
276
- seed_info = gr.Textbox(
277
- label="种子信息 / Seed Info",
 
278
  interactive=False
279
  )
280
 
@@ -291,45 +288,42 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
291
  gr.Markdown(author_info_md)
292
 
293
  # Event handlers
294
- # Update font choices when author changes
295
  author_dropdown.change(
296
  fn=update_font_choices,
297
  inputs=[author_dropdown],
298
  outputs=[font_style]
299
  )
300
 
301
- # Generate button click
302
  generate_btn.click(
303
- fn=generate_calligraphy,
304
  inputs=[
305
  text_input,
306
  author_dropdown,
307
  font_style,
308
  num_steps,
309
- seed,
310
- random_seed,
311
- batch_size,
312
  ],
313
- outputs=[output_gallery, seed_info]
314
  )
315
 
316
  # Examples
317
  gr.Markdown("### 📋 示例 / Examples")
318
  gr.Examples(
319
  examples=[
320
- ["春风得意马蹄疾", "赵佶\\宋徽宗", " (Regular Script)", 25, 42, False, 1],
321
- ["海内存知己", "黄庭坚", " (Running Script)", 25, 42, False, 1],
322
- ["天道酬勤", "王羲之", " (Cursive Script)", 25, 42, False, 1],
323
- ["宁静致远", "None (Synthetic / 合成风格)", "楷 (Regular Script)", 25, 42, False, 1],
324
  ],
325
  inputs=[
326
  text_input,
327
  author_dropdown,
328
  font_style,
329
  num_steps,
330
- seed,
331
- random_seed,
332
- batch_size,
333
  ],
334
  )
335
 
 
1
  # -*- coding: utf-8 -*-
2
  """
3
  Gradio Demo for Chinese Calligraphy Generation - HuggingFace Space Version
4
+ With interactive session mode to avoid model reloading
5
  """
6
 
7
  # IMPORTANT: import spaces first before any CUDA-related packages
 
10
  import gradio as gr
11
  import json
12
  import csv
13
+ import time
14
 
15
  # Load author and font mappings from CSV
16
  def load_author_fonts_from_csv(csv_path):
 
85
  def update_font_choices(author: str):
86
  """
87
  Update available font choices based on selected author
 
 
 
 
 
 
88
  """
89
  if author == "None (Synthetic / 合成风格)" or author not in AUTHOR_FONTS:
 
90
  choices = list(FONT_STYLE_NAMES.values())
91
  else:
 
92
  available_fonts = AUTHOR_FONTS[author]
93
  choices = [FONT_STYLE_NAMES[font] for font in available_fonts if font in FONT_STYLE_NAMES]
94
 
 
95
  return gr.Dropdown(choices=choices, value=choices[0] if choices else None)
96
 
97
 
98
+ def parse_font_style(font_style: str) -> str:
99
+ """Extract font key from display name"""
100
+ for font_key, font_display in FONT_STYLE_NAMES.items():
101
+ if font_display == font_style:
102
+ return font_key
103
+ return None
104
+
105
+
106
+ @spaces.GPU(duration=600) # 10 minutes session for multiple generations
107
+ def interactive_session(
108
  text: str,
109
  author_dropdown: str,
110
  font_style: str,
111
  num_steps: int,
112
+ start_seed: int,
113
+ num_images: int,
114
+ progress=gr.Progress()
115
  ):
116
  """
117
+ Interactive session: load model once, generate multiple images
118
 
119
  Args:
120
  text: Input text (1-7 characters)
121
+ author_dropdown: Selected author
122
+ font_style: Font style
123
+ num_steps: Inference steps
124
+ start_seed: Starting seed
125
+ num_images: Number of images to generate (each with different seed)
 
126
 
127
+ Yields:
128
+ Progress status, gallery of results
129
  """
130
  import torch
131
 
132
+ # Validate text
133
  if len(text) < 1:
134
  raise gr.Error("文本不能为空 / Text cannot be empty")
135
  if len(text) > 7:
136
  raise gr.Error(f"文本最多7个字符 / Text must be at most 7 characters. Current: {len(text)}")
137
 
138
+ # Parse font style
139
+ font = parse_font_style(font_style)
 
 
 
 
 
140
  if font is None:
141
  raise gr.Error(f"无法识别的字体风格 / Unknown font style: {font_style}")
142
 
143
  # Determine author
144
  author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
145
 
146
+ # Step 1: Load model (only once per session)
147
+ yield "⏳ 正在加载模型... / Loading model...", []
 
148
 
 
149
  gen = init_generator()
150
 
151
+ yield "✅ 模型加载完成!开始生成... / Model loaded! Starting generation...", []
152
+
153
+ # Step 2: Generate multiple images
154
  results = []
155
  seeds_used = []
156
 
157
+ for i in range(num_images):
158
+ current_seed = start_seed + i
159
+ progress((i + 1) / num_images, desc=f"生成第 {i+1}/{num_images} 张...")
160
+
161
+ yield f"🎨 正在生成第 {i+1}/{num_images} 张 (Seed: {current_seed})...", results
162
+
163
  result_img, cond_img = gen.generate(
164
  text=text,
165
  font_style=font,
 
167
  num_steps=num_steps,
168
  seed=current_seed,
169
  )
170
+
171
+ results.append((result_img, f"Seed: {current_seed}"))
172
  seeds_used.append(current_seed)
173
+
174
+ # Yield intermediate results so user can see progress
175
+ yield f"✅ 已完成 {i+1}/{num_images} 张 (Seed: {current_seed})", results
176
 
177
+ # Final yield with all seeds info
178
+ if num_images > 1:
179
+ final_status = f"✅ 全部完成!共 {num_images} 张 (Seeds: {seeds_used[0]}-{seeds_used[-1]})"
180
  else:
181
+ final_status = f"✅ 完成!Seed: {seeds_used[0]}"
182
+ yield final_status, results
 
183
 
184
 
185
  # Create Gradio interface
 
201
 
202
  text_input = gr.Textbox(
203
  label="输入文本 / Input Text (1-7个字符 / 1-7 characters)",
204
+ placeholder="请输入1-7个汉字 / Enter 1-7 Chinese characters, e.g.: 天道酬勤",
205
+ value="天道酬勤",
206
  max_lines=1
207
  )
208
 
 
211
  author_dropdown = gr.Dropdown(
212
  label="1. 选择书法家 / Select Calligrapher",
213
  choices=["None (Synthetic / 合成风格)"] + AUTHOR_LIST,
214
+ value="王羲之",
215
  info="先选择历史书法家 / Choose a historical calligrapher first"
216
  )
217
 
218
+ # Get initial fonts for default author (王羲之)
219
+ initial_author = "王羲之"
220
  initial_fonts = AUTHOR_FONTS.get(initial_author, ["楷", "草", "行"])
221
  initial_font_choices = [FONT_STYLE_NAMES[f] for f in initial_fonts if f in FONT_STYLE_NAMES]
222
 
223
  font_style = gr.Dropdown(
224
  label="2. 选择字体风格 / Select Font Style",
225
  choices=initial_font_choices,
226
+ value=" (Cursive Script)",
227
  info="根据所选书法家显示可用字体 / Shows available fonts for selected calligrapher"
228
  )
229
 
 
238
  info="更多步数 = 更高质量,但更慢 / More steps = higher quality, but slower"
239
  )
240
 
241
+ start_seed = gr.Number(
242
+ label="起始种子 / Start Seed",
243
+ value=42,
244
+ precision=0
245
+ )
 
 
 
 
 
246
 
247
+ num_images = gr.Slider(
248
+ label="生成数量 / Number of Images",
249
  minimum=1,
250
+ maximum=8,
251
  value=1,
252
+ step=1
 
253
  )
254
 
255
+ generate_btn = gr.Button("🎨 开始生成 / Start Generation", variant="primary", size="lg")
256
 
257
  with gr.Column(scale=1):
258
  # Output section
259
+ gr.Markdown("### 🖼️ 生成结果 / Generated Results")
260
+ gr.Markdown("*点击图片可放大查看 / Click image to enlarge*")
261
 
262
  output_gallery = gr.Gallery(
263
  label="生成结果 / Generated Results",
264
  show_label=False,
265
  columns=2,
266
  rows=2,
267
+ height=550,
268
  object_fit="contain",
269
  allow_preview=True
270
  )
271
 
272
+ status_text = gr.Textbox(
273
+ label="状态 / Status",
274
+ value="准备就绪 / Ready",
275
  interactive=False
276
  )
277
 
 
288
  gr.Markdown(author_info_md)
289
 
290
  # Event handlers
 
291
  author_dropdown.change(
292
  fn=update_font_choices,
293
  inputs=[author_dropdown],
294
  outputs=[font_style]
295
  )
296
 
297
+ # Generate button - uses streaming for live updates
298
  generate_btn.click(
299
+ fn=interactive_session,
300
  inputs=[
301
  text_input,
302
  author_dropdown,
303
  font_style,
304
  num_steps,
305
+ start_seed,
306
+ num_images,
 
307
  ],
308
+ outputs=[status_text, output_gallery]
309
  )
310
 
311
  # Examples
312
  gr.Markdown("### 📋 示例 / Examples")
313
  gr.Examples(
314
  examples=[
315
+ ["天道酬勤", "王羲之", " (Cursive Script)", 25, 42, 1],
316
+ ["春风得意马蹄疾", "赵佶\\宋徽宗", " (Regular Script)", 25, 42, 1],
317
+ ["海内存知己", "黄庭坚", " (Running Script)", 25, 42, 1],
318
+ ["宁静致远", "None (Synthetic / 合成风格)", "楷 (Regular Script)", 25, 42, 1],
319
  ],
320
  inputs=[
321
  text_input,
322
  author_dropdown,
323
  font_style,
324
  num_steps,
325
+ start_seed,
326
+ num_images,
 
327
  ],
328
  )
329
 
inference.py CHANGED
@@ -341,8 +341,8 @@ class CalligraphyGenerator:
341
 
342
  # Move to GPU only if NOT using DeepSpeed (DeepSpeed will handle device placement)
343
  if not use_deepspeed:
344
- print(f"Moving model to {self.device}...")
345
- model = model.to(self.device)
346
 
347
  # Enable optimized attention backends
348
  try:
 
341
 
342
  # Move to GPU only if NOT using DeepSpeed (DeepSpeed will handle device placement)
343
  if not use_deepspeed:
344
+ print(f"Moving model to {self.device} and converting to float32...")
345
+ model = model.to(device=self.device, dtype=torch.float32)
346
 
347
  # Enable optimized attention backends
348
  try:
src/flux/modules/layers.py CHANGED
@@ -34,19 +34,21 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
34
  :param max_period: controls the minimum frequency of the embeddings.
35
  :return: an (N, D) Tensor of positional embeddings.
36
  """
 
 
 
 
37
  t = time_factor * t
38
  half = dim // 2
39
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
40
- t.device
41
- )
42
 
43
  args = t[:, None].float() * freqs[None]
44
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
  if dim % 2:
46
  embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
47
- if torch.is_floating_point(t):
48
- embedding = embedding.to(t)
49
- return embedding
50
 
51
 
52
  class MLPEmbedder(nn.Module):
 
34
  :param max_period: controls the minimum frequency of the embeddings.
35
  :return: an (N, D) Tensor of positional embeddings.
36
  """
37
+ # Store original dtype and device
38
+ orig_dtype = t.dtype
39
+ orig_device = t.device
40
+
41
  t = time_factor * t
42
  half = dim // 2
43
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=orig_device) / half)
 
 
44
 
45
  args = t[:, None].float() * freqs[None]
46
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
47
  if dim % 2:
48
  embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
49
+
50
+ # Always convert to original dtype
51
+ return embedding.to(dtype=orig_dtype, device=orig_device)
52
 
53
 
54
  class MLPEmbedder(nn.Module):
src/flux/xflux_pipeline.py CHANGED
@@ -225,6 +225,7 @@ class XFluxPipeline:
225
  if self.controlnet_loaded:
226
  controlnet_image = self.annotator(controlnet_image, width, height)
227
  controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
 
228
  controlnet_image = controlnet_image.permute(
229
  2, 0, 1).unsqueeze(0).to(torch.float32).to(self.device)
230
 
@@ -311,6 +312,7 @@ class XFluxPipeline:
311
  neg_ip_scale=1.0,
312
  is_generation=True,
313
  ):
 
314
  x = get_noise(
315
  1, height, width, device=self.device,
316
  dtype=torch.float32, seed=seed
@@ -328,7 +330,8 @@ class XFluxPipeline:
328
 
329
  if not self.controlnet_loaded and controlnet_image is not None: # tianshuo
330
  # width //= 2
331
- cond_latent = self.ae.encode(controlnet_image.to(self.device, dtype=torch.float32))
 
332
 
333
  inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt)
334
  neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt)
 
225
  if self.controlnet_loaded:
226
  controlnet_image = self.annotator(controlnet_image, width, height)
227
  controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
228
+ # Keep as float32 for VAE encoding, will be converted to model dtype after
229
  controlnet_image = controlnet_image.permute(
230
  2, 0, 1).unsqueeze(0).to(torch.float32).to(self.device)
231
 
 
312
  neg_ip_scale=1.0,
313
  is_generation=True,
314
  ):
315
+ # Use float32 for stable inference
316
  x = get_noise(
317
  1, height, width, device=self.device,
318
  dtype=torch.float32, seed=seed
 
330
 
331
  if not self.controlnet_loaded and controlnet_image is not None: # tianshuo
332
  # width //= 2
333
+ # VAE expects float32 (controlnet_image is already float32)
334
+ cond_latent = self.ae.encode(controlnet_image)
335
 
336
  inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt)
337
  neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt)