Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """ | |
| 鸟类知识科普系统(Qwen3优化版) by [你的名字] | |
| ISOM5240 Group Project | |
| """ | |
| import transformers | |
| import os | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import ( | |
| pipeline, | |
| AutoConfig, | |
| AutoImageProcessor | |
| ) | |
| # ========== 环境配置 ========== | |
| os.environ["TRANSFORMERS_CACHE"] = "./model_cache" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"当前设备: {DEVICE.upper()}") | |
| # ========== 模型初始化 ========== | |
| def init_models(): | |
| """统一模型加载逻辑""" | |
| try: | |
| # 1. 鸟类分类模型(显式配置图像处理器) | |
| image_processor = AutoImageProcessor.from_pretrained( | |
| "chriamue/bird-species-classifier", | |
| use_fast=True # 强制启用快速模式 | |
| ) | |
| classifier = pipeline( | |
| task="image-classification", | |
| model="chriamue/bird-species-classifier", | |
| feature_extractor=image_processor, | |
| device=DEVICE | |
| ) | |
| # 2. 文本生成模型(解决revision参数冲突) | |
| qwen_config = AutoConfig.from_pretrained( | |
| "Qwen/Qwen-7B-Chat", | |
| trust_remote_code=True | |
| ) | |
| text_generator = pipeline( | |
| task="text-generation", | |
| model="Qwen/Qwen-7B-Chat", | |
| config=qwen_config, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| model_kwargs={ | |
| "cache_dir": "./model_cache", | |
| "revision": "main" # 唯一指定版本 | |
| } | |
| ) | |
| # 3. 语音合成模型 | |
| tts = pipeline( | |
| task="text-to-speech", | |
| model="facebook/mms-tts-eng", | |
| device=DEVICE | |
| ) | |
| return classifier, text_generator, tts | |
| except Exception as e: | |
| raise RuntimeError(f"模型加载失败: {str(e)}") | |
| # ========== 核心逻辑 ========== | |
| def generate_child_friendly_text(bird_name): | |
| """生成儿童友好型描述(优化prompt工程)""" | |
| prompt = f"""用6-12岁儿童能理解的语言描述{bird_name}: | |
| ★ 使用比喻(例如:羽毛像彩虹糖纸) | |
| ★ 包含趣味冷知识(例如:每天吃自身体重30%的食物) | |
| ★ 语句简短(每句不超过15个英文单词) | |
| ★ 避免专业术语""" | |
| try: | |
| response = text_generator( | |
| prompt, | |
| max_new_tokens=120, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True | |
| ) | |
| return response[0]['generated_text'].split("描述")[-1].strip() | |
| except Exception as e: | |
| return f"文本生成失败: {str(e)}" | |
| def process_image(img): | |
| """端到端处理流程""" | |
| try: | |
| # 1. 图像分类 | |
| classification = classifier(img) | |
| bird_name = classification[0]['label'] | |
| # 2. 生成描述 | |
| description = generate_child_friendly_text(bird_name) | |
| # 3. 语音合成 | |
| speech = tts(description, forward_params={"speaker_id": 6}) | |
| return bird_name, description, speech["audio"] | |
| except Exception as e: | |
| return "错误", f"处理失败: {str(e)}", None | |
| # ========== 初始化验证 ========== | |
| if __name__ == "__main__": | |
| # 预加载模型(验证可用性) | |
| classifier, text_generator, tts = init_models() | |
| # 构建Gradio界面 | |
| with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px}") as demo: | |
| gr.Markdown("# 🐦 鸟类知识小课堂(稳定版)") | |
| with gr.Row(): | |
| img_input = gr.Image(type="pil", label="上传鸟类图片", height=300) | |
| audio_output = gr.Audio(label="语音讲解", autoplay=True) | |
| with gr.Column(): | |
| name_output = gr.Textbox(label="识别结果") | |
| desc_output = gr.Textbox(label="趣味知识", lines=4) | |
| gr.Examples( | |
| examples=["test_images/eagle.jpg", "test_images/penguin.jpg"], | |
| inputs=img_input, | |
| label="示例图片" | |
| ) | |
| img_input.change( | |
| process_image, | |
| inputs=img_input, | |
| outputs=[name_output, desc_output, audio_output] | |
| ) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) |