sshenai commited on
Commit
2130637
·
verified ·
1 Parent(s): 85a3dae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -100
app.py CHANGED
@@ -1,113 +1,70 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- 鸟类知识科普系统(修正版) by [你的名字]
4
- ISOM5240 Group Project
5
- """
6
-
7
- import gradio as gr
8
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
9
  from PIL import Image
 
 
10
  import torch
 
 
 
11
 
12
- # 初始化模型(兼容性优化)
13
- def init_models():
14
- # 鸟类分类模型(保持不变)
15
- classifier = pipeline(
16
- "image-classification",
17
- model="chriamue/bird-species-classifier",
18
- device=0 if torch.cuda.is_available() else -1
19
- )
20
 
21
- # 替换为DeepSeek-R1模型(兼容性配置)
22
- text_generator = pipeline(
23
- "text-generation",
24
- model="deepseek-ai/DeepSeek-R1",
25
- torch_dtype=torch.bfloat16,
26
- device_map="auto",
27
- model_kwargs={
28
- "load_in_4bit": True,
29
- "trust_remote_code": True # 必须开启远程代码执行
30
- }
31
- )
32
 
33
- # 语音合成模型(保持不变)
34
- tts = pipeline(
35
- "text-to-speech",
36
- model="facebook/mms-tts-eng",
37
- device=0 if torch.cuda.is_available() else -1
38
- )
39
 
40
- return classifier, text_generator, tts
41
-
42
- # 生成儿童友好的鸟类描述(优化Prompt)
43
- def generate_child_friendly_text(bird_name):
44
- PROMPT = f"""以6-12岁儿童能理解的语言介绍{bird_name}:
45
- 1. 用动物拟人化的方式描述特征(例如:穿彩色外套的鸟)
46
- 2. 解释生活习性时结合日常场景(如:像小朋友一样喜欢玩耍)
47
- 3. 包含一个趣味冷知识(例如:飞行距离相当于绕操场XX圈)
48
- 4. 语句长度控制在10-15个英文单词
49
- 5. 使用比喻手法代替专业术语"""
50
-
51
- response = text_generator(
52
- PROMPT,
53
- max_new_tokens=150,
54
- temperature=0.8,
55
- top_k=40,
56
- do_sample=True
57
- )
58
 
59
- # 后处理优化
60
- cleaned_text = response[0]['generated_text'].split('\n')[2]
61
- return cleaned_text.replace("**", "") # 去除多余符号
62
 
63
- # 主处理流程(增加异常处理)
64
- def process_image(image):
65
  try:
66
- classification = classifier(image)
67
- bird_name = classification[0]['label']
68
- description = generate_child_friendly_text(bird_name)
69
- speech = tts(description, forward_params={"speaker_id": 6})
70
-
71
- return {
72
- "bird_name": bird_name,
73
- "description": description,
74
- "audio": speech["audio"]
75
- }
76
- except Exception as e:
77
- return f"处理错误: {str(e)}"
78
 
79
- # 初始化模型(增加缓存清理)
80
- from transformers.utils import cached_file
81
- cached_file.cache_clear()
82
- classifier, text_generator, tts = init_models()
 
 
 
83
 
84
- # 创建Gradio界面(布局优化)
85
- with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px !important}") as demo:
86
- gr.Markdown("# 🐦 鸟类知识小课堂(稳定版)")
 
 
87
 
88
- with gr.Row():
89
- with gr.Column(scale=2):
90
- image_input = gr.Image(type="pil", label="上传鸟类图片", height=300)
91
- examples = gr.Examples(
92
- examples=["eagle.jpg", "penguin.jpg", "peacock.jpg"],
93
- inputs=image_input,
94
- label="示例图片"
95
- )
96
-
97
- with gr.Column(scale=3):
98
- name_output = gr.Textbox(label="识别到的鸟类", interactive=False)
99
- text_output = gr.Textbox(label="趣味知识", lines=4, max_lines=6)
100
- audio_output = gr.Audio(label="语音讲解", autoplay=True, visible=True)
101
 
102
- image_input.change(
103
- process_image,
104
- inputs=image_input,
105
- outputs=[name_output, text_output, audio_output]
106
- )
107
-
108
- # 部署配置(增加硬件检测)
109
- if torch.cuda.is_available():
110
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
111
- else:
112
- print("警告:未检测到GPU,建议在Colab或A10G实例运行")
113
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ # 导入必要库
2
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
 
 
 
 
 
 
3
  from PIL import Image
4
+ import requests
5
+ from io import BytesIO
6
  import torch
7
+ from transformers import pipeline
8
+ import wikipedia # 用于获取鸟类百科信息
9
+ from wikipedia.exceptions import DisambiguationError, PageError
10
 
11
+ # 1. 鸟类图片识别(使用指定模型)
12
+ def bird_classification(image_url):
13
+ model_name = "chriamue/bird-species-classifier"
14
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
15
+ model = AutoModelForImageClassification.from_pretrained(model_name)
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model.to(device)
 
18
 
19
+ # 下载并预处理图片
20
+ response = requests.get(image_url)
21
+ img = Image.open(BytesIO(response.content)).convert("RGB")
22
+ inputs = feature_extractor(img, return_tensors="pt").to(device)
 
 
 
 
 
 
 
23
 
24
+ # 模型推理
25
+ with torch.no_grad():
26
+ outputs = model(**inputs)
27
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
 
 
28
 
29
+ # 获取前1个预测结果
30
+ predicted_id = torch.argmax(probabilities).item()
31
+ labels = model.config.id2label
32
+ bird_species = labels[predicted_id]
33
+ confidence = round(probabilities[predicted_id].item(), 3)
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ return bird_species, confidence
 
 
36
 
37
+ # 2. 鸟类信息获取(使用维基百科API)
38
+ def get_bird_info(species_name):
39
  try:
40
+ # 去除可能的多余标签(如模型输出中的括号内容)
41
+ clean_name = species_name.split("(")[0].strip()
42
+ # 从维基百科获取摘要(英文转中文)
43
+ summary = wikipedia.summary(clean_name, sentences=3, auto_suggest=False)
44
+ return summary
45
+ except (DisambiguationError, PageError):
46
+ return "抱歉,未找到该鸟类的详细信息。"
 
 
 
 
 
47
 
48
+ # 3. 文本转语音(使用TTS模型)
49
+ def text_to_speech(text, output_file="bird_info.mp3"):
50
+ tts = pipeline("text-to-speech", model="tts_models/en_US/tacotron2")
51
+ speech = tts(text)
52
+ with open(output_file, "wb") as f:
53
+ f.write(speech["audio"])
54
+ return output_file
55
 
56
+ # 主函数
57
+ def bird_knowledge_pipeline(image_url):
58
+ # 1. 鸟类识别
59
+ species, confidence = bird_classification(image_url)
60
+ print(f"识别结果:{species}(置信度:{confidence*100:.1f}%)")
61
 
62
+ # 2. 获取详细信息
63
+ info = get_bird_info(species)
64
+ print(f"鸟类介绍:\n{info}")
 
 
 
 
 
 
 
 
 
 
65
 
66
+ # 3. 生成语音
67
+ audio_file = text_to_speech(f"这是{species}的介绍:{info}")
68
+ print(f"语音文件已保存:{audio_file}")
69
+
70
+ return species, info, audio_file