sjdnjn commited on
Commit
58b0886
·
verified ·
1 Parent(s): a095def

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from transformers import pipeline
5
+ from diffusers import StableDiffusionPipeline
6
+
7
+ # 如果需要使用 Hugging Face 访问令牌,取消下面一行的注释并设置环境变量 HUGGINGFACE_TOKEN
8
+ # from huggingface_hub import login
9
+ # login(token=os.getenv("HUGGINGFACE_TOKEN"))
10
+
11
+ # Step 1: Prompt-to-Prompt 模块,使用 Flan-T5 生成结构化提示词
12
+ llm = pipeline(
13
+ "text2text-generation",
14
+ model="google/flan-t5-large",
15
+ device=0 if torch.cuda.is_available() else -1
16
+ )
17
+
18
+ # Step 2: 加载 Stable Diffusion 模型
19
+ # 移除无效的 revision 参数,仅使用 torch_dtype 加速加载
20
+ sd_v15 = StableDiffusionPipeline.from_pretrained(
21
+ "runwayml/stable-diffusion-v1-5",
22
+ torch_dtype=torch.float16
23
+ )
24
+ sd_v15 = sd_v15.to("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ sd_xl = StableDiffusionPipeline.from_pretrained(
27
+ "stabilityai/stable-diffusion-xl-base-1.0"
28
+ )
29
+ sd_xl = sd_xl.to("cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+ # 可选:语音输入模块,使用 Whisper
32
+ asr = pipeline(
33
+ "automatic-speech-recognition",
34
+ model="openai/whisper-base",
35
+ device=0 if torch.cuda.is_available() else -1
36
+ )
37
+
38
+ def transcribe(audio_path):
39
+ text = asr(audio_path)["text"]
40
+ return text
41
+
42
+
43
+ def generate(description, model_choice, guidance_scale, negative_prompt, style):
44
+ # 构造给 LLM 的指令
45
+ instruction = (
46
+ f"请将以下简短描述扩展为 Stable Diffusion 友好的提示词,包含细节和风格:\n"
47
+ f"描述: '{description}'\n"
48
+ f"风格: '{style}'"
49
+ )
50
+ result = llm(instruction, max_length=128)[0]["generated_text"].strip()
51
+ prompt = result
52
+ # 根据模型选择生成图像
53
+ pipeline_model = sd_xl if model_choice == "SDXL" else sd_v15
54
+ image = pipeline_model(
55
+ prompt,
56
+ guidance_scale=guidance_scale,
57
+ negative_prompt=negative_prompt
58
+ ).images[0]
59
+ return prompt, image
60
+
61
+ # Step 3: 构建 Gradio 界面
62
+ with gr.Blocks(title="Prompt-to-Image Generator") as demo:
63
+ gr.Markdown("## 基于 LLM 的提示词生成与 Stable Diffusion 图像生成")
64
+ with gr.Row():
65
+ with gr.Column():
66
+ desc_input = gr.Textbox(label="文本描述", placeholder="例如:空中的魔法树屋")
67
+ style_dropdown = gr.Dropdown(
68
+ choices=["幻想风格", "赛博朋克", "写实主义"],
69
+ label="选择风格"
70
+ )
71
+ model_radio = gr.Radio(
72
+ choices=["SD v1.5", "SDXL"],
73
+ value="SD v1.5",
74
+ label="选择模型"
75
+ )
76
+ guidance_slider = gr.Slider(
77
+ minimum=0, maximum=20, step=0.5, value=7.5,
78
+ label="Guidance Scale"
79
+ )
80
+ neg_text = gr.Textbox(
81
+ label="反向提示词",
82
+ placeholder="排除内容(如:低分辨率、水印)"
83
+ )
84
+ use_voice = gr.Checkbox(label="启用语音输入(加分项)")
85
+ # 移除 'source' 参数以兼容 Gradio 版本
86
+ audio_input = gr.Audio(type="filepath", label="语音输入")
87
+ generate_btn = gr.Button("生成图像")
88
+ with gr.Column():
89
+ prompt_output = gr.Textbox(label="生成的提示词")
90
+ image_output = gr.Image(label="生成的图像")
91
+
92
+ # 绑定语音转文字(仅当启用时)
93
+ def conditional_transcribe(audio_path, use_voice_flag):
94
+ return transcribe(audio_path) if use_voice_flag else None
95
+
96
+ audio_input.change(
97
+ fn=conditional_transcribe,
98
+ inputs=[audio_input, use_voice],
99
+ outputs=desc_input
100
+ )
101
+ # 点击按钮生成提示词并绘图
102
+ generate_btn.click(
103
+ fn=generate,
104
+ inputs=[desc_input, model_radio, guidance_slider, neg_text, style_dropdown],
105
+ outputs=[prompt_output, image_output]
106
+ )
107
+
108
+ # Step 4: 启动应用
109
+ if __name__ == "__main__":
110
+ demo.launch(server_name="0.0.0.0", server_port=7860)