Spaces:
Runtime error
Runtime error
File size: 3,484 Bytes
5ed9dbe d563f38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
def predict(image_file):
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torchvision.models as models
import os
from PIL import Image
from IPython.display import display
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
#load model with params
model = models.efficientnet_b0(weights=None)
model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu')))
device = torch.device('cpu')
classes = [
"Apple___Apple_scab",
"Apple___Black_rot",
"Apple___Cedar_apple_rust",
"Apple___healthy",
"Blueberry___healthy",
"Cherry_(including_sour)___Powdery_mildew",
"Cherry_(including_sour)___healthy",
"Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot",
"Corn_(maize)___Common_rust_",
"Corn_(maize)___Northern_Leaf_Blight",
"Corn_(maize)___healthy",
"Grape___Black_rot",
"Grape___Esca_(Black_Measles)",
"Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
"Grape___healthy",
"Orange___Haunglongbing_(Citrus_greening)",
"Peach___Bacterial_spot",
"Peach___healthy",
"Pepper,_bell___Bacterial_spot",
"Pepper,_bell___healthy",
"Potato___Early_blight",
"Potato___Late_blight",
"Potato___healthy",
"Raspberry___healthy",
"Soybean___healthy",
"Squash___Powdery_mildew",
"Strawberry___Leaf_scorch",
"Strawberry___healthy",
"Tomato___Bacterial_spot",
"Tomato___Early_blight",
"Tomato___Late_blight",
"Tomato___Leaf_Mold",
"Tomato___Septoria_leaf_spot",
"Tomato___Spider_mites Two-spotted_spider_mite",
"Tomato___Target_Spot",
"Tomato___Tomato_Yellow_Leaf_Curl_Virus",
"Tomato___Tomato_mosaic_virus",
"Tomato___healthy"
]
def pred_image(image_path, model):
topk = 3
image = Image.open(image_path).convert('RGB')
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
img_normalized = transform(image).unsqueeze(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
img_normalized = img_normalized.to(device)
with torch.no_grad():
model.eval()
output = model(img_normalized)
probs, indices = torch.topk(torch.softmax(output, dim=1), topk)
# index = output.data.cpu().numpy().argmax()
tmp_lst = []
# print(indices)
# print(probs)
for j in range(topk):
tmp_dct = {}
label_indx = indices[0][j]
# print(label_indx)
class_name = class_names[label_indx]
tmp_dct["predicted"] = class_name
tmp_dct["probability"] = probs[0][j]
tmp_lst.append(tmp_dct)
# print(f"Prediction {j+1}: label index: {indices[i][j]}, probability: {probs[i][j]:.4f}")
# class_name = class_names[index]
return tmp_lst
predicted_label = pred_image(image_file,model)
return predicted_label
|