File size: 4,362 Bytes
d2f9ee7
 
f1cdf4e
d2f9ee7
 
a3f53bc
bfac0f6
ff33c1e
bfac0f6
 
d2f9ee7
 
bfac0f6
 
 
 
 
d2f9ee7
bfac0f6
 
 
 
d2f9ee7
bfac0f6
d2f9ee7
bfac0f6
 
 
 
 
 
 
 
 
 
 
 
 
d2f9ee7
bfac0f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2f9ee7
bfac0f6
 
 
 
 
 
d2f9ee7
bfac0f6
 
 
 
 
 
 
 
 
 
 
d2f9ee7
bfac0f6
 
d2f9ee7
bfac0f6
 
d2f9ee7
d32f0b3
bfac0f6
d2f9ee7
d32f0b3
bfac0f6
 
f1cdf4e
bfac0f6
f1cdf4e
d2f9ee7
bfac0f6
f1cdf4e
bfac0f6
 
 
 
d2f9ee7
bfac0f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1cdf4e
bfac0f6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# -*- coding: utf-8 -*-
"""
鸟类知识科普系统(Qwen3优化版) by [你的名字]
ISOM5240 Group Project
"""


import transformers
import os
import torch
import gradio as gr
from PIL import Image
from transformers import (
    pipeline,
    AutoConfig,
    AutoImageProcessor
)

# ========== 环境配置 ==========
os.environ["TRANSFORMERS_CACHE"] = "./model_cache"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"当前设备: {DEVICE.upper()}")

# ========== 模型初始化 ==========
def init_models():
    """统一模型加载逻辑"""
    try:
        # 1. 鸟类分类模型(显式配置图像处理器)
        image_processor = AutoImageProcessor.from_pretrained(
            "chriamue/bird-species-classifier",
            use_fast=True  # 强制启用快速模式
        )
        classifier = pipeline(
            task="image-classification",
            model="chriamue/bird-species-classifier",
            feature_extractor=image_processor,
            device=DEVICE
        )

        # 2. 文本生成模型(解决revision参数冲突)
        qwen_config = AutoConfig.from_pretrained(
            "Qwen/Qwen-7B-Chat",
            trust_remote_code=True
        )
        text_generator = pipeline(
            task="text-generation",
            model="Qwen/Qwen-7B-Chat",
            config=qwen_config,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            model_kwargs={
                "cache_dir": "./model_cache",
                "revision": "main"  # 唯一指定版本
            }
        )

        # 3. 语音合成模型
        tts = pipeline(
            task="text-to-speech",
            model="facebook/mms-tts-eng",
            device=DEVICE
        )

        return classifier, text_generator, tts
        
    except Exception as e:
        raise RuntimeError(f"模型加载失败: {str(e)}")

# ========== 核心逻辑 ==========
def generate_child_friendly_text(bird_name):
    """生成儿童友好型描述(优化prompt工程)"""
    prompt = f"""用6-12岁儿童能理解的语言描述{bird_name}
    ★ 使用比喻(例如:羽毛像彩虹糖纸)
    ★ 包含趣味冷知识(例如:每天吃自身体重30%的食物)
    ★ 语句简短(每句不超过15个英文单词)
    ★ 避免专业术语"""
    
    try:
        response = text_generator(
            prompt,
            max_new_tokens=120,
            temperature=0.7,
            top_p=0.9,
            do_sample=True
        )
        return response[0]['generated_text'].split("描述")[-1].strip()
    except Exception as e:
        return f"文本生成失败: {str(e)}"

def process_image(img):
    """端到端处理流程"""
    try:
        # 1. 图像分类
        classification = classifier(img)
        bird_name = classification[0]['label']
        
        # 2. 生成描述
        description = generate_child_friendly_text(bird_name)
        
        # 3. 语音合成
        speech = tts(description, forward_params={"speaker_id": 6})
        
        return bird_name, description, speech["audio"]
    
    except Exception as e:
        return "错误", f"处理失败: {str(e)}", None

# ========== 初始化验证 ==========
if __name__ == "__main__":
    # 预加载模型(验证可用性)
    classifier, text_generator, tts = init_models()
    
    # 构建Gradio界面
    with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px}") as demo:
        gr.Markdown("# 🐦 鸟类知识小课堂(稳定版)")
        
        with gr.Row():
            img_input = gr.Image(type="pil", label="上传鸟类图片", height=300)
            audio_output = gr.Audio(label="语音讲解", autoplay=True)
        
        with gr.Column():
            name_output = gr.Textbox(label="识别结果")
            desc_output = gr.Textbox(label="趣味知识", lines=4)
        
        gr.Examples(
            examples=["test_images/eagle.jpg", "test_images/penguin.jpg"],
            inputs=img_input,
            label="示例图片"
        )
        
        img_input.change(
            process_image,
            inputs=img_input,
            outputs=[name_output, desc_output, audio_output]
        )
    
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=True
    )