traffic / app.py
anonymous
Initial app
dcecd2b
Raw
History Blame Contribute Delete
3.51 kB
import os
import gradio as gr
from huggingface_hub import login
from huggingface_hub import snapshot_download
import numpy as np
import tensorflow as tf
import cv2
IMG_WIDTH = 32
IMG_HEIGHT = 32
classes = { 0:'Speed limit (20km/h)',
1:'Speed limit (30km/h)',
2:'Speed limit (50km/h)',
3:'Speed limit (60km/h)',
4:'Speed limit (70km/h)',
5:'Speed limit (80km/h)',
6:'End of speed limit (80km/h)',
7:'Speed limit (100km/h)',
8:'Speed limit (120km/h)',
9:'No passing',
10:'No passing veh over 3.5 tons',
11:'Right-of-way at intersection',
12:'Priority road',
13:'Yield',
14:'Stop',
15:'No vehicles',
16:'Veh > 3.5 tons prohibited',
17:'No entry',
18:'General caution',
19:'Dangerous curve left',
20:'Dangerous curve right',
21:'Double curve',
22:'Bumpy road',
23:'Slippery road',
24:'Road narrows on the right',
25:'Road work',
26:'Traffic signals',
27:'Pedestrians',
28:'Children crossing',
29:'Bicycles crossing',
30:'Beware of ice/snow',
31:'Wild animals crossing',
32:'End speed + passing limits',
33:'Turn right ahead',
34:'Turn left ahead',
35:'Ahead only',
36:'Go straight or right',
37:'Go straight or left',
38:'Keep right',
39:'Keep left',
40:'Roundabout mandatory',
41:'End of no passing',
42:'End no passing veh > 3.5 tons' }
def image_mod(image):
# Resize image to the dimensions used when training
size = IMG_WIDTH, IMG_HEIGHT
res = cv2.resize(image, size, interpolation=cv2.INTER_AREA)
# Convert image from PIL format (RGB) to the cv2 format
# (BGR) that was used when training the model
res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR)
# Convert to float and normalize to match the training
res = res.astype("float32") / 255.0
# Convert single image to a batch for prediction
res = np.array([res])
# Carry out prediction
result = model.predict(res)
# Convert to format suitable for the Label Gradio component
confidences = result[0]
return {F"{index}: {classes[index]}":element for index, element in enumerate(confidences)}
# Download the model from Hugging Face Hub
login(token=os.environ['TOKEN_TRAFFIC'])
model_path = snapshot_download(repo_id=os.environ['REPO_TRAFFIC_MODEL'])
model = tf.keras.models.load_model(model_path)
# Configure Gradio components
input_image_component = gr.Image(type="numpy")
output_label_component = gr.Label(num_top_classes=5)
# Configure user interface
iface = gr.Interface(fn=image_mod, inputs=input_image_component, outputs=output_label_component, live=True, title="German traffic sign recognizer", description="A convolutional neural network to categorize images of German traffic signs.", article="# Reference\nJ. Stallkamp, M. Schlipsing, J. Salmen, and C. Igel. The German Traffic Sign Recognition Benchmark: A multi-class classification competition. In Proceedings of the IEEE International Joint Conference on Neural Networks, pages 1453–1460. 2011.", examples="examples")
# Launch the frontend server
iface.launch(share=False)