GGPENG commited on
Commit
6cc9580
·
verified ·
1 Parent(s): 4020064

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -77
app.py CHANGED
@@ -1,87 +1,47 @@
1
  # app.py
2
- import streamlit as st
3
- import requests
4
- import io
5
- from PIL import Image
6
- import base64
7
- import os
8
 
9
- # ----------------------------
10
- # 页面设置
11
- # ----------------------------
12
- st.set_page_config(page_title="StyleDiffusion API Demo", layout="wide")
13
- st.title("Style Diffusion 推理 Demo (Hugging Face API)")
14
- st.write("直接调用 Hugging Face 公有模型,无需下载权重")
15
 
16
- # ----------------------------
17
- # 用户输入
18
- # ----------------------------
19
- prompt = st.text_input(
20
- "Prompt",
21
- "A <new1> reference. New Year image with a rabbit in 2D anime style"
22
  )
23
- steps = st.slider("Steps", 10, 320, 50)
24
- guidance = st.slider("Guidance Scale", 1.0, 20.0, 7.5)
25
 
26
- # ----------------------------
27
- # Hugging Face API 配置
28
- # ----------------------------
29
- API_URL = "https://router.huggingface.co/models/GGPENG/StyleDiffusion"
30
- API_TOKEN = os.getenv("HF_TOKEN") # 如果模型是公有的,可以留空
31
- headers = {"Authorization": f"Bearer {API_TOKEN}"} if API_TOKEN else {}
32
 
33
- # ----------------------------
34
- # 生成函数
35
- # ----------------------------
36
- def generate(prompt, steps, guidance):
37
- payload = {
38
- "inputs": prompt,
39
- "parameters": {
40
- "num_inference_steps": steps,
41
- "guidance_scale": guidance
42
- }
43
- }
44
-
45
- response = requests.post(API_URL, headers=headers, json=payload)
46
 
47
- if response.status_code != 200:
48
- st.error(f"API请求失败:{response.status_code} {response.text}")
49
- return None
50
 
51
- try:
52
- res_json = response.json()
53
- if isinstance(res_json, list) and len(res_json) > 0:
54
- if "image_base64" in res_json[0]:
55
- img_bytes = base64.b64decode(res_json[0]["image_base64"])
56
- elif "generated_image" in res_json[0]:
57
- img_bytes = base64.b64decode(res_json[0]["generated_image"])
58
- else:
59
- st.error("API 返回不包含图像字段")
60
- return None
61
- return Image.open(io.BytesIO(img_bytes))
62
- else:
63
- st.error("API 返回格式不正确")
64
- return None
65
- except Exception as e:
66
- st.error(f"解析图像失败: {e}")
67
- return None
68
 
69
- # ----------------------------
70
- # 生成按钮
71
- # ----------------------------
72
- if st.button("Generate"):
73
- if not prompt.strip():
74
- st.warning("请输入 Prompt")
75
- else:
76
- with st.spinner("生成中,请稍候..."):
77
- image = generate(prompt, steps, guidance)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- if image:
80
- st.image(image, caption="生成结果", use_column_width=True)
81
- buf = io.BytesIO()
82
- image.save(buf, format="PNG")
83
- st.download_button(
84
- "下载图片",
85
- buf.getvalue(),
86
- "result.png"
87
- )
 
1
  # app.py
2
+ import torch
3
+ import gradio as gr
4
+ from diffusers import DiffusionPipeline
 
 
 
5
 
6
+ print("Loading pipeline...")
 
 
 
 
 
7
 
8
+ pipe = DiffusionPipeline.from_pretrained(
9
+ "GGPENG/StyleDiffusion",
10
+ torch_dtype=torch.bfloat16,
 
 
 
11
  )
 
 
12
 
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
14
 
15
+ pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ pipe.unet.load_attn_procs(
18
+ "./pytorch_custom_diffusion_weights.bin"
19
+ )
20
 
21
+ def generate(prompt, steps, guidance):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ image = pipe(
24
+ prompt,
25
+ num_inference_steps=steps,
26
+ guidance_scale=guidance,
27
+ eta=1
28
+ ).images[0]
29
+
30
+ return image
31
+
32
+
33
+ demo = gr.Interface(
34
+ fn=generate,
35
+ inputs=[
36
+ gr.Textbox(
37
+ label="Prompt",
38
+ value="A <new1> reference. New Year image with a rabbit as the main element"
39
+ ),
40
+ gr.Slider(10, 320, value=100, label="Steps"),
41
+ gr.Slider(1, 18, value=6, label="Guidance"),
42
+ ],
43
+ outputs=gr.Image(),
44
+ title="Fine-tuning style diffusion Demo"
45
+ )
46
 
47
+ demo.launch()