File size: 5,950 Bytes
cb0ad2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import math
import numpy as np


class InvalidFormat(Exception):
    pass


def segmentation_to_bbox(segmentation):
    x1 = min([pt[0] for contour in segmentation for pt in contour])
    y1 = min([pt[1] for contour in segmentation for pt in contour])
    x2 = max([pt[0] for contour in segmentation for pt in contour])
    y2 = max([pt[1] for contour in segmentation for pt in contour])
    return (x1, y1, x2, y2)


def cal_cell_bbox(table):
    cells_bbox = list()
    for cell in table['cells']:
        if 'segmentation' not in cell:
            cell_bbox = None
        else:
            segmentation = list()
            if 'sublines' in cell:
                for subline in cell['sublines']:
                    segmentation.extend(subline['segmentation'])
            if len(segmentation) == 0:
                segmentation = cell['segmentation']
            if len(segmentation) == 0:
                cell_bbox = None
            else:
                cell_bbox = segmentation_to_bbox(segmentation)
        cells_bbox.append(cell_bbox)
    return cells_bbox
            

def cal_cell_spans(table):
    layout = table['layout']
    num_cells = len(table['cells'])
    cells_span = list()
    for cell_id in range(num_cells):
        cell_positions = np.argwhere(layout == cell_id)
        y1 = np.min(cell_positions[:, 0])
        y2 = np.max(cell_positions[:, 0])
        x1 = np.min(cell_positions[:, 1])
        x2 = np.max(cell_positions[:, 1])
        assert np.all(layout[y1:y2, x1:x2] == cell_id)
        cells_span.append([x1, y1, x2, y2])
    return cells_span


def cal_fg_bg_span(spans, edge):
    num_span = len(spans)
    bg_spans = list()
    for idx in range(num_span):
        if spans[idx] is None:
            continue
        if idx == 0:
            if spans[idx][0] <= 0:
                continue
        else:
            if spans[idx-1] is None:
                continue
            if spans[idx][0] <= spans[idx-1][1]:
                continue
        if idx == num_span - 1:
            if spans[idx][1] >= edge:
                continue
        else:
            if spans[idx+1] is None:
                continue
            if spans[idx][1] >= spans[idx+1][0]:
                continue
        
        bg_spans.append(spans[idx])
    
    fg_spans = list()
    for idx in range(num_span+1):
        if idx == 0:
            s = 0
        else:
            if spans[idx-1] is None:
                continue
            s = spans[idx-1][1]
        
        if idx == num_span:
            e = edge
        else:
            if spans[idx] is None:
                continue
            e = spans[idx][0]

        if e <= s:
            continue

        fg_spans.append([s, e])

    return fg_spans, bg_spans


def shrink_spans(spans, size):
    new_spans = list()
    for idx, (start, end) in enumerate(spans):
        if idx == 0:
            if start <= 0:
                start = 1
        else:
            _, pre_end = spans[idx - 1]
            if start <= pre_end:
                shrink_distance = pre_end - start + 1
                start = start + math.ceil(shrink_distance / 2)

        if idx == len(spans) - 1:
            if end >= size:
                end = size - 1
        else:
            next_start, _ = spans[idx + 1]
            if end >= next_start:
                shrink_distance = end - next_start + 1
                end = end - math.ceil(shrink_distance / 2)
        if end - start < 1:
            raise InvalidFormat()

        new_spans.append([start, end])
    return new_spans


def cal_row_span(table, cells_span, cells_bbox, height):
    layout = table['layout']
    rows_span = list()
    for row_idx in range(layout.shape[0]):
        row = layout[row_idx, :]
        y1s = list()
        y2s = list()
        for cell_id in row:
            cell_span = cells_span[cell_id]
            cell_bbox = cells_bbox[cell_id]
            if (cell_span[1] == row_idx) and (cell_bbox is not None):
                y1s.append(cell_bbox[1])
            if (cell_span[3] == row_idx) and (cell_bbox is not None):
                y2s.append(cell_bbox[3])
        
        if (len(y1s) > 0) and (len(y2s) > 0):
            y1 = min(max(1, min(y1s)), height-1)
            y2 = min(max(1, max(y2s) + 1), height-1)
            rows_span.append([y1, y2])
        else:
            raise InvalidFormat()
    rows_span = shrink_spans(rows_span, height)
    rows_fg_span, rows_bg_span = cal_fg_bg_span(rows_span, height)
    return rows_fg_span, rows_bg_span


def cal_col_span(table, cells_span, cells_bbox, width):
    layout = table['layout']
    cols_span = list()
    for col_idx in range(layout.shape[1]):
        col = layout[:, col_idx]
        x1s = list()
        x2s = list()
        for cell_id in col:
            cell_span = cells_span[cell_id]
            cell_bbox = cells_bbox[cell_id]
            if (cell_span[0] == col_idx) and (cell_bbox is not None):
                x1s.append(cell_bbox[0])
            if (cell_span[2] == col_idx) and (cell_bbox is not None):
                x2s.append(cell_bbox[2])
        
        if (len(x1s) > 0) and (len(x2s) > 0):
            x1 = min(max(1, min(x1s)), width-1)
            x2 = min(max(1, max(x2s) + 1), width-1)
            cols_span.append([x1, x2])
        else:
            raise InvalidFormat()
    cols_span = shrink_spans(cols_span, width)
    cols_fg_span, cols_bg_span = cal_fg_bg_span(cols_span, width)
    return cols_fg_span, cols_bg_span


def extract_fg_bg_spans(table, image_size):
    width, height = image_size
    cells_bbox = cal_cell_bbox(table)
    cells_span = cal_cell_spans(table)
    # cal rows fg bg span
    rows_fg_span, rows_bg_span = cal_row_span(
        table, cells_span, cells_bbox, height
    )
    # cal cols fg bg span
    cols_fg_span, cols_bg_span = cal_col_span(
        table, cells_span, cells_bbox, width
    )
    return rows_fg_span, rows_bg_span, cols_fg_span, cols_bg_span, cells_span