Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import scipy.io | |
| import numpy as np | |
| # ---- CONFIG ---- | |
| MODEL_PATH = "twin_car_best_model_v2.pth" | |
| META_PATH = "cars_meta.mat" | |
| DEVICE = torch.device("cpu") | |
| # ---- LOAD CLASS NAMES ---- | |
| meta = scipy.io.loadmat(META_PATH) | |
| class_names = [x[0] for x in meta['class_names'][0]] | |
| # ---- DEFINE MODEL ---- | |
| def get_model(num_classes=196): | |
| model = models.resnet50(weights=None) | |
| model.fc = nn.Sequential( | |
| nn.Linear(model.fc.in_features, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(512, num_classes) | |
| ) | |
| return model | |
| model = get_model(num_classes=len(class_names)) | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) | |
| model.eval() | |
| # ---- TRANSFORM ---- | |
| imagenet_mean = [0.485, 0.456, 0.406] | |
| imagenet_std = [0.229, 0.224, 0.225] | |
| test_transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=imagenet_mean, std=imagenet_std) | |
| ]) | |
| # ---- PREDICTION FUNCTION ---- | |
| def predict(img): | |
| img_pil = img.convert("RGB") | |
| x = test_transform(img_pil).unsqueeze(0) | |
| with torch.no_grad(): | |
| logits = model(x) | |
| probs = torch.softmax(logits, dim=1).cpu().numpy()[0] | |
| top5_idx = np.argsort(probs)[-5:][::-1] | |
| results = {class_names[i]: float(probs[i]) for i in top5_idx} | |
| return results | |
| # ---- GRADIO APP ---- | |
| description = ( | |
| "Upload a car image. The model returns top-5 fine-grained make/model predictions " | |
| "using Stanford Cars 196. <br><br>Model: ResNet-50 + custom head, trained by Kiril Mickovski." | |
| ) | |
| gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload Car Image"), | |
| outputs=gr.Label(num_top_classes=5), | |
| title="๐ TwinCar-196: Stanford Cars Classifier", | |
| description=description, | |
| allow_flagging="never" | |
| ).launch() | |