Jasur05's picture
Update app.py
4757677 verified
import gradio as gr
import os
import torch
from model import create_effnetb2_model
from timeit import default_timer as timer
from typing import Tuple, Dict
import torchvision
class_classes = [
"Speed limit (20km/h)",
"Speed limit (30km/h)",
"Speed limit (50km/h)",
"Speed limit (60km/h)",
"Speed limit (70km/h)",
"Speed limit (80km/h)",
"End of speed limit (80km/h)",
"Speed limit (100km/h)",
"Speed limit (120km/h)",
"No passing",
"No passing for vehicles over 3.5 metric tons",
"Right-of-way at the next intersection",
"Priority road",
"Yield",
"Stop",
"No vehicles",
"Vehicles over 3.5 metric tons prohibited",
"No entry",
"General caution",
"Dangerous curve to the left",
"Dangerous curve to the right",
"Double curve",
"Bumpy road",
"Slippery road",
"Road narrows on the right",
"Road work",
"Traffic signals",
"Pedestrians",
"Children crossing",
"Bicycles crossing",
"Beware of ice/snow",
"Wild animals crossing",
"End of all speed and passing limits",
"Turn right ahead",
"Turn left ahead",
"Ahead only",
"Go straight or right",
"Go straight or left",
"Keep right",
"Keep left",
"Roundabout mandatory",
"End of no passing",
"End of no passing by vehicles over 3.5 metric tons"
]
effnetb2, effnetb2_transforms = create_effnetb2_model(43)
effnetb2_transforms_new = torchvision.transforms.Compose([
torchvision.transforms.TrivialAugmentWide(),
effnetb2_transforms
])
effnetb2.load_state_dict(torch.load(f="effnetb2_traffic_sign_recognition.pth", map_location=torch.device("cpu")))
def predict(
img,
model=effnetb2,
transform=effnetb2_transforms_new,
class_classes = class_classes, # 43 human-readable names
k: int = 3
) -> Tuple[Dict[str, float], float]:
"""
Returns:
• dict of top-k {label: prob} sorted by prob desc
• inference time (sec)
"""
start = timer()
img_t = transform(img).unsqueeze(0)
model.eval()
with torch.inference_mode():
logits = model(img_t)
probs = torch.softmax(logits, dim=1).squeeze(0)
# 3. Top-k
top_probs, top_idxs = probs.topk(k)
pred_topk = {
class_classes[int(idx)]: float(prob)
for idx, prob in zip(top_idxs, top_probs)
}
pred_time = round(timer() - start, 4)
return pred_topk, pred_time
import gradio as gr
title = "Traffic Sign Classifier 🚦⛖ "
description = "The model predicts the top 3 likely signs using EfficientNet"
article = "Haven't got your driving license yet? don't worry. Here we are!"
example_list = [["examples/" + example] for example in os.listdir("examples")]
demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=[gr.Label(num_top_classes=3,label="Predictions"), gr.Number(label="Prediction time (s)")], examples=example_list, title=title, description=description, article=article)
demo.launch(debug=False)