boheng.xie commited on
Commit
6ea79b4
·
1 Parent(s): cbdc720

add requirement file

Browse files
Files changed (2) hide show
  1. app.py +25 -10
  2. requirement.txt +6 -0
app.py CHANGED
@@ -3,25 +3,40 @@ import gradio as gr
3
  from diffusers import StableDiffusionPipeline
4
  import torch
5
 
6
- # 初始化模型(免费版Space选择小模型)
 
 
 
 
7
  model = StableDiffusionPipeline.from_pretrained(
8
  "CompVis/stable-diffusion-v1-4",
9
- torch_dtype=torch.float16,
10
  use_safetensors=True,
11
- low_cpu_mem_usage=True # 节省内存
12
- ).to("cuda") if torch.cuda.is_available() else None # 兼容CPU模式
 
 
 
13
 
14
  def generate(prompt):
15
- if model is None:
16
- return "当前环境不支持图像生成"
17
- image = model(prompt, num_inference_steps=15).images[0]
18
- return image
 
 
 
19
 
20
  # 基础界面
21
  interface = gr.Interface(
22
  fn=generate,
23
  inputs=gr.Textbox(label="输入表情包描述"),
24
- outputs=gr.Image(label="生成结果"),
 
 
 
25
  examples=[["一只生气的猫"], ["跳舞的香蕉"]]
26
  )
27
- interface.launch()
 
 
 
3
  from diffusers import StableDiffusionPipeline
4
  import torch
5
 
6
+ # Space环境中检测设备
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ print(f"使用设备: {device}")
9
+
10
+ # 初始化模型(适配Space环境)
11
  model = StableDiffusionPipeline.from_pretrained(
12
  "CompVis/stable-diffusion-v1-4",
13
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
14
  use_safetensors=True,
15
+ safety_checker=None, # 根据需要可开启或关闭
16
+ variant="fp16" if device == "cuda" else None,
17
+ low_cpu_mem_usage=True
18
+ )
19
+ model = model.to(device)
20
 
21
  def generate(prompt):
22
+ if not prompt:
23
+ return None, "请输入有效的提示词"
24
+ try:
25
+ image = model(prompt, num_inference_steps=15).images[0]
26
+ return image, "生成成功!"
27
+ except Exception as e:
28
+ return None, f"生成失败: {str(e)}"
29
 
30
  # 基础界面
31
  interface = gr.Interface(
32
  fn=generate,
33
  inputs=gr.Textbox(label="输入表情包描述"),
34
+ outputs=[
35
+ gr.Image(label="生成结果"),
36
+ gr.Textbox(label="状态")
37
+ ],
38
  examples=[["一只生气的猫"], ["跳舞的香蕉"]]
39
  )
40
+
41
+ # Hugging Face Space推荐的启动方式
42
+ interface.launch(share=False, server_name="0.0.0.0")
requirement.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ diffusers>=0.23.1
2
+ transformers>=4.30.0
3
+ accelerate>=0.20.0
4
+ torch>=2.0.0
5
+ gradio>=3.32.0
6
+ safetensors