lenML commited on
Commit
e487bbb
·
verified ·
1 Parent(s): 00d4335

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -133
app.py CHANGED
@@ -11,7 +11,7 @@ import warnings
11
  # 忽略警告
12
  warnings.filterwarnings("ignore")
13
 
14
- # ==================== 1. 分辨率配置 (严格保持原样) ====================
15
  RES_CHOICES = {
16
  "1024": [
17
  "720x1280 (9:16)",
@@ -65,46 +65,82 @@ def get_resolution(resolution_str):
65
  return width - width % 8, height - height % 8
66
  return 1024, 1024
67
 
68
- # ==================== 2. 模型加载与优化 ====================
69
  print("🚀 Loading Z-Image-Turbo pipeline...")
70
 
71
- # 加载模型
 
72
  pipe = DiffusionPipeline.from_pretrained(
73
  "Tongyi-MAI/Z-Image-Turbo",
74
  torch_dtype=torch.bfloat16,
75
  low_cpu_mem_usage=True,
76
  use_safetensors=True,
 
77
  )
78
 
79
- # 【核心加速】使用 FlowMatchEulerDiscreteScheduler 并设置 shift=3.0
80
- # 这是 Turbo 模型速度快的关键,能让模型在 4-10 步内生成高质量图片
81
- scheduler_config = dict(pipe.scheduler.config)
82
- scheduler_config.pop("algorithm_type", None)
83
- pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
84
- scheduler_config,
85
- shift=3.0
86
- )
 
 
 
87
 
88
  # 移动到 GPU
89
  pipe.to("cuda")
90
 
91
- # 【显存优化】开启 xFormers (如果环境支持)
92
- try:
93
- pipe.enable_xformers_memory_efficient_attention()
94
- print("✅ XFormers enabled")
95
- except Exception as e:
96
- print(f"⚠️ XFormers not available: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- # 【显存优化】开启 VAE 切片
99
  try:
100
  pipe.vae.enable_slicing()
101
  except:
102
  pass
103
 
104
- # 注意:移除了 pipe.transformer = torch.compile(...)
105
- # 原因是它导致了 'Cannot construct ConstantVariable for value of type torch.device' 错误。
106
- # 目前仅靠 Scheduler 优化和 xFormers 已经足够快且稳定。
107
-
108
  # ==================== 3. 生成逻辑 ====================
109
  @spaces.GPU
110
  def generate_image(
@@ -117,37 +153,33 @@ def generate_image(
117
  seed,
118
  randomize_seed,
119
  negative_prompt,
120
- gallery_history, # 接收历史记录
121
  progress=gr.Progress(track_tqdm=True)
122
  ):
123
  if gallery_history is None:
124
  gallery_history = []
125
 
126
  try:
127
- # 1. 输入校验
128
  if not prompt or len(prompt.strip()) < 2:
129
  raise gr.Error("请输提示词 (Prompt)")
130
 
131
  prompt = prompt.strip()
132
  neg_prompt = negative_prompt.strip() if negative_prompt else None
133
 
134
- # 2. 分辨率计算
135
  if use_custom_res:
136
  width = int(custom_width) - int(custom_width) % 8
137
  height = int(custom_height) - int(custom_height) % 8
138
  else:
139
  width, height = get_resolution(resolution_choice)
140
 
141
- # 3. 种子处理
142
  if randomize_seed:
143
  seed = random.randint(0, 2**32 - 1)
144
  seed = int(seed)
145
 
146
- # 4. 生成
147
  start_time = time.time()
148
  generator = torch.Generator("cuda").manual_seed(seed)
149
 
150
- # 清理显存
151
  torch.cuda.empty_cache()
152
 
153
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
@@ -156,7 +188,7 @@ def generate_image(
156
  height=height,
157
  width=width,
158
  num_inference_steps=int(num_inference_steps),
159
- guidance_scale=0.0, # Turbo 模型不需要 guidance_scale
160
  generator=generator,
161
  negative_prompt=neg_prompt,
162
  max_sequence_length=512,
@@ -164,190 +196,135 @@ def generate_image(
164
 
165
  gen_time = time.time() - start_time
166
 
167
- # 5. 构建历史记录
168
- # Gallery 格式: [(image_path_or_obj, label), ...]
169
- info_label = f"{width}x{height} | Seed: {seed} | {gen_time:.1f}s"
170
  gallery_history.insert(0, (image, info_label))
171
 
172
  return gallery_history, seed
173
 
174
  except Exception as e:
175
- raise gr.Error(f"生成失败: {str(e)}")
176
 
177
- # ==================== 4. UI 样式 (CSS) ====================
178
  css = """
179
- /* 全局字体 */
180
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;800&display=swap');
181
  body, .gradio-container { font-family: 'Inter', sans-serif !important; }
182
 
183
- /* 标题样式 */
184
  .header-container { text-align: center; margin-bottom: 20px; }
185
  .header-title {
186
  font-size: 2.5rem; font-weight: 800; margin: 0;
187
- background: linear-gradient(135deg, #6366f1, #a855f7, #ec4899);
188
  -webkit-background-clip: text; -webkit-text-fill-color: transparent;
189
  }
190
  .header-subtitle { font-size: 1rem; color: #6b7280; font-weight: 500; }
191
 
192
- /* 按钮样式 */
193
  .primary-btn {
194
- background: linear-gradient(90deg, #4f46e5 0%, #7c3aed 100%) !important;
195
  border: none !important;
196
  color: white !important;
197
  font-weight: 600 !important;
198
  font-size: 1.1rem !important;
199
- box-shadow: 0 4px 6px -1px rgba(79, 70, 229, 0.2) !important;
200
- transition: all 0.2s !important;
201
  }
202
- .primary-btn:hover { transform: translateY(-2px); box-shadow: 0 10px 15px -3px rgba(79, 70, 229, 0.3) !important; }
203
 
204
- /* 输入框和面板 */
205
  .panel-container {
206
- background: #ffffff;
207
- border: 1px solid #e5e7eb;
208
- border-radius: 16px;
209
- padding: 20px;
210
- box-shadow: 0 1px 3px rgba(0,0,0,0.05);
211
  }
212
- /* 暗黑模式适配 */
213
  .dark .panel-container { background: #1f2937; border-color: #374151; }
214
-
215
- /* 画廊样式 */
216
- #output-gallery { min-height: 600px; }
217
  """
218
 
219
- # ==================== 5. Gradio 界面构建 ====================
220
- with gr.Blocks(theme=gr.themes.Soft(), css=css, title="Z-Image-Turbo") as demo:
221
 
222
- # 头部
223
  gr.HTML("""
224
  <div class="header-container">
225
  <h1 class="header-title">⚡ Z-Image-Turbo</h1>
226
- <p class="header-subtitle">Ultra-Fast Generation • 8 Steps • Gallery History</p>
227
  </div>
228
  """)
229
 
230
  with gr.Row():
231
- # --- 左侧:控制面板 ---
232
  with gr.Column(scale=4, min_width=320):
233
  with gr.Group(elem_classes="panel-container"):
234
  prompt = gr.Textbox(
235
- label="提示词 (Prompt)",
236
- placeholder="Describe your imagination...",
237
- lines=4,
238
- show_label=True
239
  )
240
  negative_prompt = gr.Textbox(
241
- label="反向提示词 (Negative Prompt)",
242
- placeholder="Blurry, low quality, ugly...",
243
- lines=2
244
  )
245
-
246
- generate_btn = gr.Button("🚀 Generate Image", elem_classes="primary-btn")
247
 
248
  with gr.Group(elem_classes="panel-container"):
249
- gr.Markdown("### 📐 尺寸设置")
250
-
251
- # 分辨率分类
252
- with gr.Row():
253
- res_category = gr.Radio(
254
- choices=["1024", "1280", "1536"],
255
- value="1024",
256
- label="基准分辨率",
257
- container=False
258
- )
259
-
260
- # 具体分辨率下拉框
261
  resolution_dropdown = gr.Dropdown(
262
  choices=RES_CHOICES["1024"],
263
  value=RES_CHOICES["1024"][0],
264
- label="选择比例",
265
- show_label=False,
266
- interactive=True
267
  )
268
 
269
- # 自定义分辨率开关
270
- with gr.Accordion("自定义尺寸 (高级)", open=False):
271
- use_custom_res = gr.Checkbox(label="启用自定义尺寸", value=False)
272
  with gr.Row(visible=False) as custom_res_row:
273
- width_slider = gr.Slider(512, 1536, value=1024, step=64, label="")
274
- height_slider = gr.Slider(512, 1536, value=1024, step=64, label="")
275
 
276
- with gr.Accordion("⚙️ 高级设置", open=False):
277
  with gr.Group(elem_classes="panel-container"):
278
- steps_slider = gr.Slider(
279
- minimum=4, maximum=20, value=8, step=1,
280
- label="步数 (Steps) - 推荐 4-8"
281
- )
282
  with gr.Row():
283
- random_seed = gr.Checkbox(label="随机种子", value=True)
284
- seed_input = gr.Number(label="种子", value=42, visible=False, precision=0)
285
 
286
- # --- 右侧:画廊展示 ---
287
  with gr.Column(scale=6, min_width=500):
288
  output_gallery = gr.Gallery(
289
- label="生成历史 (History)",
290
  value=[],
291
  columns=[2],
292
  rows=[2],
293
  object_fit="contain",
294
  height="auto",
295
- elem_id="output-gallery",
296
- show_label=True,
297
  show_share_button=True,
298
  show_download_button=True,
299
  interactive=False
300
  )
301
-
302
  with gr.Row():
303
- last_seed_display = gr.Textbox(label="当前图种子", interactive=False, scale=3)
304
- clear_history_btn = gr.Button("🗑️ 清空历史", scale=1, variant="secondary")
305
 
306
- # ==================== 6. 交互逻辑 ====================
307
-
308
- # 1. 切换分辨率分类时,更新下拉框
309
  def update_resolution_list(category):
310
- new_choices = RES_CHOICES[category]
311
- return gr.Dropdown(choices=new_choices, value=new_choices[0])
312
 
313
- res_category.change(
314
- fn=update_resolution_list,
315
- inputs=[res_category],
316
- outputs=[resolution_dropdown]
317
- )
318
 
319
- # 2. 切换自定义分辨率显示
320
- def toggle_custom(is_custom):
321
- return gr.Row(visible=is_custom), gr.Dropdown(interactive=not is_custom)
322
-
323
  use_custom_res.change(
324
- fn=toggle_custom,
325
- inputs=[use_custom_res],
326
- outputs=[custom_res_row, resolution_dropdown]
327
  )
328
 
329
- # 3. 切换种子输入框
330
- random_seed.change(
331
- fn=lambda x: gr.Number(visible=not x),
332
- inputs=[random_seed],
333
- outputs=[seed_input]
334
- )
335
 
336
- # 4. 生成按钮点击
337
  generate_btn.click(
338
  fn=generate_image,
339
- inputs=[
340
- prompt, resolution_dropdown, use_custom_res, width_slider, height_slider,
341
- steps_slider, seed_input, random_seed, negative_prompt, output_gallery
342
- ],
343
  outputs=[output_gallery, last_seed_display]
344
  )
345
 
346
- # 5. 清空历史
347
- clear_history_btn.click(
348
- fn=lambda: ([], ""),
349
- outputs=[output_gallery, last_seed_display]
350
- )
351
 
352
  if __name__ == "__main__":
353
  demo.launch()
 
11
  # 忽略警告
12
  warnings.filterwarnings("ignore")
13
 
14
+ # ==================== 1. 分辨率配置 ====================
15
  RES_CHOICES = {
16
  "1024": [
17
  "720x1280 (9:16)",
 
65
  return width - width % 8, height - height % 8
66
  return 1024, 1024
67
 
68
+ # ==================== 2. 模型加载与核心优化 ====================
69
  print("🚀 Loading Z-Image-Turbo pipeline...")
70
 
71
+ # 必须设置为 True,才能加载 Z-Image 自定义的 Pipeline 和 Transformer 类
72
+ # 否则无法调用 set_attention_backend
73
  pipe = DiffusionPipeline.from_pretrained(
74
  "Tongyi-MAI/Z-Image-Turbo",
75
  torch_dtype=torch.bfloat16,
76
  low_cpu_mem_usage=True,
77
  use_safetensors=True,
78
+ trust_remote_code=True,
79
  )
80
 
81
+ # 使用 FlowMatchEulerDiscreteScheduler 并设置 shift=3.0
82
+ try:
83
+ scheduler_config = dict(pipe.scheduler.config)
84
+ scheduler_config.pop("algorithm_type", None)
85
+ pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
86
+ scheduler_config,
87
+ shift=3.0
88
+ )
89
+ print("✅ Scheduler optimized with shift=3.0")
90
+ except Exception as e:
91
+ print(f"⚠️ Scheduler config warning: {e}")
92
 
93
  # 移动到 GPU
94
  pipe.to("cuda")
95
 
96
+ # 尝试按顺序启用最快的后端
97
+ def enable_best_attention_backend(pipeline):
98
+ # 检查 pipeline.transformer 是否有 set_attention_backend 方法
99
+ # 这是 Z-Image 自定义类特有的
100
+ if hasattr(pipeline.transformer, "set_attention_backend"):
101
+ try:
102
+ # 优先尝试 Flash Attention 2 (A100/A10G)
103
+ print("⚡ Attempting to set backend to 'flash_attention_2'...")
104
+ pipeline.transformer.set_attention_backend("flash_attention_2")
105
+ print("✅ Attention backend set to: flash_attention_2")
106
+ return
107
+ except Exception as e:
108
+ print(f"ℹ️ Flash Attention 2 not available: {e}")
109
+
110
+ try:
111
+ # 其次尝试 xFormers (T4/V100 通用)
112
+ print("⚡ Attempting to set backend to 'xformers'...")
113
+ pipeline.transformer.set_attention_backend("xformers")
114
+ print("✅ Attention backend set to: xformers")
115
+ return
116
+ except Exception as e:
117
+ print(f"ℹ️ xFormers not available: {e}")
118
+
119
+ try:
120
+ # 最后使用 PyTorch 2.0 Native SDPA
121
+ print("⚡ Setting backend to 'native' (SDPA)...")
122
+ pipeline.transformer.set_attention_backend("native")
123
+ print("✅ Attention backend set to: native")
124
+ except Exception as e:
125
+ print(f"⚠️ Could not set custom attention backend: {e}")
126
+ else:
127
+ print("⚠️ Warning: Transformer model does not support 'set_attention_backend'. Custom code might not be loaded.")
128
+ # 如果加载失败,尝试标准的 xformers
129
+ try:
130
+ pipeline.enable_xformers_memory_efficient_attention()
131
+ print("✅ Standard xFormers enabled as fallback")
132
+ except:
133
+ pass
134
+
135
+ # 执行后端设置
136
+ enable_best_attention_backend(pipe)
137
 
138
+ # VAE 内存优化
139
  try:
140
  pipe.vae.enable_slicing()
141
  except:
142
  pass
143
 
 
 
 
 
144
  # ==================== 3. 生成逻辑 ====================
145
  @spaces.GPU
146
  def generate_image(
 
153
  seed,
154
  randomize_seed,
155
  negative_prompt,
156
+ gallery_history,
157
  progress=gr.Progress(track_tqdm=True)
158
  ):
159
  if gallery_history is None:
160
  gallery_history = []
161
 
162
  try:
 
163
  if not prompt or len(prompt.strip()) < 2:
164
  raise gr.Error("请输提示词 (Prompt)")
165
 
166
  prompt = prompt.strip()
167
  neg_prompt = negative_prompt.strip() if negative_prompt else None
168
 
 
169
  if use_custom_res:
170
  width = int(custom_width) - int(custom_width) % 8
171
  height = int(custom_height) - int(custom_height) % 8
172
  else:
173
  width, height = get_resolution(resolution_choice)
174
 
 
175
  if randomize_seed:
176
  seed = random.randint(0, 2**32 - 1)
177
  seed = int(seed)
178
 
 
179
  start_time = time.time()
180
  generator = torch.Generator("cuda").manual_seed(seed)
181
 
182
+ # 清理显存确保最大空间
183
  torch.cuda.empty_cache()
184
 
185
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
 
188
  height=height,
189
  width=width,
190
  num_inference_steps=int(num_inference_steps),
191
+ guidance_scale=0.0,
192
  generator=generator,
193
  negative_prompt=neg_prompt,
194
  max_sequence_length=512,
 
196
 
197
  gen_time = time.time() - start_time
198
 
199
+ # 格式化历史记录
200
+ info_label = f"{width}x{height} | Steps: {num_inference_steps} | Seed: {seed} | {gen_time:.2f}s"
 
201
  gallery_history.insert(0, (image, info_label))
202
 
203
  return gallery_history, seed
204
 
205
  except Exception as e:
206
+ raise gr.Error(f"生成错误: {str(e)}")
207
 
208
+ # ==================== 4. UI 样式 ====================
209
  css = """
 
210
  @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;800&display=swap');
211
  body, .gradio-container { font-family: 'Inter', sans-serif !important; }
212
 
 
213
  .header-container { text-align: center; margin-bottom: 20px; }
214
  .header-title {
215
  font-size: 2.5rem; font-weight: 800; margin: 0;
216
+ background: linear-gradient(135deg, #f59e0b, #ea580c);
217
  -webkit-background-clip: text; -webkit-text-fill-color: transparent;
218
  }
219
  .header-subtitle { font-size: 1rem; color: #6b7280; font-weight: 500; }
220
 
 
221
  .primary-btn {
222
+ background: linear-gradient(90deg, #f59e0b 0%, #d97706 100%) !important;
223
  border: none !important;
224
  color: white !important;
225
  font-weight: 600 !important;
226
  font-size: 1.1rem !important;
227
+ box-shadow: 0 4px 6px -1px rgba(245, 158, 11, 0.2) !important;
 
228
  }
229
+ .primary-btn:hover { transform: translateY(-2px); box-shadow: 0 10px 15px -3px rgba(245, 158, 11, 0.3) !important; }
230
 
 
231
  .panel-container {
232
+ background: #ffffff; border: 1px solid #e5e7eb; border-radius: 12px; padding: 15px;
 
 
 
 
233
  }
 
234
  .dark .panel-container { background: #1f2937; border-color: #374151; }
 
 
 
235
  """
236
 
237
+ # ==================== 5. Gradio 界面 ====================
238
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange"), css=css, title="Z-Image-Turbo") as demo:
239
 
 
240
  gr.HTML("""
241
  <div class="header-container">
242
  <h1 class="header-title">⚡ Z-Image-Turbo</h1>
243
+ <p class="header-subtitle">Optimized Backend • 8 Steps • Gallery History</p>
244
  </div>
245
  """)
246
 
247
  with gr.Row():
248
+ # --- 控制面板 ---
249
  with gr.Column(scale=4, min_width=320):
250
  with gr.Group(elem_classes="panel-container"):
251
  prompt = gr.Textbox(
252
+ label="Prompt",
253
+ placeholder="Enter your prompt here...",
254
+ lines=3
 
255
  )
256
  negative_prompt = gr.Textbox(
257
+ label="Negative Prompt",
258
+ placeholder="Low quality, blurry...",
259
+ lines=1
260
  )
261
+ generate_btn = gr.Button("🚀 Generate", elem_classes="primary-btn")
 
262
 
263
  with gr.Group(elem_classes="panel-container"):
264
+ gr.Markdown("### 📐 Resolution")
265
+ res_category = gr.Radio(
266
+ choices=["1024", "1280", "1536"],
267
+ value="1024",
268
+ label="Resolution Base",
269
+ container=False
270
+ )
 
 
 
 
 
271
  resolution_dropdown = gr.Dropdown(
272
  choices=RES_CHOICES["1024"],
273
  value=RES_CHOICES["1024"][0],
274
+ label="Select Ratio",
275
+ show_label=False
 
276
  )
277
 
278
+ with gr.Accordion("Custom Size", open=False):
279
+ use_custom_res = gr.Checkbox(label="Enable Custom", value=False)
 
280
  with gr.Row(visible=False) as custom_res_row:
281
+ width_slider = gr.Slider(512, 1536, value=1024, step=64, label="W")
282
+ height_slider = gr.Slider(512, 1536, value=1024, step=64, label="H")
283
 
284
+ with gr.Accordion("⚙️ Settings", open=False):
285
  with gr.Group(elem_classes="panel-container"):
286
+ steps_slider = gr.Slider(4, 20, value=8, step=1, label="Steps")
 
 
 
287
  with gr.Row():
288
+ random_seed = gr.Checkbox(label="Random Seed", value=True)
289
+ seed_input = gr.Number(label="Seed", value=42, visible=False, precision=0)
290
 
291
+ # --- 画廊 ---
292
  with gr.Column(scale=6, min_width=500):
293
  output_gallery = gr.Gallery(
294
+ label="History",
295
  value=[],
296
  columns=[2],
297
  rows=[2],
298
  object_fit="contain",
299
  height="auto",
 
 
300
  show_share_button=True,
301
  show_download_button=True,
302
  interactive=False
303
  )
 
304
  with gr.Row():
305
+ last_seed_display = gr.Textbox(label="Last Seed", interactive=False, scale=3)
306
+ clear_btn = gr.Button("🗑️ Clear", scale=1, variant="secondary")
307
 
308
+ # 交互逻辑
 
 
309
  def update_resolution_list(category):
310
+ return gr.Dropdown(choices=RES_CHOICES[category], value=RES_CHOICES[category][0])
 
311
 
312
+ res_category.change(update_resolution_list, inputs=res_category, outputs=resolution_dropdown)
 
 
 
 
313
 
 
 
 
 
314
  use_custom_res.change(
315
+ lambda x: (gr.Row(visible=x), gr.Dropdown(interactive=not x)),
316
+ inputs=use_custom_res, outputs=[custom_res_row, resolution_dropdown]
 
317
  )
318
 
319
+ random_seed.change(lambda x: gr.Number(visible=not x), inputs=random_seed, outputs=seed_input)
 
 
 
 
 
320
 
 
321
  generate_btn.click(
322
  fn=generate_image,
323
+ inputs=[prompt, resolution_dropdown, use_custom_res, width_slider, height_slider, steps_slider, seed_input, random_seed, negative_prompt, output_gallery],
 
 
 
324
  outputs=[output_gallery, last_seed_display]
325
  )
326
 
327
+ clear_btn.click(lambda: ([], ""), outputs=[output_gallery, last_seed_display])
 
 
 
 
328
 
329
  if __name__ == "__main__":
330
  demo.launch()