Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as T | |
| import cv2 | |
| import numpy as np | |
| import pandas as pd | |
| # ---------------------------- | |
| # Load class names | |
| # ---------------------------- | |
| df = pd.read_csv("signnames.csv") | |
| df.set_index("ClassId", inplace=True) | |
| class_ids = df.to_dict()["SignName"] | |
| id2int = {v: i for i, (k, v) in enumerate(class_ids.items())} | |
| int2id = {v: k for k, v in id2int.items()} | |
| # ---------------------------- | |
| # Define Model | |
| # ---------------------------- | |
| def conv_func(in_channels, out_channels): | |
| return nn.Sequential( | |
| nn.Dropout(0.2), | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.BatchNorm2d(out_channels), | |
| nn.MaxPool2d(2), | |
| ) | |
| class RoadSignClassifierModel(nn.Module): | |
| def __init__(self, num_classes=len(id2int)): | |
| super().__init__() | |
| self.model = nn.Sequential( | |
| conv_func(3, 64), | |
| conv_func(64, 64), | |
| conv_func(64, 128), | |
| conv_func(128, 256), | |
| nn.Flatten(), | |
| nn.Linear(256 * 2 * 2, 256), | |
| nn.Dropout(0.2), | |
| nn.ReLU(), | |
| nn.Linear(256, num_classes), | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| # ---------------------------- | |
| # Load trained model | |
| # ---------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = RoadSignClassifierModel() | |
| model.load_state_dict(torch.load("traffic_sign_model.pth", map_location=device)) | |
| model = model.to(device) | |
| model.eval() | |
| # ---------------------------- | |
| # Preprocessing | |
| # ---------------------------- | |
| val_tf = T.Compose([ | |
| T.ToPILImage(), | |
| T.Resize(32), | |
| T.CenterCrop(32), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # ---------------------------- | |
| # Prediction Function | |
| # ---------------------------- | |
| def predict(img): | |
| # Convert from Gradio (PIL.Image) to OpenCV | |
| img = np.array(img.convert("RGB")) | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
| img_input = val_tf(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(img_input) | |
| pred_class = torch.argmax(output, dim=1).item() | |
| return {class_ids[pred_class]: 1.0} | |
| # ---------------------------- | |
| # Gradio UI | |
| # ---------------------------- | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=1), | |
| title="🚦 Traffic Sign Classifier", | |
| description="Upload a traffic sign image and the model will predict its category." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |