kai-2054 commited on
Commit
cb0ad2d
·
1 Parent(s): 71f63bc

Initial commit: add code

Browse files
Files changed (49) hide show
  1. README.md +13 -0
  2. dataset/extract_ocr.py +155 -0
  3. dataset/process_scitsr.sh +0 -0
  4. dataset/trans2lrc.py +145 -0
  5. dataset/utils/extract_table_lines.py +192 -0
  6. dataset/utils/list_record_cache.py +143 -0
  7. dataset/utils/utils.py +449 -0
  8. libs/configs/__init__.py +24 -0
  9. libs/configs/default.py +77 -0
  10. libs/data/__init__.py +69 -0
  11. libs/data/batch_sampler.py +118 -0
  12. libs/data/dataset.py +164 -0
  13. libs/data/list_record_cache.py +143 -0
  14. libs/data/transform.py +70 -0
  15. libs/data/utils.py +188 -0
  16. libs/model/__init__.py +16 -0
  17. libs/model/backbone.py +281 -0
  18. libs/model/cells_extractor.py +130 -0
  19. libs/model/decoder.py +277 -0
  20. libs/model/divide_predictor.py +57 -0
  21. libs/model/extractor.py +88 -0
  22. libs/model/fpn.py +37 -0
  23. libs/model/model.py +65 -0
  24. libs/model/pan.py +24 -0
  25. libs/model/sa.py +35 -0
  26. libs/model/segment_predictor.py +133 -0
  27. libs/model/utils.py +371 -0
  28. libs/utils/__init__.py +0 -0
  29. libs/utils/cal_f1.py +214 -0
  30. libs/utils/checkpoint.py +47 -0
  31. libs/utils/comm.py +129 -0
  32. libs/utils/context_cacher.py +15 -0
  33. libs/utils/counter.py +43 -0
  34. libs/utils/format_translate.py +278 -0
  35. libs/utils/logger.py +64 -0
  36. libs/utils/metric.py +58 -0
  37. libs/utils/model_synchronizer.py +75 -0
  38. libs/utils/scitsr/__init__.py +0 -0
  39. libs/utils/scitsr/eval.py +179 -0
  40. libs/utils/scitsr/relation.py +59 -0
  41. libs/utils/scitsr/table.py +133 -0
  42. libs/utils/teds.py +212 -0
  43. libs/utils/teds_multiprocess.py +111 -0
  44. libs/utils/time_counter.py +108 -0
  45. libs/utils/utils.py +297 -0
  46. libs/utils/vocab.py +36 -0
  47. requirements.txt +93 -0
  48. runner/train.py +245 -0
  49. runner/valid.py +116 -0
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Training
3
+ emoji: 🦀
4
+ colorFrom: pink
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 5.45.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
dataset/extract_ocr.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import tqdm
4
+ import numpy as np
5
+ from utlis.list_record_cache import ListRecordCacher, merge_record_file
6
+ from utlis.utlis import get_paths, get_sub_paths, crop_pdf, extract_ocr, refine_table, visualize_table
7
+
8
+
9
+ def parse_args():
10
+ import argparse
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument('src_dir', type=str, default=None)
13
+ parser.add_argument('dst_dir', type=str, default=None)
14
+ parser.add_argument('-n', '--num_workers', type=int, default=0)
15
+ args = parser.parse_args()
16
+ return args
17
+
18
+
19
+ def single_process(paths, dst_dir):
20
+
21
+ output_pdf_dir = os.path.join(dst_dir, 'pdf')
22
+ if not os.path.exists(output_pdf_dir):
23
+ os.makedirs(output_pdf_dir)
24
+ output_img_dir = os.path.join(dst_dir, 'img')
25
+ if not os.path.exists(output_img_dir):
26
+ os.makedirs(output_img_dir)
27
+ output_error_dir = os.path.join(dst_dir, 'error')
28
+ if not os.path.exists(output_error_dir):
29
+ os.makedirs(output_error_dir)
30
+ output_visual_dir = os.path.join(dst_dir, 'visual')
31
+ if not os.path.exists(output_visual_dir):
32
+ os.makedirs(output_visual_dir)
33
+
34
+ cacher = ListRecordCacher(os.path.join(dst_dir, 'table.lrc'))
35
+
36
+ error_paths = []
37
+ error_count = 0
38
+ correct_count = 0
39
+ for id, path in enumerate(tqdm.tqdm(paths)):
40
+ try:
41
+ pdf_path, chunk_path, structure_path = path
42
+ name = os.path.splitext(os.path.basename(pdf_path))[0]
43
+
44
+ positions, transcripts = crop_pdf(path, output_pdf_dir)
45
+ table = extract_ocr(path, positions, transcripts)
46
+ table = refine_table(table, os.path.join(output_pdf_dir, name + '.png'), output_img_dir)
47
+
48
+ assert os.path.exists(structure_path), print('structure_path is not existed')
49
+ table['label_path'] = structure_path
50
+
51
+ visualize_table(os.path.join(output_img_dir, name + '.png'), output_visual_dir, table)
52
+
53
+ cacher.add_record(table)
54
+ correct_count += 1
55
+ except:
56
+ error_count += 1
57
+ error_paths.append(path)
58
+ crop_pdf(path, output_error_dir)
59
+
60
+ print("correct num: %d, error num: %d " % (correct_count, error_count))
61
+ if len(error_paths) > 0:
62
+ np.save(os.path.join(dst_dir, 'error_paths.npy'), error_paths)
63
+ cacher.close()
64
+
65
+
66
+ def _worker(worker_idx, num_workers, paths, dst_dir, result_queue):
67
+
68
+ output_pdf_dir = os.path.join(dst_dir, 'pdf')
69
+ output_img_dir = os.path.join(dst_dir, 'img')
70
+ output_error_dir = os.path.join(dst_dir, 'error')
71
+ output_visual_dir = os.path.join(dst_dir, 'visual')
72
+
73
+ cacher = ListRecordCacher(os.path.join(dst_dir, 'table_%d.lrc' % worker_idx))
74
+
75
+ error_paths = []
76
+ error_count = 0
77
+ correct_count = 0
78
+ for id, path in enumerate(tqdm.tqdm(paths)):
79
+ try:
80
+ pdf_path, chunk_path, structure_path = path
81
+ name = os.path.splitext(os.path.basename(pdf_path))[0]
82
+
83
+ positions, transcripts = crop_pdf(path, output_pdf_dir)
84
+ table = extract_ocr(path, positions, transcripts)
85
+ table = refine_table(table, os.path.join(output_pdf_dir, name + '.png'), output_img_dir)
86
+
87
+ assert os.path.exists(structure_path), print('structure_path is not existed')
88
+ table['label_path'] = structure_path
89
+
90
+ visualize_table(os.path.join(output_img_dir, name + '.png'), output_visual_dir, table)
91
+ cacher.add_record(table)
92
+ correct_count += 1
93
+ except:
94
+ error_count += 1
95
+ error_paths.append(path)
96
+ crop_pdf(path, output_error_dir)
97
+
98
+ result_queue.put((correct_count, error_count, error_paths))
99
+
100
+
101
+ def multi_process(path, dst_dir, num_workers):
102
+ import multiprocessing
103
+ manager = multiprocessing.Manager()
104
+ result_queue = manager.Queue()
105
+
106
+ workers = list()
107
+ for worker_idx in range(num_workers):
108
+ worker = multiprocessing.Process(
109
+ target=_worker,
110
+ args=(
111
+ worker_idx,
112
+ num_workers,
113
+ path[worker_idx::num_workers],
114
+ dst_dir,
115
+ result_queue
116
+ )
117
+ )
118
+ worker.daemon = True
119
+ worker.start()
120
+ workers.append(worker)
121
+
122
+ total_correct_count = 0
123
+ total_error_count = 0
124
+ total_error_paths = []
125
+ for _ in range(num_workers):
126
+ correct_count, error_count, error_paths = result_queue.get()
127
+ total_correct_count += correct_count
128
+ total_error_count += error_count
129
+ total_error_paths.extend(error_paths)
130
+
131
+ print("correct num: %d, error num: %d " % (total_correct_count, total_error_count))
132
+ if len(total_error_paths) > 0:
133
+ np.save(os.path.join(dst_dir, 'error_paths.npy'), total_error_paths)
134
+
135
+ # merge each worker lrc
136
+ cache_paths = glob.glob(os.path.join(dst_dir, '*.lrc'))
137
+ merge_record_file(cache_paths, os.path.join(dst_dir, 'table.lrc'))
138
+ for cache_path in cache_paths:
139
+ os.remove(cache_path)
140
+
141
+
142
+ def main():
143
+ args = parse_args()
144
+
145
+ paths = get_sub_paths(args.src_dir, ["pdf", "chunk", "structure"], ['.pdf', '.chunk', '.json'])
146
+ # paths = get_paths(args.src_dir, ["pdf", "chunk", "structure"], '/yrfs1/intern/zrzhang6/TSR/Dataset/SciTSR/SciTSR-COMP.list', ['.pdf', '.chunk', '.json'])
147
+
148
+ if args.num_workers == 0:
149
+ single_process(paths, args.dst_dir)
150
+ else:
151
+ multi_process(paths, args.dst_dir, args.num_workers)
152
+
153
+
154
+ if __name__ == "__main__":
155
+ main()
dataset/process_scitsr.sh ADDED
File without changes
dataset/trans2lrc.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import tqdm
4
+ import numpy as np
5
+ from utlis.list_record_cache import ListRecordCacher, merge_record_file
6
+ from utlis.utlis import get_sub_paths, crop_pdf, crop_cells, visualize_cell, match_cells
7
+
8
+
9
+ def parse_args():
10
+ import argparse
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument('src_dir', type=str, default=None)
13
+ parser.add_argument('dst_dir', type=str, default=None)
14
+ parser.add_argument('-n', '--num_workers', type=int, default=0)
15
+ args = parser.parse_args()
16
+ return args
17
+
18
+
19
+ def single_process(paths, dst_dir):
20
+
21
+ output_pdf_dir = os.path.join(dst_dir, 'pdf')
22
+ if not os.path.exists(output_pdf_dir):
23
+ os.makedirs(output_pdf_dir)
24
+ output_img_dir = os.path.join(dst_dir, 'img')
25
+ if not os.path.exists(output_img_dir):
26
+ os.makedirs(output_img_dir)
27
+ output_visual_dir = os.path.join(dst_dir, 'visual')
28
+ if not os.path.exists(output_visual_dir):
29
+ os.makedirs(output_visual_dir)
30
+ output_error_dir = os.path.join(dst_dir, 'error')
31
+ if not os.path.exists(output_error_dir):
32
+ os.makedirs(output_error_dir)
33
+
34
+ cacher = ListRecordCacher(os.path.join(dst_dir, 'table.lrc'))
35
+
36
+ error_paths = []
37
+ error_count = 0
38
+ correct_count = 0
39
+ for id, path in enumerate(tqdm.tqdm(paths)):
40
+ try:
41
+ pdf_path, chunk_path, structure_path = path
42
+ positions, transcripts = crop_pdf(path, output_pdf_dir)
43
+ table = match_cells([pdf_path, chunk_path, structure_path], positions, transcripts)
44
+ crop_cells(os.path.join(output_pdf_dir, os.path.splitext(os.path.basename(pdf_path))[0] + '.png'), output_img_dir, table)
45
+ table['id'] = id
46
+ table['image_path'] = os.path.join(output_img_dir, os.path.splitext(os.path.basename(pdf_path))[0] + '.png')
47
+ visualize_cell(os.path.join(output_img_dir, os.path.splitext(os.path.basename(pdf_path))[0] + '.png'), output_visual_dir, table)
48
+ cacher.add_record(table)
49
+ correct_count += 1
50
+ except:
51
+ error_count += 1
52
+ error_paths.append(path)
53
+ crop_pdf(path, output_error_dir)
54
+
55
+ print("correct num: %d, error num: %d " % (correct_count, error_count))
56
+ if len(error_paths) > 0:
57
+ np.save(os.path.join(dst_dir, 'error_paths.npy'), error_paths)
58
+ cacher.close()
59
+
60
+
61
+ def _worker(worker_idx, num_workers, paths, dst_dir, result_queue):
62
+
63
+ output_pdf_dir = os.path.join(dst_dir, 'pdf')
64
+ output_img_dir = os.path.join(dst_dir, 'img')
65
+ output_visual_dir = os.path.join(dst_dir, 'visual')
66
+ output_error_dir = os.path.join(dst_dir, 'error')
67
+
68
+ cacher = ListRecordCacher(os.path.join(dst_dir, 'table_%d.lrc' % worker_idx))
69
+
70
+ error_paths = []
71
+ error_count = 0
72
+ correct_count = 0
73
+ for id, path in enumerate(tqdm.tqdm(paths)):
74
+ try:
75
+ pdf_path, chunk_path, structure_path = path
76
+ positions, transcripts = crop_pdf(path, output_pdf_dir)
77
+ table = match_cells([pdf_path, chunk_path, structure_path], positions, transcripts)
78
+ crop_cells(os.path.join(output_pdf_dir, os.path.splitext(os.path.basename(pdf_path))[0] + '.png'), output_img_dir, table)
79
+ table['id'] = int(id * num_workers + worker_idx)
80
+ table['image_path'] = os.path.join(output_img_dir, os.path.splitext(os.path.basename(pdf_path))[0] + '.png')
81
+ visualize_cell(os.path.join(output_img_dir, os.path.splitext(os.path.basename(pdf_path))[0] + '.png'), output_visual_dir, table)
82
+ cacher.add_record(table)
83
+ correct_count += 1
84
+ except:
85
+ error_count += 1
86
+ error_paths.append(path)
87
+ crop_pdf(path, output_error_dir)
88
+
89
+ result_queue.put((correct_count, error_count, error_paths))
90
+
91
+
92
+ def multi_process(path, dst_dir, num_workers):
93
+ import multiprocessing
94
+ manager = multiprocessing.Manager()
95
+ result_queue = manager.Queue()
96
+
97
+ workers = list()
98
+ for worker_idx in range(num_workers):
99
+ worker = multiprocessing.Process(
100
+ target=_worker,
101
+ args=(
102
+ worker_idx,
103
+ num_workers,
104
+ path[worker_idx::num_workers],
105
+ dst_dir,
106
+ result_queue
107
+ )
108
+ )
109
+ worker.daemon = True
110
+ worker.start()
111
+ workers.append(worker)
112
+
113
+ total_correct_count = 0
114
+ total_error_count = 0
115
+ total_error_paths = []
116
+ for _ in range(num_workers):
117
+ correct_count, error_count, error_paths = result_queue.get()
118
+ total_correct_count += correct_count
119
+ total_error_count += error_count
120
+ total_error_paths.extend(error_paths)
121
+
122
+ print("correct num: %d, error num: %d " % (total_correct_count, total_error_count))
123
+ if len(total_error_paths) > 0:
124
+ np.save(os.path.join(dst_dir, 'error_paths.npy'), total_error_paths)
125
+
126
+ # merge each worker lrc
127
+ cache_paths = glob.glob(os.path.join(dst_dir, '*.lrc'))
128
+ merge_record_file(cache_paths, os.path.join(dst_dir, 'table.lrc'))
129
+ for cache_path in cache_paths:
130
+ os.remove(cache_path)
131
+
132
+
133
+ def main():
134
+ args = parse_args()
135
+
136
+ paths = get_sub_paths(args.src_dir, ["pdf", "chunk", "structure"], ['.pdf', '.chunk', '.json'])
137
+
138
+ if args.num_workers == 0:
139
+ single_process(paths, args.dst_dir)
140
+ else:
141
+ multi_process(paths, args.dst_dir, args.num_workers)
142
+
143
+
144
+ if __name__ == "__main__":
145
+ main()
dataset/utils/extract_table_lines.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+
4
+
5
+ class InvalidFormat(Exception):
6
+ pass
7
+
8
+
9
+ def segmentation_to_bbox(segmentation):
10
+ x1 = min([pt[0] for contour in segmentation for pt in contour])
11
+ y1 = min([pt[1] for contour in segmentation for pt in contour])
12
+ x2 = max([pt[0] for contour in segmentation for pt in contour])
13
+ y2 = max([pt[1] for contour in segmentation for pt in contour])
14
+ return (x1, y1, x2, y2)
15
+
16
+
17
+ def cal_cell_bbox(table):
18
+ cells_bbox = list()
19
+ for cell in table['cells']:
20
+ if 'segmentation' not in cell:
21
+ cell_bbox = None
22
+ else:
23
+ segmentation = list()
24
+ if 'sublines' in cell:
25
+ for subline in cell['sublines']:
26
+ segmentation.extend(subline['segmentation'])
27
+ if len(segmentation) == 0:
28
+ segmentation = cell['segmentation']
29
+ if len(segmentation) == 0:
30
+ cell_bbox = None
31
+ else:
32
+ cell_bbox = segmentation_to_bbox(segmentation)
33
+ cells_bbox.append(cell_bbox)
34
+ return cells_bbox
35
+
36
+
37
+ def cal_cell_spans(table):
38
+ layout = table['layout']
39
+ num_cells = len(table['cells'])
40
+ cells_span = list()
41
+ for cell_id in range(num_cells):
42
+ cell_positions = np.argwhere(layout == cell_id)
43
+ y1 = np.min(cell_positions[:, 0])
44
+ y2 = np.max(cell_positions[:, 0])
45
+ x1 = np.min(cell_positions[:, 1])
46
+ x2 = np.max(cell_positions[:, 1])
47
+ assert np.all(layout[y1:y2, x1:x2] == cell_id)
48
+ cells_span.append([x1, y1, x2, y2])
49
+ return cells_span
50
+
51
+
52
+ def cal_fg_bg_span(spans, edge):
53
+ num_span = len(spans)
54
+ bg_spans = list()
55
+ for idx in range(num_span):
56
+ if spans[idx] is None:
57
+ continue
58
+ if idx == 0:
59
+ if spans[idx][0] <= 0:
60
+ continue
61
+ else:
62
+ if spans[idx-1] is None:
63
+ continue
64
+ if spans[idx][0] <= spans[idx-1][1]:
65
+ continue
66
+ if idx == num_span - 1:
67
+ if spans[idx][1] >= edge:
68
+ continue
69
+ else:
70
+ if spans[idx+1] is None:
71
+ continue
72
+ if spans[idx][1] >= spans[idx+1][0]:
73
+ continue
74
+
75
+ bg_spans.append(spans[idx])
76
+
77
+ fg_spans = list()
78
+ for idx in range(num_span+1):
79
+ if idx == 0:
80
+ s = 0
81
+ else:
82
+ if spans[idx-1] is None:
83
+ continue
84
+ s = spans[idx-1][1]
85
+
86
+ if idx == num_span:
87
+ e = edge
88
+ else:
89
+ if spans[idx] is None:
90
+ continue
91
+ e = spans[idx][0]
92
+
93
+ if e <= s:
94
+ continue
95
+
96
+ fg_spans.append([s, e])
97
+
98
+ return fg_spans, bg_spans
99
+
100
+
101
+ def shrink_spans(spans, size):
102
+ new_spans = list()
103
+ for idx, (start, end) in enumerate(spans):
104
+ if idx == 0:
105
+ if start <= 0:
106
+ start = 1
107
+ else:
108
+ _, pre_end = spans[idx - 1]
109
+ if start <= pre_end:
110
+ shrink_distance = pre_end - start + 1
111
+ start = start + math.ceil(shrink_distance / 2)
112
+
113
+ if idx == len(spans) - 1:
114
+ if end >= size:
115
+ end = size - 1
116
+ else:
117
+ next_start, _ = spans[idx + 1]
118
+ if end >= next_start:
119
+ shrink_distance = end - next_start + 1
120
+ end = end - math.ceil(shrink_distance / 2)
121
+ if end - start < 1:
122
+ raise InvalidFormat()
123
+
124
+ new_spans.append([start, end])
125
+ return new_spans
126
+
127
+
128
+ def cal_row_span(table, cells_span, cells_bbox, height):
129
+ layout = table['layout']
130
+ rows_span = list()
131
+ for row_idx in range(layout.shape[0]):
132
+ row = layout[row_idx, :]
133
+ y1s = list()
134
+ y2s = list()
135
+ for cell_id in row:
136
+ cell_span = cells_span[cell_id]
137
+ cell_bbox = cells_bbox[cell_id]
138
+ if (cell_span[1] == row_idx) and (cell_bbox is not None):
139
+ y1s.append(cell_bbox[1])
140
+ if (cell_span[3] == row_idx) and (cell_bbox is not None):
141
+ y2s.append(cell_bbox[3])
142
+
143
+ if (len(y1s) > 0) and (len(y2s) > 0):
144
+ y1 = min(max(1, min(y1s)), height-1)
145
+ y2 = min(max(1, max(y2s) + 1), height-1)
146
+ rows_span.append([y1, y2])
147
+ else:
148
+ raise InvalidFormat()
149
+ rows_span = shrink_spans(rows_span, height)
150
+ rows_fg_span, rows_bg_span = cal_fg_bg_span(rows_span, height)
151
+ return rows_fg_span, rows_bg_span
152
+
153
+
154
+ def cal_col_span(table, cells_span, cells_bbox, width):
155
+ layout = table['layout']
156
+ cols_span = list()
157
+ for col_idx in range(layout.shape[1]):
158
+ col = layout[:, col_idx]
159
+ x1s = list()
160
+ x2s = list()
161
+ for cell_id in col:
162
+ cell_span = cells_span[cell_id]
163
+ cell_bbox = cells_bbox[cell_id]
164
+ if (cell_span[0] == col_idx) and (cell_bbox is not None):
165
+ x1s.append(cell_bbox[0])
166
+ if (cell_span[2] == col_idx) and (cell_bbox is not None):
167
+ x2s.append(cell_bbox[2])
168
+
169
+ if (len(x1s) > 0) and (len(x2s) > 0):
170
+ x1 = min(max(1, min(x1s)), width-1)
171
+ x2 = min(max(1, max(x2s) + 1), width-1)
172
+ cols_span.append([x1, x2])
173
+ else:
174
+ raise InvalidFormat()
175
+ cols_span = shrink_spans(cols_span, width)
176
+ cols_fg_span, cols_bg_span = cal_fg_bg_span(cols_span, width)
177
+ return cols_fg_span, cols_bg_span
178
+
179
+
180
+ def extract_fg_bg_spans(table, image_size):
181
+ width, height = image_size
182
+ cells_bbox = cal_cell_bbox(table)
183
+ cells_span = cal_cell_spans(table)
184
+ # cal rows fg bg span
185
+ rows_fg_span, rows_bg_span = cal_row_span(
186
+ table, cells_span, cells_bbox, height
187
+ )
188
+ # cal cols fg bg span
189
+ cols_fg_span, cols_bg_span = cal_col_span(
190
+ table, cells_span, cells_bbox, width
191
+ )
192
+ return rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span
dataset/utils/list_record_cache.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import threading
4
+
5
+
6
+ def merge_record_file(files, dst_file):
7
+ cmd = 'cat'
8
+ for file in files:
9
+ cmd += ' %s' % file
10
+ cmd += ' > %s' % dst_file
11
+ os.system(cmd)
12
+
13
+
14
+ class ListRecordCacher:
15
+ OFFSET_LENGTH = 8
16
+ def __init__(self, cache_path):
17
+ self._record_pos_list = list()
18
+ self._cache_file = open(cache_path, 'wb')
19
+ self._cached_bytes = b'\x00' * self.OFFSET_LENGTH
20
+
21
+ def add_record(self, record):
22
+ record_bytes = pickle.dumps(record)
23
+ return self.add_record_bytes(record_bytes)
24
+
25
+ def add_record_bytes(self, record_bytes):
26
+ bytes_size = len(record_bytes)
27
+ offset_bytes = bytes_size.to_bytes(
28
+ length=self.OFFSET_LENGTH,
29
+ byteorder='big', signed=False
30
+ )
31
+ total_bytes = offset_bytes + record_bytes
32
+
33
+ cur_record_pos = None
34
+ if len(self._record_pos_list) == 0:
35
+ cur_record_pos = [self.OFFSET_LENGTH*2, bytes_size]
36
+ else:
37
+ cur_record_pos = [sum(self._record_pos_list[-1]) + self.OFFSET_LENGTH, bytes_size]
38
+ self._record_pos_list.append(cur_record_pos)
39
+
40
+ self._cached_bytes += total_bytes
41
+ if len(self._cached_bytes) > 1024*1024:
42
+ self._cache_file.seek(0, 2)
43
+ self._cache_file.write(self._cached_bytes)
44
+ self._cached_bytes = b''
45
+
46
+ def flush(self):
47
+ if len(self._cached_bytes) > 0:
48
+ self._cache_file.seek(0, 2)
49
+ self._cache_file.write(self._cached_bytes)
50
+ self._cached_bytes = b''
51
+
52
+ def _wirte_record_pos_list(self):
53
+ self.flush()
54
+ self._cache_file.seek(0, 2)
55
+ offset = self._cache_file.tell()
56
+ offset_bytes = offset.to_bytes(
57
+ length=self.OFFSET_LENGTH,
58
+ byteorder='big', signed=False
59
+ )
60
+ self._cache_file.seek(0)
61
+ self._cache_file.write(offset_bytes)
62
+
63
+ data_bytes = pickle.dumps(self._record_pos_list)
64
+ bytes_size = len(data_bytes)
65
+ offset_bytes = bytes_size.to_bytes(
66
+ length=self.OFFSET_LENGTH,
67
+ byteorder='big', signed=False
68
+ )
69
+ total_bytes = offset_bytes + data_bytes
70
+ self._cache_file.seek(0, 2)
71
+ self._cache_file.write(total_bytes)
72
+
73
+ def close(self):
74
+ if not self._cache_file.closed:
75
+ self._wirte_record_pos_list()
76
+ self._cache_file.close()
77
+
78
+ def __del__(self):
79
+ self.close()
80
+
81
+
82
+ class ListRecordLoader:
83
+ OFFSET_LENGTH = 8
84
+ def __init__(self, load_path):
85
+ self._sync_lock = threading.Lock()
86
+ self._size = os.path.getsize(load_path)
87
+ self._load_path = load_path
88
+ self._open_file()
89
+ self._scan_file()
90
+
91
+ def _open_file(self):
92
+ self._pid = os.getpid()
93
+ self._cache_file = open(self._load_path, 'rb')
94
+
95
+ def _check_reopen(self):
96
+ if (self._pid != os.getpid()):
97
+ self._open_file()
98
+
99
+ def _scan_file(self):
100
+ record_pos_list = list()
101
+ pos = 0
102
+ while True:
103
+ if pos >= self._size:
104
+ break
105
+ self._cache_file.seek(pos)
106
+ offset = int().from_bytes(
107
+ self._cache_file.read(self.OFFSET_LENGTH),
108
+ byteorder='big', signed=False
109
+ )
110
+ offset = pos + offset
111
+ self._cache_file.seek(offset)
112
+
113
+ byte_size = int().from_bytes(
114
+ self._cache_file.read(self.OFFSET_LENGTH),
115
+ byteorder='big', signed=False
116
+ )
117
+ record_pos_list_bytes = self._cache_file.read(byte_size)
118
+ sub_record_pos_list = pickle.loads(record_pos_list_bytes)
119
+ assert isinstance(sub_record_pos_list, list)
120
+ sub_record_pos_list = [[item[0]+pos, item[1]] for item in sub_record_pos_list]
121
+ record_pos_list.extend(sub_record_pos_list)
122
+ pos = self._cache_file.tell()
123
+
124
+ self._record_pos_list = record_pos_list
125
+
126
+ def get_record(self, idx):
127
+ self._check_reopen()
128
+ record_bytes = self.get_record_bytes(idx)
129
+ record = pickle.loads(record_bytes)
130
+ return record
131
+
132
+ def get_record_bytes(self, idx):
133
+ offset, length = self._record_pos_list[idx]
134
+ self._sync_lock.acquire()
135
+ try:
136
+ self._cache_file.seek(offset)
137
+ record_bytes = self._cache_file.read(length)
138
+ finally:
139
+ self._sync_lock.release()
140
+ return record_bytes
141
+
142
+ def __len__(self):
143
+ return len(self._record_pos_list)
dataset/utils/utils.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import copy
5
+ import tqdm
6
+ import numpy as np
7
+ import fitz
8
+ from .extract_table_lines import extract_fg_bg_spans
9
+
10
+
11
+ def get_paths(root_dir, sub_names, names_path, exts, val=None):
12
+ # Check the existence of directories
13
+ assert os.path.isdir(root_dir)
14
+
15
+ with open(names_path, "r") as f:
16
+ names = f.readlines()
17
+ names = [name.strip() for name in names]
18
+
19
+ # TODO: sub_dirs redundancy
20
+ sub_dirs = []
21
+ for sub_name in sub_names:
22
+ sub_dir = os.path.join(root_dir, sub_name)
23
+ assert os.path.isdir(sub_dir), '"%s" is not dir.' % sub_dir
24
+ sub_dirs.append(sub_dir)
25
+
26
+ paths = []
27
+ names = names[:val]
28
+ for name in tqdm.tqdm(names):
29
+ sub_paths = []
30
+ for sub_dir, ext in zip(sub_dirs, exts):
31
+ sub_path = os.path.join(sub_dir, name + ext)
32
+ assert os.path.exists(sub_path), print('%s is not exist' % sub_path)
33
+ sub_paths.append(sub_path)
34
+ paths.append(sub_paths)
35
+
36
+ return paths
37
+
38
+
39
+ def get_sub_paths(root_dir, sub_names, exts, val=None):
40
+ # Check the existence of directories
41
+ assert os.path.isdir(root_dir)
42
+ # TODO: sub_dirs redundancy
43
+ sub_dirs = []
44
+ for sub_name in sub_names:
45
+ sub_dir = os.path.join(root_dir, sub_name)
46
+ assert os.path.isdir(sub_dir), '"%s" is not dir.' % sub_dir
47
+ sub_dirs.append(sub_dir)
48
+
49
+ paths = []
50
+ d = os.listdir(sub_dirs[0])[:val]
51
+ for file_name in tqdm.tqdm(d):
52
+ sub_paths = [os.path.join(sub_dirs[0], file_name)]
53
+ name = os.path.splitext(file_name)[0]
54
+ for sub_name, ext in zip(sub_names[1:], exts[1:]):
55
+ sub_path = os.path.join(root_dir, sub_name, name + ext)
56
+ assert os.path.exists(sub_path)
57
+ sub_paths.append(sub_path)
58
+ paths.append(sub_paths)
59
+
60
+ return paths
61
+
62
+
63
+ def cal_wer(label, rec):
64
+ dist_mat = np.zeros((len(label) + 1, len(rec) + 1), dtype='int32')
65
+ dist_mat[0, :] = range(len(rec) + 1)
66
+ dist_mat[:, 0] = range(len(label) + 1)
67
+
68
+ for i in range(1, len(label) + 1):
69
+ for j in range(1, len(rec) + 1):
70
+ hit_score = dist_mat[i - 1, j - 1] + (label[i - 1] != rec[j - 1])
71
+ ins_score = dist_mat[i, j - 1] + 1
72
+ del_score = dist_mat[i - 1, j] + 1
73
+ dist_mat[i, j] = min(hit_score, ins_score, del_score)
74
+
75
+ dist = dist_mat[len(label), len(rec)]
76
+
77
+ return 1 - dist / len(label)
78
+
79
+
80
+ def visualize(img_path, chunks, structures):
81
+ image = cv2.imread(img_path)
82
+ for chunk in chunks:
83
+ x1, x2, y1, y2 = chunk["pos"]
84
+ transcript = chunk["text"]
85
+ cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255))
86
+ cv2.putText(image, ''.join(transcript), (int(x1), int(max(0, y1-1))), cv2.FONT_HERSHEY_COMPLEX, 0.25, (0 , 0, 255), 1)
87
+ return image
88
+
89
+
90
+ def visualize_table(img_path, output_dir, table):
91
+ img = cv2.imread(img_path)
92
+ for cell in table['cells']:
93
+ x1, y1, x2, y2 = cell['bbox']
94
+ transcript = cell['transcript']
95
+ cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255))
96
+ cv2.putText(img, ''.join(transcript), (int(x1), int(max(0, y1-1))), cv2.FONT_HERSHEY_COMPLEX, 0.25, (0 , 0, 255), 1)
97
+ cv2.imwrite(os.path.join(output_dir, os.path.basename(img_path)), img)
98
+
99
+
100
+ def crop_pdf(path, output_dir, zoom_x = 2.0, zoom_y = 2.0, rotate=0, expand=10, y_fix=.0):
101
+ '''
102
+ path:[pdf_path, chunk_path]
103
+ crop table region in pdf
104
+ save pdf_name.png
105
+ return list[x1, x2, y1, y2], [str]. note these are corresponding to crop pdf
106
+ '''
107
+ # load data
108
+ with open(path[1], 'r') as f:
109
+ chunks = json.load(f)['chunks']
110
+ doc = fitz.open(path[0])
111
+ pdf_name = os.path.splitext(os.path.basename(path[0]))[0]
112
+ assert doc.pageCount == 1, print(pdf_name, ' has more than 1 page!')
113
+
114
+ # transfer pdf to img
115
+ trans = fitz.Matrix(zoom_x, zoom_y).preRotate(rotate)
116
+ pm = doc[0].getPixmap(matrix=trans, alpha=False)
117
+ pm.writePNG(os.path.join(output_dir, '%s.png' % pdf_name))
118
+
119
+ # crop table region
120
+ pdf_img = cv2.imread(os.path.join(output_dir, '%s.png' % pdf_name))
121
+ h, w, *_ = pdf_img.shape
122
+ positions = []
123
+ transcripts = []
124
+ for chunk in chunks:
125
+ positions.append([chunk['pos'][0], chunk['pos'][1], chunk['pos'][3], chunk['pos'][2]]) # x1, x2, y2, y1
126
+ transcripts.append(chunk["text"])
127
+
128
+ # the last chunk transcrip is repeated
129
+ transcripts[-1] = transcripts[-1][:-1]
130
+
131
+ positions = np.array(positions)
132
+ positions[:, :2] *= zoom_x
133
+ positions[:, 2:] = h - positions[:, 2:] * zoom_y
134
+ x_min = int(max(0, positions[:, :2].min() - expand))
135
+ y_min = int(max(0, positions[:, 2:].min() - expand))
136
+ x_max = int(min(w, positions[:, :2].max() + expand))
137
+ y_max = int(min(h, positions[:, 2:].max() + expand))
138
+
139
+ img_crop = pdf_img[y_min:y_max, x_min:x_max]
140
+ cv2.imwrite(os.path.join(output_dir, '%s.png' % pdf_name), img_crop)
141
+
142
+ positions[:, :2] = np.clip(positions[:, :2] - x_min, 0, w)
143
+ positions[:, 2] -= y_fix * zoom_y
144
+ positions[:, 2:] = np.clip(positions[:, 2:] - y_min, 0, h)
145
+ return positions, transcripts
146
+
147
+
148
+ def crop_cells(img_path, output_dir, info, expand=10):
149
+ cells = info['cells']
150
+ img = cv2.imread(img_path)
151
+ h, w, *_ = img.shape
152
+ bboxes = [cell['bbox'] for cell in cells if 'bbox' in cell.keys()]
153
+ bboxes = np.array(bboxes)
154
+ x_min = int(max(bboxes[:, 0].min() - expand, 0))
155
+ y_min = int(max(bboxes[:, 1].min() - expand, 0))
156
+ x_max = int(min(bboxes[:, 2].max() + expand, w))
157
+ y_max = int(min(bboxes[:, 3].max() + expand, h))
158
+ cv2.imwrite(os.path.join(output_dir, os.path.splitext(os.path.basename(img_path))[0] + '.png'), img[y_min:y_max, x_min:x_max])
159
+
160
+ # refine cell bbox
161
+ new_cells = []
162
+ for cell in cells:
163
+ if 'bbox' not in cell.keys():
164
+ new_cells.append(cell)
165
+ else:
166
+ cell['bbox'][0] = max(0, cell['bbox'][0] - x_min)
167
+ cell['bbox'][1] = max(0, cell['bbox'][1] - y_min)
168
+ cell['bbox'][2] = max(0, cell['bbox'][2] - x_min)
169
+ cell['bbox'][3] = max(0, cell['bbox'][3] - y_min)
170
+ segmentation = cell['segmentation']
171
+ cell['segmentation'] = [[[pt[0] - x_min, pt[1] - y_min] for pt in contour] for contour in segmentation]
172
+ new_cells.append(cell)
173
+ info['cells'] = new_cells
174
+
175
+
176
+ def visualize_ocr(img_path, output_dir, positions, transcripts):
177
+ img = cv2.imread(img_path)
178
+ for position, transcript in zip(positions, transcripts):
179
+ x1, x2, y1, y2 = position
180
+ cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255))
181
+ cv2.putText(img, transcript, (int(x1), int(y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,0,0), 1)
182
+ cv2.imwrite(os.path.join(output_dir, os.path.splitext(os.path.basename(img_path))[0] + '_ocr.png'), img)
183
+
184
+
185
+ def cal_cell_spans(table):
186
+ layout = table['layout']
187
+ num_cells = len(table['cells'])
188
+ cells_span = list()
189
+ for cell_id in range(num_cells):
190
+ cell_positions = np.argwhere(layout == cell_id)
191
+ y1 = np.min(cell_positions[:, 0])
192
+ y2 = np.max(cell_positions[:, 0])
193
+ x1 = np.min(cell_positions[:, 1])
194
+ x2 = np.max(cell_positions[:, 1])
195
+ assert np.all(layout[y1:y2, x1:x2] == cell_id)
196
+ cells_span.append([x1, y1, x2, y2])
197
+ return cells_span
198
+
199
+
200
+ def visualize_cell(img_path, output_dir, table):
201
+ def spans2lines(spans):
202
+ lines = []
203
+ lines.append(spans[0][0])
204
+ for span in spans[1:-1]:
205
+ t1, t2 = span
206
+ lines.append(int((t1 + t2) / 2))
207
+ lines.append(spans[-1][-1])
208
+ return lines
209
+
210
+ img = cv2.imread(img_path)
211
+
212
+ # draw table lines
213
+ rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span = extract_fg_bg_spans(table, img.shape[::-1][-2:])
214
+ row_lines = spans2lines(rows_fg_span)
215
+ col_lines = spans2lines(cols_fg_span)
216
+ for span in cells_span:
217
+ x1, y1, x2, y2 = span
218
+ cv2.rectangle(img, (int(col_lines[x1]), int(row_lines[y1])), (int(col_lines[x2 + 1]), int(row_lines[y2 + 1])), (0, 0, 255), 2)
219
+
220
+ # draw ocr results
221
+ for cell in table['cells']:
222
+ if 'bbox' not in cell.keys():
223
+ continue
224
+ x1, y1, x2, y2 = cell['bbox']
225
+ transcript = cell['transcript']
226
+ cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 1)
227
+ cv2.putText(img, ''.join(transcript), (int(x1), int(y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,0,0), 1)
228
+ cv2.imwrite(os.path.join(output_dir, os.path.splitext(os.path.basename(img_path))[0] + '.png'), img)
229
+
230
+
231
+ def match_cells(path, positions, transcripts, k=16, start=0.333, stop=0.1, stop_percent=0.3, gap=0.25):
232
+ '''
233
+ path: [pdf_path, chunk_path, structure_path]
234
+ positions: [x1, x2, y1, y2],
235
+ transcripts: [str]
236
+ retrun dict(
237
+ 'layout':np.array()
238
+ 'bbox':[x1, y1, x2, y2]
239
+ 'transcript: str
240
+ 'head_rows':[]
241
+ 'body_rows':[]
242
+ )
243
+ '''
244
+ # load data
245
+ with open(path[2], 'r') as f:
246
+ cells = json.load(f)['cells']
247
+
248
+ # first sort cells from left to right, from top to down
249
+ cells_pos = [] # xl1, yl1, xl2, yl2
250
+ contents = []
251
+ for cell in cells:
252
+ cells_pos.append([cell['start_col'], cell['start_row'], cell['end_col'], cell['end_row']])
253
+ contents.append(' '.join(cell['content']))
254
+
255
+ # sorted cells from left to right, from top to down
256
+ sorted_idx = sorted(list(range(len(cells_pos))), key=lambda idx: cells_pos[idx][0] + 1e6 * cells_pos[idx][1])
257
+ cells_pos = [cells_pos[idx] for idx in sorted_idx]
258
+ contents = [contents[idx] for idx in sorted_idx]
259
+
260
+ # layout
261
+ n_row = np.array(cells_pos)[:, 3].max() + 1
262
+ n_col = np.array(cells_pos)[:, 2].max() + 1
263
+ layout = np.full((n_row, n_col), -1)
264
+
265
+ # head_rows & body_rows
266
+ head_rows = list(range((np.array(cells_pos)[np.array(cells_pos)[:,1] == 0][:, 3] - np.array(cells_pos)[np.array(cells_pos)[:,1] == 0][:, 1]).max() + 1))
267
+ body_rows = list(range((np.array(cells_pos)[np.array(cells_pos)[:,1] == 0][:, 3] - np.array(cells_pos)[np.array(cells_pos)[:,1] == 0][:, 1]).max() + 1, n_row))
268
+
269
+ lt = [-1, -1]
270
+ cells = []
271
+ valid_idx = list(range(len(transcripts)))
272
+
273
+ # init start/end index of ocr results
274
+ start_content = ''
275
+ for content in contents:
276
+ if len(content) > 0:
277
+ start_content = content
278
+ break
279
+ try:
280
+ start_index = [cal_wer(start_content, transcript) > start for transcript in transcripts[:k]].index(True)
281
+ except:
282
+ start_index = 0
283
+
284
+ end_content = ''
285
+ for content in contents[::-1]:
286
+ if len(content) > 0:
287
+ end_content = content
288
+ break
289
+ try:
290
+ end_index = [cal_wer(end_content, transcript) > start for transcript in transcripts[::-1][:k]].index(True)
291
+ except:
292
+ end_index = 0
293
+
294
+ valid_idx = valid_idx[start_index:] if end_index == 0 else valid_idx[start_index: -end_index]
295
+
296
+ assert len(contents) >= len(valid_idx), print('OCR Results Have Error')
297
+
298
+ stop_counts = 0
299
+ for index, (cell_pos, content) in enumerate(zip(cells_pos, contents)):
300
+ # confirm the cell pos is increase
301
+ assert cell_pos[0] > lt[0] or cell_pos[1] > lt[1], print('Sorted Cells Have Error')
302
+ lt = cell_pos[:2]
303
+
304
+ xl1, yl1, xl2, yl2 = cell_pos
305
+ layout[yl1:yl2+1, xl1:xl2+1] = index
306
+
307
+ if len(content) == 0:
308
+ cells.append(dict(transcript=[]))
309
+ else:
310
+ is_completed = False
311
+ bboxes_list = [positions[valid_idx[0]]]
312
+ transcripts_list = [transcripts[valid_idx[0]]]
313
+ valid_idx.pop(0)
314
+ wer_last = cal_wer(content, ' '.join(transcripts_list))
315
+ if wer_last < stop:
316
+ bboxes_list = np.array(bboxes_list)
317
+ x1 = int(bboxes_list[:, :2].min())
318
+ x2 = int(bboxes_list[:, :2].max())
319
+ y1 = int(bboxes_list[:, 2:].min())
320
+ y2 = int(bboxes_list[:, 2:].max())
321
+ cells.append(dict(transcript=list(content), bbox=[x1, y1, x2, y2], segmentation=[[[x1,y1],[x2,y1],[x2,y2],[x1,y2]]]))
322
+ stop_counts += 1
323
+ continue
324
+ for idx in valid_idx[:k]:
325
+ if content == ' '.join(transcripts_list):
326
+ bboxes_list = np.array(bboxes_list)
327
+ x1 = int(bboxes_list[:, :2].min())
328
+ x2 = int(bboxes_list[:, :2].max())
329
+ y1 = int(bboxes_list[:, 2:].min())
330
+ y2 = int(bboxes_list[:, 2:].max())
331
+ cells.append(dict(transcript=list(content), bbox=[x1, y1, x2, y2], segmentation=[[[x1,y1],[x2,y1],[x2,y2],[x1,y2]]]))
332
+ is_completed = True
333
+ break
334
+ else:
335
+ cur_trans = copy.deepcopy(transcripts_list)
336
+ cur_trans.append(transcripts[idx])
337
+ wer = cal_wer(content, ' '.join(cur_trans))
338
+ # if add new str, and wer is not increase a lot, it should not be added in
339
+ if wer < wer_last + gap:
340
+ continue
341
+ else:
342
+ transcripts_list.append(transcripts[idx])
343
+ bboxes_list.append(positions[idx])
344
+ valid_idx.pop(valid_idx.index(idx))
345
+ if wer == 1.0:
346
+ break
347
+ else:
348
+ wer_last = wer
349
+ if not is_completed:
350
+ bboxes_list = np.array(bboxes_list)
351
+ x1 = int(bboxes_list[:, :2].min())
352
+ x2 = int(bboxes_list[:, :2].max())
353
+ y1 = int(bboxes_list[:, 2:].min())
354
+ y2 = int(bboxes_list[:, 2:].max())
355
+ cells.append(dict(transcript=list(content), bbox=[x1, y1, x2, y2], segmentation=[[[x1,y1],[x2,y1],[x2,y2],[x1,y2]]]))
356
+
357
+ assert stop_counts / len(contents) < stop_percent, print('This Table Has Many Error Match with OCR Results')
358
+ assert layout.min() == 0, print('This Table Layout is not Completely Resolved')
359
+ return dict(
360
+ layout=layout,
361
+ cells=cells,
362
+ head_rows=head_rows,
363
+ body_rows=body_rows,
364
+ )
365
+
366
+
367
+ def extract_ocr(path, positions, transcripts, k=16, start=0.333):
368
+ '''
369
+ path: [pdf_path, chunk_path, structure_path]
370
+ positions: [x1, x2, y1, y2],
371
+ transcripts: [ ]
372
+ retrun dict(
373
+ 'cells':{
374
+ 'bbox':[x1, y1, x2, y2]
375
+ 'transcript: []
376
+ }
377
+ )
378
+ '''
379
+ # load data
380
+ with open(path[2], 'r') as f:
381
+ cells = json.load(f)['cells']
382
+
383
+ # first sort cells from left to right, from top to down
384
+ cells_pos = [] # xl1, yl1, xl2, yl2
385
+ contents = []
386
+ for cell in cells:
387
+ cells_pos.append([cell['start_col'], cell['start_row'], cell['end_col'], cell['end_row']])
388
+ contents.append(' '.join(cell['content']))
389
+
390
+ # sorted cells from left to right, from top to down
391
+ sorted_idx = sorted(list(range(len(cells_pos))), key=lambda idx: cells_pos[idx][0] + 1e6 * cells_pos[idx][1])
392
+ cells_pos = [cells_pos[idx] for idx in sorted_idx]
393
+ contents = [contents[idx] for idx in sorted_idx]
394
+
395
+ # init start/end index, condition is the first/last index must not over split, and wer should be larger than start threshold
396
+ valid_idx = list(range(len(transcripts)))
397
+ start_content = ''
398
+ for content in contents:
399
+ if len(content) > 0:
400
+ start_content = content
401
+ break
402
+ try:
403
+ start_index = [cal_wer(start_content, transcript) > start for transcript in transcripts[:k]].index(True)
404
+ except:
405
+ start_index = 0
406
+
407
+ end_content = ''
408
+ for content in contents[::-1]:
409
+ if len(content) > 0:
410
+ end_content = content
411
+ break
412
+ try:
413
+ end_index = [cal_wer(end_content, transcript) > start for transcript in transcripts[::-1][:k]].index(True)
414
+ except:
415
+ end_index = 0
416
+
417
+ valid_idx = valid_idx[start_index:] if end_index == 0 else valid_idx[start_index: -end_index]
418
+
419
+ cells = []
420
+ for idx in valid_idx:
421
+ x1, x2, y1, y2 = positions[idx].astype('int').tolist()
422
+ cells.append(dict(transcript=list(transcripts[idx]), bbox=[x1, y1, x2, y2], segmentation=[[[x1,y1],[x2,y1],[x2,y2],[x1,y2]]]))
423
+
424
+ return dict(
425
+ cells=cells
426
+ )
427
+
428
+
429
+ def refine_table(table, img_path, output_dir, expand=10):
430
+ cells = table['cells']
431
+ bboxes = [cell['bbox'] for cell in table['cells'] if 'bbox' in cell.keys()]
432
+ bboxes = np.array(bboxes)
433
+ img = cv2.imread(img_path)
434
+ h, w, *_ = img.shape
435
+ x1 = int(max(0, bboxes[:, 0].min() - expand))
436
+ y1 = int(max(0, bboxes[:, 1].min() - expand))
437
+ x2 = int(min(w, bboxes[:, 2].max() + expand))
438
+ y2 = int(min(h, bboxes[:, 3].max() + expand))
439
+ # refine cells
440
+ bboxes[:, 0::2] = np.clip(bboxes[:, 0::2] - x1, 0, 1e6)
441
+ bboxes[:, 1::2] = np.clip(bboxes[:, 1::2] - y1, 0, 1e6)
442
+ bboxes = bboxes.tolist()
443
+ for cell, bbox in zip(cells, bboxes):
444
+ cell['bbox'] = bbox
445
+
446
+ img = img[y1:y2, x1:x2]
447
+ cv2.imwrite(os.path.join(output_dir, os.path.basename(img_path)), img)
448
+ table['image_path'] = os.path.join(output_dir, os.path.basename(img_path))
449
+ return table
libs/configs/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import default
2
+ import importlib
3
+
4
+
5
+ class CFG:
6
+ def __init__(self):
7
+ self.__dict__['cfg'] = None
8
+
9
+ def __getattr__(self, name):
10
+ return getattr(self.__dict__['cfg'], name)
11
+
12
+ def __setattr__(self, name, val):
13
+ setattr(self.__dict__['cfg'], name, val)
14
+
15
+
16
+ cfg = CFG()
17
+ cfg.__dict__['cfg'] = default
18
+
19
+
20
+ def setup_config(cfg_name):
21
+ global cfg
22
+ module_name = 'libs.configs.' + cfg_name
23
+ cfg_module = importlib.import_module(module_name)
24
+ cfg.__dict__['cfg'] = cfg_module
libs/configs/default.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from libs.utils.vocab import Vocab
4
+
5
+ device = torch.device('cuda')
6
+
7
+ train_lrcs_path = [
8
+ "/yrfs1/intern/pfhu6/TSR/Dataset/SciTSR/train/table.lrc"
9
+ ]
10
+ train_data_dir = ''
11
+ train_max_pixel_nums = 400 * 400 * 5
12
+ train_bucket_seps = (50, 50, 50)
13
+ train_max_batch_size = 8
14
+ train_num_workers = 4
15
+
16
+ valid_lrc_path = '/yrfs1/intern/pfhu6/TSR/Dataset/SciTSR/test/table.lrc'
17
+ valid_data_dir = ''
18
+ valid_num_workers = 0
19
+ valid_batch_size = 1
20
+
21
+ vocab = Vocab()
22
+
23
+ # model params
24
+ # backbone
25
+ arch = "res34"
26
+ pretrained_backbone = True
27
+ backbone_out_channels = (64, 128, 256, 512)
28
+
29
+ # fpn
30
+ fpn_out_channels = 256
31
+
32
+ # pan
33
+ pan_num_levels = 4
34
+ pan_in_dim = 256
35
+ pan_out_dim = 256
36
+
37
+ # row segment predictor
38
+ rs_scale = 1
39
+
40
+ # col segment predictor
41
+ cs_scale = 1
42
+
43
+ # divide predictor
44
+ dp_head_nums = 8
45
+ dp_scale = 1
46
+
47
+ # cells extractor params
48
+ ce_scale = 1 / 8
49
+ ce_pool_size = (3, 3)
50
+ ce_dim = 512
51
+ ce_head_nums = 8
52
+ ce_heads = 1
53
+
54
+ # decoder
55
+ embed_dim = 512
56
+ feat_dim = 512
57
+ lm_state_dim = 512
58
+ proj_dim = 512
59
+ cover_kernel = 7
60
+ att_threshold = 0.5
61
+ spatial_att_weight_loss_wight = 1.0
62
+
63
+ # train params
64
+ base_lr = 0.0001
65
+ min_lr = 1e-6
66
+ weight_decay = 0
67
+
68
+ num_epochs = 20
69
+ sync_rate = 20
70
+
71
+ log_sep = 20
72
+
73
+ work_dir = './experiments/heads_1'
74
+
75
+ train_checkpoint = None
76
+
77
+ eval_checkpoint = os.path.join(work_dir, 'best_f1_model.pth')
libs/data/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data.distributed import DistributedSampler
3
+ from .batch_sampler import BucketSampler
4
+ from .dataset import LRCRecordLoader
5
+ from .dataset import Dataset, collate_func
6
+ from libs.utils.comm import distributed, get_rank, get_world_size
7
+ from . import transform as T
8
+
9
+
10
+ def create_train_dataloader(vocab, lrcs_path, num_workers, max_batch_size, max_pixel_nums, bucket_seps, data_root_dir):
11
+ loaders = list()
12
+ for lrc_path in lrcs_path:
13
+ loader = LRCRecordLoader(lrc_path, data_root_dir)
14
+ loaders.append(loader)
15
+
16
+ transforms = T.Compose([
17
+ T.TableToLabel(vocab),
18
+ T.CalRowColSpans(),
19
+ T.CalCellSpans(),
20
+ T.CalHeadBodyDivide(),
21
+ T.ToTensor(),
22
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
23
+ ])
24
+
25
+ dataset = Dataset(loaders, transforms)
26
+ batch_sampler = BucketSampler(dataset, get_world_size(), get_rank(), max_pixel_nums=max_pixel_nums, max_batch_size=max_batch_size,seps=bucket_seps)
27
+
28
+ dataloader = torch.utils.data.DataLoader(
29
+ dataset,
30
+ num_workers=num_workers,
31
+ collate_fn=collate_func,
32
+ batch_sampler=batch_sampler
33
+ )
34
+ return dataloader
35
+
36
+
37
+ def create_valid_dataloader(vocab, lrc_path, num_workers, batch_size, data_root_dir):
38
+ loader = LRCRecordLoader(lrc_path, data_root_dir)
39
+
40
+ transforms = T.Compose([
41
+ T.TableToLabel(vocab),
42
+ T.CalRowColSpans(),
43
+ T.CalCellSpans(),
44
+ T.CalHeadBodyDivide(),
45
+ T.ToTensor(),
46
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
47
+ ])
48
+
49
+ dataset = Dataset([loader], transforms)
50
+ if distributed():
51
+ sampler = DistributedSampler(dataset, get_world_size(), get_rank(), True)
52
+ dataloader = torch.utils.data.DataLoader(
53
+ dataset,
54
+ num_workers=num_workers,
55
+ batch_size=batch_size,
56
+ collate_fn=collate_func,
57
+ sampler=sampler,
58
+ drop_last=False
59
+ )
60
+ else:
61
+ dataloader = torch.utils.data.DataLoader(
62
+ dataset,
63
+ num_workers=num_workers,
64
+ batch_size=batch_size,
65
+ collate_fn=collate_func,
66
+ shuffle=False,
67
+ drop_last=False
68
+ )
69
+ return dataloader
libs/data/batch_sampler.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ import copy
3
+ import random
4
+ from collections import defaultdict
5
+ from libs.utils import logger
6
+ import numpy as np
7
+
8
+ class BucketSampler:
9
+ def __init__(self, dataset, world_size, rank_id, fix_batch_size=None, max_pixel_nums=None, max_batch_size=8, min_batch_size=1, seps=(100, 100, 20)):
10
+ self.dataset = dataset
11
+ self.world_size = world_size
12
+ self.rank_id = rank_id
13
+ self.seps = seps
14
+ self.fix_batch_size = fix_batch_size
15
+ self.max_batch_size = max_batch_size
16
+ self.min_batch_size = min_batch_size
17
+ self.max_pixel_nums = max_pixel_nums
18
+ assert (fix_batch_size is not None) or (max_pixel_nums is not None)
19
+ self.cal_buckets()
20
+ self.seed = 20
21
+ self.epoch = 0
22
+
23
+ def count_keys(self):
24
+ infos = []
25
+ for i in tqdm.tqdm(range(len(self.dataset))):
26
+ info = self.dataset.get_info(i)
27
+ infos.append(info)
28
+ return infos
29
+
30
+ def cal_buckets(self):
31
+ infos = self.count_keys()
32
+ np.save('count_keys.npy', infos)
33
+ min_sizes = None # (64, 18, 2)
34
+ max_sizes = None # (1223, 742, 2080)
35
+ for info in infos:
36
+ if min_sizes is None:
37
+ min_sizes = info
38
+ max_sizes = info
39
+ else: # get the max size of each item of tuple
40
+ min_sizes = tuple(min(min_sizes[idx], info[idx]) for idx in range(len(min_sizes)))
41
+ max_sizes = tuple(max(max_sizes[idx], info[idx]) for idx in range(len(max_sizes)))
42
+ assert (min_sizes is not None) and (len(self.seps) == len(min_sizes))
43
+ print('max sizes: {}, min size: {}'.format(max_sizes, min_sizes))
44
+ buckets = defaultdict(list)
45
+ for idx, info in enumerate(infos):
46
+ bucket_idxes = list()
47
+ for sep, size, min_size in zip(self.seps, info, min_sizes):
48
+ bucket_idx = (size - min_size) // sep
49
+ bucket_idxes.append(str(bucket_idx))
50
+ bucket_idxes = '-'.join(bucket_idxes)
51
+ buckets[bucket_idxes].append(idx)
52
+ np.save('buckets.npy', buckets)
53
+
54
+ valid_buckets = dict()
55
+ for bucket_key, bucket_samples in buckets.items():
56
+ if len(bucket_samples) < self.min_batch_size:
57
+ continue
58
+ if (self.fix_batch_size is not None) and (len(bucket_samples) < self.fix_batch_size):
59
+ continue
60
+
61
+ w, h, *_ = [(int(item) + 1) * sep + min_size for item, min_size, sep in zip(bucket_key.split('-'), min_sizes, self.seps)]
62
+ if self.fix_batch_size is not None:
63
+ if h * w * self.fix_batch_size > self.max_pixel_nums:
64
+ continue
65
+ else:
66
+ if h * w * self.min_batch_size > self.max_pixel_nums:
67
+ continue
68
+
69
+ if self.fix_batch_size is not None:
70
+ batch_size = self.fix_batch_size
71
+ else:
72
+ batch_size = min(self.max_batch_size, max(self.max_pixel_nums // (w * h), self.min_batch_size), len(bucket_samples))
73
+
74
+ valid_buckets[bucket_key] = dict(
75
+ samples=bucket_samples,
76
+ batch_size=batch_size
77
+ )
78
+
79
+ self.buckets = [valid_buckets[bucket_key] for bucket_key in sorted(valid_buckets.keys())]
80
+ total_nums = len(infos)
81
+ valid_nums = sum([len(item['samples']) for item in valid_buckets.values()])
82
+ logger.info('Total %d samples, but ignore %d samples' % (total_nums, total_nums - valid_nums))
83
+
84
+ def __iter__(self):
85
+ random_inst = random.Random(self.seed + self.epoch)
86
+ batches = list()
87
+ for bucket in self.buckets:
88
+ sample = copy.deepcopy(bucket['samples'])
89
+ batch_size = bucket['batch_size']
90
+ random_inst.shuffle(sample)
91
+ idx = 0
92
+ while idx < len(sample):
93
+ batch = sample[idx:idx + batch_size]
94
+ idx += batch_size
95
+ if len(batch) < self.min_batch_size:
96
+ continue
97
+ batches.append(batch)
98
+ random_inst.shuffle(batches)
99
+
100
+ align_nums = (len(batches) // self.world_size) * self.world_size
101
+ batches = batches[: align_nums]
102
+ for batch_idx in range(self.rank_id, len(batches), self.world_size):
103
+ yield batches[batch_idx]
104
+
105
+ def __len__(self):
106
+ batch_nums = 0
107
+ for bucket in self.buckets:
108
+ bucket_sample_nums = len(bucket["samples"])
109
+ bucket_bs = bucket['batch_size']
110
+ batch_nums += bucket_sample_nums // bucket_bs
111
+ if bucket_sample_nums % bucket_bs >= self.min_batch_size:
112
+ batch_nums += 1
113
+
114
+ return batch_nums // self.world_size
115
+
116
+ def set_epoch(self, epoch):
117
+ self.epoch = epoch
118
+
libs/data/dataset.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import json
4
+ import pickle
5
+ import random
6
+ from torch._C import layout
7
+ import tqdm
8
+ import torch
9
+ import numpy as np
10
+ from PIL import Image
11
+ from .list_record_cache import ListRecordLoader
12
+ from libs.utils.format_translate import table_to_html
13
+
14
+
15
+ class LRCRecordLoader:
16
+ def __init__(self, lrc_path, data_dir=''):
17
+ self.loader = ListRecordLoader(lrc_path)
18
+ self.data_root_dir = data_dir
19
+
20
+ def __len__(self):
21
+ return len(self.loader)
22
+
23
+ def get_info(self, idx):
24
+ table = self.loader.get_record(idx)
25
+ image = Image.open(table['image_path']).convert('RGB')
26
+ w = image.width
27
+ h = image.height
28
+ n_rows, n_cols = table['layout'].shape
29
+ n_cells = n_rows * n_cols
30
+ return w, h, n_cells
31
+
32
+ def get_data(self, idx):
33
+ table = self.loader.get_record(idx)
34
+ img_path = os.path.join(self.data_root_dir, table['image_path'])
35
+ image = Image.open(img_path).convert('RGB')
36
+ return image, table
37
+
38
+
39
+ class Dataset:
40
+ def __init__(self, loaders, transforms):
41
+ self.loaders = loaders
42
+ self.transforms = transforms
43
+
44
+ def _match_loader(self, idx):
45
+ offset = 0
46
+ for loader in self.loaders:
47
+ if len(loader) + offset > idx:
48
+ return loader, idx - offset
49
+ else:
50
+ offset += len(loader)
51
+ raise IndexError()
52
+
53
+ def get_info(self, idx):
54
+ loader, rela_idx = self._match_loader(idx)
55
+ return loader.get_info(rela_idx)
56
+
57
+ def __len__(self):
58
+ return sum([len(loader) for loader in self.loaders])
59
+
60
+ def __getitem__(self,idx):
61
+ try:
62
+ loader, rela_idx = self._match_loader(idx)
63
+ image, table = loader.get_data(rela_idx)
64
+ image, _, cls_label, \
65
+ rows_fg_span, rows_bg_span, \
66
+ cols_fg_span, cols_bg_span, \
67
+ cells_span, divide = self.transforms(image, table) if 'layout' in table.keys() else self.transforms(image)
68
+ return dict(
69
+ id=idx,
70
+ image_size=(image.shape[2], image.shape[1]),
71
+ image=image,
72
+ cls_label=cls_label,
73
+ rows_fg_span=rows_fg_span,
74
+ rows_bg_span=rows_bg_span,
75
+ cols_fg_span=cols_fg_span,
76
+ cols_bg_span=cols_bg_span,
77
+ cells_span=cells_span,
78
+ layout=table['layout'] if 'layout' in table.keys() else None,
79
+ divide=divide,
80
+ table=table
81
+ )
82
+ except Exception as e:
83
+ print('Error occured while load data: %d' % idx)
84
+ raise e
85
+
86
+
87
+ def collate_func(batch_data):
88
+ batch_size = len(batch_data)
89
+
90
+ image_dim = batch_data[0]['image'].shape[0]
91
+ max_h = max([data['image'].shape[1] for data in batch_data])
92
+ max_w = max([data['image'].shape[2] for data in batch_data])
93
+
94
+ batch_id = list()
95
+ batch_image_size = list()
96
+
97
+ batch_image = torch.zeros([batch_size, image_dim, max_h, max_w], dtype=torch.float)
98
+ batch_image_mask = torch.zeros([batch_size, 1, max_h, max_w], dtype=torch.float)
99
+ batch_rows_fg_span = list()
100
+ batch_rows_bg_span = list()
101
+ batch_cols_fg_span = list()
102
+ batch_cols_bg_span = list()
103
+ batch_cells_span = list()
104
+ batch_divide = list()
105
+ tables = list()
106
+
107
+ if all([(data['cls_label'] is None) and (data['layout'] is None) for data in batch_data]):
108
+ batch_cls_label = list()
109
+ batch_label_mask = list()
110
+ batch_layout = list()
111
+ else:
112
+ assert not any([(data['cls_label'] is None) or (data['layout'] is None) for data in batch_data])
113
+ max_label_length = max([data['cls_label'].shape[0] for data in batch_data])
114
+ batch_cls_label = torch.zeros([batch_size, max_label_length], dtype=torch.long)
115
+ batch_label_mask = torch.zeros([batch_size, max_label_length], dtype=torch.float)
116
+ max_nr = max([data['layout'].shape[0] for data in batch_data])
117
+ max_nc = max([data['layout'].shape[1] for data in batch_data])
118
+ batch_layout = torch.full([batch_size, max_nr, max_nc], -1, dtype=torch.float)
119
+
120
+ for batch_idx, data in enumerate(batch_data):
121
+ batch_id.append(data['id'])
122
+ batch_image_size.append(data['image_size'])
123
+
124
+ _, cur_h, cur_w = data['image'].shape
125
+ batch_image[batch_idx, :, :cur_h, :cur_w] = data["image"]
126
+ batch_image_mask[batch_idx, :, :cur_h, :cur_w] = 1
127
+
128
+ if (data['cls_label'] is None) and (data['layout'] is None):
129
+ batch_cls_label.append(data["cls_label"])
130
+ batch_label_mask.append(None)
131
+ batch_layout.append(data["layout"])
132
+ else:
133
+ label_length = data['cls_label'].shape[0]
134
+ batch_cls_label[batch_idx, :label_length] = data['cls_label']
135
+ batch_label_mask[batch_idx, :label_length] = 1.0
136
+ layout_nr, layout_nc = data["layout" ].shape
137
+ batch_layout[batch_idx, :layout_nr, :layout_nc] = torch.from_numpy(data['layout']).float()
138
+
139
+ batch_rows_fg_span.append(data["rows_fg_span"])
140
+ batch_rows_bg_span.append(data['rows_bg_span'])
141
+ batch_cols_fg_span.append(data["cols_fg_span"])
142
+ batch_cols_bg_span.append(data["cols_bg_span"])
143
+ batch_cells_span.append(data["cells_span"])
144
+ batch_divide.append(data["divide"])
145
+ tables.append(data['table'])
146
+
147
+ batch_divide = torch.tensor(batch_divide, dtype=torch.long) if batch_divide[0] is not None else batch_divide
148
+
149
+ return dict(
150
+ ids=batch_id,
151
+ images_size=batch_image_size,
152
+ images=batch_image,
153
+ images_mask=batch_image_mask,
154
+ cls_labels=batch_cls_label,
155
+ labels_mask=batch_label_mask,
156
+ rows_fg_spans=batch_rows_fg_span,
157
+ rows_bg_spans=batch_rows_bg_span,
158
+ cols_fg_spans=batch_cols_fg_span,
159
+ cols_bg_spans=batch_cols_bg_span,
160
+ cells_spans=batch_cells_span,
161
+ divide_labels=batch_divide,
162
+ layouts=batch_layout,
163
+ tables=tables
164
+ )
libs/data/list_record_cache.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import threading
4
+
5
+
6
+ def merge_record_file(files, dst_file):
7
+ cmd = 'cat'
8
+ for file in files:
9
+ cmd += ' %s' % file
10
+ cmd += ' > %s' % dst_file
11
+ os.system(cmd)
12
+
13
+
14
+ class ListRecordCacher:
15
+ OFFSET_LENGTH = 8
16
+ def __init__(self, cache_path):
17
+ self._record_pos_list = list()
18
+ self._cache_file = open(cache_path, 'wb')
19
+ self._cached_bytes = b'\x00' * self.OFFSET_LENGTH
20
+
21
+ def add_record(self, record):
22
+ record_bytes = pickle.dumps(record)
23
+ return self.add_record_bytes(record_bytes)
24
+
25
+ def add_record_bytes(self, record_bytes):
26
+ bytes_size = len(record_bytes)
27
+ offset_bytes = bytes_size.to_bytes(
28
+ length=self.OFFSET_LENGTH,
29
+ byteorder='big', signed=False
30
+ )
31
+ total_bytes = offset_bytes + record_bytes
32
+
33
+ cur_record_pos = None
34
+ if len(self._record_pos_list) == 0:
35
+ cur_record_pos = [self.OFFSET_LENGTH*2, bytes_size]
36
+ else:
37
+ cur_record_pos = [sum(self._record_pos_list[-1]) + self.OFFSET_LENGTH, bytes_size]
38
+ self._record_pos_list.append(cur_record_pos)
39
+
40
+ self._cached_bytes += total_bytes
41
+ if len(self._cached_bytes) > 1024*1024:
42
+ self._cache_file.seek(0, 2)
43
+ self._cache_file.write(self._cached_bytes)
44
+ self._cached_bytes = b''
45
+
46
+ def flush(self):
47
+ if len(self._cached_bytes) > 0:
48
+ self._cache_file.seek(0, 2)
49
+ self._cache_file.write(self._cached_bytes)
50
+ self._cached_bytes = b''
51
+
52
+ def _wirte_record_pos_list(self):
53
+ self.flush()
54
+ self._cache_file.seek(0, 2)
55
+ offset = self._cache_file.tell()
56
+ offset_bytes = offset.to_bytes(
57
+ length=self.OFFSET_LENGTH,
58
+ byteorder='big', signed=False
59
+ )
60
+ self._cache_file.seek(0)
61
+ self._cache_file.write(offset_bytes)
62
+
63
+ data_bytes = pickle.dumps(self._record_pos_list)
64
+ bytes_size = len(data_bytes)
65
+ offset_bytes = bytes_size.to_bytes(
66
+ length=self.OFFSET_LENGTH,
67
+ byteorder='big', signed=False
68
+ )
69
+ total_bytes = offset_bytes + data_bytes
70
+ self._cache_file.seek(0, 2)
71
+ self._cache_file.write(total_bytes)
72
+
73
+ def close(self):
74
+ if not self._cache_file.closed:
75
+ self._wirte_record_pos_list()
76
+ self._cache_file.close()
77
+
78
+ def __del__(self):
79
+ self.close()
80
+
81
+
82
+ class ListRecordLoader:
83
+ OFFSET_LENGTH = 8
84
+ def __init__(self, load_path):
85
+ self._sync_lock = threading.Lock()
86
+ self._size = os.path.getsize(load_path)
87
+ self._load_path = load_path
88
+ self._open_file()
89
+ self._scan_file()
90
+
91
+ def _open_file(self):
92
+ self._pid = os.getpid()
93
+ self._cache_file = open(self._load_path, 'rb')
94
+
95
+ def _check_reopen(self):
96
+ if (self._pid != os.getpid()):
97
+ self._open_file()
98
+
99
+ def _scan_file(self):
100
+ record_pos_list = list()
101
+ pos = 0
102
+ while True:
103
+ if pos >= self._size:
104
+ break
105
+ self._cache_file.seek(pos)
106
+ offset = int().from_bytes(
107
+ self._cache_file.read(self.OFFSET_LENGTH),
108
+ byteorder='big', signed=False
109
+ )
110
+ offset = pos + offset
111
+ self._cache_file.seek(offset)
112
+
113
+ byte_size = int().from_bytes(
114
+ self._cache_file.read(self.OFFSET_LENGTH),
115
+ byteorder='big', signed=False
116
+ )
117
+ record_pos_list_bytes = self._cache_file.read(byte_size)
118
+ sub_record_pos_list = pickle.loads(record_pos_list_bytes)
119
+ assert isinstance(sub_record_pos_list, list)
120
+ sub_record_pos_list = [[item[0]+pos, item[1]] for item in sub_record_pos_list]
121
+ record_pos_list.extend(sub_record_pos_list)
122
+ pos = self._cache_file.tell()
123
+
124
+ self._record_pos_list = record_pos_list
125
+
126
+ def get_record(self, idx):
127
+ self._check_reopen()
128
+ record_bytes = self.get_record_bytes(idx)
129
+ record = pickle.loads(record_bytes)
130
+ return record
131
+
132
+ def get_record_bytes(self, idx):
133
+ offset, length = self._record_pos_list[idx]
134
+ self._sync_lock.acquire()
135
+ try:
136
+ self._cache_file.seek(offset)
137
+ record_bytes = self._cache_file.read(length)
138
+ finally:
139
+ self._sync_lock.release()
140
+ return record_bytes
141
+
142
+ def __len__(self):
143
+ return len(self._record_pos_list)
libs/data/transform.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ from torchvision.transforms import functional as F
6
+ from libs.utils.format_translate import table_to_latex
7
+ from .utils import extract_fg_bg_spans, cal_cell_spans
8
+
9
+
10
+ class Compose:
11
+ def __init__(self, transforms):
12
+ self.transforms = transforms
13
+
14
+ def __call__(self, *data):
15
+ for transform in self.transforms:
16
+ data = transform(*data)
17
+ return data
18
+
19
+ class TableToLabel:
20
+ def __init__(self, vocab):
21
+ self.vocab = vocab
22
+
23
+ def __call__(self, image, table=None):
24
+ if table is None:
25
+ return image, None, None
26
+ latex = table_to_latex(table) # image.size = (w, h)
27
+ cls_label = self.vocab.words_to_ids(latex)
28
+ return image, table, cls_label
29
+
30
+ class CalRowColSpans:
31
+ def __call__(self, image, table=None, cls_label=None):
32
+ if table is None:
33
+ return image, table, None, None, None, None, None
34
+ image_size = (image.width, image.height)
35
+ rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span = extract_fg_bg_spans(table, image_size)
36
+ return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span
37
+
38
+ class CalCellSpans:
39
+ def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None):
40
+ if table is not None:
41
+ cells_span = cal_cell_spans(table)
42
+ else:
43
+ cells_span = None
44
+ return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span
45
+
46
+ class CalHeadBodyDivide:
47
+ def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None, cells_span=None):
48
+ if table is None:
49
+ divide = None
50
+ else:
51
+ head_rows = table['head_rows']
52
+ divide = len(head_rows)
53
+ return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span, divide
54
+
55
+ class ToTensor:
56
+ def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None, cells_span=None, divide=None):
57
+ image = F.to_tensor(image)
58
+ if cls_label is not None:
59
+ cls_label = torch.tensor(cls_label, dtype=torch.long)
60
+ return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span, divide
61
+
62
+ class Normalize:
63
+ def __init__(self, mean, std, inplace=False):
64
+ self.mean = mean
65
+ self.std = std
66
+ self.inplace = inplace
67
+
68
+ def __call__(self, image, table=None, cls_label=None, rows_fg_span=None, rows_bg_span=None, cols_fg_span=None, cols_bg_span=None, cells_span=None, divide=None):
69
+ image = F.normalize(image, self.mean, self.std, self.inplace)
70
+ return image, table, cls_label, rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span, divide
libs/data/utils.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from libs.utils.format_translate import table_to_html, format_html
3
+ import numpy as np
4
+
5
+
6
+ class InvalidFormat(Exception):
7
+ pass
8
+
9
+
10
+ def segmentation_to_bbox(segmentation):
11
+ x1 = min([pt[0] for contour in segmentation for pt in contour])
12
+ y1 = min([pt[1] for contour in segmentation for pt in contour])
13
+ x2 = max([pt[0] for contour in segmentation for pt in contour])
14
+ y2 = max([pt[1] for contour in segmentation for pt in contour])
15
+ return (x1, y1, x2, y2)
16
+
17
+
18
+ def cal_cell_bbox(table):
19
+ cells_bbox = list()
20
+ for cell in table['cells']:
21
+ if 'segmentation' not in cell:
22
+ cell_bbox = None
23
+ else:
24
+ segmentation = list()
25
+ if 'sublines' in cell:
26
+ for subline in cell['sublines']:
27
+ segmentation.extend(subline['segmentation'])
28
+ if len(segmentation) == 0:
29
+ segmentation = cell['segmentation']
30
+ if len(segmentation) == 0:
31
+ cell_bbox = None
32
+ else:
33
+ cell_bbox = segmentation_to_bbox(segmentation)
34
+ cells_bbox.append(cell_bbox)
35
+ return cells_bbox
36
+
37
+
38
+ def cal_cell_spans(table):
39
+ layout = table['layout']
40
+ num_cells = len(table['cells'])
41
+ cells_span = list()
42
+ for cell_id in range(num_cells):
43
+ cell_positions = np.argwhere(layout == cell_id)
44
+ y1 = np.min(cell_positions[:, 0])
45
+ y2 = np.max(cell_positions[:, 0])
46
+ x1 = np.min(cell_positions[:, 1])
47
+ x2 = np.max(cell_positions[:, 1])
48
+ assert np.all(layout[y1:y2, x1:x2] == cell_id)
49
+ cells_span.append([x1, y1, x2, y2])
50
+ return cells_span
51
+
52
+ def cal_fg_bg_span(spans, edge):
53
+ num_span = len(spans)
54
+ bg_spans = list()
55
+ for idx in range(num_span):
56
+ if spans[idx] is None:
57
+ continue
58
+ if idx == 0:
59
+ if spans[idx][0] <= 0:
60
+ continue
61
+ else:
62
+ if spans[idx-1] is None:
63
+ continue
64
+ if spans[idx][0] <= spans[idx-1][1]:
65
+ continue
66
+ if idx == num_span - 1:
67
+ if spans[idx][1] >= edge:
68
+ continue
69
+ else:
70
+ if spans[idx+1] is None:
71
+ continue
72
+ if spans[idx][1] >= spans[idx+1][0]:
73
+ continue
74
+
75
+ bg_spans.append(spans[idx])
76
+
77
+ fg_spans = list()
78
+ for idx in range(num_span+1):
79
+ if idx == 0:
80
+ s = 0
81
+ else:
82
+ if spans[idx-1] is None:
83
+ continue
84
+ s = spans[idx-1][1]
85
+
86
+ if idx == num_span:
87
+ e = edge
88
+ else:
89
+ if spans[idx] is None:
90
+ continue
91
+ e = spans[idx][0]
92
+
93
+ if e <= s:
94
+ continue
95
+
96
+ fg_spans.append([s, e])
97
+
98
+ return fg_spans, bg_spans
99
+
100
+
101
+ def shrink_spans(spans, size):
102
+ new_spans = list()
103
+ for idx, (start, end) in enumerate(spans):
104
+ if idx == 0:
105
+ if start <= 0:
106
+ start = 1
107
+ else:
108
+ _, pre_end = spans[idx - 1]
109
+ if start <= pre_end:
110
+ shrink_distance = pre_end - start + 1
111
+ start = start + math.ceil(shrink_distance / 2)
112
+
113
+ if idx == len(spans) - 1:
114
+ if end >= size:
115
+ end = size - 1
116
+ else:
117
+ next_start, _ = spans[idx + 1]
118
+ if end >= next_start:
119
+ shrink_distance = end - next_start + 1
120
+ end = end - math.ceil(shrink_distance / 2)
121
+ if end - start < 1:
122
+ raise InvalidFormat()
123
+
124
+ new_spans.append([start, end])
125
+ return new_spans
126
+
127
+
128
+ def cal_row_span(table, cells_span, cells_bbox, height):
129
+ layout = table['layout']
130
+ rows_span = list()
131
+ for row_idx in range(layout.shape[0]):
132
+ row = layout[row_idx, :]
133
+ y1s = list()
134
+ y2s = list()
135
+ for cell_id in row:
136
+ cell_span = cells_span[cell_id]
137
+ cell_bbox = cells_bbox[cell_id]
138
+ if (cell_span[1] == row_idx) and (cell_bbox is not None):
139
+ y1s.append(cell_bbox[1])
140
+ if (cell_span[3] == row_idx) and (cell_bbox is not None):
141
+ y2s.append(cell_bbox[3])
142
+
143
+ if (len(y1s) > 0) and (len(y2s) > 0):
144
+ y1 = min(max(1, min(y1s)), height - 1)
145
+ y2 = min(max(1,max(y2s) + 1), height - 1)
146
+ rows_span.append([y1, y2])
147
+ else:
148
+ raise InvalidFormat()
149
+ rows_span = shrink_spans(rows_span, height)
150
+ rows_fg_span, rows_bg_span = cal_fg_bg_span(rows_span, height)
151
+ return rows_fg_span, rows_bg_span
152
+
153
+
154
+ def cal_col_span(table, cells_span, cells_bbox, width):
155
+ layout = table['layout']
156
+ cols_span = list()
157
+ for col_idx in range(layout.shape[1]):
158
+ col = layout[:, col_idx]
159
+ x1s = list()
160
+ x2s = list()
161
+ for cell_id in col:
162
+ cell_span = cells_span[cell_id]
163
+ cell_bbox = cells_bbox[cell_id]
164
+ if (cell_span[0] == col_idx) and (cell_bbox is not None):
165
+ x1s.append(cell_bbox[0])
166
+ if (cell_span[2] == col_idx) and (cell_bbox is not None):
167
+ x2s.append(cell_bbox[2])
168
+
169
+ if (len(x1s) > 0) and (len(x2s) > 0):
170
+ x1 = min(max(1, min(x1s)), width - 1)
171
+ x2 = min(max(1, max(x2s) + 1), width - 1)
172
+ cols_span.append([x1, x2])
173
+ else:
174
+ raise InvalidFormat()
175
+ cols_span = shrink_spans(cols_span, width)
176
+ cols_fg_span, cols_bg_span = cal_fg_bg_span(cols_span, width)
177
+ return cols_fg_span, cols_bg_span
178
+
179
+
180
+ def extract_fg_bg_spans(table, image_size):
181
+ width, height = image_size
182
+ cells_bbox = cal_cell_bbox(table)
183
+ cells_span = cal_cell_spans(table)
184
+ # cal rows fg bg span
185
+ rows_fg_span, rows_bg_span = cal_row_span(table, cells_span, cells_bbox, height)
186
+ #cal cols fg bg span
187
+ cols_fg_span, cols_bg_span = cal_col_span(table, cells_span, cells_bbox, width)
188
+ return rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span
libs/model/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model import Model
2
+ import torch
3
+ from torch import nn
4
+ from libs.utils.comm import get_world_size
5
+
6
+
7
+ def build_model(cfg):
8
+ if get_world_size() == 1:
9
+ norm_layer = nn.BatchNorm2d
10
+ else:
11
+ norm_layer = nn.BatchNorm2d
12
+ model = Model(
13
+ cfg,
14
+ norm_layer=norm_layer
15
+ )
16
+ return model
libs/model/backbone.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ import torch.nn as nn
4
+ from typing import Type, Any, Callable, Union, List, Optional
5
+
6
+
7
+
8
+ model_paths = {
9
+ 'resnet18': '/yrfs2/cv6/frwang/PretrainedModelParams/pytorch/ImageNet/ResNet/resnet18-5c106cde.pth',
10
+ 'resnet34': '/Pretrain/resnet_34.pth',
11
+ 'resnet50': '/yrfs2/cv6/frwang/PretrainedModelParams/pytorch/ImageNet/ResNet/resnet50-19c8e357.pth',
12
+ 'resnet101': '/yrfs2/cv6/frwang/PretrainedModelParams/pytorch/ImageNet/ResNet/resnet101-5d3b4d8f.pth',
13
+ 'resnet152': '/yrfs2/cv6/frwang/PretrainedModelParams/pytorch/ImageNet/ResNet/resnet152-b121ed2d.pth',
14
+ }
15
+
16
+
17
+ def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
18
+ """3x3 convolution with padding"""
19
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20
+ padding=dilation, groups=groups, bias=False, dilation=dilation, padding_mode='reflect')
21
+
22
+
23
+ def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
24
+ """1x1 convolution"""
25
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
26
+
27
+
28
+ class BasicBlock(nn.Module):
29
+ expansion: int = 1
30
+
31
+ def __init__(
32
+ self,
33
+ inplanes: int,
34
+ planes: int,
35
+ stride: int = 1,
36
+ downsample: Optional[nn.Module] = None,
37
+ groups: int = 1,
38
+ base_width: int = 64,
39
+ dilation: int = 1,
40
+ norm_layer: Optional[Callable[..., nn.Module]] = None
41
+ ) -> None:
42
+ super(BasicBlock, self).__init__()
43
+ if norm_layer is None:
44
+ norm_layer = nn.BatchNorm2d
45
+ if groups != 1 or base_width != 64:
46
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
47
+ if dilation > 1:
48
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
49
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
50
+ self.conv1 = conv3x3(inplanes, planes, stride)
51
+ self.bn1 = norm_layer(planes)
52
+ self.relu = nn.ReLU(inplace=True)
53
+ self.conv2 = conv3x3(planes, planes)
54
+ self.bn2 = norm_layer(planes)
55
+ self.downsample = downsample
56
+ self.stride = stride
57
+
58
+ def forward(self, x: Tensor) -> Tensor:
59
+ identity = x
60
+
61
+ out = self.conv1(x)
62
+ out = self.bn1(out)
63
+ out = self.relu(out)
64
+
65
+ out = self.conv2(out)
66
+ out = self.bn2(out)
67
+
68
+ if self.downsample is not None:
69
+ identity = self.downsample(x)
70
+
71
+ out += identity
72
+ out = self.relu(out)
73
+
74
+ return out
75
+
76
+
77
+ class Bottleneck(nn.Module):
78
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
79
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
80
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
81
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
82
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
83
+
84
+ expansion: int = 4
85
+
86
+ def __init__(
87
+ self,
88
+ inplanes: int,
89
+ planes: int,
90
+ stride: int = 1,
91
+ downsample: Optional[nn.Module] = None,
92
+ groups: int = 1,
93
+ base_width: int = 64,
94
+ dilation: int = 1,
95
+ norm_layer: Optional[Callable[..., nn.Module]] = None
96
+ ) -> None:
97
+ super(Bottleneck, self).__init__()
98
+ if norm_layer is None:
99
+ norm_layer = nn.BatchNorm2d
100
+ width = int(planes * (base_width / 64.)) * groups
101
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
102
+ self.conv1 = conv1x1(inplanes, width)
103
+ self.bn1 = norm_layer(width)
104
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
105
+ self.bn2 = norm_layer(width)
106
+ self.conv3 = conv1x1(width, planes * self.expansion)
107
+ self.bn3 = norm_layer(planes * self.expansion)
108
+ self.relu = nn.ReLU(inplace=True)
109
+ self.downsample = downsample
110
+ self.stride = stride
111
+
112
+ def forward(self, x: Tensor) -> Tensor:
113
+ identity = x
114
+
115
+ out = self.conv1(x)
116
+ out = self.bn1(out)
117
+ out = self.relu(out)
118
+
119
+ out = self.conv2(out)
120
+ out = self.bn2(out)
121
+ out = self.relu(out)
122
+
123
+ out = self.conv3(out)
124
+ out = self.bn3(out)
125
+
126
+ if self.downsample is not None:
127
+ identity = self.downsample(x)
128
+
129
+ out += identity
130
+ out = self.relu(out)
131
+
132
+ return out
133
+
134
+
135
+ class ResNet(nn.Module):
136
+
137
+ def __init__(
138
+ self,
139
+ block: Type[Union[BasicBlock, Bottleneck]],
140
+ layers: List[int],
141
+ zero_init_residual: bool = False,
142
+ groups: int = 1,
143
+ width_per_group: int = 64,
144
+ replace_stride_with_dilation: Optional[List[bool]] = None,
145
+ norm_layer: Optional[Callable[..., nn.Module]] = None
146
+ ) -> None:
147
+ super(ResNet, self).__init__()
148
+ if norm_layer is None:
149
+ norm_layer = nn.BatchNorm2d
150
+ self._norm_layer = norm_layer
151
+
152
+ self.inplanes = 64
153
+ self.dilation = 1
154
+ if replace_stride_with_dilation is None:
155
+ # each element in the tuple indicates if we should replace
156
+ # the 2x2 stride with a dilated convolution instead
157
+ replace_stride_with_dilation = [False, False, False]
158
+ if len(replace_stride_with_dilation) != 3:
159
+ raise ValueError("replace_stride_with_dilation should be None "
160
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
161
+ self.groups = groups
162
+ self.base_width = width_per_group
163
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=1, padding=3,
164
+ bias=False, padding_mode='reflect')
165
+ self.bn1 = norm_layer(self.inplanes)
166
+ self.relu = nn.ReLU(inplace=True)
167
+ self.layer1 = self._make_layer(block, 64, layers[0])
168
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
169
+ dilate=replace_stride_with_dilation[0])
170
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
171
+ dilate=replace_stride_with_dilation[1])
172
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
173
+ dilate=replace_stride_with_dilation[2])
174
+
175
+ for m in self.modules():
176
+ if isinstance(m, nn.Conv2d):
177
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
178
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
179
+ nn.init.constant_(m.weight, 1)
180
+ nn.init.constant_(m.bias, 0)
181
+
182
+ # Zero-initialize the last BN in each residual branch,
183
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
184
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
185
+ if zero_init_residual:
186
+ for m in self.modules():
187
+ if isinstance(m, Bottleneck):
188
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
189
+ elif isinstance(m, BasicBlock):
190
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
191
+
192
+ def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
193
+ stride: int = 1, dilate: bool = False) -> nn.Sequential:
194
+ norm_layer = self._norm_layer
195
+ downsample = None
196
+ previous_dilation = self.dilation
197
+ if dilate:
198
+ self.dilation *= stride
199
+ stride = 1
200
+ if stride != 1 or self.inplanes != planes * block.expansion:
201
+ downsample = nn.Sequential(
202
+ conv1x1(self.inplanes, planes * block.expansion, stride),
203
+ norm_layer(planes * block.expansion),
204
+ )
205
+
206
+ layers = []
207
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
208
+ self.base_width, previous_dilation, norm_layer))
209
+ self.inplanes = planes * block.expansion
210
+ for _ in range(1, blocks):
211
+ layers.append(block(self.inplanes, planes, groups=self.groups,
212
+ base_width=self.base_width, dilation=self.dilation,
213
+ norm_layer=norm_layer))
214
+
215
+ return nn.Sequential(*layers)
216
+
217
+ def _forward_impl(self, x: Tensor) -> Tensor:
218
+ input = x # (512, 256)
219
+ x = self.conv1(x) # (256, 128) # stride=2
220
+ x = self.bn1(x) # (256, 128)
221
+ x = self.relu(x) # (256, 128)
222
+
223
+ c2 = self.layer1(x) # (256, 128) # stride=1, total_stride=2
224
+ c3 = self.layer2(c2) # (128, 64) # stride=2, total_stride=4
225
+ c4 = self.layer3(c3) # (64, 32) # stride=2, total_stride=8
226
+ c5 = self.layer4(c4) # (32, 16) # stride=2, total_stride=16
227
+ return c2, c3, c4, c5
228
+
229
+ def forward(self, x: Tensor) -> Tensor:
230
+ return self._forward_impl(x)
231
+
232
+
233
+ def _resnet(
234
+ arch: str,
235
+ block: Type[Union[BasicBlock, Bottleneck]],
236
+ layers: List[int],
237
+ pretrained: bool,
238
+ **kwargs: Any
239
+ ) -> ResNet:
240
+ model = ResNet(block, layers, **kwargs)
241
+ if pretrained:
242
+ checkpoint = torch.load(model_paths[arch], map_location='cpu')
243
+ state_dict = model.state_dict()
244
+ for key, val in state_dict.items():
245
+ if key in checkpoint:
246
+ if val.shape == checkpoint[key].shape:
247
+ state_dict[key] = checkpoint[key]
248
+ model.load_state_dict(state_dict)
249
+ return model
250
+
251
+
252
+ def resnet18(pretrained: bool = False, **kwargs: Any) -> ResNet:
253
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, **kwargs)
254
+
255
+
256
+ def resnet34(pretrained: bool = False, **kwargs: Any) -> ResNet:
257
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, **kwargs)
258
+
259
+
260
+ def resnet50(pretrained: bool = False, **kwargs: Any) -> ResNet:
261
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, **kwargs)
262
+
263
+
264
+ def resnet101(pretrained: bool = False, **kwargs: Any) -> ResNet:
265
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, **kwargs)
266
+
267
+
268
+ def resnet152(pretrained: bool = False, **kwargs: Any) -> ResNet:
269
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, **kwargs)
270
+
271
+
272
+ def build_backbone(arch, pretrained=True, norm_layer=nn.BatchNorm2d):
273
+ arch_map = {
274
+ 'res34': resnet34,
275
+ 'res50': resnet50,
276
+ 'res101': resnet101,
277
+ 'res152': resnet152
278
+ }
279
+ if arch not in arch_map:
280
+ raise ValueError('Unknown backbone arch: %s' % arch)
281
+ return arch_map[arch](pretrained=pretrained, norm_layer=norm_layer)
libs/model/cells_extractor.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from .extractor import RoiPosFeatExtraxtor
6
+
7
+
8
+ class SALayer(nn.Module):
9
+ def __init__(self, in_dim, att_dim, head_nums):
10
+ super().__init__()
11
+ self.in_dim = in_dim
12
+ self.att_dim = att_dim
13
+ self.head_nums = head_nums
14
+
15
+ assert self.in_dim % self.head_nums == 0
16
+
17
+ self.key_layer = nn.Conv1d(self.in_dim, self.att_dim, 1, 1, 0)
18
+ self.query_layer = nn.Conv1d(self.in_dim, self.att_dim, 1, 1, 0)
19
+ self.value_layer = nn.Conv1d(self.in_dim, self.in_dim, 1, 1, 0)
20
+ self.scale = 1 / math.sqrt(self.att_dim)
21
+
22
+ def forward(self, feats, masks=None):
23
+ bs, c, n = feats.shape
24
+ keys = self.key_layer(feats).reshape(bs, -1, self.head_nums, n)
25
+ querys = self.query_layer(feats).reshape(bs, -1, self.head_nums, n)
26
+ values = self.value_layer(feats).reshape(bs, -1, self.head_nums, n)
27
+
28
+ logits = torch.einsum('bchk,bchq->bhkq', keys, querys) * self.scale
29
+ if masks is not None:
30
+ logits = logits - (1 - masks[:, None, :, None]) * 1e8
31
+ weights = torch.softmax(logits, dim=2)
32
+
33
+ new_feats = torch.einsum('bchk,bhkq->bchq', values, weights)
34
+ new_feats = new_feats.reshape(bs, -1, n)
35
+ return new_feats + feats
36
+
37
+
38
+ def gen_cells_bbox(row_segments, col_segments, device):
39
+ cells_bbox = list()
40
+ for row_segments_pi, col_segments_pi in zip(row_segments, col_segments):
41
+ num_rows = len(row_segments_pi) - 1
42
+ num_cols = len(col_segments_pi) - 1
43
+ cells_bbox_pi = list()
44
+ for row_idx in range(num_rows):
45
+ for col_idx in range(num_cols):
46
+ bbox = [
47
+ col_segments_pi[col_idx],
48
+ row_segments_pi[row_idx],
49
+ col_segments_pi[col_idx + 1],
50
+ row_segments_pi[row_idx + 1]
51
+ ]
52
+ cells_bbox_pi.append(bbox)
53
+ cells_bbox_pi = torch.tensor(cells_bbox_pi, dtype=torch.float, device=device)
54
+ cells_bbox.append(cells_bbox_pi)
55
+ return cells_bbox
56
+
57
+
58
+ def align_cells_feat(cells_feat, num_rows, num_cols):
59
+ batch_size = len(cells_feat)
60
+ dtype = cells_feat[0].dtype
61
+ device = cells_feat[0].device
62
+
63
+ max_row_nums = max(num_rows)
64
+ max_col_nums = max(num_cols)
65
+
66
+ aligned_cells_feat = list()
67
+ masks = torch.zeros([batch_size, max_row_nums, max_col_nums], dtype=dtype, device=device)
68
+ for batch_idx in range(batch_size):
69
+ num_rows_pi = num_rows[batch_idx]
70
+ num_cols_pi = num_cols[batch_idx]
71
+ cells_feat_pi = cells_feat[batch_idx]
72
+ cells_feat_pi = cells_feat_pi.transpose(0, 1).reshape(-1, num_rows_pi, num_cols_pi)
73
+ aligned_cells_feat_pi = F.pad(
74
+ cells_feat_pi,
75
+ (0, max_col_nums - num_cols_pi, 0, max_row_nums - num_rows_pi, 0, 0),
76
+ mode='constant',
77
+ value=0
78
+ )
79
+ aligned_cells_feat.append(aligned_cells_feat_pi)
80
+
81
+ masks[batch_idx, :num_rows_pi, :num_cols_pi] = 1
82
+ aligned_cells_feat = torch.stack(aligned_cells_feat, dim=0)
83
+ return aligned_cells_feat, masks
84
+
85
+
86
+ class CellsExtractor(nn.Module):
87
+ def __init__(self, in_dim, cell_dim, heads, head_nums, pool_size, scale=1):
88
+ super().__init__()
89
+ self.in_dim = in_dim
90
+ self.cell_dim = cell_dim
91
+ self.pool_size = pool_size
92
+ self.scale = scale
93
+ self.box_feat_extractor = RoiPosFeatExtraxtor(
94
+ self.scale,
95
+ self.pool_size,
96
+ self.in_dim,
97
+ self.cell_dim
98
+ )
99
+ self.heads = heads
100
+ self.row_sas = nn.ModuleList()
101
+ self.col_sas = nn.ModuleList()
102
+ for _ in range(self.heads):
103
+ self.row_sas.append(SALayer(cell_dim, cell_dim, head_nums))
104
+ self.col_sas.append(SALayer(cell_dim, cell_dim, head_nums))
105
+
106
+
107
+ def forward(self, feats, row_segments, col_segments, img_sizes):
108
+ device = feats.device
109
+ num_rows = [len(row_segments_pi) - 1 for row_segments_pi in row_segments]
110
+ num_cols = [len(col_segments_pi) - 1 for col_segments_pi in col_segments]
111
+
112
+ cells_bbox = gen_cells_bbox(row_segments, col_segments, device)
113
+ cells_feat = self.box_feat_extractor(feats, cells_bbox, img_sizes)
114
+
115
+ aligned_cells_feat, masks = align_cells_feat(cells_feat, num_rows, num_cols)
116
+
117
+ bs, c, nr, nc = aligned_cells_feat.shape
118
+
119
+ for idx in range(self.heads):
120
+ col_cells_feat = aligned_cells_feat.permute(0, 2, 1, 3).contiguous().reshape(bs * nr, c, nc)
121
+ col_masks = masks.reshape(bs * nr, nc)
122
+ col_cells_feat = self.col_sas[idx](col_cells_feat, col_masks) # self-attention
123
+ aligned_cells_feat = col_cells_feat.reshape(bs, nr, c, nc).permute(0, 2, 1, 3).contiguous()
124
+
125
+ row_cells_feat = aligned_cells_feat.permute(0, 3, 1, 2).contiguous().reshape(bs * nc, c, nr)
126
+ row_masks = masks.transpose(1, 2).reshape(bs * nc, nr)
127
+ row_cells_feat = self.row_sas[idx](row_cells_feat, row_masks) # self-attention
128
+ aligned_cells_feat = row_cells_feat.reshape(bs, nc, c, nr).permute(0, 2, 3, 1).contiguous()
129
+
130
+ return aligned_cells_feat, masks
libs/model/decoder.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from numpy.core.fromnumeric import argmax
3
+ import torch
4
+ from torch import nn
5
+ from torch._C import device, dtype, layout
6
+ from torch.nn import functional as F
7
+ from torch.nn.functional import cross_entropy, embedding
8
+ from torch.nn.modules import loss
9
+ from torch.nn.modules.activation import Tanh
10
+ from libs.utils.metric import CellMergeAcc, AccMetric
11
+ from .utils import gen_proposals
12
+
13
+
14
+ class ImageAttention(nn.Module):
15
+ def __init__(self, key_dim, query_dim, cover_kernel):
16
+ super().__init__()
17
+ self.query_transform = nn.Linear(query_dim, key_dim)
18
+ self.weight_transform = nn.Conv2d(1, key_dim, cover_kernel, 1, padding=cover_kernel // 2)
19
+ self.cum_weight_transform = nn.Conv2d(1, key_dim, cover_kernel, 1, padding=cover_kernel // 2)
20
+ self.logit_transform = nn.Conv2d(key_dim, 1, 1, 1, 0)
21
+
22
+ def forward(self, key, key_mask, query, spatial_att_weight, cum_spatial_att_weight, value, state, layouts=None, layouts_cum=None, spatial_att_weight_scores=None):
23
+ query = self.query_transform(query)
24
+ weight_query = self.weight_transform(spatial_att_weight)
25
+ cum_weight_query = self.cum_weight_transform(cum_spatial_att_weight)
26
+ fusion = key + query[:, :, None, None] + weight_query + cum_weight_query
27
+ # cal new_spatial_att_logit
28
+ new_spatial_att_logit = self.logit_transform(torch.tanh(fusion))
29
+ # cal new_spatial_att_weight
30
+ new_spatial_att_weight = new_spatial_att_logit - (1 - key_mask) * 1e8
31
+ bs, _, h, w = new_spatial_att_weight.shape
32
+ new_spatial_att_weight = new_spatial_att_weight.reshape(bs, h * w)
33
+ new_spatial_att_weight = torch.softmax(new_spatial_att_weight, dim=1).reshape(bs, 1, h, w)
34
+ # cal new_cum_spatial_att_weight
35
+ if self.training:
36
+ outputs = list()
37
+ for (value_pi, layout) in zip(value, layouts):
38
+ h, w = torch.where(layout == 1.)
39
+ if len(h) == 0 or len(w) == 0:
40
+ outputs.append(torch.zeros_like(query[0]))
41
+ else:
42
+ outputs.append(value_pi[:, h, w].mean(-1))
43
+ outputs = torch.stack(outputs, dim=0)
44
+ new_cum_spatial_att_weight = torch.clamp(layouts.unsqueeze(1).float() + cum_spatial_att_weight, max=1.)
45
+ return state, outputs, new_spatial_att_logit, new_spatial_att_weight, new_cum_spatial_att_weight, None, None
46
+ else:
47
+ state_list = list()
48
+ outputs_list = list()
49
+ scores_list = list()
50
+ proposals_list = list()
51
+ new_spatial_att_weight_list = list()
52
+ new_cum_spatial_att_weight_list = list()
53
+ layouts_pred = new_spatial_att_logit.squeeze(1).sigmoid()
54
+ for idx, (value_pi, state_pi, layout) in enumerate(zip(value, state, layouts_pred)):
55
+ if cum_spatial_att_weight[idx].min() == 1:
56
+ state_list.append(state_pi)
57
+ outputs_list.append(torch.zeros_like(query[0]))
58
+ proposals_list.append(torch.cat((layouts_cum[idx], torch.zeros_like(layout.unsqueeze(0))), dim=0))
59
+ scores_list.append(spatial_att_weight_scores[idx])
60
+ new_spatial_att_weight_list.append(new_spatial_att_weight[idx])
61
+ new_cum_spatial_att_weight_list.append(cum_spatial_att_weight[idx])
62
+ else:
63
+ srow, scol = torch.where(cum_spatial_att_weight[idx].squeeze(0) == cum_spatial_att_weight[idx].squeeze(0).min())
64
+ scol = scol[srow == srow.min()].min()
65
+ srow = srow.min()
66
+ proposals, scores = gen_proposals(layout, srow, scol, score_threshold=0.5)
67
+ scores = scores + spatial_att_weight_scores[idx]
68
+ for s in scores:
69
+ scores_list.append(s)
70
+ for p in proposals:
71
+ proposals_list.append(torch.cat((layouts_cum[idx], p.unsqueeze(0)), dim=0))
72
+ h, w = torch.where(p == 1.)
73
+ outputs_list.append(value_pi[:, h, w].mean(-1))
74
+ state_list.append(state_pi)
75
+ new_spatial_att_weight_list.append(new_spatial_att_weight[idx])
76
+ new_cum_spatial_att_weight_list.append(torch.clamp(cum_spatial_att_weight[idx] + p.unsqueeze(0), max=1.))
77
+ state_list = torch.stack(state_list, dim=0)
78
+ proposals_list = torch.stack(proposals_list, dim=0)
79
+ scores_list = torch.stack(scores_list, dim=0)
80
+ outputs_list = torch.stack(outputs_list, dim=0)
81
+ new_spatial_att_weight_list = torch.stack(new_spatial_att_weight_list, dim=0)
82
+ new_cum_spatial_att_weight_list = torch.stack(new_cum_spatial_att_weight_list, dim=0)
83
+ sorted_scores, sorted_idxes = torch.sort(scores_list, dim=0, descending=True)
84
+ sorted_scores = sorted_scores[:6]
85
+ sorted_idxes = sorted_idxes[:6]
86
+ proposals = proposals_list[sorted_idxes]
87
+ new_spatial_att_weight = new_spatial_att_weight_list[sorted_idxes]
88
+ new_cum_spatial_att_weight = new_cum_spatial_att_weight_list[sorted_idxes]
89
+ outputs = outputs_list[sorted_idxes]
90
+ state = state_list[sorted_idxes]
91
+ return state, outputs, new_spatial_att_logit, new_spatial_att_weight, new_cum_spatial_att_weight, proposals, sorted_scores
92
+
93
+
94
+
95
+ class Decoder(nn.Module):
96
+ def __init__(self, vocab, embed_dim, feat_dim, lm_state_dim, proj_dim, cover_kernel, att_threshold, spatial_att_logit_loss_wight):
97
+ super().__init__()
98
+ self.vocab = vocab
99
+ self.embed_dim = embed_dim
100
+ self.feat_dim = feat_dim
101
+ self.lm_state_dim = lm_state_dim
102
+ self.proj_dim = proj_dim
103
+ self.cover_kernel = cover_kernel
104
+ self.att_threshold = att_threshold
105
+ self.spatial_att_logit_loss_wight = spatial_att_logit_loss_wight
106
+ self.feat_projection = nn.Conv2d(self.feat_dim, self.proj_dim, 1, 1, 0)
107
+ self.state_init_projection = nn.Conv2d(self.feat_dim, self.lm_state_dim, 1, 1, 0)
108
+ self.lm_rnn1 = nn.GRUCell(input_size=self.feat_dim, hidden_size=self.lm_state_dim)
109
+ self.lm_rnn2 = nn.GRUCell(input_size=self.feat_dim, hidden_size=self.lm_state_dim)
110
+ self.image_attention = ImageAttention(self.proj_dim, self.feat_dim + self.lm_state_dim, cover_kernel)
111
+ self.struct_cls = nn.Sequential(
112
+ nn.Linear(self.feat_dim + self.lm_state_dim, self.lm_state_dim),
113
+ nn.Tanh(),
114
+ nn.Linear(self.lm_state_dim, len(self.vocab))
115
+ )
116
+
117
+ def init_state(self, feats, feats_mask):
118
+ bs, _, h, w = feats.shape
119
+ project_feats = self.feat_projection(feats) * feats_mask
120
+ init_state = torch.sum(self.state_init_projection(feats), dim=(2, 3))/torch.sum(feats_mask, dim=(2, 3))
121
+ init_context = torch.sum(feats, dim=(2, 3)) / torch.sum(feats_mask, dim=(2, 3))
122
+ init_spatial_att_weight = torch.zeros([bs, 1, h, w], dtype=torch.float, device=feats.device)
123
+ init_cum_spatial_att_weight = torch.zeros([bs, 1, h, w], dtype=torch.float, device=feats.device)
124
+ return project_feats, init_state, init_context, init_spatial_att_weight, init_cum_spatial_att_weight
125
+
126
+ def step(self, feats, project_feats, feats_mask, state, context, spatial_att_weight, cum_spatial_att_weight, layouts=None, layouts_cum=None, spatial_att_weight_scores=None):
127
+ new_state = self.lm_rnn1(context, state)
128
+ new_state, new_context, new_spatial_att_logit, \
129
+ new_spatial_att_weight, new_cum_spatial_att_weight, \
130
+ layouts_cum, spatial_att_weight_scores = self.image_attention(
131
+ project_feats,
132
+ feats_mask,
133
+ torch.cat([context, new_state], dim=1),
134
+ spatial_att_weight,
135
+ cum_spatial_att_weight,
136
+ feats,
137
+ new_state,
138
+ layouts,
139
+ layouts_cum,
140
+ spatial_att_weight_scores
141
+ )
142
+ new_state = self.lm_rnn2(new_context, new_state)
143
+ cls_feat = torch.cat([new_context, new_state], dim=1)
144
+ cls_logits_pt = self.struct_cls(cls_feat)
145
+ return cls_logits_pt, new_state, new_context, new_spatial_att_logit, new_spatial_att_weight, new_cum_spatial_att_weight, layouts_cum, spatial_att_weight_scores
146
+
147
+ def forward(self, feats, feats_mask, cls_labels=None, labels_mask=None, layouts=None):
148
+ if self.training:
149
+ return self.forward_backward(feats, feats_mask, cls_labels, labels_mask, layouts)
150
+ else:
151
+ return self.inference(feats, feats_mask)
152
+
153
+ def inference(self, feats, feats_mask):
154
+ bs, _, h, w = feats.shape
155
+ device = feats.device
156
+ assert bs == 1, print('bs should be 1')
157
+ layouts_cum = torch.zeros_like(feats[:, : 1])
158
+ spatial_att_weight_scores = torch.zeros(bs).to(device=device, dtype=feats.dtype)
159
+
160
+ project_feats, init_state, init_context, spatial_att_weight, cum_spatial_att_weight = self.init_state(feats, feats_mask)
161
+ state = init_state
162
+ context = init_context
163
+
164
+ for _ in range(h*w):
165
+ cls_logits_pt, state, context, spatial_att_logit, spatial_att_weight, \
166
+ cum_spatial_att_weight, layouts_cum, spatial_att_weight_scores \
167
+ = self.step(
168
+ feats, project_feats,
169
+ feats_mask, state, context,
170
+ spatial_att_weight, cum_spatial_att_weight, None, layouts_cum, spatial_att_weight_scores)
171
+ feats = feats[:1].repeat(layouts_cum.shape[0], 1, 1, 1)
172
+ feats_mask = feats_mask[:1].repeat(layouts_cum.shape[0], 1, 1, 1)
173
+ project_feats = project_feats[:1].repeat(layouts_cum.shape[0], 1, 1, 1)
174
+ if cum_spatial_att_weight.min() == 1:
175
+ break
176
+ spatial_att_logit_preds = layouts_cum[spatial_att_weight_scores.argmax(), 1:].unsqueeze(0)
177
+ return spatial_att_logit_preds, {}
178
+
179
+ def forward_backward(self, feats, feats_mask, cls_labels, labels_mask, layouts):
180
+ device = feats.device
181
+ valid_cls_length = torch.sum((labels_mask == 1) & (cls_labels != -1), dim=1).detach()
182
+ valid_spatial_att_logit_length = torch.stack([layout.max() + 1 for layout in layouts])
183
+ max_length = valid_cls_length.max()
184
+
185
+ project_feats, init_state, init_context, spatial_att_weight, cum_spatial_att_weight = self.init_state(feats, feats_mask)
186
+ state = init_state
187
+ context = init_context
188
+
189
+ loss_cache = dict()
190
+
191
+ cls_loss = list()
192
+ cls_preds = list()
193
+
194
+ spatial_att_logit_loss = list()
195
+ spatial_att_logit_preds = list()
196
+ spatial_att_logit_masks = list()
197
+ spatial_att_logit_labels = list()
198
+ for time_t in range(max_length):
199
+ cls_logits_pt, state, context, spatial_att_logit, spatial_att_weight, cum_spatial_att_weight, *_ \
200
+ = self.step(
201
+ feats, project_feats,
202
+ feats_mask, state, context,
203
+ spatial_att_weight, cum_spatial_att_weight, layouts == time_t
204
+ )
205
+
206
+ cls_label = cls_labels[:, time_t]
207
+ label_mask = labels_mask[:, time_t]
208
+ # cal cls loss
209
+ cls_loss_pt = F.cross_entropy(cls_logits_pt, cls_label, ignore_index=-1, reduction='none') * label_mask
210
+ cls_loss.append(cls_loss_pt)
211
+ # save for acc
212
+ cls_preds.append(torch.argmax(cls_logits_pt, dim=1).detach())
213
+
214
+ spatial_att_logit_preds.append(spatial_att_logit.sigmoid() > self.att_threshold)
215
+ spatial_att_logit_masks.append((layouts != -1).unsqueeze(1))
216
+ spatial_att_logit_labels.append((layouts == time_t).unsqueeze(1))
217
+ # cal spatial att loss
218
+ spatial_att_logit_loss_pt = list()
219
+ for spatial_att_logit_pi, layout in zip(spatial_att_logit, layouts):
220
+ target = layout == time_t
221
+ if torch.any(target) == False:
222
+ spatial_att_logit_loss_pt_pi = torch.tensor(0.0, dtype=torch.float, device=device)
223
+ else:
224
+ mask = (layout != -1).float()
225
+ spatial_att_logit_loss_pt_pi = F.binary_cross_entropy_with_logits(
226
+ spatial_att_logit_pi,
227
+ target.float().unsqueeze(0),
228
+ reduction='none'
229
+ )
230
+ spatial_att_logit_loss_pt_pi = (spatial_att_logit_loss_pt_pi * mask).sum()
231
+ spatial_att_logit_loss_pt.append(spatial_att_logit_loss_pt_pi)
232
+ spatial_att_logit_loss_pt = torch.stack(spatial_att_logit_loss_pt, dim=0)
233
+ spatial_att_logit_loss.append(spatial_att_logit_loss_pt)
234
+
235
+ cls_loss = torch.mean(torch.sum(torch.stack(cls_loss, dim=1), dim=1)/valid_cls_length)
236
+ spatial_att_logit_loss = self.spatial_att_logit_loss_wight * torch.mean(torch.sum(torch.stack(spatial_att_logit_loss, dim=1), dim=1) / valid_spatial_att_logit_length)
237
+
238
+ loss_cache['cls_loss'] = cls_loss
239
+ loss_cache['spatial_att_logit_loss'] = spatial_att_logit_loss
240
+
241
+ cls_preds = torch.stack(cls_preds, dim=1)
242
+ spatial_att_logit_preds = torch.stack(spatial_att_logit_preds, dim=1)
243
+ spatial_att_logit_masks = torch.stack(spatial_att_logit_masks, dim=1)
244
+ spatial_att_logit_labels = torch.stack(spatial_att_logit_labels, dim=1)
245
+
246
+ acc_metric = AccMetric()
247
+ cell_merge_acc = CellMergeAcc()
248
+ cls_correct, cls_total = acc_metric(cls_preds, cls_labels, labels_mask)
249
+ cls_none_correct, cls_none_total = acc_metric(cls_preds, cls_labels, (labels_mask == 1) & (cls_labels == self.vocab.none_id))
250
+ cls_bold_correct, cls_bold_total = acc_metric(cls_preds, cls_labels, (labels_mask == 1) & (cls_labels == self.vocab.bold_id))
251
+ cls_space_correct, cls_space_total = acc_metric(cls_preds, cls_labels, (labels_mask == 1) & (cls_labels == self.vocab.space_id))
252
+ cls_blank_correct = cls_none_correct + cls_bold_correct + cls_space_correct
253
+ cls_blank_total = cls_none_total + cls_bold_total + cls_space_total
254
+ cells_correct_nums, cells_total_nums = cell_merge_acc(spatial_att_logit_preds, spatial_att_logit_labels, spatial_att_logit_masks)
255
+ loss_cache['cls_acc'] = cls_correct / cls_total
256
+ loss_cache['cls_none_acc'] = cls_none_correct / cls_none_total
257
+ loss_cache['cls_bold_acc'] = cls_bold_correct / cls_bold_total
258
+ loss_cache['cls_space_acc'] = cls_space_correct / cls_space_total
259
+ loss_cache['cls_blank_acc'] = cls_blank_correct / cls_blank_total
260
+ loss_cache['spatial_att_logit_acc'] = cells_correct_nums / cells_total_nums
261
+
262
+ return (spatial_att_logit_preds), loss_cache
263
+
264
+ def build_decoder(cfg):
265
+ decoder = Decoder(
266
+ vocab=cfg.vocab,
267
+ feat_dim=cfg.encode_dim,
268
+ line_dim=cfg.extractor_dim,
269
+ embed_dim=cfg.embed_dim,
270
+ lm_state_dim=cfg.lm_state_dim,
271
+ proj_dim=cfg.proj_dim,
272
+ hidden_dim=cfg.hidden_dim,
273
+ cover_kernel=cfg.cover_kernel,
274
+ max_length=cfg.max_length
275
+ )
276
+ return decoder
277
+
libs/model/divide_predictor.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from .sa import SALayer
5
+ from libs.utils.metric import cal_cls_acc
6
+
7
+
8
+ def align_segments_feat(segments_feat):
9
+ dtype = segments_feat[0].dtype
10
+ device = segments_feat[0].device
11
+ batch_size = len(segments_feat)
12
+ max_segment_nums = max([item.shape[1] for item in segments_feat])
13
+ aligned_segments_feat = list()
14
+ masks = torch.zeros([batch_size, max_segment_nums], dtype=dtype, device=device)
15
+
16
+ for batch_idx in range(batch_size):
17
+ cur_segment_nums = segments_feat[batch_idx].shape[1]
18
+ masks[batch_idx, :cur_segment_nums] = 1
19
+ aligned_segments_feat.append(
20
+ F.pad(
21
+ segments_feat[batch_idx],
22
+ (0, max_segment_nums - cur_segment_nums, 0, 0),
23
+ mode='constant',
24
+ value=0
25
+ )
26
+ )
27
+ aligned_segments_feat = torch.stack(aligned_segments_feat, dim=0)
28
+ return aligned_segments_feat, masks
29
+
30
+
31
+ class HeadBodyDividePredictor(nn.Module):
32
+ def __init__(self, in_dim, head_nums, scale=1):
33
+ super().__init__()
34
+ self.in_dim = in_dim
35
+ self.scale = scale
36
+ self.fusion_layer = SALayer(in_dim, in_dim, head_nums)
37
+ self.classifier= nn.Conv1d(in_dim, 1, 1, 1, 0)
38
+
39
+ def forward(self, feats, segments, divide_labels=None):
40
+ segments = [[int(subitem * self.scale) for subitem in item] for item in segments]
41
+ segments_feat = [feats_pi[:, segments_pi] for feats_pi, segments_pi in zip(feats, segments)]
42
+ aligned_segments_feat, masks = align_segments_feat(segments_feat)
43
+ aligned_segments_feat = self.fusion_layer(aligned_segments_feat, masks)
44
+ divide_logits = self.classifier(aligned_segments_feat).squeeze(1)
45
+ divide_logits = divide_logits - (1 - masks) * 1e8
46
+ divide_preds = torch.argmax(divide_logits, dim=1)
47
+
48
+ result_info = dict()
49
+ ext_info = dict()
50
+ if self.training:
51
+ result_info['divide_loss'] = F.cross_entropy(divide_logits, divide_labels)
52
+ correct_nums, total_nums = cal_cls_acc(divide_preds, divide_labels)
53
+ if total_nums != 0:
54
+ result_info['divide_acc'] = correct_nums / total_nums
55
+
56
+ divide_preds = divide_preds.detach().cpu().tolist()
57
+ return divide_preds, result_info, ext_info
libs/model/extractor.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torchvision.ops import roi_align
4
+
5
+
6
+ def convert_to_roi_format(lines_box):
7
+ concat_boxes = torch.cat(lines_box, dim=0)
8
+ device, dtype = concat_boxes.device, concat_boxes.dtype
9
+ ids = torch.cat(
10
+ [
11
+ torch.full((lines_box_pi.shape[0], 1), i, dtype=dtype, device=device)
12
+ for i, lines_box_pi in enumerate(lines_box)
13
+ ],
14
+ dim=0
15
+ )
16
+ rois = torch.cat([ids, concat_boxes], dim=1)
17
+ return rois
18
+
19
+
20
+ class RoiFeatExtraxtor(nn.Module):
21
+ def __init__(self, scale, pool_size, input_dim, output_dim):
22
+ super().__init__()
23
+ self.scale = scale
24
+ self.pool_size = pool_size
25
+ self.output_dim = output_dim
26
+ input_dim = input_dim * self.pool_size[0] * self.pool_size[1]
27
+ self.fc = nn.Sequential(
28
+ nn.Linear(input_dim, self.output_dim),
29
+ nn.ReLU(),
30
+ nn.Linear(self.output_dim, self.output_dim)
31
+ )
32
+
33
+ def forward(self, feats, lines_box):
34
+ rois = convert_to_roi_format(lines_box)
35
+ lines_feat = roi_align(
36
+ input=feats,
37
+ boxes=rois,
38
+ output_size=self.pool_size,
39
+ spatial_scale=self.scale,
40
+ sampling_ratio=2
41
+ )
42
+
43
+ lines_feat = lines_feat.reshape(lines_feat.shape[0], -1)
44
+ lines_feat = self.fc(lines_feat)
45
+ lines_feat = torch.split(lines_feat, [item.shape[0] for item in lines_box])
46
+ return list(lines_feat)
47
+
48
+
49
+ class RoiPosFeatExtraxtor(nn.Module):
50
+ def __init__(self, scale, pool_size, input_dim, output_dim):
51
+ super().__init__()
52
+ self.scale = scale
53
+ self.pool_size = pool_size
54
+ self.output_dim = output_dim
55
+ input_dim = input_dim * self.pool_size[0] * self.pool_size[1]
56
+ self.fc = nn.Sequential(
57
+ nn.Linear(input_dim, self.output_dim),
58
+ nn.ReLU(),
59
+ nn.Linear(self.output_dim, self.output_dim)
60
+ )
61
+ self.bbox_ln = nn.LayerNorm(self.output_dim)
62
+ self.bbox_tranform = nn.Linear(4, self.output_dim)
63
+
64
+ self.add_ln = nn.LayerNorm(self.output_dim)
65
+
66
+ def forward(self, feats, lines_box, img_sizes):
67
+ rois = convert_to_roi_format(lines_box)
68
+ lines_feat = roi_align(
69
+ input=feats,
70
+ boxes=rois,
71
+ output_size=self.pool_size,
72
+ spatial_scale=self.scale,
73
+ sampling_ratio=2
74
+ )
75
+ lines_feat = lines_feat.reshape(lines_feat.shape[0], -1)
76
+ lines_feat = self.fc(lines_feat)
77
+ lines_feat = list(torch.split(lines_feat, [item.shape[0] for item in lines_box]))
78
+
79
+ # Add Pos Embedding
80
+ feats_H, feats_W = feats.shape[-2:]
81
+ for idx, (line_box, img_size) in enumerate(zip(lines_box, img_sizes)):
82
+ line_box[:, 0] = line_box[:, 0] * self.scale / feats_W
83
+ line_box[:, 1] = line_box[:, 1] * self.scale / feats_H
84
+ line_box[:, 2] = line_box[:, 2] * self.scale / feats_W
85
+ line_box[:, 3] = line_box[:, 3] * self.scale / feats_H
86
+ lines_feat[idx] = self.add_ln(lines_feat[idx] + self.bbox_ln(self.bbox_tranform(line_box)))
87
+
88
+ return list(lines_feat)
libs/model/fpn.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class FPN(nn.Module):
7
+ def __init__(self, in_channels, out_channels):
8
+ super().__init__()
9
+ assert len(in_channels) == 4
10
+ self.in_channels = in_channels
11
+
12
+ self.lat_layers = nn.ModuleList()
13
+ self.out_layers = nn.ModuleList()
14
+ for in_channels_pl in in_channels:
15
+ self.lat_layers.append(
16
+ nn.Conv2d(in_channels_pl, out_channels, kernel_size=1, stride=1, padding=0)
17
+ )
18
+ self.out_layers.append(
19
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflect')
20
+ )
21
+
22
+ def forward(self, feats):
23
+ c2, c3, c4, c5 = feats
24
+ p5 = self.lat_layers[3](c5)
25
+ p4 = F.interpolate(p5, size=c4.shape[2:], align_corners=False, mode='bilinear') + self.lat_layers[2](c4)
26
+ p3 = F.interpolate(p4, size=c3.shape[2:], align_corners=False, mode='bilinear') + self.lat_layers[1](c3)
27
+ p2 = F.interpolate(p3, size=c2.shape[2:], align_corners=False, mode='bilinear') + self.lat_layers[0](c2)
28
+
29
+ p2 = self.out_layers[0](p2)
30
+ p3 = self.out_layers[1](p3)
31
+ p4 = self.out_layers[2](p4)
32
+ p5 = self.out_layers[3](p5)
33
+ return p2, p3, p4, p5
34
+
35
+
36
+ def build_fpn(in_channels, out_channels):
37
+ return FPN(in_channels, out_channels)
libs/model/model.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .backbone import build_backbone
4
+ from .fpn import build_fpn
5
+ from .pan import PAN
6
+ from .segment_predictor import SegmentPredictor
7
+ from .divide_predictor import HeadBodyDividePredictor
8
+ from .cells_extractor import CellsExtractor
9
+ from .decoder import Decoder
10
+ from .utils import extend_segments, spatial_att_to_spans
11
+
12
+
13
+ class Model(nn.Module):
14
+ def __init__(self, cfg, norm_layer=nn.BatchNorm2d):
15
+ super().__init__()
16
+ self.backbone = build_backbone(cfg.arch, cfg.pretrained_backbone, norm_layer=norm_layer)
17
+ self.fpn = build_fpn(cfg.backbone_out_channels, cfg.fpn_out_channels)
18
+ self.pan = PAN(cfg.pan_num_levels, cfg.pan_in_dim, cfg.pan_out_dim)
19
+ self.row_segment_predictor = SegmentPredictor(cfg.fpn_out_channels, scale=cfg.rs_scale, type='row')
20
+ self.col_segment_predictor = SegmentPredictor(cfg.fpn_out_channels, scale=cfg.cs_scale, type='col')
21
+ self.divide_predictor = HeadBodyDividePredictor(cfg.fpn_out_channels, cfg.dp_head_nums, scale=cfg.dp_scale)
22
+ self.cells_extractor = CellsExtractor(cfg.fpn_out_channels, cfg.ce_dim, cfg.ce_heads, cfg.ce_head_nums, cfg.ce_pool_size, cfg.ce_scale)
23
+ self.decoder = Decoder(cfg.vocab, cfg.embed_dim, cfg.feat_dim, cfg.lm_state_dim, cfg.proj_dim, cfg.cover_kernel, cfg.att_threshold, cfg.spatial_att_weight_loss_wight)
24
+
25
+ def forward(self, images, images_size, cls_labels=None, labels_mask=None, layouts=None, rows_fg_spans=None,
26
+ rows_bg_spans=None, cols_fg_spans=None, cols_bg_spans=None, cells_spans=None, divide_labels=None):
27
+
28
+ feats = self.fpn(self.backbone(images))
29
+
30
+ row_feats = torch.mean(feats[0], dim=3)
31
+
32
+ result_info = dict()
33
+ ext_info = dict()
34
+ row_segments, rs_result_info, rs_ext_info = self.row_segment_predictor(feats[0], images_size, rows_fg_spans, rows_bg_spans)
35
+ rs_result_info = {'row_%s' % key: val for key, val in rs_result_info.items()}
36
+ rs_ext_info = {'row_%s' % key: val for key, val in rs_ext_info.items()}
37
+ result_info.update(rs_result_info)
38
+ ext_info.update(rs_ext_info)
39
+ col_segments, cs_result_info, cs_ext_info = self.col_segment_predictor(feats[0], images_size, cols_fg_spans, cols_bg_spans)
40
+ cs_result_info = {'col_%s' % key: val for key, val in cs_result_info.items()}
41
+ cs_ext_info = {'col_%s' % key: val for key, val in cs_ext_info.items()}
42
+ result_info.update(cs_result_info)
43
+ ext_info.update(cs_ext_info)
44
+
45
+ if self.training:
46
+ row_segments, col_segments, cells_spans, layouts, divide_labels = extend_segments(row_segments, rs_ext_info['row_ext_segments'],
47
+ col_segments, cs_ext_info['col_ext_segments'], cells_spans, layouts, divide_labels)
48
+
49
+ divide_preds, dp_result_info, dp_ext_info = self.divide_predictor(row_feats, row_segments, divide_labels=divide_labels)
50
+ result_info.update(dp_result_info)
51
+ ext_info.update(dp_ext_info)
52
+
53
+ feat_maps, feats_masks = self.cells_extractor(self.pan(feats), row_segments, col_segments, images_size)
54
+ if self.training:
55
+ assert feat_maps.shape[-2:] == layouts.shape[-2:], print('feat_maps is not the same with layouts')
56
+
57
+ de_preds, de_result_info = self.decoder(feat_maps, feats_masks.unsqueeze(1), cls_labels, labels_mask, layouts)
58
+ result_info.update(de_result_info)
59
+
60
+ if not self.training:
61
+ assert de_preds.shape[0] == 1, print("batch size should be 1")
62
+ de_recog_spans = spatial_att_to_spans(de_preds[0])
63
+ return (row_segments, col_segments, divide_preds, de_recog_spans), result_info
64
+ else:
65
+ return (row_segments, col_segments, divide_preds), result_info
libs/model/pan.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class PAN(nn.Module):
7
+ def __init__(self, num_levels, in_channels, out_channels):
8
+ super().__init__()
9
+ self.num_levels = num_levels
10
+ self.in_channels = in_channels
11
+ self.out_channels = out_channels
12
+ self.pan_layers = nn.ModuleList()
13
+ for _ in range(num_levels - 1):
14
+ self.pan_layers.append(
15
+ nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, padding_mode='reflect')
16
+ )
17
+
18
+ def forward(self, feats):
19
+ p2, p3, p4, p5 = feats
20
+ p2_ = p2
21
+ p3_ = self.pan_layers[0](F.interpolate(p2_, size=p3.shape[2:], align_corners=False, mode='bilinear') + p3)
22
+ p4_ = self.pan_layers[1](F.interpolate(p3_, size=p4.shape[2:], align_corners=False, mode='bilinear') + p4)
23
+ p5_ = self.pan_layers[2](F.interpolate(p4_, size=p5.shape[2:], align_corners=False, mode='bilinear') + p5)
24
+ return p5_
libs/model/sa.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class SALayer(nn.Module):
7
+ def __init__(self, in_dim, att_dim, head_nums):
8
+ super().__init__()
9
+ self.in_dim = in_dim
10
+ self.att_dim = att_dim
11
+ self.head_nums = head_nums
12
+
13
+ assert self.in_dim % self.head_nums == 0
14
+
15
+ self.key_layer = nn.Conv1d(self.in_dim, self.att_dim * self.head_nums, 1, 1, 0)
16
+ self.query_layer = nn.Conv1d(self.in_dim, self.att_dim * self.head_nums, 1, 1, 0)
17
+ self.value_layer = nn.Conv1d(self.in_dim, self.in_dim, 1, 1, 0)
18
+ self.scale = 1 / math.sqrt(self.att_dim)
19
+
20
+ def forward(self, feats, masks=None):
21
+ bs, c, n = feats.shape
22
+ keys = self.key_layer(feats).reshape(bs, -1, self.head_nums, n)
23
+ querys = self.query_layer(feats).reshape(bs, -1, self.head_nums, n)
24
+ values = self.value_layer(feats).reshape(bs, -1, self.head_nums, n)
25
+
26
+ logits = torch.einsum('bchk,bchq->bhkq', keys, querys) * self.scale
27
+ if masks is not None:
28
+ logits = logits - (1 - masks[:, None, :, None]) * 1e8
29
+ weights = torch.softmax(logits, dim=2)
30
+
31
+ new_feats = torch.einsum('bchk,bhkq->bchq', values, weights)
32
+ new_feats = new_feats.reshape(bs, -1, n)
33
+ return new_feats + feats
34
+
35
+
libs/model/segment_predictor.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from torch.nn.modules.activation import ReLU
5
+ from libs.utils.metric import cal_segment_pr
6
+ from .utils import draw_spans, save_logitmap
7
+
8
+ def cal_segments(cls_probs, spans, scale=1.0):
9
+ segments = list()
10
+ for span in spans:
11
+ span_cls_probs = cls_probs[int(span[0] * scale): int(span[1] * scale)]
12
+ segment = torch.argmax(span_cls_probs).item() + int(span[0] * scale)
13
+ segments.append(segment)
14
+ segments = [int(item/scale) for item in segments]
15
+ return segments
16
+
17
+
18
+ def cal_spans(cls_probs, threshold=0.5):
19
+ ids = (cls_probs > threshold).long().tolist()
20
+ spans = list()
21
+ for idx, id in enumerate(ids):
22
+ if id == 1:
23
+ if (idx == 0) or (ids[idx-1] != 1):
24
+ spans.append([idx, idx+1])
25
+ else:
26
+ spans[-1][1] = idx + 1
27
+ return spans
28
+ # draw_spans('row_segment_spans.png', 'row_segment.png', spans, 'row')
29
+
30
+ def cls_logits_to_segments(segments_logit, masks, type, spans=None, scale=1, threshold=0.5):
31
+ if type == 'col':
32
+ cls_probs = segments_logit.squeeze(1).sigmoid().mean(dim=1)
33
+ lengths = [int(mask[0, :].sum().item()) for mask in masks]
34
+ else:
35
+ cls_probs = segments_logit.squeeze(1).sigmoid().mean(dim=2)
36
+ lengths = [int(mask[:, 0].sum().item()) for mask in masks]
37
+
38
+ batch_size = cls_probs.shape[0]
39
+ segments = list()
40
+ for batch_idx in range(batch_size):
41
+ length = lengths[batch_idx]
42
+ if spans is None:
43
+ spans_pi = cal_spans(cls_probs[batch_idx, :length], threshold)
44
+ if len(spans_pi) <= 2:
45
+ spans_pi = [[0, 1], [length-1, length]]
46
+ else:
47
+ spans_pi = spans[batch_idx]
48
+ segments_pi = cal_segments(cls_probs[batch_idx, :length], spans_pi, scale)
49
+ segments.append(segments_pi)
50
+ return segments, cls_probs, lengths
51
+
52
+
53
+ def cal_ext_segments(cls_probs, lengths, bg_spans, scale=1, threshold=0.5):
54
+ """
55
+ Ѱ�Ҽ�����. ��bg_spans(��line����,����������)��Ѱ��Ԥ��������, �Ҵ���threshold����.
56
+ """
57
+ batch_size = cls_probs.shape[0]
58
+ ext_segments = list()
59
+ for batch_idx in range(batch_size):
60
+ length = lengths[batch_idx]
61
+ ext_segments_pi = cal_segments(cls_probs[batch_idx, :length], bg_spans[batch_idx], scale)
62
+ ext_segments_pi = [segment for segment in ext_segments_pi if cls_probs[batch_idx, segment] > threshold]
63
+ ext_segments.append(ext_segments_pi)
64
+ return ext_segments
65
+
66
+
67
+ def gen_masks(sizes, scale, device):
68
+ batch_size = len(sizes)
69
+ max_size = [int(max(item) * scale) for item in zip(*sizes)]
70
+ masks = torch.zeros([batch_size, *max_size], dtype=torch.float, device=device)
71
+ for batch_idx in range(batch_size):
72
+ masks[batch_idx, :sizes[batch_idx][0], :sizes[batch_idx][1]] = 1.
73
+ return masks
74
+
75
+
76
+ def gen_targets(sizes, scale, device, fg_spans, bg_spans, type):
77
+ batch_size = len(sizes)
78
+ max_size = [int(max(item) * scale) for item in zip(*sizes)]
79
+ targets = torch.zeros([batch_size, *max_size], dtype=torch.float, device=device)
80
+ for batch_idx, fg_spans_pb in enumerate(fg_spans):
81
+ if type == 'col':
82
+ for fg_spans_pi in fg_spans_pb:
83
+ targets[batch_idx, :, int(fg_spans_pi[0] * scale) : int(fg_spans_pi[1] * scale)] = 1.
84
+ else:
85
+ for fg_spans_pi in fg_spans_pb:
86
+ targets[batch_idx, int(fg_spans_pi[0] * scale) : int(fg_spans_pi[1] * scale), :] = 1.
87
+ return targets
88
+
89
+
90
+ class SegmentPredictor(nn.Module):
91
+ def __init__(self, in_dim, scale=1, threshold=0.5, type=None):
92
+ super().__init__()
93
+ self.scale = scale
94
+ self.in_dim = in_dim
95
+ assert type in ['col', 'row']
96
+ self.type = type
97
+ self.threshold = threshold
98
+ self.convs = nn.Sequential(
99
+ nn.Conv2d(in_dim, in_dim // 2, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
100
+ nn.ReLU(),
101
+ nn.Conv2d(in_dim // 2, 1, kernel_size=(1,1), stride=(1,1), padding=(0,0))
102
+ )
103
+
104
+ def forward(self, feats, images_size, fg_spans=None, bg_spans=None):
105
+ batch_size = feats.shape[0]
106
+ images_size = [image_size[::-1] for image_size in images_size]
107
+ segments_logit = self.convs(feats)
108
+ masks = gen_masks(images_size, self.scale, feats.device)
109
+ # save_logitmap('row_segment.png', segments_logit[0][0])
110
+ result_info = dict()
111
+ ext_info = dict()
112
+
113
+ if self.training:
114
+ targets = gen_targets(images_size, self.scale, feats.device, fg_spans, bg_spans, self.type)
115
+ segments_loss = F.binary_cross_entropy_with_logits(
116
+ segments_logit,
117
+ targets.unsqueeze(1),
118
+ reduction='none'
119
+ )
120
+ segments_loss = (segments_loss * masks[:, None, :, :]).sum() / targets.sum()
121
+ result_info['segments_loss'] = segments_loss
122
+
123
+ pred_segments, cls_probs, lengths = cls_logits_to_segments(segments_logit, masks, self.type, spans=None, scale=self.scale, threshold=self.threshold)
124
+ correct_nums, segment_nums, span_nums = cal_segment_pr(pred_segments, fg_spans, bg_spans)
125
+ if segment_nums != 0:
126
+ result_info['precision'] = correct_nums/segment_nums
127
+ if span_nums != 0:
128
+ result_info['recall'] = correct_nums/span_nums
129
+ ext_segments = cal_ext_segments(cls_probs, lengths, bg_spans, self.scale, self.threshold)
130
+ ext_info['ext_segments'] = ext_segments
131
+
132
+ pred_segments, *_ = cls_logits_to_segments(segments_logit, masks, self.type, spans=fg_spans, scale=self.scale, threshold=self.threshold)
133
+ return pred_segments, result_info, ext_info
libs/model/utils.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import copy
3
+ import torch
4
+ import numpy as np
5
+ from torch.nn import functional as F
6
+
7
+ def proposal_colspan(layout, layout_score, srow, scol):
8
+
9
+ y, x = torch.where(layout == 1)
10
+ if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
11
+ return layout, layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean()
12
+ else:
13
+ lf_row = srow
14
+ lf_col = scol
15
+
16
+ col_count = 0
17
+ for col_ in range(lf_col, x.max() + 1):
18
+ if layout[lf_row, col_] == 1:
19
+ col_count = col_count + 1
20
+ else:
21
+ break
22
+ row_count = 0
23
+ for row_ in range(lf_row, y.max() + 1):
24
+ if torch.all(layout[row_, lf_col: lf_col + col_count] == 1):
25
+ row_count = row_count + 1
26
+ else:
27
+ break
28
+
29
+ layout[:, :] = 0
30
+ layout[lf_row:lf_row + row_count, lf_col : lf_col + col_count] = 1
31
+ return layout, layout_score[lf_row:lf_row + row_count, lf_col : lf_col + col_count].mean()
32
+
33
+ def proposal_rowspan(layout, layout_score, srow, scol):
34
+
35
+
36
+ y, x = torch.where(layout == 1)
37
+ if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
38
+ return layout, layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean()
39
+ else:
40
+ lf_row = srow
41
+ lf_col = scol
42
+
43
+ row_count = 0
44
+ for row_ in range(lf_row, y.max() + 1):
45
+ if layout[row_, lf_col] == 1:
46
+ row_count = row_count + 1
47
+ else:
48
+ break
49
+ col_count = 0
50
+ for col_ in range(lf_col, x.max() + 1):
51
+ if torch.all(layout[lf_row : lf_row + row_count, col_] == 1):
52
+ col_count = col_count + 1
53
+ else:
54
+ break
55
+
56
+ layout[:, :] = 0
57
+ layout[lf_row:lf_row + row_count, lf_col : lf_col + col_count] = 1
58
+ return layout, layout_score[lf_row:lf_row + row_count, lf_col : lf_col + col_count].mean()
59
+
60
+ def proposal_maxcontain(layout, layout_score, srow, scol):
61
+
62
+
63
+ y, x = torch.where(layout == 1)
64
+ if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
65
+ return layout, layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean()
66
+ else:
67
+ lf_row = srow
68
+ lf_col = scol
69
+
70
+ layout[:, :] = 0
71
+ layout[lf_row: y.max()+1, lf_col : x.max() + 1] = 1
72
+ return layout, layout_score[lf_row: y.max()+1, lf_col : x.max() + 1].mean()
73
+
74
+ def proposal_maxrowspan(layout, layout_score, srow, scol):
75
+
76
+
77
+ y, x = torch.where(layout == 1)
78
+ if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
79
+ return layout, layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean()
80
+ else:
81
+ lf_row = srow
82
+ lf_col = scol
83
+
84
+ row_count = 1
85
+ for row_ in range(lf_row + 1, y.max() + 1):
86
+ if torch.all(layout[lf_row] == layout[row_]):
87
+ row_count = row_count + 1
88
+ else:
89
+ break
90
+
91
+ layout[:, :] = 0
92
+ layout[lf_row : lf_row + row_count, lf_col : x.max() + 1] = 1
93
+ return layout, layout_score[lf_row : lf_row + row_count, lf_col : x.max() + 1].mean()
94
+
95
+ def proposal_maxcolspan(layout, layout_score, srow, scol):
96
+
97
+
98
+ y, x = torch.where(layout == 1)
99
+ if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
100
+ return layout, layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean()
101
+ else:
102
+ lf_row = srow
103
+ lf_col = scol
104
+
105
+ col_count = 1
106
+ for col_ in range(lf_col + 1, x.max() + 1):
107
+ if torch.all(layout[:, lf_col] == layout[:, col_]):
108
+ col_count = col_count + 1
109
+ else:
110
+ break
111
+
112
+ layout[:, :] = 0
113
+ layout[lf_row : y.max() + 1, lf_col : lf_col + col_count] = 1
114
+ return layout, layout_score[lf_row : y.max() + 1, lf_col : lf_col + col_count].mean()
115
+
116
+ def gen_proposals(layout_score, srow, scol, score_threshold=0.5):
117
+ layout = layout_score > score_threshold
118
+ layout[srow, scol] = 1
119
+
120
+ y, x = torch.where(layout == 1)
121
+ if torch.all(layout[y.min():y.max() + 1, x.min():x.max()+1] == 1):
122
+ return layout.unsqueeze(0), layout_score[y.min():y.max() + 1, x.min():x.max()+1].mean().unsqueeze(0).log()
123
+ else:
124
+ proposal_1, score_1 = proposal_colspan(copy.deepcopy(layout), layout_score, srow, scol)
125
+ proposal_2, score_2 = proposal_rowspan(copy.deepcopy(layout), layout_score, srow, scol)
126
+ proposal_3, score_3 = proposal_maxcontain(copy.deepcopy(layout), layout_score, srow, scol)
127
+ proposal_4, score_4 = proposal_maxrowspan(copy.deepcopy(layout), layout_score, srow, scol)
128
+ proposal_5, score_5 = proposal_maxcolspan(copy.deepcopy(layout), layout_score, srow, scol)
129
+ proposals = torch.stack([proposal_1, proposal_2, proposal_3, proposal_4, proposal_5], dim=0)
130
+ scores = torch.stack([score_1.log(), score_2.log(), score_3.log(), score_4.log(), score_5.log()], dim=0)
131
+ return proposals, scores
132
+
133
+ def extend_segments(row_segments, rows_es, col_segments, cols_es, cells_spans, layouts, divide_labels):
134
+ batch_size = len(row_segments)
135
+ ext_row_segments = list()
136
+ ext_col_segments = list()
137
+ ext_cells_spans = list()
138
+ ext_layouts = list()
139
+ ext_divide_labels = list()
140
+ for batch_idx in range(batch_size):
141
+ row_segments_pi = row_segments[batch_idx]
142
+ col_segments_pi = col_segments[batch_idx]
143
+ rows_es_pi = rows_es[batch_idx]
144
+ cols_es_pi = cols_es[batch_idx]
145
+ cells_spans_pi = cells_spans[batch_idx]
146
+
147
+ ext_row_segments_pi = row_segments_pi + rows_es_pi
148
+ ext_col_segments_pi = col_segments_pi + cols_es_pi
149
+
150
+ row_segments_idx = sorted(list(range(len(ext_row_segments_pi))), key=lambda idx: ext_row_segments_pi[idx])
151
+ col_segments_idx = sorted(list(range(len(ext_col_segments_pi))), key=lambda idx: ext_col_segments_pi[idx])
152
+
153
+ ext_divide_labels.append(row_segments_idx.index(divide_labels[batch_idx].item()))
154
+
155
+ ext_row_segments.append([ext_row_segments_pi[idx] for idx in row_segments_idx])
156
+ ext_col_segments.append([ext_col_segments_pi[idx] for idx in col_segments_idx])
157
+
158
+ ext_layouts_pi = np.full((len(ext_row_segments_pi) - 1, len(ext_col_segments_pi) - 1), -1)
159
+ ext_cells_spans_pi = list()
160
+ for cell_idx, cell_span in enumerate(cells_spans_pi):
161
+ l, t, r, b = cell_span
162
+ l = col_segments_idx.index(l)
163
+ r = col_segments_idx.index(r+1) - 1
164
+ t = row_segments_idx.index(t)
165
+ b = row_segments_idx.index(b+1) - 1
166
+ ext_cells_spans_pi.append([l, t, r, b])
167
+ ext_layouts_pi[t:b+1, l:r+1] = cell_idx
168
+ ext_cells_spans.append(ext_cells_spans_pi)
169
+ ext_layouts.append(ext_layouts_pi)
170
+
171
+ return ext_row_segments, ext_col_segments, ext_cells_spans, aligned_layouts(ext_layouts, layouts), torch.tensor(ext_divide_labels).to(divide_labels.device)
172
+
173
+ def aligned_layouts(layouts_list, layouts):
174
+ batch_size = len(layouts_list)
175
+ dtype = layouts.dtype
176
+ device = layouts.device
177
+
178
+ max_row_nums = max([l.shape[0] for l in layouts_list])
179
+ max_col_nums = max([l.shape[1] for l in layouts_list])
180
+
181
+ aligned_layouts = list()
182
+ for batch_idx in range(batch_size):
183
+ num_rows_pi = layouts_list[batch_idx].shape[0]
184
+ num_cols_pi = layouts_list[batch_idx].shape[1]
185
+ layouts_pi = torch.from_numpy(layouts_list[batch_idx]).to(dtype=dtype, device=device)
186
+ aligned_layouts_pi = F.pad(
187
+ layouts_pi,
188
+ (0, max_col_nums-num_cols_pi, 0, max_row_nums-num_rows_pi),
189
+ mode='constant',
190
+ value=-1
191
+ )
192
+ aligned_layouts.append(aligned_layouts_pi)
193
+ aligned_layouts = torch.stack(aligned_layouts, dim=0)
194
+ return aligned_layouts
195
+
196
+ def parse_layout(spans, num_rows, num_cols):
197
+ layout = np.full([num_rows, num_cols], -1, dtype=np.int)
198
+ cell_count = 0
199
+ for x1, y1, x2, y2, prob in spans:
200
+ layout[y1:y2+1, x1:x2+1] = cell_count
201
+ cell_count += 1
202
+
203
+ cells_id = list()
204
+ for row_idx in range(num_rows):
205
+ for col_idx in range(num_cols):
206
+ cell_id = layout[row_idx, col_idx]
207
+ if cell_id in cells_id:
208
+ layout[row_idx, col_idx] = cells_id.index(cell_id)
209
+ else:
210
+ layout[row_idx, col_idx] = len(cells_id)
211
+ cells_id.append(cell_id)
212
+ return layout
213
+
214
+
215
+ def parse_cells(layout, row_segments, col_segments):
216
+ cells = list()
217
+ num_cells = np.max(layout) + 1
218
+ for cell_id in range(num_cells):
219
+ cell_positions = np.argwhere(layout == cell_id)
220
+ y1 = np.min(cell_positions[:, 0])
221
+ y2 = np.max(cell_positions[:, 0])
222
+ x1 = np.min(cell_positions[:, 1])
223
+ x2 = np.max(cell_positions[:, 1])
224
+ assert np.all(layout[y1:y2, x1:x2] == cell_id)
225
+ x1 = col_segments[x1]
226
+ x2 = col_segments[x2+1]
227
+ y1 = row_segments[y1]
228
+ y2 = row_segments[y2+1]
229
+ cell = dict(
230
+ segmentation=[[[x1, y1], [x2, y1], [x2, y2], [x1, y2]]]
231
+ )
232
+ cells.append(cell)
233
+ return cells
234
+
235
+
236
+ def process_layout(score, index):
237
+ layout = torch.full_like(index, -1)
238
+ layout_mask = torch.full_like(index, -1)
239
+ nrow, ncol = score.shape
240
+ for cell_id in range(nrow * ncol):
241
+ if layout_mask.min() != -1:
242
+ break
243
+ crow, ccol = torch.where(layout_mask == layout_mask.min())
244
+ ccol = ccol[crow == crow.min()].min()
245
+ crow = crow.min()
246
+ id = index[crow, ccol]
247
+ h, w = torch.where(index == id)
248
+ if h.shape[0] == 1 or w.shape[0] == 1:
249
+ layout_mask[h, w] = 1
250
+ layout[h, w] = cell_id
251
+ continue
252
+ else:
253
+ h_min = h.min()
254
+ h_max = h.max()
255
+ w_min = w.min()
256
+ w_max = w.max()
257
+ if torch.all(index[h_min:h_max+1, w_min:w_max+1] == id):
258
+ layout_mask[h_min:h_max+1, w_min:w_max+1] = 1
259
+ layout[h_min:h_max+1, w_min:w_max+1] = cell_id
260
+ else:
261
+ lf_row = crow
262
+ lf_col = ccol
263
+
264
+ col_mem = -1
265
+ for col_ in range(lf_col, w_max + 1):
266
+ if index[lf_row, col_] == id:
267
+ layout_mask[lf_row, col_] = 1
268
+ layout[lf_row, col_] = cell_id
269
+ col_mem = col_
270
+ else:
271
+ break
272
+ for row_ in range(lf_row + 1, h_max + 1):
273
+ if torch.all(index[row_, lf_col: col_mem + 1] == id):
274
+ layout_mask[row_, lf_col: col_mem + 1] = 1
275
+ layout[row_, lf_col: col_mem + 1] = cell_id
276
+ else:
277
+ break
278
+ return layout
279
+
280
+ def process_layout(score, index, use_score=False, is_merge=True, score_threshold=0.5):
281
+ if use_score:
282
+ if is_merge:
283
+ y, x = torch.where(score < score_threshold)
284
+ index[y, x] = index.max() + 1
285
+ else:
286
+ y, x = torch.where(score < score_threshold)
287
+ index[y, x] = torch.arange(index.max() + 1, index.max() + 1 + len(y)).to(index.device, index.dtype)
288
+
289
+ layout = torch.full_like(index, -1)
290
+ layout_mask = torch.full_like(index, -1)
291
+ nrow, ncol = score.shape
292
+ for cell_id in range(max(nrow * ncol, index.max() + 1)):
293
+ if layout_mask.min() != -1:
294
+ break
295
+ crow, ccol = torch.where(layout_mask == layout_mask.min())
296
+ ccol = ccol[crow == crow.min()].min()
297
+ crow = crow.min()
298
+ id = index[crow, ccol]
299
+ h, w = torch.where(index == id)
300
+ if h.shape[0] == 1 or w.shape[0] == 1:
301
+ layout_mask[h, w] = 1
302
+ layout[h, w] = cell_id
303
+ continue
304
+ else:
305
+ h_min = h.min()
306
+ h_max = h.max()
307
+ w_min = w.min()
308
+ w_max = w.max()
309
+ if torch.all(index[h_min:h_max+1, w_min:w_max+1] == id):
310
+ layout_mask[h_min:h_max+1, w_min:w_max+1] = 1
311
+ layout[h_min:h_max+1, w_min:w_max+1] = cell_id
312
+ else:
313
+ lf_row = crow
314
+ lf_col = ccol
315
+
316
+ col_mem = -1
317
+ for col_ in range(lf_col, w_max + 1):
318
+ if index[lf_row, col_] == id:
319
+ layout_mask[lf_row, col_] = 1
320
+ layout[lf_row, col_] = cell_id
321
+ col_mem = col_
322
+ else:
323
+ break
324
+ for row_ in range(lf_row + 1, h_max + 1):
325
+ if torch.all(index[row_, lf_col: col_mem + 1] == id):
326
+ layout_mask[row_, lf_col: col_mem + 1] = 1
327
+ layout[row_, lf_col: col_mem + 1] = cell_id
328
+ else:
329
+ break
330
+ return layout
331
+
332
+ def layout2spans(layout):
333
+ rows, cols = layout.shape[-2:]
334
+ cells_span = list()
335
+ for cell_id in range(rows * cols):
336
+ cell_positions = np.argwhere(layout == cell_id)
337
+ if len(cell_positions) == 0:
338
+ continue
339
+ y1 = np.min(cell_positions[:, 0])
340
+ y2 = np.max(cell_positions[:, 0])
341
+ x1 = np.min(cell_positions[:, 1])
342
+ x2 = np.max(cell_positions[:, 1])
343
+ assert np.all(layout[y1:y2, x1:x2] == cell_id)
344
+ cells_span.append([x1, y1, x2, y2])
345
+ return [cells_span]
346
+
347
+ def spatial_att_to_spans(spatial_att_weight_pred):
348
+ max_score, max_index = spatial_att_weight_pred.max(dim=0)
349
+ layout = process_layout(max_score, max_index, use_score=True, is_merge=False)
350
+ layout = process_layout(max_score, layout)
351
+
352
+ layout = layout.cpu().numpy()
353
+ spans = layout2spans(layout)
354
+ return spans
355
+
356
+
357
+ def save_logitmap(filename, logit):
358
+ cv2.imwrite(filename, (logit.sigmoid()*255).cpu().numpy().astype('uint8'))
359
+
360
+
361
+ def draw_spans(dst, src, spans, type):
362
+ image = cv2.imread(src)
363
+ H, W, *_ = image.shape
364
+ for span in spans:
365
+ if type == 'col':
366
+ cv2.rectangle(image, (span[0], 0), (span[1], H), (0, 0, 255), thickness=1)
367
+ elif type == 'row':
368
+ cv2.rectangle(image, (0, span[0]), (W, span[1]), (0, 0, 255), thickness=1)
369
+ cv2.imwrite(dst, image)
370
+
371
+
libs/utils/__init__.py ADDED
File without changes
libs/utils/cal_f1.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from numpy.core.fromnumeric import sort
3
+ import tqdm
4
+ import json
5
+ import copy
6
+ import Polygon
7
+ import numpy as np
8
+ from .scitsr.eval import json2Relations, eval_relations
9
+
10
+
11
+ def parse_layout(spans, num_rows, num_cols):
12
+ layout = np.full([num_rows, num_cols], -1, dtype=np.int)
13
+ cell_count = 0
14
+ for x1, y1, x2, y2 in spans:
15
+ layout[y1:y2+1, x1:x2+1] = cell_count
16
+ cell_count += 1
17
+
18
+ cells_id = list()
19
+ for row_idx in range(num_rows):
20
+ for col_idx in range(num_cols):
21
+ cell_id = layout[row_idx, col_idx]
22
+ if cell_id in cells_id:
23
+ layout[row_idx, col_idx] = cells_id.index(cell_id)
24
+ else:
25
+ layout[row_idx, col_idx] = len(cells_id)
26
+ cells_id.append(cell_id)
27
+ return layout
28
+
29
+
30
+ def parse_cells(layout, spans, row_segments, col_segments, lines):
31
+ cells = list()
32
+ num_cells = np.max(layout) + 1
33
+ for cell_id in range(num_cells):
34
+ cell_positions = np.argwhere(layout == cell_id)
35
+ y1 = np.min(cell_positions[:, 0])
36
+ y2 = np.max(cell_positions[:, 0])
37
+ x1 = np.min(cell_positions[:, 1])
38
+ x2 = np.max(cell_positions[:, 1])
39
+ assert np.all(layout[y1:y2, x1:x2] == cell_id)
40
+ x1 = col_segments[x1]
41
+ x2 = col_segments[x2+1]
42
+ y1 = row_segments[y1]
43
+ y2 = row_segments[y2+1]
44
+ cell = dict(
45
+ segmentation=[[[x1, y1], [x2, y1], [x2, y2], [x1, y2]]]
46
+ )
47
+ cells.append(cell)
48
+
49
+ extend_cell_lines(cells, lines)
50
+
51
+ return cells
52
+
53
+
54
+ def extend_cell_lines(cells, lines):
55
+ def segmentation_to_polygon(segmentation):
56
+ polygon = Polygon.Polygon()
57
+ for contour in segmentation:
58
+ polygon = polygon + Polygon.Polygon(contour)
59
+ return polygon
60
+
61
+ lines = copy.deepcopy(lines)
62
+
63
+ cells_poly = [segmentation_to_polygon(item['segmentation']) for item in cells]
64
+ lines_poly = [segmentation_to_polygon(item['segmentation']) for item in lines]
65
+
66
+ cells_lines = [[] for _ in range(len(cells))]
67
+
68
+ for line_idx, line_poly in enumerate(lines_poly):
69
+ if line_poly.area() == 0:
70
+ continue
71
+ line_area = line_poly.area()
72
+ max_overlap = 0
73
+ max_overlap_idx = None
74
+ for cell_idx, cell_poly in enumerate(cells_poly):
75
+ overlap = (cell_poly & line_poly).area() / line_area
76
+ if overlap > max_overlap:
77
+ max_overlap_idx = cell_idx
78
+ max_overlap = overlap
79
+ if max_overlap > 0:
80
+ cells_lines[max_overlap_idx].append(line_idx)
81
+ lines_y1 = [segmentation_to_bbox(item['segmentation'])[1] for item in lines]
82
+ cells_lines = [sorted(item, key=lambda idx: lines_y1[idx]) for item in cells_lines]
83
+
84
+ for cell, cell_lines in zip(cells, cells_lines):
85
+ transcript = []
86
+ for idx in cell_lines:
87
+ transcript.extend(lines[idx]['transcript'])
88
+ cell['transcript'] = transcript
89
+
90
+
91
+ def segmentation_to_bbox(segmentation):
92
+ x1 = min([min([pt[0] for pt in contour]) for contour in segmentation])
93
+ y1 = min([min([pt[1] for pt in contour]) for contour in segmentation])
94
+ x2 = max([max([pt[0] for pt in contour]) for contour in segmentation])
95
+ y2 = max([max([pt[1] for pt in contour]) for contour in segmentation])
96
+ return [x1, y1, x2, y2]
97
+
98
+
99
+ def cal_cell_spans(table):
100
+ layout = table['layout']
101
+ num_cells = len(table['cells'])
102
+ cells_span = list()
103
+ for cell_id in range(num_cells):
104
+ cell_positions = np.argwhere(layout == cell_id)
105
+ y1 = np.min(cell_positions[:, 0])
106
+ y2 = np.max(cell_positions[:, 0])
107
+ x1 = np.min(cell_positions[:, 1])
108
+ x2 = np.max(cell_positions[:, 1])
109
+ assert np.all(layout[y1:y2, x1:x2] == cell_id)
110
+ cells_span.append([x1, y1, x2, y2])
111
+ return cells_span
112
+
113
+
114
+ def pred_result_to_table(table, pred_result):
115
+ # gt ocr result
116
+ lines = [dict(segmentation=cell['segmentation'], transcript=cell['transcript']) for cell in table['cells'] if 'bbox' in cell.keys()]
117
+
118
+ row_segments, col_segments, divide, spans = pred_result
119
+ num_rows = len(row_segments) - 1
120
+ num_cols = len(col_segments) - 1
121
+
122
+ layout = parse_layout(spans, num_rows, num_cols)
123
+ cells = parse_cells(layout, spans, row_segments, col_segments, lines)
124
+ head_rows = list(range(0, divide))
125
+ body_rows = list(range(divide, num_rows))
126
+
127
+ table = dict(
128
+ layout=layout,
129
+ head_rows=head_rows,
130
+ body_rows=body_rows,
131
+ cells=cells
132
+ )
133
+
134
+ return table
135
+
136
+
137
+ def table_to_relations(table):
138
+ cell_spans = cal_cell_spans(table)
139
+ contents = [''.join(cell['transcript']).split() for cell in table['cells']]
140
+ relations = []
141
+ for span, content in zip(cell_spans, contents):
142
+ x1, y1, x2, y2 = span
143
+ relations.append(dict(start_row=y1, end_row=y2, start_col=x1, end_col=x2, content=content))
144
+ return dict(cells=relations)
145
+
146
+
147
+ def cal_f1(label, pred):
148
+ label = json2Relations(label, splitted_content=True)
149
+ pred = json2Relations(pred, splitted_content=True)
150
+ precision, recall = eval_relations(gt=[label], res=[pred], cmp_blank=True)
151
+ f1 = 2.0 * precision * recall / (precision + recall) if precision + recall > 0 else 0
152
+ return [precision, recall, f1]
153
+
154
+
155
+ def single_process(labels, preds):
156
+ scores = dict()
157
+ for key in tqdm.tqdm(labels.keys()):
158
+ pred = preds.get(key, '')
159
+ label = labels.get(key, '')
160
+ score = cal_f1(label, pred)
161
+ scores[key] = score
162
+ return scores
163
+
164
+
165
+ def _worker(labels, preds, keys, result_queue):
166
+ for key in keys:
167
+ label = labels.get(key, '')
168
+ pred = preds.get(key, '')
169
+ score = cal_f1(label, pred)
170
+ result_queue.put((key, score))
171
+
172
+
173
+ def multi_process(labels, preds, num_workers):
174
+ import multiprocessing
175
+ manager = multiprocessing.Manager()
176
+ result_queue = manager.Queue()
177
+ keys = list(labels.keys())
178
+ workers = list()
179
+ for worker_idx in range(num_workers):
180
+ worker = multiprocessing.Process(
181
+ target=_worker,
182
+ args=(
183
+ labels,
184
+ preds,
185
+ keys[worker_idx::num_workers],
186
+ result_queue
187
+ )
188
+ )
189
+ worker.daemon = True
190
+ worker.start()
191
+ workers.append(worker)
192
+
193
+ scores = dict()
194
+ tq = tqdm.tqdm(total=len(keys))
195
+ for _ in range(len(keys)):
196
+ key, val = result_queue.get()
197
+ scores[key] = val
198
+ P, R, F1 = (100 * np.array(list(scores.values()))).mean(0).tolist()
199
+ tq.set_description('P: %.2f, R: %.2f, F1: %.2f' % (P, R, F1), False)
200
+ tq.update()
201
+
202
+ return scores
203
+
204
+
205
+ def evaluate_f1(labels, preds, num_workers=0):
206
+ preds = {idx: pred for idx, pred in enumerate(preds)}
207
+ labels = {idx: label for idx, label in enumerate(labels)}
208
+ if num_workers == 0:
209
+ scores = single_process(labels, preds)
210
+ else:
211
+ scores = multi_process(labels, preds, num_workers)
212
+ sorted_idx = sorted(list(range(len(list(scores)))), key=lambda idx: list(scores.keys())[idx])
213
+ scores = [scores[idx] for idx in sorted_idx]
214
+ return scores
libs/utils/checkpoint.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from .comm import get_rank, synchronize
4
+
5
+
6
+ def save_checkpoint(checkpoint, model, optimizer=None, best_metric=None, epoch=None):
7
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
8
+ model = model.module
9
+ if get_rank() == 0:
10
+ if not os.path.exists(os.path.dirname(checkpoint)):
11
+ os.makedirs(os.path.dirname(checkpoint))
12
+
13
+ infos = dict()
14
+ infos['model_param'] = model.state_dict()
15
+ if optimizer is not None:
16
+ infos['opt_param'] = optimizer.state_dict()
17
+
18
+ if best_metric is not None:
19
+ infos['best_metric'] = best_metric
20
+
21
+ if epoch is not None:
22
+ infos['epoch'] = epoch
23
+
24
+ torch.save(infos, checkpoint)
25
+ synchronize()
26
+
27
+
28
+ def load_checkpoint(checkpoint, model, optimizer=None):
29
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
30
+ model = model.module
31
+ checkpoint = torch.load(checkpoint, map_location='cpu')
32
+
33
+ model.load_state_dict(checkpoint['model_param'], strict=False)
34
+
35
+ if (optimizer is not None) and ('opt_param' in checkpoint):
36
+ optimizer.load_state_dict(checkpoint['opt_param'])
37
+
38
+ if 'best_metric' in checkpoint:
39
+ best_metric = checkpoint['best_metric']
40
+ else:
41
+ best_metric = None
42
+
43
+ if 'epoch' in checkpoint:
44
+ epoch = checkpoint['epoch']
45
+ else:
46
+ epoch = None
47
+ return best_metric, epoch
libs/utils/comm.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains primitives for multi-gpu communication.
3
+ This is useful when doing distributed training.
4
+ """
5
+ import os
6
+ import pickle
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+
11
+
12
+ def distributed():
13
+ num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
14
+ distributed = num_gpus > 1
15
+ return distributed
16
+
17
+
18
+ def get_world_size():
19
+ if not dist.is_available():
20
+ return 1
21
+ if not dist.is_initialized():
22
+ return 1
23
+ return dist.get_world_size()
24
+
25
+
26
+ def get_rank():
27
+ if not dist.is_available():
28
+ return 0
29
+ if not dist.is_initialized():
30
+ return 0
31
+ return dist.get_rank()
32
+
33
+
34
+ def get_local_rank():
35
+ if 'LOCAL_RANK' not in os.environ:
36
+ return get_rank()
37
+ else:
38
+ return int(os.environ['LOCAL_RANK'])
39
+
40
+
41
+ def is_main_process():
42
+ return get_rank() == 0
43
+
44
+
45
+ def synchronize():
46
+ """
47
+ Helper function to synchronize (barrier) among all processes when
48
+ using distributed training
49
+ """
50
+ if not dist.is_available():
51
+ return
52
+ if not dist.is_initialized():
53
+ return
54
+ world_size = dist.get_world_size()
55
+ if world_size == 1:
56
+ return
57
+ dist.barrier()
58
+
59
+
60
+ def all_gather(data):
61
+ """
62
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
63
+ Args:
64
+ data: any picklable object
65
+ Returns:
66
+ list[data]: list of data gathered from each rank
67
+ """
68
+ world_size = get_world_size()
69
+ if world_size == 1:
70
+ return [data]
71
+
72
+ # serialized to a Tensor
73
+ buffer = pickle.dumps(data)
74
+ storage = torch.ByteStorage.from_buffer(buffer)
75
+ tensor = torch.ByteTensor(storage).to("cuda")
76
+
77
+ # obtain Tensor size of each rank
78
+ local_size = torch.LongTensor([tensor.numel()]).to("cuda")
79
+ size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
80
+ dist.all_gather(size_list, local_size)
81
+ size_list = [int(size.item()) for size in size_list]
82
+ max_size = max(size_list)
83
+
84
+ # receiving Tensor from all ranks
85
+ # we pad the tensor because torch all_gather does not support
86
+ # gathering tensors of different shapes
87
+ tensor_list = []
88
+ for _ in size_list:
89
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
90
+ if local_size != max_size:
91
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
92
+ tensor = torch.cat((tensor, padding), dim=0)
93
+ dist.all_gather(tensor_list, tensor)
94
+
95
+ data_list = []
96
+ for size, tensor in zip(size_list, tensor_list):
97
+ buffer = tensor.cpu().numpy().tobytes()[:size]
98
+ data_list.append(pickle.loads(buffer))
99
+
100
+ return data_list
101
+
102
+
103
+ def reduce_dict(input_dict, average=True):
104
+ """
105
+ Args:
106
+ input_dict (dict): all the values will be reduced
107
+ average (bool): whether to do average or sum
108
+ Reduce the values in the dictionary from all processes so that process with rank
109
+ 0 has the averaged results. Returns a dict with the same fields as
110
+ input_dict, after reduction.
111
+ """
112
+ world_size = get_world_size()
113
+ if world_size < 2:
114
+ return input_dict
115
+ with torch.no_grad():
116
+ names = []
117
+ values = []
118
+ # sort the keys so that they are consistent across processes
119
+ for k in sorted(input_dict.keys()):
120
+ names.append(k)
121
+ values.append(input_dict[k])
122
+ values = torch.stack(values, dim=0)
123
+ dist.reduce(values, dst=0)
124
+ if dist.get_rank() == 0 and average:
125
+ # only main process gets accumulated, so only divide by
126
+ # world_size in this case
127
+ values /= world_size
128
+ reduced_dict = {k: v for k, v in zip(names, values)}
129
+ return reduced_dict
libs/utils/context_cacher.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ContextCacher:
2
+ def __init__(self):
3
+ self.infos = dict()
4
+
5
+ def reset(self):
6
+ self.infos.clear()
7
+
8
+ def cache_info(self, key, info):
9
+ self.infos[key] = info
10
+
11
+ def get_info(self, key):
12
+ return self.infos[key]
13
+
14
+
15
+ global_context_cacher = ContextCacher()
libs/utils/counter.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from collections import defaultdict
3
+ from .comm import distributed, all_gather
4
+
5
+
6
+ def format_dict(res_dict):
7
+ res_strs = []
8
+ for key, val in res_dict.items():
9
+ res_strs.append('%s: %s' % (key, val))
10
+ return ', '.join(res_strs)
11
+
12
+
13
+ class Counter:
14
+ def __init__(self, cache_nums=1000):
15
+ self.cache_nums = cache_nums
16
+ self.reset()
17
+
18
+ def update(self, metric):
19
+ for key, val in metric.items():
20
+ if isinstance(val, torch.Tensor):
21
+ val = val.item()
22
+ self.metric_dict[key].append(val)
23
+ if self.cache_nums is not None:
24
+ self.metric_dict[key] = self.metric_dict[key][-1*self.cache_nums:]
25
+
26
+ def reset(self):
27
+ self.metric_dict = defaultdict(list)
28
+
29
+ def _sync(self):
30
+ metric_dicts = all_gather(self.metric_dict)
31
+ total_metric_dict = defaultdict(list)
32
+ for metric_dict in metric_dicts:
33
+ for key, val in metric_dict.items():
34
+ total_metric_dict[key].extend(val)
35
+ return total_metric_dict
36
+
37
+ def format_mean(self, sync=True):
38
+ if sync and distributed():
39
+ metric_dict = self._sync()
40
+ else:
41
+ metric_dict = self.metric_dict
42
+ res_dict = {key: '%.4f' % (sum(val)/len(val)) for key, val in metric_dict.items()}
43
+ return format_dict(res_dict)
libs/utils/format_translate.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import copy
3
+ import Polygon
4
+ import numpy as np
5
+ from bs4 import BeautifulSoup as bs
6
+ from .time_counter import format_table
7
+
8
+
9
+ def check_continuous(seq):
10
+ if len(seq) > 0:
11
+ pre_val = seq[0]
12
+ for val in seq[1:]:
13
+ assert pre_val + 1 == val
14
+ pre_val = val
15
+
16
+ def table_to_latex(table):
17
+ def cal_cls_id(transcript):
18
+ transcript = ''.join(transcript)
19
+ if transcript == '':
20
+ return '</none>'
21
+ elif transcript == '<b> </b>':
22
+ return '</bold>'
23
+ elif transcript == ' ':
24
+ return '</space>'
25
+ else:
26
+ return '</line>'
27
+ assert table['layout'].max() + 1 == len(table['cells'])
28
+ latex = [cal_cls_id(cell['transcript']) for cell in table['cells']]
29
+ return latex
30
+
31
+ def html_to_table(html):
32
+ tokens = html['html']['structure']['tokens']
33
+
34
+ layout = [[]]
35
+
36
+ def extend_table(x, y):
37
+ assert (x >= 0) and (y >= 0)
38
+ nonlocal layout
39
+
40
+ if x >= len(layout[0]):
41
+ for row in layout:
42
+ row.extend([-1] * (x - len(row) + 1))
43
+
44
+ if y >= len(layout):
45
+ for _ in range(y - len(layout) + 1):
46
+ layout.append([-1] * len(layout[0]))
47
+
48
+ def set_cell_val(x, y, val):
49
+ assert (x >= 0) and (y >= 0)
50
+ nonlocal layout
51
+ extend_table(x, y)
52
+ layout[y][x] = val
53
+
54
+ def get_cell_val(x, y):
55
+ assert (x >= 0) and (y >= 0)
56
+ nonlocal layout
57
+ extend_table(x, y)
58
+ return layout[y][x]
59
+
60
+ def parse_span_val(token):
61
+ span_val = int(token[token.index('"') + 1:token.rindex('"')])
62
+ return span_val
63
+
64
+ def maskout_left_rows():
65
+ nonlocal row_idx, layout
66
+ layout = layout[:max(row_idx+1, 1)]
67
+
68
+ row_idx = -1
69
+ col_idx = -1
70
+ line_idx = -1
71
+ inside_head = False
72
+ inside_body = False
73
+ head_rows = list()
74
+ body_rows = list()
75
+ col_span = 1
76
+ row_span = 1
77
+ for token in tokens:
78
+ if token == '<thead>':
79
+ inside_head = True
80
+ maskout_left_rows()
81
+ elif token == '</thead>':
82
+ inside_head = False
83
+ maskout_left_rows()
84
+ elif token == '<tbody>':
85
+ inside_body = True
86
+ maskout_left_rows()
87
+ elif token == '</tbody>':
88
+ inside_body = False
89
+ maskout_left_rows()
90
+ elif token == '<tr>':
91
+ row_idx += 1
92
+ col_idx = -1
93
+ if inside_head:
94
+ head_rows.append(row_idx)
95
+ if inside_body:
96
+ body_rows.append(row_idx)
97
+ elif token in ['<td>', '<td']:
98
+ line_idx += 1
99
+ col_idx += 1
100
+ row_span = 1
101
+ col_span = 1
102
+ while get_cell_val(col_idx, row_idx) != -1:
103
+ col_idx += 1
104
+ elif 'colspan' in token:
105
+ col_span = parse_span_val(token)
106
+ elif 'rowspan' in token:
107
+ row_span = parse_span_val(token)
108
+ elif token == '</td>':
109
+ for cur_row_idx in range(row_idx, row_idx + row_span):
110
+ for cur_col_idx in range(col_idx, col_idx + col_span):
111
+ set_cell_val(cur_col_idx, cur_row_idx, line_idx)
112
+ col_idx += col_span - 1
113
+
114
+ check_continuous(head_rows)
115
+ check_continuous(body_rows)
116
+ assert len(set(head_rows) | set(body_rows)) == len(layout)
117
+ layout = np.array(layout)
118
+ assert np.all(layout >= 0)
119
+
120
+ cells_info = list()
121
+ for cell_idx, cell in enumerate(html['html']['cells']):
122
+ transcript = cell['tokens']
123
+ cell_info = dict(
124
+ transcript=transcript
125
+ )
126
+ if 'bbox' in cell:
127
+ x1, y1, x2, y2 = cell['bbox']
128
+ cell_info['bbox'] = [x1, y1, x2, y2]
129
+ cell_info['segmentation'] = [[[x1, y1], [x2, y1], [x2, y2], [x1, y2]]]
130
+ cells_info.append(cell_info)
131
+
132
+ table = dict(
133
+ layout=layout,
134
+ cells=cells_info,
135
+ head_rows=head_rows,
136
+ body_rows=body_rows
137
+ )
138
+ return table
139
+
140
+
141
+ def segmentation_to_bbox(segmentation):
142
+ x1 = min([min([pt[0] for pt in contour]) for contour in segmentation])
143
+ y1 = min([min([pt[1] for pt in contour]) for contour in segmentation])
144
+ x2 = max([max([pt[0] for pt in contour]) for contour in segmentation])
145
+ y2 = max([max([pt[1] for pt in contour]) for contour in segmentation])
146
+ return [x1, y1, x2, y2]
147
+
148
+
149
+ def table_to_html(table):
150
+ layout = table['layout']
151
+ head_rows = table['head_rows']
152
+ body_rows = table['body_rows']
153
+
154
+ cells_span = list()
155
+ for cell_idx in range(len(table['cells'])):
156
+ cell_positions = np.argwhere(layout == cell_idx)
157
+ row_span = [np.min(cell_positions[:, 0]), np.max(cell_positions[:, 0]) + 1]
158
+ col_span = [np.min(cell_positions[:, 1]), np.max(cell_positions[:, 1]) + 1]
159
+ assert np.all(layout[row_span[0]:row_span[1], col_span[0]:col_span[1]] == cell_idx)
160
+ cells_span.append([row_span, col_span])
161
+
162
+ cells = list()
163
+ tokens = ['<thead>']
164
+ inside_head = True
165
+ for row_idx in range(layout.shape[0]):
166
+ if row_idx in body_rows:
167
+ if inside_head:
168
+ tokens.append('</thead>')
169
+ tokens.append('<tbody>')
170
+ inside_head = False
171
+ tokens.append('<tr>')
172
+ for col_idx in range(table['layout'].shape[1]):
173
+ cell_idx = layout[row_idx][col_idx]
174
+ assert cell_idx <= len(cells)
175
+ if cell_idx == len(cells):
176
+ row_span, col_span = cells_span[cell_idx]
177
+ if (row_span[1] - row_span[0]) == 1 and (col_span[1] - col_span[0] == 1):
178
+ tokens.append('<td>')
179
+ else:
180
+ tokens.append('<td')
181
+ if (row_span[1] - row_span[0]) > 1:
182
+ tokens.append(' rowspan="%d"' % (row_span[1] - row_span[0]))
183
+ if (col_span[1] - col_span[0]) > 1:
184
+ tokens.append(' colspan="%d"' % (col_span[1] - col_span[0]))
185
+ tokens.append('>')
186
+ tokens.append('</td>')
187
+
188
+ cell = dict()
189
+ cell['tokens'] = table['cells'][cell_idx]['transcript']
190
+ if 'segmentation' in table['cells'][cell_idx]:
191
+ cell['bbox'] = segmentation_to_bbox(table['cells'][cell_idx]['segmentation'])
192
+ cells.append(cell)
193
+ tokens.append('</tr>')
194
+ if inside_head:
195
+ tokens.append('</thead>')
196
+ tokens.append('<tbody>')
197
+ tokens.append('</tbody>')
198
+
199
+ html = dict(
200
+ html=dict(
201
+ cells=cells,
202
+ structure=dict(
203
+ tokens=tokens
204
+ )
205
+ )
206
+ )
207
+ return html
208
+
209
+
210
+ def format_html_for_vis(html):
211
+ html_string = '''<html>
212
+ <head>
213
+ <meta charset="UTF-8">
214
+ <style>
215
+ table, th, td {
216
+ border: 1px solid black;
217
+ font-size: 10px;
218
+ }
219
+ </style>
220
+ </head>
221
+ <body>
222
+ <table frame="hsides" rules="groups" width="100%%">
223
+ %s
224
+ </table>
225
+ </body>
226
+ </html>''' % ''.join(html['html']['structure']['tokens'])
227
+ cell_nodes = list(re.finditer(r'(<td[^<>]*>)(</td>)', html_string))
228
+ assert len(cell_nodes) == len(html['html']['cells']), 'Number of cells defined in tags does not match the length of cells'
229
+ cells = [''.join(c['tokens']) for c in html['html']['cells']]
230
+ offset = 0
231
+ for n, cell in zip(cell_nodes, cells):
232
+ html_string = html_string[:n.end(1) + offset] + cell + html_string[n.start(2) + offset:]
233
+ offset += len(cell)
234
+ # prettify the html
235
+ soup = bs(html_string)
236
+ html_string = soup.prettify()
237
+ return html_string
238
+
239
+
240
+ def format_html(html):
241
+ html_string = '''<html><body><table>%s</table></body></html>''' % ''.join(html['html']['structure']['tokens'])
242
+ cell_nodes = list(re.finditer(r'(<td[^<>]*>)(</td>)', html_string))
243
+ assert len(cell_nodes) == len(html['html']['cells']), 'Number of cells defined in tags does not match the length of cells'
244
+ cells = [''.join(c['tokens']) for c in html['html']['cells']]
245
+ offset = 0
246
+ for n, cell in zip(cell_nodes, cells):
247
+ html_string = html_string[:n.end(1) + offset] + cell + html_string[n.start(2) + offset:]
248
+ offset += len(cell)
249
+ return html_string
250
+
251
+
252
+ def format_table_layout(table):
253
+ layout = table['table']['layout']
254
+ cell_lines = [cell['lines_idx'] for cell in table['table']['cells']]
255
+
256
+ table_cells_info = list()
257
+ for row in layout:
258
+ row_cells_info = list()
259
+ for cell_idx in row:
260
+ cell_str = ','.join([str(item) for item in cell_lines[cell_idx]])
261
+ row_cells_info.append(cell_str)
262
+ table_cells_info.append(row_cells_info)
263
+
264
+ return format_table(table_cells_info, padding=1)
265
+
266
+
267
+ def remove_blank_cell(html):
268
+ start_idx = 0
269
+ while '<td' in html[start_idx:]:
270
+ start_idx = html[start_idx:].index('<td') + start_idx
271
+ content_start_idx = html[start_idx:].index('>') + 1 + start_idx
272
+ content_end_idx = html[content_start_idx:].index('</td>') + content_start_idx
273
+ end_idx = content_end_idx + len('</td>')
274
+ if content_end_idx == content_start_idx:
275
+ html = html[:start_idx] + html[end_idx:]
276
+ else:
277
+ start_idx = end_idx
278
+ return html
libs/utils/logger.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ import logging
3
+ import os
4
+ import sys
5
+
6
+ from .comm import get_rank
7
+
8
+
9
+ _default_logger = None
10
+
11
+
12
+ def __init_logger():
13
+ global _default_logger
14
+ if get_rank() == 0:
15
+ logger = logging.getLogger('default')
16
+ logger.setLevel(logging.DEBUG)
17
+ formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
18
+
19
+ if not any([isinstance(item, logging.StreamHandler) for item in logger.handlers]):
20
+ ch = logging.StreamHandler(stream=sys.stdout)
21
+ ch.setLevel(logging.DEBUG)
22
+ ch.setFormatter(formatter)
23
+ logger.addHandler(ch)
24
+ _default_logger = logger
25
+
26
+
27
+ __init_logger()
28
+
29
+
30
+ def setup_logger(name, save_dir, filename="log.txt"):
31
+ global _default_logger
32
+ # don't log results for the non-master process
33
+ if get_rank() == 0:
34
+ logger = logging.getLogger(name)
35
+ logger.setLevel(logging.DEBUG)
36
+ formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
37
+
38
+ if not any([isinstance(item, logging.StreamHandler) for item in logger.handlers]):
39
+ ch = logging.StreamHandler(stream=sys.stdout)
40
+ ch.setLevel(logging.DEBUG)
41
+ ch.setFormatter(formatter)
42
+ logger.addHandler(ch)
43
+
44
+ logger.handlers = [item for item in logger.handlers if not isinstance(item, logging.FileHandler)]
45
+ if save_dir:
46
+ log_path = os.path.join(save_dir, filename)
47
+ if not os.path.exists(os.path.dirname(log_path)):
48
+ os.makedirs(os.path.dirname(log_path))
49
+ fh = logging.FileHandler(log_path)
50
+ fh.setLevel(logging.DEBUG)
51
+ fh.setFormatter(formatter)
52
+ logger.addHandler(fh)
53
+
54
+ _default_logger = logger
55
+
56
+
57
+ def info(*args, **kwargs):
58
+ if get_rank() == 0:
59
+ _default_logger.info(*args, **kwargs)
60
+
61
+
62
+ def error(*args, **kwargs):
63
+ if get_rank() == 0:
64
+ _default_logger.error(*args, **kwargs)
libs/utils/metric.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .utils import match_segment_spans, find_unmatch_segment_spans
3
+ from .teds import TEDS
4
+
5
+
6
+ class CellMergeAcc:
7
+ def __call__(self, preds, labels, labels_mask):
8
+ preds = preds & labels_mask
9
+ labels = labels & labels_mask
10
+ flag = preds == labels
11
+ flag = flag.reshape(flag.shape[0], flag.shape[1], -1).min(-1)[0]
12
+ mask = labels.reshape(labels.shape[0], labels.shape[1], -1).max(-1)[0]
13
+ correct_nums = float(torch.sum(flag & mask).detach().cpu().item())
14
+ total_nums = max(float(torch.sum(mask).detach().cpu().item()), 1e-6)
15
+ return correct_nums, total_nums
16
+
17
+
18
+ class AccMetric:
19
+ def __call__(self, preds, labels, labels_mask):
20
+ mask = (labels_mask != 0) & (labels != -1)
21
+ correct_nums = float(torch.sum((preds == labels) & mask).detach().cpu().item())
22
+ total_nums = max(float(torch.sum(mask).detach().cpu().item()), 1e-6)
23
+ return correct_nums, total_nums
24
+
25
+
26
+ def cal_cls_acc(cls_preds, cls_labels):
27
+ mask = (cls_labels != -1)
28
+ total_nums = float(torch.sum(mask).item())
29
+ pred_nums = float(torch.sum((cls_preds == cls_labels) & mask).item())
30
+ return pred_nums, total_nums
31
+
32
+
33
+ def cal_segment_pr(pred_segments, fg_spans, bg_spans):
34
+ correct_nums = 0
35
+ segment_nums = 0
36
+ span_nums = 0
37
+ for pred_segments_pi, fg_spans_pi, bg_spans_pi in zip(pred_segments, fg_spans, bg_spans):
38
+ matched_segments_idx, _ = match_segment_spans(pred_segments_pi, fg_spans_pi)
39
+ unmatched_segments_idx = find_unmatch_segment_spans(pred_segments_pi, fg_spans_pi + bg_spans_pi)
40
+
41
+ correct_nums += len(matched_segments_idx)
42
+ segment_nums += len(pred_segments_pi) - len(unmatched_segments_idx)
43
+ span_nums += len(fg_spans_pi)
44
+
45
+ return correct_nums, segment_nums, span_nums
46
+
47
+
48
+ class TEDSMetric:
49
+ def __init__(self, num_workers=1, structure_only=False):
50
+ self.evaluator = TEDS(n_jobs=num_workers, structure_only=structure_only)
51
+
52
+ def __call__(self, pred_htmls, label_htmls):
53
+ assert len(pred_htmls) == len(label_htmls)
54
+ pred_jsons = {idx: pred_html for idx, pred_html in enumerate(pred_htmls)}
55
+ label_jsons = {idx: dict(html=label_html) for idx, label_html in enumerate(label_htmls)}
56
+ scores = self.evaluator.batch_evaluate(pred_jsons, label_jsons)
57
+ scores = [scores[idx] for idx in range(len(pred_htmls))]
58
+ return scores
libs/utils/model_synchronizer.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .comm import get_world_size
3
+ import torch.distributed as dist
4
+
5
+
6
+ class ModelSynchronizer:
7
+ bm_map = {
8
+ 2: 0.65,
9
+ 4: 0.75,
10
+ 8: 0.875,
11
+ 12: 0.8875,
12
+ 16: 0.9,
13
+ 32: 0.9
14
+ }
15
+
16
+ def __init__(self, model, sync_rate, bm=None, blr=1.0, rescale_grad=1.0):
17
+ if bm is None:
18
+ self.bm = self.bm_map[get_world_size()]
19
+ else:
20
+ self.bm = bm
21
+ self.blr = blr
22
+ self.model = model
23
+ self.sync_rate = sync_rate
24
+ self.rescale_grad = rescale_grad
25
+ self.count = 0
26
+
27
+ self.param_align()
28
+
29
+ self.momentums = dict()
30
+ self.global_params = dict()
31
+ for k, v in self.model.named_parameters():
32
+ temp = torch.zeros_like(v, requires_grad=False)
33
+ temp.copy_(v.data)
34
+ self.global_params[k] = v
35
+ self.momentums[k] = torch.zeros_like(v, requires_grad=False)
36
+
37
+ def param_align(self):
38
+ for v in self.model.parameters():
39
+ dist.broadcast_multigpu([v.data], src=0)
40
+
41
+ for k, v in self.model.named_buffers():
42
+ if 'num_batches_tracked' in k:
43
+ continue
44
+ dist.broadcast_multigpu([v.data], src=0)
45
+
46
+ def sync_params(self):
47
+ size = float(get_world_size())
48
+ for v in self.model.parameters():
49
+ dist.all_reduce(v.data, op=dist.ReduceOp.SUM)
50
+ v.data /= size
51
+
52
+ for k, v in self.model.named_buffers():
53
+ if 'num_batches_tracked' in k:
54
+ continue
55
+ dist.all_reduce(v.data, op=dist.ReduceOp.SUM)
56
+ v.data /= size
57
+
58
+ def __call__(self, final_align=False):
59
+ self.count += 1
60
+ if (self.count % self.sync_rate == 0) or final_align:
61
+ with torch.no_grad():
62
+ if final_align:
63
+ self.param_align()
64
+ else:
65
+ self.sync_params()
66
+
67
+ for k, v in self.model.named_parameters():
68
+ global_param = self.global_params[k]
69
+ momentum = self.momentums[k]
70
+ grad = v.data * self.rescale_grad - global_param
71
+ momentum *= self.bm
72
+ global_param -= momentum
73
+ momentum += self.blr * grad
74
+ global_param += (1.0 + self.bm) * momentum
75
+ v.detach().copy_(global_param.detach())
libs/utils/scitsr/__init__.py ADDED
File without changes
libs/utils/scitsr/eval.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-present, Zewen Chi
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ from typing import List
9
+
10
+ from .relation import Relation
11
+ from .table import Table, Chunk
12
+
13
+
14
+ DIR_HORIZ = 1
15
+ DIR_VERT = 2
16
+ DIR_SAME_CELL = 3
17
+
18
+
19
+ def normalize(s:str, rule=0):
20
+ if rule == 0:
21
+ s = s.replace("\r", "")
22
+ s = s.replace("\n", "")
23
+ s = s.replace(" ", "")
24
+ s = s.replace("\t", "")
25
+ return s.upper()
26
+ else:
27
+ raise NotImplementedError
28
+
29
+
30
+ def eval_relations(gt:List[List], res:List[List], cmp_blank=True):
31
+ """Evaluate results
32
+
33
+ Args:
34
+ gt: a list of list of Relation
35
+ res: a list of list of Relation
36
+ """
37
+
38
+ #TODO to know how to calculate the total recall and prec
39
+
40
+ assert len(gt) == len(res)
41
+ tot_prec = 0
42
+ tot_recall = 0
43
+ total = 0
44
+ # print("evaluating result...")
45
+
46
+ # for _gt, _res in tqdm(zip(gt, res)):
47
+ # for _gt, _res in tqdm(zip(gt, res), total=len(gt), desc='eval'):
48
+ idx, t = 0, len(gt)
49
+ for _gt, _res in zip(gt, res):
50
+ idx += 1
51
+ print('Eval %d/%d (%d%%)' % (idx, t, idx / t * 100), ' ' * 45, end='\r')
52
+ corr = compare_rel(_gt, _res, cmp_blank)
53
+ precision = corr / len(_res) if len(_res) != 0 else 0
54
+ recall = corr / len(_gt) if len(_gt) != 0 else 0
55
+ tot_prec += precision
56
+ tot_recall += recall
57
+ total += 1
58
+ # print()
59
+
60
+ precision = tot_prec / total
61
+ recall = tot_recall / total
62
+ # print("Test on %d instances. Precision: %.2f, Recall: %.2f" % (
63
+ # total, precision, recall))
64
+ return precision, recall
65
+
66
+ def compare_rel(gt_rel:List[Relation], res_rel:List[Relation], cmp_blank=True):
67
+ count = 0
68
+
69
+ #print("compare_rel =======================")
70
+ #for gt in gt_rel:
71
+ # print("rel gt:", gt.from_text, gt.to_text, gt.direction)
72
+ #for gt in res_rel:
73
+ # print("rel res:", gt.from_text, gt.to_text, gt.direction)
74
+ #print("\n\n\n\n\n")
75
+
76
+ dup_res_rel = [r for r in res_rel]
77
+ for gt in gt_rel:
78
+ to_rm = None
79
+ for i, res in enumerate(dup_res_rel):
80
+ if gt.equal(res, cmp_blank):
81
+ to_rm = i
82
+ count += 1
83
+ break
84
+ if to_rm is not None:
85
+ dup_res_rel = dup_res_rel[:i] + dup_res_rel[i + 1:]
86
+
87
+ return count
88
+
89
+ def Table2Relations(t:Table):
90
+ """Convert a Table object to a List of Relation.
91
+ """
92
+ ret = []
93
+ cl = t.coo2cell_id
94
+ # remove duplicates with pair set
95
+ used = set()
96
+
97
+ # look right
98
+ for r in range(t.row_n):
99
+ for cFrom in range(t.col_n - 1):
100
+ cTo = cFrom + 1
101
+ loop = True
102
+ while loop and cTo < t.col_n:
103
+ fid, tid = cl[r][cFrom], cl[r][cTo]
104
+ if fid != -1 and tid != -1 and fid != tid:
105
+ if (fid, tid) not in used:
106
+ ret.append(Relation(
107
+ from_text=t.cells[fid].text,
108
+ to_text=t.cells[tid].text,
109
+ direction=DIR_HORIZ,
110
+ from_id=fid,
111
+ to_id=tid,
112
+ no_blanks=cTo - cFrom - 1
113
+ ))
114
+ used.add((fid, tid))
115
+ loop = False
116
+ else:
117
+ if fid != -1 and tid != -1 and fid == tid:
118
+ cFrom = cTo
119
+ cTo += 1
120
+
121
+ # look down
122
+ for c in range(t.col_n):
123
+ for rFrom in range(t.row_n - 1):
124
+ rTo = rFrom + 1
125
+ loop = True
126
+ while loop and rTo < t.row_n:
127
+ fid, tid = cl[rFrom][c], cl[rTo][c]
128
+ if fid != -1 and tid != -1 and fid != tid:
129
+ if (fid, tid) not in used:
130
+ ret.append(Relation(
131
+ from_text=t.cells[fid].text,
132
+ to_text=t.cells[tid].text,
133
+ direction=DIR_VERT,
134
+ from_id=fid,
135
+ to_id=tid,
136
+ no_blanks=rTo - rFrom - 1
137
+ ))
138
+ used.add((fid, tid))
139
+ loop = False
140
+ else:
141
+ if fid != -1 and tid != -1 and fid == tid:
142
+ rFrom = rTo
143
+ rTo += 1
144
+
145
+ return ret
146
+
147
+ def json2Table(json_obj, tid="", splitted_content=False):
148
+ """Construct a Table object from json object
149
+
150
+ Args:
151
+ json_obj: a json object
152
+ Returns:
153
+ a Table object
154
+ """
155
+ jo = json_obj["cells"]
156
+ row_n, col_n = 0, 0
157
+ cells = []
158
+ for co in jo:
159
+ content = co["content"]
160
+ if content is None: continue
161
+ if splitted_content:
162
+ content = " ".join(content)
163
+ else:
164
+ content = content.strip()
165
+ if content == "": continue
166
+ start_row = co["start_row"]
167
+ end_row = co["end_row"]
168
+ start_col = co["start_col"]
169
+ end_col = co["end_col"]
170
+ row_n = max(row_n, end_row)
171
+ col_n = max(col_n, end_col)
172
+ cell = Chunk(content, (start_row, end_row, start_col, end_col))
173
+ cells.append(cell)
174
+ return Table(row_n + 1, col_n + 1, cells, tid)
175
+
176
+ def json2Relations(json_obj, splitted_content):
177
+ return Table2Relations(json2Table(json_obj, "", splitted_content))
178
+
179
+
libs/utils/scitsr/relation.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-present, Zewen Chi
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import sys
8
+
9
+ def normalize(s:str, rule=0):
10
+
11
+ if rule == 0:
12
+ s = s.replace("\r", "")
13
+ s = s.replace("\n", "")
14
+ s = s.replace(" ", "")
15
+ s = s.replace("\t", "")
16
+ return s.upper()
17
+ else:
18
+ raise NotImplementedError
19
+
20
+
21
+ class Relation(object):
22
+
23
+ def __init__(self, from_text, to_text, direction, from_id=0, to_id=0, no_blanks=0):
24
+ self.from_text = from_text
25
+ self.to_text = to_text
26
+ self.direction = direction
27
+ self.no_blanks = no_blanks
28
+ self.from_id = from_id
29
+ self.to_id = to_id
30
+
31
+ def __eq__(self, rl):
32
+ this_ft = normalize(self.from_text)
33
+ this_tt = normalize(self.to_text)
34
+ rl_ft = normalize(rl.from_text)
35
+ rl_tt = normalize(rl.to_text)
36
+ if len(this_ft) == 0 or len(this_tt) == 0 or \
37
+ len(rl_ft) == 0 or len(rl_tt) == 0:
38
+ print("Warning: Text comparison of 0-length strings after normalization",
39
+ file=sys.stderr)
40
+
41
+ return this_ft == rl_ft and this_tt == rl_tt and \
42
+ self.direction == rl.direction and self.no_blanks == rl.no_blanks
43
+
44
+ def equal(self, rl, cmp_blank=True):
45
+ this_ft = normalize(self.from_text)
46
+ this_tt = normalize(self.to_text)
47
+ rl_ft = normalize(rl.from_text)
48
+ rl_tt = normalize(rl.to_text)
49
+ if len(this_ft) == 0 or len(this_tt) == 0 or \
50
+ len(rl_ft) == 0 or len(rl_tt) == 0:
51
+ print("Warning: Text comparison of 0-length strings after normalization",
52
+ file=sys.stderr)
53
+
54
+ return this_ft == rl_ft and this_tt == rl_tt and \
55
+ self.direction == rl.direction and \
56
+ (self.no_blanks == rl.no_blanks if cmp_blank else True)
57
+
58
+ def __str__(self):
59
+ return "%d:%d" % (self.direction, self.no_blanks)
libs/utils/scitsr/table.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-present, Zewen Chi
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+
9
+ from typing import Iterable, List, Tuple
10
+
11
+
12
+ def load_chunks(chunk_path):
13
+ with open(chunk_path, 'r') as f:
14
+ chunks = json.load(f)['chunks']
15
+ # NOTE remove the chunk with 0 len
16
+ ret = []
17
+ for chunk in chunks:
18
+ if chunk["pos"][1] < chunk["pos"][0]:
19
+ chunk["pos"][0], chunk["pos"][1] = chunk["pos"][1], chunk["pos"][0]
20
+ print("Warning load illegal chunk.")
21
+ c = Chunk.load_from_dict(chunk)
22
+ #if c.x2 == c.x1 or c.y2 == c.y1 or c.text == "":
23
+ # continue
24
+ ret.append(c)
25
+ return ret
26
+
27
+
28
+ class Box(object):
29
+
30
+ def __init__(self, pos):
31
+ """pos: (x1, x2, y1, y2)"""
32
+ self.set_pos(pos)
33
+
34
+ def set_pos(self, pos):
35
+ assert pos[0] <= pos[1]
36
+ assert pos[2] <= pos[3]
37
+ self.x1 = pos[0]
38
+ self.x2 = pos[1]
39
+ self.y1 = pos[2]
40
+ self.y2 = pos[3]
41
+ self.w = self.x2 - self.x1
42
+ self.h = self.y2 - self.y1
43
+ self.pos = pos
44
+
45
+ def __lt__(self, other):
46
+ return self.pos.__lt__(other.pos)
47
+
48
+ def __contains__(self, other):
49
+ if other.x1 >= self.x1 and other.x2 <= self.x2 and \
50
+ other.y1 >= self.y1 and other.y2 <= self.y2:
51
+ return True
52
+ return False
53
+
54
+ def __str__(self):
55
+ return 'Box(%d, %d, %d, %d)' % self.pos
56
+
57
+ def __hash__(self):
58
+ return self.pos.__hash__()
59
+
60
+
61
+ class Chunk(Box):
62
+
63
+ def __init__(self, text:str, pos:Tuple, size:float=0.0, cell_id=None):
64
+ super(Chunk, self).__init__(pos)
65
+ self.text = text
66
+ self.size = size
67
+ self.cell_id = cell_id
68
+
69
+ def __str__(self):
70
+ return 'Chunk(text="%s", pos=(%d, %d, %d, %d))' % (self.text, *self.pos)
71
+
72
+ def __repr__(self):
73
+ return self.__str__()
74
+
75
+ def dump_as_json_obj(self):
76
+ return {"text":self.text, "pos":self.pos, "cell_id":self.cell_id}
77
+
78
+ @classmethod
79
+ def load_from_dict(cls, d):
80
+ assert type(d) == dict
81
+ assert type(d["text"]) == str
82
+ assert len(d["pos"]) == 4
83
+ cell_id = d["cell_id"] if "cell_id" in d else None
84
+ return cls(d["text"].strip(), d["pos"], cell_id=cell_id)
85
+
86
+
87
+ class Table(object):
88
+
89
+ """
90
+ The output of table segmentation.
91
+ With the Table object, we can get the set of cells
92
+ and their corresponding text.
93
+ """
94
+ def __init__(self, row_n, col_n, cells:Iterable[Chunk]=None, tid=""):
95
+ # NOTE the Chunk object here represents the coordinate of
96
+ # the cell in the table.
97
+ # NOTE x in cell object represents the row id
98
+ self.tid = tid
99
+ self.row_n = row_n
100
+ self.col_n = col_n
101
+ self.coo2cell_id = [
102
+ [ -1 for _ in range(col_n) ] for _ in range(row_n) ]
103
+ self.cells:List[Chunk] = []
104
+ for cell in cells:
105
+ self.add_cell(cell)
106
+
107
+ def reverse(self, is_col=True):
108
+ cells = self.cells
109
+ self.cells = []
110
+ cell:Chunk = None
111
+ for cell in cells:
112
+ if is_col:
113
+ _c = Chunk(cell.text, (
114
+ self.row_n - cell.x2, self.row_n - cell.x1, cell.y1, cell.y2))
115
+ else:
116
+ _c = Chunk(cell.text, (
117
+ cell.x1, cell.x2, self.col_n - cell.y1, self.col_n - cell.y2))
118
+ self.add_cell(_c)
119
+
120
+ def add_cell(self, cell:Chunk):
121
+ # TODO Check conflicts of cells
122
+ assert cell.y2 < self.col_n
123
+ assert cell.x2 < self.row_n
124
+
125
+ for x in range(cell.x1, cell.x2 + 1, 1):
126
+ for y in range(cell.y1, cell.y2 + 1, 1):
127
+ self.coo2cell_id[x][y] = len(self.cells)
128
+ self.cells.append(cell)
129
+
130
+ def __getitem__(self, id_tuple):
131
+ row_id, col_id = id_tuple
132
+ assert row_id < self.row_n and col_id < self.col_n
133
+ return self.cells[self.coo2cell_id[row_id][col_id]]
libs/utils/teds.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 IBM
2
+ # Author: peter.zhong@au1.ibm.com
3
+ #
4
+ # This is free software; you can redistribute it and/or modify
5
+ # it under the terms of the Apache 2.0 License.
6
+ #
7
+ # This software is distributed in the hope that it will be useful,
8
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
9
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10
+ # Apache 2.0 License for more details.
11
+
12
+ import distance
13
+ from apted import APTED, Config
14
+ from apted.helpers import Tree
15
+ from lxml import etree, html
16
+ from collections import deque
17
+ from tqdm import tqdm
18
+ from concurrent.futures import ProcessPoolExecutor, as_completed
19
+
20
+
21
+ def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
22
+ """
23
+ A parallel version of the map function with a progress bar.
24
+
25
+ Args:
26
+ array (array-like): An array to iterate over.
27
+ function (function): A python function to apply to the elements of array
28
+ n_jobs (int, default=16): The number of cores to use
29
+ use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
30
+ keyword arguments to function
31
+ front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
32
+ Useful for catching bugs
33
+ Returns:
34
+ [function(array[0]), function(array[1]), ...]
35
+ """
36
+ # We run the first few iterations serially to catch bugs
37
+ if front_num > 0:
38
+ front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]]
39
+ else:
40
+ front = []
41
+ # If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
42
+ if n_jobs == 1:
43
+ return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
44
+ # Assemble the workers
45
+ with ProcessPoolExecutor(max_workers=n_jobs) as pool:
46
+ # Pass the elements of array into function
47
+ if use_kwargs:
48
+ futures = [pool.submit(function, **a) for a in array[front_num:]]
49
+ else:
50
+ futures = [pool.submit(function, a) for a in array[front_num:]]
51
+ kwargs = {
52
+ 'total': len(futures),
53
+ 'unit': 'it',
54
+ 'unit_scale': True,
55
+ 'leave': True
56
+ }
57
+ # Print out the progress as tasks complete
58
+ for f in tqdm(as_completed(futures), **kwargs):
59
+ pass
60
+ out = []
61
+ # Get the results from the futures.
62
+ for i, future in tqdm(enumerate(futures)):
63
+ try:
64
+ out.append(future.result())
65
+ except Exception as e:
66
+ out.append(e)
67
+ return front + out
68
+
69
+
70
+ class TableTree(Tree):
71
+ def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
72
+ self.tag = tag
73
+ self.colspan = colspan
74
+ self.rowspan = rowspan
75
+ self.content = content
76
+ self.children = list(children)
77
+
78
+ def bracket(self):
79
+ """Show tree using brackets notation"""
80
+ if self.tag == 'td':
81
+ result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
82
+ (self.tag, self.colspan, self.rowspan, self.content)
83
+ else:
84
+ result = '"tag": %s' % self.tag
85
+ for child in self.children:
86
+ result += child.bracket()
87
+ return "{{{}}}".format(result)
88
+
89
+
90
+ class CustomConfig(Config):
91
+ @staticmethod
92
+ def maximum(*sequences):
93
+ """Get maximum possible value
94
+ """
95
+ return max(map(len, sequences))
96
+
97
+ def normalized_distance(self, *sequences):
98
+ """Get distance from 0 to 1
99
+ """
100
+ return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
101
+
102
+ def rename(self, node1, node2):
103
+ """Compares attributes of trees"""
104
+ if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
105
+ return 1.
106
+ if node1.tag == 'td':
107
+ if node1.content or node2.content:
108
+ return self.normalized_distance(node1.content, node2.content)
109
+ return 0.
110
+
111
+
112
+ class TEDS(object):
113
+ ''' Tree Edit Distance basead Similarity
114
+ '''
115
+ def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
116
+ assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1'
117
+ self.structure_only = structure_only
118
+ self.n_jobs = n_jobs
119
+ self.ignore_nodes = ignore_nodes
120
+ self.__tokens__ = []
121
+
122
+ def tokenize(self, node):
123
+ ''' Tokenizes table cells
124
+ '''
125
+ self.__tokens__.append('<%s>' % node.tag)
126
+ if node.text is not None:
127
+ self.__tokens__ += list(node.text)
128
+ for n in node.getchildren():
129
+ self.tokenize(n)
130
+ if node.tag != 'unk':
131
+ self.__tokens__.append('</%s>' % node.tag)
132
+ if node.tag != 'td' and node.tail is not None:
133
+ self.__tokens__ += list(node.tail)
134
+
135
+ def load_html_tree(self, node, parent=None):
136
+ ''' Converts HTML tree to the format required by apted
137
+ '''
138
+ global __tokens__
139
+ if node.tag == 'td':
140
+ if self.structure_only:
141
+ cell = []
142
+ else:
143
+ self.__tokens__ = []
144
+ self.tokenize(node)
145
+ cell = self.__tokens__[1:-1].copy()
146
+ new_node = TableTree(node.tag,
147
+ int(node.attrib.get('colspan', '1')),
148
+ int(node.attrib.get('rowspan', '1')),
149
+ cell, *deque())
150
+ else:
151
+ new_node = TableTree(node.tag, None, None, None, *deque())
152
+ if parent is not None:
153
+ parent.children.append(new_node)
154
+ if node.tag != 'td':
155
+ for n in node.getchildren():
156
+ self.load_html_tree(n, new_node)
157
+ if parent is None:
158
+ return new_node
159
+
160
+ def evaluate(self, pred, true):
161
+ ''' Computes TEDS score between the prediction and the ground truth of a
162
+ given sample
163
+ '''
164
+ if (not pred) or (not true):
165
+ return 0.0
166
+ parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
167
+ pred = html.fromstring(pred, parser=parser)
168
+ true = html.fromstring(true, parser=parser)
169
+ if pred.xpath('body/table') and true.xpath('body/table'):
170
+ pred = pred.xpath('body/table')[0]
171
+ true = true.xpath('body/table')[0]
172
+ if self.ignore_nodes:
173
+ etree.strip_tags(pred, *self.ignore_nodes)
174
+ etree.strip_tags(true, *self.ignore_nodes)
175
+ n_nodes_pred = len(pred.xpath(".//*"))
176
+ n_nodes_true = len(true.xpath(".//*"))
177
+ n_nodes = max(n_nodes_pred, n_nodes_true)
178
+ tree_pred = self.load_html_tree(pred)
179
+ tree_true = self.load_html_tree(true)
180
+ distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
181
+ return 1.0 - (float(distance) / n_nodes)
182
+ else:
183
+ return 0.0
184
+
185
+ def batch_evaluate(self, pred_json, true_json):
186
+ ''' Computes TEDS score between the prediction and the ground truth of
187
+ a batch of samples
188
+ @params pred_json: {'FILENAME': 'HTML CODE', ...}
189
+ @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
190
+ @output: {'FILENAME': 'TEDS SCORE', ...}
191
+ '''
192
+ samples = true_json.keys()
193
+ if self.n_jobs == 1:
194
+ scores = [self.evaluate(pred_json.get(filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
195
+ else:
196
+ inputs = [{'pred': pred_json.get(filename, ''), 'true': true_json[filename]['html']} for filename in samples]
197
+ scores = parallel_process(inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
198
+ scores = dict(zip(samples, scores))
199
+ return scores
200
+
201
+
202
+ if __name__ == '__main__':
203
+ import json
204
+ import pprint
205
+ with open('sample_pred.json') as fp:
206
+ pred_json = json.load(fp)
207
+ with open('sample_gt.json') as fp:
208
+ true_json = json.load(fp)
209
+ teds = TEDS(n_jobs=4)
210
+ scores = teds.batch_evaluate(pred_json, true_json)
211
+ pp = pprint.PrettyPrinter()
212
+ pp.pprint(scores)
libs/utils/teds_multiprocess.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import tqdm
3
+ from libs.utils.teds import TEDS
4
+ from collections import defaultdict
5
+
6
+
7
+ # def parse_args():
8
+ # import argparse
9
+ # parser = argparse.ArgumentParser()
10
+ # parser.add_argument('pred_path', type=str, default=None)
11
+ # parser.add_argument('label_path', type=str, default=None)
12
+ # parser.add_argument('-s', '--structure_only', action='store_true')
13
+ # parser.add_argument('-n', '--num_workers', type=int, default=1)
14
+ # args = parser.parse_args()
15
+ # return args
16
+
17
+
18
+ def is_simple(data):
19
+ if ('colspan' in data) or ('rowspan' in data):
20
+ return False
21
+ else:
22
+ return True
23
+
24
+
25
+ def judge_type(data):
26
+ if is_simple(data):
27
+ return 'Simple'
28
+ else:
29
+ return 'Complex'
30
+
31
+
32
+ def single_process(pred_htmls, label_htmls, structure_only=False):
33
+ evaluator = TEDS(structure_only=structure_only)
34
+ scores = dict()
35
+ for key in tqdm.tqdm(label_htmls.keys()):
36
+ pred_html = pred_htmls.get(key, '')
37
+ label_html = label_htmls[key]['html']
38
+ score = evaluator.evaluate(pred_html, label_html)
39
+ scores[key] = score
40
+ return scores
41
+
42
+
43
+ def _worker(pred_htmls, label_htmls, keys, result_queue, structure_only=False):
44
+ evaluator = TEDS(structure_only=structure_only)
45
+ for key in keys:
46
+ pred_html = pred_htmls.get(key, '')
47
+ label_html = label_htmls[key]['html']
48
+ score = evaluator.evaluate(pred_html, label_html)
49
+ result_queue.put((key, score))
50
+
51
+
52
+ def multi_process(pred_htmls, label_htmls, num_workers, structure_only=False):
53
+ import multiprocessing
54
+ manager = multiprocessing.Manager()
55
+ result_queue = manager.Queue()
56
+ keys = list(label_htmls.keys())
57
+ workers = list()
58
+ for worker_idx in range(num_workers):
59
+ worker = multiprocessing.Process(
60
+ target=_worker,
61
+ args=(
62
+ pred_htmls,
63
+ label_htmls,
64
+ keys[worker_idx::num_workers],
65
+ result_queue
66
+ )
67
+ )
68
+ worker.daemon = True
69
+ worker.start()
70
+ workers.append(worker)
71
+ scores = dict()
72
+ tq = tqdm.tqdm(total=len(keys))
73
+ for _ in range(len(keys)):
74
+ key, val = result_queue.get()
75
+ scores[key] = val
76
+ teds = sum(scores.values()) / len(scores)
77
+ tq.set_description('Teds: %s' % teds, False)
78
+ tq.update()
79
+ tq.close()
80
+ return scores
81
+
82
+
83
+ def evaluate(pred_htmls, label_htmls, num_workers, structure_only=False):
84
+ if num_workers <= 1:
85
+ scores = single_process(pred_htmls, label_htmls, structure_only)
86
+ else:
87
+ scores = multi_process(pred_htmls, label_htmls, num_workers, structure_only)
88
+ teds = sum(scores.values())/len(scores)
89
+
90
+ typed_teds = defaultdict(list)
91
+ for key, score in scores.items():
92
+ data_type = judge_type(label_htmls[key]['html'])
93
+ typed_teds[data_type].append(score)
94
+
95
+ typed_teds = {key: sum(val)/len(val) for key, val in typed_teds.items()}
96
+ return teds, typed_teds
97
+
98
+
99
+ # def main():
100
+ # args = parse_args()
101
+ # pred_data = json.load(open(args.pred_path))
102
+ # label_data = json.load(open(args.label_path))
103
+
104
+ # teds, typed_teds = evaluate(pred_data, label_data, args.num_workers, args.structure_only)
105
+ # print('Teds: %s' % teds)
106
+ # for key, val in typed_teds.items():
107
+ # print(' %s Teds: %s' % (key, val))
108
+
109
+
110
+ # if __name__ == '__main__':
111
+ # main()
libs/utils/time_counter.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import time
3
+ import datetime
4
+
5
+
6
+ class TimeCounter:
7
+ def __init__(self, start_epoch, num_epochs, epoch_iters):
8
+ self.start_epoch = start_epoch
9
+ self.num_epochs = num_epochs
10
+ self.epoch_iters = epoch_iters
11
+ self.start_time = None
12
+
13
+ def reset(self):
14
+ self.start_time = time.time()
15
+
16
+ def step(self, epoch, batch):
17
+ used = time.time() - self.start_time
18
+ finished_batch_nums = (epoch - self.start_epoch) * self.epoch_iters + batch
19
+ batch_time_cost = used / finished_batch_nums
20
+ total = (self.num_epochs - self.start_epoch) * self.epoch_iters * batch_time_cost
21
+ left = total - used
22
+ return str(datetime.timedelta(seconds=left))
23
+
24
+
25
+ def format_table(table, padding=1):
26
+ table = [[str(subitem) for subitem in item] for item in table]
27
+ num_cols = max([len(item) for item in table])
28
+ cols_width = [0] * num_cols
29
+
30
+ for row in table:
31
+ for col_idx, cell in enumerate(row):
32
+ cols_width[col_idx] = max(cols_width[col_idx], len(cell))
33
+
34
+ string = '��'
35
+ for col_idx in range(num_cols):
36
+ string += '��' * (padding * 2 + cols_width[col_idx])
37
+ if col_idx == num_cols - 1:
38
+ string += '��'
39
+ else:
40
+ string += '��'
41
+ string += '\n'
42
+
43
+ for row_idx, row in enumerate(table):
44
+ string += '��'
45
+ for col_idx in range(num_cols):
46
+ if col_idx < len(row):
47
+ word = row[col_idx]
48
+ else:
49
+ word = ''
50
+ col_width = cols_width[col_idx]
51
+ left_pad = (col_width - len(word))//2
52
+ right_pad = col_width - len(word) - left_pad
53
+ string += ' ' * (padding + left_pad)
54
+ string += word
55
+ string += ' ' * (padding + right_pad)
56
+ string += '��'
57
+
58
+ string += '\n'
59
+
60
+ if row_idx < len(table) - 1:
61
+ string += '��'
62
+ else:
63
+ string += '��'
64
+ for col_idx in range(num_cols):
65
+ string += '��' * (padding * 2 + cols_width[col_idx])
66
+ if col_idx == num_cols - 1:
67
+ if row_idx < len(table) - 1:
68
+ string += '��'
69
+ else:
70
+ string += '��'
71
+ else:
72
+ if row_idx < len(table) - 1:
73
+ string += '��'
74
+ else:
75
+ string += '��'
76
+
77
+ string += '\n'
78
+ return string
79
+
80
+
81
+ class TicTocCounter:
82
+ def __init__(self):
83
+ self.tics = dict()
84
+ self.seps = defaultdict(list)
85
+
86
+ def tic(self, name):
87
+ self.tics[name] = time.time()
88
+
89
+ def toc(self, name):
90
+ toc = time.time()
91
+ if name in self.tics:
92
+ self.seps[name].append(toc-self.tics[name])
93
+
94
+ def __repr__(self):
95
+ string = 'TicTocCount Result:\n'
96
+ infos = [['Name', 'Mean Time', 'Total Time']]
97
+ for key, val in self.seps.items():
98
+ mean = sum(val)/len(val)
99
+ total = sum(val)
100
+ infos.append([key, '%0.4f' % mean, '%0.4f' % total])
101
+ string += format_table(infos)
102
+ return string
103
+
104
+ def reset(self):
105
+ self.tics.clear()
106
+ self.seps.clear()
107
+
108
+ global_tictoc_counter = TicTocCounter()
libs/utils/utils.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import copy
3
+ import Polygon
4
+ import numpy as np
5
+
6
+
7
+ def cal_mean_lr(optimizer):
8
+ lrs = [group['lr'] for group in optimizer.param_groups]
9
+ return sum(lrs)/len(lrs)
10
+
11
+
12
+ def cal_pr_f1(pr_info):
13
+ precision = pr_info[0] / pr_info[1]
14
+ recall = pr_info[0] / pr_info[2]
15
+ f1 = 2*precision*recall/(precision+recall)
16
+ return precision, recall, f1
17
+
18
+
19
+ def match_segment_spans(segments, spans):
20
+ matched_segments = list()
21
+ matched_spans = list()
22
+
23
+ for segment_idx, segment in enumerate(segments):
24
+ for span_idx, span in enumerate(spans):
25
+ if span_idx not in matched_spans:
26
+ if (segment >= span[0]) and (segment < span[1]):
27
+ matched_segments.append(segment_idx)
28
+ matched_spans.append(span_idx)
29
+
30
+ return matched_segments, matched_spans
31
+
32
+
33
+ def find_unmatch_segment_spans(segments, spans):
34
+ unmatched_segments = list()
35
+ for segment_idx, segment in enumerate(segments):
36
+ matched = False
37
+ for span in spans:
38
+ if (segment >= span[0]) and (segment < span[1]):
39
+ matched = True
40
+ break
41
+ if not matched:
42
+ unmatched_segments.append(segment_idx)
43
+
44
+ return unmatched_segments
45
+
46
+
47
+ def parse_layout(spans, num_rows, num_cols):
48
+ layout = np.full([num_rows, num_cols], -1, dtype=np.int)
49
+ cell_count = 0
50
+ for x1, y1, x2, y2 in spans:
51
+ layout[y1:y2+1, x1:x2+1] = cell_count
52
+ cell_count += 1
53
+
54
+ cells_id = list()
55
+ for row_idx in range(num_rows):
56
+ for col_idx in range(num_cols):
57
+ cell_id = layout[row_idx, col_idx]
58
+ if cell_id in cells_id:
59
+ layout[row_idx, col_idx] = cells_id.index(cell_id)
60
+ else:
61
+ layout[row_idx, col_idx] = len(cells_id)
62
+ cells_id.append(cell_id)
63
+ return layout
64
+
65
+
66
+ def parse_cells(layout, spans, row_segments, col_segments):
67
+ cells = list()
68
+ num_cells = np.max(layout) + 1
69
+ for cell_id in range(num_cells):
70
+ cell_positions = np.argwhere(layout == cell_id)
71
+ y1 = np.min(cell_positions[:, 0])
72
+ y2 = np.max(cell_positions[:, 0])
73
+ x1 = np.min(cell_positions[:, 1])
74
+ x2 = np.max(cell_positions[:, 1])
75
+ assert np.all(layout[y1:y2, x1:x2] == cell_id)
76
+ x1 = col_segments[x1]
77
+ x2 = col_segments[x2+1]
78
+ y1 = row_segments[y1]
79
+ y2 = row_segments[y2+1]
80
+ cell = dict(
81
+ segmentation=[[[x1, y1], [x2, y1], [x2, y2], [x1, y2]]]
82
+ )
83
+ cells.append(cell)
84
+ for span in spans:
85
+ cell_id = layout[span[1], span[0]]
86
+ cells[cell_id]['transcript'] = 'None'
87
+ return cells
88
+
89
+
90
+ def segmentation_to_bbox(segmentation):
91
+ x1 = min([min([pt[0] for pt in contour]) for contour in segmentation])
92
+ y1 = min([min([pt[1] for pt in contour]) for contour in segmentation])
93
+ x2 = max([max([pt[0] for pt in contour]) for contour in segmentation])
94
+ y2 = max([max([pt[1] for pt in contour]) for contour in segmentation])
95
+ return [x1, y1, x2, y2]
96
+
97
+
98
+ def extend_cell_lines(cells, lines):
99
+ def segmentation_to_polygon(segmentation):
100
+ polygon = Polygon.Polygon()
101
+ for contour in segmentation:
102
+ polygon = polygon + Polygon.Polygon(contour)
103
+ return polygon
104
+
105
+ lines = copy.deepcopy(lines)
106
+
107
+ cells_poly = [segmentation_to_polygon(item['segmentation']) for item in cells]
108
+ lines_poly = [segmentation_to_polygon(item['segmentation']) for item in lines]
109
+
110
+ cells_lines = [[] for _ in range(len(cells))]
111
+
112
+ for line_idx, line_poly in enumerate(lines_poly):
113
+ if line_poly.area() == 0:
114
+ continue
115
+ line_area = line_poly.area()
116
+ max_overlap = 0
117
+ max_overlap_idx = None
118
+ for cell_idx, cell_poly in enumerate(cells_poly):
119
+ overlap = (cell_poly & line_poly).area()/line_area
120
+ if overlap > max_overlap:
121
+ max_overlap_idx = cell_idx
122
+ max_overlap = overlap
123
+ if max_overlap > 0:
124
+ cells_lines[max_overlap_idx].append(line_idx)
125
+ lines_y1 = [segmentation_to_bbox(item['segmentation'])[1] for item in lines]
126
+ cells_lines = [sorted(item, key=lambda idx: lines_y1[idx]) for item in cells_lines]
127
+
128
+ for cell, cell_lines in zip(cells, cells_lines):
129
+ cell['lines_idx'] = cell_lines
130
+
131
+
132
+ def rerange_layout(table):
133
+ layout = table['layout']
134
+ cells = table['cells']
135
+ valid_cells_id = list()
136
+ for row_idx in range(layout.shape[0]):
137
+ for col_idx in range(layout.shape[1]):
138
+ cell_id = layout[row_idx, col_idx]
139
+ if cell_id not in valid_cells_id:
140
+ valid_cells_id.append(cell_id)
141
+ layout[row_idx, col_idx] = valid_cells_id.index(cell_id)
142
+ cells = [cells[cell_id] for cell_id in valid_cells_id]
143
+ table['layout'] = layout
144
+ table['cells'] = cells
145
+
146
+ def cal_cell_spans(table):
147
+ layout = table['layout']
148
+ num_cells = len(table['cells'])
149
+ cells_span = list()
150
+ for cell_id in range(num_cells):
151
+ cell_positions = np.argwhere(layout == cell_id)
152
+ y1 = np.min(cell_positions[:, 0])
153
+ y2 = np.max(cell_positions[:, 0])
154
+ x1 = np.min(cell_positions[:, 1])
155
+ x2 = np.max(cell_positions[:, 1])
156
+ assert np.all(layout[y1:y2, x1:x2] == cell_id)
157
+ cells_span.append([x1, y1, x2, y2])
158
+ return cells_span
159
+
160
+
161
+ def remove_repeat_rcs(table):
162
+ layout = table['layout']
163
+ head_rows = table['head_rows']
164
+ body_rows = table['body_rows']
165
+ while True:
166
+ num_rows = layout.shape[0]
167
+ num_cols = layout.shape[1]
168
+ valid_rows_idx = list()
169
+ valid_rows_key = list()
170
+
171
+ for row_idx in range(num_rows):
172
+ row = layout[row_idx, :]
173
+ if len(np.unique(row)) == 1 and row_idx in body_rows: # remove repeated row
174
+ continue
175
+ row_key = ','.join([str(item) for item in row])
176
+ if row_key not in valid_rows_key:
177
+ valid_rows_idx.append(row_idx)
178
+ valid_rows_key.append(row_key)
179
+
180
+ valid_cols_idx = list()
181
+ valid_cols_key = list()
182
+ for col_idx in range(num_cols):
183
+ col = layout[:, col_idx]
184
+ if len(np.unique(col)) == 1: # remove repeated col
185
+ continue
186
+ col_key = ','.join([str(item) for item in col])
187
+ if col_key not in valid_cols_key:
188
+ valid_cols_idx.append(col_idx)
189
+ valid_cols_key.append(col_key)
190
+ if (len(valid_rows_idx) == num_rows) and (len(valid_cols_idx) == num_cols):
191
+ break
192
+ layout = layout[valid_rows_idx][:, valid_cols_idx]
193
+ head_rows = [n_idx for n_idx, o_idx in enumerate(valid_rows_idx) if o_idx in head_rows]
194
+ body_rows = [n_idx for n_idx, o_idx in enumerate(valid_rows_idx) if o_idx in body_rows]
195
+
196
+ table['layout'] = layout
197
+ table['head_rows'] = head_rows
198
+ table['body_rows'] = body_rows
199
+ rerange_layout(table)
200
+
201
+
202
+ def pred_result_to_table(pred_result):
203
+ row_segments, col_segments, divide, spans = pred_result
204
+ num_rows = len(row_segments) - 1
205
+ num_cols = len(col_segments) - 1
206
+
207
+ layout = parse_layout(spans, num_rows, num_cols)
208
+ cells = parse_cells(layout, spans, row_segments, col_segments)
209
+ head_rows = list(range(0, divide))
210
+ body_rows = list(range(divide, num_rows))
211
+
212
+ table = dict(
213
+ layout=layout,
214
+ head_rows=head_rows,
215
+ body_rows=body_rows,
216
+ cells=cells
217
+ )
218
+
219
+ # remove_repeat_rcs(table)
220
+
221
+ return table
222
+
223
+
224
+ def is_simple_table(table):
225
+ layout = table['layout']
226
+ num_rows, num_cols = layout.shape
227
+ if num_rows * num_cols == len(table['cells']):
228
+ return True
229
+ else:
230
+ return False
231
+
232
+
233
+ def tensor_to_image(tensor):
234
+ image = tensor.detach().cpu().numpy()
235
+ if (len(image.shape) == 3) and (image.shape[0] != 3) and (image.shape[0] != 1):
236
+ image = np.sqrt(np.sum(np.power(image, 2), axis=0, keepdims=True))
237
+ image = 255 * (image-np.min(image))/(np.max(image) - np.min(image))
238
+ image = image.astype(np.uint8)
239
+ if len(image.shape) == 3:
240
+ image = np.transpose(image, (1, 2, 0)).copy()
241
+ if image.shape[2] == 1:
242
+ image = image[:, :, 0]
243
+ return image
244
+
245
+
246
+ def visualize_layout(image, table):
247
+ def draw_segmentation(image, segmentation, color):
248
+ for contour in segmentation:
249
+ contour = np.array(contour, dtype=np.int32)
250
+ image = cv2.polylines(image, [contour], True, color)
251
+ return image
252
+ for cell in table['cells']:
253
+ if 'segmentation' in cell:
254
+ image = draw_segmentation(image, cell['segmentation'], (255, 0, 0))
255
+ return image
256
+
257
+ virtual_chars = ["<b>", "</b>", "<i>", "</i>", "<sup>", "</sup>", "<sub>", "</sub>", "<overline>", "</overline>", "<underline>", "</underline>", "<strike>", "</strike>"]
258
+
259
+
260
+ def is_blank(content):
261
+ global virtual_chars
262
+
263
+ new_content = content
264
+ for item in virtual_chars:
265
+ new_content = new_content.replace(item, '')
266
+ return new_content.strip() == ''
267
+
268
+
269
+ def filt_content(content, filt_blank=False, filt_virtual=False, filt_pad=False):
270
+ global virtual_chars
271
+ if filt_blank:
272
+ if is_blank(content):
273
+ content = ''
274
+
275
+ if filt_virtual:
276
+ for item in content:
277
+ content = content.replace(item, '')
278
+
279
+ if filt_pad:
280
+ content = content.strip()
281
+
282
+ return content
283
+
284
+
285
+ def filt_transcript(html, filt_blank=False, filt_virtual=False, filt_pad=False):
286
+ start_idx = 0
287
+ while '<td' in html[start_idx:]:
288
+ start_idx = html[start_idx:].index('<td') + start_idx
289
+ content_start_idx = html[start_idx:].index('>') + 1 + start_idx
290
+ content_end_idx = html[content_start_idx:].index('</td>') + content_start_idx
291
+ end_idx = content_end_idx + len('</td>')
292
+
293
+ content = html[content_start_idx:content_end_idx]
294
+ content = filt_content(content, filt_blank, filt_virtual, filt_pad)
295
+ html = html[:content_start_idx] + content + html[content_end_idx:]
296
+ start_idx = end_idx - (content_end_idx-content_start_idx - len(content))
297
+ return html
libs/utils/vocab.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Vocab:
2
+ key_words = [
3
+ '</line>',
4
+ '</none>', # []
5
+ '</bold>', # ['<b>', ' ', '</b>']
6
+ '</space>' # [' ']
7
+ ]
8
+
9
+ def __init__(self):
10
+ self._words_ids_map = dict()
11
+ self._ids_words_map = dict()
12
+
13
+ for word_id, word in enumerate(self.key_words):
14
+ self._words_ids_map[word] = word_id
15
+ self._ids_words_map[word_id] = word
16
+
17
+ self.line_id = self._words_ids_map['</line>']
18
+ self.none_id = self._words_ids_map['</none>']
19
+ self.bold_id = self._words_ids_map['</bold>']
20
+ self.space_id = self._words_ids_map['</space>']
21
+ self.blank_ids = [self.none_id, self.bold_id, self.space_id]
22
+
23
+ def __len__(self):
24
+ return len(self._words_ids_map)
25
+
26
+ def word_to_id(self, word):
27
+ return self._words_ids_map[word]
28
+
29
+ def words_to_ids(self, words):
30
+ return [self.word_to_id(word) for word in words]
31
+
32
+ def id_to_word(self, word_id):
33
+ return self._ids_words_map[word_id]
34
+
35
+ def ids_to_words(self, words_id):
36
+ return [self.id_to_word(word_id) for word_id in words_id]
requirements.txt ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ addict==2.4.0
2
+ aliyun-python-sdk-core==2.16.0
3
+ aliyun-python-sdk-kms==2.16.5
4
+ anyio==4.10.0
5
+ apted==1.0.3
6
+ beautifulsoup4==4.12.2
7
+ blinker==1.4
8
+ certifi==2025.8.3
9
+ charset-normalizer==3.4.3
10
+ click==8.1.8
11
+ colorama==0.4.6
12
+ contourpy==1.3.0
13
+ crcmod==1.7
14
+ cryptography==3.4.8
15
+ cycler==0.12.1
16
+ dbus-python==1.2.18
17
+ Distance==0.1.3
18
+ distro==1.7.0
19
+ exceptiongroup==1.3.0
20
+ filelock==3.14.0
21
+ fonttools==4.59.2
22
+ h11==0.16.0
23
+ httpcore==1.0.9
24
+ httplib2==0.20.2
25
+ httpx==0.28.1
26
+ idna==3.10
27
+ importlib-metadata==4.6.4
28
+ importlib_resources==6.5.2
29
+ jeepney==0.7.1
30
+ jmespath==0.10.0
31
+ keyring==23.5.0
32
+ kiwisolver==1.4.7
33
+ launchpadlib==1.10.16
34
+ lazr.restfulclient==0.14.4
35
+ lazr.uri==1.0.6
36
+ lxml==4.9.2
37
+ Mako==1.1.3
38
+ Markdown==3.3.6
39
+ markdown-it-py==3.0.0
40
+ MarkupSafe==2.0.1
41
+ matplotlib==3.7.1
42
+ mdurl==0.1.2
43
+ mmcv-full==1.6.2
44
+ mmdet==2.28.2
45
+ model-index==0.1.11
46
+ more-itertools==8.10.0
47
+ numpy==1.26.4
48
+ oauthlib==3.2.0
49
+ opencv-python==4.7.0.72
50
+ opendatalab==0.0.10
51
+ openmim==0.3.9
52
+ openxlab==0.1.2
53
+ ordered-set==4.1.0
54
+ oss2==2.17.0
55
+ packaging==24.2
56
+ pandas==2.0.2
57
+ Pillow==10.0.0
58
+ platformdirs==4.4.0
59
+ polygon==1.1.0
60
+ Polygon3==3.0.9.1
61
+ pycocotools==2.0.10
62
+ pycryptodome==3.23.0
63
+ Pygments==2.19.2
64
+ PyGObject==3.42.1
65
+ PyJWT==2.3.0
66
+ pyparsing==3.2.3
67
+ python-apt==2.4.0+ubuntu4
68
+ python-dateutil==2.9.0.post0
69
+ pytz==2023.4
70
+ PyYAML==6.0.2
71
+ requests==2.28.2
72
+ rich==13.4.2
73
+ scipy==1.13.1
74
+ seaborn==0.12.2
75
+ SecretStorage==3.3.1
76
+ shapely==2.0.1
77
+ six==1.17.0
78
+ sniffio==1.3.1
79
+ soupsieve==2.8
80
+ tabulate==0.9.0
81
+ terminaltables==3.1.10
82
+ tomli==2.2.1
83
+ torch==1.12.0
84
+ torchvision==0.13.0
85
+ tqdm==4.65.0
86
+ typing_extensions==4.15.0
87
+ tzdata==2025.2
88
+ urllib3==1.26.20
89
+ wadllib==1.3.6
90
+ websocket-client==1.8.0
91
+ websockets==15.0.1
92
+ yapf==0.43.0
93
+ zipp==3.23.0
runner/train.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import torch
3
+ import tqdm
4
+ import json
5
+ import os
6
+ import sys
7
+ sys.path.append('./')
8
+ sys.path.append('../')
9
+ import numpy as np
10
+ from torch.optim.lr_scheduler import CosineAnnealingLR
11
+ from collections import defaultdict
12
+ from libs.utils.cal_f1 import pred_result_to_table, table_to_relations, evaluate_f1
13
+ from libs.utils.comm import distributed, synchronize
14
+ from libs.utils.checkpoint import load_checkpoint, save_checkpoint
15
+ from libs.data import create_train_dataloader, create_valid_dataloader
16
+ from libs.utils.model_synchronizer import ModelSynchronizer
17
+ from libs.utils.time_counter import TimeCounter
18
+ from libs.utils.utils import is_simple_table
19
+ from libs.utils.utils import cal_mean_lr
20
+ from libs.utils.counter import Counter
21
+ from libs.utils import logger
22
+ from libs.model import build_model
23
+ from libs.configs import cfg, setup_config
24
+
25
+
26
+ metrics_name = ['f1']
27
+ best_metrics = [0.0]
28
+
29
+
30
+ def init():
31
+ import argparse
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument('--cfg', type=str, default='debug')
34
+ parser.add_argument('--local_rank', type=int, default=0)
35
+ args = parser.parse_args()
36
+ setup_config(args.cfg)
37
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
38
+ num_gpus = int(os.environ['MORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
39
+ distributed = num_gpus > 1
40
+ if distributed:
41
+ torch.cuda.set_device(args.local_rank)
42
+ torch.distributed.init_process_group(backend='nccl', init_method='env://')
43
+ synchronize()
44
+ logger.setup_logger('Line Detect Model', cfg.work_dir, 'train.log')
45
+ logger.info('Use config:%s' % args.cfg)
46
+
47
+
48
+ def train(cfg, epoch, dataloader, model, optimizer, scheduler, time_counter, synchronizer=None):
49
+ model.train()
50
+ counter = Counter(cache_nums=1000)
51
+ for it, data_batch in enumerate(dataloader):
52
+ ids = data_batch['ids']
53
+ images_size = data_batch['images_size']
54
+ images = data_batch['images'].to(cfg.device)
55
+ cls_labels = data_batch['cls_labels'].to(cfg.device)
56
+ labels_mask = data_batch['labels_mask'].to(cfg.device)
57
+ rows_fg_spans = data_batch['rows_fg_spans']
58
+ rows_bg_spans = data_batch['rows_bg_spans']
59
+ cols_fg_spans = data_batch['cols_fg_spans']
60
+ cols_bg_spans = data_batch['cols_bg_spans']
61
+ cells_spans = data_batch['cells_spans']
62
+ divide_labels = data_batch['divide_labels'].to(cfg.device)
63
+ layouts = data_batch['layouts'].to(cfg.device)
64
+
65
+ try:
66
+ optimizer.zero_grad()
67
+ pred_result, result_info = model(
68
+ images, images_size,
69
+ cls_labels, labels_mask, layouts,
70
+ rows_fg_spans, rows_bg_spans,
71
+ cols_fg_spans, cols_bg_spans,
72
+ cells_spans, divide_labels,
73
+ )
74
+ loss = sum([val for key, val in result_info.items() if 'loss' in key])
75
+ loss.backward()
76
+ optimizer.step()
77
+ scheduler.step()
78
+ counter.update(result_info)
79
+ except:
80
+ logger.info('CUDA Out Of Memory')
81
+
82
+ if it % cfg.log_sep == 0:
83
+ logger.info(
84
+ '[Train][Epoch %03d Iter %04d][Memory: %.0f ][Mean LR: %f ][Left: %s] %s' %
85
+ (
86
+ epoch,
87
+ it,
88
+ torch.cuda.max_memory_allocated()/1024/1024,
89
+ cal_mean_lr(optimizer),
90
+ time_counter.step(epoch, it + 1),
91
+ counter.format_mean(sync=False)
92
+ )
93
+ )
94
+
95
+ if synchronizer is not None:
96
+ synchronizer()
97
+ if synchronizer is not None:
98
+ synchronizer(final_align=True)
99
+
100
+
101
+ def valid(cfg, dataloader, model):
102
+ model.eval()
103
+ total_label_relations = list()
104
+ total_pred_relations = list()
105
+ total_relations_metric = list()
106
+
107
+ for it, data_batch in enumerate(tqdm.tqdm(dataloader)):
108
+ ids = data_batch['ids']
109
+ images_size = data_batch['images_size']
110
+ images = data_batch['images'].to(cfg.device)
111
+ tables = data_batch['tables']
112
+ pred_result, _ = model(images, images_size)
113
+ pred_tables = [
114
+ pred_result_to_table(tables[batch_idx],
115
+ (pred_result[0][batch_idx], pred_result[1][batch_idx],
116
+ pred_result[2][batch_idx], pred_result[3][batch_idx])
117
+ )
118
+ for batch_idx in range(len(ids))
119
+ ]
120
+ pred_relations = [table_to_relations(table) for table in pred_tables]
121
+ total_pred_relations.extend(pred_relations)
122
+ # label
123
+ label_relations = []
124
+ for table in tables:
125
+ label_path = os.path.join(cfg.valid_data_dir, table['label_path'])
126
+ with open(table['label_path'], 'r') as f:
127
+ label_relations.append(json.load(f))
128
+ total_label_relations.extend(label_relations)
129
+
130
+ # cal P, R, F1
131
+ total_relations_metric = evaluate_f1(total_label_relations, total_pred_relations, num_workers=40)
132
+ P, R, F1 = np.array(total_relations_metric).mean(0).tolist()
133
+ F1 = 2 * P * R / (P + R)
134
+ logger.info('[Valid] Total Type Mertric: Precision: %s, Recall: %s, F1-Score: %s' % (P, R, F1))
135
+ return (F1,)
136
+
137
+
138
+ def build_optimizer(cfg, model):
139
+ params = list()
140
+ for _, value in model.named_parameters():
141
+ if not value.requires_grad:
142
+ continue
143
+ lr = cfg.base_lr
144
+ weight_decay = cfg.weight_decay
145
+ params += [{'params': [value], 'lr': lr, 'weight_decay': weight_decay}]
146
+ optimizer = torch.optim.Adam(params, cfg.base_lr)
147
+ return optimizer
148
+
149
+
150
+ def build_scheduler(cfg, optimizer, epoch_iters, start_epoch=0):
151
+ scheduler = CosineAnnealingLR(
152
+ optimizer=optimizer,
153
+ T_max=cfg.num_epochs * epoch_iters,
154
+ eta_min=cfg.min_lr,
155
+ last_epoch=-1 if start_epoch == 0 else start_epoch * epoch_iters
156
+ )
157
+ return scheduler
158
+
159
+
160
+ def main():
161
+ init()
162
+
163
+ train_dataloader = create_train_dataloader(
164
+ cfg.vocab,
165
+ cfg.train_lrcs_path,
166
+ cfg.train_num_workers,
167
+ cfg.train_max_batch_size,
168
+ cfg.train_max_pixel_nums,
169
+ cfg.train_bucket_seps,
170
+ cfg.train_data_dir
171
+ )
172
+
173
+ logger.info(
174
+ 'Train dataset have %d samples, %d batchs' %
175
+ (
176
+ len(train_dataloader.dataset),
177
+ len(train_dataloader.batch_sampler)
178
+ )
179
+ )
180
+
181
+ valid_dataloader = create_valid_dataloader(
182
+ cfg.vocab,
183
+ cfg.valid_lrc_path,
184
+ cfg.valid_num_workers,
185
+ cfg.valid_batch_size,
186
+ cfg.valid_data_dir
187
+ )
188
+
189
+ logger.info(
190
+ 'Valid dataset have %d samples, %d batchs with batch_size=%d' %
191
+ (
192
+ len(valid_dataloader.dataset),
193
+ len(valid_dataloader.batch_sampler),
194
+ valid_dataloader.batch_size
195
+ )
196
+ )
197
+
198
+ model = build_model(cfg)
199
+ model.cuda()
200
+
201
+ if distributed():
202
+ synchronizer = ModelSynchronizer(model, cfg.sync_rate)
203
+ else:
204
+ synchronizer = None
205
+
206
+ epoch_iters = len(train_dataloader.batch_sampler)
207
+ optimizer = build_optimizer(cfg, model)
208
+
209
+ global metrics_name
210
+ global best_metrics
211
+ start_epoch = 0
212
+
213
+ resume_path = os.path.join(cfg.work_dir, 'latest_model.pth')
214
+ if os.path.exists(resume_path):
215
+ best_metrics, start_epoch = load_checkpoint(resume_path, model, optimizer)
216
+ start_epoch += 1
217
+ logger.info('resume from: %s' % resume_path)
218
+ elif cfg.train_checkpoint is not None:
219
+ load_checkpoint(cfg.train_checkpoint, model)
220
+ logger.info('load checkpoint from: %s' % cfg.train_checkpoint)
221
+
222
+ scheduler = build_scheduler(cfg, optimizer, epoch_iters, start_epoch)
223
+
224
+ time_counter = TimeCounter(start_epoch, cfg.num_epochs, epoch_iters)
225
+ time_counter.reset()
226
+
227
+ for epoch in range(start_epoch, cfg.num_epochs):
228
+ if hasattr(train_dataloader.sampler, 'set_epoch'):
229
+ train_dataloader.sampler.set_epoch(epoch)
230
+ train(cfg, epoch, train_dataloader, model, optimizer, scheduler, time_counter, synchronizer)
231
+
232
+ with torch.no_grad():
233
+ metrics = valid(cfg, valid_dataloader, model)
234
+
235
+ for metric_idx in range(len(metrics_name)):
236
+ if metrics[metric_idx] > best_metrics[metric_idx]:
237
+ best_metrics[metric_idx] = metrics[metric_idx]
238
+ save_checkpoint(os.path.join(cfg.work_dir, 'best_%s_model.pth' % metrics_name[metric_idx]), model, optimizer, best_metrics, epoch)
239
+ logger.info('Save current model as best_%s_model' % metrics_name[metric_idx])
240
+
241
+ save_checkpoint(os.path.join(cfg.work_dir, 'latest_model.pth'), model, optimizer, best_metrics, epoch)
242
+
243
+
244
+ if __name__ == '__main__':
245
+ main()
runner/valid.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import json
3
+ sys.path.append('./')
4
+ sys.path.append('../')
5
+ import os
6
+ import tqdm
7
+ import torch
8
+ import numpy as np
9
+ from libs.configs import cfg, setup_config
10
+ from libs.model import build_model
11
+ from libs.data import create_valid_dataloader
12
+ from libs.utils import logger
13
+ from libs.utils.cal_f1 import pred_result_to_table, table_to_relations, evaluate_f1
14
+ from libs.utils.checkpoint import load_checkpoint
15
+ from libs.utils.comm import synchronize, all_gather
16
+
17
+
18
+ def init():
19
+ import argparse
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--lrc", type=str, default=None)
22
+ parser.add_argument("--cfg", type=str, default='default')
23
+ parser.add_argument("--local_rank", type=int, default=0)
24
+ args = parser.parse_args()
25
+
26
+ setup_config(args.cfg)
27
+ if args.lrc is not None:
28
+ cfg.valid_lrc_path = args.lrc
29
+
30
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
31
+
32
+ num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
33
+ distributed = num_gpus > 1
34
+
35
+ if distributed:
36
+ torch.cuda.set_device(args.local_rank)
37
+ torch.distributed.init_process_group(
38
+ backend="nccl", init_method="env://"
39
+ )
40
+ synchronize()
41
+
42
+ logger.setup_logger('Line Detect Model', cfg.work_dir, 'valid.log')
43
+ logger.info('Use config: %s' % args.cfg)
44
+ logger.info('Evaluate Dataset: %s' % cfg.valid_lrc_path)
45
+
46
+
47
+ def valid(cfg, dataloader, model):
48
+ model.eval()
49
+ total_label_relations = list()
50
+ total_pred_relations = list()
51
+ total_relations_metric = list()
52
+
53
+ for it, data_batch in enumerate(tqdm.tqdm(dataloader)):
54
+ ids = data_batch['ids']
55
+ images_size = data_batch['images_size']
56
+ images = data_batch['images'].to(cfg.device)
57
+ tables = data_batch['tables']
58
+
59
+ pred_result, _ = model(images, images_size)
60
+
61
+ # pred
62
+ pred_tables = [
63
+ pred_result_to_table(tables[batch_idx],
64
+ (pred_result[0][batch_idx], pred_result[1][batch_idx], \
65
+ pred_result[2][batch_idx], pred_result[3][batch_idx])
66
+ ) \
67
+ for batch_idx in range(len(ids))
68
+ ]
69
+ pred_relations = [table_to_relations(table) for table in pred_tables]
70
+ total_pred_relations.extend(pred_relations)
71
+ # label
72
+ label_relations = []
73
+ for table in tables:
74
+ with open(table['label_path'], 'r') as f:
75
+ label_relations.append(json.load(f))
76
+ total_label_relations.extend(label_relations)
77
+
78
+ # cal P, R, F1
79
+ total_relations_metric = evaluate_f1(total_label_relations, total_pred_relations, num_workers=40)
80
+ P, R, F1 = np.array(total_relations_metric).mean(0).tolist()
81
+ F1 = 2 * P * R / (P + R)
82
+ logger.info('[Valid] Total Type Mertric: Precision: %s, Recall: %s, F1-Score: %s' % (P, R, F1))
83
+
84
+ return (F1, )
85
+
86
+
87
+ def main():
88
+ init()
89
+
90
+ valid_dataloader = create_valid_dataloader(
91
+ cfg.vocab,
92
+ cfg.valid_lrc_path,
93
+ cfg.valid_num_workers,
94
+ cfg.valid_batch_size
95
+ )
96
+ logger.info(
97
+ 'Valid dataset have %d samples, %d batchs with batch_size=%d' % \
98
+ (
99
+ len(valid_dataloader.dataset),
100
+ len(valid_dataloader.batch_sampler),
101
+ valid_dataloader.batch_size
102
+ )
103
+ )
104
+
105
+ model = build_model(cfg)
106
+ model.cuda()
107
+
108
+ load_checkpoint(cfg.eval_checkpoint, model)
109
+ logger.info('Load checkpoint from: %s' % cfg.eval_checkpoint)
110
+
111
+ with torch.no_grad():
112
+ valid(cfg, valid_dataloader, model)
113
+
114
+
115
+ if __name__ == '__main__':
116
+ main()