import torch import torch.nn as nn from torchvision import models, transforms as T from PIL import Image import gradio as gr device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load label files wnids = [line.strip() for line in open("wnids.txt")] # Map wnid → readable label words = {} with open("words.txt", "r") as f: for line in f: wnid, name = line.split("\t") words[wnid] = name.split(",")[0] # ✅ CORRECT CLASS ORDER (alphabetical) sorted_wnids = sorted(wnids) # ✅ Final label list matching model training order id_to_label = [words[wnid] for wnid in sorted_wnids] # Load Model model = models.resnet18(weights=None) model.fc = nn.Linear(model.fc.in_features, 200) model.load_state_dict(torch.load("best_resnet18_tinyimagenet.pth", map_location=device)) model.to(device) model.eval() # Same preprocessing as training transform = T.Compose([ T.Resize((224, 224)), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) def predict(image): image = transform(image).unsqueeze(0).to(device) with torch.no_grad(): logits = model(image) pred = logits.argmax(1).item() return id_to_label[pred] gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs="text", title="Tiny ImageNet Classifier (Corrected Labels)" ).launch()