GGPENG's picture
Update app.py
a97d94f verified
raw
history blame
2.51 kB
import streamlit as st
import io
import requests
from PIL import Image
from io import BytesIO
# ----------------------------
# 配置 Hugging Face Inference API
# ----------------------------
import os
API_URL = "https://api-inference.huggingface.co/models/GGPENG/StyleDiffusion" # 替换为你上传的模型仓库
API_TOKEN = os.getenv("HF_TOKEN")
headers = {"Authorization": f"Bearer {API_TOKEN}"}
# ----------------------------
# Streamlit 页面设置
# ----------------------------
st.set_page_config(page_title="Fine-tuning style diffusion (API)", layout="wide")
st.title("Fine-tuning style diffusion 推理 Demo (API)")
st.write("只是训练了一个提示词 'A <new1> reference.'")
st.write("示例:A <new1> reference. New Year image with a rabbit as the main element, in a 2D or anime style, and a festive background")
# ----------------------------
# Prompt 输入
# ----------------------------
prompt = st.text_input(
"Prompt",
"A <new1> reference. New Year image with a rabbit as the main element, in a 2D or anime style, and a festive background"
)
# ----------------------------
# 参数调节
# ----------------------------
steps = st.slider("Steps", 10, 320, 100)
guidance = st.slider("Guidance", 1.0, 18.0, 6.0)
# ----------------------------
# 生成函数(调用 API)
# ----------------------------
def generate(prompt):
payload = {
"inputs": prompt,
"parameters": {
"num_inference_steps": steps,
"guidance_scale": guidance,
# "height": 512,
# "width": 512
}
}
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code != 200:
st.error(f"API请求失败:{response.status_code}, {response.text}")
return None
# 将返回的字节流或 Base64 数据转换为 PIL Image
try:
image = Image.open(BytesIO(response.content))
except:
st.error("生成图像失败,请检查模型是否支持图像输出。")
return None
return image
# ----------------------------
# 生成按钮
# ----------------------------
if st.button("Generate"):
with st.spinner("Generating via Hugging Face API..."):
image = generate(prompt)
if image:
st.image(image, caption="Result", width=512)
buf = io.BytesIO()
image.save(buf, format="PNG")
st.download_button(
"Download",
buf.getvalue(),
"result.png"
)