ImageNet1k / app.py
saneshashank's picture
Update app.py
c8b3e9d verified
import gradio as gr
import torch
import torch.nn.functional as F
from PIL import Image
import json
from torchvision import transforms
from model import resnet50
# Load class labels from local file
with open("imagenet_classes.json", "r") as f:
class_labels = json.load(f)
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet50(num_classes=1000, drop_path_rate=0.0, use_blurpool=True)
# model.load_state_dict(torch.load("best_resnet50_imagenet_1k.pt", map_location=device))
checkpoint = torch.load('best_resnet50_imagenet_1k.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
# Image preprocessing
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict(image):
# Preprocess
img_tensor = transform(image).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
outputs = model(img_tensor)
probabilities = F.softmax(outputs, dim=1)[0]
# Get top 5 predictions
top5_prob, top5_idx = torch.topk(probabilities, 5)
# Format results
results = {class_labels[idx]: float(prob) for idx, prob in zip(top5_idx, top5_prob)}
return results
# Create Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title="ImageNet ResNet50 Classifier (71% Accuracy)",
description="ResNet50 trained on ImageNet with improved stem, BlurPool, and progressive resizing. Achieved 71% top-1 accuracy under $30 budget.",
examples=[]
)
if __name__ == "__main__":
demo.launch()