sshenai commited on
Commit
d2f9ee7
·
verified ·
1 Parent(s): 39c4a98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py CHANGED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 鸟类知识科普系统(Qwen3优化版) by [你的名字]
4
+ ISOM5240 Group Project
5
+ """
6
+
7
+ import gradio as gr
8
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
9
+ from PIL import Image
10
+ import torch
11
+
12
+ # 强制清理旧版缓存
13
+ from transformers.utils import move_cache
14
+ move_cache()
15
+
16
+ # 初始化模型(兼容Qwen3)
17
+ def init_models():
18
+ # 鸟类分类模型(保持不变)
19
+ classifier = pipeline(
20
+ "image-classification",
21
+ model="chriamue/bird-species-classifier",
22
+ device=0 if torch.cuda.is_available() else -1
23
+ )
24
+
25
+ # 更新为Qwen3模型(官方支持版本)
26
+ text_generator = pipeline(
27
+ "text-generation",
28
+ model="Qwen/Qwen-7B-Chat", # 使用官方维护版本
29
+ device_map="auto",
30
+ torch_dtype=torch.bfloat16,
31
+ trust_remote_code=True, # 必须开启
32
+ model_kwargs={
33
+ "revision": "main",
34
+ "force_download": True # 替换弃用参数
35
+ }
36
+ )
37
+
38
+ # 语音合成模型(保持不变)
39
+ tts = pipeline(
40
+ "text-to-speech",
41
+ model="facebook/mms-tts-eng",
42
+ device=0 if torch.cuda.is_available() else -1
43
+ )
44
+
45
+ return classifier, text_generator, tts
46
+
47
+ # 生成儿童友好的鸟类描述
48
+ def generate_child_friendly_text(bird_name):
49
+ PROMPT = f"""以6-12岁儿童能理解的方式描述{bird_name}:
50
+ 1. 用比喻手法(如:羽毛像彩虹糖纸)
51
+ 2. 包含一个趣味冷知识(例如:每天吃相当于自身体重30%的食物)
52
+ 3. 语句长度不超过15个英文单词
53
+ 4. 避免使用专业术语"""
54
+
55
+ response = text_generator(
56
+ PROMPT,
57
+ max_new_tokens=150,
58
+ temperature=0.7,
59
+ do_sample=True
60
+ )
61
+
62
+ return response[0]['generated_text'].split('\n')[2]
63
+
64
+ # 主处理流程
65
+ def process_image(image):
66
+ try:
67
+ classification = classifier(image)
68
+ bird_name = classification[0]['label']
69
+ description = generate_child_friendly_text(bird_name)
70
+ speech = tts(description, forward_params={"speaker_id": 6})
71
+
72
+ return {
73
+ "bird_name": bird_name,
74
+ "description": description,
75
+ "audio": speech["audio"]
76
+ }
77
+ except Exception as e:
78
+ return f"处理错误: {str(e)}"
79
+
80
+ # 初始化模型
81
+ classifier, text_generator, tts = init_models()
82
+
83
+ # 创建Gradio界面
84
+ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px}") as demo:
85
+ gr.Markdown("# 🐦 鸟类知识小课堂(Qwen3版)")
86
+
87
+ with gr.Row():
88
+ image_input = gr.Image(type="pil", label="上传鸟类图片", height=300)
89
+ audio_output = gr.Audio(label="语音讲解", autoplay=True)
90
+
91
+ with gr.Column():
92
+ name_output = gr.Textbox(label="识别到的鸟类")
93
+ text_output = gr.Textbox(label="趣味知识", lines=4)
94
+
95
+ examples = gr.Examples(
96
+ examples=["eagle.jpg", "penguin.jpg", "peacock.jpg"],
97
+ inputs=image_input,
98
+ label="示例图片"
99
+ )
100
+
101
+ image_input.change(
102
+ process_image,
103
+ inputs=image_input,
104
+ outputs=[name_output, text_output, audio_output]
105
+ )
106
+
107
+ # 部署配置
108
+ if __name__ == "__main__":
109
+ demo.launch(server_name="0.0.0.0", server_port=7860)