Spaces:
Build error
Build error
File size: 3,190 Bytes
c7a77a7 4d8788c c7a77a7 4d8788c 9632f31 4d8788c 9632f31 4d8788c 11c2804 4d8788c 8ca204e 4d8788c c7a77a7 4d8788c c7a77a7 4d8788c 9632f31 4d8788c 9632f31 4d8788c 8ca204e 4d8788c 2130637 4d8788c c7a77a7 4d8788c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
# -*- coding: utf-8 -*-
"""
鸟类知识智能科普系统
"""
import streamlit as st
from PIL import Image
import tempfile
from transformers import pipeline, AutoConfig
import torch
# ========== 模型配置 ==========
MODEL_CONFIG = {
"image_to_text": {
"model": "chriamue/bird-species-classifier",
"config": {"use_fast": True} # 强制启用快速处理器
},
"text_generation": {
"model": "Qwen/Qwen-7B-Chat",
"config": AutoConfig.from_pretrained("Qwen/Qwen-7B-Chat", revision="main")
},
"text_to_speech": {
"model": "facebook/mms-tts-eng",
"config": {"speaker_id": 6} # 儿童音色
}
}
# ========== 模型初始化 ==========
@st.cache_resource
def init_pipelines():
"""缓存模型加载结果避免重复初始化"""
try:
img_pipeline = pipeline(
"image-classification",
model=MODEL_CONFIG["image_to_text"]["model"],
**MODEL_CONFIG["image_to_text"]["config"]
)
text_pipeline = pipeline(
"text-generation",
model=MODEL_CONFIG["text_generation"]["model"],
config=MODEL_CONFIG["text_generation"]["config"],
torch_dtype=torch.bfloat16,
device_map="auto"
)
tts_pipeline = pipeline(
"text-to-speech",
model=MODEL_CONFIG["text_to_speech"]["model"],
**MODEL_CONFIG["text_to_speech"]["config"]
)
return img_pipeline, text_pipeline, tts_pipeline
except Exception as e:
st.error(f"模型加载失败: {str(e)}")
st.stop()
# ========== 核心功能 ==========
def generate_description(_pipe, bird_name):
"""生成儿童友好型描述"""
prompt = f"用6-12岁儿童能理解的语言描述{bird_name},使用比喻和趣味知识:"
return _pipe(prompt, max_new_tokens=120)[0]['generated_text'].split(":")[-1]
# ========== 界面设计 ==========
st.set_page_config(page_title="鸟类知识百科", page_icon="🐦")
st.title("🐦 智能鸟类科普系统")
st.markdown("上传鸟类图片,获取趣味知识讲解")
# 主流程
def main():
img_pipe, text_pipe, tts_pipe = init_pipelines()
uploaded_file = st.file_uploader("选择图片文件", type=["jpg", "png", "jpeg"])
if uploaded_file:
with tempfile.NamedTemporaryFile(suffix=".jpg") as tmp_file:
# 保存临时文件
tmp_file.write(uploaded_file.getvalue())
with st.spinner("识别中..."):
# 识别鸟类
result = img_pipe(Image.open(tmp_file.name))
bird_name = result[0]['label']
st.success(f"识别结果:{bird_name}")
# 生成描述
desc = generate_description(text_pipe, bird_name)
st.subheader("趣味知识")
st.write(desc)
# 语音合成
audio = tts_pipe(desc[:1000]) # 限制文本长度
st.audio(audio["audio"], sample_rate=audio["sampling_rate"])
if __name__ == "__main__":
main()
|