import cv2 import copy import Polygon import numpy as np def cal_mean_lr(optimizer): lrs = [group['lr'] for group in optimizer.param_groups] return sum(lrs)/len(lrs) def cal_pr_f1(pr_info): precision = pr_info[0] / pr_info[1] recall = pr_info[0] / pr_info[2] f1 = 2*precision*recall/(precision+recall) return precision, recall, f1 def match_segment_spans(segments, spans): matched_segments = list() matched_spans = list() for segment_idx, segment in enumerate(segments): for span_idx, span in enumerate(spans): if span_idx not in matched_spans: if (segment >= span[0]) and (segment < span[1]): matched_segments.append(segment_idx) matched_spans.append(span_idx) return matched_segments, matched_spans def find_unmatch_segment_spans(segments, spans): unmatched_segments = list() for segment_idx, segment in enumerate(segments): matched = False for span in spans: if (segment >= span[0]) and (segment < span[1]): matched = True break if not matched: unmatched_segments.append(segment_idx) return unmatched_segments def parse_layout(spans, num_rows, num_cols): layout = np.full([num_rows, num_cols], -1, dtype=np.int) cell_count = 0 for x1, y1, x2, y2 in spans: layout[y1:y2+1, x1:x2+1] = cell_count cell_count += 1 cells_id = list() for row_idx in range(num_rows): for col_idx in range(num_cols): cell_id = layout[row_idx, col_idx] if cell_id in cells_id: layout[row_idx, col_idx] = cells_id.index(cell_id) else: layout[row_idx, col_idx] = len(cells_id) cells_id.append(cell_id) return layout def parse_cells(layout, spans, row_segments, col_segments): cells = list() num_cells = np.max(layout) + 1 for cell_id in range(num_cells): cell_positions = np.argwhere(layout == cell_id) y1 = np.min(cell_positions[:, 0]) y2 = np.max(cell_positions[:, 0]) x1 = np.min(cell_positions[:, 1]) x2 = np.max(cell_positions[:, 1]) assert np.all(layout[y1:y2, x1:x2] == cell_id) x1 = col_segments[x1] x2 = col_segments[x2+1] y1 = row_segments[y1] y2 = row_segments[y2+1] cell = dict( segmentation=[[[x1, y1], [x2, y1], [x2, y2], [x1, y2]]] ) cells.append(cell) for span in spans: cell_id = layout[span[1], span[0]] cells[cell_id]['transcript'] = 'None' return cells def segmentation_to_bbox(segmentation): x1 = min([min([pt[0] for pt in contour]) for contour in segmentation]) y1 = min([min([pt[1] for pt in contour]) for contour in segmentation]) x2 = max([max([pt[0] for pt in contour]) for contour in segmentation]) y2 = max([max([pt[1] for pt in contour]) for contour in segmentation]) return [x1, y1, x2, y2] def extend_cell_lines(cells, lines): def segmentation_to_polygon(segmentation): polygon = Polygon.Polygon() for contour in segmentation: polygon = polygon + Polygon.Polygon(contour) return polygon lines = copy.deepcopy(lines) cells_poly = [segmentation_to_polygon(item['segmentation']) for item in cells] lines_poly = [segmentation_to_polygon(item['segmentation']) for item in lines] cells_lines = [[] for _ in range(len(cells))] for line_idx, line_poly in enumerate(lines_poly): if line_poly.area() == 0: continue line_area = line_poly.area() max_overlap = 0 max_overlap_idx = None for cell_idx, cell_poly in enumerate(cells_poly): overlap = (cell_poly & line_poly).area()/line_area if overlap > max_overlap: max_overlap_idx = cell_idx max_overlap = overlap if max_overlap > 0: cells_lines[max_overlap_idx].append(line_idx) lines_y1 = [segmentation_to_bbox(item['segmentation'])[1] for item in lines] cells_lines = [sorted(item, key=lambda idx: lines_y1[idx]) for item in cells_lines] for cell, cell_lines in zip(cells, cells_lines): cell['lines_idx'] = cell_lines def rerange_layout(table): layout = table['layout'] cells = table['cells'] valid_cells_id = list() for row_idx in range(layout.shape[0]): for col_idx in range(layout.shape[1]): cell_id = layout[row_idx, col_idx] if cell_id not in valid_cells_id: valid_cells_id.append(cell_id) layout[row_idx, col_idx] = valid_cells_id.index(cell_id) cells = [cells[cell_id] for cell_id in valid_cells_id] table['layout'] = layout table['cells'] = cells def cal_cell_spans(table): layout = table['layout'] num_cells = len(table['cells']) cells_span = list() for cell_id in range(num_cells): cell_positions = np.argwhere(layout == cell_id) y1 = np.min(cell_positions[:, 0]) y2 = np.max(cell_positions[:, 0]) x1 = np.min(cell_positions[:, 1]) x2 = np.max(cell_positions[:, 1]) assert np.all(layout[y1:y2, x1:x2] == cell_id) cells_span.append([x1, y1, x2, y2]) return cells_span def remove_repeat_rcs(table): layout = table['layout'] head_rows = table['head_rows'] body_rows = table['body_rows'] while True: num_rows = layout.shape[0] num_cols = layout.shape[1] valid_rows_idx = list() valid_rows_key = list() for row_idx in range(num_rows): row = layout[row_idx, :] if len(np.unique(row)) == 1 and row_idx in body_rows: # remove repeated row continue row_key = ','.join([str(item) for item in row]) if row_key not in valid_rows_key: valid_rows_idx.append(row_idx) valid_rows_key.append(row_key) valid_cols_idx = list() valid_cols_key = list() for col_idx in range(num_cols): col = layout[:, col_idx] if len(np.unique(col)) == 1: # remove repeated col continue col_key = ','.join([str(item) for item in col]) if col_key not in valid_cols_key: valid_cols_idx.append(col_idx) valid_cols_key.append(col_key) if (len(valid_rows_idx) == num_rows) and (len(valid_cols_idx) == num_cols): break layout = layout[valid_rows_idx][:, valid_cols_idx] head_rows = [n_idx for n_idx, o_idx in enumerate(valid_rows_idx) if o_idx in head_rows] body_rows = [n_idx for n_idx, o_idx in enumerate(valid_rows_idx) if o_idx in body_rows] table['layout'] = layout table['head_rows'] = head_rows table['body_rows'] = body_rows rerange_layout(table) def pred_result_to_table(pred_result): row_segments, col_segments, divide, spans = pred_result num_rows = len(row_segments) - 1 num_cols = len(col_segments) - 1 layout = parse_layout(spans, num_rows, num_cols) cells = parse_cells(layout, spans, row_segments, col_segments) head_rows = list(range(0, divide)) body_rows = list(range(divide, num_rows)) table = dict( layout=layout, head_rows=head_rows, body_rows=body_rows, cells=cells ) # remove_repeat_rcs(table) return table def is_simple_table(table): layout = table['layout'] num_rows, num_cols = layout.shape if num_rows * num_cols == len(table['cells']): return True else: return False def tensor_to_image(tensor): image = tensor.detach().cpu().numpy() if (len(image.shape) == 3) and (image.shape[0] != 3) and (image.shape[0] != 1): image = np.sqrt(np.sum(np.power(image, 2), axis=0, keepdims=True)) image = 255 * (image-np.min(image))/(np.max(image) - np.min(image)) image = image.astype(np.uint8) if len(image.shape) == 3: image = np.transpose(image, (1, 2, 0)).copy() if image.shape[2] == 1: image = image[:, :, 0] return image def visualize_layout(image, table): def draw_segmentation(image, segmentation, color): for contour in segmentation: contour = np.array(contour, dtype=np.int32) image = cv2.polylines(image, [contour], True, color) return image for cell in table['cells']: if 'segmentation' in cell: image = draw_segmentation(image, cell['segmentation'], (255, 0, 0)) return image virtual_chars = ["", "", "", "", "", "", "", "", "", "", "", "", "", ""] def is_blank(content): global virtual_chars new_content = content for item in virtual_chars: new_content = new_content.replace(item, '') return new_content.strip() == '' def filt_content(content, filt_blank=False, filt_virtual=False, filt_pad=False): global virtual_chars if filt_blank: if is_blank(content): content = '' if filt_virtual: for item in content: content = content.replace(item, '') if filt_pad: content = content.strip() return content def filt_transcript(html, filt_blank=False, filt_virtual=False, filt_pad=False): start_idx = 0 while '') + 1 + start_idx content_end_idx = html[content_start_idx:].index('') + content_start_idx end_idx = content_end_idx + len('') content = html[content_start_idx:content_end_idx] content = filt_content(content, filt_blank, filt_virtual, filt_pad) html = html[:content_start_idx] + content + html[content_end_idx:] start_idx = end_idx - (content_end_idx-content_start_idx - len(content)) return html