noahzhy's picture
feat: demo
d1cce4c
import os, glob
from pathlib import Path
import cv2
import numpy as np
import gradio as gr
from ai_edge_litert.interpreter import Interpreter
def get_samples():
list_ = glob.glob(os.path.join(os.path.dirname(__file__), 'data/*.jpg'))
list_.sort(key=lambda x: int(Path(x).stem))
return list_
def cv2_imread(path):
return cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
def load_dict(dict_path='label.names'):
dict_path = os.path.join(os.path.dirname(__file__), dict_path)
with open(dict_path, 'r', encoding='utf-8') as f:
_dict = f.read().splitlines()
_dict = {i: _dict[i] for i in range(len(_dict))}
return _dict
class TFliteDemo:
def __init__(self, model_path, blank=0):
self.blank = blank
self.interpreter = Interpreter(model_path=model_path, num_threads=os.cpu_count())
self.interpreter.allocate_tensors()
self.inputs = self.interpreter.get_input_details()
self.outputs = self.interpreter.get_output_details()
def inference(self, x):
self.interpreter.set_tensor(self.inputs[0]['index'], x)
self.interpreter.invoke()
return self.interpreter.get_tensor(self.outputs[0]['index'])
def preprocess(self, img):
if isinstance(img, str):
image = cv2_imread(img)
else:
if img is None:
raise ValueError('img is None')
image = img.copy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = cv2.resize(image, (96, 32), interpolation=cv2.INTER_LINEAR)
image = np.reshape(image, (1, 1, *image.shape)).astype(np.float32) / 255.0
return image
def postprocess(self, pred):
label_dict = load_dict()
pred_probs = pred[0]
pred_indices = np.argmax(pred_probs, axis=-1)
pred_label = [label_dict[i] for i in pred_indices]
label = ''.join(pred_label)
conf = np.min(np.max(pred_probs, axis=-1))
conf = float(f'{conf:.4f}')
return label, conf
def get_results(self, img):
img = self.preprocess(img)
pred = self.inference(img)
return self.postprocess(pred)
if __name__ == '__main__':
_TITLE = '''Lightweight South Korean Multi-line License Plate Recognition'''
_DESCRIPTION = '''
<div>
<p style="text-align: center; font-size: 1.3em">This is a demo of Lightweight South Korean Multi-line License Plate Recognition.
<a style="display:inline-block; margin-left: .5em" href='https://github.com/noahzhy/SALPR'><img src='https://img.shields.io/github/stars/noahzhy/SALPR?style=social' /></a>
</p>
</div>
'''
# init model
demo = TFliteDemo(os.path.join(os.path.dirname(__file__), 'tinyLPR.tflite'))
app = gr.Interface(
fn=demo.get_results,
inputs="image",
outputs=[
gr.Textbox(label="Plate Number", type="text"),
gr.Textbox(label="Confidence", type="text"),
],
title=_TITLE,
description=_DESCRIPTION,
examples=get_samples(),
)
app.launch()