CHCZHC's picture
Upload app.py
059d827
import torch
import torchvision.transforms as transforms
from torch import nn
import torch.nn.functional as F
import gradio as gr
import json
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
image_path = '/content/mogu_dataset/有毒类/Amanita/000_jY-KIKiHVjQ.jpg'
num_classes = 2
label_name_list = ['No', 'Yes']
# step1: 加载模型
# model = model_big
class BaseModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 3)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 3)
self.conv3 = nn.Conv2d(16, 32, 3)
self.conv4 = nn.Conv2d(32, 32, 3)
self.fc1 = nn.Linear(4608, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, num_classes)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = self.pool(F.relu(self.conv4(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
cnn_model = BaseModel().to(device)
cnn_model.load_state_dict(torch.load("cnn_model.pth"))
model = cnn_model
# model = base_model_vgg13
model.eval()
# step2: 图片转换
train_transform = transforms.Compose([
transforms.RandomRotation(5),
transforms.RandomHorizontalFlip(),
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(inp):
# 定义预处理变换
transform = train_transform
# 加载图片并进行预处理
# image = Image.open(image_path)
image = transform(inp).unsqueeze(0).to(device)
# step3:使用模型进行预测
with torch.no_grad():
output = model(image)
print(output)
# step4:数据后处理
# 计算预测概率
pred_score = nn.functional.softmax(output[0], dim=0)
pred_score = pred_score.numpy()
# pred_score = torch.max(pred_score)
# 获取预测结果
pred_index = torch.argmax(output, dim=1).item()
pred_label = label_name_list[pred_index]
result_dict = {'pred_score':max(pred_score),'pred_index':pred_index,'pred_label':pred_label }
# result_json = json.dumps(result_dict)
# # 打印预测结果
# print(f"predict class name : {pred_label} \npredict score : {pred_score}")
result_json = str(result_dict)
return result_json
demo = gr.Interface(fn=predict,
inputs=gr.inputs.Image(type="pil"),
outputs="text",
examples=[["111212.jpg"]],
)
demo.launch()