Spaces:
Runtime error
Runtime error
File size: 4,362 Bytes
d2f9ee7 f1cdf4e d2f9ee7 a3f53bc bfac0f6 ff33c1e bfac0f6 d2f9ee7 bfac0f6 d2f9ee7 bfac0f6 d2f9ee7 bfac0f6 d2f9ee7 bfac0f6 d2f9ee7 bfac0f6 d2f9ee7 bfac0f6 d2f9ee7 bfac0f6 d2f9ee7 bfac0f6 d2f9ee7 bfac0f6 d2f9ee7 d32f0b3 bfac0f6 d2f9ee7 d32f0b3 bfac0f6 f1cdf4e bfac0f6 f1cdf4e d2f9ee7 bfac0f6 f1cdf4e bfac0f6 d2f9ee7 bfac0f6 f1cdf4e bfac0f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# -*- coding: utf-8 -*-
"""
鸟类知识科普系统(Qwen3优化版) by [你的名字]
ISOM5240 Group Project
"""
import transformers
import os
import torch
import gradio as gr
from PIL import Image
from transformers import (
pipeline,
AutoConfig,
AutoImageProcessor
)
# ========== 环境配置 ==========
os.environ["TRANSFORMERS_CACHE"] = "./model_cache"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"当前设备: {DEVICE.upper()}")
# ========== 模型初始化 ==========
def init_models():
"""统一模型加载逻辑"""
try:
# 1. 鸟类分类模型(显式配置图像处理器)
image_processor = AutoImageProcessor.from_pretrained(
"chriamue/bird-species-classifier",
use_fast=True # 强制启用快速模式
)
classifier = pipeline(
task="image-classification",
model="chriamue/bird-species-classifier",
feature_extractor=image_processor,
device=DEVICE
)
# 2. 文本生成模型(解决revision参数冲突)
qwen_config = AutoConfig.from_pretrained(
"Qwen/Qwen-7B-Chat",
trust_remote_code=True
)
text_generator = pipeline(
task="text-generation",
model="Qwen/Qwen-7B-Chat",
config=qwen_config,
torch_dtype=torch.bfloat16,
device_map="auto",
model_kwargs={
"cache_dir": "./model_cache",
"revision": "main" # 唯一指定版本
}
)
# 3. 语音合成模型
tts = pipeline(
task="text-to-speech",
model="facebook/mms-tts-eng",
device=DEVICE
)
return classifier, text_generator, tts
except Exception as e:
raise RuntimeError(f"模型加载失败: {str(e)}")
# ========== 核心逻辑 ==========
def generate_child_friendly_text(bird_name):
"""生成儿童友好型描述(优化prompt工程)"""
prompt = f"""用6-12岁儿童能理解的语言描述{bird_name}:
★ 使用比喻(例如:羽毛像彩虹糖纸)
★ 包含趣味冷知识(例如:每天吃自身体重30%的食物)
★ 语句简短(每句不超过15个英文单词)
★ 避免专业术语"""
try:
response = text_generator(
prompt,
max_new_tokens=120,
temperature=0.7,
top_p=0.9,
do_sample=True
)
return response[0]['generated_text'].split("描述")[-1].strip()
except Exception as e:
return f"文本生成失败: {str(e)}"
def process_image(img):
"""端到端处理流程"""
try:
# 1. 图像分类
classification = classifier(img)
bird_name = classification[0]['label']
# 2. 生成描述
description = generate_child_friendly_text(bird_name)
# 3. 语音合成
speech = tts(description, forward_params={"speaker_id": 6})
return bird_name, description, speech["audio"]
except Exception as e:
return "错误", f"处理失败: {str(e)}", None
# ========== 初始化验证 ==========
if __name__ == "__main__":
# 预加载模型(验证可用性)
classifier, text_generator, tts = init_models()
# 构建Gradio界面
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px}") as demo:
gr.Markdown("# 🐦 鸟类知识小课堂(稳定版)")
with gr.Row():
img_input = gr.Image(type="pil", label="上传鸟类图片", height=300)
audio_output = gr.Audio(label="语音讲解", autoplay=True)
with gr.Column():
name_output = gr.Textbox(label="识别结果")
desc_output = gr.Textbox(label="趣味知识", lines=4)
gr.Examples(
examples=["test_images/eagle.jpg", "test_images/penguin.jpg"],
inputs=img_input,
label="示例图片"
)
img_input.change(
process_image,
inputs=img_input,
outputs=[name_output, desc_output, audio_output]
)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
) |