mywork / app.py
txh17's picture
Update app.py
3a430c7 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from diffusers import StableDiffusionPipeline
import whisper
import os
# 加载 Whisper 模型
whisper_model = whisper.load_model("base")
# 使用 DistilGPT-2 生成中文提示词
model_name = "distilgpt2" # DistilGPT-2 是一个较小的 GPT-2 版本
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def generate_prompt(user_input):
# 编码输入并生成提示
inputs = tokenizer.encode(f"根据以下描述生成一个结构良好的提示,适用于稳定扩散图像生成:'{user_input}'", return_tensors="pt")
# 使用模型生成响应
outputs = model.generate(inputs, max_length=100, num_return_sequences=1)
# 解码输出并返回
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# 使用 Stable Diffusion 生成图像
def generate_image(prompt):
# 使用 Stable Diffusion 生成图像
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
pipe.to("cpu")
# 生成图像
image = pipe(prompt).images[0]
return image
# 语音转文本
def transcribe_audio(audio_file):
result = whisper_model.transcribe(audio_file, language="zh") # 指定中文识别
return result['text']
# 生成艺术作品
def generate_artwork(description, style, enhance_details, audio_file):
# 如果上传了音频文件,进行语音转文本
if audio_file is not None:
description = transcribe_audio(audio_file)
print("Transcribed Description:", description) # 打印转录后的描述
# 生成提示词
prompt = generate_prompt(description)
print("Generated Prompt:", prompt) # 打印生成的提示词
# 如果选中了增强细节的选项,修改提示
if enhance_details:
prompt += "有增强细节 "
# 根据用户选择的风格修改提示
if style == "奇幻":
prompt += " 奇幻风格"
elif style == "赛博朋克":
prompt += "赛博朋克风格"
else:
prompt += "写实风格"
# 生成图像
image = generate_image(prompt)
print("Image Generated:", image) # 打印生成的图像
return prompt, image
# 创建 Gradio 界面
with gr.Blocks() as demo:
with gr.Row():
description_input = gr.Textbox(label="请输入描述", placeholder="例如:天空中的魔法树屋")
style_selector = gr.Dropdown(choices=["奇幻", "赛博朋克", "现实主义"], label="选择风格")
detail_checkbox = gr.Checkbox(label="增强细节")
audio_input = gr.Audio(label="录制您的描述", type="filepath")
with gr.Row():
output_prompt = gr.Textbox(label="生成的提示词", interactive=False)
output_image = gr.Image(label="生成的图像", interactive=False)
generate_button = gr.Button("生成作品")
generate_button.click(generate_artwork,
inputs=[description_input, style_selector, detail_checkbox, audio_input],
outputs=[output_prompt, output_image])
demo.launch()