Spaces:
Sleeping
Sleeping
File size: 3,962 Bytes
6b290da | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import collections
class QuantizedPlantClassifier(nn.Module):
def __init__(self, num_classes=39, in_features=3, base_filters=16, # conv stride stays 1, kernel size stays 3
n_conv_layers=6, num_fc_units=256):
super(QuantizedPlantClassifier, self).__init__()
# want to modularize the conv layers
def make_conv_block(in_channels, out_channels, layer_idx=""):
return nn.Sequential(
collections.OrderedDict([
('conv'+layer_idx, nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=0)),
('bn'+layer_idx, nn.BatchNorm2d(out_channels)),
('relu'+layer_idx, nn.ReLU(inplace=True)),
('pool'+layer_idx, nn.MaxPool2d(kernel_size=2, stride=2))
])
)
self.quant = torch.quantization.QuantStub()
layers = []
current_in_channels = in_features
for i in range(n_conv_layers):
layers.append(make_conv_block(current_in_channels, base_filters * (2 ** i), layer_idx=str(i)))
current_in_channels = base_filters * (2 ** i)
self.features = nn.Sequential(*layers)
with torch.no_grad():
dummy_input = torch.randn(1, in_features, 256, 256)
dummy_output = self.features(dummy_input)
flattened_size = dummy_output.view(1, -1).shape[1]
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(flattened_size, num_fc_units),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(num_fc_units, num_classes)
)
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.features(x)
x = self.classifier(x)
x = self.dequant(x)
return x
# loading model
model = QuantizedPlantClassifier()
model.eval()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
torch.quantization.convert(model, inplace=True)
model.load_state_dict(torch.load("model_dict.pth"))
classes = ['Apple___Apple_scab',
'Apple___Black_rot',
'Apple___Cedar_apple_rust',
'Apple___healthy',
'Background_without_leaves',
'Blueberry___healthy',
'Cherry___Powdery_mildew',
'Cherry___healthy',
'Corn___Cercospora_leaf_spot Gray_leaf_spot',
'Corn___Common_rust',
'Corn___Northern_Leaf_Blight',
'Corn___healthy',
'Grape___Black_rot',
'Grape___Esca_(Black_Measles)',
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
'Grape___healthy',
'Orange___Haunglongbing_(Citrus_greening)',
'Peach___Bacterial_spot',
'Peach___healthy',
'Pepper,_bell___Bacterial_spot',
'Pepper,_bell___healthy',
'Potato___Early_blight',
'Potato___Late_blight',
'Potato___healthy',
'Raspberry___healthy',
'Soybean___healthy',
'Squash___Powdery_mildew',
'Strawberry___Leaf_scorch',
'Strawberry___healthy',
'Tomato___Bacterial_spot',
'Tomato___Early_blight',
'Tomato___Late_blight',
'Tomato___Leaf_Mold',
'Tomato___Septoria_leaf_spot',
'Tomato___Spider_mites Two-spotted_spider_mite',
'Tomato___Target_Spot',
'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
'Tomato___Tomato_mosaic_virus',
'Tomato___healthy']
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
def predict(img: Image.Image):
img = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(img)
probs = torch.softmax(logits, dim=1)[0]
return {classes[i]: float(probs[i]) for i in range(len(classes))}
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
title="Image Classifier Demo",
description="Upload an image to get predictions with confidence scores."
)
demo.launch()
|