wjt6 commited on
Commit
3c41586
·
verified ·
1 Parent(s): 65acb4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -95
app.py CHANGED
@@ -1,110 +1,185 @@
1
- import os
2
- import torch
3
  import gradio as gr
4
- from transformers import pipeline
 
 
5
  from diffusers import StableDiffusionPipeline
 
 
 
6
 
7
- # 如果需要使用 Hugging Face 访问令牌,取消下面一行的注释并设环境变量 HUGGINGFACE_TOKEN
8
- # from huggingface_hub import login
9
- # login(token=os.getenv("HUGGINGFACE_TOKEN"))
10
-
11
- # Step 1: Prompt-to-Prompt 模块,使用 Flan-T5 生成结构化提示词
12
- llm = pipeline(
13
- "text2text-generation",
14
- model="google/flan-t5-large",
15
- device=0 if torch.cuda.is_available() else -1
16
- )
17
-
18
- # Step 2: 加载 Stable Diffusion 模型
19
- # 移除无效的 revision 参数,仅使用 torch_dtype 加速加载
20
- sd_v15 = StableDiffusionPipeline.from_pretrained(
21
- "runwayml/stable-diffusion-v1-5",
22
- torch_dtype=torch.float16
23
- )
24
- sd_v15 = sd_v15.to("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
- sd_xl = StableDiffusionPipeline.from_pretrained(
27
- "stabilityai/stable-diffusion-xl-base-1.0"
28
- )
29
- sd_xl = sd_xl.to("cuda" if torch.cuda.is_available() else "cpu")
30
 
31
- # 可选:语音输入模块,使用 Whisper
32
- asr = pipeline(
33
- "automatic-speech-recognition",
34
- model="openai/whisper-base",
35
- device=0 if torch.cuda.is_available() else -1
36
- )
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- def transcribe(audio_path):
39
- text = asr(audio_path)["text"]
40
- return text
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- def generate(description, model_choice, guidance_scale, negative_prompt, style):
44
- # 构造给 LLM 的指令
45
- instruction = (
46
- f"请将以下简短描述扩展为 Stable Diffusion 友好的提示词,包含细节和风格:\n"
47
- f"描述: '{description}'\n"
48
- f"风格: '{style}'"
49
- )
50
- result = llm(instruction, max_length=128)[0]["generated_text"].strip()
51
- prompt = result
52
- # 根据模型选择生成图像
53
- pipeline_model = sd_xl if model_choice == "SDXL" else sd_v15
54
- image = pipeline_model(
55
- prompt,
56
- guidance_scale=guidance_scale,
57
- negative_prompt=negative_prompt
58
- ).images[0]
59
- return prompt, image
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # Step 3: 构建 Gradio 界面
62
- with gr.Blocks(title="Prompt-to-Image Generator") as demo:
63
- gr.Markdown("## 基于 LLM 的提示词生成与 Stable Diffusion 图像生成")
64
- with gr.Row():
65
- with gr.Column():
66
- desc_input = gr.Textbox(label="文本描述", placeholder="例如:空中的魔法树屋")
67
- style_dropdown = gr.Dropdown(
68
- choices=["幻想风格", "赛博朋克", "写实主义"],
69
- label="选择风格"
70
- )
71
- model_radio = gr.Radio(
72
- choices=["SD v1.5", "SDXL"],
73
- value="SD v1.5",
74
- label="选择模型"
75
- )
76
- guidance_slider = gr.Slider(
77
- minimum=0, maximum=20, step=0.5, value=7.5,
78
- label="Guidance Scale"
79
- )
80
- neg_text = gr.Textbox(
81
- label="反向提示词",
82
- placeholder="排除内容(如:低分辨率、水印)"
83
- )
84
- use_voice = gr.Checkbox(label="启用语音输入(加分项)")
85
- # 移除 'source' 参数以兼容 Gradio 版本
86
- audio_input = gr.Audio(type="filepath", label="语音输入")
87
- generate_btn = gr.Button("生成图像")
88
- with gr.Column():
89
- prompt_output = gr.Textbox(label="生成的提示词")
90
- image_output = gr.Image(label="生成的图像")
 
 
 
91
 
92
- # 绑定语音转文字(仅当启用时)
93
- def conditional_transcribe(audio_path, use_voice_flag):
94
- return transcribe(audio_path) if use_voice_flag else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- audio_input.change(
97
- fn=conditional_transcribe,
98
- inputs=[audio_input, use_voice],
99
- outputs=desc_input
 
 
 
 
 
100
  )
101
- # 点击按钮生成提示词并���图
102
- generate_btn.click(
103
- fn=generate,
104
- inputs=[desc_input, model_radio, guidance_slider, neg_text, style_dropdown],
105
- outputs=[prompt_output, image_output]
 
106
  )
107
 
108
- # Step 4: 启动应用
109
  if __name__ == "__main__":
110
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import os
4
+ import logging
5
  from diffusers import StableDiffusionPipeline
6
+ from PIL import Image
7
+ from openai import OpenAI
8
+ from typing import Optional
9
 
10
+ # 日志
11
+ logging.basicConfig(level=logging.INFO)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # 初始化API客户端
14
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
15
+ openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
 
16
 
17
+ # 初始化Stable Diffusion
18
+ def init_sd_pipeline():
19
+ try:
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
22
+
23
+ logging.info(f"正在加载模型到 {device},精度:{torch_dtype}")
24
+
25
+ pipe = StableDiffusionPipeline.from_pretrained(
26
+ "runwayml/stable-diffusion-v1-5",
27
+ torch_dtype=torch_dtype,
28
+ safety_checker=None,
29
+ use_safetensors=True
30
+ )
31
+ return pipe.to(device)
32
+ except Exception as e:
33
+ logging.error(f"模型初始化失败: {str(e)}")
34
+ return None
35
 
36
+ # 显存优化
37
+ torch.cuda.empty_cache()
38
+ image_pipe = init_sd_pipeline()
39
 
40
+ # 文本生成函数
41
+ def generate_prompt(user_input, temperature, top_p, max_tokens, repetition_penalty):
42
+ if not openai_client:
43
+ raise gr.Error("OpenAI客户端未初始化,请检查API密钥")
44
+
45
+ system_prompt = """你是一个专业的提示词工程师,请将用户的想法转化为详细的Stable Diffusion提示词。
46
+ 遵循以下格式:
47
+ [主体描述],[环境细节],[艺术风格],[画质参数]
48
+ 示例:
49
+ 魔法屋在空中漂浮,被五彩云环绕,赛博朋克风格,8k分辨率,超精细细节"""
50
+
51
+ try:
52
+ response = openai_client.chat.completions.create(
53
+ model="gpt-3.5-turbo",
54
+ messages=[
55
+ {"role": "system", "content": system_prompt},
56
+ {"role": "user", "content": user_input}
57
+ ],
58
+ temperature=temperature,
59
+ top_p=top_p,
60
+ max_tokens=max_tokens,
61
+ presence_penalty=repetition_penalty
62
+ )
63
+ return response.choices[0].message.content.strip()
64
+ except Exception as e:
65
+ logging.error(f"提示词生成失败: {str(e)}")
66
+ return f"生成失败: {str(e)}"
67
 
68
+ # 图像生成函数
69
+ def generate_image(prompt, negative_prompt="", guidance=7.5, steps=25):
70
+ if not image_pipe:
71
+ raise gr.Error("图像模型未初始化")
72
+ if not prompt:
73
+ raise gr.Error("请输入提示词")
74
+
75
+ try:
76
+ logging.info(f"开始生成图像,参数:guidance={guidance}, steps={steps}")
77
+
78
+ generator = torch.Generator()
79
+ if torch.cuda.is_available():
80
+ generator = generator.cuda()
81
+
82
+ result = image_pipe(
83
+ prompt=prompt,
84
+ negative_prompt=negative_prompt,
85
+ guidance_scale=guidance,
86
+ num_inference_steps=int(steps),
87
+ generator=generator.manual_seed(int(torch.rand(1).item()*1e7)),
88
+ num_images_per_prompt=1
89
+ )
90
+ return result.images[0]
91
+ except torch.cuda.OutOfMemoryError:
92
+ torch.cuda.empty_cache()
93
+ raise gr.Error("显存不足,请尝试:1. 简化提示词 2. 减小迭代次数 3. 降低引导强度")
94
+ except Exception as e:
95
+ logging.error(f"图像生成失败: {str(e)}")
96
+ raise gr.Error(f"生成失败: {str(e)}")
97
 
98
+ # 界面布局
99
+ with gr.Blocks(theme=gr.themes.Soft(), title="AI创作平台") as app:
100
+ gr.Markdown("## 🎨 智能创作平台 - 文本到图像生成工作流")
101
+
102
+ with gr.Tabs():
103
+ # 提示词优化选项卡
104
+ with gr.Tab("🖋 提示词设计"):
105
+ with gr.Row():
106
+ with gr.Column(scale=2):
107
+ input_box = gr.Textbox(
108
+ label="原始想法",
109
+ placeholder="例:空中的魔法屋",
110
+ lines=3,
111
+ max_lines=6
112
+ )
113
+ with gr.Accordion("advanced parameters", open=False):
114
+ temp_slider = gr.Slider(0.1, 1.5, 0.7,
115
+ label="creative temperature")
116
+ top_p_slider = gr.Slider(0.1, 1.0, 0.9,
117
+ label="core sampling ratio")
118
+ max_len = gr.Slider(64, 2048, 512, step=64,
119
+ label="maxlength")
120
+ rep_penalty = gr.Slider(1.0, 2.0, 1.2,
121
+ label="reapted punishment")
122
+
123
+ gen_btn = gr.Button("生成专业提示词", variant="primary")
124
+
125
+ output_prompt = gr.Textbox(
126
+ label="优化后的提示词",
127
+ lines=4,
128
+ interactive=True,
129
+ elem_classes=["prompt-box"]
130
+ )
131
 
132
+ # 图像生成选项卡
133
+ with gr.Tab("🖼 图像生成"):
134
+ with gr.Row():
135
+ with gr.Column(scale=1):
136
+ prompt_transfer = gr.Textbox(
137
+ label="当前提示词",
138
+ lines=3,
139
+ interactive=True
140
+ )
141
+ neg_prompt = gr.Textbox(
142
+ label="排除内容",
143
+ placeholder="例:模糊、低质量、水印",
144
+ lines=2
145
+ )
146
+ with gr.Row():
147
+ guidance_slider = gr.Slider(1, 20, 7.5,
148
+ label="guiding strength", step=0.5)
149
+ steps_slider = gr.Slider(10, 50, 25,
150
+ label="epochs", step=5)
151
+ image_btn = gr.Button("生成图像", variant="primary")
152
+
153
+ gallery = gr.Gallery(
154
+ label="生成结果",
155
+ columns=2,
156
+ height=600,
157
+ object_fit="contain"
158
+ )
159
 
160
+ # 事件绑定
161
+ gen_btn.click(
162
+ generate_prompt,
163
+ [input_box, temp_slider, top_p_slider, max_len, rep_penalty],
164
+ output_prompt
165
+ ).then(
166
+ lambda x: x,
167
+ output_prompt,
168
+ prompt_transfer
169
  )
170
+
171
+ image_btn.click(
172
+ generate_image,
173
+ [prompt_transfer, neg_prompt, guidance_slider, steps_slider],
174
+ gallery,
175
+ api_name='generate'
176
  )
177
 
178
+ # 运行应用
179
  if __name__ == "__main__":
180
+ app.launch(
181
+ server_name="0.0.0.0",
182
+ server_port=7860,
183
+ show_error=True,
184
+ share=False
185
+ )