plantdx / predictor.py
KiKi
updated name of predict function
5ed9dbe
raw
history blame
3.48 kB
def predict(image_file):
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torchvision.models as models
import os
from PIL import Image
from IPython.display import display
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
#load model with params
model = models.efficientnet_b0(weights=None)
model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu')))
device = torch.device('cpu')
classes = [
"Apple___Apple_scab",
"Apple___Black_rot",
"Apple___Cedar_apple_rust",
"Apple___healthy",
"Blueberry___healthy",
"Cherry_(including_sour)___Powdery_mildew",
"Cherry_(including_sour)___healthy",
"Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot",
"Corn_(maize)___Common_rust_",
"Corn_(maize)___Northern_Leaf_Blight",
"Corn_(maize)___healthy",
"Grape___Black_rot",
"Grape___Esca_(Black_Measles)",
"Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
"Grape___healthy",
"Orange___Haunglongbing_(Citrus_greening)",
"Peach___Bacterial_spot",
"Peach___healthy",
"Pepper,_bell___Bacterial_spot",
"Pepper,_bell___healthy",
"Potato___Early_blight",
"Potato___Late_blight",
"Potato___healthy",
"Raspberry___healthy",
"Soybean___healthy",
"Squash___Powdery_mildew",
"Strawberry___Leaf_scorch",
"Strawberry___healthy",
"Tomato___Bacterial_spot",
"Tomato___Early_blight",
"Tomato___Late_blight",
"Tomato___Leaf_Mold",
"Tomato___Septoria_leaf_spot",
"Tomato___Spider_mites Two-spotted_spider_mite",
"Tomato___Target_Spot",
"Tomato___Tomato_Yellow_Leaf_Curl_Virus",
"Tomato___Tomato_mosaic_virus",
"Tomato___healthy"
]
def pred_image(image_path, model):
topk = 3
image = Image.open(image_path).convert('RGB')
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
img_normalized = transform(image).unsqueeze(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
img_normalized = img_normalized.to(device)
with torch.no_grad():
model.eval()
output = model(img_normalized)
probs, indices = torch.topk(torch.softmax(output, dim=1), topk)
# index = output.data.cpu().numpy().argmax()
tmp_lst = []
# print(indices)
# print(probs)
for j in range(topk):
tmp_dct = {}
label_indx = indices[0][j]
# print(label_indx)
class_name = class_names[label_indx]
tmp_dct["predicted"] = class_name
tmp_dct["probability"] = probs[0][j]
tmp_lst.append(tmp_dct)
# print(f"Prediction {j+1}: label index: {indices[i][j]}, probability: {probs[i][j]:.4f}")
# class_name = class_names[index]
return tmp_lst
predicted_label = pred_image(image_file,model)
return predicted_label