import torch import torchvision.models as models from PIL import Image from torchvision import transforms def predict(image_file): #load model with params model = models.efficientnet_b0(weights=None) model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')), strict=False) device = torch.device('cpu') class_names = [ "Apple___Apple_scab", "Apple___Black_rot", "Apple___Cedar_apple_rust", "Apple___healthy", "Blueberry___healthy", "Cherry_(including_sour)___Powdery_mildew", "Cherry_(including_sour)___healthy", "Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot", "Corn_(maize)___Common_rust_", "Corn_(maize)___Northern_Leaf_Blight", "Corn_(maize)___healthy", "Grape___Black_rot", "Grape___Esca_(Black_Measles)", "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)", "Grape___healthy", "Orange___Haunglongbing_(Citrus_greening)", "Peach___Bacterial_spot", "Peach___healthy", "Pepper,_bell___Bacterial_spot", "Pepper,_bell___healthy", "Potato___Early_blight", "Potato___Late_blight", "Potato___healthy", "Raspberry___healthy", "Soybean___healthy", "Squash___Powdery_mildew", "Strawberry___Leaf_scorch", "Strawberry___healthy", "Tomato___Bacterial_spot", "Tomato___Early_blight", "Tomato___Late_blight", "Tomato___Leaf_Mold", "Tomato___Septoria_leaf_spot", "Tomato___Spider_mites Two-spotted_spider_mite", "Tomato___Target_Spot", "Tomato___Tomato_Yellow_Leaf_Curl_Virus", "Tomato___Tomato_mosaic_virus", "Tomato___healthy" ] def pred_image(image_path, model): topk = 3 image = Image.open(image_path).convert('RGB') transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) img_normalized = transform(image).unsqueeze(0) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") img_normalized = img_normalized.to(device) with torch.no_grad(): model.eval() output = model(img_normalized) probs, indices = torch.topk(torch.softmax(output, dim=1), topk) # index = output.data.cpu().numpy().argmax() tmp_lst = [] print(indices) print(probs) for j in range(topk): tmp_dct = {} label_indx = indices[0][j] print("index:", label_indx) class_name = class_names[label_indx] tmp_dct["predicted"] = class_name tmp_dct["probability"] = probs[0][j] tmp_lst.append(tmp_dct) # print(f"Prediction {j+1}: label index: {indices[i][j]}, probability: {probs[i][j]:.4f}") # class_name = class_names[index] return tmp_lst predicted_label = pred_image(image_file,model) return predicted_label