Spaces:
Sleeping
Sleeping
| # app.py | |
| import json | |
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import os | |
| from model import get_mobilenet_model | |
| # 自动检测是否使用GPU | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # 加载模型 | |
| weights_path = "best_model.pth" | |
| model, _, _ = get_mobilenet_model(num_classes=16) | |
| checkpoint = torch.load(weights_path, map_location=device, weights_only=False) | |
| model.load_state_dict(checkpoint['state_dict']) # 注意这里是 checkpoint['state_dict'] | |
| model.to(device) | |
| model.eval() | |
| with open("class_eg_to_cn.json", "r", encoding="utf-8") as f: | |
| class_eg_to_cn = json.load(f) | |
| with open("num_to_class.json", "r", encoding="utf-8") as f: | |
| idx2label = json.load(f) | |
| with open("class_inf.json", "r", encoding="utf-8") as f: | |
| plant_info = json.load(f) | |
| # 图像预处理 | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| ]) | |
| # 推理函数 | |
| def predict(img): | |
| image = transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(image) | |
| probs = torch.softmax(output, dim=1) | |
| conf, pred = torch.max(probs, 1) | |
| cn_label = class_eg_to_cn[idx2label[str(pred.item())]] | |
| intro = plant_info.get(idx2label[str(pred.item())], "暂无介绍") | |
| return cn_label, float(conf.item()), intro | |
| # Gradio界面 | |
| with gr.Blocks(css="body { background-color: #90caf9; font-family: sans-serif; }") as demo: | |
| gr.Markdown("## 🌱 基于深度学习的植物识别系统", elem_id="title") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_input = gr.Image(type="pil", label="选择图像") | |
| btn = gr.Button("开始检测", elem_id="detect") | |
| with gr.Column(scale=3): | |
| label_output = gr.Textbox(label="类别") | |
| confidence_output = gr.Textbox(label="置信度") | |
| description_output = gr.Textbox(label="相关介绍", lines=5) | |
| btn.click(fn=predict, inputs=image_input, outputs=[label_output, confidence_output, description_output]) | |
| demo.launch() |