Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,13 +16,22 @@ llm = pipeline(
|
|
| 16 |
)
|
| 17 |
|
| 18 |
# Step 2: 加载并量化 Stable Diffusion 模型以加速推理
|
| 19 |
-
#
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
# SD v1.5
|
| 28 |
sd_v15 = StableDiffusionPipeline.from_pretrained(
|
|
@@ -30,6 +39,7 @@ sd_v15 = StableDiffusionPipeline.from_pretrained(
|
|
| 30 |
**load_kwargs
|
| 31 |
)
|
| 32 |
sd_v15.scheduler = DPMSolverMultistepScheduler.from_config(sd_v15.scheduler.config)
|
|
|
|
| 33 |
|
| 34 |
# SD XL
|
| 35 |
sd_xl = StableDiffusionPipeline.from_pretrained(
|
|
@@ -37,6 +47,7 @@ sd_xl = StableDiffusionPipeline.from_pretrained(
|
|
| 37 |
**load_kwargs
|
| 38 |
)
|
| 39 |
sd_xl.scheduler = DPMSolverMultistepScheduler.from_config(sd_xl.scheduler.config)
|
|
|
|
| 40 |
|
| 41 |
# 可选:语音输入模块,使用 Whisper
|
| 42 |
asr = pipeline(
|
|
@@ -122,3 +133,4 @@ with gr.Blocks(title="Prompt-to-Image Generator") as demo:
|
|
| 122 |
# Step 4: 启动应用
|
| 123 |
if __name__ == "__main__":
|
| 124 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
|
|
| 16 |
)
|
| 17 |
|
| 18 |
# Step 2: 加载并量化 Stable Diffusion 模型以加速推理
|
| 19 |
+
# 根据硬件环境选择加载方式:
|
| 20 |
+
# - GPU 环境:8-bit 量化 + 自动设备映射
|
| 21 |
+
# - CPU 环境:浮点32 + 平衡设备映射
|
| 22 |
+
if torch.cuda.is_available():
|
| 23 |
+
device = "cuda"
|
| 24 |
+
load_kwargs = {
|
| 25 |
+
"torch_dtype": torch.float16,
|
| 26 |
+
"device_map": "auto",
|
| 27 |
+
"load_in_8bit": True # 需要安装 bitsandbytes
|
| 28 |
+
}
|
| 29 |
+
else:
|
| 30 |
+
device = "cpu"
|
| 31 |
+
load_kwargs = {
|
| 32 |
+
"torch_dtype": torch.float32,
|
| 33 |
+
"device_map": "balanced"
|
| 34 |
+
}
|
| 35 |
|
| 36 |
# SD v1.5
|
| 37 |
sd_v15 = StableDiffusionPipeline.from_pretrained(
|
|
|
|
| 39 |
**load_kwargs
|
| 40 |
)
|
| 41 |
sd_v15.scheduler = DPMSolverMultistepScheduler.from_config(sd_v15.scheduler.config)
|
| 42 |
+
sd_v15 = sd_v15.to(device)
|
| 43 |
|
| 44 |
# SD XL
|
| 45 |
sd_xl = StableDiffusionPipeline.from_pretrained(
|
|
|
|
| 47 |
**load_kwargs
|
| 48 |
)
|
| 49 |
sd_xl.scheduler = DPMSolverMultistepScheduler.from_config(sd_xl.scheduler.config)
|
| 50 |
+
sd_xl = sd_xl.to(device)
|
| 51 |
|
| 52 |
# 可选:语音输入模块,使用 Whisper
|
| 53 |
asr = pipeline(
|
|
|
|
| 133 |
# Step 4: 启动应用
|
| 134 |
if __name__ == "__main__":
|
| 135 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 136 |
+
|