File size: 3,484 Bytes
5ed9dbe
d563f38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def predict(image_file):
    import torch
    import numpy as np
    import cv2
    import matplotlib.pyplot as plt
    import torchvision.models as models
    import os
    from PIL import Image
    from IPython.display import display
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader, Subset

    #load model with params
    model = models.efficientnet_b0(weights=None)
    model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu')))
    device = torch.device('cpu')

    classes = [
        "Apple___Apple_scab",
        "Apple___Black_rot",
        "Apple___Cedar_apple_rust",
        "Apple___healthy",
        "Blueberry___healthy",
        "Cherry_(including_sour)___Powdery_mildew",
        "Cherry_(including_sour)___healthy",
        "Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot",
        "Corn_(maize)___Common_rust_",
        "Corn_(maize)___Northern_Leaf_Blight",
        "Corn_(maize)___healthy",
        "Grape___Black_rot",
        "Grape___Esca_(Black_Measles)",
        "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
        "Grape___healthy",
        "Orange___Haunglongbing_(Citrus_greening)",
        "Peach___Bacterial_spot",
        "Peach___healthy",
        "Pepper,_bell___Bacterial_spot",
        "Pepper,_bell___healthy",
        "Potato___Early_blight",
        "Potato___Late_blight",
        "Potato___healthy",
        "Raspberry___healthy",
        "Soybean___healthy",
        "Squash___Powdery_mildew",
        "Strawberry___Leaf_scorch",
        "Strawberry___healthy",
        "Tomato___Bacterial_spot",
        "Tomato___Early_blight",
        "Tomato___Late_blight",
        "Tomato___Leaf_Mold",
        "Tomato___Septoria_leaf_spot",
        "Tomato___Spider_mites Two-spotted_spider_mite",
        "Tomato___Target_Spot",
        "Tomato___Tomato_Yellow_Leaf_Curl_Virus",
        "Tomato___Tomato_mosaic_virus",
        "Tomato___healthy"
    ]

    def pred_image(image_path, model):
        topk = 3 

        image = Image.open(image_path).convert('RGB')
        transform = transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225])])
        img_normalized = transform(image).unsqueeze(0)
        
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        img_normalized = img_normalized.to(device)

        with torch.no_grad():
            model.eval()  
            output = model(img_normalized)
            probs, indices = torch.topk(torch.softmax(output, dim=1), topk)
            # index = output.data.cpu().numpy().argmax()
        tmp_lst = []
        # print(indices)
        # print(probs)
        for j in range(topk):
            tmp_dct = {}
            label_indx = indices[0][j]
            # print(label_indx)
            class_name = class_names[label_indx]
            tmp_dct["predicted"] =  class_name
            tmp_dct["probability"] =  probs[0][j]
            tmp_lst.append(tmp_dct)
                
                # print(f"Prediction {j+1}: label index: {indices[i][j]}, probability: {probs[i][j]:.4f}")

        # class_name = class_names[index]
        return tmp_lst

    predicted_label = pred_image(image_file,model)
    return predicted_label