File size: 4,229 Bytes
449e89a
 
 
 
 
 
c12ed16
449e89a
 
 
c12ed16
449e89a
c12ed16
449e89a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c12ed16
449e89a
 
 
 
c12ed16
449e89a
c12ed16
 
 
 
 
449e89a
 
 
 
 
c12ed16
 
 
 
 
 
 
 
 
449e89a
 
 
 
 
 
 
 
c12ed16
449e89a
 
 
 
 
 
 
 
 
c12ed16
449e89a
 
 
c12ed16
449e89a
 
c1dcb58
449e89a
182027e
 
449e89a
 
c12ed16
 
c1dcb58
449e89a
 
c12ed16
449e89a
 
 
182027e
 
449e89a
182027e
 
 
c12ed16
449e89a
182027e
 
449e89a
 
 
 
 
 
 
 
 
 
c12ed16
449e89a
 
c12ed16
 
449e89a
 
 
 
 
c12ed16
 
61b9cc3
c12ed16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import torch
import gradio as gr
from transformers import pipeline
from diffusers import StableDiffusionPipeline

# 如果需要使用 Hugging Face 访问令牌,取消下面两行的注释并设置环境变量 HUGGINGFACE_TOKEN
# from huggingface_hub import login
# login(token=os.getenv("HUGGINGFACE_TOKEN"))

# Step 1: Prompt-to-Prompt 模块,使用 Flan - T5 生成结构化提示词
llm = pipeline(
    task="text2text-generation",
    model="google/flan-t5-large",
    device=0 if torch.cuda.is_available() else -1
)

# Step 2: 加载 Stable Diffusion 模型
# 移除无效的 revision 参数,仅使用 torch_dtype 加速加载
sd_v15 = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16
)
sd_v15 = sd_v15.to("cuda" if torch.cuda.is_available() else "cpu")

sd_xl = StableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0"
)
sd_xl = sd_xl.to("cuda" if torch.cuda.is_available() else "cpu")

# 可选:语音输入模块,使用 Whisper
asr = pipeline(
    task="automatic-speech-recognition",
    model="openai/whisper-base",
    device=0 if torch.cuda.is_available() else -1
)


def transcribe(audio_path):
    """
    对音频文件进行转录
    :param audio_path: 音频文件路径
    :return: 转录后的文本
    """
    text = asr(audio_path)["text"]
    return text


def generate(description, model_choice, guidance_scale, negative_prompt, style):
    """
    根据输入生成图像
    :param description: 文本描述
    :param model_choice: 选择的模型
    :param guidance_scale: 引导比例
    :param negative_prompt: 反向提示词
    :param style: 选择的风格
    :return: 生成的提示词和图像
    """
    # 构造给 LLM 的指令
    instruction = (
        f"请将以下简短描述扩展为 Stable Diffusion 友好的提示词,包含细节和风格:\n"
        f"描述: '{description}'\n"
        f"风格: '{style}'"
    )
    result = llm(instruction, max_length=128)[0]["generated_text"].strip()
    prompt = result

    # 根据模型选择生成图像
    pipeline_model = sd_xl if model_choice == "SDXL" else sd_v15
    image = pipeline_model(
        prompt,
        guidance_scale=guidance_scale,
        negative_prompt=negative_prompt
    ).images[0]
    return prompt, image


# Step 3: 构建 Gradio 界面
with gr.Blocks(title="Prompt-to-Image Generator") as demo:
    gr.Markdown("## 基于 LLM 的提示词生成与 Stable Diffusion 图像生成")

    with gr.Row():
        with gr.Column():
            desc_input = gr.Textbox(label="文本描述", placeholder="Example:blue sky")
            style_dropdown = gr.Dropdown(
                choices=["Fancy", "Science", "Reality"],
                label="choice"
            )
            model_radio = gr.Radio(
                choices=["SD v1.5", "SDXL"],
                value="SD v1.5",
                label="choice"
            )
            guidance_slider = gr.Slider(
                minimum=0, maximum=20, step=0.5, value=7.5,
                label="Guidance Scale"
            )
            neg_text = gr.Textbox(
                label="reverse_prompt",
                
            )
            use_voice = gr.Checkbox(label="voice_input")
            audio_input = gr.Audio(type="filepath", label="voice_input")
            generate_btn = gr.Button("generate")

        with gr.Column():
            prompt_output = gr.Textbox(label="generated prompt")
            image_output = gr.Image(label="generated word")

    # 绑定语音转文字(仅当启用时)
    def conditional_transcribe(audio_path, use_voice_flag):
        return transcribe(audio_path) if use_voice_flag else None

    audio_input.change(
        fn=conditional_transcribe,
        inputs=[audio_input, use_voice],
        outputs=desc_input
    )

    # 点击按钮生成提示词并绘图
    generate_btn.click(
        fn=generate,
        inputs=[desc_input, model_radio, guidance_slider, neg_text, style_dropdown],
        outputs=[prompt_output, image_output]
    )

# Step 4: 启动应用
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)