File size: 5,455 Bytes
cb0ad2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
import glob
import tqdm
import numpy as np
from utlis.list_record_cache import ListRecordCacher, merge_record_file
from utlis.utlis import get_paths, get_sub_paths, crop_pdf, extract_ocr, refine_table, visualize_table
def parse_args():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('src_dir', type=str, default=None)
parser.add_argument('dst_dir', type=str, default=None)
parser.add_argument('-n', '--num_workers', type=int, default=0)
args = parser.parse_args()
return args
def single_process(paths, dst_dir):
output_pdf_dir = os.path.join(dst_dir, 'pdf')
if not os.path.exists(output_pdf_dir):
os.makedirs(output_pdf_dir)
output_img_dir = os.path.join(dst_dir, 'img')
if not os.path.exists(output_img_dir):
os.makedirs(output_img_dir)
output_error_dir = os.path.join(dst_dir, 'error')
if not os.path.exists(output_error_dir):
os.makedirs(output_error_dir)
output_visual_dir = os.path.join(dst_dir, 'visual')
if not os.path.exists(output_visual_dir):
os.makedirs(output_visual_dir)
cacher = ListRecordCacher(os.path.join(dst_dir, 'table.lrc'))
error_paths = []
error_count = 0
correct_count = 0
for id, path in enumerate(tqdm.tqdm(paths)):
try:
pdf_path, chunk_path, structure_path = path
name = os.path.splitext(os.path.basename(pdf_path))[0]
positions, transcripts = crop_pdf(path, output_pdf_dir)
table = extract_ocr(path, positions, transcripts)
table = refine_table(table, os.path.join(output_pdf_dir, name + '.png'), output_img_dir)
assert os.path.exists(structure_path), print('structure_path is not existed')
table['label_path'] = structure_path
visualize_table(os.path.join(output_img_dir, name + '.png'), output_visual_dir, table)
cacher.add_record(table)
correct_count += 1
except:
error_count += 1
error_paths.append(path)
crop_pdf(path, output_error_dir)
print("correct num: %d, error num: %d " % (correct_count, error_count))
if len(error_paths) > 0:
np.save(os.path.join(dst_dir, 'error_paths.npy'), error_paths)
cacher.close()
def _worker(worker_idx, num_workers, paths, dst_dir, result_queue):
output_pdf_dir = os.path.join(dst_dir, 'pdf')
output_img_dir = os.path.join(dst_dir, 'img')
output_error_dir = os.path.join(dst_dir, 'error')
output_visual_dir = os.path.join(dst_dir, 'visual')
cacher = ListRecordCacher(os.path.join(dst_dir, 'table_%d.lrc' % worker_idx))
error_paths = []
error_count = 0
correct_count = 0
for id, path in enumerate(tqdm.tqdm(paths)):
try:
pdf_path, chunk_path, structure_path = path
name = os.path.splitext(os.path.basename(pdf_path))[0]
positions, transcripts = crop_pdf(path, output_pdf_dir)
table = extract_ocr(path, positions, transcripts)
table = refine_table(table, os.path.join(output_pdf_dir, name + '.png'), output_img_dir)
assert os.path.exists(structure_path), print('structure_path is not existed')
table['label_path'] = structure_path
visualize_table(os.path.join(output_img_dir, name + '.png'), output_visual_dir, table)
cacher.add_record(table)
correct_count += 1
except:
error_count += 1
error_paths.append(path)
crop_pdf(path, output_error_dir)
result_queue.put((correct_count, error_count, error_paths))
def multi_process(path, dst_dir, num_workers):
import multiprocessing
manager = multiprocessing.Manager()
result_queue = manager.Queue()
workers = list()
for worker_idx in range(num_workers):
worker = multiprocessing.Process(
target=_worker,
args=(
worker_idx,
num_workers,
path[worker_idx::num_workers],
dst_dir,
result_queue
)
)
worker.daemon = True
worker.start()
workers.append(worker)
total_correct_count = 0
total_error_count = 0
total_error_paths = []
for _ in range(num_workers):
correct_count, error_count, error_paths = result_queue.get()
total_correct_count += correct_count
total_error_count += error_count
total_error_paths.extend(error_paths)
print("correct num: %d, error num: %d " % (total_correct_count, total_error_count))
if len(total_error_paths) > 0:
np.save(os.path.join(dst_dir, 'error_paths.npy'), total_error_paths)
# merge each worker lrc
cache_paths = glob.glob(os.path.join(dst_dir, '*.lrc'))
merge_record_file(cache_paths, os.path.join(dst_dir, 'table.lrc'))
for cache_path in cache_paths:
os.remove(cache_path)
def main():
args = parse_args()
paths = get_sub_paths(args.src_dir, ["pdf", "chunk", "structure"], ['.pdf', '.chunk', '.json'])
# paths = get_paths(args.src_dir, ["pdf", "chunk", "structure"], '/yrfs1/intern/zrzhang6/TSR/Dataset/SciTSR/SciTSR-COMP.list', ['.pdf', '.chunk', '.json'])
if args.num_workers == 0:
single_process(paths, args.dst_dir)
else:
multi_process(paths, args.dst_dir, args.num_workers)
if __name__ == "__main__":
main() |