|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import torchvision.transforms as T |
|
|
|
|
|
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] |
|
|
|
|
|
device = torch.device("cpu") |
|
|
inference_transform = T.Compose([ |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=(0.4914, 0.4822, 0.4465), |
|
|
std=(0.2023, 0.1994, 0.2010)), |
|
|
]) |
|
|
|
|
|
class SmallCifarCNN(nn.Module): |
|
|
def __init__(self, num_classes: int = 10): |
|
|
super().__init__() |
|
|
self.features = nn.Sequential( |
|
|
|
|
|
nn.Conv2d(3, 32, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(32), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2), |
|
|
|
|
|
nn.Conv2d(32, 64, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2), |
|
|
|
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2), |
|
|
) |
|
|
self.classifier = nn.Sequential( |
|
|
nn.Flatten(), |
|
|
nn.Linear(128 * 4 * 4, 256), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(p=0.5), |
|
|
nn.Linear(256, num_classes), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.features(x) |
|
|
x = self.classifier(x) |
|
|
return x |
|
|
|
|
|
deployed_model = SmallCifarCNN(num_classes=len(class_names)).to(device) |
|
|
|
|
|
model_path = 'cifar_cnn_best.pt' |
|
|
deployed_model.load_state_dict( |
|
|
torch.load(model_path, map_location=device) |
|
|
) |
|
|
deployed_model.to(device) |
|
|
deployed_model.eval() |
|
|
|
|
|
def predict_cifar_image(img: Image.Image): |
|
|
""" |
|
|
Gradio callback: |
|
|
- Takes a PIL Image |
|
|
- Resizes to 32x32 (CIFAR size) |
|
|
- Normalizes and runs through the CNN |
|
|
- Returns top-3 class probabilities |
|
|
""" |
|
|
img = img.convert("RGB") |
|
|
img = img.resize((32, 32), Image.BILINEAR) |
|
|
|
|
|
x = inference_transform(img).unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
logits = deployed_model(x) |
|
|
probs = F.softmax(logits, dim=1).cpu().numpy().ravel() |
|
|
|
|
|
topk = 3 |
|
|
idxs = np.argsort(-probs)[:topk] |
|
|
return {class_names[i]: float(probs[i]) for i in idxs} |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict_cifar_image, |
|
|
inputs=gr.Image(type="pil", label="Upload an RGB image (will be resized to 32×32)"), |
|
|
outputs=gr.Label(num_top_classes=3, label="Top-3 CIFAR-10 predictions"), |
|
|
title="CIFAR-10 CNN Classifier", |
|
|
description="Small CNN trained on CIFAR-10. Upload an image and see top-3 class probabilities.", |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |