Initial commit: add code
Browse files- README.md +13 -0
- dataset/extract_ocr.py +155 -0
- dataset/process_scitsr.sh +0 -0
- dataset/trans2lrc.py +145 -0
- dataset/utils/extract_table_lines.py +192 -0
- dataset/utils/list_record_cache.py +143 -0
- dataset/utils/utils.py +449 -0
- libs/configs/__init__.py +24 -0
- libs/configs/default.py +77 -0
- libs/data/__init__.py +69 -0
- libs/data/batch_sampler.py +118 -0
- libs/data/dataset.py +164 -0
- libs/data/list_record_cache.py +143 -0
- libs/data/transform.py +70 -0
- libs/data/utils.py +188 -0
- libs/model/__init__.py +16 -0
- libs/model/backbone.py +281 -0
- libs/model/cells_extractor.py +130 -0
- libs/model/decoder.py +277 -0
- libs/model/divide_predictor.py +57 -0
- libs/model/extractor.py +88 -0
- libs/model/fpn.py +37 -0
- libs/model/model.py +65 -0
- libs/model/pan.py +24 -0
- libs/model/sa.py +35 -0
- libs/model/segment_predictor.py +133 -0
- libs/model/utils.py +371 -0
- libs/utils/__init__.py +0 -0
- libs/utils/cal_f1.py +214 -0
- libs/utils/checkpoint.py +47 -0
- libs/utils/comm.py +129 -0
- libs/utils/context_cacher.py +15 -0
- libs/utils/counter.py +43 -0
- libs/utils/format_translate.py +278 -0
- libs/utils/logger.py +64 -0
- libs/utils/metric.py +58 -0
- libs/utils/model_synchronizer.py +75 -0
- libs/utils/scitsr/__init__.py +0 -0
- libs/utils/scitsr/eval.py +179 -0
- libs/utils/scitsr/relation.py +59 -0
- libs/utils/scitsr/table.py +133 -0
- libs/utils/teds.py +212 -0
- libs/utils/teds_multiprocess.py +111 -0
- libs/utils/time_counter.py +108 -0
- libs/utils/utils.py +297 -0
- libs/utils/vocab.py +36 -0
- requirements.txt +93 -0
- runner/train.py +245 -0
- runner/valid.py +116 -0
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Training
|
| 3 |
+
emoji: 🦀
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.45.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
dataset/extract_ocr.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import tqdm
|
| 4 |
+
import numpy as np
|
| 5 |
+
from utlis.list_record_cache import ListRecordCacher, merge_record_file
|
| 6 |
+
from utlis.utlis import get_paths, get_sub_paths, crop_pdf, extract_ocr, refine_table, visualize_table
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_args():
|
| 10 |
+
import argparse
|
| 11 |
+
parser = argparse.ArgumentParser()
|
| 12 |
+
parser.add_argument('src_dir', type=str, default=None)
|
| 13 |
+
parser.add_argument('dst_dir', type=str, default=None)
|
| 14 |
+
parser.add_argument('-n', '--num_workers', type=int, default=0)
|
| 15 |
+
args = parser.parse_args()
|
| 16 |
+
return args
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def single_process(paths, dst_dir):
|
| 20 |
+
|
| 21 |
+
output_pdf_dir = os.path.join(dst_dir, 'pdf')
|
| 22 |
+
if not os.path.exists(output_pdf_dir):
|
| 23 |
+
os.makedirs(output_pdf_dir)
|
| 24 |
+
output_img_dir = os.path.join(dst_dir, 'img')
|
| 25 |
+
if not os.path.exists(output_img_dir):
|
| 26 |
+
os.makedirs(output_img_dir)
|
| 27 |
+
output_error_dir = os.path.join(dst_dir, 'error')
|
| 28 |
+
if not os.path.exists(output_error_dir):
|
| 29 |
+
os.makedirs(output_error_dir)
|
| 30 |
+
output_visual_dir = os.path.join(dst_dir, 'visual')
|
| 31 |
+
if not os.path.exists(output_visual_dir):
|
| 32 |
+
os.makedirs(output_visual_dir)
|
| 33 |
+
|
| 34 |
+
cacher = ListRecordCacher(os.path.join(dst_dir, 'table.lrc'))
|
| 35 |
+
|
| 36 |
+
error_paths = []
|
| 37 |
+
error_count = 0
|
| 38 |
+
correct_count = 0
|
| 39 |
+
for id, path in enumerate(tqdm.tqdm(paths)):
|
| 40 |
+
try:
|
| 41 |
+
pdf_path, chunk_path, structure_path = path
|
| 42 |
+
name = os.path.splitext(os.path.basename(pdf_path))[0]
|
| 43 |
+
|
| 44 |
+
positions, transcripts = crop_pdf(path, output_pdf_dir)
|
| 45 |
+
table = extract_ocr(path, positions, transcripts)
|
| 46 |
+
table = refine_table(table, os.path.join(output_pdf_dir, name + '.png'), output_img_dir)
|
| 47 |
+
|
| 48 |
+
assert os.path.exists(structure_path), print('structure_path is not existed')
|
| 49 |
+
table['label_path'] = structure_path
|
| 50 |
+
|
| 51 |
+
visualize_table(os.path.join(output_img_dir, name + '.png'), output_visual_dir, table)
|
| 52 |
+
|
| 53 |
+
cacher.add_record(table)
|
| 54 |
+
correct_count += 1
|
| 55 |
+
except:
|
| 56 |
+
error_count += 1
|
| 57 |
+
error_paths.append(path)
|
| 58 |
+
crop_pdf(path, output_error_dir)
|
| 59 |
+
|
| 60 |
+
print("correct num: %d, error num: %d " % (correct_count, error_count))
|
| 61 |
+
if len(error_paths) > 0:
|
| 62 |
+
np.save(os.path.join(dst_dir, 'error_paths.npy'), error_paths)
|
| 63 |
+
cacher.close()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _worker(worker_idx, num_workers, paths, dst_dir, result_queue):
|
| 67 |
+
|
| 68 |
+
output_pdf_dir = os.path.join(dst_dir, 'pdf')
|
| 69 |
+
output_img_dir = os.path.join(dst_dir, 'img')
|
| 70 |
+
output_error_dir = os.path.join(dst_dir, 'error')
|
| 71 |
+
output_visual_dir = os.path.join(dst_dir, 'visual')
|
| 72 |
+
|
| 73 |
+
cacher = ListRecordCacher(os.path.join(dst_dir, 'table_%d.lrc' % worker_idx))
|
| 74 |
+
|
| 75 |
+
error_paths = []
|
| 76 |
+
error_count = 0
|
| 77 |
+
correct_count = 0
|
| 78 |
+
for id, path in enumerate(tqdm.tqdm(paths)):
|
| 79 |
+
try:
|
| 80 |
+
pdf_path, chunk_path, structure_path = path
|
| 81 |
+
name = os.path.splitext(os.path.basename(pdf_path))[0]
|
| 82 |
+
|
| 83 |
+
positions, transcripts = crop_pdf(path, output_pdf_dir)
|
| 84 |
+
table = extract_ocr(path, positions, transcripts)
|
| 85 |
+
table = refine_table(table, os.path.join(output_pdf_dir, name + '.png'), output_img_dir)
|
| 86 |
+
|
| 87 |
+
assert os.path.exists(structure_path), print('structure_path is not existed')
|
| 88 |
+
table['label_path'] = structure_path
|
| 89 |
+
|
| 90 |
+
visualize_table(os.path.join(output_img_dir, name + '.png'), output_visual_dir, table)
|
| 91 |
+
cacher.add_record(table)
|
| 92 |
+
correct_count += 1
|
| 93 |
+
except:
|
| 94 |
+
error_count += 1
|
| 95 |
+
error_paths.append(path)
|
| 96 |
+
crop_pdf(path, output_error_dir)
|
| 97 |
+
|
| 98 |
+
result_queue.put((correct_count, error_count, error_paths))
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def multi_process(path, dst_dir, num_workers):
|
| 102 |
+
import multiprocessing
|
| 103 |
+
manager = multiprocessing.Manager()
|
| 104 |
+
result_queue = manager.Queue()
|
| 105 |
+
|
| 106 |
+
workers = list()
|
| 107 |
+
for worker_idx in range(num_workers):
|
| 108 |
+
worker = multiprocessing.Process(
|
| 109 |
+
target=_worker,
|
| 110 |
+
args=(
|
| 111 |
+
worker_idx,
|
| 112 |
+
num_workers,
|
| 113 |
+
path[worker_idx::num_workers],
|
| 114 |
+
dst_dir,
|
| 115 |
+
result_queue
|
| 116 |
+
)
|
| 117 |
+
)
|
| 118 |
+
worker.daemon = True
|
| 119 |
+
worker.start()
|
| 120 |
+
workers.append(worker)
|
| 121 |
+
|
| 122 |
+
total_correct_count = 0
|
| 123 |
+
total_error_count = 0
|
| 124 |
+
total_error_paths = []
|
| 125 |
+
for _ in range(num_workers):
|
| 126 |
+
correct_count, error_count, error_paths = result_queue.get()
|
| 127 |
+
total_correct_count += correct_count
|
| 128 |
+
total_error_count += error_count
|
| 129 |
+
total_error_paths.extend(error_paths)
|
| 130 |
+
|
| 131 |
+
print("correct num: %d, error num: %d " % (total_correct_count, total_error_count))
|
| 132 |
+
if len(total_error_paths) > 0:
|
| 133 |
+
np.save(os.path.join(dst_dir, 'error_paths.npy'), total_error_paths)
|
| 134 |
+
|
| 135 |
+
# merge each worker lrc
|
| 136 |
+
cache_paths = glob.glob(os.path.join(dst_dir, '*.lrc'))
|
| 137 |
+
merge_record_file(cache_paths, os.path.join(dst_dir, 'table.lrc'))
|
| 138 |
+
for cache_path in cache_paths:
|
| 139 |
+
os.remove(cache_path)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def main():
|
| 143 |
+
args = parse_args()
|
| 144 |
+
|
| 145 |
+
paths = get_sub_paths(args.src_dir, ["pdf", "chunk", "structure"], ['.pdf', '.chunk', '.json'])
|
| 146 |
+
# paths = get_paths(args.src_dir, ["pdf", "chunk", "structure"], '/yrfs1/intern/zrzhang6/TSR/Dataset/SciTSR/SciTSR-COMP.list', ['.pdf', '.chunk', '.json'])
|
| 147 |
+
|
| 148 |
+
if args.num_workers == 0:
|
| 149 |
+
single_process(paths, args.dst_dir)
|
| 150 |
+
else:
|
| 151 |
+
multi_process(paths, args.dst_dir, args.num_workers)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
main()
|
dataset/process_scitsr.sh
ADDED
|
File without changes
|
dataset/trans2lrc.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import tqdm
|
| 4 |
+
import numpy as np
|
| 5 |
+
from utlis.list_record_cache import ListRecordCacher, merge_record_file
|
| 6 |
+
from utlis.utlis import get_sub_paths, crop_pdf, crop_cells, visualize_cell, match_cells
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_args():
|
| 10 |
+
import argparse
|
| 11 |
+
parser = argparse.ArgumentParser()
|
| 12 |
+
parser.add_argument('src_dir', type=str, default=None)
|
| 13 |
+
parser.add_argument('dst_dir', type=str, default=None)
|
| 14 |
+
parser.add_argument('-n', '--num_workers', type=int, default=0)
|
| 15 |
+
args = parser.parse_args()
|
| 16 |
+
return args
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def single_process(paths, dst_dir):
|
| 20 |
+
|
| 21 |
+
output_pdf_dir = os.path.join(dst_dir, 'pdf')
|
| 22 |
+
if not os.path.exists(output_pdf_dir):
|
| 23 |
+
os.makedirs(output_pdf_dir)
|
| 24 |
+
output_img_dir = os.path.join(dst_dir, 'img')
|
| 25 |
+
if not os.path.exists(output_img_dir):
|
| 26 |
+
os.makedirs(output_img_dir)
|
| 27 |
+
output_visual_dir = os.path.join(dst_dir, 'visual')
|
| 28 |
+
if not os.path.exists(output_visual_dir):
|
| 29 |
+
os.makedirs(output_visual_dir)
|
| 30 |
+
output_error_dir = os.path.join(dst_dir, 'error')
|
| 31 |
+
if not os.path.exists(output_error_dir):
|
| 32 |
+
os.makedirs(output_error_dir)
|
| 33 |
+
|
| 34 |
+
cacher = ListRecordCacher(os.path.join(dst_dir, 'table.lrc'))
|
| 35 |
+
|
| 36 |
+
error_paths = []
|
| 37 |
+
error_count = 0
|
| 38 |
+
correct_count = 0
|
| 39 |
+
for id, path in enumerate(tqdm.tqdm(paths)):
|
| 40 |
+
try:
|
| 41 |
+
pdf_path, chunk_path, structure_path = path
|
| 42 |
+
positions, transcripts = crop_pdf(path, output_pdf_dir)
|
| 43 |
+
table = match_cells([pdf_path, chunk_path, structure_path], positions, transcripts)
|
| 44 |
+
crop_cells(os.path.join(output_pdf_dir, os.path.splitext(os.path.basename(pdf_path))[0] + '.png'), output_img_dir, table)
|
| 45 |
+
table['id'] = id
|
| 46 |
+
table['image_path'] = os.path.join(output_img_dir, os.path.splitext(os.path.basename(pdf_path))[0] + '.png')
|
| 47 |
+
visualize_cell(os.path.join(output_img_dir, os.path.splitext(os.path.basename(pdf_path))[0] + '.png'), output_visual_dir, table)
|
| 48 |
+
cacher.add_record(table)
|
| 49 |
+
correct_count += 1
|
| 50 |
+
except:
|
| 51 |
+
error_count += 1
|
| 52 |
+
error_paths.append(path)
|
| 53 |
+
crop_pdf(path, output_error_dir)
|
| 54 |
+
|
| 55 |
+
print("correct num: %d, error num: %d " % (correct_count, error_count))
|
| 56 |
+
if len(error_paths) > 0:
|
| 57 |
+
np.save(os.path.join(dst_dir, 'error_paths.npy'), error_paths)
|
| 58 |
+
cacher.close()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _worker(worker_idx, num_workers, paths, dst_dir, result_queue):
|
| 62 |
+
|
| 63 |
+
output_pdf_dir = os.path.join(dst_dir, 'pdf')
|
| 64 |
+
output_img_dir = os.path.join(dst_dir, 'img')
|
| 65 |
+
output_visual_dir = os.path.join(dst_dir, 'visual')
|
| 66 |
+
output_error_dir = os.path.join(dst_dir, 'error')
|
| 67 |
+
|
| 68 |
+
cacher = ListRecordCacher(os.path.join(dst_dir, 'table_%d.lrc' % worker_idx))
|
| 69 |
+
|
| 70 |
+
error_paths = []
|
| 71 |
+
error_count = 0
|
| 72 |
+
correct_count = 0
|
| 73 |
+
for id, path in enumerate(tqdm.tqdm(paths)):
|
| 74 |
+
try:
|
| 75 |
+
pdf_path, chunk_path, structure_path = path
|
| 76 |
+
positions, transcripts = crop_pdf(path, output_pdf_dir)
|
| 77 |
+
table = match_cells([pdf_path, chunk_path, structure_path], positions, transcripts)
|
| 78 |
+
crop_cells(os.path.join(output_pdf_dir, os.path.splitext(os.path.basename(pdf_path))[0] + '.png'), output_img_dir, table)
|
| 79 |
+
table['id'] = int(id * num_workers + worker_idx)
|
| 80 |
+
table['image_path'] = os.path.join(output_img_dir, os.path.splitext(os.path.basename(pdf_path))[0] + '.png')
|
| 81 |
+
visualize_cell(os.path.join(output_img_dir, os.path.splitext(os.path.basename(pdf_path))[0] + '.png'), output_visual_dir, table)
|
| 82 |
+
cacher.add_record(table)
|
| 83 |
+
correct_count += 1
|
| 84 |
+
except:
|
| 85 |
+
error_count += 1
|
| 86 |
+
error_paths.append(path)
|
| 87 |
+
crop_pdf(path, output_error_dir)
|
| 88 |
+
|
| 89 |
+
result_queue.put((correct_count, error_count, error_paths))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def multi_process(path, dst_dir, num_workers):
|
| 93 |
+
import multiprocessing
|
| 94 |
+
manager = multiprocessing.Manager()
|
| 95 |
+
result_queue = manager.Queue()
|
| 96 |
+
|
| 97 |
+
workers = list()
|
| 98 |
+
for worker_idx in range(num_workers):
|
| 99 |
+
worker = multiprocessing.Process(
|
| 100 |
+
target=_worker,
|
| 101 |
+
args=(
|
| 102 |
+
worker_idx,
|
| 103 |
+
num_workers,
|
| 104 |
+
path[worker_idx::num_workers],
|
| 105 |
+
dst_dir,
|
| 106 |
+
result_queue
|
| 107 |
+
)
|
| 108 |
+
)
|
| 109 |
+
worker.daemon = True
|
| 110 |
+
worker.start()
|
| 111 |
+
workers.append(worker)
|
| 112 |
+
|
| 113 |
+
total_correct_count = 0
|
| 114 |
+
total_error_count = 0
|
| 115 |
+
total_error_paths = []
|
| 116 |
+
for _ in range(num_workers):
|
| 117 |
+
correct_count, error_count, error_paths = result_queue.get()
|
| 118 |
+
total_correct_count += correct_count
|
| 119 |
+
total_error_count += error_count
|
| 120 |
+
total_error_paths.extend(error_paths)
|
| 121 |
+
|
| 122 |
+
print("correct num: %d, error num: %d " % (total_correct_count, total_error_count))
|
| 123 |
+
if len(total_error_paths) > 0:
|
| 124 |
+
np.save(os.path.join(dst_dir, 'error_paths.npy'), total_error_paths)
|
| 125 |
+
|
| 126 |
+
# merge each worker lrc
|
| 127 |
+
cache_paths = glob.glob(os.path.join(dst_dir, '*.lrc'))
|
| 128 |
+
merge_record_file(cache_paths, os.path.join(dst_dir, 'table.lrc'))
|
| 129 |
+
for cache_path in cache_paths:
|
| 130 |
+
os.remove(cache_path)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def main():
|
| 134 |
+
args = parse_args()
|
| 135 |
+
|
| 136 |
+
paths = get_sub_paths(args.src_dir, ["pdf", "chunk", "structure"], ['.pdf', '.chunk', '.json'])
|
| 137 |
+
|
| 138 |
+
if args.num_workers == 0:
|
| 139 |
+
single_process(paths, args.dst_dir)
|
| 140 |
+
else:
|
| 141 |
+
multi_process(paths, args.dst_dir, args.num_workers)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
if __name__ == "__main__":
|
| 145 |
+
main()
|
dataset/utils/extract_table_lines.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class InvalidFormat(Exception):
|
| 6 |
+
pass
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def segmentation_to_bbox(segmentation):
|
| 10 |
+
x1 = min([pt[0] for contour in segmentation for pt in contour])
|
| 11 |
+
y1 = min([pt[1] for contour in segmentation for pt in contour])
|
| 12 |
+
x2 = max([pt[0] for contour in segmentation for pt in contour])
|
| 13 |
+
y2 = max([pt[1] for contour in segmentation for pt in contour])
|
| 14 |
+
return (x1, y1, x2, y2)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def cal_cell_bbox(table):
|
| 18 |
+
cells_bbox = list()
|
| 19 |
+
for cell in table['cells']:
|
| 20 |
+
if 'segmentation' not in cell:
|
| 21 |
+
cell_bbox = None
|
| 22 |
+
else:
|
| 23 |
+
segmentation = list()
|
| 24 |
+
if 'sublines' in cell:
|
| 25 |
+
for subline in cell['sublines']:
|
| 26 |
+
segmentation.extend(subline['segmentation'])
|
| 27 |
+
if len(segmentation) == 0:
|
| 28 |
+
segmentation = cell['segmentation']
|
| 29 |
+
if len(segmentation) == 0:
|
| 30 |
+
cell_bbox = None
|
| 31 |
+
else:
|
| 32 |
+
cell_bbox = segmentation_to_bbox(segmentation)
|
| 33 |
+
cells_bbox.append(cell_bbox)
|
| 34 |
+
return cells_bbox
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def cal_cell_spans(table):
|
| 38 |
+
layout = table['layout']
|
| 39 |
+
num_cells = len(table['cells'])
|
| 40 |
+
cells_span = list()
|
| 41 |
+
for cell_id in range(num_cells):
|
| 42 |
+
cell_positions = np.argwhere(layout == cell_id)
|
| 43 |
+
y1 = np.min(cell_positions[:, 0])
|
| 44 |
+
y2 = np.max(cell_positions[:, 0])
|
| 45 |
+
x1 = np.min(cell_positions[:, 1])
|
| 46 |
+
x2 = np.max(cell_positions[:, 1])
|
| 47 |
+
assert np.all(layout[y1:y2, x1:x2] == cell_id)
|
| 48 |
+
cells_span.append([x1, y1, x2, y2])
|
| 49 |
+
return cells_span
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def cal_fg_bg_span(spans, edge):
|
| 53 |
+
num_span = len(spans)
|
| 54 |
+
bg_spans = list()
|
| 55 |
+
for idx in range(num_span):
|
| 56 |
+
if spans[idx] is None:
|
| 57 |
+
continue
|
| 58 |
+
if idx == 0:
|
| 59 |
+
if spans[idx][0] <= 0:
|
| 60 |
+
continue
|
| 61 |
+
else:
|
| 62 |
+
if spans[idx-1] is None:
|
| 63 |
+
continue
|
| 64 |
+
if spans[idx][0] <= spans[idx-1][1]:
|
| 65 |
+
continue
|
| 66 |
+
if idx == num_span - 1:
|
| 67 |
+
if spans[idx][1] >= edge:
|
| 68 |
+
continue
|
| 69 |
+
else:
|
| 70 |
+
if spans[idx+1] is None:
|
| 71 |
+
continue
|
| 72 |
+
if spans[idx][1] >= spans[idx+1][0]:
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
bg_spans.append(spans[idx])
|
| 76 |
+
|
| 77 |
+
fg_spans = list()
|
| 78 |
+
for idx in range(num_span+1):
|
| 79 |
+
if idx == 0:
|
| 80 |
+
s = 0
|
| 81 |
+
else:
|
| 82 |
+
if spans[idx-1] is None:
|
| 83 |
+
continue
|
| 84 |
+
s = spans[idx-1][1]
|
| 85 |
+
|
| 86 |
+
if idx == num_span:
|
| 87 |
+
e = edge
|
| 88 |
+
else:
|
| 89 |
+
if spans[idx] is None:
|
| 90 |
+
continue
|
| 91 |
+
e = spans[idx][0]
|
| 92 |
+
|
| 93 |
+
if e <= s:
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
fg_spans.append([s, e])
|
| 97 |
+
|
| 98 |
+
return fg_spans, bg_spans
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def shrink_spans(spans, size):
|
| 102 |
+
new_spans = list()
|
| 103 |
+
for idx, (start, end) in enumerate(spans):
|
| 104 |
+
if idx == 0:
|
| 105 |
+
if start <= 0:
|
| 106 |
+
start = 1
|
| 107 |
+
else:
|
| 108 |
+
_, pre_end = spans[idx - 1]
|
| 109 |
+
if start <= pre_end:
|
| 110 |
+
shrink_distance = pre_end - start + 1
|
| 111 |
+
start = start + math.ceil(shrink_distance / 2)
|
| 112 |
+
|
| 113 |
+
if idx == len(spans) - 1:
|
| 114 |
+
if end >= size:
|
| 115 |
+
end = size - 1
|
| 116 |
+
else:
|
| 117 |
+
next_start, _ = spans[idx + 1]
|
| 118 |
+
if end >= next_start:
|
| 119 |
+
shrink_distance = end - next_start + 1
|
| 120 |
+
end = end - math.ceil(shrink_distance / 2)
|
| 121 |
+
if end - start < 1:
|
| 122 |
+
raise InvalidFormat()
|
| 123 |
+
|
| 124 |
+
new_spans.append([start, end])
|
| 125 |
+
return new_spans
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def cal_row_span(table, cells_span, cells_bbox, height):
|
| 129 |
+
layout = table['layout']
|
| 130 |
+
rows_span = list()
|
| 131 |
+
for row_idx in range(layout.shape[0]):
|
| 132 |
+
row = layout[row_idx, :]
|
| 133 |
+
y1s = list()
|
| 134 |
+
y2s = list()
|
| 135 |
+
for cell_id in row:
|
| 136 |
+
cell_span = cells_span[cell_id]
|
| 137 |
+
cell_bbox = cells_bbox[cell_id]
|
| 138 |
+
if (cell_span[1] == row_idx) and (cell_bbox is not None):
|
| 139 |
+
y1s.append(cell_bbox[1])
|
| 140 |
+
if (cell_span[3] == row_idx) and (cell_bbox is not None):
|
| 141 |
+
y2s.append(cell_bbox[3])
|
| 142 |
+
|
| 143 |
+
if (len(y1s) > 0) and (len(y2s) > 0):
|
| 144 |
+
y1 = min(max(1, min(y1s)), height-1)
|
| 145 |
+
y2 = min(max(1, max(y2s) + 1), height-1)
|
| 146 |
+
rows_span.append([y1, y2])
|
| 147 |
+
else:
|
| 148 |
+
raise InvalidFormat()
|
| 149 |
+
rows_span = shrink_spans(rows_span, height)
|
| 150 |
+
rows_fg_span, rows_bg_span = cal_fg_bg_span(rows_span, height)
|
| 151 |
+
return rows_fg_span, rows_bg_span
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def cal_col_span(table, cells_span, cells_bbox, width):
|
| 155 |
+
layout = table['layout']
|
| 156 |
+
cols_span = list()
|
| 157 |
+
for col_idx in range(layout.shape[1]):
|
| 158 |
+
col = layout[:, col_idx]
|
| 159 |
+
x1s = list()
|
| 160 |
+
x2s = list()
|
| 161 |
+
for cell_id in col:
|
| 162 |
+
cell_span = cells_span[cell_id]
|
| 163 |
+
cell_bbox = cells_bbox[cell_id]
|
| 164 |
+
if (cell_span[0] == col_idx) and (cell_bbox is not None):
|
| 165 |
+
x1s.append(cell_bbox[0])
|
| 166 |
+
if (cell_span[2] == col_idx) and (cell_bbox is not None):
|
| 167 |
+
x2s.append(cell_bbox[2])
|
| 168 |
+
|
| 169 |
+
if (len(x1s) > 0) and (len(x2s) > 0):
|
| 170 |
+
x1 = min(max(1, min(x1s)), width-1)
|
| 171 |
+
x2 = min(max(1, max(x2s) + 1), width-1)
|
| 172 |
+
cols_span.append([x1, x2])
|
| 173 |
+
else:
|
| 174 |
+
raise InvalidFormat()
|
| 175 |
+
cols_span = shrink_spans(cols_span, width)
|
| 176 |
+
cols_fg_span, cols_bg_span = cal_fg_bg_span(cols_span, width)
|
| 177 |
+
return cols_fg_span, cols_bg_span
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def extract_fg_bg_spans(table, image_size):
|
| 181 |
+
width, height = image_size
|
| 182 |
+
cells_bbox = cal_cell_bbox(table)
|
| 183 |
+
cells_span = cal_cell_spans(table)
|
| 184 |
+
# cal rows fg bg span
|
| 185 |
+
rows_fg_span, rows_bg_span = cal_row_span(
|
| 186 |
+
table, cells_span, cells_bbox, height
|
| 187 |
+
)
|
| 188 |
+
# cal cols fg bg span
|
| 189 |
+
cols_fg_span, cols_bg_span = cal_col_span(
|
| 190 |
+
table, cells_span, cells_bbox, width
|
| 191 |
+
)
|
| 192 |
+
return rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span
|
dataset/utils/list_record_cache.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import threading
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def merge_record_file(files, dst_file):
|
| 7 |
+
cmd = 'cat'
|
| 8 |
+
for file in files:
|
| 9 |
+
cmd += ' %s' % file
|
| 10 |
+
cmd += ' > %s' % dst_file
|
| 11 |
+
os.system(cmd)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ListRecordCacher:
|
| 15 |
+
OFFSET_LENGTH = 8
|
| 16 |
+
def __init__(self, cache_path):
|
| 17 |
+
self._record_pos_list = list()
|
| 18 |
+
self._cache_file = open(cache_path, 'wb')
|
| 19 |
+
self._cached_bytes = b'\x00' * self.OFFSET_LENGTH
|
| 20 |
+
|
| 21 |
+
def add_record(self, record):
|
| 22 |
+
record_bytes = pickle.dumps(record)
|
| 23 |
+
return self.add_record_bytes(record_bytes)
|
| 24 |
+
|
| 25 |
+
def add_record_bytes(self, record_bytes):
|
| 26 |
+
bytes_size = len(record_bytes)
|
| 27 |
+
offset_bytes = bytes_size.to_bytes(
|
| 28 |
+
length=self.OFFSET_LENGTH,
|
| 29 |
+
byteorder='big', signed=False
|
| 30 |
+
)
|
| 31 |
+
total_bytes = offset_bytes + record_bytes
|
| 32 |
+
|
| 33 |
+
cur_record_pos = None
|
| 34 |
+
if len(self._record_pos_list) == 0:
|
| 35 |
+
cur_record_pos = [self.OFFSET_LENGTH*2, bytes_size]
|
| 36 |
+
else:
|
| 37 |
+
cur_record_pos = [sum(self._record_pos_list[-1]) + self.OFFSET_LENGTH, bytes_size]
|
| 38 |
+
self._record_pos_list.append(cur_record_pos)
|
| 39 |
+
|
| 40 |
+
self._cached_bytes += total_bytes
|
| 41 |
+
if len(self._cached_bytes) > 1024*1024:
|
| 42 |
+
self._cache_file.seek(0, 2)
|
| 43 |
+
self._cache_file.write(self._cached_bytes)
|
| 44 |
+
self._cached_bytes = b''
|
| 45 |
+
|
| 46 |
+
def flush(self):
|
| 47 |
+
if len(self._cached_bytes) > 0:
|
| 48 |
+
self._cache_file.seek(0, 2)
|
| 49 |
+
self._cache_file.write(self._cached_bytes)
|
| 50 |
+
self._cached_bytes = b''
|
| 51 |
+
|
| 52 |
+
def _wirte_record_pos_list(self):
|
| 53 |
+
self.flush()
|
| 54 |
+
self._cache_file.seek(0, 2)
|
| 55 |
+
offset = self._cache_file.tell()
|
| 56 |
+
offset_bytes = offset.to_bytes(
|
| 57 |
+
length=self.OFFSET_LENGTH,
|
| 58 |
+
byteorder='big', signed=False
|
| 59 |
+
)
|
| 60 |
+
self._cache_file.seek(0)
|
| 61 |
+
self._cache_file.write(offset_bytes)
|
| 62 |
+
|
| 63 |
+
data_bytes = pickle.dumps(self._record_pos_list)
|
| 64 |
+
bytes_size = len(data_bytes)
|
| 65 |
+
offset_bytes = bytes_size.to_bytes(
|
| 66 |
+
length=self.OFFSET_LENGTH,
|
| 67 |
+
byteorder='big', signed=False
|
| 68 |
+
)
|
| 69 |
+
total_bytes = offset_bytes + data_bytes
|
| 70 |
+
self._cache_file.seek(0, 2)
|
| 71 |
+
self._cache_file.write(total_bytes)
|
| 72 |
+
|
| 73 |
+
def close(self):
|
| 74 |
+
if not self._cache_file.closed:
|
| 75 |
+
self._wirte_record_pos_list()
|
| 76 |
+
self._cache_file.close()
|
| 77 |
+
|
| 78 |
+
def __del__(self):
|
| 79 |
+
self.close()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ListRecordLoader:
|
| 83 |
+
OFFSET_LENGTH = 8
|
| 84 |
+
def __init__(self, load_path):
|
| 85 |
+
self._sync_lock = threading.Lock()
|
| 86 |
+
self._size = os.path.getsize(load_path)
|
| 87 |
+
self._load_path = load_path
|
| 88 |
+
self._open_file()
|
| 89 |
+
self._scan_file()
|
| 90 |
+
|
| 91 |
+
def _open_file(self):
|
| 92 |
+
self._pid = os.getpid()
|
| 93 |
+
self._cache_file = open(self._load_path, 'rb')
|
| 94 |
+
|
| 95 |
+
def _check_reopen(self):
|
| 96 |
+
if (self._pid != os.getpid()):
|
| 97 |
+
self._open_file()
|
| 98 |
+
|
| 99 |
+
def _scan_file(self):
|
| 100 |
+
record_pos_list = list()
|
| 101 |
+
pos = 0
|
| 102 |
+
while True:
|
| 103 |
+
if pos >= self._size:
|
| 104 |
+
break
|
| 105 |
+
self._cache_file.seek(pos)
|
| 106 |
+
offset = int().from_bytes(
|
| 107 |
+
self._cache_file.read(self.OFFSET_LENGTH),
|
| 108 |
+
byteorder='big', signed=False
|
| 109 |
+
)
|
| 110 |
+
offset = pos + offset
|
| 111 |
+
self._cache_file.seek(offset)
|
| 112 |
+
|
| 113 |
+
byte_size = int().from_bytes(
|
| 114 |
+
self._cache_file.read(self.OFFSET_LENGTH),
|
| 115 |
+
byteorder='big', signed=False
|
| 116 |
+
)
|
| 117 |
+
record_pos_list_bytes = self._cache_file.read(byte_size)
|
| 118 |
+
sub_record_pos_list = pickle.loads(record_pos_list_bytes)
|
| 119 |
+
assert isinstance(sub_record_pos_list, list)
|
| 120 |
+
sub_record_pos_list = [[item[0]+pos, item[1]] for item in sub_record_pos_list]
|
| 121 |
+
record_pos_list.extend(sub_record_pos_list)
|
| 122 |
+
pos = self._cache_file.tell()
|
| 123 |
+
|
| 124 |
+
self._record_pos_list = record_pos_list
|
| 125 |
+
|
| 126 |
+
def get_record(self, idx):
|
| 127 |
+
self._check_reopen()
|
| 128 |
+
record_bytes = self.get_record_bytes(idx)
|
| 129 |
+
record = pickle.loads(record_bytes)
|
| 130 |
+
return record
|
| 131 |
+
|
| 132 |
+
def get_record_bytes(self, idx):
|
| 133 |
+
offset, length = self._record_pos_list[idx]
|
| 134 |
+
self._sync_lock.acquire()
|
| 135 |
+
try:
|
| 136 |
+
self._cache_file.seek(offset)
|
| 137 |
+
record_bytes = self._cache_file.read(length)
|
| 138 |
+
finally:
|
| 139 |
+
self._sync_lock.release()
|
| 140 |
+
return record_bytes
|
| 141 |
+
|
| 142 |
+
def __len__(self):
|
| 143 |
+
return len(self._record_pos_list)
|
dataset/utils/utils.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import json
|
| 4 |
+
import copy
|
| 5 |
+
import tqdm
|
| 6 |
+
import numpy as np
|
| 7 |
+
import fitz
|
| 8 |
+
from .extract_table_lines import extract_fg_bg_spans
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_paths(root_dir, sub_names, names_path, exts, val=None):
|
| 12 |
+
# Check the existence of directories
|
| 13 |
+
assert os.path.isdir(root_dir)
|
| 14 |
+
|
| 15 |
+
with open(names_path, "r") as f:
|
| 16 |
+
names = f.readlines()
|
| 17 |
+
names = [name.strip() for name in names]
|
| 18 |
+
|
| 19 |
+
# TODO: sub_dirs redundancy
|
| 20 |
+
sub_dirs = []
|
| 21 |
+
for sub_name in sub_names:
|
| 22 |
+
sub_dir = os.path.join(root_dir, sub_name)
|
| 23 |
+
assert os.path.isdir(sub_dir), '"%s" is not dir.' % sub_dir
|
| 24 |
+
sub_dirs.append(sub_dir)
|
| 25 |
+
|
| 26 |
+
paths = []
|
| 27 |
+
names = names[:val]
|
| 28 |
+
for name in tqdm.tqdm(names):
|
| 29 |
+
sub_paths = []
|
| 30 |
+
for sub_dir, ext in zip(sub_dirs, exts):
|
| 31 |
+
sub_path = os.path.join(sub_dir, name + ext)
|
| 32 |
+
assert os.path.exists(sub_path), print('%s is not exist' % sub_path)
|
| 33 |
+
sub_paths.append(sub_path)
|
| 34 |
+
paths.append(sub_paths)
|
| 35 |
+
|
| 36 |
+
return paths
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_sub_paths(root_dir, sub_names, exts, val=None):
|
| 40 |
+
# Check the existence of directories
|
| 41 |
+
assert os.path.isdir(root_dir)
|
| 42 |
+
# TODO: sub_dirs redundancy
|
| 43 |
+
sub_dirs = []
|
| 44 |
+
for sub_name in sub_names:
|
| 45 |
+
sub_dir = os.path.join(root_dir, sub_name)
|
| 46 |
+
assert os.path.isdir(sub_dir), '"%s" is not dir.' % sub_dir
|
| 47 |
+
sub_dirs.append(sub_dir)
|
| 48 |
+
|
| 49 |
+
paths = []
|
| 50 |
+
d = os.listdir(sub_dirs[0])[:val]
|
| 51 |
+
for file_name in tqdm.tqdm(d):
|
| 52 |
+
sub_paths = [os.path.join(sub_dirs[0], file_name)]
|
| 53 |
+
name = os.path.splitext(file_name)[0]
|
| 54 |
+
for sub_name, ext in zip(sub_names[1:], exts[1:]):
|
| 55 |
+
sub_path = os.path.join(root_dir, sub_name, name + ext)
|
| 56 |
+
assert os.path.exists(sub_path)
|
| 57 |
+
sub_paths.append(sub_path)
|
| 58 |
+
paths.append(sub_paths)
|
| 59 |
+
|
| 60 |
+
return paths
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def cal_wer(label, rec):
|
| 64 |
+
dist_mat = np.zeros((len(label) + 1, len(rec) + 1), dtype='int32')
|
| 65 |
+
dist_mat[0, :] = range(len(rec) + 1)
|
| 66 |
+
dist_mat[:, 0] = range(len(label) + 1)
|
| 67 |
+
|
| 68 |
+
for i in range(1, len(label) + 1):
|
| 69 |
+
for j in range(1, len(rec) + 1):
|
| 70 |
+
hit_score = dist_mat[i - 1, j - 1] + (label[i - 1] != rec[j - 1])
|
| 71 |
+
ins_score = dist_mat[i, j - 1] + 1
|
| 72 |
+
del_score = dist_mat[i - 1, j] + 1
|
| 73 |
+
dist_mat[i, j] = min(hit_score, ins_score, del_score)
|
| 74 |
+
|
| 75 |
+
dist = dist_mat[len(label), len(rec)]
|
| 76 |
+
|
| 77 |
+
return 1 - dist / len(label)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def visualize(img_path, chunks, structures):
|
| 81 |
+
image = cv2.imread(img_path)
|
| 82 |
+
for chunk in chunks:
|
| 83 |
+
x1, x2, y1, y2 = chunk["pos"]
|
| 84 |
+
transcript = chunk["text"]
|
| 85 |
+
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255))
|
| 86 |
+
cv2.putText(image, ''.join(transcript), (int(x1), int(max(0, y1-1))), cv2.FONT_HERSHEY_COMPLEX, 0.25, (0 , 0, 255), 1)
|
| 87 |
+
return image
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def visualize_table(img_path, output_dir, table):
|
| 91 |
+
img = cv2.imread(img_path)
|
| 92 |
+
for cell in table['cells']:
|
| 93 |
+
x1, y1, x2, y2 = cell['bbox']
|
| 94 |
+
transcript = cell['transcript']
|
| 95 |
+
cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255))
|
| 96 |
+
cv2.putText(img, ''.join(transcript), (int(x1), int(max(0, y1-1))), cv2.FONT_HERSHEY_COMPLEX, 0.25, (0 , 0, 255), 1)
|
| 97 |
+
cv2.imwrite(os.path.join(output_dir, os.path.basename(img_path)), img)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def crop_pdf(path, output_dir, zoom_x = 2.0, zoom_y = 2.0, rotate=0, expand=10, y_fix=.0):
|
| 101 |
+
'''
|
| 102 |
+
path:[pdf_path, chunk_path]
|
| 103 |
+
crop table region in pdf
|
| 104 |
+
save pdf_name.png
|
| 105 |
+
return list[x1, x2, y1, y2], [str]. note these are corresponding to crop pdf
|
| 106 |
+
'''
|
| 107 |
+
# load data
|
| 108 |
+
with open(path[1], 'r') as f:
|
| 109 |
+
chunks = json.load(f)['chunks']
|
| 110 |
+
doc = fitz.open(path[0])
|
| 111 |
+
pdf_name = os.path.splitext(os.path.basename(path[0]))[0]
|
| 112 |
+
assert doc.pageCount == 1, print(pdf_name, ' has more than 1 page!')
|
| 113 |
+
|
| 114 |
+
# transfer pdf to img
|
| 115 |
+
trans = fitz.Matrix(zoom_x, zoom_y).preRotate(rotate)
|
| 116 |
+
pm = doc[0].getPixmap(matrix=trans, alpha=False)
|
| 117 |
+
pm.writePNG(os.path.join(output_dir, '%s.png' % pdf_name))
|
| 118 |
+
|
| 119 |
+
# crop table region
|
| 120 |
+
pdf_img = cv2.imread(os.path.join(output_dir, '%s.png' % pdf_name))
|
| 121 |
+
h, w, *_ = pdf_img.shape
|
| 122 |
+
positions = []
|
| 123 |
+
transcripts = []
|
| 124 |
+
for chunk in chunks:
|
| 125 |
+
positions.append([chunk['pos'][0], chunk['pos'][1], chunk['pos'][3], chunk['pos'][2]]) # x1, x2, y2, y1
|
| 126 |
+
transcripts.append(chunk["text"])
|
| 127 |
+
|
| 128 |
+
# the last chunk transcrip is repeated
|
| 129 |
+
transcripts[-1] = transcripts[-1][:-1]
|
| 130 |
+
|
| 131 |
+
positions = np.array(positions)
|
| 132 |
+
positions[:, :2] *= zoom_x
|
| 133 |
+
positions[:, 2:] = h - positions[:, 2:] * zoom_y
|
| 134 |
+
x_min = int(max(0, positions[:, :2].min() - expand))
|
| 135 |
+
y_min = int(max(0, positions[:, 2:].min() - expand))
|
| 136 |
+
x_max = int(min(w, positions[:, :2].max() + expand))
|
| 137 |
+
y_max = int(min(h, positions[:, 2:].max() + expand))
|
| 138 |
+
|
| 139 |
+
img_crop = pdf_img[y_min:y_max, x_min:x_max]
|
| 140 |
+
cv2.imwrite(os.path.join(output_dir, '%s.png' % pdf_name), img_crop)
|
| 141 |
+
|
| 142 |
+
positions[:, :2] = np.clip(positions[:, :2] - x_min, 0, w)
|
| 143 |
+
positions[:, 2] -= y_fix * zoom_y
|
| 144 |
+
positions[:, 2:] = np.clip(positions[:, 2:] - y_min, 0, h)
|
| 145 |
+
return positions, transcripts
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def crop_cells(img_path, output_dir, info, expand=10):
|
| 149 |
+
cells = info['cells']
|
| 150 |
+
img = cv2.imread(img_path)
|
| 151 |
+
h, w, *_ = img.shape
|
| 152 |
+
bboxes = [cell['bbox'] for cell in cells if 'bbox' in cell.keys()]
|
| 153 |
+
bboxes = np.array(bboxes)
|
| 154 |
+
x_min = int(max(bboxes[:, 0].min() - expand, 0))
|
| 155 |
+
y_min = int(max(bboxes[:, 1].min() - expand, 0))
|
| 156 |
+
x_max = int(min(bboxes[:, 2].max() + expand, w))
|
| 157 |
+
y_max = int(min(bboxes[:, 3].max() + expand, h))
|
| 158 |
+
cv2.imwrite(os.path.join(output_dir, os.path.splitext(os.path.basename(img_path))[0] + '.png'), img[y_min:y_max, x_min:x_max])
|
| 159 |
+
|
| 160 |
+
# refine cell bbox
|
| 161 |
+
new_cells = []
|
| 162 |
+
for cell in cells:
|
| 163 |
+
if 'bbox' not in cell.keys():
|
| 164 |
+
new_cells.append(cell)
|
| 165 |
+
else:
|
| 166 |
+
cell['bbox'][0] = max(0, cell['bbox'][0] - x_min)
|
| 167 |
+
cell['bbox'][1] = max(0, cell['bbox'][1] - y_min)
|
| 168 |
+
cell['bbox'][2] = max(0, cell['bbox'][2] - x_min)
|
| 169 |
+
cell['bbox'][3] = max(0, cell['bbox'][3] - y_min)
|
| 170 |
+
segmentation = cell['segmentation']
|
| 171 |
+
cell['segmentation'] = [[[pt[0] - x_min, pt[1] - y_min] for pt in contour] for contour in segmentation]
|
| 172 |
+
new_cells.append(cell)
|
| 173 |
+
info['cells'] = new_cells
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def visualize_ocr(img_path, output_dir, positions, transcripts):
|
| 177 |
+
img = cv2.imread(img_path)
|
| 178 |
+
for position, transcript in zip(positions, transcripts):
|
| 179 |
+
x1, x2, y1, y2 = position
|
| 180 |
+
cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255))
|
| 181 |
+
cv2.putText(img, transcript, (int(x1), int(y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,0,0), 1)
|
| 182 |
+
cv2.imwrite(os.path.join(output_dir, os.path.splitext(os.path.basename(img_path))[0] + '_ocr.png'), img)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def cal_cell_spans(table):
|
| 186 |
+
layout = table['layout']
|
| 187 |
+
num_cells = len(table['cells'])
|
| 188 |
+
cells_span = list()
|
| 189 |
+
for cell_id in range(num_cells):
|
| 190 |
+
cell_positions = np.argwhere(layout == cell_id)
|
| 191 |
+
y1 = np.min(cell_positions[:, 0])
|
| 192 |
+
y2 = np.max(cell_positions[:, 0])
|
| 193 |
+
x1 = np.min(cell_positions[:, 1])
|
| 194 |
+
x2 = np.max(cell_positions[:, 1])
|
| 195 |
+
assert np.all(layout[y1:y2, x1:x2] == cell_id)
|
| 196 |
+
cells_span.append([x1, y1, x2, y2])
|
| 197 |
+
return cells_span
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def visualize_cell(img_path, output_dir, table):
|
| 201 |
+
def spans2lines(spans):
|
| 202 |
+
lines = []
|
| 203 |
+
lines.append(spans[0][0])
|
| 204 |
+
for span in spans[1:-1]:
|
| 205 |
+
t1, t2 = span
|
| 206 |
+
lines.append(int((t1 + t2) / 2))
|
| 207 |
+
lines.append(spans[-1][-1])
|
| 208 |
+
return lines
|
| 209 |
+
|
| 210 |
+
img = cv2.imread(img_path)
|
| 211 |
+
|
| 212 |
+
# draw table lines
|
| 213 |
+
rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span = extract_fg_bg_spans(table, img.shape[::-1][-2:])
|
| 214 |
+
row_lines = spans2lines(rows_fg_span)
|
| 215 |
+
col_lines = spans2lines(cols_fg_span)
|
| 216 |
+
for span in cells_span:
|
| 217 |
+
x1, y1, x2, y2 = span
|
| 218 |
+
cv2.rectangle(img, (int(col_lines[x1]), int(row_lines[y1])), (int(col_lines[x2 + 1]), int(row_lines[y2 + 1])), (0, 0, 255), 2)
|
| 219 |
+
|
| 220 |
+
# draw ocr results
|
| 221 |
+
for cell in table['cells']:
|
| 222 |
+
if 'bbox' not in cell.keys():
|
| 223 |
+
continue
|
| 224 |
+
x1, y1, x2, y2 = cell['bbox']
|
| 225 |
+
transcript = cell['transcript']
|
| 226 |
+
cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 1)
|
| 227 |
+
cv2.putText(img, ''.join(transcript), (int(x1), int(y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,0,0), 1)
|
| 228 |
+
cv2.imwrite(os.path.join(output_dir, os.path.splitext(os.path.basename(img_path))[0] + '.png'), img)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def match_cells(path, positions, transcripts, k=16, start=0.333, stop=0.1, stop_percent=0.3, gap=0.25):
|
| 232 |
+
'''
|
| 233 |
+
path: [pdf_path, chunk_path, structure_path]
|
| 234 |
+
positions: [x1, x2, y1, y2],
|
| 235 |
+
transcripts: [str]
|
| 236 |
+
retrun dict(
|
| 237 |
+
'layout':np.array()
|
| 238 |
+
'bbox':[x1, y1, x2, y2]
|
| 239 |
+
'transcript: str
|
| 240 |
+
'head_rows':[]
|
| 241 |
+
'body_rows':[]
|
| 242 |
+
)
|
| 243 |
+
'''
|
| 244 |
+
# load data
|
| 245 |
+
with open(path[2], 'r') as f:
|
| 246 |
+
cells = json.load(f)['cells']
|
| 247 |
+
|
| 248 |
+
# first sort cells from left to right, from top to down
|
| 249 |
+
cells_pos = [] # xl1, yl1, xl2, yl2
|
| 250 |
+
contents = []
|
| 251 |
+
for cell in cells:
|
| 252 |
+
cells_pos.append([cell['start_col'], cell['start_row'], cell['end_col'], cell['end_row']])
|
| 253 |
+
contents.append(' '.join(cell['content']))
|
| 254 |
+
|
| 255 |
+
# sorted cells from left to right, from top to down
|
| 256 |
+
sorted_idx = sorted(list(range(len(cells_pos))), key=lambda idx: cells_pos[idx][0] + 1e6 * cells_pos[idx][1])
|
| 257 |
+
cells_pos = [cells_pos[idx] for idx in sorted_idx]
|
| 258 |
+
contents = [contents[idx] for idx in sorted_idx]
|
| 259 |
+
|
| 260 |
+
# layout
|
| 261 |
+
n_row = np.array(cells_pos)[:, 3].max() + 1
|
| 262 |
+
n_col = np.array(cells_pos)[:, 2].max() + 1
|
| 263 |
+
layout = np.full((n_row, n_col), -1)
|
| 264 |
+
|
| 265 |
+
# head_rows & body_rows
|
| 266 |
+
head_rows = list(range((np.array(cells_pos)[np.array(cells_pos)[:,1] == 0][:, 3] - np.array(cells_pos)[np.array(cells_pos)[:,1] == 0][:, 1]).max() + 1))
|
| 267 |
+
body_rows = list(range((np.array(cells_pos)[np.array(cells_pos)[:,1] == 0][:, 3] - np.array(cells_pos)[np.array(cells_pos)[:,1] == 0][:, 1]).max() + 1, n_row))
|
| 268 |
+
|
| 269 |
+
lt = [-1, -1]
|
| 270 |
+
cells = []
|
| 271 |
+
valid_idx = list(range(len(transcripts)))
|
| 272 |
+
|
| 273 |
+
# init start/end index of ocr results
|
| 274 |
+
start_content = ''
|
| 275 |
+
for content in contents:
|
| 276 |
+
if len(content) > 0:
|
| 277 |
+
start_content = content
|
| 278 |
+
break
|
| 279 |
+
try:
|
| 280 |
+
start_index = [cal_wer(start_content, transcript) > start for transcript in transcripts[:k]].index(True)
|
| 281 |
+
except:
|
| 282 |
+
start_index = 0
|
| 283 |
+
|
| 284 |
+
end_content = ''
|
| 285 |
+
for content in contents[::-1]:
|
| 286 |
+
if len(content) > 0:
|
| 287 |
+
end_content = content
|
| 288 |
+
break
|
| 289 |
+
try:
|
| 290 |
+
end_index = [cal_wer(end_content, transcript) > start for transcript in transcripts[::-1][:k]].index(True)
|
| 291 |
+
except:
|
| 292 |
+
end_index = 0
|
| 293 |
+
|
| 294 |
+
valid_idx = valid_idx[start_index:] if end_index == 0 else valid_idx[start_index: -end_index]
|
| 295 |
+
|
| 296 |
+
assert len(contents) >= len(valid_idx), print('OCR Results Have Error')
|
| 297 |
+
|
| 298 |
+
stop_counts = 0
|
| 299 |
+
for index, (cell_pos, content) in enumerate(zip(cells_pos, contents)):
|
| 300 |
+
# confirm the cell pos is increase
|
| 301 |
+
assert cell_pos[0] > lt[0] or cell_pos[1] > lt[1], print('Sorted Cells Have Error')
|
| 302 |
+
lt = cell_pos[:2]
|
| 303 |
+
|
| 304 |
+
xl1, yl1, xl2, yl2 = cell_pos
|
| 305 |
+
layout[yl1:yl2+1, xl1:xl2+1] = index
|
| 306 |
+
|
| 307 |
+
if len(content) == 0:
|
| 308 |
+
cells.append(dict(transcript=[]))
|
| 309 |
+
else:
|
| 310 |
+
is_completed = False
|
| 311 |
+
bboxes_list = [positions[valid_idx[0]]]
|
| 312 |
+
transcripts_list = [transcripts[valid_idx[0]]]
|
| 313 |
+
valid_idx.pop(0)
|
| 314 |
+
wer_last = cal_wer(content, ' '.join(transcripts_list))
|
| 315 |
+
if wer_last < stop:
|
| 316 |
+
bboxes_list = np.array(bboxes_list)
|
| 317 |
+
x1 = int(bboxes_list[:, :2].min())
|
| 318 |
+
x2 = int(bboxes_list[:, :2].max())
|
| 319 |
+
y1 = int(bboxes_list[:, 2:].min())
|
| 320 |
+
y2 = int(bboxes_list[:, 2:].max())
|
| 321 |
+
cells.append(dict(transcript=list(content), bbox=[x1, y1, x2, y2], segmentation=[[[x1,y1],[x2,y1],[x2,y2],[x1,y2]]]))
|
| 322 |
+
stop_counts += 1
|
| 323 |
+
continue
|
| 324 |
+
for idx in valid_idx[:k]:
|
| 325 |
+
if content == ' '.join(transcripts_list):
|
| 326 |
+
bboxes_list = np.array(bboxes_list)
|
| 327 |
+
x1 = int(bboxes_list[:, :2].min())
|
| 328 |
+
x2 = int(bboxes_list[:, :2].max())
|
| 329 |
+
y1 = int(bboxes_list[:, 2:].min())
|
| 330 |
+
y2 = int(bboxes_list[:, 2:].max())
|
| 331 |
+
cells.append(dict(transcript=list(content), bbox=[x1, y1, x2, y2], segmentation=[[[x1,y1],[x2,y1],[x2,y2],[x1,y2]]]))
|
| 332 |
+
is_completed = True
|
| 333 |
+
break
|
| 334 |
+
else:
|
| 335 |
+
cur_trans = copy.deepcopy(transcripts_list)
|
| 336 |
+
cur_trans.append(transcripts[idx])
|
| 337 |
+
wer = cal_wer(content, ' '.join(cur_trans))
|
| 338 |
+
# if add new str, and wer is not increase a lot, it should not be added in
|
| 339 |
+
if wer < wer_last + gap:
|
| 340 |
+
continue
|
| 341 |
+
else:
|
| 342 |
+
transcripts_list.append(transcripts[idx])
|
| 343 |
+
bboxes_list.append(positions[idx])
|
| 344 |
+
valid_idx.pop(valid_idx.index(idx))
|
| 345 |
+
if wer == 1.0:
|
| 346 |
+
break
|
| 347 |
+
else:
|
| 348 |
+
wer_last = wer
|
| 349 |
+
if not is_completed:
|
| 350 |
+
bboxes_list = np.array(bboxes_list)
|
| 351 |
+
x1 = int(bboxes_list[:, :2].min())
|
| 352 |
+
x2 = int(bboxes_list[:, :2].max())
|
| 353 |
+
y1 = int(bboxes_list[:, 2:].min())
|
| 354 |
+
y2 = int(bboxes_list[:, 2:].max())
|
| 355 |
+
cells.append(dict(transcript=list(content), bbox=[x1, y1, x2, y2], segmentation=[[[x1,y1],[x2,y1],[x2,y2],[x1,y2]]]))
|
| 356 |
+
|
| 357 |
+
assert stop_counts / len(contents) < stop_percent, print('This Table Has Many Error Match with OCR Results')
|
| 358 |
+
assert layout.min() == 0, print('This Table Layout is not Completely Resolved')
|
| 359 |
+
return dict(
|
| 360 |
+
layout=layout,
|
| 361 |
+
cells=cells,
|
| 362 |
+
head_rows=head_rows,
|
| 363 |
+
body_rows=body_rows,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def extract_ocr(path, positions, transcripts, k=16, start=0.333):
|
| 368 |
+
'''
|
| 369 |
+
path: [pdf_path, chunk_path, structure_path]
|
| 370 |
+
positions: [x1, x2, y1, y2],
|
| 371 |
+
transcripts: [ ]
|
| 372 |
+
retrun dict(
|
| 373 |
+
'cells':{
|
| 374 |
+
'bbox':[x1, y1, x2, y2]
|
| 375 |
+
'transcript: []
|
| 376 |
+
}
|
| 377 |
+
)
|
| 378 |
+
'''
|
| 379 |
+
# load data
|
| 380 |
+
with open(path[2], 'r') as f:
|
| 381 |
+
cells = json.load(f)['cells']
|
| 382 |
+
|
| 383 |
+
# first sort cells from left to right, from top to down
|
| 384 |
+
cells_pos = [] # xl1, yl1, xl2, yl2
|
| 385 |
+
contents = []
|
| 386 |
+
for cell in cells:
|
| 387 |
+
cells_pos.append([cell['start_col'], cell['start_row'], cell['end_col'], cell['end_row']])
|
| 388 |
+
contents.append(' '.join(cell['content']))
|
| 389 |
+
|
| 390 |
+
# sorted cells from left to right, from top to down
|
| 391 |
+
sorted_idx = sorted(list(range(len(cells_pos))), key=lambda idx: cells_pos[idx][0] + 1e6 * cells_pos[idx][1])
|
| 392 |
+
cells_pos = [cells_pos[idx] for idx in sorted_idx]
|
| 393 |
+
contents = [contents[idx] for idx in sorted_idx]
|
| 394 |
+
|
| 395 |
+
# init start/end index, condition is the first/last index must not over split, and wer should be larger than start threshold
|
| 396 |
+
valid_idx = list(range(len(transcripts)))
|
| 397 |
+
start_content = ''
|
| 398 |
+
for content in contents:
|
| 399 |
+
if len(content) > 0:
|
| 400 |
+
start_content = content
|
| 401 |
+
break
|
| 402 |
+
try:
|
| 403 |
+
start_index = [cal_wer(start_content, transcript) > start for transcript in transcripts[:k]].index(True)
|
| 404 |
+
except:
|
| 405 |
+
start_index = 0
|
| 406 |
+
|
| 407 |
+
end_content = ''
|
| 408 |
+
for content in contents[::-1]:
|
| 409 |
+
if len(content) > 0:
|
| 410 |
+
end_content = content
|
| 411 |
+
break
|
| 412 |
+
try:
|
| 413 |
+
end_index = [cal_wer(end_content, transcript) > start for transcript in transcripts[::-1][:k]].index(True)
|
| 414 |
+
except:
|
| 415 |
+
end_index = 0
|
| 416 |
+
|
| 417 |
+
valid_idx = valid_idx[start_index:] if end_index == 0 else valid_idx[start_index: -end_index]
|
| 418 |
+
|
| 419 |
+
cells = []
|
| 420 |
+
for idx in valid_idx:
|
| 421 |
+
x1, x2, y1, y2 = positions[idx].astype('int').tolist()
|
| 422 |
+
cells.append(dict(transcript=list(transcripts[idx]), bbox=[x1, y1, x2, y2], segmentation=[[[x1,y1],[x2,y1],[x2,y2],[x1,y2]]]))
|
| 423 |
+
|
| 424 |
+
return dict(
|
| 425 |
+
cells=cells
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def refine_table(table, img_path, output_dir, expand=10):
|
| 430 |
+
cells = table['cells']
|
| 431 |
+
bboxes = [cell['bbox'] for cell in table['cells'] if 'bbox' in cell.keys()]
|
| 432 |
+
bboxes = np.array(bboxes)
|
| 433 |
+
img = cv2.imread(img_path)
|
| 434 |
+
h, w, *_ = img.shape
|
| 435 |
+
x1 = int(max(0, bboxes[:, 0].min() - expand))
|
| 436 |
+
y1 = int(max(0, bboxes[:, 1].min() - expand))
|
| 437 |
+
x2 = int(min(w, bboxes[:, 2].max() + expand))
|
| 438 |
+
y2 = int(min(h, bboxes[:, 3].max() + expand))
|
| 439 |
+
# refine cells
|
| 440 |
+
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2] - x1, 0, 1e6)
|
| 441 |
+
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2] - y1, 0, 1e6)
|
| 442 |
+
bboxes = bboxes.tolist()
|
| 443 |
+
for cell, bbox in zip(cells, bboxes):
|
| 444 |
+
cell['bbox'] = bbox
|
| 445 |
+
|
| 446 |
+
img = img[y1:y2, x1:x2]
|
| 447 |
+
cv2.imwrite(os.path.join(output_dir, os.path.basename(img_path)), img)
|
| 448 |
+
table['image_path'] = os.path.join(output_dir, os.path.basename(img_path))
|
| 449 |
+
return table
|
libs/configs/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import default
|
| 2 |
+
import importlib
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class CFG:
|
| 6 |
+
def __init__(self):
|
| 7 |
+
self.__dict__['cfg'] = None
|
| 8 |
+
|
| 9 |
+
def __getattr__(self, name):
|
| 10 |
+
return getattr(self.__dict__['cfg'], name)
|
| 11 |
+
|
| 12 |
+
def __setattr__(self, name, val):
|
| 13 |
+
setattr(self.__dict__['cfg'], name, val)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
cfg = CFG()
|
| 17 |
+
cfg.__dict__['cfg'] = default
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def setup_config(cfg_name):
|
| 21 |
+
global cfg
|
| 22 |
+
module_name = 'libs.configs.' + cfg_name
|
| 23 |
+
cfg_module = importlib.import_module(module_name)
|
| 24 |
+
cfg.__dict__['cfg'] = cfg_module
|
libs/configs/default.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from libs.utils.vocab import Vocab
|
| 4 |
+
|
| 5 |
+
device = torch.device('cuda')
|
| 6 |
+
|
| 7 |
+
train_lrcs_path = [
|
| 8 |
+
"/yrfs1/intern/pfhu6/TSR/Dataset/SciTSR/train/table.lrc"
|
| 9 |
+
]
|
| 10 |
+
train_data_dir = ''
|
| 11 |
+
train_max_pixel_nums = 400 * 400 * 5
|
| 12 |
+
train_bucket_seps = (50, 50, 50)
|
| 13 |
+
train_max_batch_size = 8
|
| 14 |
+
train_num_workers = 4
|
| 15 |
+
|
| 16 |
+
valid_lrc_path = '/yrfs1/intern/pfhu6/TSR/Dataset/SciTSR/test/table.lrc'
|
| 17 |
+
valid_data_dir = ''
|
| 18 |
+
valid_num_workers = 0
|
| 19 |
+
valid_batch_size = 1
|
| 20 |
+
|
| 21 |
+
vocab = Vocab()
|
| 22 |
+
|
| 23 |
+
# model params
|
| 24 |
+
# backbone
|
| 25 |
+
arch = "res34"
|
| 26 |
+
pretrained_backbone = True
|
| 27 |
+
backbone_out_channels = (64, 128, 256, 512)
|
| 28 |
+
|
| 29 |
+
# fpn
|
| 30 |
+
fpn_out_channels = 256
|
| 31 |
+
|
| 32 |
+
# pan
|
| 33 |
+
pan_num_levels = 4
|
| 34 |
+
pan_in_dim = 256
|
| 35 |
+
pan_out_dim = 256
|
| 36 |
+
|
| 37 |
+
# row segment predictor
|
| 38 |
+
rs_scale = 1
|
| 39 |
+
|
| 40 |
+
# col segment predictor
|
| 41 |
+
cs_scale = 1
|
| 42 |
+
|
| 43 |
+
# divide predictor
|
| 44 |
+
dp_head_nums = 8
|
| 45 |
+
dp_scale = 1
|
| 46 |
+
|
| 47 |
+
# cells extractor params
|
| 48 |
+
ce_scale = 1 / 8
|
| 49 |
+
ce_pool_size = (3, 3)
|
| 50 |
+
ce_dim = 512
|
| 51 |
+
ce_head_nums = 8
|
| 52 |
+
ce_heads = 1
|
| 53 |
+
|
| 54 |
+
# decoder
|
| 55 |
+
embed_dim = 512
|
| 56 |
+
feat_dim = 512
|
| 57 |
+
lm_state_dim = 512
|
| 58 |
+
proj_dim = 512
|
| 59 |
+
cover_kernel = 7
|
| 60 |
+
att_threshold = 0.5
|
| 61 |
+
spatial_att_weight_loss_wight = 1.0
|
| 62 |
+
|
| 63 |
+
# train params
|
| 64 |
+
base_lr = 0.0001
|
| 65 |
+
min_lr = 1e-6
|
| 66 |
+
weight_decay = 0
|
| 67 |
+
|
| 68 |
+
num_epochs = 20
|
| 69 |
+
sync_rate = 20
|
| 70 |
+
|
| 71 |
+
log_sep = 20
|
| 72 |
+
|
| 73 |
+
work_dir = './experiments/heads_1'
|
| 74 |
+
|
| 75 |
+
train_checkpoint = None
|
| 76 |
+
|
| 77 |
+
eval_checkpoint = os.path.join(work_dir, 'best_f1_model.pth')
|
libs/data/__init__.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 3 |
+
from .batch_sampler import BucketSampler
|
| 4 |
+
from .dataset import LRCRecordLoader
|
| 5 |
+
from .dataset import Dataset, collate_func
|
| 6 |
+
from libs.utils.comm import distributed, get_rank, get_world_size
|
| 7 |
+
from . import transform as T
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_train_dataloader(vocab, lrcs_path, num_workers, max_batch_size, max_pixel_nums, bucket_seps, data_root_dir):
|
| 11 |
+
loaders = list()
|
| 12 |
+
for lrc_path in lrcs_path:
|
| 13 |
+
loader = LRCRecordLoader(lrc_path, data_root_dir)
|
| 14 |
+
loaders.append(loader)
|
| 15 |
+
|
| 16 |
+
transforms = T.Compose([
|
| 17 |
+
T.TableToLabel(vocab),
|
| 18 |
+
T.CalRowColSpans(),
|
| 19 |
+
T.CalCellSpans(),
|
| 20 |
+
T.CalHeadBodyDivide(),
|
| 21 |
+
T.ToTensor(),
|
| 22 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 23 |
+
])
|
| 24 |
+
|
| 25 |
+
dataset = Dataset(loaders, transforms)
|
| 26 |
+
batch_sampler = BucketSampler(dataset, get_world_size(), get_rank(), max_pixel_nums=max_pixel_nums, max_batch_size=max_batch_size,seps=bucket_seps)
|
| 27 |
+
|
| 28 |
+
dataloader = torch.utils.data.DataLoader(
|
| 29 |
+
dataset,
|
| 30 |
+
num_workers=num_workers,
|
| 31 |
+
collate_fn=collate_func,
|
| 32 |
+
batch_sampler=batch_sampler
|
| 33 |
+
)
|
| 34 |
+
return dataloader
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def create_valid_dataloader(vocab, lrc_path, num_workers, batch_size, data_root_dir):
|
| 38 |
+
loader = LRCRecordLoader(lrc_path, data_root_dir)
|
| 39 |
+
|
| 40 |
+
transforms = T.Compose([
|
| 41 |
+
T.TableToLabel(vocab),
|
| 42 |
+
T.CalRowColSpans(),
|
| 43 |
+
T.CalCellSpans(),
|
| 44 |
+
T.CalHeadBodyDivide(),
|
| 45 |
+
T.ToTensor(),
|
| 46 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 47 |
+
])
|
| 48 |
+
|
| 49 |
+
dataset = Dataset([loader], transforms)
|
| 50 |
+
if distributed():
|
| 51 |
+
sampler = DistributedSampler(dataset, get_world_size(), get_rank(), True)
|
| 52 |
+
dataloader = torch.utils.data.DataLoader(
|
| 53 |
+
dataset,
|
| 54 |
+
num_workers=num_workers,
|
| 55 |
+
batch_size=batch_size,
|
| 56 |
+
collate_fn=collate_func,
|
| 57 |
+
sampler=sampler,
|
| 58 |
+
drop_last=False
|
| 59 |
+
)
|
| 60 |
+
else:
|
| 61 |
+
dataloader = torch.utils.data.DataLoader(
|
| 62 |
+
dataset,
|
| 63 |
+
num_workers=num_workers,
|
| 64 |
+
batch_size=batch_size,
|
| 65 |
+
collate_fn=collate_func,
|
| 66 |
+
shuffle=False,
|
| 67 |
+
drop_last=False
|
| 68 |
+
)
|
| 69 |
+
return dataloader
|
libs/data/batch_sampler.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tqdm
|
| 2 |
+
import copy
|
| 3 |
+
import random
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from libs.utils import logger
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
class BucketSampler:
|
| 9 |
+
def __init__(self, dataset, world_size, rank_id, fix_batch_size=None, max_pixel_nums=None, max_batch_size=8, min_batch_size=1, seps=(100, 100, 20)):
|
| 10 |
+
self.dataset = dataset
|
| 11 |
+
self.world_size = world_size
|
| 12 |
+
self.rank_id = rank_id
|
| 13 |
+
self.seps = seps
|
| 14 |
+
self.fix_batch_size = fix_batch_size
|
| 15 |
+
self.max_batch_size = max_batch_size
|
| 16 |
+
self.min_batch_size = min_batch_size
|
| 17 |
+
self.max_pixel_nums = max_pixel_nums
|
| 18 |
+
assert (fix_batch_size is not None) or (max_pixel_nums is not None)
|
| 19 |
+
self.cal_buckets()
|
| 20 |
+
self.seed = 20
|
| 21 |
+
self.epoch = 0
|
| 22 |
+
|
| 23 |
+
def count_keys(self):
|
| 24 |
+
infos = []
|
| 25 |
+
for i in tqdm.tqdm(range(len(self.dataset))):
|
| 26 |
+
info = self.dataset.get_info(i)
|
| 27 |
+
infos.append(info)
|
| 28 |
+
return infos
|
| 29 |
+
|
| 30 |
+
def cal_buckets(self):
|
| 31 |
+
infos = self.count_keys()
|
| 32 |
+
np.save('count_keys.npy', infos)
|
| 33 |
+
min_sizes = None # (64, 18, 2)
|
| 34 |
+
max_sizes = None # (1223, 742, 2080)
|
| 35 |
+
for info in infos:
|
| 36 |
+
if min_sizes is None:
|
| 37 |
+
min_sizes = info
|
| 38 |
+
max_sizes = info
|
| 39 |
+
else: # get the max size of each item of tuple
|
| 40 |
+
min_sizes = tuple(min(min_sizes[idx], info[idx]) for idx in range(len(min_sizes)))
|
| 41 |
+
max_sizes = tuple(max(max_sizes[idx], info[idx]) for idx in range(len(max_sizes)))
|
| 42 |
+
assert (min_sizes is not None) and (len(self.seps) == len(min_sizes))
|
| 43 |
+
print('max sizes: {}, min size: {}'.format(max_sizes, min_sizes))
|
| 44 |
+
buckets = defaultdict(list)
|
| 45 |
+
for idx, info in enumerate(infos):
|
| 46 |
+
bucket_idxes = list()
|
| 47 |
+
for sep, size, min_size in zip(self.seps, info, min_sizes):
|
| 48 |
+
bucket_idx = (size - min_size) // sep
|
| 49 |
+
bucket_idxes.append(str(bucket_idx))
|
| 50 |
+
bucket_idxes = '-'.join(bucket_idxes)
|
| 51 |
+
buckets[bucket_idxes].append(idx)
|
| 52 |
+
np.save('buckets.npy', buckets)
|
| 53 |
+
|
| 54 |
+
valid_buckets = dict()
|
| 55 |
+
for bucket_key, bucket_samples in buckets.items():
|
| 56 |
+
if len(bucket_samples) < self.min_batch_size:
|
| 57 |
+
continue
|
| 58 |
+
if (self.fix_batch_size is not None) and (len(bucket_samples) < self.fix_batch_size):
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
w, h, *_ = [(int(item) + 1) * sep + min_size for item, min_size, sep in zip(bucket_key.split('-'), min_sizes, self.seps)]
|
| 62 |
+
if self.fix_batch_size is not None:
|
| 63 |
+
if h * w * self.fix_batch_size > self.max_pixel_nums:
|
| 64 |
+
continue
|
| 65 |
+
else:
|
| 66 |
+
if h * w * self.min_batch_size > self.max_pixel_nums:
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
if self.fix_batch_size is not None:
|
| 70 |
+
batch_size = self.fix_batch_size
|
| 71 |
+
else:
|
| 72 |
+
batch_size = min(self.max_batch_size, max(self.max_pixel_nums // (w * h), self.min_batch_size), len(bucket_samples))
|
| 73 |
+
|
| 74 |
+
valid_buckets[bucket_key] = dict(
|
| 75 |
+
samples=bucket_samples,
|
| 76 |
+
batch_size=batch_size
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
self.buckets = [valid_buckets[bucket_key] for bucket_key in sorted(valid_buckets.keys())]
|
| 80 |
+
total_nums = len(infos)
|
| 81 |
+
valid_nums = sum([len(item['samples']) for item in valid_buckets.values()])
|
| 82 |
+
logger.info('Total %d samples, but ignore %d samples' % (total_nums, total_nums - valid_nums))
|
| 83 |
+
|
| 84 |
+
def __iter__(self):
|
| 85 |
+
random_inst = random.Random(self.seed + self.epoch)
|
| 86 |
+
batches = list()
|
| 87 |
+
for bucket in self.buckets:
|
| 88 |
+
sample = copy.deepcopy(bucket['samples'])
|
| 89 |
+
batch_size = bucket['batch_size']
|
| 90 |
+
random_inst.shuffle(sample)
|
| 91 |
+
idx = 0
|
| 92 |
+
while idx < len(sample):
|
| 93 |
+
batch = sample[idx:idx + batch_size]
|
| 94 |
+
idx += batch_size
|
| 95 |
+
if len(batch) < self.min_batch_size:
|
| 96 |
+
continue
|
| 97 |
+
batches.append(batch)
|
| 98 |
+
random_inst.shuffle(batches)
|
| 99 |
+
|
| 100 |
+
align_nums = (len(batches) // self.world_size) * self.world_size
|
| 101 |
+
batches = batches[: align_nums]
|
| 102 |
+
for batch_idx in range(self.rank_id, len(batches), self.world_size):
|
| 103 |
+
yield batches[batch_idx]
|
| 104 |
+
|
| 105 |
+
def __len__(self):
|
| 106 |
+
batch_nums = 0
|
| 107 |
+
for bucket in self.buckets:
|
| 108 |
+
bucket_sample_nums = len(bucket["samples"])
|
| 109 |
+
bucket_bs = bucket['batch_size']
|
| 110 |
+
batch_nums += bucket_sample_nums // bucket_bs
|
| 111 |
+
if bucket_sample_nums % bucket_bs >= self.min_batch_size:
|
| 112 |
+
batch_nums += 1
|
| 113 |
+
|
| 114 |
+
return batch_nums // self.world_size
|
| 115 |
+
|
| 116 |
+
def set_epoch(self, epoch):
|
| 117 |
+
self.epoch = epoch
|
| 118 |
+
|
libs/data/dataset.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import copy
|
| 3 |
+
import json
|
| 4 |
+
import pickle
|
| 5 |
+
import random
|
| 6 |
+
from torch._C import layout
|
| 7 |
+
import tqdm
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from .list_record_cache import ListRecordLoader
|
| 12 |
+
from libs.utils.format_translate import table_to_html
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LRCRecordLoader:
|
| 16 |
+
def __init__(self, lrc_path, data_dir=''):
|
| 17 |
+
self.loader = ListRecordLoader(lrc_path)
|
| 18 |
+
self.data_root_dir = data_dir
|
| 19 |
+
|
| 20 |
+
def __len__(self):
|
| 21 |
+
return len(self.loader)
|
| 22 |
+
|
| 23 |
+
def get_info(self, idx):
|
| 24 |
+
table = self.loader.get_record(idx)
|
| 25 |
+
image = Image.open(table['image_path']).convert('RGB')
|
| 26 |
+
w = image.width
|
| 27 |
+
h = image.height
|
| 28 |
+
n_rows, n_cols = table['layout'].shape
|
| 29 |
+
n_cells = n_rows * n_cols
|
| 30 |
+
return w, h, n_cells
|
| 31 |
+
|
| 32 |
+
def get_data(self, idx):
|
| 33 |
+
table = self.loader.get_record(idx)
|
| 34 |
+
img_path = os.path.join(self.data_root_dir, table['image_path'])
|
| 35 |
+
image = Image.open(img_path).convert('RGB')
|
| 36 |
+
return image, table
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Dataset:
|
| 40 |
+
def __init__(self, loaders, transforms):
|
| 41 |
+
self.loaders = loaders
|
| 42 |
+
self.transforms = transforms
|
| 43 |
+
|
| 44 |
+
def _match_loader(self, idx):
|
| 45 |
+
offset = 0
|
| 46 |
+
for loader in self.loaders:
|
| 47 |
+
if len(loader) + offset > idx:
|
| 48 |
+
return loader, idx - offset
|
| 49 |
+
else:
|
| 50 |
+
offset += len(loader)
|
| 51 |
+
raise IndexError()
|
| 52 |
+
|
| 53 |
+
def get_info(self, idx):
|
| 54 |
+
loader, rela_idx = self._match_loader(idx)
|
| 55 |
+
return loader.get_info(rela_idx)
|
| 56 |
+
|
| 57 |
+
def __len__(self):
|
| 58 |
+
return sum([len(loader) for loader in self.loaders])
|
| 59 |
+
|
| 60 |
+
def __getitem__(self,idx):
|
| 61 |
+
try:
|
| 62 |
+
loader, rela_idx = self._match_loader(idx)
|
| 63 |
+
image, table = loader.get_data(rela_idx)
|
| 64 |
+
image, _, cls_label, \
|
| 65 |
+
rows_fg_span, rows_bg_span, \
|
| 66 |
+
cols_fg_span, cols_bg_span, \
|
| 67 |
+
cells_span, divide = self.transforms(image, table) if 'layout' in table.keys() else self.transforms(image)
|
| 68 |
+
return dict(
|
| 69 |
+
id=idx,
|
| 70 |
+
image_size=(image.shape[2], image.shape[1]),
|
| 71 |
+
image=image,
|
| 72 |
+
cls_label=cls_label,
|
| 73 |
+
rows_fg_span=rows_fg_span,
|
| 74 |
+
rows_bg_span=rows_bg_span,
|
| 75 |
+
cols_fg_span=cols_fg_span,
|
| 76 |
+
cols_bg_span=cols_bg_span,
|
| 77 |
+
cells_span=cells_span,
|
| 78 |
+
layout=table['layout'] if 'layout' in table.keys() else None,
|
| 79 |
+
divide=divide,
|
| 80 |
+
table=table
|
| 81 |
+
)
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print('Error occured while load data: %d' % idx)
|
| 84 |
+
raise e
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def collate_func(batch_data):
|
| 88 |
+
batch_size = len(batch_data)
|
| 89 |
+
|
| 90 |
+
image_dim = batch_data[0]['image'].shape[0]
|
| 91 |
+
max_h = max([data['image'].shape[1] for data in batch_data])
|
| 92 |
+
max_w = max([data['image'].shape[2] for data in batch_data])
|
| 93 |
+
|
| 94 |
+
batch_id = list()
|
| 95 |
+
batch_image_size = list()
|
| 96 |
+
|
| 97 |
+
batch_image = torch.zeros([batch_size, image_dim, max_h, max_w], dtype=torch.float)
|
| 98 |
+
batch_image_mask = torch.zeros([batch_size, 1, max_h, max_w], dtype=torch.float)
|
| 99 |
+
batch_rows_fg_span = list()
|
| 100 |
+
batch_rows_bg_span = list()
|
| 101 |
+
batch_cols_fg_span = list()
|
| 102 |
+
batch_cols_bg_span = list()
|
| 103 |
+
batch_cells_span = list()
|
| 104 |
+
batch_divide = list()
|
| 105 |
+
tables = list()
|
| 106 |
+
|
| 107 |
+
if all([(data['cls_label'] is None) and (data['layout'] is None) for data in batch_data]):
|
| 108 |
+
batch_cls_label = list()
|
| 109 |
+
batch_label_mask = list()
|
| 110 |
+
batch_layout = list()
|
| 111 |
+
else:
|
| 112 |
+
assert not any([(data['cls_label'] is None) or (data['layout'] is None) for data in batch_data])
|
| 113 |
+
max_label_length = max([data['cls_label'].shape[0] for data in batch_data])
|
| 114 |
+
batch_cls_label = torch.zeros([batch_size, max_label_length], dtype=torch.long)
|
| 115 |
+
batch_label_mask = torch.zeros([batch_size, max_label_length], dtype=torch.float)
|
| 116 |
+
max_nr = max([data['layout'].shape[0] for data in batch_data])
|
| 117 |
+
max_nc = max([data['layout'].shape[1] for data in batch_data])
|
| 118 |
+
batch_layout = torch.full([batch_size, max_nr, max_nc], -1, dtype=torch.float)
|
| 119 |
+
|
| 120 |
+
for batch_idx, data in enumerate(batch_data):
|
| 121 |
+
batch_id.append(data['id'])
|
| 122 |
+
batch_image_size.append(data['image_size'])
|
| 123 |
+
|
| 124 |
+
_, cur_h, cur_w = data['image'].shape
|
| 125 |
+
batch_image[batch_idx, :, :cur_h, :cur_w] = data["image"]
|
| 126 |
+
batch_image_mask[batch_idx, :, :cur_h, :cur_w] = 1
|
| 127 |
+
|
| 128 |
+
if (data['cls_label'] is None) and (data['layout'] is None):
|
| 129 |
+
batch_cls_label.append(data["cls_label"])
|
| 130 |
+
batch_label_mask.append(None)
|
| 131 |
+
batch_layout.append(data["layout"])
|
| 132 |
+
else:
|
| 133 |
+
label_length = data['cls_label'].shape[0]
|
| 134 |
+
batch_cls_label[batch_idx, :label_length] = data['cls_label']
|
| 135 |
+
batch_label_mask[batch_idx, :label_length] = 1.0
|
| 136 |
+
layout_nr, layout_nc = data["layout" ].shape
|
| 137 |
+
batch_layout[batch_idx, :layout_nr, :layout_nc] = torch.from_numpy(data['layout']).float()
|
| 138 |
+
|
| 139 |
+
batch_rows_fg_span.append(data["rows_fg_span"])
|
| 140 |
+
batch_rows_bg_span.append(data['rows_bg_span'])
|
| 141 |
+
batch_cols_fg_span.append(data["cols_fg_span"])
|
| 142 |
+
batch_cols_bg_span.append(data["cols_bg_span"])
|
| 143 |
+
batch_cells_span.append(data["cells_span"])
|
| 144 |
+
batch_divide.append(data["divide"])
|
| 145 |
+
tables.append(data['table'])
|
| 146 |
+
|
| 147 |
+
batch_divide = torch.tensor(batch_divide, dtype=torch.long) if batch_divide[0] is not None else batch_divide
|
| 148 |
+
|
| 149 |
+
return dict(
|
| 150 |
+
ids=batch_id,
|
| 151 |
+
images_size=batch_image_size,
|
| 152 |
+
images=batch_image,
|
| 153 |
+
images_mask=batch_image_mask,
|
| 154 |
+
cls_labels=batch_cls_label,
|
| 155 |
+
labels_mask=batch_label_mask,
|
| 156 |
+
rows_fg_spans=batch_rows_fg_span,
|
| 157 |
+
rows_bg_spans=batch_rows_bg_span,
|
| 158 |
+
cols_fg_spans=batch_cols_fg_span,
|
| 159 |
+
cols_bg_spans=batch_cols_bg_span,
|
| 160 |
+
cells_spans=batch_cells_span,
|
| 161 |
+
divide_labels=batch_divide,
|
| 162 |
+
layouts=batch_layout,
|
| 163 |
+
tables=tables
|
| 164 |
+
)
|
libs/data/list_record_cache.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import threading
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def merge_record_file(files, dst_file):
|
| 7 |
+
cmd = 'cat'
|
| 8 |
+
for file in files:
|
| 9 |
+
cmd += ' %s' % file
|
| 10 |
+
cmd += ' > %s' % dst_file
|
| 11 |
+
os.system(cmd)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ListRecordCacher:
|
| 15 |
+
OFFSET_LENGTH = 8
|
| 16 |
+
def __init__(self, cache_path):
|
| 17 |
+
self._record_pos_list = list()
|
| 18 |
+
self._cache_file = open(cache_path, 'wb')
|
| 19 |
+
self._cached_bytes = b'\x00' * self.OFFSET_LENGTH
|
| 20 |
+
|
| 21 |
+
def add_record(self, record):
|
| 22 |
+
record_bytes = pickle.dumps(record)
|
| 23 |
+
return self.add_record_bytes(record_bytes)
|
| 24 |
+
|
| 25 |
+
def add_record_bytes(self, record_bytes):
|
| 26 |
+
bytes_size = len(record_bytes)
|
| 27 |
+
offset_bytes = bytes_size.to_bytes(
|
| 28 |
+
length=self.OFFSET_LENGTH,
|
| 29 |
+
byteorder='big', signed=False
|
| 30 |
+
)
|
| 31 |
+
total_bytes = offset_bytes + record_bytes
|
| 32 |
+
|
| 33 |
+
cur_record_pos = None
|
| 34 |
+
if len(self._record_pos_list) == 0:
|
| 35 |
+
cur_record_pos = [self.OFFSET_LENGTH*2, bytes_size]
|
| 36 |
+
else:
|
| 37 |
+
cur_record_pos = [sum(self._record_pos_list[-1]) + self.OFFSET_LENGTH, bytes_size]
|
| 38 |
+
self._record_pos_list.append(cur_record_pos)
|
| 39 |
+
|
| 40 |
+
self._cached_bytes += total_bytes
|
| 41 |
+
if len(self._cached_bytes) > 1024*1024:
|
| 42 |
+
self._cache_file.seek(0, 2)
|
| 43 |
+
self._cache_file.write(self._cached_bytes)
|
| 44 |
+
self._cached_bytes = b''
|
| 45 |
+
|
| 46 |
+
def flush(self):
|
| 47 |
+
if len(self._cached_bytes) > 0:
|
| 48 |
+
self._cache_file.seek(0, 2)
|
| 49 |
+
self._cache_file.write(self._cached_bytes)
|
| 50 |
+
self._cached_bytes = b''
|
| 51 |
+
|
| 52 |
+
def _wirte_record_pos_list(self):
|
| 53 |
+
self.flush()
|
| 54 |
+
self._cache_file.seek(0, 2)
|
| 55 |
+
offset = self._cache_file.tell()
|
| 56 |
+
offset_bytes = offset.to_bytes(
|
| 57 |
+
length=self.OFFSET_LENGTH,
|
| 58 |
+
byteorder='big', signed=False
|
| 59 |
+
)
|
| 60 |
+
self._cache_file.seek(0)
|
| 61 |
+
self._cache_file.write(offset_bytes)
|
| 62 |
+
|
| 63 |
+
data_bytes = pickle.dumps(self._record_pos_list)
|
| 64 |
+
bytes_size = len(data_bytes)
|
| 65 |
+
offset_bytes = bytes_size.to_bytes(
|
| 66 |
+
length=self.OFFSET_LENGTH,
|
| 67 |
+
byteorder='big', signed=False
|
| 68 |
+
)
|
| 69 |
+
total_bytes = offset_bytes + data_bytes
|
| 70 |
+
self._cache_file.seek(0, 2)
|
| 71 |
+
self._cache_file.write(total_bytes)
|
| 72 |
+
|
| 73 |
+
def close(self):
|
| 74 |
+
if not self._cache_file.closed:
|
| 75 |
+
self._wirte_record_pos_list()
|
| 76 |
+
self._cache_file.close()
|
| 77 |
+
|
| 78 |
+
def __del__(self):
|
| 79 |
+
self.close()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ListRecordLoader:
|
| 83 |
+
OFFSET_LENGTH = 8
|
| 84 |
+
def __init__(self, load_path):
|
| 85 |
+
self._sync_lock = threading.Lock()
|
| 86 |
+
self._size = os.path.getsize(load_path)
|
| 87 |
+
self._load_path = load_path
|
| 88 |
+
self._open_file()
|
| 89 |
+
self._scan_file()
|
| 90 |
+
|
| 91 |
+
def _open_file(self):
|
| 92 |
+
self._pid = os.getpid()
|
| 93 |
+
self._cache_file = open(self._load_path, 'rb')
|
| 94 |
+
|
| 95 |
+
def _check_reopen(self):
|
| 96 |
+
if (self._pid != os.getpid()):
|
| 97 |
+
self._open_file()
|
| 98 |
+
|
| 99 |
+
def _scan_file(self):
|
| 100 |
+
record_pos_list = list()
|
| 101 |
+
pos = 0
|
| 102 |
+
while True:
|
| 103 |
+
if pos >= self._size:
|
| 104 |
+
break
|
| 105 |
+
self._cache_file.seek(pos)
|
| 106 |
+
offset = int().from_bytes(
|
| 107 |
+
self._cache_file.read(self.OFFSET_LENGTH),
|
| 108 |
+
byteorder='big', signed=False
|
| 109 |
+
)
|
| 110 |
+
offset = pos + offset
|
| 111 |
+
self._cache_file.seek(offset)
|
| 112 |
+
|
| 113 |
+
byte_size = int().from_bytes(
|
| 114 |
+
self._cache_file.read(self.OFFSET_LENGTH),
|
| 115 |
+
byteorder='big', signed=False
|
| 116 |
+
)
|
| 117 |
+
record_pos_list_bytes = self._cache_file.read(byte_size)
|
| 118 |
+
sub_record_pos_list = pickle.loads(record_pos_list_bytes)
|
| 119 |
+
assert isinstance(sub_record_pos_list, list)
|
| 120 |
+
sub_record_pos_list = [[item[0]+pos, item[1]] for item in sub_record_pos_list]
|
| 121 |
+
record_pos_list.extend(sub_record_pos_list)
|
| 122 |
+
pos = self._cache_file.tell()
|
| 123 |
+
|
| 124 |
+
self._record_pos_list = record_pos_list
|
| 125 |
+
|
| 126 |
+
def get_record(self, idx):
|
| 127 |
+
self._check_reopen()
|
| 128 |
+
record_bytes = self.get_record_bytes(idx)
|
| 129 |
+
record = pickle.loads(record_bytes)
|
| 130 |
+
return record
|
| 131 |
+
|
| 132 |
+
def get_record_bytes(self, idx):
|
| 133 |
+
offset, length = self._record_pos_list[idx]
|
| 134 |
+
self._sync_lock.acquire()
|
| 135 |
+
try:
|
| 136 |
+
self._cache_file.seek(offset)
|
| 137 |
+
record_bytes = self._cache_file.read(length)
|
| 138 |
+
finally:
|
| 139 |
+
self._sync_lock.release()
|
| 140 |
+
return record_bytes
|
| 141 |
+
|
| 142 |
+
def __len__(self):
|
| 143 |
+
return len(self._record_pos_list)
|
libs/data/transform.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torchvision.transforms import functional as F
|
| 6 |
+
from libs.utils.format_translate import table_to_latex
|
| 7 |
+
from .utils import extract_fg_bg_spans, cal_cell_spans
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Compose:
|
| 11 |
+
def __init__(self, transforms):
|
| 12 |
+
self.transforms = transforms
|
| 13 |
+
|
| 14 |
+
def __call__(self, *data):
|
| 15 |
+
for transform in self.transforms:
|
| 16 |
+
data = transform(*data)
|
| 17 |
+
return data
|
| 18 |
+
|
| 19 |
+
class TableToLabel:
|
| 20 |
+
def __init__(self, vocab):
|
| 21 |
+
self.vocab = vocab
|
| 22 |
+
|
| 23 |
+
def __call__(self, image, table=None):
|
| 24 |
+
if table is None:
|
| 25 |
+
return image, None, None
|
| 26 |
+
latex = table_to_latex(table) # image.size = (w, h)
|
| 27 |
+
cls_label = self.vocab.words_to_ids(latex)
|
| 28 |
+
return image, table, cls_label
|
| 29 |
+
|
| 30 |
+
class CalRowColSpans:
|
| 31 |
+
def __call__(self, image, table=None, cls_label=None):
|
| 32 |
+
if table is None:
|
| 33 |
+
return image, table, None, None, None, None, None
|
| 34 |
+
image_size = (image.width, image.height)
|
| 35 |
+
rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span = extract_fg_bg_spans(table, image_size)
|
| 36 |
+
return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span
|
| 37 |
+
|
| 38 |
+
class CalCellSpans:
|
| 39 |
+
def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None):
|
| 40 |
+
if table is not None:
|
| 41 |
+
cells_span = cal_cell_spans(table)
|
| 42 |
+
else:
|
| 43 |
+
cells_span = None
|
| 44 |
+
return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span
|
| 45 |
+
|
| 46 |
+
class CalHeadBodyDivide:
|
| 47 |
+
def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None, cells_span=None):
|
| 48 |
+
if table is None:
|
| 49 |
+
divide = None
|
| 50 |
+
else:
|
| 51 |
+
head_rows = table['head_rows']
|
| 52 |
+
divide = len(head_rows)
|
| 53 |
+
return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span, divide
|
| 54 |
+
|
| 55 |
+
class ToTensor:
|
| 56 |
+
def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None, cells_span=None, divide=None):
|
| 57 |
+
image = F.to_tensor(image)
|
| 58 |
+
if cls_label is not None:
|
| 59 |
+
cls_label = torch.tensor(cls_label, dtype=torch.long)
|
| 60 |
+
return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span, divide
|
| 61 |
+
|
| 62 |
+
class Normalize:
|
| 63 |
+
def __init__(self, mean, std, inplace=False):
|
| 64 |
+
self.mean = mean
|
| 65 |
+
self.std = std
|
| 66 |
+
self.inplace = inplace
|
| 67 |
+
|
| 68 |
+
def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None, cells_span=None, divide=None):
|
| 69 |
+
image = F.normalize(image, self.mean, self.std, self.inplace)
|
| 70 |
+
return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span, divide
|
libs/data/utils.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from libs.utils.format_translate import table_to_html, format_html
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class InvalidFormat(Exception):
|
| 7 |
+
pass
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def segmentation_to_bbox(segmentation):
|
| 11 |
+
x1 = min([pt[0] for contour in segmentation for pt in contour])
|
| 12 |
+
y1 = min([pt[1] for contour in segmentation for pt in contour])
|
| 13 |
+
x2 = max([pt[0] for contour in segmentation for pt in contour])
|
| 14 |
+
y2 = max([pt[1] for contour in segmentation for pt in contour])
|
| 15 |
+
return (x1, y1, x2, y2)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def cal_cell_bbox(table):
|
| 19 |
+
cells_bbox = list()
|
| 20 |
+
for cell in table['cells']:
|
| 21 |
+
if 'segmentation' not in cell:
|
| 22 |
+
cell_bbox = None
|
| 23 |
+
else:
|
| 24 |
+
segmentation = list()
|
| 25 |
+
if 'sublines' in cell:
|
| 26 |
+
for subline in cell['sublines']:
|
| 27 |
+
segmentation.extend(subline['segmentation'])
|
| 28 |
+
if len(segmentation) == 0:
|
| 29 |
+
segmentation = cell['segmentation']
|
| 30 |
+
if len(segmentation) == 0:
|
| 31 |
+
cell_bbox = None
|
| 32 |
+
else:
|
| 33 |
+
cell_bbox = segmentation_to_bbox(segmentation)
|
| 34 |
+
cells_bbox.append(cell_bbox)
|
| 35 |
+
return cells_bbox
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def cal_cell_spans(table):
|
| 39 |
+
layout = table['layout']
|
| 40 |
+
num_cells = len(table['cells'])
|
| 41 |
+
cells_span = list()
|
| 42 |
+
for cell_id in range(num_cells):
|
| 43 |
+
cell_positions = np.argwhere(layout == cell_id)
|
| 44 |
+
y1 = np.min(cell_positions[:, 0])
|
| 45 |
+
y2 = np.max(cell_positions[:, 0])
|
| 46 |
+
x1 = np.min(cell_positions[:, 1])
|
| 47 |
+
x2 = np.max(cell_positions[:, 1])
|
| 48 |
+
assert np.all(layout[y1:y2, x1:x2] == cell_id)
|
| 49 |
+
cells_span.append([x1, y1, x2, y2])
|
| 50 |
+
return cells_span
|
| 51 |
+
|
| 52 |
+
def cal_fg_bg_span(spans, edge):
|
| 53 |
+
num_span = len(spans)
|
| 54 |
+
bg_spans = list()
|
| 55 |
+
for idx in range(num_span):
|
| 56 |
+
if spans[idx] is None:
|
| 57 |
+
continue
|
| 58 |
+
if idx == 0:
|
| 59 |
+
if spans[idx][0] <= 0:
|
| 60 |
+
continue
|
| 61 |
+
else:
|
| 62 |
+
if spans[idx-1] is None:
|
| 63 |
+
continue
|
| 64 |
+
if spans[idx][0] <= spans[idx-1][1]:
|
| 65 |
+
continue
|
| 66 |
+
if idx == num_span - 1:
|
| 67 |
+
if spans[idx][1] >= edge:
|
| 68 |
+
continue
|
| 69 |
+
else:
|
| 70 |
+
if spans[idx+1] is None:
|
| 71 |
+
continue
|
| 72 |
+
if spans[idx][1] >= spans[idx+1][0]:
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
bg_spans.append(spans[idx])
|
| 76 |
+
|
| 77 |
+
fg_spans = list()
|
| 78 |
+
for idx in range(num_span+1):
|
| 79 |
+
if idx == 0:
|
| 80 |
+
s = 0
|
| 81 |
+
else:
|
| 82 |
+
if spans[idx-1] is None:
|
| 83 |
+
continue
|
| 84 |
+
s = spans[idx-1][1]
|
| 85 |
+
|
| 86 |
+
if idx == num_span:
|
| 87 |
+
e = edge
|
| 88 |
+
else:
|
| 89 |
+
if spans[idx] is None:
|
| 90 |
+
continue
|
| 91 |
+
e = spans[idx][0]
|
| 92 |
+
|
| 93 |
+
if e <= s:
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
fg_spans.append([s, e])
|
| 97 |
+
|
| 98 |
+
return fg_spans, bg_spans
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def shrink_spans(spans, size):
|
| 102 |
+
new_spans = list()
|
| 103 |
+
for idx, (start, end) in enumerate(spans):
|
| 104 |
+
if idx == 0:
|
| 105 |
+
if start <= 0:
|
| 106 |
+
start = 1
|
| 107 |
+
else:
|
| 108 |
+
_, pre_end = spans[idx - 1]
|
| 109 |
+
if start <= pre_end:
|
| 110 |
+
shrink_distance = pre_end - start + 1
|
| 111 |
+
start = start + math.ceil(shrink_distance / 2)
|
| 112 |
+
|
| 113 |
+
if idx == len(spans) - 1:
|
| 114 |
+
if end >= size:
|
| 115 |
+
end = size - 1
|
| 116 |
+
else:
|
| 117 |
+
next_start, _ = spans[idx + 1]
|
| 118 |
+
if end >= next_start:
|
| 119 |
+
shrink_distance = end - next_start + 1
|
| 120 |
+
end = end - math.ceil(shrink_distance / 2)
|
| 121 |
+
if end - start < 1:
|
| 122 |
+
raise InvalidFormat()
|
| 123 |
+
|
| 124 |
+
new_spans.append([start, end])
|
| 125 |
+
return new_spans
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def cal_row_span(table, cells_span, cells_bbox, height):
|
| 129 |
+
layout = table['layout']
|
| 130 |
+
rows_span = list()
|
| 131 |
+
for row_idx in range(layout.shape[0]):
|
| 132 |
+
row = layout[row_idx, :]
|
| 133 |
+
y1s = list()
|
| 134 |
+
y2s = list()
|
| 135 |
+
for cell_id in row:
|
| 136 |
+
cell_span = cells_span[cell_id]
|
| 137 |
+
cell_bbox = cells_bbox[cell_id]
|
| 138 |
+
if (cell_span[1] == row_idx) and (cell_bbox is not None):
|
| 139 |
+
y1s.append(cell_bbox[1])
|
| 140 |
+
if (cell_span[3] == row_idx) and (cell_bbox is not None):
|
| 141 |
+
y2s.append(cell_bbox[3])
|
| 142 |
+
|
| 143 |
+
if (len(y1s) > 0) and (len(y2s) > 0):
|
| 144 |
+
y1 = min(max(1, min(y1s)), height - 1)
|
| 145 |
+
y2 = min(max(1,max(y2s) + 1), height - 1)
|
| 146 |
+
rows_span.append([y1, y2])
|
| 147 |
+
else:
|
| 148 |
+
raise InvalidFormat()
|
| 149 |
+
rows_span = shrink_spans(rows_span, height)
|
| 150 |
+
rows_fg_span, rows_bg_span = cal_fg_bg_span(rows_span, height)
|
| 151 |
+
return rows_fg_span, rows_bg_span
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def cal_col_span(table, cells_span, cells_bbox, width):
|
| 155 |
+
layout = table['layout']
|
| 156 |
+
cols_span = list()
|
| 157 |
+
for col_idx in range(layout.shape[1]):
|
| 158 |
+
col = layout[:, col_idx]
|
| 159 |
+
x1s = list()
|
| 160 |
+
x2s = list()
|
| 161 |
+
for cell_id in col:
|
| 162 |
+
cell_span = cells_span[cell_id]
|
| 163 |
+
cell_bbox = cells_bbox[cell_id]
|
| 164 |
+
if (cell_span[0] == col_idx) and (cell_bbox is not None):
|
| 165 |
+
x1s.append(cell_bbox[0])
|
| 166 |
+
if (cell_span[2] == col_idx) and (cell_bbox is not None):
|
| 167 |
+
x2s.append(cell_bbox[2])
|
| 168 |
+
|
| 169 |
+
if (len(x1s) > 0) and (len(x2s) > 0):
|
| 170 |
+
x1 = min(max(1, min(x1s)), width - 1)
|
| 171 |
+
x2 = min(max(1, max(x2s) + 1), width - 1)
|
| 172 |
+
cols_span.append([x1, x2])
|
| 173 |
+
else:
|
| 174 |
+
raise InvalidFormat()
|
| 175 |
+
cols_span = shrink_spans(cols_span, width)
|
| 176 |
+
cols_fg_span, cols_bg_span = cal_fg_bg_span(cols_span, width)
|
| 177 |
+
return cols_fg_span, cols_bg_span
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def extract_fg_bg_spans(table, image_size):
|
| 181 |
+
width, height = image_size
|
| 182 |
+
cells_bbox = cal_cell_bbox(table)
|
| 183 |
+
cells_span = cal_cell_spans(table)
|
| 184 |
+
# cal rows fg bg span
|
| 185 |
+
rows_fg_span, rows_bg_span = cal_row_span(table, cells_span, cells_bbox, height)
|
| 186 |
+
#cal cols fg bg span
|
| 187 |
+
cols_fg_span, cols_bg_span = cal_col_span(table, cells_span, cells_bbox, width)
|
| 188 |
+
return rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span
|
libs/model/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model import Model
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from libs.utils.comm import get_world_size
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def build_model(cfg):
|
| 8 |
+
if get_world_size() == 1:
|
| 9 |
+
norm_layer = nn.BatchNorm2d
|
| 10 |
+
else:
|
| 11 |
+
norm_layer = nn.BatchNorm2d
|
| 12 |
+
model = Model(
|
| 13 |
+
cfg,
|
| 14 |
+
norm_layer=norm_layer
|
| 15 |
+
)
|
| 16 |
+
return model
|
libs/model/backbone.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from typing import Type, Any, Callable, Union, List, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
model_paths = {
|
| 9 |
+
'resnet18': '/yrfs2/cv6/frwang/PretrainedModelParams/pytorch/ImageNet/ResNet/resnet18-5c106cde.pth',
|
| 10 |
+
'resnet34': '/Pretrain/resnet_34.pth',
|
| 11 |
+
'resnet50': '/yrfs2/cv6/frwang/PretrainedModelParams/pytorch/ImageNet/ResNet/resnet50-19c8e357.pth',
|
| 12 |
+
'resnet101': '/yrfs2/cv6/frwang/PretrainedModelParams/pytorch/ImageNet/ResNet/resnet101-5d3b4d8f.pth',
|
| 13 |
+
'resnet152': '/yrfs2/cv6/frwang/PretrainedModelParams/pytorch/ImageNet/ResNet/resnet152-b121ed2d.pth',
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
| 18 |
+
"""3x3 convolution with padding"""
|
| 19 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 20 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation, padding_mode='reflect')
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
| 24 |
+
"""1x1 convolution"""
|
| 25 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BasicBlock(nn.Module):
|
| 29 |
+
expansion: int = 1
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
inplanes: int,
|
| 34 |
+
planes: int,
|
| 35 |
+
stride: int = 1,
|
| 36 |
+
downsample: Optional[nn.Module] = None,
|
| 37 |
+
groups: int = 1,
|
| 38 |
+
base_width: int = 64,
|
| 39 |
+
dilation: int = 1,
|
| 40 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None
|
| 41 |
+
) -> None:
|
| 42 |
+
super(BasicBlock, self).__init__()
|
| 43 |
+
if norm_layer is None:
|
| 44 |
+
norm_layer = nn.BatchNorm2d
|
| 45 |
+
if groups != 1 or base_width != 64:
|
| 46 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
| 47 |
+
if dilation > 1:
|
| 48 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 49 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 50 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 51 |
+
self.bn1 = norm_layer(planes)
|
| 52 |
+
self.relu = nn.ReLU(inplace=True)
|
| 53 |
+
self.conv2 = conv3x3(planes, planes)
|
| 54 |
+
self.bn2 = norm_layer(planes)
|
| 55 |
+
self.downsample = downsample
|
| 56 |
+
self.stride = stride
|
| 57 |
+
|
| 58 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 59 |
+
identity = x
|
| 60 |
+
|
| 61 |
+
out = self.conv1(x)
|
| 62 |
+
out = self.bn1(out)
|
| 63 |
+
out = self.relu(out)
|
| 64 |
+
|
| 65 |
+
out = self.conv2(out)
|
| 66 |
+
out = self.bn2(out)
|
| 67 |
+
|
| 68 |
+
if self.downsample is not None:
|
| 69 |
+
identity = self.downsample(x)
|
| 70 |
+
|
| 71 |
+
out += identity
|
| 72 |
+
out = self.relu(out)
|
| 73 |
+
|
| 74 |
+
return out
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Bottleneck(nn.Module):
|
| 78 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
| 79 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
| 80 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
| 81 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
| 82 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
| 83 |
+
|
| 84 |
+
expansion: int = 4
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
inplanes: int,
|
| 89 |
+
planes: int,
|
| 90 |
+
stride: int = 1,
|
| 91 |
+
downsample: Optional[nn.Module] = None,
|
| 92 |
+
groups: int = 1,
|
| 93 |
+
base_width: int = 64,
|
| 94 |
+
dilation: int = 1,
|
| 95 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None
|
| 96 |
+
) -> None:
|
| 97 |
+
super(Bottleneck, self).__init__()
|
| 98 |
+
if norm_layer is None:
|
| 99 |
+
norm_layer = nn.BatchNorm2d
|
| 100 |
+
width = int(planes * (base_width / 64.)) * groups
|
| 101 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
| 102 |
+
self.conv1 = conv1x1(inplanes, width)
|
| 103 |
+
self.bn1 = norm_layer(width)
|
| 104 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
| 105 |
+
self.bn2 = norm_layer(width)
|
| 106 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
| 107 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 108 |
+
self.relu = nn.ReLU(inplace=True)
|
| 109 |
+
self.downsample = downsample
|
| 110 |
+
self.stride = stride
|
| 111 |
+
|
| 112 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 113 |
+
identity = x
|
| 114 |
+
|
| 115 |
+
out = self.conv1(x)
|
| 116 |
+
out = self.bn1(out)
|
| 117 |
+
out = self.relu(out)
|
| 118 |
+
|
| 119 |
+
out = self.conv2(out)
|
| 120 |
+
out = self.bn2(out)
|
| 121 |
+
out = self.relu(out)
|
| 122 |
+
|
| 123 |
+
out = self.conv3(out)
|
| 124 |
+
out = self.bn3(out)
|
| 125 |
+
|
| 126 |
+
if self.downsample is not None:
|
| 127 |
+
identity = self.downsample(x)
|
| 128 |
+
|
| 129 |
+
out += identity
|
| 130 |
+
out = self.relu(out)
|
| 131 |
+
|
| 132 |
+
return out
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class ResNet(nn.Module):
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
block: Type[Union[BasicBlock, Bottleneck]],
|
| 140 |
+
layers: List[int],
|
| 141 |
+
zero_init_residual: bool = False,
|
| 142 |
+
groups: int = 1,
|
| 143 |
+
width_per_group: int = 64,
|
| 144 |
+
replace_stride_with_dilation: Optional[List[bool]] = None,
|
| 145 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None
|
| 146 |
+
) -> None:
|
| 147 |
+
super(ResNet, self).__init__()
|
| 148 |
+
if norm_layer is None:
|
| 149 |
+
norm_layer = nn.BatchNorm2d
|
| 150 |
+
self._norm_layer = norm_layer
|
| 151 |
+
|
| 152 |
+
self.inplanes = 64
|
| 153 |
+
self.dilation = 1
|
| 154 |
+
if replace_stride_with_dilation is None:
|
| 155 |
+
# each element in the tuple indicates if we should replace
|
| 156 |
+
# the 2x2 stride with a dilated convolution instead
|
| 157 |
+
replace_stride_with_dilation = [False, False, False]
|
| 158 |
+
if len(replace_stride_with_dilation) != 3:
|
| 159 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
| 160 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
| 161 |
+
self.groups = groups
|
| 162 |
+
self.base_width = width_per_group
|
| 163 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=1, padding=3,
|
| 164 |
+
bias=False, padding_mode='reflect')
|
| 165 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 166 |
+
self.relu = nn.ReLU(inplace=True)
|
| 167 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 168 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
| 169 |
+
dilate=replace_stride_with_dilation[0])
|
| 170 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
| 171 |
+
dilate=replace_stride_with_dilation[1])
|
| 172 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
| 173 |
+
dilate=replace_stride_with_dilation[2])
|
| 174 |
+
|
| 175 |
+
for m in self.modules():
|
| 176 |
+
if isinstance(m, nn.Conv2d):
|
| 177 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 178 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 179 |
+
nn.init.constant_(m.weight, 1)
|
| 180 |
+
nn.init.constant_(m.bias, 0)
|
| 181 |
+
|
| 182 |
+
# Zero-initialize the last BN in each residual branch,
|
| 183 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 184 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 185 |
+
if zero_init_residual:
|
| 186 |
+
for m in self.modules():
|
| 187 |
+
if isinstance(m, Bottleneck):
|
| 188 |
+
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
|
| 189 |
+
elif isinstance(m, BasicBlock):
|
| 190 |
+
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
| 191 |
+
|
| 192 |
+
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
|
| 193 |
+
stride: int = 1, dilate: bool = False) -> nn.Sequential:
|
| 194 |
+
norm_layer = self._norm_layer
|
| 195 |
+
downsample = None
|
| 196 |
+
previous_dilation = self.dilation
|
| 197 |
+
if dilate:
|
| 198 |
+
self.dilation *= stride
|
| 199 |
+
stride = 1
|
| 200 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 201 |
+
downsample = nn.Sequential(
|
| 202 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 203 |
+
norm_layer(planes * block.expansion),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
layers = []
|
| 207 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
| 208 |
+
self.base_width, previous_dilation, norm_layer))
|
| 209 |
+
self.inplanes = planes * block.expansion
|
| 210 |
+
for _ in range(1, blocks):
|
| 211 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
| 212 |
+
base_width=self.base_width, dilation=self.dilation,
|
| 213 |
+
norm_layer=norm_layer))
|
| 214 |
+
|
| 215 |
+
return nn.Sequential(*layers)
|
| 216 |
+
|
| 217 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
| 218 |
+
input = x # (512, 256)
|
| 219 |
+
x = self.conv1(x) # (256, 128) # stride=2
|
| 220 |
+
x = self.bn1(x) # (256, 128)
|
| 221 |
+
x = self.relu(x) # (256, 128)
|
| 222 |
+
|
| 223 |
+
c2 = self.layer1(x) # (256, 128) # stride=1, total_stride=2
|
| 224 |
+
c3 = self.layer2(c2) # (128, 64) # stride=2, total_stride=4
|
| 225 |
+
c4 = self.layer3(c3) # (64, 32) # stride=2, total_stride=8
|
| 226 |
+
c5 = self.layer4(c4) # (32, 16) # stride=2, total_stride=16
|
| 227 |
+
return c2, c3, c4, c5
|
| 228 |
+
|
| 229 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 230 |
+
return self._forward_impl(x)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _resnet(
|
| 234 |
+
arch: str,
|
| 235 |
+
block: Type[Union[BasicBlock, Bottleneck]],
|
| 236 |
+
layers: List[int],
|
| 237 |
+
pretrained: bool,
|
| 238 |
+
**kwargs: Any
|
| 239 |
+
) -> ResNet:
|
| 240 |
+
model = ResNet(block, layers, **kwargs)
|
| 241 |
+
if pretrained:
|
| 242 |
+
checkpoint = torch.load(model_paths[arch], map_location='cpu')
|
| 243 |
+
state_dict = model.state_dict()
|
| 244 |
+
for key, val in state_dict.items():
|
| 245 |
+
if key in checkpoint:
|
| 246 |
+
if val.shape == checkpoint[key].shape:
|
| 247 |
+
state_dict[key] = checkpoint[key]
|
| 248 |
+
model.load_state_dict(state_dict)
|
| 249 |
+
return model
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def resnet18(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
| 253 |
+
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, **kwargs)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def resnet34(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
| 257 |
+
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, **kwargs)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
| 261 |
+
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, **kwargs)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def resnet101(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
| 265 |
+
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, **kwargs)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def resnet152(pretrained: bool = False, **kwargs: Any) -> ResNet:
|
| 269 |
+
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, **kwargs)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def build_backbone(arch, pretrained=True, norm_layer=nn.BatchNorm2d):
|
| 273 |
+
arch_map = {
|
| 274 |
+
'res34': resnet34,
|
| 275 |
+
'res50': resnet50,
|
| 276 |
+
'res101': resnet101,
|
| 277 |
+
'res152': resnet152
|
| 278 |
+
}
|
| 279 |
+
if arch not in arch_map:
|
| 280 |
+
raise ValueError('Unknown backbone arch: %s' % arch)
|
| 281 |
+
return arch_map[arch](pretrained=pretrained, norm_layer=norm_layer)
|
libs/model/cells_extractor.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
from .extractor import RoiPosFeatExtraxtor
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SALayer(nn.Module):
|
| 9 |
+
def __init__(self, in_dim, att_dim, head_nums):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.in_dim = in_dim
|
| 12 |
+
self.att_dim = att_dim
|
| 13 |
+
self.head_nums = head_nums
|
| 14 |
+
|
| 15 |
+
assert self.in_dim % self.head_nums == 0
|
| 16 |
+
|
| 17 |
+
self.key_layer = nn.Conv1d(self.in_dim, self.att_dim, 1, 1, 0)
|
| 18 |
+
self.query_layer = nn.Conv1d(self.in_dim, self.att_dim, 1, 1, 0)
|
| 19 |
+
self.value_layer = nn.Conv1d(self.in_dim, self.in_dim, 1, 1, 0)
|
| 20 |
+
self.scale = 1 / math.sqrt(self.att_dim)
|
| 21 |
+
|
| 22 |
+
def forward(self, feats, masks=None):
|
| 23 |
+
bs, c, n = feats.shape
|
| 24 |
+
keys = self.key_layer(feats).reshape(bs, -1, self.head_nums, n)
|
| 25 |
+
querys = self.query_layer(feats).reshape(bs, -1, self.head_nums, n)
|
| 26 |
+
values = self.value_layer(feats).reshape(bs, -1, self.head_nums, n)
|
| 27 |
+
|
| 28 |
+
logits = torch.einsum('bchk,bchq->bhkq', keys, querys) * self.scale
|
| 29 |
+
if masks is not None:
|
| 30 |
+
logits = logits - (1 - masks[:, None, :, None]) * 1e8
|
| 31 |
+
weights = torch.softmax(logits, dim=2)
|
| 32 |
+
|
| 33 |
+
new_feats = torch.einsum('bchk,bhkq->bchq', values, weights)
|
| 34 |
+
new_feats = new_feats.reshape(bs, -1, n)
|
| 35 |
+
return new_feats + feats
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def gen_cells_bbox(row_segments, col_segments, device):
|
| 39 |
+
cells_bbox = list()
|
| 40 |
+
for row_segments_pi, col_segments_pi in zip(row_segments, col_segments):
|
| 41 |
+
num_rows = len(row_segments_pi) - 1
|
| 42 |
+
num_cols = len(col_segments_pi) - 1
|
| 43 |
+
cells_bbox_pi = list()
|
| 44 |
+
for row_idx in range(num_rows):
|
| 45 |
+
for col_idx in range(num_cols):
|
| 46 |
+
bbox = [
|
| 47 |
+
col_segments_pi[col_idx],
|
| 48 |
+
row_segments_pi[row_idx],
|
| 49 |
+
col_segments_pi[col_idx + 1],
|
| 50 |
+
row_segments_pi[row_idx + 1]
|
| 51 |
+
]
|
| 52 |
+
cells_bbox_pi.append(bbox)
|
| 53 |
+
cells_bbox_pi = torch.tensor(cells_bbox_pi, dtype=torch.float, device=device)
|
| 54 |
+
cells_bbox.append(cells_bbox_pi)
|
| 55 |
+
return cells_bbox
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def align_cells_feat(cells_feat, num_rows, num_cols):
|
| 59 |
+
batch_size = len(cells_feat)
|
| 60 |
+
dtype = cells_feat[0].dtype
|
| 61 |
+
device = cells_feat[0].device
|
| 62 |
+
|
| 63 |
+
max_row_nums = max(num_rows)
|
| 64 |
+
max_col_nums = max(num_cols)
|
| 65 |
+
|
| 66 |
+
aligned_cells_feat = list()
|
| 67 |
+
masks = torch.zeros([batch_size, max_row_nums, max_col_nums], dtype=dtype, device=device)
|
| 68 |
+
for batch_idx in range(batch_size):
|
| 69 |
+
num_rows_pi = num_rows[batch_idx]
|
| 70 |
+
num_cols_pi = num_cols[batch_idx]
|
| 71 |
+
cells_feat_pi = cells_feat[batch_idx]
|
| 72 |
+
cells_feat_pi = cells_feat_pi.transpose(0, 1).reshape(-1, num_rows_pi, num_cols_pi)
|
| 73 |
+
aligned_cells_feat_pi = F.pad(
|
| 74 |
+
cells_feat_pi,
|
| 75 |
+
(0, max_col_nums - num_cols_pi, 0, max_row_nums - num_rows_pi, 0, 0),
|
| 76 |
+
mode='constant',
|
| 77 |
+
value=0
|
| 78 |
+
)
|
| 79 |
+
aligned_cells_feat.append(aligned_cells_feat_pi)
|
| 80 |
+
|
| 81 |
+
masks[batch_idx, :num_rows_pi, :num_cols_pi] = 1
|
| 82 |
+
aligned_cells_feat = torch.stack(aligned_cells_feat, dim=0)
|
| 83 |
+
return aligned_cells_feat, masks
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class CellsExtractor(nn.Module):
|
| 87 |
+
def __init__(self, in_dim, cell_dim, heads, head_nums, pool_size, scale=1):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.in_dim = in_dim
|
| 90 |
+
self.cell_dim = cell_dim
|
| 91 |
+
self.pool_size = pool_size
|
| 92 |
+
self.scale = scale
|
| 93 |
+
self.box_feat_extractor = RoiPosFeatExtraxtor(
|
| 94 |
+
self.scale,
|
| 95 |
+
self.pool_size,
|
| 96 |
+
self.in_dim,
|
| 97 |
+
self.cell_dim
|
| 98 |
+
)
|
| 99 |
+
self.heads = heads
|
| 100 |
+
self.row_sas = nn.ModuleList()
|
| 101 |
+
self.col_sas = nn.ModuleList()
|
| 102 |
+
for _ in range(self.heads):
|
| 103 |
+
self.row_sas.append(SALayer(cell_dim, cell_dim, head_nums))
|
| 104 |
+
self.col_sas.append(SALayer(cell_dim, cell_dim, head_nums))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def forward(self, feats, row_segments, col_segments, img_sizes):
|
| 108 |
+
device = feats.device
|
| 109 |
+
num_rows = [len(row_segments_pi) - 1 for row_segments_pi in row_segments]
|
| 110 |
+
num_cols = [len(col_segments_pi) - 1 for col_segments_pi in col_segments]
|
| 111 |
+
|
| 112 |
+
cells_bbox = gen_cells_bbox(row_segments, col_segments, device)
|
| 113 |
+
cells_feat = self.box_feat_extractor(feats, cells_bbox, img_sizes)
|
| 114 |
+
|
| 115 |
+
aligned_cells_feat, masks = align_cells_feat(cells_feat, num_rows, num_cols)
|
| 116 |
+
|
| 117 |
+
bs, c, nr, nc = aligned_cells_feat.shape
|
| 118 |
+
|
| 119 |
+
for idx in range(self.heads):
|
| 120 |
+
col_cells_feat = aligned_cells_feat.permute(0, 2, 1, 3).contiguous().reshape(bs * nr, c, nc)
|
| 121 |
+
col_masks = masks.reshape(bs * nr, nc)
|
| 122 |
+
col_cells_feat = self.col_sas[idx](col_cells_feat, col_masks) # self-attention
|
| 123 |
+
aligned_cells_feat = col_cells_feat.reshape(bs, nr, c, nc).permute(0, 2, 1, 3).contiguous()
|
| 124 |
+
|
| 125 |
+
row_cells_feat = aligned_cells_feat.permute(0, 3, 1, 2).contiguous().reshape(bs * nc, c, nr)
|
| 126 |
+
row_masks = masks.transpose(1, 2).reshape(bs * nc, nr)
|
| 127 |
+
row_cells_feat = self.row_sas[idx](row_cells_feat, row_masks) # self-attention
|
| 128 |
+
aligned_cells_feat = row_cells_feat.reshape(bs, nc, c, nr).permute(0, 2, 3, 1).contiguous()
|
| 129 |
+
|
| 130 |
+
return aligned_cells_feat, masks
|
libs/model/decoder.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from numpy.core.fromnumeric import argmax
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch._C import device, dtype, layout
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
from torch.nn.functional import cross_entropy, embedding
|
| 8 |
+
from torch.nn.modules import loss
|
| 9 |
+
from torch.nn.modules.activation import Tanh
|
| 10 |
+
from libs.utils.metric import CellMergeAcc, AccMetric
|
| 11 |
+
from .utils import gen_proposals
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ImageAttention(nn.Module):
|
| 15 |
+
def __init__(self, key_dim, query_dim, cover_kernel):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.query_transform = nn.Linear(query_dim, key_dim)
|
| 18 |
+
self.weight_transform = nn.Conv2d(1, key_dim, cover_kernel, 1, padding=cover_kernel // 2)
|
| 19 |
+
self.cum_weight_transform = nn.Conv2d(1, key_dim, cover_kernel, 1, padding=cover_kernel // 2)
|
| 20 |
+
self.logit_transform = nn.Conv2d(key_dim, 1, 1, 1, 0)
|
| 21 |
+
|
| 22 |
+
def forward(self, key, key_mask, query, spatial_att_weight, cum_spatial_att_weight, value, state, layouts=None, layouts_cum=None, spatial_att_weight_scores=None):
|
| 23 |
+
query = self.query_transform(query)
|
| 24 |
+
weight_query = self.weight_transform(spatial_att_weight)
|
| 25 |
+
cum_weight_query = self.cum_weight_transform(cum_spatial_att_weight)
|
| 26 |
+
fusion = key + query[:, :, None, None] + weight_query + cum_weight_query
|
| 27 |
+
# cal new_spatial_att_logit
|
| 28 |
+
new_spatial_att_logit = self.logit_transform(torch.tanh(fusion))
|
| 29 |
+
# cal new_spatial_att_weight
|
| 30 |
+
new_spatial_att_weight = new_spatial_att_logit - (1 - key_mask) * 1e8
|
| 31 |
+
bs, _, h, w = new_spatial_att_weight.shape
|
| 32 |
+
new_spatial_att_weight = new_spatial_att_weight.reshape(bs, h * w)
|
| 33 |
+
new_spatial_att_weight = torch.softmax(new_spatial_att_weight, dim=1).reshape(bs, 1, h, w)
|
| 34 |
+
# cal new_cum_spatial_att_weight
|
| 35 |
+
if self.training:
|
| 36 |
+
outputs = list()
|
| 37 |
+
for (value_pi, layout) in zip(value, layouts):
|
| 38 |
+
h, w = torch.where(layout == 1.)
|
| 39 |
+
if len(h) == 0 or len(w) == 0:
|
| 40 |
+
outputs.append(torch.zeros_like(query[0]))
|
| 41 |
+
else:
|
| 42 |
+
outputs.append(value_pi[:, h, w].mean(-1))
|
| 43 |
+
outputs = torch.stack(outputs, dim=0)
|
| 44 |
+
new_cum_spatial_att_weight = torch.clamp(layouts.unsqueeze(1).float() + cum_spatial_att_weight, max=1.)
|
| 45 |
+
return state, outputs, new_spatial_att_logit, new_spatial_att_weight, new_cum_spatial_att_weight, None, None
|
| 46 |
+
else:
|
| 47 |
+
state_list = list()
|
| 48 |
+
outputs_list = list()
|
| 49 |
+
scores_list = list()
|
| 50 |
+
proposals_list = list()
|
| 51 |
+
new_spatial_att_weight_list = list()
|
| 52 |
+
new_cum_spatial_att_weight_list = list()
|
| 53 |
+
layouts_pred = new_spatial_att_logit.squeeze(1).sigmoid()
|
| 54 |
+
for idx, (value_pi, state_pi, layout) in enumerate(zip(value, state, layouts_pred)):
|
| 55 |
+
if cum_spatial_att_weight[idx].min() == 1:
|
| 56 |
+
state_list.append(state_pi)
|
| 57 |
+
outputs_list.append(torch.zeros_like(query[0]))
|
| 58 |
+
proposals_list.append(torch.cat((layouts_cum[idx], torch.zeros_like(layout.unsqueeze(0))), dim=0))
|
| 59 |
+
scores_list.append(spatial_att_weight_scores[idx])
|
| 60 |
+
new_spatial_att_weight_list.append(new_spatial_att_weight[idx])
|
| 61 |
+
new_cum_spatial_att_weight_list.append(cum_spatial_att_weight[idx])
|
| 62 |
+
else:
|
| 63 |
+
srow, scol = torch.where(cum_spatial_att_weight[idx].squeeze(0) == cum_spatial_att_weight[idx].squeeze(0).min())
|
| 64 |
+
scol = scol[srow == srow.min()].min()
|
| 65 |
+
srow = srow.min()
|
| 66 |
+
proposals, scores = gen_proposals(layout, srow, scol, score_threshold=0.5)
|
| 67 |
+
scores = scores + spatial_att_weight_scores[idx]
|
| 68 |
+
for s in scores:
|
| 69 |
+
scores_list.append(s)
|
| 70 |
+
for p in proposals:
|
| 71 |
+
proposals_list.append(torch.cat((layouts_cum[idx], p.unsqueeze(0)), dim=0))
|
| 72 |
+
h, w = torch.where(p == 1.)
|
| 73 |
+
outputs_list.append(value_pi[:, h, w].mean(-1))
|
| 74 |
+
state_list.append(state_pi)
|
| 75 |
+
new_spatial_att_weight_list.append(new_spatial_att_weight[idx])
|
| 76 |
+
new_cum_spatial_att_weight_list.append(torch.clamp(cum_spatial_att_weight[idx] + p.unsqueeze(0), max=1.))
|
| 77 |
+
state_list = torch.stack(state_list, dim=0)
|
| 78 |
+
proposals_list = torch.stack(proposals_list, dim=0)
|
| 79 |
+
scores_list = torch.stack(scores_list, dim=0)
|
| 80 |
+
outputs_list = torch.stack(outputs_list, dim=0)
|
| 81 |
+
new_spatial_att_weight_list = torch.stack(new_spatial_att_weight_list, dim=0)
|
| 82 |
+
new_cum_spatial_att_weight_list = torch.stack(new_cum_spatial_att_weight_list, dim=0)
|
| 83 |
+
sorted_scores, sorted_idxes = torch.sort(scores_list, dim=0, descending=True)
|
| 84 |
+
sorted_scores = sorted_scores[:6]
|
| 85 |
+
sorted_idxes = sorted_idxes[:6]
|
| 86 |
+
proposals = proposals_list[sorted_idxes]
|
| 87 |
+
new_spatial_att_weight = new_spatial_att_weight_list[sorted_idxes]
|
| 88 |
+
new_cum_spatial_att_weight = new_cum_spatial_att_weight_list[sorted_idxes]
|
| 89 |
+
outputs = outputs_list[sorted_idxes]
|
| 90 |
+
state = state_list[sorted_idxes]
|
| 91 |
+
return state, outputs, new_spatial_att_logit, new_spatial_att_weight, new_cum_spatial_att_weight, proposals, sorted_scores
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Decoder(nn.Module):
|
| 96 |
+
def __init__(self, vocab, embed_dim, feat_dim, lm_state_dim, proj_dim, cover_kernel, att_threshold, spatial_att_logit_loss_wight):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.vocab = vocab
|
| 99 |
+
self.embed_dim = embed_dim
|
| 100 |
+
self.feat_dim = feat_dim
|
| 101 |
+
self.lm_state_dim = lm_state_dim
|
| 102 |
+
self.proj_dim = proj_dim
|
| 103 |
+
self.cover_kernel = cover_kernel
|
| 104 |
+
self.att_threshold = att_threshold
|
| 105 |
+
self.spatial_att_logit_loss_wight = spatial_att_logit_loss_wight
|
| 106 |
+
self.feat_projection = nn.Conv2d(self.feat_dim, self.proj_dim, 1, 1, 0)
|
| 107 |
+
self.state_init_projection = nn.Conv2d(self.feat_dim, self.lm_state_dim, 1, 1, 0)
|
| 108 |
+
self.lm_rnn1 = nn.GRUCell(input_size=self.feat_dim, hidden_size=self.lm_state_dim)
|
| 109 |
+
self.lm_rnn2 = nn.GRUCell(input_size=self.feat_dim, hidden_size=self.lm_state_dim)
|
| 110 |
+
self.image_attention = ImageAttention(self.proj_dim, self.feat_dim + self.lm_state_dim, cover_kernel)
|
| 111 |
+
self.struct_cls = nn.Sequential(
|
| 112 |
+
nn.Linear(self.feat_dim + self.lm_state_dim, self.lm_state_dim),
|
| 113 |
+
nn.Tanh(),
|
| 114 |
+
nn.Linear(self.lm_state_dim, len(self.vocab))
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def init_state(self, feats, feats_mask):
|
| 118 |
+
bs, _, h, w = feats.shape
|
| 119 |
+
project_feats = self.feat_projection(feats) * feats_mask
|
| 120 |
+
init_state = torch.sum(self.state_init_projection(feats), dim=(2, 3))/torch.sum(feats_mask, dim=(2, 3))
|
| 121 |
+
init_context = torch.sum(feats, dim=(2, 3)) / torch.sum(feats_mask, dim=(2, 3))
|
| 122 |
+
init_spatial_att_weight = torch.zeros([bs, 1, h, w], dtype=torch.float, device=feats.device)
|
| 123 |
+
init_cum_spatial_att_weight = torch.zeros([bs, 1, h, w], dtype=torch.float, device=feats.device)
|
| 124 |
+
return project_feats, init_state, init_context, init_spatial_att_weight, init_cum_spatial_att_weight
|
| 125 |
+
|
| 126 |
+
def step(self, feats, project_feats, feats_mask, state, context, spatial_att_weight, cum_spatial_att_weight, layouts=None, layouts_cum=None, spatial_att_weight_scores=None):
|
| 127 |
+
new_state = self.lm_rnn1(context, state)
|
| 128 |
+
new_state, new_context, new_spatial_att_logit, \
|
| 129 |
+
new_spatial_att_weight, new_cum_spatial_att_weight, \
|
| 130 |
+
layouts_cum, spatial_att_weight_scores = self.image_attention(
|
| 131 |
+
project_feats,
|
| 132 |
+
feats_mask,
|
| 133 |
+
torch.cat([context, new_state], dim=1),
|
| 134 |
+
spatial_att_weight,
|
| 135 |
+
cum_spatial_att_weight,
|
| 136 |
+
feats,
|
| 137 |
+
new_state,
|
| 138 |
+
layouts,
|
| 139 |
+
layouts_cum,
|
| 140 |
+
spatial_att_weight_scores
|
| 141 |
+
)
|
| 142 |
+
new_state = self.lm_rnn2(new_context, new_state)
|
| 143 |
+
cls_feat = torch.cat([new_context, new_state], dim=1)
|
| 144 |
+
cls_logits_pt = self.struct_cls(cls_feat)
|
| 145 |
+
return cls_logits_pt, new_state, new_context, new_spatial_att_logit, new_spatial_att_weight, new_cum_spatial_att_weight, layouts_cum, spatial_att_weight_scores
|
| 146 |
+
|
| 147 |
+
def forward(self, feats, feats_mask, cls_labels=None, labels_mask=None, layouts=None):
|
| 148 |
+
if self.training:
|
| 149 |
+
return self.forward_backward(feats, feats_mask, cls_labels, labels_mask, layouts)
|
| 150 |
+
else:
|
| 151 |
+
return self.inference(feats, feats_mask)
|
| 152 |
+
|
| 153 |
+
def inference(self, feats, feats_mask):
|
| 154 |
+
bs, _, h, w = feats.shape
|
| 155 |
+
device = feats.device
|
| 156 |
+
assert bs == 1, print('bs should be 1')
|
| 157 |
+
layouts_cum = torch.zeros_like(feats[:, : 1])
|
| 158 |
+
spatial_att_weight_scores = torch.zeros(bs).to(device=device, dtype=feats.dtype)
|
| 159 |
+
|
| 160 |
+
project_feats, init_state, init_context, spatial_att_weight, cum_spatial_att_weight = self.init_state(feats, feats_mask)
|
| 161 |
+
state = init_state
|
| 162 |
+
context = init_context
|
| 163 |
+
|
| 164 |
+
for _ in range(h*w):
|
| 165 |
+
cls_logits_pt, state, context, spatial_att_logit, spatial_att_weight, \
|
| 166 |
+
cum_spatial_att_weight, layouts_cum, spatial_att_weight_scores \
|
| 167 |
+
= self.step(
|
| 168 |
+
feats, project_feats,
|
| 169 |
+
feats_mask, state, context,
|
| 170 |
+
spatial_att_weight, cum_spatial_att_weight, None, layouts_cum, spatial_att_weight_scores)
|
| 171 |
+
feats = feats[:1].repeat(layouts_cum.shape[0], 1, 1, 1)
|
| 172 |
+
feats_mask = feats_mask[:1].repeat(layouts_cum.shape[0], 1, 1, 1)
|
| 173 |
+
project_feats = project_feats[:1].repeat(layouts_cum.shape[0], 1, 1, 1)
|
| 174 |
+
if cum_spatial_att_weight.min() == 1:
|
| 175 |
+
break
|
| 176 |
+
spatial_att_logit_preds = layouts_cum[spatial_att_weight_scores.argmax(), 1:].unsqueeze(0)
|
| 177 |
+
return spatial_att_logit_preds, {}
|
| 178 |
+
|
| 179 |
+
def forward_backward(self, feats, feats_mask, cls_labels, labels_mask, layouts):
|
| 180 |
+
device = feats.device
|
| 181 |
+
valid_cls_length = torch.sum((labels_mask == 1) & (cls_labels != -1), dim=1).detach()
|
| 182 |
+
valid_spatial_att_logit_length = torch.stack([layout.max() + 1 for layout in layouts])
|
| 183 |
+
max_length = valid_cls_length.max()
|
| 184 |
+
|
| 185 |
+
project_feats, init_state, init_context, spatial_att_weight, cum_spatial_att_weight = self.init_state(feats, feats_mask)
|
| 186 |
+
state = init_state
|
| 187 |
+
context = init_context
|
| 188 |
+
|
| 189 |
+
loss_cache = dict()
|
| 190 |
+
|
| 191 |
+
cls_loss = list()
|
| 192 |
+
cls_preds = list()
|
| 193 |
+
|
| 194 |
+
spatial_att_logit_loss = list()
|
| 195 |
+
spatial_att_logit_preds = list()
|
| 196 |
+
spatial_att_logit_masks = list()
|
| 197 |
+
spatial_att_logit_labels = list()
|
| 198 |
+
for time_t in range(max_length):
|
| 199 |
+
cls_logits_pt, state, context, spatial_att_logit, spatial_att_weight, cum_spatial_att_weight, *_ \
|
| 200 |
+
= self.step(
|
| 201 |
+
feats, project_feats,
|
| 202 |
+
feats_mask, state, context,
|
| 203 |
+
spatial_att_weight, cum_spatial_att_weight, layouts == time_t
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
cls_label = cls_labels[:, time_t]
|
| 207 |
+
label_mask = labels_mask[:, time_t]
|
| 208 |
+
# cal cls loss
|
| 209 |
+
cls_loss_pt = F.cross_entropy(cls_logits_pt, cls_label, ignore_index=-1, reduction='none') * label_mask
|
| 210 |
+
cls_loss.append(cls_loss_pt)
|
| 211 |
+
# save for acc
|
| 212 |
+
cls_preds.append(torch.argmax(cls_logits_pt, dim=1).detach())
|
| 213 |
+
|
| 214 |
+
spatial_att_logit_preds.append(spatial_att_logit.sigmoid() > self.att_threshold)
|
| 215 |
+
spatial_att_logit_masks.append((layouts != -1).unsqueeze(1))
|
| 216 |
+
spatial_att_logit_labels.append((layouts == time_t).unsqueeze(1))
|
| 217 |
+
# cal spatial att loss
|
| 218 |
+
spatial_att_logit_loss_pt = list()
|
| 219 |
+
for spatial_att_logit_pi, layout in zip(spatial_att_logit, layouts):
|
| 220 |
+
target = layout == time_t
|
| 221 |
+
if torch.any(target) == False:
|
| 222 |
+
spatial_att_logit_loss_pt_pi = torch.tensor(0.0, dtype=torch.float, device=device)
|
| 223 |
+
else:
|
| 224 |
+
mask = (layout != -1).float()
|
| 225 |
+
spatial_att_logit_loss_pt_pi = F.binary_cross_entropy_with_logits(
|
| 226 |
+
spatial_att_logit_pi,
|
| 227 |
+
target.float().unsqueeze(0),
|
| 228 |
+
reduction='none'
|
| 229 |
+
)
|
| 230 |
+
spatial_att_logit_loss_pt_pi = (spatial_att_logit_loss_pt_pi * mask).sum()
|
| 231 |
+
spatial_att_logit_loss_pt.append(spatial_att_logit_loss_pt_pi)
|
| 232 |
+
spatial_att_logit_loss_pt = torch.stack(spatial_att_logit_loss_pt, dim=0)
|
| 233 |
+
spatial_att_logit_loss.append(spatial_att_logit_loss_pt)
|
| 234 |
+
|
| 235 |
+
cls_loss = torch.mean(torch.sum(torch.stack(cls_loss, dim=1), dim=1)/valid_cls_length)
|
| 236 |
+
spatial_att_logit_loss = self.spatial_att_logit_loss_wight * torch.mean(torch.sum(torch.stack(spatial_att_logit_loss, dim=1), dim=1) / valid_spatial_att_logit_length)
|
| 237 |
+
|
| 238 |
+
loss_cache['cls_loss'] = cls_loss
|
| 239 |
+
loss_cache['spatial_att_logit_loss'] = spatial_att_logit_loss
|
| 240 |
+
|
| 241 |
+
cls_preds = torch.stack(cls_preds, dim=1)
|
| 242 |
+
spatial_att_logit_preds = torch.stack(spatial_att_logit_preds, dim=1)
|
| 243 |
+
spatial_att_logit_masks = torch.stack(spatial_att_logit_masks, dim=1)
|
| 244 |
+
spatial_att_logit_labels = torch.stack(spatial_att_logit_labels, dim=1)
|
| 245 |
+
|
| 246 |
+
acc_metric = AccMetric()
|
| 247 |
+
cell_merge_acc = CellMergeAcc()
|
| 248 |
+
cls_correct, cls_total = acc_metric(cls_preds, cls_labels, labels_mask)
|
| 249 |
+
cls_none_correct, cls_none_total = acc_metric(cls_preds, cls_labels, (labels_mask == 1) & (cls_labels == self.vocab.none_id))
|
| 250 |
+
cls_bold_correct, cls_bold_total = acc_metric(cls_preds, cls_labels, (labels_mask == 1) & (cls_labels == self.vocab.bold_id))
|
| 251 |
+
cls_space_correct, cls_space_total = acc_metric(cls_preds, cls_labels, (labels_mask == 1) & (cls_labels == self.vocab.space_id))
|
| 252 |
+
cls_blank_correct = cls_none_correct + cls_bold_correct + cls_space_correct
|
| 253 |
+
cls_blank_total = cls_none_total + cls_bold_total + cls_space_total
|
| 254 |
+
cells_correct_nums, cells_total_nums = cell_merge_acc(spatial_att_logit_preds, spatial_att_logit_labels, spatial_att_logit_masks)
|
| 255 |
+
loss_cache['cls_acc'] = cls_correct / cls_total
|
| 256 |
+
loss_cache['cls_none_acc'] = cls_none_correct / cls_none_total
|
| 257 |
+
loss_cache['cls_bold_acc'] = cls_bold_correct / cls_bold_total
|
| 258 |
+
loss_cache['cls_space_acc'] = cls_space_correct / cls_space_total
|
| 259 |
+
loss_cache['cls_blank_acc'] = cls_blank_correct / cls_blank_total
|
| 260 |
+
loss_cache['spatial_att_logit_acc'] = cells_correct_nums / cells_total_nums
|
| 261 |
+
|
| 262 |
+
return (spatial_att_logit_preds), loss_cache
|
| 263 |
+
|
| 264 |
+
def build_decoder(cfg):
|
| 265 |
+
decoder = Decoder(
|
| 266 |
+
vocab=cfg.vocab,
|
| 267 |
+
feat_dim=cfg.encode_dim,
|
| 268 |
+
line_dim=cfg.extractor_dim,
|
| 269 |
+
embed_dim=cfg.embed_dim,
|
| 270 |
+
lm_state_dim=cfg.lm_state_dim,
|
| 271 |
+
proj_dim=cfg.proj_dim,
|
| 272 |
+
hidden_dim=cfg.hidden_dim,
|
| 273 |
+
cover_kernel=cfg.cover_kernel,
|
| 274 |
+
max_length=cfg.max_length
|
| 275 |
+
)
|
| 276 |
+
return decoder
|
| 277 |
+
|
libs/model/divide_predictor.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from .sa import SALayer
|
| 5 |
+
from libs.utils.metric import cal_cls_acc
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def align_segments_feat(segments_feat):
|
| 9 |
+
dtype = segments_feat[0].dtype
|
| 10 |
+
device = segments_feat[0].device
|
| 11 |
+
batch_size = len(segments_feat)
|
| 12 |
+
max_segment_nums = max([item.shape[1] for item in segments_feat])
|
| 13 |
+
aligned_segments_feat = list()
|
| 14 |
+
masks = torch.zeros([batch_size, max_segment_nums], dtype=dtype, device=device)
|
| 15 |
+
|
| 16 |
+
for batch_idx in range(batch_size):
|
| 17 |
+
cur_segment_nums = segments_feat[batch_idx].shape[1]
|
| 18 |
+
masks[batch_idx, :cur_segment_nums] = 1
|
| 19 |
+
aligned_segments_feat.append(
|
| 20 |
+
F.pad(
|
| 21 |
+
segments_feat[batch_idx],
|
| 22 |
+
(0, max_segment_nums - cur_segment_nums, 0, 0),
|
| 23 |
+
mode='constant',
|
| 24 |
+
value=0
|
| 25 |
+
)
|
| 26 |
+
)
|
| 27 |
+
aligned_segments_feat = torch.stack(aligned_segments_feat, dim=0)
|
| 28 |
+
return aligned_segments_feat, masks
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class HeadBodyDividePredictor(nn.Module):
|
| 32 |
+
def __init__(self, in_dim, head_nums, scale=1):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.in_dim = in_dim
|
| 35 |
+
self.scale = scale
|
| 36 |
+
self.fusion_layer = SALayer(in_dim, in_dim, head_nums)
|
| 37 |
+
self.classifier= nn.Conv1d(in_dim, 1, 1, 1, 0)
|
| 38 |
+
|
| 39 |
+
def forward(self, feats, segments, divide_labels=None):
|
| 40 |
+
segments = [[int(subitem * self.scale) for subitem in item] for item in segments]
|
| 41 |
+
segments_feat = [feats_pi[:, segments_pi] for feats_pi, segments_pi in zip(feats, segments)]
|
| 42 |
+
aligned_segments_feat, masks = align_segments_feat(segments_feat)
|
| 43 |
+
aligned_segments_feat = self.fusion_layer(aligned_segments_feat, masks)
|
| 44 |
+
divide_logits = self.classifier(aligned_segments_feat).squeeze(1)
|
| 45 |
+
divide_logits = divide_logits - (1 - masks) * 1e8
|
| 46 |
+
divide_preds = torch.argmax(divide_logits, dim=1)
|
| 47 |
+
|
| 48 |
+
result_info = dict()
|
| 49 |
+
ext_info = dict()
|
| 50 |
+
if self.training:
|
| 51 |
+
result_info['divide_loss'] = F.cross_entropy(divide_logits, divide_labels)
|
| 52 |
+
correct_nums, total_nums = cal_cls_acc(divide_preds, divide_labels)
|
| 53 |
+
if total_nums != 0:
|
| 54 |
+
result_info['divide_acc'] = correct_nums / total_nums
|
| 55 |
+
|
| 56 |
+
divide_preds = divide_preds.detach().cpu().tolist()
|
| 57 |
+
return divide_preds, result_info, ext_info
|
libs/model/extractor.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torchvision.ops import roi_align
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def convert_to_roi_format(lines_box):
|
| 7 |
+
concat_boxes = torch.cat(lines_box, dim=0)
|
| 8 |
+
device, dtype = concat_boxes.device, concat_boxes.dtype
|
| 9 |
+
ids = torch.cat(
|
| 10 |
+
[
|
| 11 |
+
torch.full((lines_box_pi.shape[0], 1), i, dtype=dtype, device=device)
|
| 12 |
+
for i, lines_box_pi in enumerate(lines_box)
|
| 13 |
+
],
|
| 14 |
+
dim=0
|
| 15 |
+
)
|
| 16 |
+
rois = torch.cat([ids, concat_boxes], dim=1)
|
| 17 |
+
return rois
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RoiFeatExtraxtor(nn.Module):
|
| 21 |
+
def __init__(self, scale, pool_size, input_dim, output_dim):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.scale = scale
|
| 24 |
+
self.pool_size = pool_size
|
| 25 |
+
self.output_dim = output_dim
|
| 26 |
+
input_dim = input_dim * self.pool_size[0] * self.pool_size[1]
|
| 27 |
+
self.fc = nn.Sequential(
|
| 28 |
+
nn.Linear(input_dim, self.output_dim),
|
| 29 |
+
nn.ReLU(),
|
| 30 |
+
nn.Linear(self.output_dim, self.output_dim)
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def forward(self, feats, lines_box):
|
| 34 |
+
rois = convert_to_roi_format(lines_box)
|
| 35 |
+
lines_feat = roi_align(
|
| 36 |
+
input=feats,
|
| 37 |
+
boxes=rois,
|
| 38 |
+
output_size=self.pool_size,
|
| 39 |
+
spatial_scale=self.scale,
|
| 40 |
+
sampling_ratio=2
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
lines_feat = lines_feat.reshape(lines_feat.shape[0], -1)
|
| 44 |
+
lines_feat = self.fc(lines_feat)
|
| 45 |
+
lines_feat = torch.split(lines_feat, [item.shape[0] for item in lines_box])
|
| 46 |
+
return list(lines_feat)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class RoiPosFeatExtraxtor(nn.Module):
|
| 50 |
+
def __init__(self, scale, pool_size, input_dim, output_dim):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.scale = scale
|
| 53 |
+
self.pool_size = pool_size
|
| 54 |
+
self.output_dim = output_dim
|
| 55 |
+
input_dim = input_dim * self.pool_size[0] * self.pool_size[1]
|
| 56 |
+
self.fc = nn.Sequential(
|
| 57 |
+
nn.Linear(input_dim, self.output_dim),
|
| 58 |
+
nn.ReLU(),
|
| 59 |
+
nn.Linear(self.output_dim, self.output_dim)
|
| 60 |
+
)
|
| 61 |
+
self.bbox_ln = nn.LayerNorm(self.output_dim)
|
| 62 |
+
self.bbox_tranform = nn.Linear(4, self.output_dim)
|
| 63 |
+
|
| 64 |
+
self.add_ln = nn.LayerNorm(self.output_dim)
|
| 65 |
+
|
| 66 |
+
def forward(self, feats, lines_box, img_sizes):
|
| 67 |
+
rois = convert_to_roi_format(lines_box)
|
| 68 |
+
lines_feat = roi_align(
|
| 69 |
+
input=feats,
|
| 70 |
+
boxes=rois,
|
| 71 |
+
output_size=self.pool_size,
|
| 72 |
+
spatial_scale=self.scale,
|
| 73 |
+
sampling_ratio=2
|
| 74 |
+
)
|
| 75 |
+
lines_feat = lines_feat.reshape(lines_feat.shape[0], -1)
|
| 76 |
+
lines_feat = self.fc(lines_feat)
|
| 77 |
+
lines_feat = list(torch.split(lines_feat, [item.shape[0] for item in lines_box]))
|
| 78 |
+
|
| 79 |
+
# Add Pos Embedding
|
| 80 |
+
feats_H, feats_W = feats.shape[-2:]
|
| 81 |
+
for idx, (line_box, img_size) in enumerate(zip(lines_box, img_sizes)):
|
| 82 |
+
line_box[:, 0] = line_box[:, 0] * self.scale / feats_W
|
| 83 |
+
line_box[:, 1] = line_box[:, 1] * self.scale / feats_H
|
| 84 |
+
line_box[:, 2] = line_box[:, 2] * self.scale / feats_W
|
| 85 |
+
line_box[:, 3] = line_box[:, 3] * self.scale / feats_H
|
| 86 |
+
lines_feat[idx] = self.add_ln(lines_feat[idx] + self.bbox_ln(self.bbox_tranform(line_box)))
|
| 87 |
+
|
| 88 |
+
return list(lines_feat)
|
libs/model/fpn.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class FPN(nn.Module):
|
| 7 |
+
def __init__(self, in_channels, out_channels):
|
| 8 |
+
super().__init__()
|
| 9 |
+
assert len(in_channels) == 4
|
| 10 |
+
self.in_channels = in_channels
|
| 11 |
+
|
| 12 |
+
self.lat_layers = nn.ModuleList()
|
| 13 |
+
self.out_layers = nn.ModuleList()
|
| 14 |
+
for in_channels_pl in in_channels:
|
| 15 |
+
self.lat_layers.append(
|
| 16 |
+
nn.Conv2d(in_channels_pl, out_channels, kernel_size=1, stride=1, padding=0)
|
| 17 |
+
)
|
| 18 |
+
self.out_layers.append(
|
| 19 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflect')
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def forward(self, feats):
|
| 23 |
+
c2, c3, c4, c5 = feats
|
| 24 |
+
p5 = self.lat_layers[3](c5)
|
| 25 |
+
p4 = F.interpolate(p5, size=c4.shape[2:], align_corners=False, mode='bilinear') + self.lat_layers[2](c4)
|
| 26 |
+
p3 = F.interpolate(p4, size=c3.shape[2:], align_corners=False, mode='bilinear') + self.lat_layers[1](c3)
|
| 27 |
+
p2 = F.interpolate(p3, size=c2.shape[2:], align_corners=False, mode='bilinear') + self.lat_layers[0](c2)
|
| 28 |
+
|
| 29 |
+
p2 = self.out_layers[0](p2)
|
| 30 |
+
p3 = self.out_layers[1](p3)
|
| 31 |
+
p4 = self.out_layers[2](p4)
|
| 32 |
+
p5 = self.out_layers[3](p5)
|
| 33 |
+
return p2, p3, p4, p5
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_fpn(in_channels, out_channels):
|
| 37 |
+
return FPN(in_channels, out_channels)
|
libs/model/model.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from .backbone import build_backbone
|
| 4 |
+
from .fpn import build_fpn
|
| 5 |
+
from .pan import PAN
|
| 6 |
+
from .segment_predictor import SegmentPredictor
|
| 7 |
+
from .divide_predictor import HeadBodyDividePredictor
|
| 8 |
+
from .cells_extractor import CellsExtractor
|
| 9 |
+
from .decoder import Decoder
|
| 10 |
+
from .utils import extend_segments, spatial_att_to_spans
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Model(nn.Module):
|
| 14 |
+
def __init__(self, cfg, norm_layer=nn.BatchNorm2d):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.backbone = build_backbone(cfg.arch, cfg.pretrained_backbone, norm_layer=norm_layer)
|
| 17 |
+
self.fpn = build_fpn(cfg.backbone_out_channels, cfg.fpn_out_channels)
|
| 18 |
+
self.pan = PAN(cfg.pan_num_levels, cfg.pan_in_dim, cfg.pan_out_dim)
|
| 19 |
+
self.row_segment_predictor = SegmentPredictor(cfg.fpn_out_channels, scale=cfg.rs_scale, type='row')
|
| 20 |
+
self.col_segment_predictor = SegmentPredictor(cfg.fpn_out_channels, scale=cfg.cs_scale, type='col')
|
| 21 |
+
self.divide_predictor = HeadBodyDividePredictor(cfg.fpn_out_channels, cfg.dp_head_nums, scale=cfg.dp_scale)
|
| 22 |
+
self.cells_extractor = CellsExtractor(cfg.fpn_out_channels, cfg.ce_dim, cfg.ce_heads, cfg.ce_head_nums, cfg.ce_pool_size, cfg.ce_scale)
|
| 23 |
+
self.decoder = Decoder(cfg.vocab, cfg.embed_dim, cfg.feat_dim, cfg.lm_state_dim, cfg.proj_dim, cfg.cover_kernel, cfg.att_threshold, cfg.spatial_att_weight_loss_wight)
|
| 24 |
+
|
| 25 |
+
def forward(self, images, images_size, cls_labels=None, labels_mask=None, layouts=None, rows_fg_spans=None,
|
| 26 |
+
rows_bg_spans=None, cols_fg_spans=None, cols_bg_spans=None, cells_spans=None, divide_labels=None):
|
| 27 |
+
|
| 28 |
+
feats = self.fpn(self.backbone(images))
|
| 29 |
+
|
| 30 |
+
row_feats = torch.mean(feats[0], dim=3)
|
| 31 |
+
|
| 32 |
+
result_info = dict()
|
| 33 |
+
ext_info = dict()
|
| 34 |
+
row_segments, rs_result_info, rs_ext_info = self.row_segment_predictor(feats[0], images_size, rows_fg_spans, rows_bg_spans)
|
| 35 |
+
rs_result_info = {'row_%s' % key: val for key, val in rs_result_info.items()}
|
| 36 |
+
rs_ext_info = {'row_%s' % key: val for key, val in rs_ext_info.items()}
|
| 37 |
+
result_info.update(rs_result_info)
|
| 38 |
+
ext_info.update(rs_ext_info)
|
| 39 |
+
col_segments, cs_result_info, cs_ext_info = self.col_segment_predictor(feats[0], images_size, cols_fg_spans, cols_bg_spans)
|
| 40 |
+
cs_result_info = {'col_%s' % key: val for key, val in cs_result_info.items()}
|
| 41 |
+
cs_ext_info = {'col_%s' % key: val for key, val in cs_ext_info.items()}
|
| 42 |
+
result_info.update(cs_result_info)
|
| 43 |
+
ext_info.update(cs_ext_info)
|
| 44 |
+
|
| 45 |
+
if self.training:
|
| 46 |
+
row_segments, col_segments, cells_spans, layouts, divide_labels = extend_segments(row_segments, rs_ext_info['row_ext_segments'],
|
| 47 |
+
col_segments, cs_ext_info['col_ext_segments'], cells_spans, layouts, divide_labels)
|
| 48 |
+
|
| 49 |
+
divide_preds, dp_result_info, dp_ext_info = self.divide_predictor(row_feats, row_segments, divide_labels=divide_labels)
|
| 50 |
+
result_info.update(dp_result_info)
|
| 51 |
+
ext_info.update(dp_ext_info)
|
| 52 |
+
|
| 53 |
+
feat_maps, feats_masks = self.cells_extractor(self.pan(feats), row_segments, col_segments, images_size)
|
| 54 |
+
if self.training:
|
| 55 |
+
assert feat_maps.shape[-2:] == layouts.shape[-2:], print('feat_maps is not the same with layouts')
|
| 56 |
+
|
| 57 |
+
de_preds, de_result_info = self.decoder(feat_maps, feats_masks.unsqueeze(1), cls_labels, labels_mask, layouts)
|
| 58 |
+
result_info.update(de_result_info)
|
| 59 |
+
|
| 60 |
+
if not self.training:
|
| 61 |
+
assert de_preds.shape[0] == 1, print("batch size should be 1")
|
| 62 |
+
de_recog_spans = spatial_att_to_spans(de_preds[0])
|
| 63 |
+
return (row_segments, col_segments, divide_preds, de_recog_spans), result_info
|
| 64 |
+
else:
|
| 65 |
+
return (row_segments, col_segments, divide_preds), result_info
|
libs/model/pan.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PAN(nn.Module):
|
| 7 |
+
def __init__(self, num_levels, in_channels, out_channels):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.num_levels = num_levels
|
| 10 |
+
self.in_channels = in_channels
|
| 11 |
+
self.out_channels = out_channels
|
| 12 |
+
self.pan_layers = nn.ModuleList()
|
| 13 |
+
for _ in range(num_levels - 1):
|
| 14 |
+
self.pan_layers.append(
|
| 15 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflect')
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def forward(self, feats):
|
| 19 |
+
p2, p3, p4, p5 = feats
|
| 20 |
+
p2_ = p2
|
| 21 |
+
p3_ = self.pan_layers[0](F.interpolate(p2_, size=p3.shape[2:], align_corners=False, mode='bilinear') + p3)
|
| 22 |
+
p4_ = self.pan_layers[1](F.interpolate(p3_, size=p4.shape[2:], align_corners=False, mode='bilinear') + p4)
|
| 23 |
+
p5_ = self.pan_layers[2](F.interpolate(p4_, size=p5.shape[2:], align_corners=False, mode='bilinear') + p5)
|
| 24 |
+
return p5_
|
libs/model/sa.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SALayer(nn.Module):
|
| 7 |
+
def __init__(self, in_dim, att_dim, head_nums):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.in_dim = in_dim
|
| 10 |
+
self.att_dim = att_dim
|
| 11 |
+
self.head_nums = head_nums
|
| 12 |
+
|
| 13 |
+
assert self.in_dim % self.head_nums == 0
|
| 14 |
+
|
| 15 |
+
self.key_layer = nn.Conv1d(self.in_dim, self.att_dim * self.head_nums, 1, 1, 0)
|
| 16 |
+
self.query_layer = nn.Conv1d(self.in_dim, self.att_dim * self.head_nums, 1, 1, 0)
|
| 17 |
+
self.value_layer = nn.Conv1d(self.in_dim, self.in_dim, 1, 1, 0)
|
| 18 |
+
self.scale = 1 / math.sqrt(self.att_dim)
|
| 19 |
+
|
| 20 |
+
def forward(self, feats, masks=None):
|
| 21 |
+
bs, c, n = feats.shape
|
| 22 |
+
keys = self.key_layer(feats).reshape(bs, -1, self.head_nums, n)
|
| 23 |
+
querys = self.query_layer(feats).reshape(bs, -1, self.head_nums, n)
|
| 24 |
+
values = self.value_layer(feats).reshape(bs, -1, self.head_nums, n)
|
| 25 |
+
|
| 26 |
+
logits = torch.einsum('bchk,bchq->bhkq', keys, querys) * self.scale
|
| 27 |
+
if masks is not None:
|
| 28 |
+
logits = logits - (1 - masks[:, None, :, None]) * 1e8
|
| 29 |
+
weights = torch.softmax(logits, dim=2)
|
| 30 |
+
|
| 31 |
+
new_feats = torch.einsum('bchk,bhkq->bchq', values, weights)
|
| 32 |
+
new_feats = new_feats.reshape(bs, -1, n)
|
| 33 |
+
return new_feats + feats
|
| 34 |
+
|
| 35 |
+
|
libs/model/segment_predictor.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from torch.nn.modules.activation import ReLU
|
| 5 |
+
from libs.utils.metric import cal_segment_pr
|
| 6 |
+
from .utils import draw_spans, save_logitmap
|
| 7 |
+
|
| 8 |
+
def cal_segments(cls_probs, spans, scale=1.0):
|
| 9 |
+
segments = list()
|
| 10 |
+
for span in spans:
|
| 11 |
+
span_cls_probs = cls_probs[int(span[0] * scale): int(span[1] * scale)]
|
| 12 |
+
segment = torch.argmax(span_cls_probs).item() + int(span[0] * scale)
|
| 13 |
+
segments.append(segment)
|
| 14 |
+
segments = [int(item/scale) for item in segments]
|
| 15 |
+
return segments
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def cal_spans(cls_probs, threshold=0.5):
|
| 19 |
+
ids = (cls_probs > threshold).long().tolist()
|
| 20 |
+
spans = list()
|
| 21 |
+
for idx, id in enumerate(ids):
|
| 22 |
+
if id == 1:
|
| 23 |
+
if (idx == 0) or (ids[idx-1] != 1):
|
| 24 |
+
spans.append([idx, idx+1])
|
| 25 |
+
else:
|
| 26 |
+
spans[-1][1] = idx + 1
|
| 27 |
+
return spans
|
| 28 |
+
# draw_spans('row_segment_spans.png', 'row_segment.png', spans, 'row')
|
| 29 |
+
|
| 30 |
+
def cls_logits_to_segments(segments_logit, masks, type, spans=None, scale=1, threshold=0.5):
|
| 31 |
+
if type == 'col':
|
| 32 |
+
cls_probs = segments_logit.squeeze(1).sigmoid().mean(dim=1)
|
| 33 |
+
lengths = [int(mask[0, :].sum().item()) for mask in masks]
|
| 34 |
+
else:
|
| 35 |
+
cls_probs = segments_logit.squeeze(1).sigmoid().mean(dim=2)
|
| 36 |
+
lengths = [int(mask[:, 0].sum().item()) for mask in masks]
|
| 37 |
+
|
| 38 |
+
batch_size = cls_probs.shape[0]
|
| 39 |
+
segments = list()
|
| 40 |
+
for batch_idx in range(batch_size):
|
| 41 |
+
length = lengths[batch_idx]
|
| 42 |
+
if spans is None:
|
| 43 |
+
spans_pi = cal_spans(cls_probs[batch_idx, :length], threshold)
|
| 44 |
+
if len(spans_pi) <= 2:
|
| 45 |
+
spans_pi = [[0, 1], [length-1, length]]
|
| 46 |
+
else:
|
| 47 |
+
spans_pi = spans[batch_idx]
|
| 48 |
+
segments_pi = cal_segments(cls_probs[batch_idx, :length], spans_pi, scale)
|
| 49 |
+
segments.append(segments_pi)
|
| 50 |
+
return segments, cls_probs, lengths
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def cal_ext_segments(cls_probs, lengths, bg_spans, scale=1, threshold=0.5):
|
| 54 |
+
"""
|
| 55 |
+
Ѱ�Ҽ�����. ��bg_spans(��line����,����������)��Ѱ��Ԥ��������, �Ҵ���threshold����.
|
| 56 |
+
"""
|
| 57 |
+
batch_size = cls_probs.shape[0]
|
| 58 |
+
ext_segments = list()
|
| 59 |
+
for batch_idx in range(batch_size):
|
| 60 |
+
length = lengths[batch_idx]
|
| 61 |
+
ext_segments_pi = cal_segments(cls_probs[batch_idx, :length], bg_spans[batch_idx], scale)
|
| 62 |
+
ext_segments_pi = [segment for segment in ext_segments_pi if cls_probs[batch_idx, segment] > threshold]
|
| 63 |
+
ext_segments.append(ext_segments_pi)
|
| 64 |
+
return ext_segments
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def gen_masks(sizes, scale, device):
|
| 68 |
+
batch_size = len(sizes)
|
| 69 |
+
max_size = [int(max(item) * scale) for item in zip(*sizes)]
|
| 70 |
+
masks = torch.zeros([batch_size, *max_size], dtype=torch.float, device=device)
|
| 71 |
+
for batch_idx in range(batch_size):
|
| 72 |
+
masks[batch_idx, :sizes[batch_idx][0], :sizes[batch_idx][1]] = 1.
|
| 73 |
+
return masks
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def gen_targets(sizes, scale, device, fg_spans, bg_spans, type):
|
| 77 |
+
batch_size = len(sizes)
|
| 78 |
+
max_size = [int(max(item) * scale) for item in zip(*sizes)]
|
| 79 |
+
targets = torch.zeros([batch_size, *max_size], dtype=torch.float, device=device)
|
| 80 |
+
for batch_idx, fg_spans_pb in enumerate(fg_spans):
|
| 81 |
+
if type == 'col':
|
| 82 |
+
for fg_spans_pi in fg_spans_pb:
|
| 83 |
+
targets[batch_idx, :, int(fg_spans_pi[0] * scale) : int(fg_spans_pi[1] * scale)] = 1.
|
| 84 |
+
else:
|
| 85 |
+
for fg_spans_pi in fg_spans_pb:
|
| 86 |
+
targets[batch_idx, int(fg_spans_pi[0] * scale) : int(fg_spans_pi[1] * scale), :] = 1.
|
| 87 |
+
return targets
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class SegmentPredictor(nn.Module):
|
| 91 |
+
def __init__(self, in_dim, scale=1, threshold=0.5, type=None):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.scale = scale
|
| 94 |
+
self.in_dim = in_dim
|
| 95 |
+
assert type in ['col', 'row']
|
| 96 |
+
self.type = type
|
| 97 |
+
self.threshold = threshold
|
| 98 |
+
self.convs = nn.Sequential(
|
| 99 |
+
nn.Conv2d(in_dim, in_dim // 2, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
|
| 100 |
+
nn.ReLU(),
|
| 101 |
+
nn.Conv2d(in_dim // 2, 1, kernel_size=(1,1), stride=(1,1), padding=(0,0))
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def forward(self, feats, images_size, fg_spans=None, bg_spans=None):
|
| 105 |
+
batch_size = feats.shape[0]
|
| 106 |
+
images_size = [image_size[::-1] for image_size in images_size]
|
| 107 |
+
segments_logit = self.convs(feats)
|
| 108 |
+
masks = gen_masks(images_size, self.scale, feats.device)
|
| 109 |
+
# save_logitmap('row_segment.png', segments_logit[0][0])
|
| 110 |
+
result_info = dict()
|
| 111 |
+
ext_info = dict()
|
| 112 |
+
|
| 113 |
+
if self.training:
|
| 114 |
+
targets = gen_targets(images_size, self.scale, feats.device, fg_spans, bg_spans, self.type)
|
| 115 |
+
segments_loss = F.binary_cross_entropy_with_logits(
|
| 116 |
+
segments_logit,
|
| 117 |
+
targets.unsqueeze(1),
|
| 118 |
+
reduction='none'
|
| 119 |
+
)
|
| 120 |
+
segments_loss = (segments_loss * masks[:, None, :, :]).sum() / targets.sum()
|
| 121 |
+
result_info['segments_loss'] = segments_loss
|
| 122 |
+
|
| 123 |
+
pred_segments, cls_probs, lengths = cls_logits_to_segments(segments_logit, masks, self.type, spans=None, scale=self.scale, threshold=self.threshold)
|
| 124 |
+
correct_nums, segment_nums, span_nums = cal_segment_pr(pred_segments, fg_spans, bg_spans)
|
| 125 |
+
if segment_nums != 0:
|
| 126 |
+
result_info['precision'] = correct_nums/segment_nums
|
| 127 |
+
if span_nums != 0:
|
| 128 |
+
result_info['recall'] = correct_nums/span_nums
|
| 129 |
+
ext_segments = cal_ext_segments(cls_probs, lengths, bg_spans, self.scale, self.threshold)
|
| 130 |
+
ext_info['ext_segments'] = ext_segments
|
| 131 |
+
|
| 132 |
+
pred_segments, *_ = cls_logits_to_segments(segments_logit, masks, self.type, spans=fg_spans, scale=self.scale, threshold=self.threshold)
|
| 133 |
+
return pred_segments, result_info, ext_info
|
libs/model/utils.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import copy
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
def proposal_colspan(layout, layout_score, srow, scol):
|
| 8 |
+
|
| 9 |
+
y, x = torch.where(layout == 1)
|
| 10 |
+
if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
|
| 11 |
+
return layout, layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean()
|
| 12 |
+
else:
|
| 13 |
+
lf_row = srow
|
| 14 |
+
lf_col = scol
|
| 15 |
+
|
| 16 |
+
col_count = 0
|
| 17 |
+
for col_ in range(lf_col, x.max() + 1):
|
| 18 |
+
if layout[lf_row, col_] == 1:
|
| 19 |
+
col_count = col_count + 1
|
| 20 |
+
else:
|
| 21 |
+
break
|
| 22 |
+
row_count = 0
|
| 23 |
+
for row_ in range(lf_row, y.max() + 1):
|
| 24 |
+
if torch.all(layout[row_, lf_col: lf_col + col_count] == 1):
|
| 25 |
+
row_count = row_count + 1
|
| 26 |
+
else:
|
| 27 |
+
break
|
| 28 |
+
|
| 29 |
+
layout[:, :] = 0
|
| 30 |
+
layout[lf_row:lf_row + row_count, lf_col : lf_col + col_count] = 1
|
| 31 |
+
return layout, layout_score[lf_row:lf_row + row_count, lf_col : lf_col + col_count].mean()
|
| 32 |
+
|
| 33 |
+
def proposal_rowspan(layout, layout_score, srow, scol):
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
y, x = torch.where(layout == 1)
|
| 37 |
+
if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
|
| 38 |
+
return layout, layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean()
|
| 39 |
+
else:
|
| 40 |
+
lf_row = srow
|
| 41 |
+
lf_col = scol
|
| 42 |
+
|
| 43 |
+
row_count = 0
|
| 44 |
+
for row_ in range(lf_row, y.max() + 1):
|
| 45 |
+
if layout[row_, lf_col] == 1:
|
| 46 |
+
row_count = row_count + 1
|
| 47 |
+
else:
|
| 48 |
+
break
|
| 49 |
+
col_count = 0
|
| 50 |
+
for col_ in range(lf_col, x.max() + 1):
|
| 51 |
+
if torch.all(layout[lf_row : lf_row + row_count, col_] == 1):
|
| 52 |
+
col_count = col_count + 1
|
| 53 |
+
else:
|
| 54 |
+
break
|
| 55 |
+
|
| 56 |
+
layout[:, :] = 0
|
| 57 |
+
layout[lf_row:lf_row + row_count, lf_col : lf_col + col_count] = 1
|
| 58 |
+
return layout, layout_score[lf_row:lf_row + row_count, lf_col : lf_col + col_count].mean()
|
| 59 |
+
|
| 60 |
+
def proposal_maxcontain(layout, layout_score, srow, scol):
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
y, x = torch.where(layout == 1)
|
| 64 |
+
if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
|
| 65 |
+
return layout, layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean()
|
| 66 |
+
else:
|
| 67 |
+
lf_row = srow
|
| 68 |
+
lf_col = scol
|
| 69 |
+
|
| 70 |
+
layout[:, :] = 0
|
| 71 |
+
layout[lf_row: y.max()+1, lf_col : x.max() + 1] = 1
|
| 72 |
+
return layout, layout_score[lf_row: y.max()+1, lf_col : x.max() + 1].mean()
|
| 73 |
+
|
| 74 |
+
def proposal_maxrowspan(layout, layout_score, srow, scol):
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
y, x = torch.where(layout == 1)
|
| 78 |
+
if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
|
| 79 |
+
return layout, layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean()
|
| 80 |
+
else:
|
| 81 |
+
lf_row = srow
|
| 82 |
+
lf_col = scol
|
| 83 |
+
|
| 84 |
+
row_count = 1
|
| 85 |
+
for row_ in range(lf_row + 1, y.max() + 1):
|
| 86 |
+
if torch.all(layout[lf_row] == layout[row_]):
|
| 87 |
+
row_count = row_count + 1
|
| 88 |
+
else:
|
| 89 |
+
break
|
| 90 |
+
|
| 91 |
+
layout[:, :] = 0
|
| 92 |
+
layout[lf_row : lf_row + row_count, lf_col : x.max() + 1] = 1
|
| 93 |
+
return layout, layout_score[lf_row : lf_row + row_count, lf_col : x.max() + 1].mean()
|
| 94 |
+
|
| 95 |
+
def proposal_maxcolspan(layout, layout_score, srow, scol):
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
y, x = torch.where(layout == 1)
|
| 99 |
+
if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
|
| 100 |
+
return layout, layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean()
|
| 101 |
+
else:
|
| 102 |
+
lf_row = srow
|
| 103 |
+
lf_col = scol
|
| 104 |
+
|
| 105 |
+
col_count = 1
|
| 106 |
+
for col_ in range(lf_col + 1, x.max() + 1):
|
| 107 |
+
if torch.all(layout[:, lf_col] == layout[:, col_]):
|
| 108 |
+
col_count = col_count + 1
|
| 109 |
+
else:
|
| 110 |
+
break
|
| 111 |
+
|
| 112 |
+
layout[:, :] = 0
|
| 113 |
+
layout[lf_row : y.max() + 1, lf_col : lf_col + col_count] = 1
|
| 114 |
+
return layout, layout_score[lf_row : y.max() + 1, lf_col : lf_col + col_count].mean()
|
| 115 |
+
|
| 116 |
+
def gen_proposals(layout_score, srow, scol, score_threshold=0.5):
|
| 117 |
+
layout = layout_score > score_threshold
|
| 118 |
+
layout[srow, scol] = 1
|
| 119 |
+
|
| 120 |
+
y, x = torch.where(layout == 1)
|
| 121 |
+
if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
|
| 122 |
+
return layout.unsqueeze(0), layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean().unsqueeze(0).log()
|
| 123 |
+
else:
|
| 124 |
+
proposal_1, score_1 = proposal_colspan(copy.deepcopy(layout), layout_score, srow, scol)
|
| 125 |
+
proposal_2, score_2 = proposal_rowspan(copy.deepcopy(layout), layout_score, srow, scol)
|
| 126 |
+
proposal_3, score_3 = proposal_maxcontain(copy.deepcopy(layout), layout_score, srow, scol)
|
| 127 |
+
proposal_4, score_4 = proposal_maxrowspan(copy.deepcopy(layout), layout_score, srow, scol)
|
| 128 |
+
proposal_5, score_5 = proposal_maxcolspan(copy.deepcopy(layout), layout_score, srow, scol)
|
| 129 |
+
proposals = torch.stack([proposal_1, proposal_2, proposal_3, proposal_4, proposal_5], dim=0)
|
| 130 |
+
scores = torch.stack([score_1.log(), score_2.log(), score_3.log(), score_4.log(), score_5.log()], dim=0)
|
| 131 |
+
return proposals, scores
|
| 132 |
+
|
| 133 |
+
def extend_segments(row_segments, rows_es, col_segments, cols_es, cells_spans, layouts, divide_labels):
|
| 134 |
+
batch_size = len(row_segments)
|
| 135 |
+
ext_row_segments = list()
|
| 136 |
+
ext_col_segments = list()
|
| 137 |
+
ext_cells_spans = list()
|
| 138 |
+
ext_layouts = list()
|
| 139 |
+
ext_divide_labels = list()
|
| 140 |
+
for batch_idx in range(batch_size):
|
| 141 |
+
row_segments_pi = row_segments[batch_idx]
|
| 142 |
+
col_segments_pi = col_segments[batch_idx]
|
| 143 |
+
rows_es_pi = rows_es[batch_idx]
|
| 144 |
+
cols_es_pi = cols_es[batch_idx]
|
| 145 |
+
cells_spans_pi = cells_spans[batch_idx]
|
| 146 |
+
|
| 147 |
+
ext_row_segments_pi = row_segments_pi + rows_es_pi
|
| 148 |
+
ext_col_segments_pi = col_segments_pi + cols_es_pi
|
| 149 |
+
|
| 150 |
+
row_segments_idx = sorted(list(range(len(ext_row_segments_pi))), key=lambda idx: ext_row_segments_pi[idx])
|
| 151 |
+
col_segments_idx = sorted(list(range(len(ext_col_segments_pi))), key=lambda idx: ext_col_segments_pi[idx])
|
| 152 |
+
|
| 153 |
+
ext_divide_labels.append(row_segments_idx.index(divide_labels[batch_idx].item()))
|
| 154 |
+
|
| 155 |
+
ext_row_segments.append([ext_row_segments_pi[idx] for idx in row_segments_idx])
|
| 156 |
+
ext_col_segments.append([ext_col_segments_pi[idx] for idx in col_segments_idx])
|
| 157 |
+
|
| 158 |
+
ext_layouts_pi = np.full((len(ext_row_segments_pi) - 1, len(ext_col_segments_pi) - 1), -1)
|
| 159 |
+
ext_cells_spans_pi = list()
|
| 160 |
+
for cell_idx, cell_span in enumerate(cells_spans_pi):
|
| 161 |
+
l, t, r, b = cell_span
|
| 162 |
+
l = col_segments_idx.index(l)
|
| 163 |
+
r = col_segments_idx.index(r+1) - 1
|
| 164 |
+
t = row_segments_idx.index(t)
|
| 165 |
+
b = row_segments_idx.index(b+1) - 1
|
| 166 |
+
ext_cells_spans_pi.append([l, t, r, b])
|
| 167 |
+
ext_layouts_pi[t:b+1, l:r+1] = cell_idx
|
| 168 |
+
ext_cells_spans.append(ext_cells_spans_pi)
|
| 169 |
+
ext_layouts.append(ext_layouts_pi)
|
| 170 |
+
|
| 171 |
+
return ext_row_segments, ext_col_segments, ext_cells_spans, aligned_layouts(ext_layouts, layouts), torch.tensor(ext_divide_labels).to(divide_labels.device)
|
| 172 |
+
|
| 173 |
+
def aligned_layouts(layouts_list, layouts):
|
| 174 |
+
batch_size = len(layouts_list)
|
| 175 |
+
dtype = layouts.dtype
|
| 176 |
+
device = layouts.device
|
| 177 |
+
|
| 178 |
+
max_row_nums = max([l.shape[0] for l in layouts_list])
|
| 179 |
+
max_col_nums = max([l.shape[1] for l in layouts_list])
|
| 180 |
+
|
| 181 |
+
aligned_layouts = list()
|
| 182 |
+
for batch_idx in range(batch_size):
|
| 183 |
+
num_rows_pi = layouts_list[batch_idx].shape[0]
|
| 184 |
+
num_cols_pi = layouts_list[batch_idx].shape[1]
|
| 185 |
+
layouts_pi = torch.from_numpy(layouts_list[batch_idx]).to(dtype=dtype, device=device)
|
| 186 |
+
aligned_layouts_pi = F.pad(
|
| 187 |
+
layouts_pi,
|
| 188 |
+
(0, max_col_nums-num_cols_pi, 0, max_row_nums-num_rows_pi),
|
| 189 |
+
mode='constant',
|
| 190 |
+
value=-1
|
| 191 |
+
)
|
| 192 |
+
aligned_layouts.append(aligned_layouts_pi)
|
| 193 |
+
aligned_layouts = torch.stack(aligned_layouts, dim=0)
|
| 194 |
+
return aligned_layouts
|
| 195 |
+
|
| 196 |
+
def parse_layout(spans, num_rows, num_cols):
|
| 197 |
+
layout = np.full([num_rows, num_cols], -1, dtype=np.int)
|
| 198 |
+
cell_count = 0
|
| 199 |
+
for x1, y1, x2, y2, prob in spans:
|
| 200 |
+
layout[y1:y2+1, x1:x2+1] = cell_count
|
| 201 |
+
cell_count += 1
|
| 202 |
+
|
| 203 |
+
cells_id = list()
|
| 204 |
+
for row_idx in range(num_rows):
|
| 205 |
+
for col_idx in range(num_cols):
|
| 206 |
+
cell_id = layout[row_idx, col_idx]
|
| 207 |
+
if cell_id in cells_id:
|
| 208 |
+
layout[row_idx, col_idx] = cells_id.index(cell_id)
|
| 209 |
+
else:
|
| 210 |
+
layout[row_idx, col_idx] = len(cells_id)
|
| 211 |
+
cells_id.append(cell_id)
|
| 212 |
+
return layout
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def parse_cells(layout, row_segments, col_segments):
|
| 216 |
+
cells = list()
|
| 217 |
+
num_cells = np.max(layout) + 1
|
| 218 |
+
for cell_id in range(num_cells):
|
| 219 |
+
cell_positions = np.argwhere(layout == cell_id)
|
| 220 |
+
y1 = np.min(cell_positions[:, 0])
|
| 221 |
+
y2 = np.max(cell_positions[:, 0])
|
| 222 |
+
x1 = np.min(cell_positions[:, 1])
|
| 223 |
+
x2 = np.max(cell_positions[:, 1])
|
| 224 |
+
assert np.all(layout[y1:y2, x1:x2] == cell_id)
|
| 225 |
+
x1 = col_segments[x1]
|
| 226 |
+
x2 = col_segments[x2+1]
|
| 227 |
+
y1 = row_segments[y1]
|
| 228 |
+
y2 = row_segments[y2+1]
|
| 229 |
+
cell = dict(
|
| 230 |
+
segmentation=[[[x1, y1], [x2, y1], [x2, y2], [x1, y2]]]
|
| 231 |
+
)
|
| 232 |
+
cells.append(cell)
|
| 233 |
+
return cells
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def process_layout(score, index):
|
| 237 |
+
layout = torch.full_like(index, -1)
|
| 238 |
+
layout_mask = torch.full_like(index, -1)
|
| 239 |
+
nrow, ncol = score.shape
|
| 240 |
+
for cell_id in range(nrow * ncol):
|
| 241 |
+
if layout_mask.min() != -1:
|
| 242 |
+
break
|
| 243 |
+
crow, ccol = torch.where(layout_mask == layout_mask.min())
|
| 244 |
+
ccol = ccol[crow == crow.min()].min()
|
| 245 |
+
crow = crow.min()
|
| 246 |
+
id = index[crow, ccol]
|
| 247 |
+
h, w = torch.where(index == id)
|
| 248 |
+
if h.shape[0] == 1 or w.shape[0] == 1:
|
| 249 |
+
layout_mask[h, w] = 1
|
| 250 |
+
layout[h, w] = cell_id
|
| 251 |
+
continue
|
| 252 |
+
else:
|
| 253 |
+
h_min = h.min()
|
| 254 |
+
h_max = h.max()
|
| 255 |
+
w_min = w.min()
|
| 256 |
+
w_max = w.max()
|
| 257 |
+
if torch.all(index[h_min:h_max+1, w_min:w_max+1] == id):
|
| 258 |
+
layout_mask[h_min:h_max+1, w_min:w_max+1] = 1
|
| 259 |
+
layout[h_min:h_max+1, w_min:w_max+1] = cell_id
|
| 260 |
+
else:
|
| 261 |
+
lf_row = crow
|
| 262 |
+
lf_col = ccol
|
| 263 |
+
|
| 264 |
+
col_mem = -1
|
| 265 |
+
for col_ in range(lf_col, w_max + 1):
|
| 266 |
+
if index[lf_row, col_] == id:
|
| 267 |
+
layout_mask[lf_row, col_] = 1
|
| 268 |
+
layout[lf_row, col_] = cell_id
|
| 269 |
+
col_mem = col_
|
| 270 |
+
else:
|
| 271 |
+
break
|
| 272 |
+
for row_ in range(lf_row + 1, h_max + 1):
|
| 273 |
+
if torch.all(index[row_, lf_col: col_mem + 1] == id):
|
| 274 |
+
layout_mask[row_, lf_col: col_mem + 1] = 1
|
| 275 |
+
layout[row_, lf_col: col_mem + 1] = cell_id
|
| 276 |
+
else:
|
| 277 |
+
break
|
| 278 |
+
return layout
|
| 279 |
+
|
| 280 |
+
def process_layout(score, index, use_score=False, is_merge=True, score_threshold=0.5):
|
| 281 |
+
if use_score:
|
| 282 |
+
if is_merge:
|
| 283 |
+
y, x = torch.where(score < score_threshold)
|
| 284 |
+
index[y, x] = index.max() + 1
|
| 285 |
+
else:
|
| 286 |
+
y, x = torch.where(score < score_threshold)
|
| 287 |
+
index[y, x] = torch.arange(index.max() + 1, index.max() + 1 + len(y)).to(index.device, index.dtype)
|
| 288 |
+
|
| 289 |
+
layout = torch.full_like(index, -1)
|
| 290 |
+
layout_mask = torch.full_like(index, -1)
|
| 291 |
+
nrow, ncol = score.shape
|
| 292 |
+
for cell_id in range(max(nrow * ncol, index.max() + 1)):
|
| 293 |
+
if layout_mask.min() != -1:
|
| 294 |
+
break
|
| 295 |
+
crow, ccol = torch.where(layout_mask == layout_mask.min())
|
| 296 |
+
ccol = ccol[crow == crow.min()].min()
|
| 297 |
+
crow = crow.min()
|
| 298 |
+
id = index[crow, ccol]
|
| 299 |
+
h, w = torch.where(index == id)
|
| 300 |
+
if h.shape[0] == 1 or w.shape[0] == 1:
|
| 301 |
+
layout_mask[h, w] = 1
|
| 302 |
+
layout[h, w] = cell_id
|
| 303 |
+
continue
|
| 304 |
+
else:
|
| 305 |
+
h_min = h.min()
|
| 306 |
+
h_max = h.max()
|
| 307 |
+
w_min = w.min()
|
| 308 |
+
w_max = w.max()
|
| 309 |
+
if torch.all(index[h_min:h_max+1, w_min:w_max+1] == id):
|
| 310 |
+
layout_mask[h_min:h_max+1, w_min:w_max+1] = 1
|
| 311 |
+
layout[h_min:h_max+1, w_min:w_max+1] = cell_id
|
| 312 |
+
else:
|
| 313 |
+
lf_row = crow
|
| 314 |
+
lf_col = ccol
|
| 315 |
+
|
| 316 |
+
col_mem = -1
|
| 317 |
+
for col_ in range(lf_col, w_max + 1):
|
| 318 |
+
if index[lf_row, col_] == id:
|
| 319 |
+
layout_mask[lf_row, col_] = 1
|
| 320 |
+
layout[lf_row, col_] = cell_id
|
| 321 |
+
col_mem = col_
|
| 322 |
+
else:
|
| 323 |
+
break
|
| 324 |
+
for row_ in range(lf_row + 1, h_max + 1):
|
| 325 |
+
if torch.all(index[row_, lf_col: col_mem + 1] == id):
|
| 326 |
+
layout_mask[row_, lf_col: col_mem + 1] = 1
|
| 327 |
+
layout[row_, lf_col: col_mem + 1] = cell_id
|
| 328 |
+
else:
|
| 329 |
+
break
|
| 330 |
+
return layout
|
| 331 |
+
|
| 332 |
+
def layout2spans(layout):
|
| 333 |
+
rows, cols = layout.shape[-2:]
|
| 334 |
+
cells_span = list()
|
| 335 |
+
for cell_id in range(rows * cols):
|
| 336 |
+
cell_positions = np.argwhere(layout == cell_id)
|
| 337 |
+
if len(cell_positions) == 0:
|
| 338 |
+
continue
|
| 339 |
+
y1 = np.min(cell_positions[:, 0])
|
| 340 |
+
y2 = np.max(cell_positions[:, 0])
|
| 341 |
+
x1 = np.min(cell_positions[:, 1])
|
| 342 |
+
x2 = np.max(cell_positions[:, 1])
|
| 343 |
+
assert np.all(layout[y1:y2, x1:x2] == cell_id)
|
| 344 |
+
cells_span.append([x1, y1, x2, y2])
|
| 345 |
+
return [cells_span]
|
| 346 |
+
|
| 347 |
+
def spatial_att_to_spans(spatial_att_weight_pred):
|
| 348 |
+
max_score, max_index = spatial_att_weight_pred.max(dim=0)
|
| 349 |
+
layout = process_layout(max_score, max_index, use_score=True, is_merge=False)
|
| 350 |
+
layout = process_layout(max_score, layout)
|
| 351 |
+
|
| 352 |
+
layout = layout.cpu().numpy()
|
| 353 |
+
spans = layout2spans(layout)
|
| 354 |
+
return spans
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def save_logitmap(filename, logit):
|
| 358 |
+
cv2.imwrite(filename, (logit.sigmoid()*255).cpu().numpy().astype('uint8'))
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def draw_spans(dst, src, spans, type):
|
| 362 |
+
image = cv2.imread(src)
|
| 363 |
+
H, W, *_ = image.shape
|
| 364 |
+
for span in spans:
|
| 365 |
+
if type == 'col':
|
| 366 |
+
cv2.rectangle(image, (span[0], 0), (span[1], H), (0, 0, 255), thickness=1)
|
| 367 |
+
elif type == 'row':
|
| 368 |
+
cv2.rectangle(image, (0, span[0]), (W, span[1]), (0, 0, 255), thickness=1)
|
| 369 |
+
cv2.imwrite(dst, image)
|
| 370 |
+
|
| 371 |
+
|
libs/utils/__init__.py
ADDED
|
File without changes
|
libs/utils/cal_f1.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
from numpy.core.fromnumeric import sort
|
| 3 |
+
import tqdm
|
| 4 |
+
import json
|
| 5 |
+
import copy
|
| 6 |
+
import Polygon
|
| 7 |
+
import numpy as np
|
| 8 |
+
from .scitsr.eval import json2Relations, eval_relations
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def parse_layout(spans, num_rows, num_cols):
|
| 12 |
+
layout = np.full([num_rows, num_cols], -1, dtype=np.int)
|
| 13 |
+
cell_count = 0
|
| 14 |
+
for x1, y1, x2, y2 in spans:
|
| 15 |
+
layout[y1:y2+1, x1:x2+1] = cell_count
|
| 16 |
+
cell_count += 1
|
| 17 |
+
|
| 18 |
+
cells_id = list()
|
| 19 |
+
for row_idx in range(num_rows):
|
| 20 |
+
for col_idx in range(num_cols):
|
| 21 |
+
cell_id = layout[row_idx, col_idx]
|
| 22 |
+
if cell_id in cells_id:
|
| 23 |
+
layout[row_idx, col_idx] = cells_id.index(cell_id)
|
| 24 |
+
else:
|
| 25 |
+
layout[row_idx, col_idx] = len(cells_id)
|
| 26 |
+
cells_id.append(cell_id)
|
| 27 |
+
return layout
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def parse_cells(layout, spans, row_segments, col_segments, lines):
|
| 31 |
+
cells = list()
|
| 32 |
+
num_cells = np.max(layout) + 1
|
| 33 |
+
for cell_id in range(num_cells):
|
| 34 |
+
cell_positions = np.argwhere(layout == cell_id)
|
| 35 |
+
y1 = np.min(cell_positions[:, 0])
|
| 36 |
+
y2 = np.max(cell_positions[:, 0])
|
| 37 |
+
x1 = np.min(cell_positions[:, 1])
|
| 38 |
+
x2 = np.max(cell_positions[:, 1])
|
| 39 |
+
assert np.all(layout[y1:y2, x1:x2] == cell_id)
|
| 40 |
+
x1 = col_segments[x1]
|
| 41 |
+
x2 = col_segments[x2+1]
|
| 42 |
+
y1 = row_segments[y1]
|
| 43 |
+
y2 = row_segments[y2+1]
|
| 44 |
+
cell = dict(
|
| 45 |
+
segmentation=[[[x1, y1], [x2, y1], [x2, y2], [x1, y2]]]
|
| 46 |
+
)
|
| 47 |
+
cells.append(cell)
|
| 48 |
+
|
| 49 |
+
extend_cell_lines(cells, lines)
|
| 50 |
+
|
| 51 |
+
return cells
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def extend_cell_lines(cells, lines):
|
| 55 |
+
def segmentation_to_polygon(segmentation):
|
| 56 |
+
polygon = Polygon.Polygon()
|
| 57 |
+
for contour in segmentation:
|
| 58 |
+
polygon = polygon + Polygon.Polygon(contour)
|
| 59 |
+
return polygon
|
| 60 |
+
|
| 61 |
+
lines = copy.deepcopy(lines)
|
| 62 |
+
|
| 63 |
+
cells_poly = [segmentation_to_polygon(item['segmentation']) for item in cells]
|
| 64 |
+
lines_poly = [segmentation_to_polygon(item['segmentation']) for item in lines]
|
| 65 |
+
|
| 66 |
+
cells_lines = [[] for _ in range(len(cells))]
|
| 67 |
+
|
| 68 |
+
for line_idx, line_poly in enumerate(lines_poly):
|
| 69 |
+
if line_poly.area() == 0:
|
| 70 |
+
continue
|
| 71 |
+
line_area = line_poly.area()
|
| 72 |
+
max_overlap = 0
|
| 73 |
+
max_overlap_idx = None
|
| 74 |
+
for cell_idx, cell_poly in enumerate(cells_poly):
|
| 75 |
+
overlap = (cell_poly & line_poly).area() / line_area
|
| 76 |
+
if overlap > max_overlap:
|
| 77 |
+
max_overlap_idx = cell_idx
|
| 78 |
+
max_overlap = overlap
|
| 79 |
+
if max_overlap > 0:
|
| 80 |
+
cells_lines[max_overlap_idx].append(line_idx)
|
| 81 |
+
lines_y1 = [segmentation_to_bbox(item['segmentation'])[1] for item in lines]
|
| 82 |
+
cells_lines = [sorted(item, key=lambda idx: lines_y1[idx]) for item in cells_lines]
|
| 83 |
+
|
| 84 |
+
for cell, cell_lines in zip(cells, cells_lines):
|
| 85 |
+
transcript = []
|
| 86 |
+
for idx in cell_lines:
|
| 87 |
+
transcript.extend(lines[idx]['transcript'])
|
| 88 |
+
cell['transcript'] = transcript
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def segmentation_to_bbox(segmentation):
|
| 92 |
+
x1 = min([min([pt[0] for pt in contour]) for contour in segmentation])
|
| 93 |
+
y1 = min([min([pt[1] for pt in contour]) for contour in segmentation])
|
| 94 |
+
x2 = max([max([pt[0] for pt in contour]) for contour in segmentation])
|
| 95 |
+
y2 = max([max([pt[1] for pt in contour]) for contour in segmentation])
|
| 96 |
+
return [x1, y1, x2, y2]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def cal_cell_spans(table):
|
| 100 |
+
layout = table['layout']
|
| 101 |
+
num_cells = len(table['cells'])
|
| 102 |
+
cells_span = list()
|
| 103 |
+
for cell_id in range(num_cells):
|
| 104 |
+
cell_positions = np.argwhere(layout == cell_id)
|
| 105 |
+
y1 = np.min(cell_positions[:, 0])
|
| 106 |
+
y2 = np.max(cell_positions[:, 0])
|
| 107 |
+
x1 = np.min(cell_positions[:, 1])
|
| 108 |
+
x2 = np.max(cell_positions[:, 1])
|
| 109 |
+
assert np.all(layout[y1:y2, x1:x2] == cell_id)
|
| 110 |
+
cells_span.append([x1, y1, x2, y2])
|
| 111 |
+
return cells_span
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def pred_result_to_table(table, pred_result):
|
| 115 |
+
# gt ocr result
|
| 116 |
+
lines = [dict(segmentation=cell['segmentation'], transcript=cell['transcript']) for cell in table['cells'] if 'bbox' in cell.keys()]
|
| 117 |
+
|
| 118 |
+
row_segments, col_segments, divide, spans = pred_result
|
| 119 |
+
num_rows = len(row_segments) - 1
|
| 120 |
+
num_cols = len(col_segments) - 1
|
| 121 |
+
|
| 122 |
+
layout = parse_layout(spans, num_rows, num_cols)
|
| 123 |
+
cells = parse_cells(layout, spans, row_segments, col_segments, lines)
|
| 124 |
+
head_rows = list(range(0, divide))
|
| 125 |
+
body_rows = list(range(divide, num_rows))
|
| 126 |
+
|
| 127 |
+
table = dict(
|
| 128 |
+
layout=layout,
|
| 129 |
+
head_rows=head_rows,
|
| 130 |
+
body_rows=body_rows,
|
| 131 |
+
cells=cells
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return table
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def table_to_relations(table):
|
| 138 |
+
cell_spans = cal_cell_spans(table)
|
| 139 |
+
contents = [''.join(cell['transcript']).split() for cell in table['cells']]
|
| 140 |
+
relations = []
|
| 141 |
+
for span, content in zip(cell_spans, contents):
|
| 142 |
+
x1, y1, x2, y2 = span
|
| 143 |
+
relations.append(dict(start_row=y1, end_row=y2, start_col=x1, end_col=x2, content=content))
|
| 144 |
+
return dict(cells=relations)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def cal_f1(label, pred):
|
| 148 |
+
label = json2Relations(label, splitted_content=True)
|
| 149 |
+
pred = json2Relations(pred, splitted_content=True)
|
| 150 |
+
precision, recall = eval_relations(gt=[label], res=[pred], cmp_blank=True)
|
| 151 |
+
f1 = 2.0 * precision * recall / (precision + recall) if precision + recall > 0 else 0
|
| 152 |
+
return [precision, recall, f1]
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def single_process(labels, preds):
|
| 156 |
+
scores = dict()
|
| 157 |
+
for key in tqdm.tqdm(labels.keys()):
|
| 158 |
+
pred = preds.get(key, '')
|
| 159 |
+
label = labels.get(key, '')
|
| 160 |
+
score = cal_f1(label, pred)
|
| 161 |
+
scores[key] = score
|
| 162 |
+
return scores
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _worker(labels, preds, keys, result_queue):
|
| 166 |
+
for key in keys:
|
| 167 |
+
label = labels.get(key, '')
|
| 168 |
+
pred = preds.get(key, '')
|
| 169 |
+
score = cal_f1(label, pred)
|
| 170 |
+
result_queue.put((key, score))
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def multi_process(labels, preds, num_workers):
|
| 174 |
+
import multiprocessing
|
| 175 |
+
manager = multiprocessing.Manager()
|
| 176 |
+
result_queue = manager.Queue()
|
| 177 |
+
keys = list(labels.keys())
|
| 178 |
+
workers = list()
|
| 179 |
+
for worker_idx in range(num_workers):
|
| 180 |
+
worker = multiprocessing.Process(
|
| 181 |
+
target=_worker,
|
| 182 |
+
args=(
|
| 183 |
+
labels,
|
| 184 |
+
preds,
|
| 185 |
+
keys[worker_idx::num_workers],
|
| 186 |
+
result_queue
|
| 187 |
+
)
|
| 188 |
+
)
|
| 189 |
+
worker.daemon = True
|
| 190 |
+
worker.start()
|
| 191 |
+
workers.append(worker)
|
| 192 |
+
|
| 193 |
+
scores = dict()
|
| 194 |
+
tq = tqdm.tqdm(total=len(keys))
|
| 195 |
+
for _ in range(len(keys)):
|
| 196 |
+
key, val = result_queue.get()
|
| 197 |
+
scores[key] = val
|
| 198 |
+
P, R, F1 = (100 * np.array(list(scores.values()))).mean(0).tolist()
|
| 199 |
+
tq.set_description('P: %.2f, R: %.2f, F1: %.2f' % (P, R, F1), False)
|
| 200 |
+
tq.update()
|
| 201 |
+
|
| 202 |
+
return scores
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def evaluate_f1(labels, preds, num_workers=0):
|
| 206 |
+
preds = {idx: pred for idx, pred in enumerate(preds)}
|
| 207 |
+
labels = {idx: label for idx, label in enumerate(labels)}
|
| 208 |
+
if num_workers == 0:
|
| 209 |
+
scores = single_process(labels, preds)
|
| 210 |
+
else:
|
| 211 |
+
scores = multi_process(labels, preds, num_workers)
|
| 212 |
+
sorted_idx = sorted(list(range(len(list(scores)))), key=lambda idx: list(scores.keys())[idx])
|
| 213 |
+
scores = [scores[idx] for idx in sorted_idx]
|
| 214 |
+
return scores
|
libs/utils/checkpoint.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from .comm import get_rank, synchronize
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def save_checkpoint(checkpoint, model, optimizer=None, best_metric=None, epoch=None):
|
| 7 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
| 8 |
+
model = model.module
|
| 9 |
+
if get_rank() == 0:
|
| 10 |
+
if not os.path.exists(os.path.dirname(checkpoint)):
|
| 11 |
+
os.makedirs(os.path.dirname(checkpoint))
|
| 12 |
+
|
| 13 |
+
infos = dict()
|
| 14 |
+
infos['model_param'] = model.state_dict()
|
| 15 |
+
if optimizer is not None:
|
| 16 |
+
infos['opt_param'] = optimizer.state_dict()
|
| 17 |
+
|
| 18 |
+
if best_metric is not None:
|
| 19 |
+
infos['best_metric'] = best_metric
|
| 20 |
+
|
| 21 |
+
if epoch is not None:
|
| 22 |
+
infos['epoch'] = epoch
|
| 23 |
+
|
| 24 |
+
torch.save(infos, checkpoint)
|
| 25 |
+
synchronize()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load_checkpoint(checkpoint, model, optimizer=None):
|
| 29 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
| 30 |
+
model = model.module
|
| 31 |
+
checkpoint = torch.load(checkpoint, map_location='cpu')
|
| 32 |
+
|
| 33 |
+
model.load_state_dict(checkpoint['model_param'], strict=False)
|
| 34 |
+
|
| 35 |
+
if (optimizer is not None) and ('opt_param' in checkpoint):
|
| 36 |
+
optimizer.load_state_dict(checkpoint['opt_param'])
|
| 37 |
+
|
| 38 |
+
if 'best_metric' in checkpoint:
|
| 39 |
+
best_metric = checkpoint['best_metric']
|
| 40 |
+
else:
|
| 41 |
+
best_metric = None
|
| 42 |
+
|
| 43 |
+
if 'epoch' in checkpoint:
|
| 44 |
+
epoch = checkpoint['epoch']
|
| 45 |
+
else:
|
| 46 |
+
epoch = None
|
| 47 |
+
return best_metric, epoch
|
libs/utils/comm.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file contains primitives for multi-gpu communication.
|
| 3 |
+
This is useful when doing distributed training.
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import pickle
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def distributed():
|
| 13 |
+
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
|
| 14 |
+
distributed = num_gpus > 1
|
| 15 |
+
return distributed
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_world_size():
|
| 19 |
+
if not dist.is_available():
|
| 20 |
+
return 1
|
| 21 |
+
if not dist.is_initialized():
|
| 22 |
+
return 1
|
| 23 |
+
return dist.get_world_size()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_rank():
|
| 27 |
+
if not dist.is_available():
|
| 28 |
+
return 0
|
| 29 |
+
if not dist.is_initialized():
|
| 30 |
+
return 0
|
| 31 |
+
return dist.get_rank()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_local_rank():
|
| 35 |
+
if 'LOCAL_RANK' not in os.environ:
|
| 36 |
+
return get_rank()
|
| 37 |
+
else:
|
| 38 |
+
return int(os.environ['LOCAL_RANK'])
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def is_main_process():
|
| 42 |
+
return get_rank() == 0
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def synchronize():
|
| 46 |
+
"""
|
| 47 |
+
Helper function to synchronize (barrier) among all processes when
|
| 48 |
+
using distributed training
|
| 49 |
+
"""
|
| 50 |
+
if not dist.is_available():
|
| 51 |
+
return
|
| 52 |
+
if not dist.is_initialized():
|
| 53 |
+
return
|
| 54 |
+
world_size = dist.get_world_size()
|
| 55 |
+
if world_size == 1:
|
| 56 |
+
return
|
| 57 |
+
dist.barrier()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def all_gather(data):
|
| 61 |
+
"""
|
| 62 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
| 63 |
+
Args:
|
| 64 |
+
data: any picklable object
|
| 65 |
+
Returns:
|
| 66 |
+
list[data]: list of data gathered from each rank
|
| 67 |
+
"""
|
| 68 |
+
world_size = get_world_size()
|
| 69 |
+
if world_size == 1:
|
| 70 |
+
return [data]
|
| 71 |
+
|
| 72 |
+
# serialized to a Tensor
|
| 73 |
+
buffer = pickle.dumps(data)
|
| 74 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
| 75 |
+
tensor = torch.ByteTensor(storage).to("cuda")
|
| 76 |
+
|
| 77 |
+
# obtain Tensor size of each rank
|
| 78 |
+
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
|
| 79 |
+
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
|
| 80 |
+
dist.all_gather(size_list, local_size)
|
| 81 |
+
size_list = [int(size.item()) for size in size_list]
|
| 82 |
+
max_size = max(size_list)
|
| 83 |
+
|
| 84 |
+
# receiving Tensor from all ranks
|
| 85 |
+
# we pad the tensor because torch all_gather does not support
|
| 86 |
+
# gathering tensors of different shapes
|
| 87 |
+
tensor_list = []
|
| 88 |
+
for _ in size_list:
|
| 89 |
+
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
|
| 90 |
+
if local_size != max_size:
|
| 91 |
+
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
|
| 92 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
| 93 |
+
dist.all_gather(tensor_list, tensor)
|
| 94 |
+
|
| 95 |
+
data_list = []
|
| 96 |
+
for size, tensor in zip(size_list, tensor_list):
|
| 97 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
| 98 |
+
data_list.append(pickle.loads(buffer))
|
| 99 |
+
|
| 100 |
+
return data_list
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def reduce_dict(input_dict, average=True):
|
| 104 |
+
"""
|
| 105 |
+
Args:
|
| 106 |
+
input_dict (dict): all the values will be reduced
|
| 107 |
+
average (bool): whether to do average or sum
|
| 108 |
+
Reduce the values in the dictionary from all processes so that process with rank
|
| 109 |
+
0 has the averaged results. Returns a dict with the same fields as
|
| 110 |
+
input_dict, after reduction.
|
| 111 |
+
"""
|
| 112 |
+
world_size = get_world_size()
|
| 113 |
+
if world_size < 2:
|
| 114 |
+
return input_dict
|
| 115 |
+
with torch.no_grad():
|
| 116 |
+
names = []
|
| 117 |
+
values = []
|
| 118 |
+
# sort the keys so that they are consistent across processes
|
| 119 |
+
for k in sorted(input_dict.keys()):
|
| 120 |
+
names.append(k)
|
| 121 |
+
values.append(input_dict[k])
|
| 122 |
+
values = torch.stack(values, dim=0)
|
| 123 |
+
dist.reduce(values, dst=0)
|
| 124 |
+
if dist.get_rank() == 0 and average:
|
| 125 |
+
# only main process gets accumulated, so only divide by
|
| 126 |
+
# world_size in this case
|
| 127 |
+
values /= world_size
|
| 128 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
| 129 |
+
return reduced_dict
|
libs/utils/context_cacher.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ContextCacher:
|
| 2 |
+
def __init__(self):
|
| 3 |
+
self.infos = dict()
|
| 4 |
+
|
| 5 |
+
def reset(self):
|
| 6 |
+
self.infos.clear()
|
| 7 |
+
|
| 8 |
+
def cache_info(self, key, info):
|
| 9 |
+
self.infos[key] = info
|
| 10 |
+
|
| 11 |
+
def get_info(self, key):
|
| 12 |
+
return self.infos[key]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
global_context_cacher = ContextCacher()
|
libs/utils/counter.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from .comm import distributed, all_gather
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def format_dict(res_dict):
|
| 7 |
+
res_strs = []
|
| 8 |
+
for key, val in res_dict.items():
|
| 9 |
+
res_strs.append('%s: %s' % (key, val))
|
| 10 |
+
return ', '.join(res_strs)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Counter:
|
| 14 |
+
def __init__(self, cache_nums=1000):
|
| 15 |
+
self.cache_nums = cache_nums
|
| 16 |
+
self.reset()
|
| 17 |
+
|
| 18 |
+
def update(self, metric):
|
| 19 |
+
for key, val in metric.items():
|
| 20 |
+
if isinstance(val, torch.Tensor):
|
| 21 |
+
val = val.item()
|
| 22 |
+
self.metric_dict[key].append(val)
|
| 23 |
+
if self.cache_nums is not None:
|
| 24 |
+
self.metric_dict[key] = self.metric_dict[key][-1*self.cache_nums:]
|
| 25 |
+
|
| 26 |
+
def reset(self):
|
| 27 |
+
self.metric_dict = defaultdict(list)
|
| 28 |
+
|
| 29 |
+
def _sync(self):
|
| 30 |
+
metric_dicts = all_gather(self.metric_dict)
|
| 31 |
+
total_metric_dict = defaultdict(list)
|
| 32 |
+
for metric_dict in metric_dicts:
|
| 33 |
+
for key, val in metric_dict.items():
|
| 34 |
+
total_metric_dict[key].extend(val)
|
| 35 |
+
return total_metric_dict
|
| 36 |
+
|
| 37 |
+
def format_mean(self, sync=True):
|
| 38 |
+
if sync and distributed():
|
| 39 |
+
metric_dict = self._sync()
|
| 40 |
+
else:
|
| 41 |
+
metric_dict = self.metric_dict
|
| 42 |
+
res_dict = {key: '%.4f' % (sum(val)/len(val)) for key, val in metric_dict.items()}
|
| 43 |
+
return format_dict(res_dict)
|
libs/utils/format_translate.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import copy
|
| 3 |
+
import Polygon
|
| 4 |
+
import numpy as np
|
| 5 |
+
from bs4 import BeautifulSoup as bs
|
| 6 |
+
from .time_counter import format_table
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def check_continuous(seq):
|
| 10 |
+
if len(seq) > 0:
|
| 11 |
+
pre_val = seq[0]
|
| 12 |
+
for val in seq[1:]:
|
| 13 |
+
assert pre_val + 1 == val
|
| 14 |
+
pre_val = val
|
| 15 |
+
|
| 16 |
+
def table_to_latex(table):
|
| 17 |
+
def cal_cls_id(transcript):
|
| 18 |
+
transcript = ''.join(transcript)
|
| 19 |
+
if transcript == '':
|
| 20 |
+
return '</none>'
|
| 21 |
+
elif transcript == '<b> </b>':
|
| 22 |
+
return '</bold>'
|
| 23 |
+
elif transcript == ' ':
|
| 24 |
+
return '</space>'
|
| 25 |
+
else:
|
| 26 |
+
return '</line>'
|
| 27 |
+
assert table['layout'].max() + 1 == len(table['cells'])
|
| 28 |
+
latex = [cal_cls_id(cell['transcript']) for cell in table['cells']]
|
| 29 |
+
return latex
|
| 30 |
+
|
| 31 |
+
def html_to_table(html):
|
| 32 |
+
tokens = html['html']['structure']['tokens']
|
| 33 |
+
|
| 34 |
+
layout = [[]]
|
| 35 |
+
|
| 36 |
+
def extend_table(x, y):
|
| 37 |
+
assert (x >= 0) and (y >= 0)
|
| 38 |
+
nonlocal layout
|
| 39 |
+
|
| 40 |
+
if x >= len(layout[0]):
|
| 41 |
+
for row in layout:
|
| 42 |
+
row.extend([-1] * (x - len(row) + 1))
|
| 43 |
+
|
| 44 |
+
if y >= len(layout):
|
| 45 |
+
for _ in range(y - len(layout) + 1):
|
| 46 |
+
layout.append([-1] * len(layout[0]))
|
| 47 |
+
|
| 48 |
+
def set_cell_val(x, y, val):
|
| 49 |
+
assert (x >= 0) and (y >= 0)
|
| 50 |
+
nonlocal layout
|
| 51 |
+
extend_table(x, y)
|
| 52 |
+
layout[y][x] = val
|
| 53 |
+
|
| 54 |
+
def get_cell_val(x, y):
|
| 55 |
+
assert (x >= 0) and (y >= 0)
|
| 56 |
+
nonlocal layout
|
| 57 |
+
extend_table(x, y)
|
| 58 |
+
return layout[y][x]
|
| 59 |
+
|
| 60 |
+
def parse_span_val(token):
|
| 61 |
+
span_val = int(token[token.index('"') + 1:token.rindex('"')])
|
| 62 |
+
return span_val
|
| 63 |
+
|
| 64 |
+
def maskout_left_rows():
|
| 65 |
+
nonlocal row_idx, layout
|
| 66 |
+
layout = layout[:max(row_idx+1, 1)]
|
| 67 |
+
|
| 68 |
+
row_idx = -1
|
| 69 |
+
col_idx = -1
|
| 70 |
+
line_idx = -1
|
| 71 |
+
inside_head = False
|
| 72 |
+
inside_body = False
|
| 73 |
+
head_rows = list()
|
| 74 |
+
body_rows = list()
|
| 75 |
+
col_span = 1
|
| 76 |
+
row_span = 1
|
| 77 |
+
for token in tokens:
|
| 78 |
+
if token == '<thead>':
|
| 79 |
+
inside_head = True
|
| 80 |
+
maskout_left_rows()
|
| 81 |
+
elif token == '</thead>':
|
| 82 |
+
inside_head = False
|
| 83 |
+
maskout_left_rows()
|
| 84 |
+
elif token == '<tbody>':
|
| 85 |
+
inside_body = True
|
| 86 |
+
maskout_left_rows()
|
| 87 |
+
elif token == '</tbody>':
|
| 88 |
+
inside_body = False
|
| 89 |
+
maskout_left_rows()
|
| 90 |
+
elif token == '<tr>':
|
| 91 |
+
row_idx += 1
|
| 92 |
+
col_idx = -1
|
| 93 |
+
if inside_head:
|
| 94 |
+
head_rows.append(row_idx)
|
| 95 |
+
if inside_body:
|
| 96 |
+
body_rows.append(row_idx)
|
| 97 |
+
elif token in ['<td>', '<td']:
|
| 98 |
+
line_idx += 1
|
| 99 |
+
col_idx += 1
|
| 100 |
+
row_span = 1
|
| 101 |
+
col_span = 1
|
| 102 |
+
while get_cell_val(col_idx, row_idx) != -1:
|
| 103 |
+
col_idx += 1
|
| 104 |
+
elif 'colspan' in token:
|
| 105 |
+
col_span = parse_span_val(token)
|
| 106 |
+
elif 'rowspan' in token:
|
| 107 |
+
row_span = parse_span_val(token)
|
| 108 |
+
elif token == '</td>':
|
| 109 |
+
for cur_row_idx in range(row_idx, row_idx + row_span):
|
| 110 |
+
for cur_col_idx in range(col_idx, col_idx + col_span):
|
| 111 |
+
set_cell_val(cur_col_idx, cur_row_idx, line_idx)
|
| 112 |
+
col_idx += col_span - 1
|
| 113 |
+
|
| 114 |
+
check_continuous(head_rows)
|
| 115 |
+
check_continuous(body_rows)
|
| 116 |
+
assert len(set(head_rows) | set(body_rows)) == len(layout)
|
| 117 |
+
layout = np.array(layout)
|
| 118 |
+
assert np.all(layout >= 0)
|
| 119 |
+
|
| 120 |
+
cells_info = list()
|
| 121 |
+
for cell_idx, cell in enumerate(html['html']['cells']):
|
| 122 |
+
transcript = cell['tokens']
|
| 123 |
+
cell_info = dict(
|
| 124 |
+
transcript=transcript
|
| 125 |
+
)
|
| 126 |
+
if 'bbox' in cell:
|
| 127 |
+
x1, y1, x2, y2 = cell['bbox']
|
| 128 |
+
cell_info['bbox'] = [x1, y1, x2, y2]
|
| 129 |
+
cell_info['segmentation'] = [[[x1, y1], [x2, y1], [x2, y2], [x1, y2]]]
|
| 130 |
+
cells_info.append(cell_info)
|
| 131 |
+
|
| 132 |
+
table = dict(
|
| 133 |
+
layout=layout,
|
| 134 |
+
cells=cells_info,
|
| 135 |
+
head_rows=head_rows,
|
| 136 |
+
body_rows=body_rows
|
| 137 |
+
)
|
| 138 |
+
return table
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def segmentation_to_bbox(segmentation):
|
| 142 |
+
x1 = min([min([pt[0] for pt in contour]) for contour in segmentation])
|
| 143 |
+
y1 = min([min([pt[1] for pt in contour]) for contour in segmentation])
|
| 144 |
+
x2 = max([max([pt[0] for pt in contour]) for contour in segmentation])
|
| 145 |
+
y2 = max([max([pt[1] for pt in contour]) for contour in segmentation])
|
| 146 |
+
return [x1, y1, x2, y2]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def table_to_html(table):
|
| 150 |
+
layout = table['layout']
|
| 151 |
+
head_rows = table['head_rows']
|
| 152 |
+
body_rows = table['body_rows']
|
| 153 |
+
|
| 154 |
+
cells_span = list()
|
| 155 |
+
for cell_idx in range(len(table['cells'])):
|
| 156 |
+
cell_positions = np.argwhere(layout == cell_idx)
|
| 157 |
+
row_span = [np.min(cell_positions[:, 0]), np.max(cell_positions[:, 0]) + 1]
|
| 158 |
+
col_span = [np.min(cell_positions[:, 1]), np.max(cell_positions[:, 1]) + 1]
|
| 159 |
+
assert np.all(layout[row_span[0]:row_span[1], col_span[0]:col_span[1]] == cell_idx)
|
| 160 |
+
cells_span.append([row_span, col_span])
|
| 161 |
+
|
| 162 |
+
cells = list()
|
| 163 |
+
tokens = ['<thead>']
|
| 164 |
+
inside_head = True
|
| 165 |
+
for row_idx in range(layout.shape[0]):
|
| 166 |
+
if row_idx in body_rows:
|
| 167 |
+
if inside_head:
|
| 168 |
+
tokens.append('</thead>')
|
| 169 |
+
tokens.append('<tbody>')
|
| 170 |
+
inside_head = False
|
| 171 |
+
tokens.append('<tr>')
|
| 172 |
+
for col_idx in range(table['layout'].shape[1]):
|
| 173 |
+
cell_idx = layout[row_idx][col_idx]
|
| 174 |
+
assert cell_idx <= len(cells)
|
| 175 |
+
if cell_idx == len(cells):
|
| 176 |
+
row_span, col_span = cells_span[cell_idx]
|
| 177 |
+
if (row_span[1] - row_span[0]) == 1 and (col_span[1] - col_span[0] == 1):
|
| 178 |
+
tokens.append('<td>')
|
| 179 |
+
else:
|
| 180 |
+
tokens.append('<td')
|
| 181 |
+
if (row_span[1] - row_span[0]) > 1:
|
| 182 |
+
tokens.append(' rowspan="%d"' % (row_span[1] - row_span[0]))
|
| 183 |
+
if (col_span[1] - col_span[0]) > 1:
|
| 184 |
+
tokens.append(' colspan="%d"' % (col_span[1] - col_span[0]))
|
| 185 |
+
tokens.append('>')
|
| 186 |
+
tokens.append('</td>')
|
| 187 |
+
|
| 188 |
+
cell = dict()
|
| 189 |
+
cell['tokens'] = table['cells'][cell_idx]['transcript']
|
| 190 |
+
if 'segmentation' in table['cells'][cell_idx]:
|
| 191 |
+
cell['bbox'] = segmentation_to_bbox(table['cells'][cell_idx]['segmentation'])
|
| 192 |
+
cells.append(cell)
|
| 193 |
+
tokens.append('</tr>')
|
| 194 |
+
if inside_head:
|
| 195 |
+
tokens.append('</thead>')
|
| 196 |
+
tokens.append('<tbody>')
|
| 197 |
+
tokens.append('</tbody>')
|
| 198 |
+
|
| 199 |
+
html = dict(
|
| 200 |
+
html=dict(
|
| 201 |
+
cells=cells,
|
| 202 |
+
structure=dict(
|
| 203 |
+
tokens=tokens
|
| 204 |
+
)
|
| 205 |
+
)
|
| 206 |
+
)
|
| 207 |
+
return html
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def format_html_for_vis(html):
|
| 211 |
+
html_string = '''<html>
|
| 212 |
+
<head>
|
| 213 |
+
<meta charset="UTF-8">
|
| 214 |
+
<style>
|
| 215 |
+
table, th, td {
|
| 216 |
+
border: 1px solid black;
|
| 217 |
+
font-size: 10px;
|
| 218 |
+
}
|
| 219 |
+
</style>
|
| 220 |
+
</head>
|
| 221 |
+
<body>
|
| 222 |
+
<table frame="hsides" rules="groups" width="100%%">
|
| 223 |
+
%s
|
| 224 |
+
</table>
|
| 225 |
+
</body>
|
| 226 |
+
</html>''' % ''.join(html['html']['structure']['tokens'])
|
| 227 |
+
cell_nodes = list(re.finditer(r'(<td[^<>]*>)(</td>)', html_string))
|
| 228 |
+
assert len(cell_nodes) == len(html['html']['cells']), 'Number of cells defined in tags does not match the length of cells'
|
| 229 |
+
cells = [''.join(c['tokens']) for c in html['html']['cells']]
|
| 230 |
+
offset = 0
|
| 231 |
+
for n, cell in zip(cell_nodes, cells):
|
| 232 |
+
html_string = html_string[:n.end(1) + offset] + cell + html_string[n.start(2) + offset:]
|
| 233 |
+
offset += len(cell)
|
| 234 |
+
# prettify the html
|
| 235 |
+
soup = bs(html_string)
|
| 236 |
+
html_string = soup.prettify()
|
| 237 |
+
return html_string
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def format_html(html):
|
| 241 |
+
html_string = '''<html><body><table>%s</table></body></html>''' % ''.join(html['html']['structure']['tokens'])
|
| 242 |
+
cell_nodes = list(re.finditer(r'(<td[^<>]*>)(</td>)', html_string))
|
| 243 |
+
assert len(cell_nodes) == len(html['html']['cells']), 'Number of cells defined in tags does not match the length of cells'
|
| 244 |
+
cells = [''.join(c['tokens']) for c in html['html']['cells']]
|
| 245 |
+
offset = 0
|
| 246 |
+
for n, cell in zip(cell_nodes, cells):
|
| 247 |
+
html_string = html_string[:n.end(1) + offset] + cell + html_string[n.start(2) + offset:]
|
| 248 |
+
offset += len(cell)
|
| 249 |
+
return html_string
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def format_table_layout(table):
|
| 253 |
+
layout = table['table']['layout']
|
| 254 |
+
cell_lines = [cell['lines_idx'] for cell in table['table']['cells']]
|
| 255 |
+
|
| 256 |
+
table_cells_info = list()
|
| 257 |
+
for row in layout:
|
| 258 |
+
row_cells_info = list()
|
| 259 |
+
for cell_idx in row:
|
| 260 |
+
cell_str = ','.join([str(item) for item in cell_lines[cell_idx]])
|
| 261 |
+
row_cells_info.append(cell_str)
|
| 262 |
+
table_cells_info.append(row_cells_info)
|
| 263 |
+
|
| 264 |
+
return format_table(table_cells_info, padding=1)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def remove_blank_cell(html):
|
| 268 |
+
start_idx = 0
|
| 269 |
+
while '<td' in html[start_idx:]:
|
| 270 |
+
start_idx = html[start_idx:].index('<td') + start_idx
|
| 271 |
+
content_start_idx = html[start_idx:].index('>') + 1 + start_idx
|
| 272 |
+
content_end_idx = html[content_start_idx:].index('</td>') + content_start_idx
|
| 273 |
+
end_idx = content_end_idx + len('</td>')
|
| 274 |
+
if content_end_idx == content_start_idx:
|
| 275 |
+
html = html[:start_idx] + html[end_idx:]
|
| 276 |
+
else:
|
| 277 |
+
start_idx = end_idx
|
| 278 |
+
return html
|
libs/utils/logger.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
from .comm import get_rank
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
_default_logger = None
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def __init_logger():
|
| 13 |
+
global _default_logger
|
| 14 |
+
if get_rank() == 0:
|
| 15 |
+
logger = logging.getLogger('default')
|
| 16 |
+
logger.setLevel(logging.DEBUG)
|
| 17 |
+
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
|
| 18 |
+
|
| 19 |
+
if not any([isinstance(item, logging.StreamHandler) for item in logger.handlers]):
|
| 20 |
+
ch = logging.StreamHandler(stream=sys.stdout)
|
| 21 |
+
ch.setLevel(logging.DEBUG)
|
| 22 |
+
ch.setFormatter(formatter)
|
| 23 |
+
logger.addHandler(ch)
|
| 24 |
+
_default_logger = logger
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
__init_logger()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def setup_logger(name, save_dir, filename="log.txt"):
|
| 31 |
+
global _default_logger
|
| 32 |
+
# don't log results for the non-master process
|
| 33 |
+
if get_rank() == 0:
|
| 34 |
+
logger = logging.getLogger(name)
|
| 35 |
+
logger.setLevel(logging.DEBUG)
|
| 36 |
+
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
|
| 37 |
+
|
| 38 |
+
if not any([isinstance(item, logging.StreamHandler) for item in logger.handlers]):
|
| 39 |
+
ch = logging.StreamHandler(stream=sys.stdout)
|
| 40 |
+
ch.setLevel(logging.DEBUG)
|
| 41 |
+
ch.setFormatter(formatter)
|
| 42 |
+
logger.addHandler(ch)
|
| 43 |
+
|
| 44 |
+
logger.handlers = [item for item in logger.handlers if not isinstance(item, logging.FileHandler)]
|
| 45 |
+
if save_dir:
|
| 46 |
+
log_path = os.path.join(save_dir, filename)
|
| 47 |
+
if not os.path.exists(os.path.dirname(log_path)):
|
| 48 |
+
os.makedirs(os.path.dirname(log_path))
|
| 49 |
+
fh = logging.FileHandler(log_path)
|
| 50 |
+
fh.setLevel(logging.DEBUG)
|
| 51 |
+
fh.setFormatter(formatter)
|
| 52 |
+
logger.addHandler(fh)
|
| 53 |
+
|
| 54 |
+
_default_logger = logger
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def info(*args, **kwargs):
|
| 58 |
+
if get_rank() == 0:
|
| 59 |
+
_default_logger.info(*args, **kwargs)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def error(*args, **kwargs):
|
| 63 |
+
if get_rank() == 0:
|
| 64 |
+
_default_logger.error(*args, **kwargs)
|
libs/utils/metric.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .utils import match_segment_spans, find_unmatch_segment_spans
|
| 3 |
+
from .teds import TEDS
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class CellMergeAcc:
|
| 7 |
+
def __call__(self, preds, labels, labels_mask):
|
| 8 |
+
preds = preds & labels_mask
|
| 9 |
+
labels = labels & labels_mask
|
| 10 |
+
flag = preds == labels
|
| 11 |
+
flag = flag.reshape(flag.shape[0], flag.shape[1], -1).min(-1)[0]
|
| 12 |
+
mask = labels.reshape(labels.shape[0], labels.shape[1], -1).max(-1)[0]
|
| 13 |
+
correct_nums = float(torch.sum(flag & mask).detach().cpu().item())
|
| 14 |
+
total_nums = max(float(torch.sum(mask).detach().cpu().item()), 1e-6)
|
| 15 |
+
return correct_nums, total_nums
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class AccMetric:
|
| 19 |
+
def __call__(self, preds, labels, labels_mask):
|
| 20 |
+
mask = (labels_mask != 0) & (labels != -1)
|
| 21 |
+
correct_nums = float(torch.sum((preds == labels) & mask).detach().cpu().item())
|
| 22 |
+
total_nums = max(float(torch.sum(mask).detach().cpu().item()), 1e-6)
|
| 23 |
+
return correct_nums, total_nums
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def cal_cls_acc(cls_preds, cls_labels):
|
| 27 |
+
mask = (cls_labels != -1)
|
| 28 |
+
total_nums = float(torch.sum(mask).item())
|
| 29 |
+
pred_nums = float(torch.sum((cls_preds == cls_labels) & mask).item())
|
| 30 |
+
return pred_nums, total_nums
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def cal_segment_pr(pred_segments, fg_spans, bg_spans):
|
| 34 |
+
correct_nums = 0
|
| 35 |
+
segment_nums = 0
|
| 36 |
+
span_nums = 0
|
| 37 |
+
for pred_segments_pi, fg_spans_pi, bg_spans_pi in zip(pred_segments, fg_spans, bg_spans):
|
| 38 |
+
matched_segments_idx, _ = match_segment_spans(pred_segments_pi, fg_spans_pi)
|
| 39 |
+
unmatched_segments_idx = find_unmatch_segment_spans(pred_segments_pi, fg_spans_pi + bg_spans_pi)
|
| 40 |
+
|
| 41 |
+
correct_nums += len(matched_segments_idx)
|
| 42 |
+
segment_nums += len(pred_segments_pi) - len(unmatched_segments_idx)
|
| 43 |
+
span_nums += len(fg_spans_pi)
|
| 44 |
+
|
| 45 |
+
return correct_nums, segment_nums, span_nums
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class TEDSMetric:
|
| 49 |
+
def __init__(self, num_workers=1, structure_only=False):
|
| 50 |
+
self.evaluator = TEDS(n_jobs=num_workers, structure_only=structure_only)
|
| 51 |
+
|
| 52 |
+
def __call__(self, pred_htmls, label_htmls):
|
| 53 |
+
assert len(pred_htmls) == len(label_htmls)
|
| 54 |
+
pred_jsons = {idx: pred_html for idx, pred_html in enumerate(pred_htmls)}
|
| 55 |
+
label_jsons = {idx: dict(html=label_html) for idx, label_html in enumerate(label_htmls)}
|
| 56 |
+
scores = self.evaluator.batch_evaluate(pred_jsons, label_jsons)
|
| 57 |
+
scores = [scores[idx] for idx in range(len(pred_htmls))]
|
| 58 |
+
return scores
|
libs/utils/model_synchronizer.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .comm import get_world_size
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ModelSynchronizer:
|
| 7 |
+
bm_map = {
|
| 8 |
+
2: 0.65,
|
| 9 |
+
4: 0.75,
|
| 10 |
+
8: 0.875,
|
| 11 |
+
12: 0.8875,
|
| 12 |
+
16: 0.9,
|
| 13 |
+
32: 0.9
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
def __init__(self, model, sync_rate, bm=None, blr=1.0, rescale_grad=1.0):
|
| 17 |
+
if bm is None:
|
| 18 |
+
self.bm = self.bm_map[get_world_size()]
|
| 19 |
+
else:
|
| 20 |
+
self.bm = bm
|
| 21 |
+
self.blr = blr
|
| 22 |
+
self.model = model
|
| 23 |
+
self.sync_rate = sync_rate
|
| 24 |
+
self.rescale_grad = rescale_grad
|
| 25 |
+
self.count = 0
|
| 26 |
+
|
| 27 |
+
self.param_align()
|
| 28 |
+
|
| 29 |
+
self.momentums = dict()
|
| 30 |
+
self.global_params = dict()
|
| 31 |
+
for k, v in self.model.named_parameters():
|
| 32 |
+
temp = torch.zeros_like(v, requires_grad=False)
|
| 33 |
+
temp.copy_(v.data)
|
| 34 |
+
self.global_params[k] = v
|
| 35 |
+
self.momentums[k] = torch.zeros_like(v, requires_grad=False)
|
| 36 |
+
|
| 37 |
+
def param_align(self):
|
| 38 |
+
for v in self.model.parameters():
|
| 39 |
+
dist.broadcast_multigpu([v.data], src=0)
|
| 40 |
+
|
| 41 |
+
for k, v in self.model.named_buffers():
|
| 42 |
+
if 'num_batches_tracked' in k:
|
| 43 |
+
continue
|
| 44 |
+
dist.broadcast_multigpu([v.data], src=0)
|
| 45 |
+
|
| 46 |
+
def sync_params(self):
|
| 47 |
+
size = float(get_world_size())
|
| 48 |
+
for v in self.model.parameters():
|
| 49 |
+
dist.all_reduce(v.data, op=dist.ReduceOp.SUM)
|
| 50 |
+
v.data /= size
|
| 51 |
+
|
| 52 |
+
for k, v in self.model.named_buffers():
|
| 53 |
+
if 'num_batches_tracked' in k:
|
| 54 |
+
continue
|
| 55 |
+
dist.all_reduce(v.data, op=dist.ReduceOp.SUM)
|
| 56 |
+
v.data /= size
|
| 57 |
+
|
| 58 |
+
def __call__(self, final_align=False):
|
| 59 |
+
self.count += 1
|
| 60 |
+
if (self.count % self.sync_rate == 0) or final_align:
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
if final_align:
|
| 63 |
+
self.param_align()
|
| 64 |
+
else:
|
| 65 |
+
self.sync_params()
|
| 66 |
+
|
| 67 |
+
for k, v in self.model.named_parameters():
|
| 68 |
+
global_param = self.global_params[k]
|
| 69 |
+
momentum = self.momentums[k]
|
| 70 |
+
grad = v.data * self.rescale_grad - global_param
|
| 71 |
+
momentum *= self.bm
|
| 72 |
+
global_param -= momentum
|
| 73 |
+
momentum += self.blr * grad
|
| 74 |
+
global_param += (1.0 + self.bm) * momentum
|
| 75 |
+
v.detach().copy_(global_param.detach())
|
libs/utils/scitsr/__init__.py
ADDED
|
File without changes
|
libs/utils/scitsr/eval.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019-present, Zewen Chi
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
from .relation import Relation
|
| 11 |
+
from .table import Table, Chunk
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
DIR_HORIZ = 1
|
| 15 |
+
DIR_VERT = 2
|
| 16 |
+
DIR_SAME_CELL = 3
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def normalize(s:str, rule=0):
|
| 20 |
+
if rule == 0:
|
| 21 |
+
s = s.replace("\r", "")
|
| 22 |
+
s = s.replace("\n", "")
|
| 23 |
+
s = s.replace(" ", "")
|
| 24 |
+
s = s.replace("\t", "")
|
| 25 |
+
return s.upper()
|
| 26 |
+
else:
|
| 27 |
+
raise NotImplementedError
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def eval_relations(gt:List[List], res:List[List], cmp_blank=True):
|
| 31 |
+
"""Evaluate results
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
gt: a list of list of Relation
|
| 35 |
+
res: a list of list of Relation
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
#TODO to know how to calculate the total recall and prec
|
| 39 |
+
|
| 40 |
+
assert len(gt) == len(res)
|
| 41 |
+
tot_prec = 0
|
| 42 |
+
tot_recall = 0
|
| 43 |
+
total = 0
|
| 44 |
+
# print("evaluating result...")
|
| 45 |
+
|
| 46 |
+
# for _gt, _res in tqdm(zip(gt, res)):
|
| 47 |
+
# for _gt, _res in tqdm(zip(gt, res), total=len(gt), desc='eval'):
|
| 48 |
+
idx, t = 0, len(gt)
|
| 49 |
+
for _gt, _res in zip(gt, res):
|
| 50 |
+
idx += 1
|
| 51 |
+
print('Eval %d/%d (%d%%)' % (idx, t, idx / t * 100), ' ' * 45, end='\r')
|
| 52 |
+
corr = compare_rel(_gt, _res, cmp_blank)
|
| 53 |
+
precision = corr / len(_res) if len(_res) != 0 else 0
|
| 54 |
+
recall = corr / len(_gt) if len(_gt) != 0 else 0
|
| 55 |
+
tot_prec += precision
|
| 56 |
+
tot_recall += recall
|
| 57 |
+
total += 1
|
| 58 |
+
# print()
|
| 59 |
+
|
| 60 |
+
precision = tot_prec / total
|
| 61 |
+
recall = tot_recall / total
|
| 62 |
+
# print("Test on %d instances. Precision: %.2f, Recall: %.2f" % (
|
| 63 |
+
# total, precision, recall))
|
| 64 |
+
return precision, recall
|
| 65 |
+
|
| 66 |
+
def compare_rel(gt_rel:List[Relation], res_rel:List[Relation], cmp_blank=True):
|
| 67 |
+
count = 0
|
| 68 |
+
|
| 69 |
+
#print("compare_rel =======================")
|
| 70 |
+
#for gt in gt_rel:
|
| 71 |
+
# print("rel gt:", gt.from_text, gt.to_text, gt.direction)
|
| 72 |
+
#for gt in res_rel:
|
| 73 |
+
# print("rel res:", gt.from_text, gt.to_text, gt.direction)
|
| 74 |
+
#print("\n\n\n\n\n")
|
| 75 |
+
|
| 76 |
+
dup_res_rel = [r for r in res_rel]
|
| 77 |
+
for gt in gt_rel:
|
| 78 |
+
to_rm = None
|
| 79 |
+
for i, res in enumerate(dup_res_rel):
|
| 80 |
+
if gt.equal(res, cmp_blank):
|
| 81 |
+
to_rm = i
|
| 82 |
+
count += 1
|
| 83 |
+
break
|
| 84 |
+
if to_rm is not None:
|
| 85 |
+
dup_res_rel = dup_res_rel[:i] + dup_res_rel[i + 1:]
|
| 86 |
+
|
| 87 |
+
return count
|
| 88 |
+
|
| 89 |
+
def Table2Relations(t:Table):
|
| 90 |
+
"""Convert a Table object to a List of Relation.
|
| 91 |
+
"""
|
| 92 |
+
ret = []
|
| 93 |
+
cl = t.coo2cell_id
|
| 94 |
+
# remove duplicates with pair set
|
| 95 |
+
used = set()
|
| 96 |
+
|
| 97 |
+
# look right
|
| 98 |
+
for r in range(t.row_n):
|
| 99 |
+
for cFrom in range(t.col_n - 1):
|
| 100 |
+
cTo = cFrom + 1
|
| 101 |
+
loop = True
|
| 102 |
+
while loop and cTo < t.col_n:
|
| 103 |
+
fid, tid = cl[r][cFrom], cl[r][cTo]
|
| 104 |
+
if fid != -1 and tid != -1 and fid != tid:
|
| 105 |
+
if (fid, tid) not in used:
|
| 106 |
+
ret.append(Relation(
|
| 107 |
+
from_text=t.cells[fid].text,
|
| 108 |
+
to_text=t.cells[tid].text,
|
| 109 |
+
direction=DIR_HORIZ,
|
| 110 |
+
from_id=fid,
|
| 111 |
+
to_id=tid,
|
| 112 |
+
no_blanks=cTo - cFrom - 1
|
| 113 |
+
))
|
| 114 |
+
used.add((fid, tid))
|
| 115 |
+
loop = False
|
| 116 |
+
else:
|
| 117 |
+
if fid != -1 and tid != -1 and fid == tid:
|
| 118 |
+
cFrom = cTo
|
| 119 |
+
cTo += 1
|
| 120 |
+
|
| 121 |
+
# look down
|
| 122 |
+
for c in range(t.col_n):
|
| 123 |
+
for rFrom in range(t.row_n - 1):
|
| 124 |
+
rTo = rFrom + 1
|
| 125 |
+
loop = True
|
| 126 |
+
while loop and rTo < t.row_n:
|
| 127 |
+
fid, tid = cl[rFrom][c], cl[rTo][c]
|
| 128 |
+
if fid != -1 and tid != -1 and fid != tid:
|
| 129 |
+
if (fid, tid) not in used:
|
| 130 |
+
ret.append(Relation(
|
| 131 |
+
from_text=t.cells[fid].text,
|
| 132 |
+
to_text=t.cells[tid].text,
|
| 133 |
+
direction=DIR_VERT,
|
| 134 |
+
from_id=fid,
|
| 135 |
+
to_id=tid,
|
| 136 |
+
no_blanks=rTo - rFrom - 1
|
| 137 |
+
))
|
| 138 |
+
used.add((fid, tid))
|
| 139 |
+
loop = False
|
| 140 |
+
else:
|
| 141 |
+
if fid != -1 and tid != -1 and fid == tid:
|
| 142 |
+
rFrom = rTo
|
| 143 |
+
rTo += 1
|
| 144 |
+
|
| 145 |
+
return ret
|
| 146 |
+
|
| 147 |
+
def json2Table(json_obj, tid="", splitted_content=False):
|
| 148 |
+
"""Construct a Table object from json object
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
json_obj: a json object
|
| 152 |
+
Returns:
|
| 153 |
+
a Table object
|
| 154 |
+
"""
|
| 155 |
+
jo = json_obj["cells"]
|
| 156 |
+
row_n, col_n = 0, 0
|
| 157 |
+
cells = []
|
| 158 |
+
for co in jo:
|
| 159 |
+
content = co["content"]
|
| 160 |
+
if content is None: continue
|
| 161 |
+
if splitted_content:
|
| 162 |
+
content = " ".join(content)
|
| 163 |
+
else:
|
| 164 |
+
content = content.strip()
|
| 165 |
+
if content == "": continue
|
| 166 |
+
start_row = co["start_row"]
|
| 167 |
+
end_row = co["end_row"]
|
| 168 |
+
start_col = co["start_col"]
|
| 169 |
+
end_col = co["end_col"]
|
| 170 |
+
row_n = max(row_n, end_row)
|
| 171 |
+
col_n = max(col_n, end_col)
|
| 172 |
+
cell = Chunk(content, (start_row, end_row, start_col, end_col))
|
| 173 |
+
cells.append(cell)
|
| 174 |
+
return Table(row_n + 1, col_n + 1, cells, tid)
|
| 175 |
+
|
| 176 |
+
def json2Relations(json_obj, splitted_content):
|
| 177 |
+
return Table2Relations(json2Table(json_obj, "", splitted_content))
|
| 178 |
+
|
| 179 |
+
|
libs/utils/scitsr/relation.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019-present, Zewen Chi
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
def normalize(s:str, rule=0):
|
| 10 |
+
|
| 11 |
+
if rule == 0:
|
| 12 |
+
s = s.replace("\r", "")
|
| 13 |
+
s = s.replace("\n", "")
|
| 14 |
+
s = s.replace(" ", "")
|
| 15 |
+
s = s.replace("\t", "")
|
| 16 |
+
return s.upper()
|
| 17 |
+
else:
|
| 18 |
+
raise NotImplementedError
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Relation(object):
|
| 22 |
+
|
| 23 |
+
def __init__(self, from_text, to_text, direction, from_id=0, to_id=0, no_blanks=0):
|
| 24 |
+
self.from_text = from_text
|
| 25 |
+
self.to_text = to_text
|
| 26 |
+
self.direction = direction
|
| 27 |
+
self.no_blanks = no_blanks
|
| 28 |
+
self.from_id = from_id
|
| 29 |
+
self.to_id = to_id
|
| 30 |
+
|
| 31 |
+
def __eq__(self, rl):
|
| 32 |
+
this_ft = normalize(self.from_text)
|
| 33 |
+
this_tt = normalize(self.to_text)
|
| 34 |
+
rl_ft = normalize(rl.from_text)
|
| 35 |
+
rl_tt = normalize(rl.to_text)
|
| 36 |
+
if len(this_ft) == 0 or len(this_tt) == 0 or \
|
| 37 |
+
len(rl_ft) == 0 or len(rl_tt) == 0:
|
| 38 |
+
print("Warning: Text comparison of 0-length strings after normalization",
|
| 39 |
+
file=sys.stderr)
|
| 40 |
+
|
| 41 |
+
return this_ft == rl_ft and this_tt == rl_tt and \
|
| 42 |
+
self.direction == rl.direction and self.no_blanks == rl.no_blanks
|
| 43 |
+
|
| 44 |
+
def equal(self, rl, cmp_blank=True):
|
| 45 |
+
this_ft = normalize(self.from_text)
|
| 46 |
+
this_tt = normalize(self.to_text)
|
| 47 |
+
rl_ft = normalize(rl.from_text)
|
| 48 |
+
rl_tt = normalize(rl.to_text)
|
| 49 |
+
if len(this_ft) == 0 or len(this_tt) == 0 or \
|
| 50 |
+
len(rl_ft) == 0 or len(rl_tt) == 0:
|
| 51 |
+
print("Warning: Text comparison of 0-length strings after normalization",
|
| 52 |
+
file=sys.stderr)
|
| 53 |
+
|
| 54 |
+
return this_ft == rl_ft and this_tt == rl_tt and \
|
| 55 |
+
self.direction == rl.direction and \
|
| 56 |
+
(self.no_blanks == rl.no_blanks if cmp_blank else True)
|
| 57 |
+
|
| 58 |
+
def __str__(self):
|
| 59 |
+
return "%d:%d" % (self.direction, self.no_blanks)
|
libs/utils/scitsr/table.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2019-present, Zewen Chi
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
from typing import Iterable, List, Tuple
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_chunks(chunk_path):
|
| 13 |
+
with open(chunk_path, 'r') as f:
|
| 14 |
+
chunks = json.load(f)['chunks']
|
| 15 |
+
# NOTE remove the chunk with 0 len
|
| 16 |
+
ret = []
|
| 17 |
+
for chunk in chunks:
|
| 18 |
+
if chunk["pos"][1] < chunk["pos"][0]:
|
| 19 |
+
chunk["pos"][0], chunk["pos"][1] = chunk["pos"][1], chunk["pos"][0]
|
| 20 |
+
print("Warning load illegal chunk.")
|
| 21 |
+
c = Chunk.load_from_dict(chunk)
|
| 22 |
+
#if c.x2 == c.x1 or c.y2 == c.y1 or c.text == "":
|
| 23 |
+
# continue
|
| 24 |
+
ret.append(c)
|
| 25 |
+
return ret
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Box(object):
|
| 29 |
+
|
| 30 |
+
def __init__(self, pos):
|
| 31 |
+
"""pos: (x1, x2, y1, y2)"""
|
| 32 |
+
self.set_pos(pos)
|
| 33 |
+
|
| 34 |
+
def set_pos(self, pos):
|
| 35 |
+
assert pos[0] <= pos[1]
|
| 36 |
+
assert pos[2] <= pos[3]
|
| 37 |
+
self.x1 = pos[0]
|
| 38 |
+
self.x2 = pos[1]
|
| 39 |
+
self.y1 = pos[2]
|
| 40 |
+
self.y2 = pos[3]
|
| 41 |
+
self.w = self.x2 - self.x1
|
| 42 |
+
self.h = self.y2 - self.y1
|
| 43 |
+
self.pos = pos
|
| 44 |
+
|
| 45 |
+
def __lt__(self, other):
|
| 46 |
+
return self.pos.__lt__(other.pos)
|
| 47 |
+
|
| 48 |
+
def __contains__(self, other):
|
| 49 |
+
if other.x1 >= self.x1 and other.x2 <= self.x2 and \
|
| 50 |
+
other.y1 >= self.y1 and other.y2 <= self.y2:
|
| 51 |
+
return True
|
| 52 |
+
return False
|
| 53 |
+
|
| 54 |
+
def __str__(self):
|
| 55 |
+
return 'Box(%d, %d, %d, %d)' % self.pos
|
| 56 |
+
|
| 57 |
+
def __hash__(self):
|
| 58 |
+
return self.pos.__hash__()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Chunk(Box):
|
| 62 |
+
|
| 63 |
+
def __init__(self, text:str, pos:Tuple, size:float=0.0, cell_id=None):
|
| 64 |
+
super(Chunk, self).__init__(pos)
|
| 65 |
+
self.text = text
|
| 66 |
+
self.size = size
|
| 67 |
+
self.cell_id = cell_id
|
| 68 |
+
|
| 69 |
+
def __str__(self):
|
| 70 |
+
return 'Chunk(text="%s", pos=(%d, %d, %d, %d))' % (self.text, *self.pos)
|
| 71 |
+
|
| 72 |
+
def __repr__(self):
|
| 73 |
+
return self.__str__()
|
| 74 |
+
|
| 75 |
+
def dump_as_json_obj(self):
|
| 76 |
+
return {"text":self.text, "pos":self.pos, "cell_id":self.cell_id}
|
| 77 |
+
|
| 78 |
+
@classmethod
|
| 79 |
+
def load_from_dict(cls, d):
|
| 80 |
+
assert type(d) == dict
|
| 81 |
+
assert type(d["text"]) == str
|
| 82 |
+
assert len(d["pos"]) == 4
|
| 83 |
+
cell_id = d["cell_id"] if "cell_id" in d else None
|
| 84 |
+
return cls(d["text"].strip(), d["pos"], cell_id=cell_id)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class Table(object):
|
| 88 |
+
|
| 89 |
+
"""
|
| 90 |
+
The output of table segmentation.
|
| 91 |
+
With the Table object, we can get the set of cells
|
| 92 |
+
and their corresponding text.
|
| 93 |
+
"""
|
| 94 |
+
def __init__(self, row_n, col_n, cells:Iterable[Chunk]=None, tid=""):
|
| 95 |
+
# NOTE the Chunk object here represents the coordinate of
|
| 96 |
+
# the cell in the table.
|
| 97 |
+
# NOTE x in cell object represents the row id
|
| 98 |
+
self.tid = tid
|
| 99 |
+
self.row_n = row_n
|
| 100 |
+
self.col_n = col_n
|
| 101 |
+
self.coo2cell_id = [
|
| 102 |
+
[ -1 for _ in range(col_n) ] for _ in range(row_n) ]
|
| 103 |
+
self.cells:List[Chunk] = []
|
| 104 |
+
for cell in cells:
|
| 105 |
+
self.add_cell(cell)
|
| 106 |
+
|
| 107 |
+
def reverse(self, is_col=True):
|
| 108 |
+
cells = self.cells
|
| 109 |
+
self.cells = []
|
| 110 |
+
cell:Chunk = None
|
| 111 |
+
for cell in cells:
|
| 112 |
+
if is_col:
|
| 113 |
+
_c = Chunk(cell.text, (
|
| 114 |
+
self.row_n - cell.x2, self.row_n - cell.x1, cell.y1, cell.y2))
|
| 115 |
+
else:
|
| 116 |
+
_c = Chunk(cell.text, (
|
| 117 |
+
cell.x1, cell.x2, self.col_n - cell.y1, self.col_n - cell.y2))
|
| 118 |
+
self.add_cell(_c)
|
| 119 |
+
|
| 120 |
+
def add_cell(self, cell:Chunk):
|
| 121 |
+
# TODO Check conflicts of cells
|
| 122 |
+
assert cell.y2 < self.col_n
|
| 123 |
+
assert cell.x2 < self.row_n
|
| 124 |
+
|
| 125 |
+
for x in range(cell.x1, cell.x2 + 1, 1):
|
| 126 |
+
for y in range(cell.y1, cell.y2 + 1, 1):
|
| 127 |
+
self.coo2cell_id[x][y] = len(self.cells)
|
| 128 |
+
self.cells.append(cell)
|
| 129 |
+
|
| 130 |
+
def __getitem__(self, id_tuple):
|
| 131 |
+
row_id, col_id = id_tuple
|
| 132 |
+
assert row_id < self.row_n and col_id < self.col_n
|
| 133 |
+
return self.cells[self.coo2cell_id[row_id][col_id]]
|
libs/utils/teds.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 IBM
|
| 2 |
+
# Author: peter.zhong@au1.ibm.com
|
| 3 |
+
#
|
| 4 |
+
# This is free software; you can redistribute it and/or modify
|
| 5 |
+
# it under the terms of the Apache 2.0 License.
|
| 6 |
+
#
|
| 7 |
+
# This software is distributed in the hope that it will be useful,
|
| 8 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 9 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 10 |
+
# Apache 2.0 License for more details.
|
| 11 |
+
|
| 12 |
+
import distance
|
| 13 |
+
from apted import APTED, Config
|
| 14 |
+
from apted.helpers import Tree
|
| 15 |
+
from lxml import etree, html
|
| 16 |
+
from collections import deque
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
|
| 22 |
+
"""
|
| 23 |
+
A parallel version of the map function with a progress bar.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
array (array-like): An array to iterate over.
|
| 27 |
+
function (function): A python function to apply to the elements of array
|
| 28 |
+
n_jobs (int, default=16): The number of cores to use
|
| 29 |
+
use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
|
| 30 |
+
keyword arguments to function
|
| 31 |
+
front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
|
| 32 |
+
Useful for catching bugs
|
| 33 |
+
Returns:
|
| 34 |
+
[function(array[0]), function(array[1]), ...]
|
| 35 |
+
"""
|
| 36 |
+
# We run the first few iterations serially to catch bugs
|
| 37 |
+
if front_num > 0:
|
| 38 |
+
front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]]
|
| 39 |
+
else:
|
| 40 |
+
front = []
|
| 41 |
+
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
|
| 42 |
+
if n_jobs == 1:
|
| 43 |
+
return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
|
| 44 |
+
# Assemble the workers
|
| 45 |
+
with ProcessPoolExecutor(max_workers=n_jobs) as pool:
|
| 46 |
+
# Pass the elements of array into function
|
| 47 |
+
if use_kwargs:
|
| 48 |
+
futures = [pool.submit(function, **a) for a in array[front_num:]]
|
| 49 |
+
else:
|
| 50 |
+
futures = [pool.submit(function, a) for a in array[front_num:]]
|
| 51 |
+
kwargs = {
|
| 52 |
+
'total': len(futures),
|
| 53 |
+
'unit': 'it',
|
| 54 |
+
'unit_scale': True,
|
| 55 |
+
'leave': True
|
| 56 |
+
}
|
| 57 |
+
# Print out the progress as tasks complete
|
| 58 |
+
for f in tqdm(as_completed(futures), **kwargs):
|
| 59 |
+
pass
|
| 60 |
+
out = []
|
| 61 |
+
# Get the results from the futures.
|
| 62 |
+
for i, future in tqdm(enumerate(futures)):
|
| 63 |
+
try:
|
| 64 |
+
out.append(future.result())
|
| 65 |
+
except Exception as e:
|
| 66 |
+
out.append(e)
|
| 67 |
+
return front + out
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class TableTree(Tree):
|
| 71 |
+
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
|
| 72 |
+
self.tag = tag
|
| 73 |
+
self.colspan = colspan
|
| 74 |
+
self.rowspan = rowspan
|
| 75 |
+
self.content = content
|
| 76 |
+
self.children = list(children)
|
| 77 |
+
|
| 78 |
+
def bracket(self):
|
| 79 |
+
"""Show tree using brackets notation"""
|
| 80 |
+
if self.tag == 'td':
|
| 81 |
+
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
|
| 82 |
+
(self.tag, self.colspan, self.rowspan, self.content)
|
| 83 |
+
else:
|
| 84 |
+
result = '"tag": %s' % self.tag
|
| 85 |
+
for child in self.children:
|
| 86 |
+
result += child.bracket()
|
| 87 |
+
return "{{{}}}".format(result)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class CustomConfig(Config):
|
| 91 |
+
@staticmethod
|
| 92 |
+
def maximum(*sequences):
|
| 93 |
+
"""Get maximum possible value
|
| 94 |
+
"""
|
| 95 |
+
return max(map(len, sequences))
|
| 96 |
+
|
| 97 |
+
def normalized_distance(self, *sequences):
|
| 98 |
+
"""Get distance from 0 to 1
|
| 99 |
+
"""
|
| 100 |
+
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
| 101 |
+
|
| 102 |
+
def rename(self, node1, node2):
|
| 103 |
+
"""Compares attributes of trees"""
|
| 104 |
+
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
| 105 |
+
return 1.
|
| 106 |
+
if node1.tag == 'td':
|
| 107 |
+
if node1.content or node2.content:
|
| 108 |
+
return self.normalized_distance(node1.content, node2.content)
|
| 109 |
+
return 0.
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class TEDS(object):
|
| 113 |
+
''' Tree Edit Distance basead Similarity
|
| 114 |
+
'''
|
| 115 |
+
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
|
| 116 |
+
assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1'
|
| 117 |
+
self.structure_only = structure_only
|
| 118 |
+
self.n_jobs = n_jobs
|
| 119 |
+
self.ignore_nodes = ignore_nodes
|
| 120 |
+
self.__tokens__ = []
|
| 121 |
+
|
| 122 |
+
def tokenize(self, node):
|
| 123 |
+
''' Tokenizes table cells
|
| 124 |
+
'''
|
| 125 |
+
self.__tokens__.append('<%s>' % node.tag)
|
| 126 |
+
if node.text is not None:
|
| 127 |
+
self.__tokens__ += list(node.text)
|
| 128 |
+
for n in node.getchildren():
|
| 129 |
+
self.tokenize(n)
|
| 130 |
+
if node.tag != 'unk':
|
| 131 |
+
self.__tokens__.append('</%s>' % node.tag)
|
| 132 |
+
if node.tag != 'td' and node.tail is not None:
|
| 133 |
+
self.__tokens__ += list(node.tail)
|
| 134 |
+
|
| 135 |
+
def load_html_tree(self, node, parent=None):
|
| 136 |
+
''' Converts HTML tree to the format required by apted
|
| 137 |
+
'''
|
| 138 |
+
global __tokens__
|
| 139 |
+
if node.tag == 'td':
|
| 140 |
+
if self.structure_only:
|
| 141 |
+
cell = []
|
| 142 |
+
else:
|
| 143 |
+
self.__tokens__ = []
|
| 144 |
+
self.tokenize(node)
|
| 145 |
+
cell = self.__tokens__[1:-1].copy()
|
| 146 |
+
new_node = TableTree(node.tag,
|
| 147 |
+
int(node.attrib.get('colspan', '1')),
|
| 148 |
+
int(node.attrib.get('rowspan', '1')),
|
| 149 |
+
cell, *deque())
|
| 150 |
+
else:
|
| 151 |
+
new_node = TableTree(node.tag, None, None, None, *deque())
|
| 152 |
+
if parent is not None:
|
| 153 |
+
parent.children.append(new_node)
|
| 154 |
+
if node.tag != 'td':
|
| 155 |
+
for n in node.getchildren():
|
| 156 |
+
self.load_html_tree(n, new_node)
|
| 157 |
+
if parent is None:
|
| 158 |
+
return new_node
|
| 159 |
+
|
| 160 |
+
def evaluate(self, pred, true):
|
| 161 |
+
''' Computes TEDS score between the prediction and the ground truth of a
|
| 162 |
+
given sample
|
| 163 |
+
'''
|
| 164 |
+
if (not pred) or (not true):
|
| 165 |
+
return 0.0
|
| 166 |
+
parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
|
| 167 |
+
pred = html.fromstring(pred, parser=parser)
|
| 168 |
+
true = html.fromstring(true, parser=parser)
|
| 169 |
+
if pred.xpath('body/table') and true.xpath('body/table'):
|
| 170 |
+
pred = pred.xpath('body/table')[0]
|
| 171 |
+
true = true.xpath('body/table')[0]
|
| 172 |
+
if self.ignore_nodes:
|
| 173 |
+
etree.strip_tags(pred, *self.ignore_nodes)
|
| 174 |
+
etree.strip_tags(true, *self.ignore_nodes)
|
| 175 |
+
n_nodes_pred = len(pred.xpath(".//*"))
|
| 176 |
+
n_nodes_true = len(true.xpath(".//*"))
|
| 177 |
+
n_nodes = max(n_nodes_pred, n_nodes_true)
|
| 178 |
+
tree_pred = self.load_html_tree(pred)
|
| 179 |
+
tree_true = self.load_html_tree(true)
|
| 180 |
+
distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
|
| 181 |
+
return 1.0 - (float(distance) / n_nodes)
|
| 182 |
+
else:
|
| 183 |
+
return 0.0
|
| 184 |
+
|
| 185 |
+
def batch_evaluate(self, pred_json, true_json):
|
| 186 |
+
''' Computes TEDS score between the prediction and the ground truth of
|
| 187 |
+
a batch of samples
|
| 188 |
+
@params pred_json: {'FILENAME': 'HTML CODE', ...}
|
| 189 |
+
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
|
| 190 |
+
@output: {'FILENAME': 'TEDS SCORE', ...}
|
| 191 |
+
'''
|
| 192 |
+
samples = true_json.keys()
|
| 193 |
+
if self.n_jobs == 1:
|
| 194 |
+
scores = [self.evaluate(pred_json.get(filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
|
| 195 |
+
else:
|
| 196 |
+
inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]['html']} for filename in samples]
|
| 197 |
+
scores = parallel_process(inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
|
| 198 |
+
scores = dict(zip(samples, scores))
|
| 199 |
+
return scores
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == '__main__':
|
| 203 |
+
import json
|
| 204 |
+
import pprint
|
| 205 |
+
with open('sample_pred.json') as fp:
|
| 206 |
+
pred_json = json.load(fp)
|
| 207 |
+
with open('sample_gt.json') as fp:
|
| 208 |
+
true_json = json.load(fp)
|
| 209 |
+
teds = TEDS(n_jobs=4)
|
| 210 |
+
scores = teds.batch_evaluate(pred_json, true_json)
|
| 211 |
+
pp = pprint.PrettyPrinter()
|
| 212 |
+
pp.pprint(scores)
|
libs/utils/teds_multiprocess.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import tqdm
|
| 3 |
+
from libs.utils.teds import TEDS
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# def parse_args():
|
| 8 |
+
# import argparse
|
| 9 |
+
# parser = argparse.ArgumentParser()
|
| 10 |
+
# parser.add_argument('pred_path', type=str, default=None)
|
| 11 |
+
# parser.add_argument('label_path', type=str, default=None)
|
| 12 |
+
# parser.add_argument('-s', '--structure_only', action='store_true')
|
| 13 |
+
# parser.add_argument('-n', '--num_workers', type=int, default=1)
|
| 14 |
+
# args = parser.parse_args()
|
| 15 |
+
# return args
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def is_simple(data):
|
| 19 |
+
if ('colspan' in data) or ('rowspan' in data):
|
| 20 |
+
return False
|
| 21 |
+
else:
|
| 22 |
+
return True
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def judge_type(data):
|
| 26 |
+
if is_simple(data):
|
| 27 |
+
return 'Simple'
|
| 28 |
+
else:
|
| 29 |
+
return 'Complex'
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def single_process(pred_htmls, label_htmls, structure_only=False):
|
| 33 |
+
evaluator = TEDS(structure_only=structure_only)
|
| 34 |
+
scores = dict()
|
| 35 |
+
for key in tqdm.tqdm(label_htmls.keys()):
|
| 36 |
+
pred_html = pred_htmls.get(key, '')
|
| 37 |
+
label_html = label_htmls[key]['html']
|
| 38 |
+
score = evaluator.evaluate(pred_html, label_html)
|
| 39 |
+
scores[key] = score
|
| 40 |
+
return scores
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _worker(pred_htmls, label_htmls, keys, result_queue, structure_only=False):
|
| 44 |
+
evaluator = TEDS(structure_only=structure_only)
|
| 45 |
+
for key in keys:
|
| 46 |
+
pred_html = pred_htmls.get(key, '')
|
| 47 |
+
label_html = label_htmls[key]['html']
|
| 48 |
+
score = evaluator.evaluate(pred_html, label_html)
|
| 49 |
+
result_queue.put((key, score))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def multi_process(pred_htmls, label_htmls, num_workers, structure_only=False):
|
| 53 |
+
import multiprocessing
|
| 54 |
+
manager = multiprocessing.Manager()
|
| 55 |
+
result_queue = manager.Queue()
|
| 56 |
+
keys = list(label_htmls.keys())
|
| 57 |
+
workers = list()
|
| 58 |
+
for worker_idx in range(num_workers):
|
| 59 |
+
worker = multiprocessing.Process(
|
| 60 |
+
target=_worker,
|
| 61 |
+
args=(
|
| 62 |
+
pred_htmls,
|
| 63 |
+
label_htmls,
|
| 64 |
+
keys[worker_idx::num_workers],
|
| 65 |
+
result_queue
|
| 66 |
+
)
|
| 67 |
+
)
|
| 68 |
+
worker.daemon = True
|
| 69 |
+
worker.start()
|
| 70 |
+
workers.append(worker)
|
| 71 |
+
scores = dict()
|
| 72 |
+
tq = tqdm.tqdm(total=len(keys))
|
| 73 |
+
for _ in range(len(keys)):
|
| 74 |
+
key, val = result_queue.get()
|
| 75 |
+
scores[key] = val
|
| 76 |
+
teds = sum(scores.values()) / len(scores)
|
| 77 |
+
tq.set_description('Teds: %s' % teds, False)
|
| 78 |
+
tq.update()
|
| 79 |
+
tq.close()
|
| 80 |
+
return scores
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def evaluate(pred_htmls, label_htmls, num_workers, structure_only=False):
|
| 84 |
+
if num_workers <= 1:
|
| 85 |
+
scores = single_process(pred_htmls, label_htmls, structure_only)
|
| 86 |
+
else:
|
| 87 |
+
scores = multi_process(pred_htmls, label_htmls, num_workers, structure_only)
|
| 88 |
+
teds = sum(scores.values())/len(scores)
|
| 89 |
+
|
| 90 |
+
typed_teds = defaultdict(list)
|
| 91 |
+
for key, score in scores.items():
|
| 92 |
+
data_type = judge_type(label_htmls[key]['html'])
|
| 93 |
+
typed_teds[data_type].append(score)
|
| 94 |
+
|
| 95 |
+
typed_teds = {key: sum(val)/len(val) for key, val in typed_teds.items()}
|
| 96 |
+
return teds, typed_teds
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# def main():
|
| 100 |
+
# args = parse_args()
|
| 101 |
+
# pred_data = json.load(open(args.pred_path))
|
| 102 |
+
# label_data = json.load(open(args.label_path))
|
| 103 |
+
|
| 104 |
+
# teds, typed_teds = evaluate(pred_data, label_data, args.num_workers, args.structure_only)
|
| 105 |
+
# print('Teds: %s' % teds)
|
| 106 |
+
# for key, val in typed_teds.items():
|
| 107 |
+
# print(' %s Teds: %s' % (key, val))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# if __name__ == '__main__':
|
| 111 |
+
# main()
|
libs/utils/time_counter.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
import time
|
| 3 |
+
import datetime
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TimeCounter:
|
| 7 |
+
def __init__(self, start_epoch, num_epochs, epoch_iters):
|
| 8 |
+
self.start_epoch = start_epoch
|
| 9 |
+
self.num_epochs = num_epochs
|
| 10 |
+
self.epoch_iters = epoch_iters
|
| 11 |
+
self.start_time = None
|
| 12 |
+
|
| 13 |
+
def reset(self):
|
| 14 |
+
self.start_time = time.time()
|
| 15 |
+
|
| 16 |
+
def step(self, epoch, batch):
|
| 17 |
+
used = time.time() - self.start_time
|
| 18 |
+
finished_batch_nums = (epoch - self.start_epoch) * self.epoch_iters + batch
|
| 19 |
+
batch_time_cost = used / finished_batch_nums
|
| 20 |
+
total = (self.num_epochs - self.start_epoch) * self.epoch_iters * batch_time_cost
|
| 21 |
+
left = total - used
|
| 22 |
+
return str(datetime.timedelta(seconds=left))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def format_table(table, padding=1):
|
| 26 |
+
table = [[str(subitem) for subitem in item] for item in table]
|
| 27 |
+
num_cols = max([len(item) for item in table])
|
| 28 |
+
cols_width = [0] * num_cols
|
| 29 |
+
|
| 30 |
+
for row in table:
|
| 31 |
+
for col_idx, cell in enumerate(row):
|
| 32 |
+
cols_width[col_idx] = max(cols_width[col_idx], len(cell))
|
| 33 |
+
|
| 34 |
+
string = '��'
|
| 35 |
+
for col_idx in range(num_cols):
|
| 36 |
+
string += '��' * (padding * 2 + cols_width[col_idx])
|
| 37 |
+
if col_idx == num_cols - 1:
|
| 38 |
+
string += '��'
|
| 39 |
+
else:
|
| 40 |
+
string += '��'
|
| 41 |
+
string += '\n'
|
| 42 |
+
|
| 43 |
+
for row_idx, row in enumerate(table):
|
| 44 |
+
string += '��'
|
| 45 |
+
for col_idx in range(num_cols):
|
| 46 |
+
if col_idx < len(row):
|
| 47 |
+
word = row[col_idx]
|
| 48 |
+
else:
|
| 49 |
+
word = ''
|
| 50 |
+
col_width = cols_width[col_idx]
|
| 51 |
+
left_pad = (col_width - len(word))//2
|
| 52 |
+
right_pad = col_width - len(word) - left_pad
|
| 53 |
+
string += ' ' * (padding + left_pad)
|
| 54 |
+
string += word
|
| 55 |
+
string += ' ' * (padding + right_pad)
|
| 56 |
+
string += '��'
|
| 57 |
+
|
| 58 |
+
string += '\n'
|
| 59 |
+
|
| 60 |
+
if row_idx < len(table) - 1:
|
| 61 |
+
string += '��'
|
| 62 |
+
else:
|
| 63 |
+
string += '��'
|
| 64 |
+
for col_idx in range(num_cols):
|
| 65 |
+
string += '��' * (padding * 2 + cols_width[col_idx])
|
| 66 |
+
if col_idx == num_cols - 1:
|
| 67 |
+
if row_idx < len(table) - 1:
|
| 68 |
+
string += '��'
|
| 69 |
+
else:
|
| 70 |
+
string += '��'
|
| 71 |
+
else:
|
| 72 |
+
if row_idx < len(table) - 1:
|
| 73 |
+
string += '��'
|
| 74 |
+
else:
|
| 75 |
+
string += '��'
|
| 76 |
+
|
| 77 |
+
string += '\n'
|
| 78 |
+
return string
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class TicTocCounter:
|
| 82 |
+
def __init__(self):
|
| 83 |
+
self.tics = dict()
|
| 84 |
+
self.seps = defaultdict(list)
|
| 85 |
+
|
| 86 |
+
def tic(self, name):
|
| 87 |
+
self.tics[name] = time.time()
|
| 88 |
+
|
| 89 |
+
def toc(self, name):
|
| 90 |
+
toc = time.time()
|
| 91 |
+
if name in self.tics:
|
| 92 |
+
self.seps[name].append(toc-self.tics[name])
|
| 93 |
+
|
| 94 |
+
def __repr__(self):
|
| 95 |
+
string = 'TicTocCount Result:\n'
|
| 96 |
+
infos = [['Name', 'Mean Time', 'Total Time']]
|
| 97 |
+
for key, val in self.seps.items():
|
| 98 |
+
mean = sum(val)/len(val)
|
| 99 |
+
total = sum(val)
|
| 100 |
+
infos.append([key, '%0.4f' % mean, '%0.4f' % total])
|
| 101 |
+
string += format_table(infos)
|
| 102 |
+
return string
|
| 103 |
+
|
| 104 |
+
def reset(self):
|
| 105 |
+
self.tics.clear()
|
| 106 |
+
self.seps.clear()
|
| 107 |
+
|
| 108 |
+
global_tictoc_counter = TicTocCounter()
|
libs/utils/utils.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import copy
|
| 3 |
+
import Polygon
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def cal_mean_lr(optimizer):
|
| 8 |
+
lrs = [group['lr'] for group in optimizer.param_groups]
|
| 9 |
+
return sum(lrs)/len(lrs)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def cal_pr_f1(pr_info):
|
| 13 |
+
precision = pr_info[0] / pr_info[1]
|
| 14 |
+
recall = pr_info[0] / pr_info[2]
|
| 15 |
+
f1 = 2*precision*recall/(precision+recall)
|
| 16 |
+
return precision, recall, f1
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def match_segment_spans(segments, spans):
|
| 20 |
+
matched_segments = list()
|
| 21 |
+
matched_spans = list()
|
| 22 |
+
|
| 23 |
+
for segment_idx, segment in enumerate(segments):
|
| 24 |
+
for span_idx, span in enumerate(spans):
|
| 25 |
+
if span_idx not in matched_spans:
|
| 26 |
+
if (segment >= span[0]) and (segment < span[1]):
|
| 27 |
+
matched_segments.append(segment_idx)
|
| 28 |
+
matched_spans.append(span_idx)
|
| 29 |
+
|
| 30 |
+
return matched_segments, matched_spans
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def find_unmatch_segment_spans(segments, spans):
|
| 34 |
+
unmatched_segments = list()
|
| 35 |
+
for segment_idx, segment in enumerate(segments):
|
| 36 |
+
matched = False
|
| 37 |
+
for span in spans:
|
| 38 |
+
if (segment >= span[0]) and (segment < span[1]):
|
| 39 |
+
matched = True
|
| 40 |
+
break
|
| 41 |
+
if not matched:
|
| 42 |
+
unmatched_segments.append(segment_idx)
|
| 43 |
+
|
| 44 |
+
return unmatched_segments
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def parse_layout(spans, num_rows, num_cols):
|
| 48 |
+
layout = np.full([num_rows, num_cols], -1, dtype=np.int)
|
| 49 |
+
cell_count = 0
|
| 50 |
+
for x1, y1, x2, y2 in spans:
|
| 51 |
+
layout[y1:y2+1, x1:x2+1] = cell_count
|
| 52 |
+
cell_count += 1
|
| 53 |
+
|
| 54 |
+
cells_id = list()
|
| 55 |
+
for row_idx in range(num_rows):
|
| 56 |
+
for col_idx in range(num_cols):
|
| 57 |
+
cell_id = layout[row_idx, col_idx]
|
| 58 |
+
if cell_id in cells_id:
|
| 59 |
+
layout[row_idx, col_idx] = cells_id.index(cell_id)
|
| 60 |
+
else:
|
| 61 |
+
layout[row_idx, col_idx] = len(cells_id)
|
| 62 |
+
cells_id.append(cell_id)
|
| 63 |
+
return layout
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def parse_cells(layout, spans, row_segments, col_segments):
|
| 67 |
+
cells = list()
|
| 68 |
+
num_cells = np.max(layout) + 1
|
| 69 |
+
for cell_id in range(num_cells):
|
| 70 |
+
cell_positions = np.argwhere(layout == cell_id)
|
| 71 |
+
y1 = np.min(cell_positions[:, 0])
|
| 72 |
+
y2 = np.max(cell_positions[:, 0])
|
| 73 |
+
x1 = np.min(cell_positions[:, 1])
|
| 74 |
+
x2 = np.max(cell_positions[:, 1])
|
| 75 |
+
assert np.all(layout[y1:y2, x1:x2] == cell_id)
|
| 76 |
+
x1 = col_segments[x1]
|
| 77 |
+
x2 = col_segments[x2+1]
|
| 78 |
+
y1 = row_segments[y1]
|
| 79 |
+
y2 = row_segments[y2+1]
|
| 80 |
+
cell = dict(
|
| 81 |
+
segmentation=[[[x1, y1], [x2, y1], [x2, y2], [x1, y2]]]
|
| 82 |
+
)
|
| 83 |
+
cells.append(cell)
|
| 84 |
+
for span in spans:
|
| 85 |
+
cell_id = layout[span[1], span[0]]
|
| 86 |
+
cells[cell_id]['transcript'] = 'None'
|
| 87 |
+
return cells
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def segmentation_to_bbox(segmentation):
|
| 91 |
+
x1 = min([min([pt[0] for pt in contour]) for contour in segmentation])
|
| 92 |
+
y1 = min([min([pt[1] for pt in contour]) for contour in segmentation])
|
| 93 |
+
x2 = max([max([pt[0] for pt in contour]) for contour in segmentation])
|
| 94 |
+
y2 = max([max([pt[1] for pt in contour]) for contour in segmentation])
|
| 95 |
+
return [x1, y1, x2, y2]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def extend_cell_lines(cells, lines):
|
| 99 |
+
def segmentation_to_polygon(segmentation):
|
| 100 |
+
polygon = Polygon.Polygon()
|
| 101 |
+
for contour in segmentation:
|
| 102 |
+
polygon = polygon + Polygon.Polygon(contour)
|
| 103 |
+
return polygon
|
| 104 |
+
|
| 105 |
+
lines = copy.deepcopy(lines)
|
| 106 |
+
|
| 107 |
+
cells_poly = [segmentation_to_polygon(item['segmentation']) for item in cells]
|
| 108 |
+
lines_poly = [segmentation_to_polygon(item['segmentation']) for item in lines]
|
| 109 |
+
|
| 110 |
+
cells_lines = [[] for _ in range(len(cells))]
|
| 111 |
+
|
| 112 |
+
for line_idx, line_poly in enumerate(lines_poly):
|
| 113 |
+
if line_poly.area() == 0:
|
| 114 |
+
continue
|
| 115 |
+
line_area = line_poly.area()
|
| 116 |
+
max_overlap = 0
|
| 117 |
+
max_overlap_idx = None
|
| 118 |
+
for cell_idx, cell_poly in enumerate(cells_poly):
|
| 119 |
+
overlap = (cell_poly & line_poly).area()/line_area
|
| 120 |
+
if overlap > max_overlap:
|
| 121 |
+
max_overlap_idx = cell_idx
|
| 122 |
+
max_overlap = overlap
|
| 123 |
+
if max_overlap > 0:
|
| 124 |
+
cells_lines[max_overlap_idx].append(line_idx)
|
| 125 |
+
lines_y1 = [segmentation_to_bbox(item['segmentation'])[1] for item in lines]
|
| 126 |
+
cells_lines = [sorted(item, key=lambda idx: lines_y1[idx]) for item in cells_lines]
|
| 127 |
+
|
| 128 |
+
for cell, cell_lines in zip(cells, cells_lines):
|
| 129 |
+
cell['lines_idx'] = cell_lines
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def rerange_layout(table):
|
| 133 |
+
layout = table['layout']
|
| 134 |
+
cells = table['cells']
|
| 135 |
+
valid_cells_id = list()
|
| 136 |
+
for row_idx in range(layout.shape[0]):
|
| 137 |
+
for col_idx in range(layout.shape[1]):
|
| 138 |
+
cell_id = layout[row_idx, col_idx]
|
| 139 |
+
if cell_id not in valid_cells_id:
|
| 140 |
+
valid_cells_id.append(cell_id)
|
| 141 |
+
layout[row_idx, col_idx] = valid_cells_id.index(cell_id)
|
| 142 |
+
cells = [cells[cell_id] for cell_id in valid_cells_id]
|
| 143 |
+
table['layout'] = layout
|
| 144 |
+
table['cells'] = cells
|
| 145 |
+
|
| 146 |
+
def cal_cell_spans(table):
|
| 147 |
+
layout = table['layout']
|
| 148 |
+
num_cells = len(table['cells'])
|
| 149 |
+
cells_span = list()
|
| 150 |
+
for cell_id in range(num_cells):
|
| 151 |
+
cell_positions = np.argwhere(layout == cell_id)
|
| 152 |
+
y1 = np.min(cell_positions[:, 0])
|
| 153 |
+
y2 = np.max(cell_positions[:, 0])
|
| 154 |
+
x1 = np.min(cell_positions[:, 1])
|
| 155 |
+
x2 = np.max(cell_positions[:, 1])
|
| 156 |
+
assert np.all(layout[y1:y2, x1:x2] == cell_id)
|
| 157 |
+
cells_span.append([x1, y1, x2, y2])
|
| 158 |
+
return cells_span
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def remove_repeat_rcs(table):
|
| 162 |
+
layout = table['layout']
|
| 163 |
+
head_rows = table['head_rows']
|
| 164 |
+
body_rows = table['body_rows']
|
| 165 |
+
while True:
|
| 166 |
+
num_rows = layout.shape[0]
|
| 167 |
+
num_cols = layout.shape[1]
|
| 168 |
+
valid_rows_idx = list()
|
| 169 |
+
valid_rows_key = list()
|
| 170 |
+
|
| 171 |
+
for row_idx in range(num_rows):
|
| 172 |
+
row = layout[row_idx, :]
|
| 173 |
+
if len(np.unique(row)) == 1 and row_idx in body_rows: # remove repeated row
|
| 174 |
+
continue
|
| 175 |
+
row_key = ','.join([str(item) for item in row])
|
| 176 |
+
if row_key not in valid_rows_key:
|
| 177 |
+
valid_rows_idx.append(row_idx)
|
| 178 |
+
valid_rows_key.append(row_key)
|
| 179 |
+
|
| 180 |
+
valid_cols_idx = list()
|
| 181 |
+
valid_cols_key = list()
|
| 182 |
+
for col_idx in range(num_cols):
|
| 183 |
+
col = layout[:, col_idx]
|
| 184 |
+
if len(np.unique(col)) == 1: # remove repeated col
|
| 185 |
+
continue
|
| 186 |
+
col_key = ','.join([str(item) for item in col])
|
| 187 |
+
if col_key not in valid_cols_key:
|
| 188 |
+
valid_cols_idx.append(col_idx)
|
| 189 |
+
valid_cols_key.append(col_key)
|
| 190 |
+
if (len(valid_rows_idx) == num_rows) and (len(valid_cols_idx) == num_cols):
|
| 191 |
+
break
|
| 192 |
+
layout = layout[valid_rows_idx][:, valid_cols_idx]
|
| 193 |
+
head_rows = [n_idx for n_idx, o_idx in enumerate(valid_rows_idx) if o_idx in head_rows]
|
| 194 |
+
body_rows = [n_idx for n_idx, o_idx in enumerate(valid_rows_idx) if o_idx in body_rows]
|
| 195 |
+
|
| 196 |
+
table['layout'] = layout
|
| 197 |
+
table['head_rows'] = head_rows
|
| 198 |
+
table['body_rows'] = body_rows
|
| 199 |
+
rerange_layout(table)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def pred_result_to_table(pred_result):
|
| 203 |
+
row_segments, col_segments, divide, spans = pred_result
|
| 204 |
+
num_rows = len(row_segments) - 1
|
| 205 |
+
num_cols = len(col_segments) - 1
|
| 206 |
+
|
| 207 |
+
layout = parse_layout(spans, num_rows, num_cols)
|
| 208 |
+
cells = parse_cells(layout, spans, row_segments, col_segments)
|
| 209 |
+
head_rows = list(range(0, divide))
|
| 210 |
+
body_rows = list(range(divide, num_rows))
|
| 211 |
+
|
| 212 |
+
table = dict(
|
| 213 |
+
layout=layout,
|
| 214 |
+
head_rows=head_rows,
|
| 215 |
+
body_rows=body_rows,
|
| 216 |
+
cells=cells
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# remove_repeat_rcs(table)
|
| 220 |
+
|
| 221 |
+
return table
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def is_simple_table(table):
|
| 225 |
+
layout = table['layout']
|
| 226 |
+
num_rows, num_cols = layout.shape
|
| 227 |
+
if num_rows * num_cols == len(table['cells']):
|
| 228 |
+
return True
|
| 229 |
+
else:
|
| 230 |
+
return False
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def tensor_to_image(tensor):
|
| 234 |
+
image = tensor.detach().cpu().numpy()
|
| 235 |
+
if (len(image.shape) == 3) and (image.shape[0] != 3) and (image.shape[0] != 1):
|
| 236 |
+
image = np.sqrt(np.sum(np.power(image, 2), axis=0, keepdims=True))
|
| 237 |
+
image = 255 * (image-np.min(image))/(np.max(image) - np.min(image))
|
| 238 |
+
image = image.astype(np.uint8)
|
| 239 |
+
if len(image.shape) == 3:
|
| 240 |
+
image = np.transpose(image, (1, 2, 0)).copy()
|
| 241 |
+
if image.shape[2] == 1:
|
| 242 |
+
image = image[:, :, 0]
|
| 243 |
+
return image
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def visualize_layout(image, table):
|
| 247 |
+
def draw_segmentation(image, segmentation, color):
|
| 248 |
+
for contour in segmentation:
|
| 249 |
+
contour = np.array(contour, dtype=np.int32)
|
| 250 |
+
image = cv2.polylines(image, [contour], True, color)
|
| 251 |
+
return image
|
| 252 |
+
for cell in table['cells']:
|
| 253 |
+
if 'segmentation' in cell:
|
| 254 |
+
image = draw_segmentation(image, cell['segmentation'], (255, 0, 0))
|
| 255 |
+
return image
|
| 256 |
+
|
| 257 |
+
virtual_chars = ["<b>", "</b>", "<i>", "</i>", "<sup>", "</sup>", "<sub>", "</sub>", "<overline>", "</overline>", "<underline>", "</underline>", "<strike>", "</strike>"]
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def is_blank(content):
|
| 261 |
+
global virtual_chars
|
| 262 |
+
|
| 263 |
+
new_content = content
|
| 264 |
+
for item in virtual_chars:
|
| 265 |
+
new_content = new_content.replace(item, '')
|
| 266 |
+
return new_content.strip() == ''
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def filt_content(content, filt_blank=False, filt_virtual=False, filt_pad=False):
|
| 270 |
+
global virtual_chars
|
| 271 |
+
if filt_blank:
|
| 272 |
+
if is_blank(content):
|
| 273 |
+
content = ''
|
| 274 |
+
|
| 275 |
+
if filt_virtual:
|
| 276 |
+
for item in content:
|
| 277 |
+
content = content.replace(item, '')
|
| 278 |
+
|
| 279 |
+
if filt_pad:
|
| 280 |
+
content = content.strip()
|
| 281 |
+
|
| 282 |
+
return content
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def filt_transcript(html, filt_blank=False, filt_virtual=False, filt_pad=False):
|
| 286 |
+
start_idx = 0
|
| 287 |
+
while '<td' in html[start_idx:]:
|
| 288 |
+
start_idx = html[start_idx:].index('<td') + start_idx
|
| 289 |
+
content_start_idx = html[start_idx:].index('>') + 1 + start_idx
|
| 290 |
+
content_end_idx = html[content_start_idx:].index('</td>') + content_start_idx
|
| 291 |
+
end_idx = content_end_idx + len('</td>')
|
| 292 |
+
|
| 293 |
+
content = html[content_start_idx:content_end_idx]
|
| 294 |
+
content = filt_content(content, filt_blank, filt_virtual, filt_pad)
|
| 295 |
+
html = html[:content_start_idx] + content + html[content_end_idx:]
|
| 296 |
+
start_idx = end_idx - (content_end_idx-content_start_idx - len(content))
|
| 297 |
+
return html
|
libs/utils/vocab.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class Vocab:
|
| 2 |
+
key_words = [
|
| 3 |
+
'</line>',
|
| 4 |
+
'</none>', # []
|
| 5 |
+
'</bold>', # ['<b>', ' ', '</b>']
|
| 6 |
+
'</space>' # [' ']
|
| 7 |
+
]
|
| 8 |
+
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self._words_ids_map = dict()
|
| 11 |
+
self._ids_words_map = dict()
|
| 12 |
+
|
| 13 |
+
for word_id, word in enumerate(self.key_words):
|
| 14 |
+
self._words_ids_map[word] = word_id
|
| 15 |
+
self._ids_words_map[word_id] = word
|
| 16 |
+
|
| 17 |
+
self.line_id = self._words_ids_map['</line>']
|
| 18 |
+
self.none_id = self._words_ids_map['</none>']
|
| 19 |
+
self.bold_id = self._words_ids_map['</bold>']
|
| 20 |
+
self.space_id = self._words_ids_map['</space>']
|
| 21 |
+
self.blank_ids = [self.none_id, self.bold_id, self.space_id]
|
| 22 |
+
|
| 23 |
+
def __len__(self):
|
| 24 |
+
return len(self._words_ids_map)
|
| 25 |
+
|
| 26 |
+
def word_to_id(self, word):
|
| 27 |
+
return self._words_ids_map[word]
|
| 28 |
+
|
| 29 |
+
def words_to_ids(self, words):
|
| 30 |
+
return [self.word_to_id(word) for word in words]
|
| 31 |
+
|
| 32 |
+
def id_to_word(self, word_id):
|
| 33 |
+
return self._ids_words_map[word_id]
|
| 34 |
+
|
| 35 |
+
def ids_to_words(self, words_id):
|
| 36 |
+
return [self.id_to_word(word_id) for word_id in words_id]
|
requirements.txt
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
addict==2.4.0
|
| 2 |
+
aliyun-python-sdk-core==2.16.0
|
| 3 |
+
aliyun-python-sdk-kms==2.16.5
|
| 4 |
+
anyio==4.10.0
|
| 5 |
+
apted==1.0.3
|
| 6 |
+
beautifulsoup4==4.12.2
|
| 7 |
+
blinker==1.4
|
| 8 |
+
certifi==2025.8.3
|
| 9 |
+
charset-normalizer==3.4.3
|
| 10 |
+
click==8.1.8
|
| 11 |
+
colorama==0.4.6
|
| 12 |
+
contourpy==1.3.0
|
| 13 |
+
crcmod==1.7
|
| 14 |
+
cryptography==3.4.8
|
| 15 |
+
cycler==0.12.1
|
| 16 |
+
dbus-python==1.2.18
|
| 17 |
+
Distance==0.1.3
|
| 18 |
+
distro==1.7.0
|
| 19 |
+
exceptiongroup==1.3.0
|
| 20 |
+
filelock==3.14.0
|
| 21 |
+
fonttools==4.59.2
|
| 22 |
+
h11==0.16.0
|
| 23 |
+
httpcore==1.0.9
|
| 24 |
+
httplib2==0.20.2
|
| 25 |
+
httpx==0.28.1
|
| 26 |
+
idna==3.10
|
| 27 |
+
importlib-metadata==4.6.4
|
| 28 |
+
importlib_resources==6.5.2
|
| 29 |
+
jeepney==0.7.1
|
| 30 |
+
jmespath==0.10.0
|
| 31 |
+
keyring==23.5.0
|
| 32 |
+
kiwisolver==1.4.7
|
| 33 |
+
launchpadlib==1.10.16
|
| 34 |
+
lazr.restfulclient==0.14.4
|
| 35 |
+
lazr.uri==1.0.6
|
| 36 |
+
lxml==4.9.2
|
| 37 |
+
Mako==1.1.3
|
| 38 |
+
Markdown==3.3.6
|
| 39 |
+
markdown-it-py==3.0.0
|
| 40 |
+
MarkupSafe==2.0.1
|
| 41 |
+
matplotlib==3.7.1
|
| 42 |
+
mdurl==0.1.2
|
| 43 |
+
mmcv-full==1.6.2
|
| 44 |
+
mmdet==2.28.2
|
| 45 |
+
model-index==0.1.11
|
| 46 |
+
more-itertools==8.10.0
|
| 47 |
+
numpy==1.26.4
|
| 48 |
+
oauthlib==3.2.0
|
| 49 |
+
opencv-python==4.7.0.72
|
| 50 |
+
opendatalab==0.0.10
|
| 51 |
+
openmim==0.3.9
|
| 52 |
+
openxlab==0.1.2
|
| 53 |
+
ordered-set==4.1.0
|
| 54 |
+
oss2==2.17.0
|
| 55 |
+
packaging==24.2
|
| 56 |
+
pandas==2.0.2
|
| 57 |
+
Pillow==10.0.0
|
| 58 |
+
platformdirs==4.4.0
|
| 59 |
+
polygon==1.1.0
|
| 60 |
+
Polygon3==3.0.9.1
|
| 61 |
+
pycocotools==2.0.10
|
| 62 |
+
pycryptodome==3.23.0
|
| 63 |
+
Pygments==2.19.2
|
| 64 |
+
PyGObject==3.42.1
|
| 65 |
+
PyJWT==2.3.0
|
| 66 |
+
pyparsing==3.2.3
|
| 67 |
+
python-apt==2.4.0+ubuntu4
|
| 68 |
+
python-dateutil==2.9.0.post0
|
| 69 |
+
pytz==2023.4
|
| 70 |
+
PyYAML==6.0.2
|
| 71 |
+
requests==2.28.2
|
| 72 |
+
rich==13.4.2
|
| 73 |
+
scipy==1.13.1
|
| 74 |
+
seaborn==0.12.2
|
| 75 |
+
SecretStorage==3.3.1
|
| 76 |
+
shapely==2.0.1
|
| 77 |
+
six==1.17.0
|
| 78 |
+
sniffio==1.3.1
|
| 79 |
+
soupsieve==2.8
|
| 80 |
+
tabulate==0.9.0
|
| 81 |
+
terminaltables==3.1.10
|
| 82 |
+
tomli==2.2.1
|
| 83 |
+
torch==1.12.0
|
| 84 |
+
torchvision==0.13.0
|
| 85 |
+
tqdm==4.65.0
|
| 86 |
+
typing_extensions==4.15.0
|
| 87 |
+
tzdata==2025.2
|
| 88 |
+
urllib3==1.26.20
|
| 89 |
+
wadllib==1.3.6
|
| 90 |
+
websocket-client==1.8.0
|
| 91 |
+
websockets==15.0.1
|
| 92 |
+
yapf==0.43.0
|
| 93 |
+
zipp==3.23.0
|
runner/train.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import shutil
|
| 2 |
+
import torch
|
| 3 |
+
import tqdm
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
sys.path.append('./')
|
| 8 |
+
sys.path.append('../')
|
| 9 |
+
import numpy as np
|
| 10 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
from libs.utils.cal_f1 import pred_result_to_table, table_to_relations, evaluate_f1
|
| 13 |
+
from libs.utils.comm import distributed, synchronize
|
| 14 |
+
from libs.utils.checkpoint import load_checkpoint, save_checkpoint
|
| 15 |
+
from libs.data import create_train_dataloader, create_valid_dataloader
|
| 16 |
+
from libs.utils.model_synchronizer import ModelSynchronizer
|
| 17 |
+
from libs.utils.time_counter import TimeCounter
|
| 18 |
+
from libs.utils.utils import is_simple_table
|
| 19 |
+
from libs.utils.utils import cal_mean_lr
|
| 20 |
+
from libs.utils.counter import Counter
|
| 21 |
+
from libs.utils import logger
|
| 22 |
+
from libs.model import build_model
|
| 23 |
+
from libs.configs import cfg, setup_config
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
metrics_name = ['f1']
|
| 27 |
+
best_metrics = [0.0]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def init():
|
| 31 |
+
import argparse
|
| 32 |
+
parser = argparse.ArgumentParser()
|
| 33 |
+
parser.add_argument('--cfg', type=str, default='debug')
|
| 34 |
+
parser.add_argument('--local_rank', type=int, default=0)
|
| 35 |
+
args = parser.parse_args()
|
| 36 |
+
setup_config(args.cfg)
|
| 37 |
+
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
| 38 |
+
num_gpus = int(os.environ['MORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
| 39 |
+
distributed = num_gpus > 1
|
| 40 |
+
if distributed:
|
| 41 |
+
torch.cuda.set_device(args.local_rank)
|
| 42 |
+
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
| 43 |
+
synchronize()
|
| 44 |
+
logger.setup_logger('Line Detect Model', cfg.work_dir, 'train.log')
|
| 45 |
+
logger.info('Use config:%s' % args.cfg)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def train(cfg, epoch, dataloader, model, optimizer, scheduler, time_counter, synchronizer=None):
|
| 49 |
+
model.train()
|
| 50 |
+
counter = Counter(cache_nums=1000)
|
| 51 |
+
for it, data_batch in enumerate(dataloader):
|
| 52 |
+
ids = data_batch['ids']
|
| 53 |
+
images_size = data_batch['images_size']
|
| 54 |
+
images = data_batch['images'].to(cfg.device)
|
| 55 |
+
cls_labels = data_batch['cls_labels'].to(cfg.device)
|
| 56 |
+
labels_mask = data_batch['labels_mask'].to(cfg.device)
|
| 57 |
+
rows_fg_spans = data_batch['rows_fg_spans']
|
| 58 |
+
rows_bg_spans = data_batch['rows_bg_spans']
|
| 59 |
+
cols_fg_spans = data_batch['cols_fg_spans']
|
| 60 |
+
cols_bg_spans = data_batch['cols_bg_spans']
|
| 61 |
+
cells_spans = data_batch['cells_spans']
|
| 62 |
+
divide_labels = data_batch['divide_labels'].to(cfg.device)
|
| 63 |
+
layouts = data_batch['layouts'].to(cfg.device)
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
optimizer.zero_grad()
|
| 67 |
+
pred_result, result_info = model(
|
| 68 |
+
images, images_size,
|
| 69 |
+
cls_labels, labels_mask, layouts,
|
| 70 |
+
rows_fg_spans, rows_bg_spans,
|
| 71 |
+
cols_fg_spans, cols_bg_spans,
|
| 72 |
+
cells_spans, divide_labels,
|
| 73 |
+
)
|
| 74 |
+
loss = sum([val for key, val in result_info.items() if 'loss' in key])
|
| 75 |
+
loss.backward()
|
| 76 |
+
optimizer.step()
|
| 77 |
+
scheduler.step()
|
| 78 |
+
counter.update(result_info)
|
| 79 |
+
except:
|
| 80 |
+
logger.info('CUDA Out Of Memory')
|
| 81 |
+
|
| 82 |
+
if it % cfg.log_sep == 0:
|
| 83 |
+
logger.info(
|
| 84 |
+
'[Train][Epoch %03d Iter %04d][Memory: %.0f ][Mean LR: %f ][Left: %s] %s' %
|
| 85 |
+
(
|
| 86 |
+
epoch,
|
| 87 |
+
it,
|
| 88 |
+
torch.cuda.max_memory_allocated()/1024/1024,
|
| 89 |
+
cal_mean_lr(optimizer),
|
| 90 |
+
time_counter.step(epoch, it + 1),
|
| 91 |
+
counter.format_mean(sync=False)
|
| 92 |
+
)
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
if synchronizer is not None:
|
| 96 |
+
synchronizer()
|
| 97 |
+
if synchronizer is not None:
|
| 98 |
+
synchronizer(final_align=True)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def valid(cfg, dataloader, model):
|
| 102 |
+
model.eval()
|
| 103 |
+
total_label_relations = list()
|
| 104 |
+
total_pred_relations = list()
|
| 105 |
+
total_relations_metric = list()
|
| 106 |
+
|
| 107 |
+
for it, data_batch in enumerate(tqdm.tqdm(dataloader)):
|
| 108 |
+
ids = data_batch['ids']
|
| 109 |
+
images_size = data_batch['images_size']
|
| 110 |
+
images = data_batch['images'].to(cfg.device)
|
| 111 |
+
tables = data_batch['tables']
|
| 112 |
+
pred_result, _ = model(images, images_size)
|
| 113 |
+
pred_tables = [
|
| 114 |
+
pred_result_to_table(tables[batch_idx],
|
| 115 |
+
(pred_result[0][batch_idx], pred_result[1][batch_idx],
|
| 116 |
+
pred_result[2][batch_idx], pred_result[3][batch_idx])
|
| 117 |
+
)
|
| 118 |
+
for batch_idx in range(len(ids))
|
| 119 |
+
]
|
| 120 |
+
pred_relations = [table_to_relations(table) for table in pred_tables]
|
| 121 |
+
total_pred_relations.extend(pred_relations)
|
| 122 |
+
# label
|
| 123 |
+
label_relations = []
|
| 124 |
+
for table in tables:
|
| 125 |
+
label_path = os.path.join(cfg.valid_data_dir, table['label_path'])
|
| 126 |
+
with open(table['label_path'], 'r') as f:
|
| 127 |
+
label_relations.append(json.load(f))
|
| 128 |
+
total_label_relations.extend(label_relations)
|
| 129 |
+
|
| 130 |
+
# cal P, R, F1
|
| 131 |
+
total_relations_metric = evaluate_f1(total_label_relations, total_pred_relations, num_workers=40)
|
| 132 |
+
P, R, F1 = np.array(total_relations_metric).mean(0).tolist()
|
| 133 |
+
F1 = 2 * P * R / (P + R)
|
| 134 |
+
logger.info('[Valid] Total Type Mertric: Precision: %s, Recall: %s, F1-Score: %s' % (P, R, F1))
|
| 135 |
+
return (F1,)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def build_optimizer(cfg, model):
|
| 139 |
+
params = list()
|
| 140 |
+
for _, value in model.named_parameters():
|
| 141 |
+
if not value.requires_grad:
|
| 142 |
+
continue
|
| 143 |
+
lr = cfg.base_lr
|
| 144 |
+
weight_decay = cfg.weight_decay
|
| 145 |
+
params += [{'params': [value], 'lr': lr, 'weight_decay': weight_decay}]
|
| 146 |
+
optimizer = torch.optim.Adam(params, cfg.base_lr)
|
| 147 |
+
return optimizer
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def build_scheduler(cfg, optimizer, epoch_iters, start_epoch=0):
|
| 151 |
+
scheduler = CosineAnnealingLR(
|
| 152 |
+
optimizer=optimizer,
|
| 153 |
+
T_max=cfg.num_epochs * epoch_iters,
|
| 154 |
+
eta_min=cfg.min_lr,
|
| 155 |
+
last_epoch=-1 if start_epoch == 0 else start_epoch * epoch_iters
|
| 156 |
+
)
|
| 157 |
+
return scheduler
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def main():
|
| 161 |
+
init()
|
| 162 |
+
|
| 163 |
+
train_dataloader = create_train_dataloader(
|
| 164 |
+
cfg.vocab,
|
| 165 |
+
cfg.train_lrcs_path,
|
| 166 |
+
cfg.train_num_workers,
|
| 167 |
+
cfg.train_max_batch_size,
|
| 168 |
+
cfg.train_max_pixel_nums,
|
| 169 |
+
cfg.train_bucket_seps,
|
| 170 |
+
cfg.train_data_dir
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
logger.info(
|
| 174 |
+
'Train dataset have %d samples, %d batchs' %
|
| 175 |
+
(
|
| 176 |
+
len(train_dataloader.dataset),
|
| 177 |
+
len(train_dataloader.batch_sampler)
|
| 178 |
+
)
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
valid_dataloader = create_valid_dataloader(
|
| 182 |
+
cfg.vocab,
|
| 183 |
+
cfg.valid_lrc_path,
|
| 184 |
+
cfg.valid_num_workers,
|
| 185 |
+
cfg.valid_batch_size,
|
| 186 |
+
cfg.valid_data_dir
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
logger.info(
|
| 190 |
+
'Valid dataset have %d samples, %d batchs with batch_size=%d' %
|
| 191 |
+
(
|
| 192 |
+
len(valid_dataloader.dataset),
|
| 193 |
+
len(valid_dataloader.batch_sampler),
|
| 194 |
+
valid_dataloader.batch_size
|
| 195 |
+
)
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
model = build_model(cfg)
|
| 199 |
+
model.cuda()
|
| 200 |
+
|
| 201 |
+
if distributed():
|
| 202 |
+
synchronizer = ModelSynchronizer(model, cfg.sync_rate)
|
| 203 |
+
else:
|
| 204 |
+
synchronizer = None
|
| 205 |
+
|
| 206 |
+
epoch_iters = len(train_dataloader.batch_sampler)
|
| 207 |
+
optimizer = build_optimizer(cfg, model)
|
| 208 |
+
|
| 209 |
+
global metrics_name
|
| 210 |
+
global best_metrics
|
| 211 |
+
start_epoch = 0
|
| 212 |
+
|
| 213 |
+
resume_path = os.path.join(cfg.work_dir, 'latest_model.pth')
|
| 214 |
+
if os.path.exists(resume_path):
|
| 215 |
+
best_metrics, start_epoch = load_checkpoint(resume_path, model, optimizer)
|
| 216 |
+
start_epoch += 1
|
| 217 |
+
logger.info('resume from: %s' % resume_path)
|
| 218 |
+
elif cfg.train_checkpoint is not None:
|
| 219 |
+
load_checkpoint(cfg.train_checkpoint, model)
|
| 220 |
+
logger.info('load checkpoint from: %s' % cfg.train_checkpoint)
|
| 221 |
+
|
| 222 |
+
scheduler = build_scheduler(cfg, optimizer, epoch_iters, start_epoch)
|
| 223 |
+
|
| 224 |
+
time_counter = TimeCounter(start_epoch, cfg.num_epochs, epoch_iters)
|
| 225 |
+
time_counter.reset()
|
| 226 |
+
|
| 227 |
+
for epoch in range(start_epoch, cfg.num_epochs):
|
| 228 |
+
if hasattr(train_dataloader.sampler, 'set_epoch'):
|
| 229 |
+
train_dataloader.sampler.set_epoch(epoch)
|
| 230 |
+
train(cfg, epoch, train_dataloader, model, optimizer, scheduler, time_counter, synchronizer)
|
| 231 |
+
|
| 232 |
+
with torch.no_grad():
|
| 233 |
+
metrics = valid(cfg, valid_dataloader, model)
|
| 234 |
+
|
| 235 |
+
for metric_idx in range(len(metrics_name)):
|
| 236 |
+
if metrics[metric_idx] > best_metrics[metric_idx]:
|
| 237 |
+
best_metrics[metric_idx] = metrics[metric_idx]
|
| 238 |
+
save_checkpoint(os.path.join(cfg.work_dir, 'best_%s_model.pth' % metrics_name[metric_idx]), model, optimizer, best_metrics, epoch)
|
| 239 |
+
logger.info('Save current model as best_%s_model' % metrics_name[metric_idx])
|
| 240 |
+
|
| 241 |
+
save_checkpoint(os.path.join(cfg.work_dir, 'latest_model.pth'), model, optimizer, best_metrics, epoch)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if __name__ == '__main__':
|
| 245 |
+
main()
|
runner/valid.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import json
|
| 3 |
+
sys.path.append('./')
|
| 4 |
+
sys.path.append('../')
|
| 5 |
+
import os
|
| 6 |
+
import tqdm
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
from libs.configs import cfg, setup_config
|
| 10 |
+
from libs.model import build_model
|
| 11 |
+
from libs.data import create_valid_dataloader
|
| 12 |
+
from libs.utils import logger
|
| 13 |
+
from libs.utils.cal_f1 import pred_result_to_table, table_to_relations, evaluate_f1
|
| 14 |
+
from libs.utils.checkpoint import load_checkpoint
|
| 15 |
+
from libs.utils.comm import synchronize, all_gather
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def init():
|
| 19 |
+
import argparse
|
| 20 |
+
parser = argparse.ArgumentParser()
|
| 21 |
+
parser.add_argument("--lrc", type=str, default=None)
|
| 22 |
+
parser.add_argument("--cfg", type=str, default='default')
|
| 23 |
+
parser.add_argument("--local_rank", type=int, default=0)
|
| 24 |
+
args = parser.parse_args()
|
| 25 |
+
|
| 26 |
+
setup_config(args.cfg)
|
| 27 |
+
if args.lrc is not None:
|
| 28 |
+
cfg.valid_lrc_path = args.lrc
|
| 29 |
+
|
| 30 |
+
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
| 31 |
+
|
| 32 |
+
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
|
| 33 |
+
distributed = num_gpus > 1
|
| 34 |
+
|
| 35 |
+
if distributed:
|
| 36 |
+
torch.cuda.set_device(args.local_rank)
|
| 37 |
+
torch.distributed.init_process_group(
|
| 38 |
+
backend="nccl", init_method="env://"
|
| 39 |
+
)
|
| 40 |
+
synchronize()
|
| 41 |
+
|
| 42 |
+
logger.setup_logger('Line Detect Model', cfg.work_dir, 'valid.log')
|
| 43 |
+
logger.info('Use config: %s' % args.cfg)
|
| 44 |
+
logger.info('Evaluate Dataset: %s' % cfg.valid_lrc_path)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def valid(cfg, dataloader, model):
|
| 48 |
+
model.eval()
|
| 49 |
+
total_label_relations = list()
|
| 50 |
+
total_pred_relations = list()
|
| 51 |
+
total_relations_metric = list()
|
| 52 |
+
|
| 53 |
+
for it, data_batch in enumerate(tqdm.tqdm(dataloader)):
|
| 54 |
+
ids = data_batch['ids']
|
| 55 |
+
images_size = data_batch['images_size']
|
| 56 |
+
images = data_batch['images'].to(cfg.device)
|
| 57 |
+
tables = data_batch['tables']
|
| 58 |
+
|
| 59 |
+
pred_result, _ = model(images, images_size)
|
| 60 |
+
|
| 61 |
+
# pred
|
| 62 |
+
pred_tables = [
|
| 63 |
+
pred_result_to_table(tables[batch_idx],
|
| 64 |
+
(pred_result[0][batch_idx], pred_result[1][batch_idx], \
|
| 65 |
+
pred_result[2][batch_idx], pred_result[3][batch_idx])
|
| 66 |
+
) \
|
| 67 |
+
for batch_idx in range(len(ids))
|
| 68 |
+
]
|
| 69 |
+
pred_relations = [table_to_relations(table) for table in pred_tables]
|
| 70 |
+
total_pred_relations.extend(pred_relations)
|
| 71 |
+
# label
|
| 72 |
+
label_relations = []
|
| 73 |
+
for table in tables:
|
| 74 |
+
with open(table['label_path'], 'r') as f:
|
| 75 |
+
label_relations.append(json.load(f))
|
| 76 |
+
total_label_relations.extend(label_relations)
|
| 77 |
+
|
| 78 |
+
# cal P, R, F1
|
| 79 |
+
total_relations_metric = evaluate_f1(total_label_relations, total_pred_relations, num_workers=40)
|
| 80 |
+
P, R, F1 = np.array(total_relations_metric).mean(0).tolist()
|
| 81 |
+
F1 = 2 * P * R / (P + R)
|
| 82 |
+
logger.info('[Valid] Total Type Mertric: Precision: %s, Recall: %s, F1-Score: %s' % (P, R, F1))
|
| 83 |
+
|
| 84 |
+
return (F1, )
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def main():
|
| 88 |
+
init()
|
| 89 |
+
|
| 90 |
+
valid_dataloader = create_valid_dataloader(
|
| 91 |
+
cfg.vocab,
|
| 92 |
+
cfg.valid_lrc_path,
|
| 93 |
+
cfg.valid_num_workers,
|
| 94 |
+
cfg.valid_batch_size
|
| 95 |
+
)
|
| 96 |
+
logger.info(
|
| 97 |
+
'Valid dataset have %d samples, %d batchs with batch_size=%d' % \
|
| 98 |
+
(
|
| 99 |
+
len(valid_dataloader.dataset),
|
| 100 |
+
len(valid_dataloader.batch_sampler),
|
| 101 |
+
valid_dataloader.batch_size
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
model = build_model(cfg)
|
| 106 |
+
model.cuda()
|
| 107 |
+
|
| 108 |
+
load_checkpoint(cfg.eval_checkpoint, model)
|
| 109 |
+
logger.info('Load checkpoint from: %s' % cfg.eval_checkpoint)
|
| 110 |
+
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
valid(cfg, valid_dataloader, model)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if __name__ == '__main__':
|
| 116 |
+
main()
|