File size: 3,962 Bytes
6b290da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import collections

class QuantizedPlantClassifier(nn.Module): 
    def __init__(self, num_classes=39, in_features=3, base_filters=16, # conv stride stays 1, kernel size stays 3
                 n_conv_layers=6, num_fc_units=256):
        super(QuantizedPlantClassifier, self).__init__()
        # want to modularize the conv layers
        def make_conv_block(in_channels, out_channels, layer_idx=""):
            return nn.Sequential(
                collections.OrderedDict([
                    ('conv'+layer_idx, nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=0)),
                    ('bn'+layer_idx, nn.BatchNorm2d(out_channels)),
                    ('relu'+layer_idx, nn.ReLU(inplace=True)),
                    ('pool'+layer_idx, nn.MaxPool2d(kernel_size=2, stride=2))
                ])
            )
        self.quant = torch.quantization.QuantStub()
        layers = []
        current_in_channels = in_features
        for i in range(n_conv_layers):
            layers.append(make_conv_block(current_in_channels, base_filters * (2 ** i), layer_idx=str(i)))
            current_in_channels = base_filters * (2 ** i)
        self.features = nn.Sequential(*layers)

        with torch.no_grad():
            dummy_input = torch.randn(1, in_features, 256, 256)
            dummy_output = self.features(dummy_input)
            flattened_size = dummy_output.view(1, -1).shape[1]
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flattened_size, num_fc_units),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(num_fc_units, num_classes)
        )
        self.dequant = torch.quantization.DeQuantStub()
        
    def forward(self, x):
        x = self.quant(x)
        x = self.features(x)
        x = self.classifier(x)
        x = self.dequant(x)
        return x

# loading model
model = QuantizedPlantClassifier()
model.eval()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
torch.quantization.convert(model, inplace=True)

model.load_state_dict(torch.load("model_dict.pth")) 

classes = ['Apple___Apple_scab',
 'Apple___Black_rot',
 'Apple___Cedar_apple_rust',
 'Apple___healthy',
 'Background_without_leaves',
 'Blueberry___healthy',
 'Cherry___Powdery_mildew',
 'Cherry___healthy',
 'Corn___Cercospora_leaf_spot Gray_leaf_spot',
 'Corn___Common_rust',
 'Corn___Northern_Leaf_Blight',
 'Corn___healthy',
 'Grape___Black_rot',
 'Grape___Esca_(Black_Measles)',
 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
 'Grape___healthy',
 'Orange___Haunglongbing_(Citrus_greening)',
 'Peach___Bacterial_spot',
 'Peach___healthy',
 'Pepper,_bell___Bacterial_spot',
 'Pepper,_bell___healthy',
 'Potato___Early_blight',
 'Potato___Late_blight',
 'Potato___healthy',
 'Raspberry___healthy',
 'Soybean___healthy',
 'Squash___Powdery_mildew',
 'Strawberry___Leaf_scorch',
 'Strawberry___healthy',
 'Tomato___Bacterial_spot',
 'Tomato___Early_blight',
 'Tomato___Late_blight',
 'Tomato___Leaf_Mold',
 'Tomato___Septoria_leaf_spot',
 'Tomato___Spider_mites Two-spotted_spider_mite',
 'Tomato___Target_Spot',
 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
 'Tomato___Tomato_mosaic_virus',
 'Tomato___healthy']

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

def predict(img: Image.Image):
    img = transform(img).unsqueeze(0)
    with torch.no_grad():
        logits = model(img)
        probs = torch.softmax(logits, dim=1)[0]

    return {classes[i]: float(probs[i]) for i in range(len(classes))}

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(num_top_classes=3),
    title="Image Classifier Demo",
    description="Upload an image to get predictions with confidence scores."
)

demo.launch()