sshenai commited on
Commit
50d090f
·
verified ·
1 Parent(s): 0bd9eee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -47
app.py CHANGED
@@ -1,27 +1,36 @@
 
 
 
 
 
 
1
  import gradio as gr
2
- from transformers import pipeline
3
  from PIL import Image
4
  import torch
5
 
6
- # 初始化模型(缓存加载)
7
  def init_models():
8
- # 鸟类分类模型
9
  classifier = pipeline(
10
  "image-classification",
11
  model="chriamue/bird-species-classifier",
12
  device=0 if torch.cuda.is_available() else -1
13
  )
14
 
15
- # 文本生成模型(设置量化降低显存占用)
16
  text_generator = pipeline(
17
  "text-generation",
18
- model="Qwen/Qwen3-235B-A22B",
19
  torch_dtype=torch.bfloat16,
20
  device_map="auto",
21
- model_kwargs={"load_in_4bit": True}
 
 
 
22
  )
23
 
24
- # 语音合成模型
25
  tts = pipeline(
26
  "text-to-speech",
27
  model="facebook/mms-tts-eng",
@@ -30,64 +39,65 @@ def init_models():
30
 
31
  return classifier, text_generator, tts
32
 
33
- # 生成儿童友好的鸟类描述
34
  def generate_child_friendly_text(bird_name):
35
- PROMPT = f"""请用简单易懂的语言,向6-12岁儿童介绍{bird_name}:
36
- 1. 用比喻手法描述外形特征
37
- 2. 解释生活习性时使用拟人化
38
- 3. 包含一个有趣的小知识
39
- 4. 语句长度不超过15个英文单词
40
- 5. 避免使用专业术语"""
41
-
42
  response = text_generator(
43
  PROMPT,
44
- max_new_tokens=200,
45
- temperature=0.7,
 
46
  do_sample=True
47
  )
48
 
49
- return response[0]['generated_text'].split('\n')[2:] # 提取核心内容
 
 
50
 
51
- # 主处理流程
52
  def process_image(image):
53
  try:
54
- # Step 1: 鸟类识别
55
  classification = classifier(image)
56
  bird_name = classification[0]['label']
57
-
58
- # Step 2: 生成描述
59
  description = generate_child_friendly_text(bird_name)
60
-
61
- # Step 3: 语音合成
62
- speech = tts(description, forward_params={"speaker_id": 6}) # 使用儿童语音
63
 
64
  return {
65
  "bird_name": bird_name,
66
- "description": "\n".join(description),
67
  "audio": speech["audio"]
68
  }
69
  except Exception as e:
70
  return f"处理错误: {str(e)}"
71
 
72
- # 初始化模型
 
 
73
  classifier, text_generator, tts = init_models()
74
 
75
- # 创建Gradio界面
76
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
77
- gr.Markdown("# 🐦 鸟类知识小课堂")
78
 
79
  with gr.Row():
80
- image_input = gr.Image(type="pil", label="上传鸟类图片")
81
- audio_output = gr.Audio(label="语音讲解", autoplay=True)
82
-
83
- with gr.Column():
84
- name_output = gr.Textbox(label="识别到的鸟类")
85
- text_output = gr.Textbox(label="趣味知识", lines=4)
86
-
87
- examples = gr.Examples(
88
- examples=["eagle.jpg", "penguin.jpg", "peacock.jpg"],
89
- inputs=image_input
90
- )
 
91
 
92
  image_input.change(
93
  process_image,
@@ -95,9 +105,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
95
  outputs=[name_output, text_output, audio_output]
96
  )
97
 
98
- # 部署配置
99
- demo.launch(
100
- server_name="0.0.0.0",
101
- server_port=7860,
102
- share=True
103
- )
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 鸟类知识科普系统(修正版) by [你的名字]
4
+ ISOM5240 Group Project
5
+ """
6
+
7
  import gradio as gr
8
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
9
  from PIL import Image
10
  import torch
11
 
12
+ # 初始化模型(兼容性优化)
13
  def init_models():
14
+ # 鸟类分类模型(保持不变)
15
  classifier = pipeline(
16
  "image-classification",
17
  model="chriamue/bird-species-classifier",
18
  device=0 if torch.cuda.is_available() else -1
19
  )
20
 
21
+ # 替换为DeepSeek-R1模型(兼容性配置)
22
  text_generator = pipeline(
23
  "text-generation",
24
+ model="deepseek-ai/DeepSeek-R1",
25
  torch_dtype=torch.bfloat16,
26
  device_map="auto",
27
+ model_kwargs={
28
+ "load_in_4bit": True,
29
+ "trust_remote_code": True # 必须开启远程代码执行
30
+ }
31
  )
32
 
33
+ # 语音合成模型(保持不变)
34
  tts = pipeline(
35
  "text-to-speech",
36
  model="facebook/mms-tts-eng",
 
39
 
40
  return classifier, text_generator, tts
41
 
42
+ # 生成儿童友好的鸟类描述(优化Prompt)
43
  def generate_child_friendly_text(bird_name):
44
+ PROMPT = f"""6-12岁儿童能理解的语言介绍{bird_name}:
45
+ 1. 用动物拟人化的方式描述特征(例如:穿彩色外套的鸟)
46
+ 2. 解释生活习性时结合日常场景(如:像小朋友一样喜欢玩耍)
47
+ 3. 包含一个趣味冷知识(例如:飞行距离相当于绕操场XX圈)
48
+ 4. 语句长度控制在10-15个英文单词
49
+ 5. 使用比喻手法代替专业术语"""
50
+
51
  response = text_generator(
52
  PROMPT,
53
+ max_new_tokens=150,
54
+ temperature=0.8,
55
+ top_k=40,
56
  do_sample=True
57
  )
58
 
59
+ # 后处理优化
60
+ cleaned_text = response[0]['generated_text'].split('\n')[2]
61
+ return cleaned_text.replace("**", "") # 去除多余符号
62
 
63
+ # 主处理流程(增加异常处理)
64
  def process_image(image):
65
  try:
 
66
  classification = classifier(image)
67
  bird_name = classification[0]['label']
 
 
68
  description = generate_child_friendly_text(bird_name)
69
+ speech = tts(description, forward_params={"speaker_id": 6})
 
 
70
 
71
  return {
72
  "bird_name": bird_name,
73
+ "description": description,
74
  "audio": speech["audio"]
75
  }
76
  except Exception as e:
77
  return f"处理错误: {str(e)}"
78
 
79
+ # 初始化模型(增加缓存清理)
80
+ from transformers.utils import cached_file
81
+ cached_file.cache_clear()
82
  classifier, text_generator, tts = init_models()
83
 
84
+ # 创建Gradio界面(布局优化)
85
+ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px !important}") as demo:
86
+ gr.Markdown("# 🐦 鸟类知识小课堂(稳定版)")
87
 
88
  with gr.Row():
89
+ with gr.Column(scale=2):
90
+ image_input = gr.Image(type="pil", label="上传鸟类图片", height=300)
91
+ examples = gr.Examples(
92
+ examples=["eagle.jpg", "penguin.jpg", "peacock.jpg"],
93
+ inputs=image_input,
94
+ label="示例图片"
95
+ )
96
+
97
+ with gr.Column(scale=3):
98
+ name_output = gr.Textbox(label="识别到的鸟类", interactive=False)
99
+ text_output = gr.Textbox(label="趣味知识", lines=4, max_lines=6)
100
+ audio_output = gr.Audio(label="语音讲解", autoplay=True, visible=True)
101
 
102
  image_input.change(
103
  process_image,
 
105
  outputs=[name_output, text_output, audio_output]
106
  )
107
 
108
+ # 部署配置(增加硬件检测)
109
+ if torch.cuda.is_available():
110
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
111
+ else:
112
+ print("警告:未检测到GPU,建议在Colab或A10G实例运行")
113
+ demo.launch(server_name="0.0.0.0", server_port=7860)