GGPENG's picture
Update app.py
435643f verified
raw
history blame
3.98 kB
import streamlit as st
import io
import requests
from PIL import Image
from io import BytesIO
# ----------------------------
# 配置 Hugging Face Inference API
# ----------------------------
import os
# API_URL = "https://api-inference.huggingface.co/models/GGPENG/StyleDiffusion" # 替换为你上传的模型仓库
# API_TOKEN = os.getenv("HF_TOKEN")
# headers = {"Authorization": f"Bearer {API_TOKEN}"}
# # ----------------------------
# # Streamlit 页面设置
# # ----------------------------
# st.set_page_config(page_title="Fine-tuning style diffusion (API)", layout="wide")
# st.title("Fine-tuning style diffusion 推理 Demo (API)")
# st.write("只是训练了一个提示词 'A <new1> reference.'")
# st.write("示例:A <new1> reference. New Year image with a rabbit as the main element, in a 2D or anime style, and a festive background")
# # ----------------------------
# # Prompt 输入
# # ----------------------------
# prompt = st.text_input(
# "Prompt",
# "A <new1> reference. New Year image with a rabbit as the main element, in a 2D or anime style, and a festive background"
# )
# # ----------------------------
# # 参数调节
# # ----------------------------
# steps = st.slider("Steps", 10, 320, 100)
# guidance = st.slider("Guidance", 1.0, 18.0, 6.0)
# # ----------------------------
# # 生成函数(调用 API)
# # ----------------------------
# def generate(prompt):
# payload = {
# "inputs": prompt,
# "parameters": {
# "num_inference_steps": steps,
# "guidance_scale": guidance,
# # "height": 512,
# # "width": 512
# }
# }
# response = requests.post(API_URL, headers=headers, json=payload)
# if response.status_code != 200:
# st.error(f"API请求失败:{response.status_code}, {response.text}")
# return None
# # 将返回的字节流或 Base64 数据转换为 PIL Image
# try:
# image = Image.open(BytesIO(response.content))
# except:
# st.error("生成图像失败,请检查模型是否支持图像输出。")
# return None
# return image
# # ----------------------------
# # 生成按钮
# # ----------------------------
# if st.button("Generate"):
# with st.spinner("Generating via Hugging Face API..."):
# image = generate(prompt)
# if image:
# st.image(image, caption="Result", width=512)
# buf = io.BytesIO()
# image.save(buf, format="PNG")
# st.download_button(
# "Download",
# buf.getvalue(),
# "result.png"
# )
from diffusers import StableDiffusionPipeline
import torch
from PIL import Image
import io
# ----------------------------
# 加载基础模型
# ----------------------------
base_model = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16)
pipe = pipe.to("cuda") # 有GPU加速
# ----------------------------
# 加载自定义微调权重
# ----------------------------
ckpt_path = "./pytorch_custom_diffusion_weights.bin"
# 假设你用的是 Diffusers 支持的 UNet 权重增量加载
pipe.unet.load_attn_procs(ckpt_path)
import streamlit as st
st.set_page_config(page_title="Custom Style Diffusion Demo", layout="wide")
st.title("Custom Style Diffusion 本地推理 Demo")
prompt = st.text_input("Prompt", "A <new1> reference. New Year image with a rabbit in 2D anime style")
steps = st.slider("Steps", 10, 320, 50)
guidance = st.slider("Guidance Scale", 1.0, 20.0, 7.5)
if st.button("Generate"):
with st.spinner("Generating image..."):
result = pipe(prompt, num_inference_steps=steps, guidance_scale=guidance)
image = result.images[0]
st.image(image, caption="Result", use_column_width=True)
buf = io.BytesIO()
image.save(buf, format="PNG")
st.download_button("Download Image", buf.getvalue(), "result.png")