DPLproject / app.py
sshenai's picture
Update app.py
bfac0f6 verified
# -*- coding: utf-8 -*-
"""
鸟类知识科普系统(Qwen3优化版) by [你的名字]
ISOM5240 Group Project
"""
import transformers
import os
import torch
import gradio as gr
from PIL import Image
from transformers import (
pipeline,
AutoConfig,
AutoImageProcessor
)
# ========== 环境配置 ==========
os.environ["TRANSFORMERS_CACHE"] = "./model_cache"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"当前设备: {DEVICE.upper()}")
# ========== 模型初始化 ==========
def init_models():
"""统一模型加载逻辑"""
try:
# 1. 鸟类分类模型(显式配置图像处理器)
image_processor = AutoImageProcessor.from_pretrained(
"chriamue/bird-species-classifier",
use_fast=True # 强制启用快速模式
)
classifier = pipeline(
task="image-classification",
model="chriamue/bird-species-classifier",
feature_extractor=image_processor,
device=DEVICE
)
# 2. 文本生成模型(解决revision参数冲突)
qwen_config = AutoConfig.from_pretrained(
"Qwen/Qwen-7B-Chat",
trust_remote_code=True
)
text_generator = pipeline(
task="text-generation",
model="Qwen/Qwen-7B-Chat",
config=qwen_config,
torch_dtype=torch.bfloat16,
device_map="auto",
model_kwargs={
"cache_dir": "./model_cache",
"revision": "main" # 唯一指定版本
}
)
# 3. 语音合成模型
tts = pipeline(
task="text-to-speech",
model="facebook/mms-tts-eng",
device=DEVICE
)
return classifier, text_generator, tts
except Exception as e:
raise RuntimeError(f"模型加载失败: {str(e)}")
# ========== 核心逻辑 ==========
def generate_child_friendly_text(bird_name):
"""生成儿童友好型描述(优化prompt工程)"""
prompt = f"""用6-12岁儿童能理解的语言描述{bird_name}
★ 使用比喻(例如:羽毛像彩虹糖纸)
★ 包含趣味冷知识(例如:每天吃自身体重30%的食物)
★ 语句简短(每句不超过15个英文单词)
★ 避免专业术语"""
try:
response = text_generator(
prompt,
max_new_tokens=120,
temperature=0.7,
top_p=0.9,
do_sample=True
)
return response[0]['generated_text'].split("描述")[-1].strip()
except Exception as e:
return f"文本生成失败: {str(e)}"
def process_image(img):
"""端到端处理流程"""
try:
# 1. 图像分类
classification = classifier(img)
bird_name = classification[0]['label']
# 2. 生成描述
description = generate_child_friendly_text(bird_name)
# 3. 语音合成
speech = tts(description, forward_params={"speaker_id": 6})
return bird_name, description, speech["audio"]
except Exception as e:
return "错误", f"处理失败: {str(e)}", None
# ========== 初始化验证 ==========
if __name__ == "__main__":
# 预加载模型(验证可用性)
classifier, text_generator, tts = init_models()
# 构建Gradio界面
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px}") as demo:
gr.Markdown("# 🐦 鸟类知识小课堂(稳定版)")
with gr.Row():
img_input = gr.Image(type="pil", label="上传鸟类图片", height=300)
audio_output = gr.Audio(label="语音讲解", autoplay=True)
with gr.Column():
name_output = gr.Textbox(label="识别结果")
desc_output = gr.Textbox(label="趣味知识", lines=4)
gr.Examples(
examples=["test_images/eagle.jpg", "test_images/penguin.jpg"],
inputs=img_input,
label="示例图片"
)
img_input.change(
process_image,
inputs=img_input,
outputs=[name_output, desc_output, audio_output]
)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)