GGPENG commited on
Commit
55e61c2
·
verified ·
1 Parent(s): 435643f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -113
app.py CHANGED
@@ -1,128 +1,87 @@
 
1
  import streamlit as st
2
- import io
3
  import requests
 
4
  from PIL import Image
5
- from io import BytesIO
 
6
 
7
  # ----------------------------
8
- # Hugging Face Inference API
9
  # ----------------------------
10
- import os
11
-
12
-
13
-
14
-
15
- # API_URL = "https://api-inference.huggingface.co/models/GGPENG/StyleDiffusion" # 替换为你上传的模型仓库
16
- # API_TOKEN = os.getenv("HF_TOKEN")
17
-
18
- # headers = {"Authorization": f"Bearer {API_TOKEN}"}
19
-
20
- # # ----------------------------
21
- # # Streamlit 页面设置
22
- # # ----------------------------
23
- # st.set_page_config(page_title="Fine-tuning style diffusion (API)", layout="wide")
24
-
25
- # st.title("Fine-tuning style diffusion 推理 Demo (API)")
26
- # st.write("只是训练了一个提示词 'A <new1> reference.'")
27
- # st.write("示例:A <new1> reference. New Year image with a rabbit as the main element, in a 2D or anime style, and a festive background")
28
-
29
- # # ----------------------------
30
- # # Prompt 输入
31
- # # ----------------------------
32
- # prompt = st.text_input(
33
- # "Prompt",
34
- # "A <new1> reference. New Year image with a rabbit as the main element, in a 2D or anime style, and a festive background"
35
- # )
36
-
37
- # # ----------------------------
38
- # # 参数调节
39
- # # ----------------------------
40
- # steps = st.slider("Steps", 10, 320, 100)
41
- # guidance = st.slider("Guidance", 1.0, 18.0, 6.0)
42
-
43
- # # ----------------------------
44
- # # 生成函数(调用 API)
45
- # # ----------------------------
46
- # def generate(prompt):
47
- # payload = {
48
- # "inputs": prompt,
49
- # "parameters": {
50
- # "num_inference_steps": steps,
51
- # "guidance_scale": guidance,
52
- # # "height": 512,
53
- # # "width": 512
54
- # }
55
- # }
56
-
57
- # response = requests.post(API_URL, headers=headers, json=payload)
58
-
59
- # if response.status_code != 200:
60
- # st.error(f"API请求失败:{response.status_code}, {response.text}")
61
- # return None
62
-
63
- # # 将返回的字节流或 Base64 数据转换为 PIL Image
64
- # try:
65
- # image = Image.open(BytesIO(response.content))
66
- # except:
67
- # st.error("生成图像失败,请检查模型是否支持图像输出。")
68
- # return None
69
- # return image
70
-
71
- # # ----------------------------
72
- # # 生成按钮
73
- # # ----------------------------
74
- # if st.button("Generate"):
75
- # with st.spinner("Generating via Hugging Face API..."):
76
- # image = generate(prompt)
77
-
78
- # if image:
79
- # st.image(image, caption="Result", width=512)
80
- # buf = io.BytesIO()
81
- # image.save(buf, format="PNG")
82
- # st.download_button(
83
- # "Download",
84
- # buf.getvalue(),
85
- # "result.png"
86
- # )
87
-
88
-
89
-
90
- from diffusers import StableDiffusionPipeline
91
- import torch
92
- from PIL import Image
93
- import io
94
 
95
  # ----------------------------
96
- # 加载基础模型
97
  # ----------------------------
98
- base_model = "runwayml/stable-diffusion-v1-5"
99
- pipe = StableDiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16)
100
- pipe = pipe.to("cuda") # 有GPU加速
 
 
 
101
 
102
  # ----------------------------
103
- # 加载自定义微调权重
104
  # ----------------------------
105
- ckpt_path = "./pytorch_custom_diffusion_weights.bin"
106
-
107
- # 假设你用的是 Diffusers 支持的 UNet 权重增量加载
108
- pipe.unet.load_attn_procs(ckpt_path)
109
-
110
-
111
- import streamlit as st
112
 
113
- st.set_page_config(page_title="Custom Style Diffusion Demo", layout="wide")
114
- st.title("Custom Style Diffusion 本地推理 Demo")
115
-
116
- prompt = st.text_input("Prompt", "A <new1> reference. New Year image with a rabbit in 2D anime style")
117
- steps = st.slider("Steps", 10, 320, 50)
118
- guidance = st.slider("Guidance Scale", 1.0, 20.0, 7.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
 
 
 
120
  if st.button("Generate"):
121
- with st.spinner("Generating image..."):
122
- result = pipe(prompt, num_inference_steps=steps, guidance_scale=guidance)
123
- image = result.images[0]
124
- st.image(image, caption="Result", use_column_width=True)
125
-
126
- buf = io.BytesIO()
127
- image.save(buf, format="PNG")
128
- st.download_button("Download Image", buf.getvalue(), "result.png")
 
 
 
 
 
 
 
 
1
+ # app.py
2
  import streamlit as st
 
3
  import requests
4
+ import io
5
  from PIL import Image
6
+ import base64
7
+ import os
8
 
9
  # ----------------------------
10
+ # 页面设
11
  # ----------------------------
12
+ st.set_page_config(page_title="StyleDiffusion API Demo", layout="wide")
13
+ st.title("Style Diffusion 推理 Demo (Hugging Face API)")
14
+ st.write("直接调用 Hugging Face 公有模型,无需下载权重")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # ----------------------------
17
+ # 用户输入
18
  # ----------------------------
19
+ prompt = st.text_input(
20
+ "Prompt",
21
+ "A <new1> reference. New Year image with a rabbit in 2D anime style"
22
+ )
23
+ steps = st.slider("Steps", 10, 320, 50)
24
+ guidance = st.slider("Guidance Scale", 1.0, 20.0, 7.5)
25
 
26
  # ----------------------------
27
+ # Hugging Face API 配置
28
  # ----------------------------
29
+ API_URL = "https://router.huggingface.co/models/GGPENG/StyleDiffusion"
30
+ API_TOKEN = os.getenv("HF_TOKEN") # 如果模型是公有的,可以留空
31
+ headers = {"Authorization": f"Bearer {API_TOKEN}"} if API_TOKEN else {}
 
 
 
 
32
 
33
+ # ----------------------------
34
+ # 生成函数
35
+ # ----------------------------
36
+ def generate(prompt, steps, guidance):
37
+ payload = {
38
+ "inputs": prompt,
39
+ "parameters": {
40
+ "num_inference_steps": steps,
41
+ "guidance_scale": guidance
42
+ }
43
+ }
44
+
45
+ response = requests.post(API_URL, headers=headers, json=payload)
46
+
47
+ if response.status_code != 200:
48
+ st.error(f"API请求失败:{response.status_code} {response.text}")
49
+ return None
50
+
51
+ try:
52
+ res_json = response.json()
53
+ if isinstance(res_json, list) and len(res_json) > 0:
54
+ if "image_base64" in res_json[0]:
55
+ img_bytes = base64.b64decode(res_json[0]["image_base64"])
56
+ elif "generated_image" in res_json[0]:
57
+ img_bytes = base64.b64decode(res_json[0]["generated_image"])
58
+ else:
59
+ st.error("API 返回不包含图像字段")
60
+ return None
61
+ return Image.open(io.BytesIO(img_bytes))
62
+ else:
63
+ st.error("API 返回格式不正确")
64
+ return None
65
+ except Exception as e:
66
+ st.error(f"解析图像失败: {e}")
67
+ return None
68
 
69
+ # ----------------------------
70
+ # 生成按钮
71
+ # ----------------------------
72
  if st.button("Generate"):
73
+ if not prompt.strip():
74
+ st.warning("请输入 Prompt")
75
+ else:
76
+ with st.spinner("生成中,请稍候..."):
77
+ image = generate(prompt, steps, guidance)
78
+
79
+ if image:
80
+ st.image(image, caption="生成结果", use_column_width=True)
81
+ buf = io.BytesIO()
82
+ image.save(buf, format="PNG")
83
+ st.download_button(
84
+ "下载图片",
85
+ buf.getvalue(),
86
+ "result.png"
87
+ )