File size: 3,978 Bytes
a97d94f
 
 
 
 
 
 
 
 
 
 
 
 
 
435643f
 
a97d94f
435643f
a97d94f
435643f
 
 
 
a97d94f
435643f
 
 
a97d94f
435643f
 
 
 
 
 
 
a97d94f
435643f
 
 
 
 
a97d94f
435643f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a97d94f
435643f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a97d94f
435643f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a97d94f
 
435643f
 
 
 
 
 
a97d94f
435643f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a97d94f
435643f
 
 
 
 
a97d94f
 
435643f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import streamlit as st
import io
import requests
from PIL import Image
from io import BytesIO

# ----------------------------
# 配置 Hugging Face Inference API
# ----------------------------
import os




# API_URL = "https://api-inference.huggingface.co/models/GGPENG/StyleDiffusion"  # 替换为你上传的模型仓库
# API_TOKEN = os.getenv("HF_TOKEN")

# headers = {"Authorization": f"Bearer {API_TOKEN}"}

# # ----------------------------
# # Streamlit 页面设置
# # ----------------------------
# st.set_page_config(page_title="Fine-tuning style diffusion (API)", layout="wide")

# st.title("Fine-tuning style diffusion 推理 Demo (API)")
# st.write("只是训练了一个提示词 'A <new1> reference.'")
# st.write("示例:A <new1> reference. New Year image with a rabbit as the main element, in a 2D or anime style, and a festive background")

# # ----------------------------
# # Prompt 输入
# # ----------------------------
# prompt = st.text_input(
#     "Prompt",
#     "A <new1> reference. New Year image with a rabbit as the main element, in a 2D or anime style, and a festive background"
# )

# # ----------------------------
# # 参数调节
# # ----------------------------
# steps = st.slider("Steps", 10, 320, 100)
# guidance = st.slider("Guidance", 1.0, 18.0, 6.0)

# # ----------------------------
# # 生成函数(调用 API)
# # ----------------------------
# def generate(prompt):
#     payload = {
#         "inputs": prompt,
#         "parameters": {
#             "num_inference_steps": steps,
#             "guidance_scale": guidance,
#             # "height": 512,
#             # "width": 512
#         }
#     }

#     response = requests.post(API_URL, headers=headers, json=payload)
    
#     if response.status_code != 200:
#         st.error(f"API请求失败:{response.status_code}, {response.text}")
#         return None
    
#     # 将返回的字节流或 Base64 数据转换为 PIL Image
#     try:
#         image = Image.open(BytesIO(response.content))
#     except:
#         st.error("生成图像失败,请检查模型是否支持图像输出。")
#         return None
#     return image

# # ----------------------------
# # 生成按钮
# # ----------------------------
# if st.button("Generate"):
#     with st.spinner("Generating via Hugging Face API..."):
#         image = generate(prompt)
    
#     if image:
#         st.image(image, caption="Result", width=512)
#         buf = io.BytesIO()
#         image.save(buf, format="PNG")
#         st.download_button(
#             "Download",
#             buf.getvalue(),
#             "result.png"
#         )



from diffusers import StableDiffusionPipeline
import torch
from PIL import Image
import io

# ----------------------------
# 加载基础模型
# ----------------------------
base_model = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16)
pipe = pipe.to("cuda")  # 有GPU加速

# ----------------------------
# 加载自定义微调权重
# ----------------------------
ckpt_path = "./pytorch_custom_diffusion_weights.bin"

# 假设你用的是 Diffusers 支持的 UNet 权重增量加载
pipe.unet.load_attn_procs(ckpt_path)


import streamlit as st

st.set_page_config(page_title="Custom Style Diffusion Demo", layout="wide")
st.title("Custom Style Diffusion 本地推理 Demo")

prompt = st.text_input("Prompt", "A <new1> reference. New Year image with a rabbit in 2D anime style")
steps = st.slider("Steps", 10, 320, 50)
guidance = st.slider("Guidance Scale", 1.0, 20.0, 7.5)

if st.button("Generate"):
    with st.spinner("Generating image..."):
        result = pipe(prompt, num_inference_steps=steps, guidance_scale=guidance)
        image = result.images[0]
        st.image(image, caption="Result", use_column_width=True)

        buf = io.BytesIO()
        image.save(buf, format="PNG")
        st.download_button("Download Image", buf.getvalue(), "result.png")