lucksadasd commited on
Commit
fd6c129
·
verified ·
1 Parent(s): 594c0d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -7
app.py CHANGED
@@ -16,13 +16,22 @@ llm = pipeline(
16
  )
17
 
18
  # Step 2: 加载并量化 Stable Diffusion 模型以加速推理
19
- # 使用 8-bit 量化和自动设备映射
20
- device = "cuda" if torch.cuda.is_available() else "cpu"
21
- load_kwargs = {
22
- "torch_dtype": torch.float16 if device == "cuda" else torch.float32,
23
- "device_map": "auto",
24
- "load_in_8bit": True # 需要安装 bitsandbytes
25
- }
 
 
 
 
 
 
 
 
 
26
 
27
  # SD v1.5
28
  sd_v15 = StableDiffusionPipeline.from_pretrained(
@@ -30,6 +39,7 @@ sd_v15 = StableDiffusionPipeline.from_pretrained(
30
  **load_kwargs
31
  )
32
  sd_v15.scheduler = DPMSolverMultistepScheduler.from_config(sd_v15.scheduler.config)
 
33
 
34
  # SD XL
35
  sd_xl = StableDiffusionPipeline.from_pretrained(
@@ -37,6 +47,7 @@ sd_xl = StableDiffusionPipeline.from_pretrained(
37
  **load_kwargs
38
  )
39
  sd_xl.scheduler = DPMSolverMultistepScheduler.from_config(sd_xl.scheduler.config)
 
40
 
41
  # 可选:语音输入模块,使用 Whisper
42
  asr = pipeline(
@@ -122,3 +133,4 @@ with gr.Blocks(title="Prompt-to-Image Generator") as demo:
122
  # Step 4: 启动应用
123
  if __name__ == "__main__":
124
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
16
  )
17
 
18
  # Step 2: 加载并量化 Stable Diffusion 模型以加速推理
19
+ # 根据硬件环境选择加载方式:
20
+ # - GPU 环境:8-bit 量化 + 自动设备映射
21
+ # - CPU 环境:浮点32 + 平衡设备映射
22
+ if torch.cuda.is_available():
23
+ device = "cuda"
24
+ load_kwargs = {
25
+ "torch_dtype": torch.float16,
26
+ "device_map": "auto",
27
+ "load_in_8bit": True # 需要安装 bitsandbytes
28
+ }
29
+ else:
30
+ device = "cpu"
31
+ load_kwargs = {
32
+ "torch_dtype": torch.float32,
33
+ "device_map": "balanced"
34
+ }
35
 
36
  # SD v1.5
37
  sd_v15 = StableDiffusionPipeline.from_pretrained(
 
39
  **load_kwargs
40
  )
41
  sd_v15.scheduler = DPMSolverMultistepScheduler.from_config(sd_v15.scheduler.config)
42
+ sd_v15 = sd_v15.to(device)
43
 
44
  # SD XL
45
  sd_xl = StableDiffusionPipeline.from_pretrained(
 
47
  **load_kwargs
48
  )
49
  sd_xl.scheduler = DPMSolverMultistepScheduler.from_config(sd_xl.scheduler.config)
50
+ sd_xl = sd_xl.to(device)
51
 
52
  # 可选:语音输入模块,使用 Whisper
53
  asr = pipeline(
 
133
  # Step 4: 启动应用
134
  if __name__ == "__main__":
135
  demo.launch(server_name="0.0.0.0", server_port=7860)
136
+