Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,6 +17,7 @@ from diffusers import FluxPipeline, FlowMatchEulerDiscreteScheduler
|
|
| 17 |
from PIL import Image
|
| 18 |
import traceback
|
| 19 |
import numpy as np
|
|
|
|
| 20 |
|
| 21 |
# 移除 Compel(FLUX 不兼容,简化处理)
|
| 22 |
COMPEL_AVAILABLE = False
|
|
@@ -59,23 +60,36 @@ pipeline = None
|
|
| 59 |
device = None
|
| 60 |
model_loaded = False
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
def initialize_model():
|
| 63 |
-
"""
|
| 64 |
global pipeline, device, model_loaded
|
| 65 |
|
| 66 |
if model_loaded and pipeline is not None:
|
| 67 |
return True
|
| 68 |
|
| 69 |
try:
|
|
|
|
|
|
|
|
|
|
| 70 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 71 |
print(f"🖥️ Using device: {device}")
|
| 72 |
|
| 73 |
print(f"📦 Loading fixed model: {FIXED_MODEL}")
|
| 74 |
|
| 75 |
-
#
|
| 76 |
pipeline = FluxPipeline.from_pretrained(
|
| 77 |
FIXED_MODEL,
|
| 78 |
-
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
|
|
|
|
|
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
# 优化调度器(默认 FlowMatchEulerDiscreteScheduler)
|
|
@@ -84,30 +98,33 @@ def initialize_model():
|
|
| 84 |
)
|
| 85 |
pipeline = pipeline.to(device)
|
| 86 |
|
| 87 |
-
#
|
| 88 |
if torch.cuda.is_available():
|
| 89 |
-
|
| 90 |
-
pipeline.
|
| 91 |
-
pipeline.enable_vae_slicing()
|
|
|
|
| 92 |
|
| 93 |
# 移除 torch.compile(Spaces 不稳定)
|
| 94 |
-
print("✅ Model initialization complete (
|
| 95 |
model_loaded = True
|
| 96 |
return True
|
| 97 |
|
| 98 |
except Exception as e:
|
| 99 |
print(f"❌ Critical model loading error: {e}")
|
| 100 |
print(traceback.format_exc())
|
|
|
|
| 101 |
model_loaded = False
|
| 102 |
return False
|
| 103 |
|
| 104 |
def enhance_prompt(prompt: str, style: str) -> str:
|
| 105 |
"""增强提示词"""
|
| 106 |
-
|
|
|
|
| 107 |
|
| 108 |
style_terms = ""
|
| 109 |
if style in STYLE_ENHANCERS:
|
| 110 |
-
style_terms = ", " + ", ".join(STYLE_ENHANCERS[style])
|
| 111 |
|
| 112 |
style_suffix = STYLE_PRESETS.get(style, "")
|
| 113 |
|
|
@@ -122,12 +139,18 @@ def enhance_prompt(prompt: str, style: str) -> str:
|
|
| 122 |
enhanced_parts.append(quality_terms)
|
| 123 |
|
| 124 |
enhanced_prompt = ", ".join(filter(None, enhanced_parts))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
return enhanced_prompt
|
| 126 |
|
| 127 |
def apply_spaces_decorator(func):
|
| 128 |
-
"""应用 spaces
|
| 129 |
if SPACES_AVAILABLE:
|
| 130 |
-
|
|
|
|
| 131 |
return func
|
| 132 |
|
| 133 |
def create_metadata_content(prompt, enhanced_prompt, seed, steps, cfg_scale, width, height, style):
|
|
@@ -147,20 +170,26 @@ Model: FLUX.1-dev
|
|
| 147 |
"""
|
| 148 |
|
| 149 |
@apply_spaces_decorator
|
| 150 |
-
def generate_image(prompt: str, style: str, negative_prompt: str = "", steps: int =
|
| 151 |
seed: int = -1, width: int = 1024, height: int = 1024, progress=gr.Progress()):
|
| 152 |
-
"""
|
| 153 |
-
if not prompt or prompt.strip() == "":
|
| 154 |
-
return None, "", ""
|
| 155 |
-
|
| 156 |
-
# 初始化模型
|
| 157 |
-
progress(0.1, desc="Loading model...")
|
| 158 |
-
if not initialize_model():
|
| 159 |
-
return None, "", "❌ Failed to load model"
|
| 160 |
-
|
| 161 |
-
progress(0.3, desc="Processing prompt...")
|
| 162 |
-
|
| 163 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
# 处理 seed
|
| 165 |
if seed == -1:
|
| 166 |
seed = random.randint(0, np.iinfo(np.int32).max)
|
|
@@ -175,29 +204,39 @@ def generate_image(prompt: str, style: str, negative_prompt: str = "", steps: in
|
|
| 175 |
# 生成参数(官方示例:generator 用 cpu)
|
| 176 |
generator = torch.Generator("cpu").manual_seed(seed)
|
| 177 |
|
| 178 |
-
progress(0.
|
| 179 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
# 直接使用标准 pipeline(无 Compel,添加 max_sequence_length)
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
| 192 |
image = result.images[0]
|
| 193 |
print("✅ Inference complete")
|
| 194 |
|
| 195 |
-
progress(0.9, desc="
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
# 保存图像
|
| 198 |
filename = f"IMG_{seed}.png"
|
| 199 |
filepath = os.path.join(SAVE_DIR, filename)
|
| 200 |
-
image.save(filepath,
|
| 201 |
|
| 202 |
# 创建元数据内容
|
| 203 |
metadata_content = create_metadata_content(
|
|
@@ -211,13 +250,20 @@ def generate_image(prompt: str, style: str, negative_prompt: str = "", steps: in
|
|
| 211 |
|
| 212 |
return image, generation_info, metadata_content
|
| 213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
except Exception as e:
|
|
|
|
| 215 |
error_msg = str(e)
|
| 216 |
print(f"❌ Generation error: {error_msg}")
|
| 217 |
print(traceback.format_exc())
|
| 218 |
return None, "", f"❌ Generation failed: {error_msg}"
|
| 219 |
-
|
| 220 |
-
# ===== CSS
|
| 221 |
css = """
|
| 222 |
/* 全局容器 */
|
| 223 |
.gradio-container {
|
|
@@ -466,33 +512,33 @@ def create_interface():
|
|
| 466 |
precision=0
|
| 467 |
)
|
| 468 |
|
| 469 |
-
#
|
| 470 |
with gr.Group(elem_classes=["controls-section"]):
|
| 471 |
width_input = gr.Slider(
|
| 472 |
label="Width",
|
| 473 |
minimum=512,
|
| 474 |
-
maximum=
|
| 475 |
value=1024,
|
| 476 |
step=64
|
| 477 |
)
|
| 478 |
|
| 479 |
-
#
|
| 480 |
with gr.Group(elem_classes=["controls-section"]):
|
| 481 |
height_input = gr.Slider(
|
| 482 |
label="Height",
|
| 483 |
minimum=512,
|
| 484 |
-
maximum=
|
| 485 |
value=1024,
|
| 486 |
step=64
|
| 487 |
)
|
| 488 |
|
| 489 |
-
#
|
| 490 |
with gr.Group(elem_classes=["controls-section"]):
|
| 491 |
steps_input = gr.Slider(
|
| 492 |
label="Steps",
|
| 493 |
minimum=10,
|
| 494 |
-
maximum=
|
| 495 |
-
value=
|
| 496 |
step=1
|
| 497 |
)
|
| 498 |
|
|
@@ -555,7 +601,10 @@ def create_interface():
|
|
| 555 |
|
| 556 |
if image is not None:
|
| 557 |
# 提取实际使用的 seed
|
| 558 |
-
|
|
|
|
|
|
|
|
|
|
| 559 |
|
| 560 |
return (
|
| 561 |
image, # 图片输出
|
|
@@ -582,7 +631,7 @@ def create_interface():
|
|
| 582 |
if image_data is not None:
|
| 583 |
filename = f"IMG_{seed_val}.png"
|
| 584 |
filepath = os.path.join(SAVE_DIR, filename)
|
| 585 |
-
image_data.save(filepath,
|
| 586 |
return filepath
|
| 587 |
return None
|
| 588 |
|
|
@@ -671,7 +720,7 @@ if __name__ == "__main__":
|
|
| 671 |
print(f"🔧 CUDA: {'✅ Available' if torch.cuda.is_available() else '❌ Not Available'}")
|
| 672 |
|
| 673 |
app = create_interface()
|
| 674 |
-
app.queue(max_size=
|
| 675 |
|
| 676 |
app.launch(
|
| 677 |
server_name="0.0.0.0",
|
|
|
|
| 17 |
from PIL import Image
|
| 18 |
import traceback
|
| 19 |
import numpy as np
|
| 20 |
+
import gc # 添加垃圾回收
|
| 21 |
|
| 22 |
# 移除 Compel(FLUX 不兼容,简化处理)
|
| 23 |
COMPEL_AVAILABLE = False
|
|
|
|
| 60 |
device = None
|
| 61 |
model_loaded = False
|
| 62 |
|
| 63 |
+
def cleanup_memory():
|
| 64 |
+
"""清理GPU内存"""
|
| 65 |
+
if torch.cuda.is_available():
|
| 66 |
+
torch.cuda.empty_cache()
|
| 67 |
+
torch.cuda.synchronize()
|
| 68 |
+
gc.collect()
|
| 69 |
+
|
| 70 |
def initialize_model():
|
| 71 |
+
"""优化的模型初始化函数(针对ZeroGPU优化)"""
|
| 72 |
global pipeline, device, model_loaded
|
| 73 |
|
| 74 |
if model_loaded and pipeline is not None:
|
| 75 |
return True
|
| 76 |
|
| 77 |
try:
|
| 78 |
+
# 清理内存
|
| 79 |
+
cleanup_memory()
|
| 80 |
+
|
| 81 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 82 |
print(f"🖥️ Using device: {device}")
|
| 83 |
|
| 84 |
print(f"📦 Loading fixed model: {FIXED_MODEL}")
|
| 85 |
|
| 86 |
+
# ZeroGPU优化的模型加载
|
| 87 |
pipeline = FluxPipeline.from_pretrained(
|
| 88 |
FIXED_MODEL,
|
| 89 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| 90 |
+
# 添加内存优化参数
|
| 91 |
+
variant=None,
|
| 92 |
+
use_safetensors=True
|
| 93 |
)
|
| 94 |
|
| 95 |
# 优化调度器(默认 FlowMatchEulerDiscreteScheduler)
|
|
|
|
| 98 |
)
|
| 99 |
pipeline = pipeline.to(device)
|
| 100 |
|
| 101 |
+
# ZeroGPU专用优化
|
| 102 |
if torch.cuda.is_available():
|
| 103 |
+
# 使用更保守的内存设置
|
| 104 |
+
pipeline.enable_model_cpu_offload() # 改用model_cpu_offload
|
| 105 |
+
pipeline.enable_vae_slicing()
|
| 106 |
+
pipeline.enable_vae_tiling() # 添加VAE tiling
|
| 107 |
|
| 108 |
# 移除 torch.compile(Spaces 不稳定)
|
| 109 |
+
print("✅ Model initialization complete (ZeroGPU optimized)")
|
| 110 |
model_loaded = True
|
| 111 |
return True
|
| 112 |
|
| 113 |
except Exception as e:
|
| 114 |
print(f"❌ Critical model loading error: {e}")
|
| 115 |
print(traceback.format_exc())
|
| 116 |
+
cleanup_memory()
|
| 117 |
model_loaded = False
|
| 118 |
return False
|
| 119 |
|
| 120 |
def enhance_prompt(prompt: str, style: str) -> str:
|
| 121 |
"""增强提示词"""
|
| 122 |
+
# 限制质量词数量,避免过长
|
| 123 |
+
quality_terms = ", ".join(QUALITY_ENHANCERS[:5]) # 只取前5个
|
| 124 |
|
| 125 |
style_terms = ""
|
| 126 |
if style in STYLE_ENHANCERS:
|
| 127 |
+
style_terms = ", " + ", ".join(STYLE_ENHANCERS[style][:3]) # 只取前3个
|
| 128 |
|
| 129 |
style_suffix = STYLE_PRESETS.get(style, "")
|
| 130 |
|
|
|
|
| 139 |
enhanced_parts.append(quality_terms)
|
| 140 |
|
| 141 |
enhanced_prompt = ", ".join(filter(None, enhanced_parts))
|
| 142 |
+
|
| 143 |
+
# 限制总长度,避免超出模型限制
|
| 144 |
+
if len(enhanced_prompt) > 500:
|
| 145 |
+
enhanced_prompt = enhanced_prompt[:500] + "..."
|
| 146 |
+
|
| 147 |
return enhanced_prompt
|
| 148 |
|
| 149 |
def apply_spaces_decorator(func):
|
| 150 |
+
"""应用 spaces 装饰器,增加更长的超时时间"""
|
| 151 |
if SPACES_AVAILABLE:
|
| 152 |
+
# 增加超时时间到120秒,并设置更大的内存限制
|
| 153 |
+
return spaces.GPU(duration=120)(func)
|
| 154 |
return func
|
| 155 |
|
| 156 |
def create_metadata_content(prompt, enhanced_prompt, seed, steps, cfg_scale, width, height, style):
|
|
|
|
| 170 |
"""
|
| 171 |
|
| 172 |
@apply_spaces_decorator
|
| 173 |
+
def generate_image(prompt: str, style: str, negative_prompt: str = "", steps: int = 20, cfg_scale: float = 3.5,
|
| 174 |
seed: int = -1, width: int = 1024, height: int = 1024, progress=gr.Progress()):
|
| 175 |
+
"""图像生成函数(ZeroGPU优化版本)"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
try:
|
| 177 |
+
if not prompt or prompt.strip() == "":
|
| 178 |
+
return None, "", "❌ Please enter a prompt"
|
| 179 |
+
|
| 180 |
+
# 参数验证和限制
|
| 181 |
+
steps = max(10, min(steps, 30)) # 限制步数范围,避免超时
|
| 182 |
+
width = min(width, 1024) # 限制最大尺寸
|
| 183 |
+
height = min(height, 1024)
|
| 184 |
+
|
| 185 |
+
# 初始化模型
|
| 186 |
+
progress(0.1, desc="Initializing model...")
|
| 187 |
+
if not initialize_model():
|
| 188 |
+
cleanup_memory()
|
| 189 |
+
return None, "", "❌ Failed to initialize model"
|
| 190 |
+
|
| 191 |
+
progress(0.2, desc="Processing prompt...")
|
| 192 |
+
|
| 193 |
# 处理 seed
|
| 194 |
if seed == -1:
|
| 195 |
seed = random.randint(0, np.iinfo(np.int32).max)
|
|
|
|
| 204 |
# 生成参数(官方示例:generator 用 cpu)
|
| 205 |
generator = torch.Generator("cpu").manual_seed(seed)
|
| 206 |
|
| 207 |
+
progress(0.4, desc="Starting generation...")
|
| 208 |
+
print(f"🔥 Starting inference: steps={steps}, guidance={cfg_scale}, size={width}x{height}")
|
| 209 |
+
|
| 210 |
+
# 清理内存
|
| 211 |
+
cleanup_memory()
|
| 212 |
|
| 213 |
# 直接使用标准 pipeline(无 Compel,添加 max_sequence_length)
|
| 214 |
+
with torch.no_grad(): # 确保不计算梯度
|
| 215 |
+
result = pipeline(
|
| 216 |
+
prompt=enhanced_prompt,
|
| 217 |
+
negative_prompt=negative_prompt,
|
| 218 |
+
num_inference_steps=steps,
|
| 219 |
+
guidance_scale=cfg_scale,
|
| 220 |
+
width=width,
|
| 221 |
+
height=height,
|
| 222 |
+
max_sequence_length=256, # 减少序列长度,节省内存
|
| 223 |
+
generator=generator,
|
| 224 |
+
output_type="pil"
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
image = result.images[0]
|
| 228 |
print("✅ Inference complete")
|
| 229 |
|
| 230 |
+
progress(0.9, desc="Finalizing...")
|
| 231 |
+
|
| 232 |
+
# 立即清理内存
|
| 233 |
+
del result
|
| 234 |
+
cleanup_memory()
|
| 235 |
|
| 236 |
# 保存图像
|
| 237 |
filename = f"IMG_{seed}.png"
|
| 238 |
filepath = os.path.join(SAVE_DIR, filename)
|
| 239 |
+
image.save(filepath, format="PNG", optimize=True)
|
| 240 |
|
| 241 |
# 创建元数据内容
|
| 242 |
metadata_content = create_metadata_content(
|
|
|
|
| 250 |
|
| 251 |
return image, generation_info, metadata_content
|
| 252 |
|
| 253 |
+
except torch.cuda.OutOfMemoryError as e:
|
| 254 |
+
cleanup_memory()
|
| 255 |
+
error_msg = "❌ GPU memory insufficient. Try reducing image size or steps."
|
| 256 |
+
print(f"CUDA OOM: {error_msg}")
|
| 257 |
+
return None, "", error_msg
|
| 258 |
+
|
| 259 |
except Exception as e:
|
| 260 |
+
cleanup_memory()
|
| 261 |
error_msg = str(e)
|
| 262 |
print(f"❌ Generation error: {error_msg}")
|
| 263 |
print(traceback.format_exc())
|
| 264 |
return None, "", f"❌ Generation failed: {error_msg}"
|
| 265 |
+
|
| 266 |
+
# ===== CSS 样式(保持不变)=====
|
| 267 |
css = """
|
| 268 |
/* 全局容器 */
|
| 269 |
.gradio-container {
|
|
|
|
| 512 |
precision=0
|
| 513 |
)
|
| 514 |
|
| 515 |
+
# 宽度选择(降低最大值)
|
| 516 |
with gr.Group(elem_classes=["controls-section"]):
|
| 517 |
width_input = gr.Slider(
|
| 518 |
label="Width",
|
| 519 |
minimum=512,
|
| 520 |
+
maximum=1024, # 降低最大值
|
| 521 |
value=1024,
|
| 522 |
step=64
|
| 523 |
)
|
| 524 |
|
| 525 |
+
# 高度选择(降低最大值)
|
| 526 |
with gr.Group(elem_classes=["controls-section"]):
|
| 527 |
height_input = gr.Slider(
|
| 528 |
label="Height",
|
| 529 |
minimum=512,
|
| 530 |
+
maximum=1024, # 降低最大值
|
| 531 |
value=1024,
|
| 532 |
step=64
|
| 533 |
)
|
| 534 |
|
| 535 |
+
# 高级参数(调整默认值)
|
| 536 |
with gr.Group(elem_classes=["controls-section"]):
|
| 537 |
steps_input = gr.Slider(
|
| 538 |
label="Steps",
|
| 539 |
minimum=10,
|
| 540 |
+
maximum=30, # 降低最大值
|
| 541 |
+
value=20, # 降低默认值
|
| 542 |
step=1
|
| 543 |
)
|
| 544 |
|
|
|
|
| 601 |
|
| 602 |
if image is not None:
|
| 603 |
# 提取实际使用的 seed
|
| 604 |
+
try:
|
| 605 |
+
actual_seed = seed if seed != -1 else int(info.split("Seed:")[1].split("|")[0].strip())
|
| 606 |
+
except:
|
| 607 |
+
actual_seed = seed if seed != -1 else random.randint(0, 999999)
|
| 608 |
|
| 609 |
return (
|
| 610 |
image, # 图片输出
|
|
|
|
| 631 |
if image_data is not None:
|
| 632 |
filename = f"IMG_{seed_val}.png"
|
| 633 |
filepath = os.path.join(SAVE_DIR, filename)
|
| 634 |
+
image_data.save(filepath, format="PNG", optimize=True)
|
| 635 |
return filepath
|
| 636 |
return None
|
| 637 |
|
|
|
|
| 720 |
print(f"🔧 CUDA: {'✅ Available' if torch.cuda.is_available() else '❌ Not Available'}")
|
| 721 |
|
| 722 |
app = create_interface()
|
| 723 |
+
app.queue(max_size=5, default_concurrency_limit=1) # 降低并发限制
|
| 724 |
|
| 725 |
app.launch(
|
| 726 |
server_name="0.0.0.0",
|