Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms as T | |
| from PIL import Image | |
| import gradio as gr | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load label files | |
| wnids = [line.strip() for line in open("wnids.txt")] | |
| # Map wnid β readable label | |
| words = {} | |
| with open("words.txt", "r") as f: | |
| for line in f: | |
| wnid, name = line.split("\t") | |
| words[wnid] = name.split(",")[0] | |
| # β CORRECT CLASS ORDER (alphabetical) | |
| sorted_wnids = sorted(wnids) | |
| # β Final label list matching model training order | |
| id_to_label = [words[wnid] for wnid in sorted_wnids] | |
| # Load Model | |
| model = models.resnet18(weights=None) | |
| model.fc = nn.Linear(model.fc.in_features, 200) | |
| model.load_state_dict(torch.load("best_resnet18_tinyimagenet.pth", map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # Same preprocessing as training | |
| transform = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.ToTensor(), | |
| T.Normalize((0.485, 0.456, 0.406), | |
| (0.229, 0.224, 0.225)) | |
| ]) | |
| def predict(image): | |
| image = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(image) | |
| pred = logits.argmax(1).item() | |
| return id_to_label[pred] | |
| gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs="text", | |
| title="Tiny ImageNet Classifier (Corrected Labels)" | |
| ).launch() |