Spaces:
Sleeping
Sleeping
| # app.py | |
| import streamlit as st | |
| import requests | |
| import io | |
| from PIL import Image | |
| import base64 | |
| import os | |
| # ---------------------------- | |
| # 页面设置 | |
| # ---------------------------- | |
| st.set_page_config(page_title="StyleDiffusion API Demo", layout="wide") | |
| st.title("Style Diffusion 推理 Demo (Hugging Face API)") | |
| st.write("直接调用 Hugging Face 公有模型,无需下载权重") | |
| # ---------------------------- | |
| # 用户输入 | |
| # ---------------------------- | |
| prompt = st.text_input( | |
| "Prompt", | |
| "A <new1> reference. New Year image with a rabbit in 2D anime style" | |
| ) | |
| steps = st.slider("Steps", 10, 320, 50) | |
| guidance = st.slider("Guidance Scale", 1.0, 20.0, 7.5) | |
| # ---------------------------- | |
| # Hugging Face API 配置 | |
| # ---------------------------- | |
| API_URL = "https://router.huggingface.co/models/GGPENG/StyleDiffusion" | |
| API_TOKEN = os.getenv("HF_TOKEN") # 如果模型是公有的,可以留空 | |
| headers = {"Authorization": f"Bearer {API_TOKEN}"} if API_TOKEN else {} | |
| # ---------------------------- | |
| # 生成函数 | |
| # ---------------------------- | |
| def generate(prompt, steps, guidance): | |
| payload = { | |
| "inputs": prompt, | |
| "parameters": { | |
| "num_inference_steps": steps, | |
| "guidance_scale": guidance | |
| } | |
| } | |
| response = requests.post(API_URL, headers=headers, json=payload) | |
| if response.status_code != 200: | |
| st.error(f"API请求失败:{response.status_code} {response.text}") | |
| return None | |
| try: | |
| res_json = response.json() | |
| if isinstance(res_json, list) and len(res_json) > 0: | |
| if "image_base64" in res_json[0]: | |
| img_bytes = base64.b64decode(res_json[0]["image_base64"]) | |
| elif "generated_image" in res_json[0]: | |
| img_bytes = base64.b64decode(res_json[0]["generated_image"]) | |
| else: | |
| st.error("API 返回不包含图像字段") | |
| return None | |
| return Image.open(io.BytesIO(img_bytes)) | |
| else: | |
| st.error("API 返回格式不正确") | |
| return None | |
| except Exception as e: | |
| st.error(f"解析图像失败: {e}") | |
| return None | |
| # ---------------------------- | |
| # 生成按钮 | |
| # ---------------------------- | |
| if st.button("Generate"): | |
| if not prompt.strip(): | |
| st.warning("请输入 Prompt") | |
| else: | |
| with st.spinner("生成中,请稍候..."): | |
| image = generate(prompt, steps, guidance) | |
| if image: | |
| st.image(image, caption="生成结果", use_column_width=True) | |
| buf = io.BytesIO() | |
| image.save(buf, format="PNG") | |
| st.download_button( | |
| "下载图片", | |
| buf.getvalue(), | |
| "result.png" | |
| ) |