edmos7's picture
first commit
6b290da
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import collections
class QuantizedPlantClassifier(nn.Module):
def __init__(self, num_classes=39, in_features=3, base_filters=16, # conv stride stays 1, kernel size stays 3
n_conv_layers=6, num_fc_units=256):
super(QuantizedPlantClassifier, self).__init__()
# want to modularize the conv layers
def make_conv_block(in_channels, out_channels, layer_idx=""):
return nn.Sequential(
collections.OrderedDict([
('conv'+layer_idx, nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=0)),
('bn'+layer_idx, nn.BatchNorm2d(out_channels)),
('relu'+layer_idx, nn.ReLU(inplace=True)),
('pool'+layer_idx, nn.MaxPool2d(kernel_size=2, stride=2))
])
)
self.quant = torch.quantization.QuantStub()
layers = []
current_in_channels = in_features
for i in range(n_conv_layers):
layers.append(make_conv_block(current_in_channels, base_filters * (2 ** i), layer_idx=str(i)))
current_in_channels = base_filters * (2 ** i)
self.features = nn.Sequential(*layers)
with torch.no_grad():
dummy_input = torch.randn(1, in_features, 256, 256)
dummy_output = self.features(dummy_input)
flattened_size = dummy_output.view(1, -1).shape[1]
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(flattened_size, num_fc_units),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(num_fc_units, num_classes)
)
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.features(x)
x = self.classifier(x)
x = self.dequant(x)
return x
# loading model
model = QuantizedPlantClassifier()
model.eval()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
torch.quantization.convert(model, inplace=True)
model.load_state_dict(torch.load("model_dict.pth"))
classes = ['Apple___Apple_scab',
'Apple___Black_rot',
'Apple___Cedar_apple_rust',
'Apple___healthy',
'Background_without_leaves',
'Blueberry___healthy',
'Cherry___Powdery_mildew',
'Cherry___healthy',
'Corn___Cercospora_leaf_spot Gray_leaf_spot',
'Corn___Common_rust',
'Corn___Northern_Leaf_Blight',
'Corn___healthy',
'Grape___Black_rot',
'Grape___Esca_(Black_Measles)',
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
'Grape___healthy',
'Orange___Haunglongbing_(Citrus_greening)',
'Peach___Bacterial_spot',
'Peach___healthy',
'Pepper,_bell___Bacterial_spot',
'Pepper,_bell___healthy',
'Potato___Early_blight',
'Potato___Late_blight',
'Potato___healthy',
'Raspberry___healthy',
'Soybean___healthy',
'Squash___Powdery_mildew',
'Strawberry___Leaf_scorch',
'Strawberry___healthy',
'Tomato___Bacterial_spot',
'Tomato___Early_blight',
'Tomato___Late_blight',
'Tomato___Leaf_Mold',
'Tomato___Septoria_leaf_spot',
'Tomato___Spider_mites Two-spotted_spider_mite',
'Tomato___Target_Spot',
'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
'Tomato___Tomato_mosaic_virus',
'Tomato___healthy']
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
def predict(img: Image.Image):
img = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(img)
probs = torch.softmax(logits, dim=1)[0]
return {classes[i]: float(probs[i]) for i in range(len(classes))}
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
title="Image Classifier Demo",
description="Upload an image to get predictions with confidence scores."
)
demo.launch()