Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| import collections | |
| class QuantizedPlantClassifier(nn.Module): | |
| def __init__(self, num_classes=39, in_features=3, base_filters=16, # conv stride stays 1, kernel size stays 3 | |
| n_conv_layers=6, num_fc_units=256): | |
| super(QuantizedPlantClassifier, self).__init__() | |
| # want to modularize the conv layers | |
| def make_conv_block(in_channels, out_channels, layer_idx=""): | |
| return nn.Sequential( | |
| collections.OrderedDict([ | |
| ('conv'+layer_idx, nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=0)), | |
| ('bn'+layer_idx, nn.BatchNorm2d(out_channels)), | |
| ('relu'+layer_idx, nn.ReLU(inplace=True)), | |
| ('pool'+layer_idx, nn.MaxPool2d(kernel_size=2, stride=2)) | |
| ]) | |
| ) | |
| self.quant = torch.quantization.QuantStub() | |
| layers = [] | |
| current_in_channels = in_features | |
| for i in range(n_conv_layers): | |
| layers.append(make_conv_block(current_in_channels, base_filters * (2 ** i), layer_idx=str(i))) | |
| current_in_channels = base_filters * (2 ** i) | |
| self.features = nn.Sequential(*layers) | |
| with torch.no_grad(): | |
| dummy_input = torch.randn(1, in_features, 256, 256) | |
| dummy_output = self.features(dummy_input) | |
| flattened_size = dummy_output.view(1, -1).shape[1] | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(flattened_size, num_fc_units), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.5), | |
| nn.Linear(num_fc_units, num_classes) | |
| ) | |
| self.dequant = torch.quantization.DeQuantStub() | |
| def forward(self, x): | |
| x = self.quant(x) | |
| x = self.features(x) | |
| x = self.classifier(x) | |
| x = self.dequant(x) | |
| return x | |
| # loading model | |
| model = QuantizedPlantClassifier() | |
| model.eval() | |
| model.qconfig = torch.quantization.get_default_qconfig('fbgemm') | |
| torch.quantization.prepare(model, inplace=True) | |
| torch.quantization.convert(model, inplace=True) | |
| model.load_state_dict(torch.load("model_dict.pth")) | |
| classes = ['Apple___Apple_scab', | |
| 'Apple___Black_rot', | |
| 'Apple___Cedar_apple_rust', | |
| 'Apple___healthy', | |
| 'Background_without_leaves', | |
| 'Blueberry___healthy', | |
| 'Cherry___Powdery_mildew', | |
| 'Cherry___healthy', | |
| 'Corn___Cercospora_leaf_spot Gray_leaf_spot', | |
| 'Corn___Common_rust', | |
| 'Corn___Northern_Leaf_Blight', | |
| 'Corn___healthy', | |
| 'Grape___Black_rot', | |
| 'Grape___Esca_(Black_Measles)', | |
| 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', | |
| 'Grape___healthy', | |
| 'Orange___Haunglongbing_(Citrus_greening)', | |
| 'Peach___Bacterial_spot', | |
| 'Peach___healthy', | |
| 'Pepper,_bell___Bacterial_spot', | |
| 'Pepper,_bell___healthy', | |
| 'Potato___Early_blight', | |
| 'Potato___Late_blight', | |
| 'Potato___healthy', | |
| 'Raspberry___healthy', | |
| 'Soybean___healthy', | |
| 'Squash___Powdery_mildew', | |
| 'Strawberry___Leaf_scorch', | |
| 'Strawberry___healthy', | |
| 'Tomato___Bacterial_spot', | |
| 'Tomato___Early_blight', | |
| 'Tomato___Late_blight', | |
| 'Tomato___Leaf_Mold', | |
| 'Tomato___Septoria_leaf_spot', | |
| 'Tomato___Spider_mites Two-spotted_spider_mite', | |
| 'Tomato___Target_Spot', | |
| 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', | |
| 'Tomato___Tomato_mosaic_virus', | |
| 'Tomato___healthy'] | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| ]) | |
| def predict(img: Image.Image): | |
| img = transform(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| logits = model(img) | |
| probs = torch.softmax(logits, dim=1)[0] | |
| return {classes[i]: float(probs[i]) for i in range(len(classes))} | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=3), | |
| title="Image Classifier Demo", | |
| description="Upload an image to get predictions with confidence scores." | |
| ) | |
| demo.launch() | |