kamaldhakad's picture
Update app.py
17f2199 verified
import torch
import torch.nn as nn
from torchvision import models, transforms as T
from PIL import Image
import gradio as gr
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load label files
wnids = [line.strip() for line in open("wnids.txt")]
# Map wnid β†’ readable label
words = {}
with open("words.txt", "r") as f:
for line in f:
wnid, name = line.split("\t")
words[wnid] = name.split(",")[0]
# βœ… CORRECT CLASS ORDER (alphabetical)
sorted_wnids = sorted(wnids)
# βœ… Final label list matching model training order
id_to_label = [words[wnid] for wnid in sorted_wnids]
# Load Model
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 200)
model.load_state_dict(torch.load("best_resnet18_tinyimagenet.pth", map_location=device))
model.to(device)
model.eval()
# Same preprocessing as training
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
])
def predict(image):
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(image)
pred = logits.argmax(1).item()
return id_to_label[pred]
gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="text",
title="Tiny ImageNet Classifier (Corrected Labels)"
).launch()