bbox_detection / prediction.py
iasjkk's picture
Create prediction.py
11f89a2 verified
# display image with masks and bounding boxes
from os import listdir
import json
import cv2
import easyocr
reader = easyocr.Reader(['en'])
# fit a bounding box cnn on the EID dataset
from numpy import zeros
from numpy import asarray
from numpy import expand_dims
from matplotlib import pyplot
from matplotlib.patches import Rectangle
from bboxcnn.config import Config
from bboxcnn.model import BBoxCNN
from bboxcnn.model import mold_image
from bboxcnn.utils import Dataset
class PredictionConfig(Config):
# define the name of the configuration
NAME = "eid_cfg"
# number of classes (background + EID Field classes(10))
NUM_CLASSES = 1 + 10
# simplify GPU config(Here the GPU and CPU Config are same: It works if Dex sys doesnt have GPU)
GPU_COUNT = 1
IMAGES_PER_GPU = 1
# create config
cfg = PredictionConfig()
# define the model
model = BBoxCNN(mode='inference', model_dir='./', config=cfg)
# load model weights
model_path = 'model/bboxcnn_eid_cfg_0033.h5'
model.load_weights(model_path, by_name=True)
class_ids_to_class_name = {1: 'Sex', 2: 'DOB', 3: 'Country',
4: 'DOE', 5: 'Card Number', 6: 'Document Type',
7: 'Id Number', 8: 'MRZ', 9: 'Name', 10: 'Nationality'}
def load_image(path):
source_image = cv2.imread(path, cv2.IMREAD_COLOR)
scaled_image = mold_image(source_image, cfg)
# convert image into one sample
sample = expand_dims(scaled_image, 0)
return source_image, sample
def extract_text(image_path):
# piece images using bboxes
dict_ = {}
class_ids_to_class_name = {1: 'Sex', 2: 'DOB', 3: 'Country',
4: 'DOE', 5: 'Card Number', 6: 'Document Type',
7: 'Id Number', 8: 'MRZ', 9: 'Name', 10: 'Nationality'}
source_image, scaled_samp_image = load_image(image_path)
yhat = model.detect(scaled_samp_image, verbose=0)[0]
bboxes, class_ids, mask = yhat['rois'], yhat['class_ids'], yhat['masks']
for bbox, class_id in zip(bboxes, class_ids):
xmin, ymin, xmax, ymax = bbox
piece_image = source_image[xmin:xmax, ymin:ymax]
classname = class_ids_to_class_name[class_id]
text = reader.readtext(piece_image, detail=0)
dict_.update({class_ids_to_class_name[class_id]: text})
return dict_
def process(im_path):
eid_data = extract_text(im_path)
return eid_data