GGPENG's picture
Update app.py
55e61c2 verified
raw
history blame
2.79 kB
# app.py
import streamlit as st
import requests
import io
from PIL import Image
import base64
import os
# ----------------------------
# 页面设置
# ----------------------------
st.set_page_config(page_title="StyleDiffusion API Demo", layout="wide")
st.title("Style Diffusion 推理 Demo (Hugging Face API)")
st.write("直接调用 Hugging Face 公有模型,无需下载权重")
# ----------------------------
# 用户输入
# ----------------------------
prompt = st.text_input(
"Prompt",
"A <new1> reference. New Year image with a rabbit in 2D anime style"
)
steps = st.slider("Steps", 10, 320, 50)
guidance = st.slider("Guidance Scale", 1.0, 20.0, 7.5)
# ----------------------------
# Hugging Face API 配置
# ----------------------------
API_URL = "https://router.huggingface.co/models/GGPENG/StyleDiffusion"
API_TOKEN = os.getenv("HF_TOKEN") # 如果模型是公有的,可以留空
headers = {"Authorization": f"Bearer {API_TOKEN}"} if API_TOKEN else {}
# ----------------------------
# 生成函数
# ----------------------------
def generate(prompt, steps, guidance):
payload = {
"inputs": prompt,
"parameters": {
"num_inference_steps": steps,
"guidance_scale": guidance
}
}
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code != 200:
st.error(f"API请求失败:{response.status_code} {response.text}")
return None
try:
res_json = response.json()
if isinstance(res_json, list) and len(res_json) > 0:
if "image_base64" in res_json[0]:
img_bytes = base64.b64decode(res_json[0]["image_base64"])
elif "generated_image" in res_json[0]:
img_bytes = base64.b64decode(res_json[0]["generated_image"])
else:
st.error("API 返回不包含图像字段")
return None
return Image.open(io.BytesIO(img_bytes))
else:
st.error("API 返回格式不正确")
return None
except Exception as e:
st.error(f"解析图像失败: {e}")
return None
# ----------------------------
# 生成按钮
# ----------------------------
if st.button("Generate"):
if not prompt.strip():
st.warning("请输入 Prompt")
else:
with st.spinner("生成中,请稍候..."):
image = generate(prompt, steps, guidance)
if image:
st.image(image, caption="生成结果", use_column_width=True)
buf = io.BytesIO()
image.save(buf, format="PNG")
st.download_button(
"下载图片",
buf.getvalue(),
"result.png"
)