sshenai commited on
Commit
d32f0b3
·
verified ·
1 Parent(s): e164259

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -51
app.py CHANGED
@@ -1,73 +1,99 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- 鸟类知识科普系统(Qwen3优化版) by [你的名字]
4
  ISOM5240 Group Project
5
  """
6
 
7
  import gradio as gr
8
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
9
  from PIL import Image
10
  import torch
 
 
 
11
 
12
- # 强制清理旧版缓存
13
- from transformers.utils import move_cache
14
- move_cache()
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # 初始化模型(兼容Qwen3)
17
  def init_models():
18
- # 鸟类分类模型(保持不变)
 
 
19
  classifier = pipeline(
20
- "image-classification",
21
  model="chriamue/bird-species-classifier",
22
  device=0 if torch.cuda.is_available() else -1
23
  )
24
 
25
- # 更新为Qwen3模型(官方支持版本)
26
  text_generator = pipeline(
27
- "text-generation",
28
- model="Qwen/Qwen-7B-Chat", # 使用官方维护版本
29
  device_map="auto",
30
  torch_dtype=torch.bfloat16,
31
- trust_remote_code=True, # 必须开启
32
  model_kwargs={
33
  "revision": "main",
34
- "force_download": True # 替换弃用参数
35
  }
36
  )
37
 
38
- # 语音合成模型(保持不变)
39
  tts = pipeline(
40
- "text-to-speech",
41
  model="facebook/mms-tts-eng",
42
  device=0 if torch.cuda.is_available() else -1
43
  )
44
 
45
  return classifier, text_generator, tts
46
 
47
- # 生成儿童友好的鸟类描述
48
  def generate_child_friendly_text(bird_name):
 
49
  PROMPT = f"""以6-12岁儿童能理解的方式描述{bird_name}:
50
- 1. 用比喻手法(如:羽毛像彩虹糖纸)
51
- 2. 包含一个趣味冷知识(例如:每天吃相当于自身体重30%的食物)
52
- 3. 语句长度不超过15个英文单词
53
- 4. 避免使用专业术语"""
54
 
55
  response = text_generator(
56
  PROMPT,
57
  max_new_tokens=150,
58
  temperature=0.7,
 
59
  do_sample=True
60
  )
61
 
62
- return response[0]['generated_text'].split('\n')[2]
 
 
 
63
 
64
- # 主处理流程
65
  def process_image(image):
 
66
  try:
 
67
  classification = classifier(image)
68
  bird_name = classification[0]['label']
 
 
69
  description = generate_child_friendly_text(bird_name)
70
- speech = tts(description, forward_params={"speaker_id": 6})
 
 
71
 
72
  return {
73
  "bird_name": bird_name,
@@ -77,33 +103,38 @@ def process_image(image):
77
  except Exception as e:
78
  return f"处理错误: {str(e)}"
79
 
80
- # 初始化模型
81
- classifier, text_generator, tts = init_models()
82
-
83
- # 创建Gradio界面
84
- with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px}") as demo:
85
- gr.Markdown("# 🐦 鸟类知识小课堂(Qwen3版)")
86
-
87
- with gr.Row():
88
- image_input = gr.Image(type="pil", label="上传鸟类图片", height=300)
89
- audio_output = gr.Audio(label="语音讲解", autoplay=True)
90
-
91
- with gr.Column():
92
- name_output = gr.Textbox(label="识别到的鸟类")
93
- text_output = gr.Textbox(label="趣味知识", lines=4)
94
 
95
- examples = gr.Examples(
96
- examples=["eagle.jpg", "penguin.jpg", "peacock.jpg"],
97
- inputs=image_input,
98
- label="示例图片"
99
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- image_input.change(
102
- process_image,
103
- inputs=image_input,
104
- outputs=[name_output, text_output, audio_output]
105
- )
106
-
107
- # 部署配置
108
- if __name__ == "__main__":
109
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ 鸟类知识科普系统(兼容版)by [你的名字]
4
  ISOM5240 Group Project
5
  """
6
 
7
  import gradio as gr
8
+ from transformers import pipeline, AutoModelForCausalLM
9
  from PIL import Image
10
  import torch
11
+ import shutil
12
+ import os
13
+ from pathlib import Path
14
 
15
+ # ---- 兼容性缓存清理方案 ----
16
+ def clear_hf_cache():
17
+ """清除Hugging Face缓存目录(兼容所有版本)"""
18
+ cache_paths = [
19
+ Path("~/.cache/huggingface/hub"), # Linux/Mac
20
+ Path(os.environ.get("TRANSFORMERS_CACHE", "")), # 自定义缓存路径
21
+ Path("transformers") # Colab环境
22
+ ]
23
+
24
+ for path in cache_paths:
25
+ expanded_path = path.expanduser()
26
+ if expanded_path.exists():
27
+ print(f"清理缓存目录: {expanded_path}")
28
+ shutil.rmtree(expanded_path, ignore_errors=True)
29
 
30
+ # ---- 模型初始化 ----
31
  def init_models():
32
+ clear_hf_cache() # 执行缓存清理
33
+
34
+ # 1. 鸟类分类模型
35
  classifier = pipeline(
36
+ task="image-classification",
37
  model="chriamue/bird-species-classifier",
38
  device=0 if torch.cuda.is_available() else -1
39
  )
40
 
41
+ # 2. 文本生成模型(Qwen3
42
  text_generator = pipeline(
43
+ task="text-generation",
44
+ model="Qwen/Qwen-7B-Chat",
45
  device_map="auto",
46
  torch_dtype=torch.bfloat16,
47
+ trust_remote_code=True,
48
  model_kwargs={
49
  "revision": "main",
50
+ "force_download": True
51
  }
52
  )
53
 
54
+ # 3. 语音合成模型
55
  tts = pipeline(
56
+ task="text-to-speech",
57
  model="facebook/mms-tts-eng",
58
  device=0 if torch.cuda.is_available() else -1
59
  )
60
 
61
  return classifier, text_generator, tts
62
 
63
+ # ---- 核心处理逻辑 ----
64
  def generate_child_friendly_text(bird_name):
65
+ """生成儿童友好的鸟类描述"""
66
  PROMPT = f"""以6-12岁儿童能理解的方式描述{bird_name}:
67
+ 1. 使用比喻手法(如:羽毛像彩虹糖纸)
68
+ 2. 包含一个趣味冷知识(例如:每天吃相当于自身体重30%的食物)
69
+ 3. 语句长度不超过15个英文单词
70
+ 4. 避免使用专业术语"""
71
 
72
  response = text_generator(
73
  PROMPT,
74
  max_new_tokens=150,
75
  temperature=0.7,
76
+ top_p=0.9,
77
  do_sample=True
78
  )
79
 
80
+ # 清洗输出文本
81
+ full_text = response[0]['generated_text']
82
+ clean_text = full_text.split("描述{}:".format(bird_name))[-1].strip()
83
+ return clean_text.replace("**", "").replace("```", "")
84
 
 
85
  def process_image(image):
86
+ """处理图片生成结果的完整流程"""
87
  try:
88
+ # 步骤1: 鸟类识别
89
  classification = classifier(image)
90
  bird_name = classification[0]['label']
91
+
92
+ # 步骤2: 生成描述
93
  description = generate_child_friendly_text(bird_name)
94
+
95
+ # 步骤3: 语音合成
96
+ speech = tts(description, forward_params={"speaker_id": 6}) # 使用儿童音色
97
 
98
  return {
99
  "bird_name": bird_name,
 
103
  except Exception as e:
104
  return f"处理错误: {str(e)}"
105
 
106
+ # ---- 初始化与界面 ----
107
+ if __name__ == "__main__":
108
+ # 初始化模型(显式指定设备)
109
+ classifier, text_generator, tts = init_models()
 
 
 
 
 
 
 
 
 
 
110
 
111
+ # 构建Gradio界面
112
+ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px}") as demo:
113
+ gr.Markdown("# 🐦 鸟类知识小课堂(兼容版)")
114
+
115
+ with gr.Row():
116
+ image_input = gr.Image(type="pil", label="上传鸟类图片", height=300)
117
+ audio_output = gr.Audio(label="语音讲解", autoplay=True)
118
+
119
+ with gr.Column():
120
+ name_output = gr.Textbox(label="识别到的鸟类")
121
+ text_output = gr.Textbox(label="趣味知识", lines=4)
122
+
123
+ examples = gr.Examples(
124
+ examples=["eagle.jpg", "penguin.jpg", "peacock.jpg"],
125
+ inputs=image_input,
126
+ label="示例图片"
127
+ )
128
+
129
+ image_input.change(
130
+ process_image,
131
+ inputs=image_input,
132
+ outputs=[name_output, text_output, audio_output]
133
+ )
134
 
135
+ # 启动服务
136
+ demo.launch(
137
+ server_name="0.0.0.0",
138
+ server_port=7860,
139
+ share=True
140
+ )