mgbam's picture
Create app.py
058024c verified
raw
history blame
1.66 kB
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader
from PIL import Image
import gradio as gr
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load class names (make sure this file is in the Space)
with open("cifar10_classes.txt") as f:
CLASSES = [line.strip() for line in f.readlines()]
def build_model(num_classes: int, device: str = "cpu"):
try:
weights = models.ResNet18_Weights.DEFAULT
model = models.resnet18(weights=weights)
except AttributeError:
model = models.resnet18(weights="IMAGENET1K_V1")
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)
return model
num_classes = len(CLASSES)
model = build_model(num_classes, device=DEVICE)
state_dict = torch.load("ast_cifar10_resnet18.pth", map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
def predict(image: Image.Image):
if image is None:
return {}
x = preprocess(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1)[0]
return {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload CIFAR-like image"),
outputs=gr.Label(num_top_classes=3, label="Top-3 Predictions"),
title="AST CIFAR-10 Classifier",
description="ResNet18 fine-tuned with Adaptive Sparse Training (AST) on CIFAR-10.",
)
if __name__ == "__main__":
demo.launch()