File size: 3,190 Bytes
c7a77a7
 
4d8788c
c7a77a7
4d8788c
 
9632f31
4d8788c
 
9632f31
 
4d8788c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11c2804
4d8788c
 
 
 
8ca204e
4d8788c
 
 
 
 
 
 
 
 
 
 
 
 
c7a77a7
4d8788c
 
 
 
 
 
 
 
c7a77a7
4d8788c
 
9632f31
4d8788c
 
 
 
 
9632f31
4d8788c
 
 
 
 
 
 
 
8ca204e
4d8788c
2130637
4d8788c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7a77a7
 
4d8788c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# -*- coding: utf-8 -*-
"""
鸟类知识智能科普系统
"""

import streamlit as st
from PIL import Image
import tempfile
from transformers import pipeline, AutoConfig
import torch

# ========== 模型配置 ==========
MODEL_CONFIG = {
    "image_to_text": {
        "model": "chriamue/bird-species-classifier",
        "config": {"use_fast": True}  # 强制启用快速处理器
    },
    "text_generation": {
        "model": "Qwen/Qwen-7B-Chat",
        "config": AutoConfig.from_pretrained("Qwen/Qwen-7B-Chat", revision="main")
    },
    "text_to_speech": {
        "model": "facebook/mms-tts-eng",
        "config": {"speaker_id": 6}  # 儿童音色
    }
}

# ========== 模型初始化 ==========
@st.cache_resource
def init_pipelines():
    """缓存模型加载结果避免重复初始化"""
    try:
        img_pipeline = pipeline(
            "image-classification", 
            model=MODEL_CONFIG["image_to_text"]["model"],
            **MODEL_CONFIG["image_to_text"]["config"]
        )
        
        text_pipeline = pipeline(
            "text-generation",
            model=MODEL_CONFIG["text_generation"]["model"],
            config=MODEL_CONFIG["text_generation"]["config"],
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        
        tts_pipeline = pipeline(
            "text-to-speech",
            model=MODEL_CONFIG["text_to_speech"]["model"],
            **MODEL_CONFIG["text_to_speech"]["config"]
        )
        
        return img_pipeline, text_pipeline, tts_pipeline
    
    except Exception as e:
        st.error(f"模型加载失败: {str(e)}")
        st.stop()

# ========== 核心功能 ==========
def generate_description(_pipe, bird_name):
    """生成儿童友好型描述"""
    prompt = f"用6-12岁儿童能理解的语言描述{bird_name},使用比喻和趣味知识:"
    return _pipe(prompt, max_new_tokens=120)[0]['generated_text'].split(":")[-1]

# ========== 界面设计 ==========
st.set_page_config(page_title="鸟类知识百科", page_icon="🐦")
st.title("🐦 智能鸟类科普系统")
st.markdown("上传鸟类图片,获取趣味知识讲解")

# 主流程
def main():
    img_pipe, text_pipe, tts_pipe = init_pipelines()
    
    uploaded_file = st.file_uploader("选择图片文件", type=["jpg", "png", "jpeg"])
    
    if uploaded_file:
        with tempfile.NamedTemporaryFile(suffix=".jpg") as tmp_file:
            # 保存临时文件
            tmp_file.write(uploaded_file.getvalue())
            
            with st.spinner("识别中..."):
                # 识别鸟类
                result = img_pipe(Image.open(tmp_file.name))
                bird_name = result[0]['label']
                st.success(f"识别结果:{bird_name}")
                
                # 生成描述
                desc = generate_description(text_pipe, bird_name)
                st.subheader("趣味知识")
                st.write(desc)
                
                # 语音合成
                audio = tts_pipe(desc[:1000])  # 限制文本长度
                st.audio(audio["audio"], sample_rate=audio["sampling_rate"])

if __name__ == "__main__":
    main()