lucksadasd commited on
Commit
a89ee91
·
verified ·
1 Parent(s): 16f7650

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -4
app.py CHANGED
@@ -15,17 +15,34 @@ llm = pipeline(
15
  device=0 if torch.cuda.is_available() else -1
16
  )
17
 
18
- # Step 2: 加载 Stable Diffusion 模型
19
- # 移除无效的 revision 参数,仅使用 torch_dtype 加速加载
20
  sd_v15 = StableDiffusionPipeline.from_pretrained(
21
  "runwayml/stable-diffusion-v1-5",
22
  torch_dtype=torch.float16
23
  )
 
 
 
 
 
 
 
 
 
24
  sd_v15 = sd_v15.to("cuda" if torch.cuda.is_available() else "cpu")
25
 
 
26
  sd_xl = StableDiffusionPipeline.from_pretrained(
27
- "stabilityai/stable-diffusion-xl-base-1.0"
 
28
  )
 
 
 
 
 
 
29
  sd_xl = sd_xl.to("cuda" if torch.cuda.is_available() else "cpu")
30
 
31
  # 可选:语音输入模块,使用 Whisper
@@ -82,7 +99,6 @@ with gr.Blocks(title="Prompt-to-Image Generator") as demo:
82
  placeholder="排除内容(如:低分辨率、水印)"
83
  )
84
  use_voice = gr.Checkbox(label="启用语音输入(加分项)")
85
- # 移除 'source' 参数以兼容 Gradio 版本
86
  audio_input = gr.Audio(type="filepath", label="语音输入")
87
  generate_btn = gr.Button("生成图像")
88
  with gr.Column():
 
15
  device=0 if torch.cuda.is_available() else -1
16
  )
17
 
18
+ # Step 2: 加载 Stable Diffusion 模型并优化以加速推理
19
+ # SD v1.5
20
  sd_v15 = StableDiffusionPipeline.from_pretrained(
21
  "runwayml/stable-diffusion-v1-5",
22
  torch_dtype=torch.float16
23
  )
24
+ # 开启注意力切片,减少显存峰值
25
+ sd_v15.enable_attention_slicing()
26
+ # 如果安装了 xformers,启用更高效的注意力实现
27
+ try:
28
+ sd_v15.enable_xformers_memory_efficient_attention()
29
+ except Exception:
30
+ pass
31
+ # 启用CPU内存卸载,减轻GPU显存压力
32
+ sd_v15.enable_model_cpu_offload()
33
  sd_v15 = sd_v15.to("cuda" if torch.cuda.is_available() else "cpu")
34
 
35
+ # SD XL
36
  sd_xl = StableDiffusionPipeline.from_pretrained(
37
+ "stabilityai/stable-diffusion-xl-base-1.0",
38
+ torch_dtype=torch.float16
39
  )
40
+ sd_xl.enable_attention_slicing()
41
+ try:
42
+ sd_xl.enable_xformers_memory_efficient_attention()
43
+ except Exception:
44
+ pass
45
+ sd_xl.enable_model_cpu_offload()
46
  sd_xl = sd_xl.to("cuda" if torch.cuda.is_available() else "cpu")
47
 
48
  # 可选:语音输入模块,使用 Whisper
 
99
  placeholder="排除内容(如:低分辨率、水印)"
100
  )
101
  use_voice = gr.Checkbox(label="启用语音输入(加分项)")
 
102
  audio_input = gr.Audio(type="filepath", label="语音输入")
103
  generate_btn = gr.Button("生成图像")
104
  with gr.Column():