penguin218's picture
a
c34b829 verified
# app.py
import json
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import os
from model import get_mobilenet_model
# 自动检测是否使用GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# 加载模型
weights_path = "best_model.pth"
model, _, _ = get_mobilenet_model(num_classes=16)
checkpoint = torch.load(weights_path, map_location=device, weights_only=False)
model.load_state_dict(checkpoint['state_dict']) # 注意这里是 checkpoint['state_dict']
model.to(device)
model.eval()
with open("class_eg_to_cn.json", "r", encoding="utf-8") as f:
class_eg_to_cn = json.load(f)
with open("num_to_class.json", "r", encoding="utf-8") as f:
idx2label = json.load(f)
with open("class_inf.json", "r", encoding="utf-8") as f:
plant_info = json.load(f)
# 图像预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# 推理函数
def predict(img):
image = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image)
probs = torch.softmax(output, dim=1)
conf, pred = torch.max(probs, 1)
cn_label = class_eg_to_cn[idx2label[str(pred.item())]]
intro = plant_info.get(idx2label[str(pred.item())], "暂无介绍")
return cn_label, float(conf.item()), intro
# Gradio界面
with gr.Blocks(css="body { background-color: #90caf9; font-family: sans-serif; }") as demo:
gr.Markdown("## 🌱 基于深度学习的植物识别系统", elem_id="title")
with gr.Row():
with gr.Column(scale=2):
image_input = gr.Image(type="pil", label="选择图像")
btn = gr.Button("开始检测", elem_id="detect")
with gr.Column(scale=3):
label_output = gr.Textbox(label="类别")
confidence_output = gr.Textbox(label="置信度")
description_output = gr.Textbox(label="相关介绍", lines=5)
btn.click(fn=predict, inputs=image_input, outputs=[label_output, confidence_output, description_output])
demo.launch()