bbox_detection / train.py
iasjkk's picture
Upload train.py
5298ec2 verified
# display image with masks and bounding boxes
from os import listdir
from xml.etree import ElementTree
import json
from numpy import zeros
from numpy import asarray
from bboxcnn.utils import Dataset
from bboxcnn.config import Config
from bboxcnn.model import BBoxCNN
class PASSPORT_Dataset(Dataset):
# load the dataset definitions
def load_dataset(self, dataset_dir, is_train=True):
# define one class
self.add_class("dataset", 1, "Country Name")
self.add_class("dataset", 2, "Document Type")
self.add_class("dataset", 3, "Country Code")
self.add_class("dataset", 4, "Passport Number")
self.add_class("dataset", 5, "Surname")
self.add_class("dataset", 6, "Given Name")
self.add_class("dataset", 7, "Nationality")
self.add_class("dataset", 8, "Sex")
self.add_class("dataset", 9, "DOB")
self.add_class("dataset", 10, "Place Of Birth")
self.add_class("dataset", 11, "Place Of Issue")
self.add_class("dataset", 12, "DOI")
self.add_class("dataset", 13, "DOE")
self.add_class("dataset", 14, "MRZ")
self.add_class("dataset", 15, "Name Of Father")
self.add_class("dataset", 16, "Name Of Mother")
self.add_class("dataset", 17, "Name Of Spouse")
self.add_class("dataset", 18, "Address")
self.add_class("dataset", 19, "Old Passport Information")
self.add_class("dataset", 20, "File Number")
# define data locations
images_dir = dataset_dir + '/images/'
annotations_dir = dataset_dir + '/annots/'
# find all images
for filename in listdir(images_dir):
# extract image id
image_id = filename[:-4]
# skip bad images
if image_id in ['017']:
continue
# skip all images after 150 if we are building the train set
if is_train and int(image_id) >= 79:
continue
# skip all images before 150 if we are building the test/val set
if not is_train and int(image_id) < 79:
continue
img_path = images_dir + filename
ann_path = annotations_dir + image_id + '.json'
# add to dataset
self.add_image('dataset', image_id=image_id, path=img_path, annotation=ann_path)
# extract bounding boxes from an annotation file
def extract_boxes(self, filename):
# load and parse the file
with open(filename, 'r') as f:
data = json.load(f)
boxes = list()
bndboxes = [i['bndbox'] for i in data['object']]
class_names = [i['name'] for i in data['object']]
for box in bndboxes:
xmin = int(box['xmin'])
ymin = int(box['ymin'])
xmax = int(box['xmax'])
ymax = int(box['ymax'])
coors = [xmin, ymin, xmax, ymax]
boxes.append(coors)
# extract image dimensions
width = int(data['size']['width'])
height = int(data['size']['height'])
return boxes, class_names, width, height
# load the masks for an image
def load_mask(self, image_id):
# get details of image
info = self.image_info[image_id]
# define box file location
path = info['annotation']
# load XML
boxes, class_names, w, h = self.extract_boxes(path)
# create one array for all masks, each on a different channel
masks = zeros([h, w, len(boxes)], dtype='uint8')
# create masks
class_ids = list()
for i, entity in enumerate(zip(boxes, class_names)):
box, class_name = entity
row_s, row_e = box[1], box[3]
col_s, col_e = box[0], box[2]
masks[row_s:row_e, col_s:col_e, i] = i+1
class_ids.append(self.class_names.index(class_name))
return masks, asarray(class_ids, dtype='int32')
# load an image reference
def image_reference(self, image_id):
info = self.image_info[image_id]
return info['path']
# define a configuration for the model
class PASSPORT_Config(Config):
# define the name of the configuration
NAME = "passport_cfg"
# number of classes (background + Object Classes)
NUM_CLASSES = 1 + 20
# number of training steps per epoch
STEPS_PER_EPOCH = 81
# train set
train_set = PASSPORT_Dataset()
train_set.load_dataset('passport_data', is_train=True)
train_set.prepare()
print('Train: %d' % len(train_set.image_ids))
# prepare test/val set
test_set = PASSPORT_Dataset()
test_set.load_dataset('passport_data', is_train=False)
test_set.prepare()
print('Test: %d' % len(test_set.image_ids))
# prepare config
config = PASSPORT_Config()
config.display()
# define the model
model = BBoxCNN(mode='training', model_dir='./', config=config)
# load weights (mscoco) and exclude the output layers
# model.load_weights('bboxcnn_base.h5', by_name=True, exclude=["bboxcnn_class_logits", "bboxcnn_bbox_fc", "bboxcnn_bbox", "bboxcnn_mask"])
model.load_weights('passport_cfg20220520T2226/bboxcnn_passport_cfg_0090.h5', by_name=True, exclude=["bboxcnn_class_logits", "bboxcnn_bbox_fc", "bboxcnn_bbox", "bboxcnn_mask"])
# train weights (output layers or 'heads')
model.train(train_set, test_set, learning_rate=config.LEARNING_RATE, epochs=90, layers='heads')