import gradio as gr import torch import torch.nn as nn from torchvision import transforms from PIL import Image import collections class QuantizedPlantClassifier(nn.Module): def __init__(self, num_classes=39, in_features=3, base_filters=16, # conv stride stays 1, kernel size stays 3 n_conv_layers=6, num_fc_units=256): super(QuantizedPlantClassifier, self).__init__() # want to modularize the conv layers def make_conv_block(in_channels, out_channels, layer_idx=""): return nn.Sequential( collections.OrderedDict([ ('conv'+layer_idx, nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=0)), ('bn'+layer_idx, nn.BatchNorm2d(out_channels)), ('relu'+layer_idx, nn.ReLU(inplace=True)), ('pool'+layer_idx, nn.MaxPool2d(kernel_size=2, stride=2)) ]) ) self.quant = torch.quantization.QuantStub() layers = [] current_in_channels = in_features for i in range(n_conv_layers): layers.append(make_conv_block(current_in_channels, base_filters * (2 ** i), layer_idx=str(i))) current_in_channels = base_filters * (2 ** i) self.features = nn.Sequential(*layers) with torch.no_grad(): dummy_input = torch.randn(1, in_features, 256, 256) dummy_output = self.features(dummy_input) flattened_size = dummy_output.view(1, -1).shape[1] self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(flattened_size, num_fc_units), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear(num_fc_units, num_classes) ) self.dequant = torch.quantization.DeQuantStub() def forward(self, x): x = self.quant(x) x = self.features(x) x = self.classifier(x) x = self.dequant(x) return x # loading model model = QuantizedPlantClassifier() model.eval() model.qconfig = torch.quantization.get_default_qconfig('fbgemm') torch.quantization.prepare(model, inplace=True) torch.quantization.convert(model, inplace=True) model.load_state_dict(torch.load("model_dict.pth")) classes = ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Background_without_leaves', 'Blueberry___healthy', 'Cherry___Powdery_mildew', 'Cherry___healthy', 'Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust', 'Corn___Northern_Leaf_Blight', 'Corn___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy'] transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), ]) def predict(img: Image.Image): img = transform(img).unsqueeze(0) with torch.no_grad(): logits = model(img) probs = torch.softmax(logits, dim=1)[0] return {classes[i]: float(probs[i]) for i in range(len(classes))} demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=3), title="Image Classifier Demo", description="Upload an image to get predictions with confidence scores." ) demo.launch()