training_sem / dataset /extract_ocr.py
kai-2054's picture
Initial commit: add code
cb0ad2d
import os
import glob
import tqdm
import numpy as np
from utlis.list_record_cache import ListRecordCacher, merge_record_file
from utlis.utlis import get_paths, get_sub_paths, crop_pdf, extract_ocr, refine_table, visualize_table
def parse_args():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('src_dir', type=str, default=None)
parser.add_argument('dst_dir', type=str, default=None)
parser.add_argument('-n', '--num_workers', type=int, default=0)
args = parser.parse_args()
return args
def single_process(paths, dst_dir):
output_pdf_dir = os.path.join(dst_dir, 'pdf')
if not os.path.exists(output_pdf_dir):
os.makedirs(output_pdf_dir)
output_img_dir = os.path.join(dst_dir, 'img')
if not os.path.exists(output_img_dir):
os.makedirs(output_img_dir)
output_error_dir = os.path.join(dst_dir, 'error')
if not os.path.exists(output_error_dir):
os.makedirs(output_error_dir)
output_visual_dir = os.path.join(dst_dir, 'visual')
if not os.path.exists(output_visual_dir):
os.makedirs(output_visual_dir)
cacher = ListRecordCacher(os.path.join(dst_dir, 'table.lrc'))
error_paths = []
error_count = 0
correct_count = 0
for id, path in enumerate(tqdm.tqdm(paths)):
try:
pdf_path, chunk_path, structure_path = path
name = os.path.splitext(os.path.basename(pdf_path))[0]
positions, transcripts = crop_pdf(path, output_pdf_dir)
table = extract_ocr(path, positions, transcripts)
table = refine_table(table, os.path.join(output_pdf_dir, name + '.png'), output_img_dir)
assert os.path.exists(structure_path), print('structure_path is not existed')
table['label_path'] = structure_path
visualize_table(os.path.join(output_img_dir, name + '.png'), output_visual_dir, table)
cacher.add_record(table)
correct_count += 1
except:
error_count += 1
error_paths.append(path)
crop_pdf(path, output_error_dir)
print("correct num: %d, error num: %d " % (correct_count, error_count))
if len(error_paths) > 0:
np.save(os.path.join(dst_dir, 'error_paths.npy'), error_paths)
cacher.close()
def _worker(worker_idx, num_workers, paths, dst_dir, result_queue):
output_pdf_dir = os.path.join(dst_dir, 'pdf')
output_img_dir = os.path.join(dst_dir, 'img')
output_error_dir = os.path.join(dst_dir, 'error')
output_visual_dir = os.path.join(dst_dir, 'visual')
cacher = ListRecordCacher(os.path.join(dst_dir, 'table_%d.lrc' % worker_idx))
error_paths = []
error_count = 0
correct_count = 0
for id, path in enumerate(tqdm.tqdm(paths)):
try:
pdf_path, chunk_path, structure_path = path
name = os.path.splitext(os.path.basename(pdf_path))[0]
positions, transcripts = crop_pdf(path, output_pdf_dir)
table = extract_ocr(path, positions, transcripts)
table = refine_table(table, os.path.join(output_pdf_dir, name + '.png'), output_img_dir)
assert os.path.exists(structure_path), print('structure_path is not existed')
table['label_path'] = structure_path
visualize_table(os.path.join(output_img_dir, name + '.png'), output_visual_dir, table)
cacher.add_record(table)
correct_count += 1
except:
error_count += 1
error_paths.append(path)
crop_pdf(path, output_error_dir)
result_queue.put((correct_count, error_count, error_paths))
def multi_process(path, dst_dir, num_workers):
import multiprocessing
manager = multiprocessing.Manager()
result_queue = manager.Queue()
workers = list()
for worker_idx in range(num_workers):
worker = multiprocessing.Process(
target=_worker,
args=(
worker_idx,
num_workers,
path[worker_idx::num_workers],
dst_dir,
result_queue
)
)
worker.daemon = True
worker.start()
workers.append(worker)
total_correct_count = 0
total_error_count = 0
total_error_paths = []
for _ in range(num_workers):
correct_count, error_count, error_paths = result_queue.get()
total_correct_count += correct_count
total_error_count += error_count
total_error_paths.extend(error_paths)
print("correct num: %d, error num: %d " % (total_correct_count, total_error_count))
if len(total_error_paths) > 0:
np.save(os.path.join(dst_dir, 'error_paths.npy'), total_error_paths)
# merge each worker lrc
cache_paths = glob.glob(os.path.join(dst_dir, '*.lrc'))
merge_record_file(cache_paths, os.path.join(dst_dir, 'table.lrc'))
for cache_path in cache_paths:
os.remove(cache_path)
def main():
args = parse_args()
paths = get_sub_paths(args.src_dir, ["pdf", "chunk", "structure"], ['.pdf', '.chunk', '.json'])
# paths = get_paths(args.src_dir, ["pdf", "chunk", "structure"], '/yrfs1/intern/zrzhang6/TSR/Dataset/SciTSR/SciTSR-COMP.list', ['.pdf', '.chunk', '.json'])
if args.num_workers == 0:
single_process(paths, args.dst_dir)
else:
multi_process(paths, args.dst_dir, args.num_workers)
if __name__ == "__main__":
main()