# app.py import json import gradio as gr import torch from torchvision import transforms from PIL import Image import os from model import get_mobilenet_model # 自动检测是否使用GPU device = "cuda" if torch.cuda.is_available() else "cpu" # 加载模型 weights_path = "best_model.pth" model, _, _ = get_mobilenet_model(num_classes=16) checkpoint = torch.load(weights_path, map_location=device, weights_only=False) model.load_state_dict(checkpoint['state_dict']) # 注意这里是 checkpoint['state_dict'] model.to(device) model.eval() with open("class_eg_to_cn.json", "r", encoding="utf-8") as f: class_eg_to_cn = json.load(f) with open("num_to_class.json", "r", encoding="utf-8") as f: idx2label = json.load(f) with open("class_inf.json", "r", encoding="utf-8") as f: plant_info = json.load(f) # 图像预处理 transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # 推理函数 def predict(img): image = transform(img).unsqueeze(0).to(device) with torch.no_grad(): output = model(image) probs = torch.softmax(output, dim=1) conf, pred = torch.max(probs, 1) cn_label = class_eg_to_cn[idx2label[str(pred.item())]] intro = plant_info.get(idx2label[str(pred.item())], "暂无介绍") return cn_label, float(conf.item()), intro # Gradio界面 with gr.Blocks(css="body { background-color: #90caf9; font-family: sans-serif; }") as demo: gr.Markdown("## 🌱 基于深度学习的植物识别系统", elem_id="title") with gr.Row(): with gr.Column(scale=2): image_input = gr.Image(type="pil", label="选择图像") btn = gr.Button("开始检测", elem_id="detect") with gr.Column(scale=3): label_output = gr.Textbox(label="类别") confidence_output = gr.Textbox(label="置信度") description_output = gr.Textbox(label="相关介绍", lines=5) btn.click(fn=predict, inputs=image_input, outputs=[label_output, confidence_output, description_output]) demo.launch()