Spaces:
Runtime error
Runtime error
| # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import numpy as np | |
| from ppstructure.table.table_master_match import deal_eb_token, deal_bb | |
| def distance(box_1, box_2): | |
| x1, y1, x2, y2 = box_1 | |
| x3, y3, x4, y4 = box_2 | |
| dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2) | |
| dis_2 = abs(x3 - x1) + abs(y3 - y1) | |
| dis_3 = abs(x4 - x2) + abs(y4 - y2) | |
| return dis + min(dis_2, dis_3) | |
| def compute_iou(rec1, rec2): | |
| """ | |
| computing IoU | |
| :param rec1: (y0, x0, y1, x1), which reflects | |
| (top, left, bottom, right) | |
| :param rec2: (y0, x0, y1, x1) | |
| :return: scala value of IoU | |
| """ | |
| # computing area of each rectangles | |
| S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) | |
| S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) | |
| # computing the sum_area | |
| sum_area = S_rec1 + S_rec2 | |
| # find the each edge of intersect rectangle | |
| left_line = max(rec1[1], rec2[1]) | |
| right_line = min(rec1[3], rec2[3]) | |
| top_line = max(rec1[0], rec2[0]) | |
| bottom_line = min(rec1[2], rec2[2]) | |
| # judge if there is an intersect | |
| if left_line >= right_line or top_line >= bottom_line: | |
| return 0.0 | |
| else: | |
| intersect = (right_line - left_line) * (bottom_line - top_line) | |
| return (intersect / (sum_area - intersect)) * 1.0 | |
| class TableMatch: | |
| def __init__(self, filter_ocr_result=False, use_master=False): | |
| self.filter_ocr_result = filter_ocr_result | |
| self.use_master = use_master | |
| def __call__(self, structure_res, dt_boxes, rec_res): | |
| pred_structures, pred_bboxes = structure_res | |
| if self.filter_ocr_result: | |
| dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes, | |
| rec_res) | |
| matched_index = self.match_result(dt_boxes, pred_bboxes) | |
| if self.use_master: | |
| pred_html, pred = self.get_pred_html_master(pred_structures, | |
| matched_index, rec_res) | |
| else: | |
| pred_html, pred = self.get_pred_html(pred_structures, matched_index, | |
| rec_res) | |
| return pred_html | |
| def match_result(self, dt_boxes, pred_bboxes): | |
| matched = {} | |
| for i, gt_box in enumerate(dt_boxes): | |
| distances = [] | |
| for j, pred_box in enumerate(pred_bboxes): | |
| if len(pred_box) == 8: | |
| pred_box = [ | |
| np.min(pred_box[0::2]), np.min(pred_box[1::2]), | |
| np.max(pred_box[0::2]), np.max(pred_box[1::2]) | |
| ] | |
| distances.append((distance(gt_box, pred_box), | |
| 1. - compute_iou(gt_box, pred_box) | |
| )) # compute iou and l1 distance | |
| sorted_distances = distances.copy() | |
| # select det box by iou and l1 distance | |
| sorted_distances = sorted( | |
| sorted_distances, key=lambda item: (item[1], item[0])) | |
| if distances.index(sorted_distances[0]) not in matched.keys(): | |
| matched[distances.index(sorted_distances[0])] = [i] | |
| else: | |
| matched[distances.index(sorted_distances[0])].append(i) | |
| return matched | |
| def get_pred_html(self, pred_structures, matched_index, ocr_contents): | |
| end_html = [] | |
| td_index = 0 | |
| for tag in pred_structures: | |
| if '</td>' in tag: | |
| if '<td></td>' == tag: | |
| end_html.extend('<td>') | |
| if td_index in matched_index.keys(): | |
| b_with = False | |
| if '<b>' in ocr_contents[matched_index[td_index][ | |
| 0]] and len(matched_index[td_index]) > 1: | |
| b_with = True | |
| end_html.extend('<b>') | |
| for i, td_index_index in enumerate(matched_index[td_index]): | |
| content = ocr_contents[td_index_index][0] | |
| if len(matched_index[td_index]) > 1: | |
| if len(content) == 0: | |
| continue | |
| if content[0] == ' ': | |
| content = content[1:] | |
| if '<b>' in content: | |
| content = content[3:] | |
| if '</b>' in content: | |
| content = content[:-4] | |
| if len(content) == 0: | |
| continue | |
| if i != len(matched_index[ | |
| td_index]) - 1 and ' ' != content[-1]: | |
| content += ' ' | |
| end_html.extend(content) | |
| if b_with: | |
| end_html.extend('</b>') | |
| if '<td></td>' == tag: | |
| end_html.append('</td>') | |
| else: | |
| end_html.append(tag) | |
| td_index += 1 | |
| else: | |
| end_html.append(tag) | |
| return ''.join(end_html), end_html | |
| def get_pred_html_master(self, pred_structures, matched_index, | |
| ocr_contents): | |
| end_html = [] | |
| td_index = 0 | |
| for token in pred_structures: | |
| if '</td>' in token: | |
| txt = '' | |
| b_with = False | |
| if td_index in matched_index.keys(): | |
| if '<b>' in ocr_contents[matched_index[td_index][ | |
| 0]] and len(matched_index[td_index]) > 1: | |
| b_with = True | |
| for i, td_index_index in enumerate(matched_index[td_index]): | |
| content = ocr_contents[td_index_index][0] | |
| if len(matched_index[td_index]) > 1: | |
| if len(content) == 0: | |
| continue | |
| if content[0] == ' ': | |
| content = content[1:] | |
| if '<b>' in content: | |
| content = content[3:] | |
| if '</b>' in content: | |
| content = content[:-4] | |
| if len(content) == 0: | |
| continue | |
| if i != len(matched_index[ | |
| td_index]) - 1 and ' ' != content[-1]: | |
| content += ' ' | |
| txt += content | |
| if b_with: | |
| txt = '<b>{}</b>'.format(txt) | |
| if '<td></td>' == token: | |
| token = '<td>{}</td>'.format(txt) | |
| else: | |
| token = '{}</td>'.format(txt) | |
| td_index += 1 | |
| token = deal_eb_token(token) | |
| end_html.append(token) | |
| html = ''.join(end_html) | |
| html = deal_bb(html) | |
| return html, end_html | |
| def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res): | |
| y1 = pred_bboxes[:, 1::2].min() | |
| new_dt_boxes = [] | |
| new_rec_res = [] | |
| for box, rec in zip(dt_boxes, rec_res): | |
| if np.max(box[1::2]) < y1: | |
| continue | |
| new_dt_boxes.append(box) | |
| new_rec_res.append(rec) | |
| return new_dt_boxes, new_rec_res | |