Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import torch | |
| from model import create_effnetb2_model | |
| from timeit import default_timer as timer | |
| from typing import Tuple, Dict | |
| import torchvision | |
| class_classes = [ | |
| "Speed limit (20km/h)", | |
| "Speed limit (30km/h)", | |
| "Speed limit (50km/h)", | |
| "Speed limit (60km/h)", | |
| "Speed limit (70km/h)", | |
| "Speed limit (80km/h)", | |
| "End of speed limit (80km/h)", | |
| "Speed limit (100km/h)", | |
| "Speed limit (120km/h)", | |
| "No passing", | |
| "No passing for vehicles over 3.5 metric tons", | |
| "Right-of-way at the next intersection", | |
| "Priority road", | |
| "Yield", | |
| "Stop", | |
| "No vehicles", | |
| "Vehicles over 3.5 metric tons prohibited", | |
| "No entry", | |
| "General caution", | |
| "Dangerous curve to the left", | |
| "Dangerous curve to the right", | |
| "Double curve", | |
| "Bumpy road", | |
| "Slippery road", | |
| "Road narrows on the right", | |
| "Road work", | |
| "Traffic signals", | |
| "Pedestrians", | |
| "Children crossing", | |
| "Bicycles crossing", | |
| "Beware of ice/snow", | |
| "Wild animals crossing", | |
| "End of all speed and passing limits", | |
| "Turn right ahead", | |
| "Turn left ahead", | |
| "Ahead only", | |
| "Go straight or right", | |
| "Go straight or left", | |
| "Keep right", | |
| "Keep left", | |
| "Roundabout mandatory", | |
| "End of no passing", | |
| "End of no passing by vehicles over 3.5 metric tons" | |
| ] | |
| effnetb2, effnetb2_transforms = create_effnetb2_model(43) | |
| effnetb2_transforms_new = torchvision.transforms.Compose([ | |
| torchvision.transforms.TrivialAugmentWide(), | |
| effnetb2_transforms | |
| ]) | |
| effnetb2.load_state_dict(torch.load(f="effnetb2_traffic_sign_recognition.pth", map_location=torch.device("cpu"))) | |
| def predict( | |
| img, | |
| model=effnetb2, | |
| transform=effnetb2_transforms_new, | |
| class_classes = class_classes, # 43 human-readable names | |
| k: int = 3 | |
| ) -> Tuple[Dict[str, float], float]: | |
| """ | |
| Returns: | |
| • dict of top-k {label: prob} sorted by prob desc | |
| • inference time (sec) | |
| """ | |
| start = timer() | |
| img_t = transform(img).unsqueeze(0) | |
| model.eval() | |
| with torch.inference_mode(): | |
| logits = model(img_t) | |
| probs = torch.softmax(logits, dim=1).squeeze(0) | |
| # 3. Top-k | |
| top_probs, top_idxs = probs.topk(k) | |
| pred_topk = { | |
| class_classes[int(idx)]: float(prob) | |
| for idx, prob in zip(top_idxs, top_probs) | |
| } | |
| pred_time = round(timer() - start, 4) | |
| return pred_topk, pred_time | |
| import gradio as gr | |
| title = "Traffic Sign Classifier 🚦⛖ " | |
| description = "The model predicts the top 3 likely signs using EfficientNet" | |
| article = "Haven't got your driving license yet? don't worry. Here we are!" | |
| example_list = [["examples/" + example] for example in os.listdir("examples")] | |
| demo = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=[gr.Label(num_top_classes=3,label="Predictions"), gr.Number(label="Prediction time (s)")], examples=example_list, title=title, description=description, article=article) | |
| demo.launch(debug=False) | |