iasjkk commited on
Commit
11f89a2
·
verified ·
1 Parent(s): b1acd2c

Create prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +67 -0
prediction.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # display image with masks and bounding boxes
2
+ from os import listdir
3
+ import json
4
+ import cv2
5
+ import easyocr
6
+ reader = easyocr.Reader(['en'])
7
+ # fit a bounding box cnn on the EID dataset
8
+ from numpy import zeros
9
+ from numpy import asarray
10
+ from numpy import expand_dims
11
+ from matplotlib import pyplot
12
+ from matplotlib.patches import Rectangle
13
+ from bboxcnn.config import Config
14
+ from bboxcnn.model import BBoxCNN
15
+ from bboxcnn.model import mold_image
16
+ from bboxcnn.utils import Dataset
17
+
18
+
19
+
20
+ class PredictionConfig(Config):
21
+ # define the name of the configuration
22
+ NAME = "eid_cfg"
23
+ # number of classes (background + EID Field classes(10))
24
+ NUM_CLASSES = 1 + 10
25
+ # simplify GPU config(Here the GPU and CPU Config are same: It works if Dex sys doesnt have GPU)
26
+ GPU_COUNT = 1
27
+ IMAGES_PER_GPU = 1
28
+ # create config
29
+ cfg = PredictionConfig()
30
+ # define the model
31
+ model = BBoxCNN(mode='inference', model_dir='./', config=cfg)
32
+ # load model weights
33
+ model_path = 'model/bboxcnn_eid_cfg_0033.h5'
34
+ model.load_weights(model_path, by_name=True)
35
+
36
+
37
+ class_ids_to_class_name = {1: 'Sex', 2: 'DOB', 3: 'Country',
38
+ 4: 'DOE', 5: 'Card Number', 6: 'Document Type',
39
+ 7: 'Id Number', 8: 'MRZ', 9: 'Name', 10: 'Nationality'}
40
+
41
+ def load_image(path):
42
+ source_image = cv2.imread(path, cv2.IMREAD_COLOR)
43
+ scaled_image = mold_image(source_image, cfg)
44
+ # convert image into one sample
45
+ sample = expand_dims(scaled_image, 0)
46
+ return source_image, sample
47
+
48
+ def extract_text(image_path):
49
+ # piece images using bboxes
50
+ dict_ = {}
51
+ class_ids_to_class_name = {1: 'Sex', 2: 'DOB', 3: 'Country',
52
+ 4: 'DOE', 5: 'Card Number', 6: 'Document Type',
53
+ 7: 'Id Number', 8: 'MRZ', 9: 'Name', 10: 'Nationality'}
54
+ source_image, scaled_samp_image = load_image(image_path)
55
+ yhat = model.detect(scaled_samp_image, verbose=0)[0]
56
+ bboxes, class_ids, mask = yhat['rois'], yhat['class_ids'], yhat['masks']
57
+ for bbox, class_id in zip(bboxes, class_ids):
58
+ xmin, ymin, xmax, ymax = bbox
59
+ piece_image = source_image[xmin:xmax, ymin:ymax]
60
+ classname = class_ids_to_class_name[class_id]
61
+ text = reader.readtext(piece_image, detail=0)
62
+ dict_.update({class_ids_to_class_name[class_id]: text})
63
+ return dict_
64
+
65
+ def process(im_path):
66
+ eid_data = extract_text(im_path)
67
+ return eid_data