File size: 3,286 Bytes
55286c8
 
 
 
 
5ed9dbe
d563f38
 
 
015558b
d563f38
 
55286c8
d563f38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55286c8
d563f38
 
 
 
 
 
 
 
 
55286c8
 
d563f38
 
 
415374b
d563f38
015558b
 
d563f38
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torchvision.models as models
from PIL import Image
from torchvision import transforms

def predict(image_file):

    #load model with params
    model = models.efficientnet_b0(weights=None)
    model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')), strict=False)
    device = torch.device('cpu')

    class_names = [
        "Apple___Apple_scab",
        "Apple___Black_rot",
        "Apple___Cedar_apple_rust",
        "Apple___healthy",
        "Blueberry___healthy",
        "Cherry_(including_sour)___Powdery_mildew",
        "Cherry_(including_sour)___healthy",
        "Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot",
        "Corn_(maize)___Common_rust_",
        "Corn_(maize)___Northern_Leaf_Blight",
        "Corn_(maize)___healthy",
        "Grape___Black_rot",
        "Grape___Esca_(Black_Measles)",
        "Grape___Leaf_blight_(Isariopsis_Leaf_Spot)",
        "Grape___healthy",
        "Orange___Haunglongbing_(Citrus_greening)",
        "Peach___Bacterial_spot",
        "Peach___healthy",
        "Pepper,_bell___Bacterial_spot",
        "Pepper,_bell___healthy",
        "Potato___Early_blight",
        "Potato___Late_blight",
        "Potato___healthy",
        "Raspberry___healthy",
        "Soybean___healthy",
        "Squash___Powdery_mildew",
        "Strawberry___Leaf_scorch",
        "Strawberry___healthy",
        "Tomato___Bacterial_spot",
        "Tomato___Early_blight",
        "Tomato___Late_blight",
        "Tomato___Leaf_Mold",
        "Tomato___Septoria_leaf_spot",
        "Tomato___Spider_mites Two-spotted_spider_mite",
        "Tomato___Target_Spot",
        "Tomato___Tomato_Yellow_Leaf_Curl_Virus",
        "Tomato___Tomato_mosaic_virus",
        "Tomato___healthy"
    ]

    def pred_image(image_path, model):
        topk = 3 

        image = Image.open(image_path).convert('RGB')
        transform = transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225])])
        img_normalized = transform(image).unsqueeze(0)

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        img_normalized = img_normalized.to(device)

        with torch.no_grad():
            model.eval()  
            output = model(img_normalized)
            probs, indices = torch.topk(torch.softmax(output, dim=1), topk)
            # index = output.data.cpu().numpy().argmax()
        tmp_lst = []
        print(indices)
        print(probs)
        for j in range(topk):
            tmp_dct = {}
            label_indx = indices[0][j]
            print("index:", label_indx)
            class_name = class_names[label_indx]
            tmp_dct["predicted"] = class_name
            tmp_dct["probability"] = probs[0][j]
            tmp_lst.append(tmp_dct)
                
                # print(f"Prediction {j+1}: label index: {indices[i][j]}, probability: {probs[i][j]:.4f}")

        # class_name = class_names[index]
        return tmp_lst

    predicted_label = pred_image(image_file,model)
    return predicted_label