|
|
import gradio as gr |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
from diffusers import StableDiffusionPipeline |
|
|
import whisper |
|
|
import os |
|
|
|
|
|
|
|
|
whisper_model = whisper.load_model("base") |
|
|
|
|
|
|
|
|
model_name = "distilgpt2" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
def generate_prompt(user_input): |
|
|
|
|
|
inputs = tokenizer.encode(f"根据以下描述生成一个结构良好的提示,适用于稳定扩散图像生成:'{user_input}'", return_tensors="pt") |
|
|
|
|
|
|
|
|
outputs = model.generate(inputs, max_length=100, num_return_sequences=1) |
|
|
|
|
|
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
def generate_image(prompt): |
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base") |
|
|
pipe.to("cpu") |
|
|
|
|
|
image = pipe(prompt).images[0] |
|
|
return image |
|
|
|
|
|
|
|
|
def transcribe_audio(audio_file): |
|
|
result = whisper_model.transcribe(audio_file, language="zh") |
|
|
return result['text'] |
|
|
|
|
|
|
|
|
def generate_artwork(description, style, enhance_details, audio_file): |
|
|
|
|
|
if audio_file is not None: |
|
|
description = transcribe_audio(audio_file) |
|
|
print("Transcribed Description:", description) |
|
|
|
|
|
|
|
|
prompt = generate_prompt(description) |
|
|
print("Generated Prompt:", prompt) |
|
|
|
|
|
|
|
|
if enhance_details: |
|
|
prompt += "有增强细节 " |
|
|
|
|
|
|
|
|
if style == "奇幻": |
|
|
prompt += " 奇幻风格" |
|
|
elif style == "赛博朋克": |
|
|
prompt += "赛博朋克风格" |
|
|
else: |
|
|
prompt += "写实风格" |
|
|
|
|
|
|
|
|
image = generate_image(prompt) |
|
|
print("Image Generated:", image) |
|
|
return prompt, image |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
with gr.Row(): |
|
|
description_input = gr.Textbox(label="请输入描述", placeholder="例如:天空中的魔法树屋") |
|
|
style_selector = gr.Dropdown(choices=["奇幻", "赛博朋克", "现实主义"], label="选择风格") |
|
|
detail_checkbox = gr.Checkbox(label="增强细节") |
|
|
audio_input = gr.Audio(label="录制您的描述", type="filepath") |
|
|
|
|
|
with gr.Row(): |
|
|
output_prompt = gr.Textbox(label="生成的提示词", interactive=False) |
|
|
output_image = gr.Image(label="生成的图像", interactive=False) |
|
|
|
|
|
generate_button = gr.Button("生成作品") |
|
|
generate_button.click(generate_artwork, |
|
|
inputs=[description_input, style_selector, detail_checkbox, audio_input], |
|
|
outputs=[output_prompt, output_image]) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|
|
|
|
|
|
|