ImageClassifier / app.py
sidharthg's picture
Upload 7 files
a91f34b verified
import gradio as gr
import torch
import torchvision.transforms as T
from PIL import Image
from model import resnet18 # Ensure this matches your model definition file
# Load CIFAR-100 class names
with open("cifar100_classes.txt") as f:
CIFAR100_CLASSES = [line.strip() for line in f.readlines()]
# Load trained model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet18(num_classes=100)
checkpoint=torch.load("resnet18_cifar100_best.pth", map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
model.to(DEVICE)
# Define preprocessing
transform = T.Compose([
T.Resize((32, 32)),
T.ToTensor(),
T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])
def predict(image):
img = Image.fromarray(image).convert("RGB")
img = transform(img).unsqueeze(0).to(DEVICE)
with torch.no_grad():
outputs = model(img)
probs = torch.softmax(outputs, dim=1)
conf, pred = torch.max(probs, dim=1)
class_name = CIFAR100_CLASSES[pred.item()]
confidence = conf.item() # Normalize to 0-100%
return {f"{class_name}": round(confidence, 2)}
# Gradio UI
title = "CIFAR-100 Image Classifier"
description = "Upload an image (32x32 or larger). The model will predict the top class with confidence score."
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy", label="Upload Image"),
outputs=gr.Label(num_top_classes=1, label="Prediction"),
title=title,
description=description,
examples=[["examples/1.jpg"], ["examples/2.jpg"]],
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()