iasjkk commited on
Commit
5298ec2
·
verified ·
1 Parent(s): f5961e7

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +131 -0
train.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # display image with masks and bounding boxes
2
+ from os import listdir
3
+ from xml.etree import ElementTree
4
+ import json
5
+ from numpy import zeros
6
+ from numpy import asarray
7
+ from bboxcnn.utils import Dataset
8
+ from bboxcnn.config import Config
9
+ from bboxcnn.model import BBoxCNN
10
+
11
+ class PASSPORT_Dataset(Dataset):
12
+ # load the dataset definitions
13
+ def load_dataset(self, dataset_dir, is_train=True):
14
+ # define one class
15
+
16
+ self.add_class("dataset", 1, "Country Name")
17
+ self.add_class("dataset", 2, "Document Type")
18
+ self.add_class("dataset", 3, "Country Code")
19
+ self.add_class("dataset", 4, "Passport Number")
20
+ self.add_class("dataset", 5, "Surname")
21
+ self.add_class("dataset", 6, "Given Name")
22
+ self.add_class("dataset", 7, "Nationality")
23
+ self.add_class("dataset", 8, "Sex")
24
+ self.add_class("dataset", 9, "DOB")
25
+ self.add_class("dataset", 10, "Place Of Birth")
26
+ self.add_class("dataset", 11, "Place Of Issue")
27
+ self.add_class("dataset", 12, "DOI")
28
+ self.add_class("dataset", 13, "DOE")
29
+ self.add_class("dataset", 14, "MRZ")
30
+ self.add_class("dataset", 15, "Name Of Father")
31
+ self.add_class("dataset", 16, "Name Of Mother")
32
+ self.add_class("dataset", 17, "Name Of Spouse")
33
+ self.add_class("dataset", 18, "Address")
34
+ self.add_class("dataset", 19, "Old Passport Information")
35
+ self.add_class("dataset", 20, "File Number")
36
+ # define data locations
37
+ images_dir = dataset_dir + '/images/'
38
+ annotations_dir = dataset_dir + '/annots/'
39
+ # find all images
40
+ for filename in listdir(images_dir):
41
+ # extract image id
42
+ image_id = filename[:-4]
43
+ # skip bad images
44
+ if image_id in ['017']:
45
+ continue
46
+ # skip all images after 150 if we are building the train set
47
+ if is_train and int(image_id) >= 79:
48
+ continue
49
+ # skip all images before 150 if we are building the test/val set
50
+ if not is_train and int(image_id) < 79:
51
+ continue
52
+ img_path = images_dir + filename
53
+ ann_path = annotations_dir + image_id + '.json'
54
+ # add to dataset
55
+ self.add_image('dataset', image_id=image_id, path=img_path, annotation=ann_path)
56
+
57
+ # extract bounding boxes from an annotation file
58
+ def extract_boxes(self, filename):
59
+ # load and parse the file
60
+ with open(filename, 'r') as f:
61
+ data = json.load(f)
62
+ boxes = list()
63
+ bndboxes = [i['bndbox'] for i in data['object']]
64
+ class_names = [i['name'] for i in data['object']]
65
+ for box in bndboxes:
66
+ xmin = int(box['xmin'])
67
+ ymin = int(box['ymin'])
68
+ xmax = int(box['xmax'])
69
+ ymax = int(box['ymax'])
70
+ coors = [xmin, ymin, xmax, ymax]
71
+ boxes.append(coors)
72
+ # extract image dimensions
73
+ width = int(data['size']['width'])
74
+ height = int(data['size']['height'])
75
+ return boxes, class_names, width, height
76
+
77
+ # load the masks for an image
78
+ def load_mask(self, image_id):
79
+ # get details of image
80
+ info = self.image_info[image_id]
81
+ # define box file location
82
+ path = info['annotation']
83
+ # load XML
84
+ boxes, class_names, w, h = self.extract_boxes(path)
85
+ # create one array for all masks, each on a different channel
86
+ masks = zeros([h, w, len(boxes)], dtype='uint8')
87
+ # create masks
88
+ class_ids = list()
89
+ for i, entity in enumerate(zip(boxes, class_names)):
90
+ box, class_name = entity
91
+ row_s, row_e = box[1], box[3]
92
+ col_s, col_e = box[0], box[2]
93
+ masks[row_s:row_e, col_s:col_e, i] = i+1
94
+ class_ids.append(self.class_names.index(class_name))
95
+ return masks, asarray(class_ids, dtype='int32')
96
+
97
+ # load an image reference
98
+ def image_reference(self, image_id):
99
+ info = self.image_info[image_id]
100
+ return info['path']
101
+
102
+ # define a configuration for the model
103
+ class PASSPORT_Config(Config):
104
+ # define the name of the configuration
105
+ NAME = "passport_cfg"
106
+ # number of classes (background + Object Classes)
107
+ NUM_CLASSES = 1 + 20
108
+ # number of training steps per epoch
109
+ STEPS_PER_EPOCH = 81
110
+
111
+
112
+ # train set
113
+ train_set = PASSPORT_Dataset()
114
+ train_set.load_dataset('passport_data', is_train=True)
115
+ train_set.prepare()
116
+ print('Train: %d' % len(train_set.image_ids))
117
+ # prepare test/val set
118
+ test_set = PASSPORT_Dataset()
119
+ test_set.load_dataset('passport_data', is_train=False)
120
+ test_set.prepare()
121
+ print('Test: %d' % len(test_set.image_ids))
122
+ # prepare config
123
+ config = PASSPORT_Config()
124
+ config.display()
125
+ # define the model
126
+ model = BBoxCNN(mode='training', model_dir='./', config=config)
127
+ # load weights (mscoco) and exclude the output layers
128
+ # model.load_weights('bboxcnn_base.h5', by_name=True, exclude=["bboxcnn_class_logits", "bboxcnn_bbox_fc", "bboxcnn_bbox", "bboxcnn_mask"])
129
+ model.load_weights('passport_cfg20220520T2226/bboxcnn_passport_cfg_0090.h5', by_name=True, exclude=["bboxcnn_class_logits", "bboxcnn_bbox_fc", "bboxcnn_bbox", "bboxcnn_mask"])
130
+ # train weights (output layers or 'heads')
131
+ model.train(train_set, test_set, learning_rate=config.LEARNING_RATE, epochs=90, layers='heads')