Create prediction.py
Browse files- prediction.py +67 -0
prediction.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# display image with masks and bounding boxes
|
| 2 |
+
from os import listdir
|
| 3 |
+
import json
|
| 4 |
+
import cv2
|
| 5 |
+
import easyocr
|
| 6 |
+
reader = easyocr.Reader(['en'])
|
| 7 |
+
# fit a bounding box cnn on the EID dataset
|
| 8 |
+
from numpy import zeros
|
| 9 |
+
from numpy import asarray
|
| 10 |
+
from numpy import expand_dims
|
| 11 |
+
from matplotlib import pyplot
|
| 12 |
+
from matplotlib.patches import Rectangle
|
| 13 |
+
from bboxcnn.config import Config
|
| 14 |
+
from bboxcnn.model import BBoxCNN
|
| 15 |
+
from bboxcnn.model import mold_image
|
| 16 |
+
from bboxcnn.utils import Dataset
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class PredictionConfig(Config):
|
| 21 |
+
# define the name of the configuration
|
| 22 |
+
NAME = "eid_cfg"
|
| 23 |
+
# number of classes (background + EID Field classes(10))
|
| 24 |
+
NUM_CLASSES = 1 + 10
|
| 25 |
+
# simplify GPU config(Here the GPU and CPU Config are same: It works if Dex sys doesnt have GPU)
|
| 26 |
+
GPU_COUNT = 1
|
| 27 |
+
IMAGES_PER_GPU = 1
|
| 28 |
+
# create config
|
| 29 |
+
cfg = PredictionConfig()
|
| 30 |
+
# define the model
|
| 31 |
+
model = BBoxCNN(mode='inference', model_dir='./', config=cfg)
|
| 32 |
+
# load model weights
|
| 33 |
+
model_path = 'model/bboxcnn_eid_cfg_0033.h5'
|
| 34 |
+
model.load_weights(model_path, by_name=True)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class_ids_to_class_name = {1: 'Sex', 2: 'DOB', 3: 'Country',
|
| 38 |
+
4: 'DOE', 5: 'Card Number', 6: 'Document Type',
|
| 39 |
+
7: 'Id Number', 8: 'MRZ', 9: 'Name', 10: 'Nationality'}
|
| 40 |
+
|
| 41 |
+
def load_image(path):
|
| 42 |
+
source_image = cv2.imread(path, cv2.IMREAD_COLOR)
|
| 43 |
+
scaled_image = mold_image(source_image, cfg)
|
| 44 |
+
# convert image into one sample
|
| 45 |
+
sample = expand_dims(scaled_image, 0)
|
| 46 |
+
return source_image, sample
|
| 47 |
+
|
| 48 |
+
def extract_text(image_path):
|
| 49 |
+
# piece images using bboxes
|
| 50 |
+
dict_ = {}
|
| 51 |
+
class_ids_to_class_name = {1: 'Sex', 2: 'DOB', 3: 'Country',
|
| 52 |
+
4: 'DOE', 5: 'Card Number', 6: 'Document Type',
|
| 53 |
+
7: 'Id Number', 8: 'MRZ', 9: 'Name', 10: 'Nationality'}
|
| 54 |
+
source_image, scaled_samp_image = load_image(image_path)
|
| 55 |
+
yhat = model.detect(scaled_samp_image, verbose=0)[0]
|
| 56 |
+
bboxes, class_ids, mask = yhat['rois'], yhat['class_ids'], yhat['masks']
|
| 57 |
+
for bbox, class_id in zip(bboxes, class_ids):
|
| 58 |
+
xmin, ymin, xmax, ymax = bbox
|
| 59 |
+
piece_image = source_image[xmin:xmax, ymin:ymax]
|
| 60 |
+
classname = class_ids_to_class_name[class_id]
|
| 61 |
+
text = reader.readtext(piece_image, detail=0)
|
| 62 |
+
dict_.update({class_ids_to_class_name[class_id]: text})
|
| 63 |
+
return dict_
|
| 64 |
+
|
| 65 |
+
def process(im_path):
|
| 66 |
+
eid_data = extract_text(im_path)
|
| 67 |
+
return eid_data
|