alexander00001 commited on
Commit
23f8bf6
·
verified ·
1 Parent(s): 6d8ec97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -50
app.py CHANGED
@@ -17,6 +17,7 @@ from diffusers import FluxPipeline, FlowMatchEulerDiscreteScheduler
17
  from PIL import Image
18
  import traceback
19
  import numpy as np
 
20
 
21
  # 移除 Compel(FLUX 不兼容,简化处理)
22
  COMPEL_AVAILABLE = False
@@ -59,23 +60,36 @@ pipeline = None
59
  device = None
60
  model_loaded = False
61
 
 
 
 
 
 
 
 
62
  def initialize_model():
63
- """优化的模型初始化函数(基于官方示例)"""
64
  global pipeline, device, model_loaded
65
 
66
  if model_loaded and pipeline is not None:
67
  return True
68
 
69
  try:
 
 
 
70
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
  print(f"🖥️ Using device: {device}")
72
 
73
  print(f"📦 Loading fixed model: {FIXED_MODEL}")
74
 
75
- # 基础模型加载(官方示例)
76
  pipeline = FluxPipeline.from_pretrained(
77
  FIXED_MODEL,
78
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
 
 
 
79
  )
80
 
81
  # 优化调度器(默认 FlowMatchEulerDiscreteScheduler)
@@ -84,30 +98,33 @@ def initialize_model():
84
  )
85
  pipeline = pipeline.to(device)
86
 
87
- # 统一数据类型和优化(官方推荐)
88
  if torch.cuda.is_available():
89
- pipeline.vae.to(torch.bfloat16)
90
- pipeline.enable_sequential_cpu_offload() # 官方 VRAM 优化
91
- pipeline.enable_vae_slicing() # 保留 VAE 分片
 
92
 
93
  # 移除 torch.compile(Spaces 不稳定)
94
- print("✅ Model initialization complete (no torch.compile for stability)")
95
  model_loaded = True
96
  return True
97
 
98
  except Exception as e:
99
  print(f"❌ Critical model loading error: {e}")
100
  print(traceback.format_exc())
 
101
  model_loaded = False
102
  return False
103
 
104
  def enhance_prompt(prompt: str, style: str) -> str:
105
  """增强提示词"""
106
- quality_terms = ", ".join(QUALITY_ENHANCERS)
 
107
 
108
  style_terms = ""
109
  if style in STYLE_ENHANCERS:
110
- style_terms = ", " + ", ".join(STYLE_ENHANCERS[style])
111
 
112
  style_suffix = STYLE_PRESETS.get(style, "")
113
 
@@ -122,12 +139,18 @@ def enhance_prompt(prompt: str, style: str) -> str:
122
  enhanced_parts.append(quality_terms)
123
 
124
  enhanced_prompt = ", ".join(filter(None, enhanced_parts))
 
 
 
 
 
125
  return enhanced_prompt
126
 
127
  def apply_spaces_decorator(func):
128
- """应用 spaces 装饰器"""
129
  if SPACES_AVAILABLE:
130
- return spaces.GPU(duration=60)(func)
 
131
  return func
132
 
133
  def create_metadata_content(prompt, enhanced_prompt, seed, steps, cfg_scale, width, height, style):
@@ -147,20 +170,26 @@ Model: FLUX.1-dev
147
  """
148
 
149
  @apply_spaces_decorator
150
- def generate_image(prompt: str, style: str, negative_prompt: str = "", steps: int = 28, cfg_scale: float = 3.5,
151
  seed: int = -1, width: int = 1024, height: int = 1024, progress=gr.Progress()):
152
- """图像生成函数(基于官方示例)"""
153
- if not prompt or prompt.strip() == "":
154
- return None, "", ""
155
-
156
- # 初始化模型
157
- progress(0.1, desc="Loading model...")
158
- if not initialize_model():
159
- return None, "", "❌ Failed to load model"
160
-
161
- progress(0.3, desc="Processing prompt...")
162
-
163
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  # 处理 seed
165
  if seed == -1:
166
  seed = random.randint(0, np.iinfo(np.int32).max)
@@ -175,29 +204,39 @@ def generate_image(prompt: str, style: str, negative_prompt: str = "", steps: in
175
  # 生成参数(官方示例:generator 用 cpu)
176
  generator = torch.Generator("cpu").manual_seed(seed)
177
 
178
- progress(0.5, desc="Generating image...")
179
- print(f"🔄 Starting inference: prompt='{enhanced_prompt[:50]}...', steps={steps}, guidance={cfg_scale}")
 
 
 
180
 
181
  # 直接使用标准 pipeline(无 Compel,添加 max_sequence_length)
182
- result = pipeline(
183
- prompt=enhanced_prompt,
184
- negative_prompt=negative_prompt,
185
- num_inference_steps=steps,
186
- guidance_scale=cfg_scale,
187
- width=width,
188
- height=height,
189
- max_sequence_length=512, # 官方必须参数!
190
- generator=generator
191
- )
 
 
 
192
  image = result.images[0]
193
  print("✅ Inference complete")
194
 
195
- progress(0.9, desc="Saving image...")
 
 
 
 
196
 
197
  # 保存图像
198
  filename = f"IMG_{seed}.png"
199
  filepath = os.path.join(SAVE_DIR, filename)
200
- image.save(filepath, quality=95, optimize=True)
201
 
202
  # 创建元数据内容
203
  metadata_content = create_metadata_content(
@@ -211,13 +250,20 @@ def generate_image(prompt: str, style: str, negative_prompt: str = "", steps: in
211
 
212
  return image, generation_info, metadata_content
213
 
 
 
 
 
 
 
214
  except Exception as e:
 
215
  error_msg = str(e)
216
  print(f"❌ Generation error: {error_msg}")
217
  print(traceback.format_exc())
218
  return None, "", f"❌ Generation failed: {error_msg}"
219
-
220
- # ===== CSS 样式(保持不变) =====
221
  css = """
222
  /* 全局容器 */
223
  .gradio-container {
@@ -466,33 +512,33 @@ def create_interface():
466
  precision=0
467
  )
468
 
469
- # 宽度选择
470
  with gr.Group(elem_classes=["controls-section"]):
471
  width_input = gr.Slider(
472
  label="Width",
473
  minimum=512,
474
- maximum=2048,
475
  value=1024,
476
  step=64
477
  )
478
 
479
- # 高度选择
480
  with gr.Group(elem_classes=["controls-section"]):
481
  height_input = gr.Slider(
482
  label="Height",
483
  minimum=512,
484
- maximum=2048,
485
  value=1024,
486
  step=64
487
  )
488
 
489
- # 高级参数
490
  with gr.Group(elem_classes=["controls-section"]):
491
  steps_input = gr.Slider(
492
  label="Steps",
493
  minimum=10,
494
- maximum=50,
495
- value=28,
496
  step=1
497
  )
498
 
@@ -555,7 +601,10 @@ def create_interface():
555
 
556
  if image is not None:
557
  # 提取实际使用的 seed
558
- actual_seed = seed if seed != -1 else int(info.split("Seed:")[1].split("|")[0].strip())
 
 
 
559
 
560
  return (
561
  image, # 图片输出
@@ -582,7 +631,7 @@ def create_interface():
582
  if image_data is not None:
583
  filename = f"IMG_{seed_val}.png"
584
  filepath = os.path.join(SAVE_DIR, filename)
585
- image_data.save(filepath, quality=95, optimize=True)
586
  return filepath
587
  return None
588
 
@@ -671,7 +720,7 @@ if __name__ == "__main__":
671
  print(f"🔧 CUDA: {'✅ Available' if torch.cuda.is_available() else '❌ Not Available'}")
672
 
673
  app = create_interface()
674
- app.queue(max_size=10, default_concurrency_limit=2)
675
 
676
  app.launch(
677
  server_name="0.0.0.0",
 
17
  from PIL import Image
18
  import traceback
19
  import numpy as np
20
+ import gc # 添加垃圾回收
21
 
22
  # 移除 Compel(FLUX 不兼容,简化处理)
23
  COMPEL_AVAILABLE = False
 
60
  device = None
61
  model_loaded = False
62
 
63
+ def cleanup_memory():
64
+ """清理GPU内存"""
65
+ if torch.cuda.is_available():
66
+ torch.cuda.empty_cache()
67
+ torch.cuda.synchronize()
68
+ gc.collect()
69
+
70
  def initialize_model():
71
+ """优化的模型初始化函数(针对ZeroGPU优化)"""
72
  global pipeline, device, model_loaded
73
 
74
  if model_loaded and pipeline is not None:
75
  return True
76
 
77
  try:
78
+ # 清理内存
79
+ cleanup_memory()
80
+
81
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
  print(f"🖥️ Using device: {device}")
83
 
84
  print(f"📦 Loading fixed model: {FIXED_MODEL}")
85
 
86
+ # ZeroGPU优化的模型加载
87
  pipeline = FluxPipeline.from_pretrained(
88
  FIXED_MODEL,
89
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
90
+ # 添加内存优化参数
91
+ variant=None,
92
+ use_safetensors=True
93
  )
94
 
95
  # 优化调度器(默认 FlowMatchEulerDiscreteScheduler)
 
98
  )
99
  pipeline = pipeline.to(device)
100
 
101
+ # ZeroGPU专用优化
102
  if torch.cuda.is_available():
103
+ # 使用更保守的内存设置
104
+ pipeline.enable_model_cpu_offload() # 改用model_cpu_offload
105
+ pipeline.enable_vae_slicing()
106
+ pipeline.enable_vae_tiling() # 添加VAE tiling
107
 
108
  # 移除 torch.compile(Spaces 不稳定)
109
+ print("✅ Model initialization complete (ZeroGPU optimized)")
110
  model_loaded = True
111
  return True
112
 
113
  except Exception as e:
114
  print(f"❌ Critical model loading error: {e}")
115
  print(traceback.format_exc())
116
+ cleanup_memory()
117
  model_loaded = False
118
  return False
119
 
120
  def enhance_prompt(prompt: str, style: str) -> str:
121
  """增强提示词"""
122
+ # 限制质量词数量,避免过长
123
+ quality_terms = ", ".join(QUALITY_ENHANCERS[:5]) # 只取前5个
124
 
125
  style_terms = ""
126
  if style in STYLE_ENHANCERS:
127
+ style_terms = ", " + ", ".join(STYLE_ENHANCERS[style][:3]) # 只取前3个
128
 
129
  style_suffix = STYLE_PRESETS.get(style, "")
130
 
 
139
  enhanced_parts.append(quality_terms)
140
 
141
  enhanced_prompt = ", ".join(filter(None, enhanced_parts))
142
+
143
+ # 限制总长度,避免超出模型限制
144
+ if len(enhanced_prompt) > 500:
145
+ enhanced_prompt = enhanced_prompt[:500] + "..."
146
+
147
  return enhanced_prompt
148
 
149
  def apply_spaces_decorator(func):
150
+ """应用 spaces 装饰器,增加更长的超时时间"""
151
  if SPACES_AVAILABLE:
152
+ # 增加超时时间到120秒,并设置更大的内存限制
153
+ return spaces.GPU(duration=120)(func)
154
  return func
155
 
156
  def create_metadata_content(prompt, enhanced_prompt, seed, steps, cfg_scale, width, height, style):
 
170
  """
171
 
172
  @apply_spaces_decorator
173
+ def generate_image(prompt: str, style: str, negative_prompt: str = "", steps: int = 20, cfg_scale: float = 3.5,
174
  seed: int = -1, width: int = 1024, height: int = 1024, progress=gr.Progress()):
175
+ """图像生成函数(ZeroGPU优化版本)"""
 
 
 
 
 
 
 
 
 
 
176
  try:
177
+ if not prompt or prompt.strip() == "":
178
+ return None, "", "❌ Please enter a prompt"
179
+
180
+ # 参数验证和限制
181
+ steps = max(10, min(steps, 30)) # 限制步数范围,避免超时
182
+ width = min(width, 1024) # 限制最大尺寸
183
+ height = min(height, 1024)
184
+
185
+ # 初始化模型
186
+ progress(0.1, desc="Initializing model...")
187
+ if not initialize_model():
188
+ cleanup_memory()
189
+ return None, "", "❌ Failed to initialize model"
190
+
191
+ progress(0.2, desc="Processing prompt...")
192
+
193
  # 处理 seed
194
  if seed == -1:
195
  seed = random.randint(0, np.iinfo(np.int32).max)
 
204
  # 生成参数(官方示例:generator 用 cpu)
205
  generator = torch.Generator("cpu").manual_seed(seed)
206
 
207
+ progress(0.4, desc="Starting generation...")
208
+ print(f"🔥 Starting inference: steps={steps}, guidance={cfg_scale}, size={width}x{height}")
209
+
210
+ # 清理内存
211
+ cleanup_memory()
212
 
213
  # 直接使用标准 pipeline(无 Compel,添加 max_sequence_length)
214
+ with torch.no_grad(): # 确保不计算梯度
215
+ result = pipeline(
216
+ prompt=enhanced_prompt,
217
+ negative_prompt=negative_prompt,
218
+ num_inference_steps=steps,
219
+ guidance_scale=cfg_scale,
220
+ width=width,
221
+ height=height,
222
+ max_sequence_length=256, # 减少序列长度,节省内存
223
+ generator=generator,
224
+ output_type="pil"
225
+ )
226
+
227
  image = result.images[0]
228
  print("✅ Inference complete")
229
 
230
+ progress(0.9, desc="Finalizing...")
231
+
232
+ # 立即清理内存
233
+ del result
234
+ cleanup_memory()
235
 
236
  # 保存图像
237
  filename = f"IMG_{seed}.png"
238
  filepath = os.path.join(SAVE_DIR, filename)
239
+ image.save(filepath, format="PNG", optimize=True)
240
 
241
  # 创建元数据内容
242
  metadata_content = create_metadata_content(
 
250
 
251
  return image, generation_info, metadata_content
252
 
253
+ except torch.cuda.OutOfMemoryError as e:
254
+ cleanup_memory()
255
+ error_msg = "❌ GPU memory insufficient. Try reducing image size or steps."
256
+ print(f"CUDA OOM: {error_msg}")
257
+ return None, "", error_msg
258
+
259
  except Exception as e:
260
+ cleanup_memory()
261
  error_msg = str(e)
262
  print(f"❌ Generation error: {error_msg}")
263
  print(traceback.format_exc())
264
  return None, "", f"❌ Generation failed: {error_msg}"
265
+
266
+ # ===== CSS 样式(保持不变)=====
267
  css = """
268
  /* 全局容器 */
269
  .gradio-container {
 
512
  precision=0
513
  )
514
 
515
+ # 宽度选择(降低最大值)
516
  with gr.Group(elem_classes=["controls-section"]):
517
  width_input = gr.Slider(
518
  label="Width",
519
  minimum=512,
520
+ maximum=1024, # 降低最大值
521
  value=1024,
522
  step=64
523
  )
524
 
525
+ # 高度选择(降低最大值)
526
  with gr.Group(elem_classes=["controls-section"]):
527
  height_input = gr.Slider(
528
  label="Height",
529
  minimum=512,
530
+ maximum=1024, # 降低最大值
531
  value=1024,
532
  step=64
533
  )
534
 
535
+ # 高级参数(调整默认值)
536
  with gr.Group(elem_classes=["controls-section"]):
537
  steps_input = gr.Slider(
538
  label="Steps",
539
  minimum=10,
540
+ maximum=30, # 降低最大值
541
+ value=20, # 降低默认值
542
  step=1
543
  )
544
 
 
601
 
602
  if image is not None:
603
  # 提取实际使用的 seed
604
+ try:
605
+ actual_seed = seed if seed != -1 else int(info.split("Seed:")[1].split("|")[0].strip())
606
+ except:
607
+ actual_seed = seed if seed != -1 else random.randint(0, 999999)
608
 
609
  return (
610
  image, # 图片输出
 
631
  if image_data is not None:
632
  filename = f"IMG_{seed_val}.png"
633
  filepath = os.path.join(SAVE_DIR, filename)
634
+ image_data.save(filepath, format="PNG", optimize=True)
635
  return filepath
636
  return None
637
 
 
720
  print(f"🔧 CUDA: {'✅ Available' if torch.cuda.is_available() else '❌ Not Available'}")
721
 
722
  app = create_interface()
723
+ app.queue(max_size=5, default_concurrency_limit=1) # 降低并发限制
724
 
725
  app.launch(
726
  server_name="0.0.0.0",