math2tex / ScanSSD /gtdb /adjust_boxes.py
duycse1603's picture
[Add] source
6163604
# Author: Parag Mali
# This script stitches back the output generated on the image patches (sub-images)
# read the image
import sys
sys.path.extend(['/home/psm2208/code', '/home/psm2208/code'])
import cv2
import os
import csv
import numpy as np
import utils.visualize as visualize
from multiprocessing import Pool
from cv2.dnn import NMSBoxes
from scipy.ndimage.measurements import label
import scipy.ndimage as ndimage
import copy
from gtdb import fit_box
from gtdb import box_utils
from gtdb import feature_extractor
import shutil
import time
from collections import OrderedDict
from collections import deque
import argparse
def parse_args():
'''
Parameters
'''
parser = argparse.ArgumentParser(
description='Stitching method')
parser.add_argument('--data_file', default='test',
type=str, help='choose one')
parser.add_argument('--output_dir', default='.',
help='Output directory path')
parser.add_argument('--math_dir', required=True,
type=str, help='detections dir')
parser.add_argument('--math_ext', default='.csv',
help='Extention of detection files')
parser.add_argument('--home_data', default='/home/psm2208/data/GTDB/', type = str,
help='database dir')
parser.add_argument('--home_eval', default='/home/psm2208/code/eval/', type = str,
help='Eval dir')
parser.add_argument('--home_images', default='/home/psm2208/data/GTDB/images/', type = str,
help='Images dir')
parser.add_argument('--home_anno', default='/home/psm2208/data/GTDB/annotations/', type = str,
help='Annotations dir')
parser.add_argument('--home_char', default='/home/psm2208/data/GTDB/char_annotations/', type = str,
help='Char anno dir')
parser.add_argument('--num_workers', default=4, type=int, help='Number of workers')
parser.add_argument('--type', default='math', type=str, help='Math or text')
return parser.parse_args()
def read_math(args, pdf_name):
'''
Read math bounding boxes for given PDF
'''
math_file = os.path.join(args.math_dir, pdf_name + args.math_ext)
data = np.array([])
if os.path.exists(math_file):
data = np.genfromtxt(math_file, delimiter=',')
# if there is only one entry convert it to correct form required
if len(data.shape) == 1:
data = data.reshape(1, -1)
if args.math_ext == '.char':
data = np.delete(data,1,1)
data = data[:,:5]
return data.astype(int)
def read_char(args, pdf_name):
'''
Read char bounding boxes for given PDF
'''
data = []
path = os.path.join(args.home_char, pdf_name + ".char")
with open(path, 'r') as csvfile:
reader = csv.reader(csvfile, delimiter=',')
for row in reader:
#print('row is ' + str(row[1]))
# if entry is not in map
data.append(row)
return np.array(data)
def adjust(params):
'''
Fit the bounding boxes to the characters
'''
args, math_regions, pdf_name, page_num = params
print('Processing ', pdf_name, ' > ', page_num)
image = cv2.imread(os.path.join(args.home_images, pdf_name, str(int(page_num + 1)) + ".png"))
im_bw = fit_box.convert_to_binary(image)
new_math = []
for math in math_regions:
box = fit_box.adjust_box(im_bw, math)
if feature_extractor.width(box) > 0 and feature_extractor.height(box) > 0:
new_math.append(box)
return new_math
def adjust_char(params):
'''
Adjust the character bounding boxes
'''
try:
args, char_regions, pdf_name, page_num = params
print('Char processing ', pdf_name, ' > ', page_num)
image = cv2.imread(os.path.join(args.home_images, pdf_name, str(int(page_num) + 1) + ".png"))
im_bw = fit_box.convert_to_binary(image)
new_chars = []
for char in char_regions:
bb_char = [char[2],char[3],char[4],char[5]]
bb_char = [int(float(k)) for k in bb_char]
box = fit_box.adjust_box(im_bw, bb_char)
if feature_extractor.width(box) > 0 and feature_extractor.height(box) > 0:
char[1] = box[0]
char[2] = box[1]
char[3] = box[2]
char[4] = box[3]
new_chars.append(char)
return new_chars
except Exception as e:
print('Error while processing ', pdf_name, ' > ', page_num, sys.exc_info())
return []
def adjust_boxes(args):
'''
Driving function for adjusting the boxes
'''
pdf_list = []
pdf_names_file = open(args.data_file, 'r')
for pdf_name in pdf_names_file:
pdf_name = pdf_name.strip()
if pdf_name != '':
pdf_list.append(pdf_name)
regions = {}
for pdf_name in pdf_list:
if args.type == 'char':
regions[pdf_name] = read_char(args, pdf_name)
else:
regions[pdf_name] = read_math(args, pdf_name)
voting_ip_list = []
for pdf_name in pdf_list:
pages = np.unique(regions[pdf_name][:, 0])
#args, math_regions, pdf_name, page_num
for page_num in pages:
current_math = regions[pdf_name][regions[pdf_name][:,0] == page_num]
if args.type == 'math':
current_math = np.delete(current_math, 0, 1)
voting_ip_list.append([args, current_math, pdf_name, page_num])
pool = Pool(processes=args.num_workers)
if args.type == 'math':
out = pool.map(adjust, voting_ip_list)
else:
out = pool.map(adjust_char, voting_ip_list)
for ip, final_math in zip(voting_ip_list, out):
pdf_name = ip[2]
page_num = ip[3]
if args.type == 'math':
col = np.array([int(page_num)] * len(final_math))
final_math = np.concatenate((col[:, np.newaxis], final_math), axis=1)
math_file_path = os.path.join(args.output_dir, pdf_name + '.csv')
if not os.path.exists(os.path.dirname(math_file_path)):
os.makedirs(os.path.dirname(math_file_path))
if args.type == 'math':
math_file = open(math_file_path, 'a')
np.savetxt(math_file, final_math, fmt='%.2f', delimiter=',')
math_file.close()
else:
with open(math_file_path, 'a') as csvfile:
writer = csv.writer(csvfile, delimiter=",")
for math_region in final_math:
writer.writerow(math_region)
if __name__ == '__main__':
args = parse_args()
adjust_boxes(args)