Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from skimage.morphology import skeletonize | |
| from mmocr.models.builder import POSTPROCESSOR | |
| from .base_postprocessor import BasePostprocessor | |
| from .utils import centralize, fill_hole, merge_disks | |
| class TextSnakePostprocessor(BasePostprocessor): | |
| """Decoding predictions of TextSnake to instances. This was partially | |
| adapted from https://github.com/princewang1994/TextSnake.pytorch. | |
| Args: | |
| text_repr_type (str): The boundary encoding type 'poly' or 'quad'. | |
| min_text_region_confidence (float): The confidence threshold of text | |
| region in TextSnake. | |
| min_center_region_confidence (float): The confidence threshold of text | |
| center region in TextSnake. | |
| min_center_area (int): The minimal text center region area. | |
| disk_overlap_thr (float): The radius overlap threshold for merging | |
| disks. | |
| radius_shrink_ratio (float): The shrink ratio of ordered disks radii. | |
| """ | |
| def __init__(self, | |
| text_repr_type='poly', | |
| min_text_region_confidence=0.6, | |
| min_center_region_confidence=0.2, | |
| min_center_area=30, | |
| disk_overlap_thr=0.03, | |
| radius_shrink_ratio=1.03, | |
| **kwargs): | |
| super().__init__(text_repr_type) | |
| assert text_repr_type == 'poly' | |
| self.min_text_region_confidence = min_text_region_confidence | |
| self.min_center_region_confidence = min_center_region_confidence | |
| self.min_center_area = min_center_area | |
| self.disk_overlap_thr = disk_overlap_thr | |
| self.radius_shrink_ratio = radius_shrink_ratio | |
| def __call__(self, preds): | |
| """ | |
| Args: | |
| preds (Tensor): Prediction map with shape :math:`(C, H, W)`. | |
| Returns: | |
| list[list[float]]: The instance boundary and its confidence. | |
| """ | |
| assert preds.dim() == 3 | |
| preds[:2, :, :] = torch.sigmoid(preds[:2, :, :]) | |
| preds = preds.detach().cpu().numpy() | |
| pred_text_score = preds[0] | |
| pred_text_mask = pred_text_score > self.min_text_region_confidence | |
| pred_center_score = preds[1] * pred_text_score | |
| pred_center_mask = \ | |
| pred_center_score > self.min_center_region_confidence | |
| pred_sin = preds[2] | |
| pred_cos = preds[3] | |
| pred_radius = preds[4] | |
| mask_sz = pred_text_mask.shape | |
| scale = np.sqrt(1.0 / (pred_sin**2 + pred_cos**2 + 1e-8)) | |
| pred_sin = pred_sin * scale | |
| pred_cos = pred_cos * scale | |
| pred_center_mask = fill_hole(pred_center_mask).astype(np.uint8) | |
| center_contours, _ = cv2.findContours(pred_center_mask, cv2.RETR_TREE, | |
| cv2.CHAIN_APPROX_SIMPLE) | |
| boundaries = [] | |
| for contour in center_contours: | |
| if cv2.contourArea(contour) < self.min_center_area: | |
| continue | |
| instance_center_mask = np.zeros(mask_sz, dtype=np.uint8) | |
| cv2.drawContours(instance_center_mask, [contour], -1, 1, -1) | |
| skeleton = skeletonize(instance_center_mask) | |
| skeleton_yx = np.argwhere(skeleton > 0) | |
| y, x = skeleton_yx[:, 0], skeleton_yx[:, 1] | |
| cos = pred_cos[y, x].reshape((-1, 1)) | |
| sin = pred_sin[y, x].reshape((-1, 1)) | |
| radius = pred_radius[y, x].reshape((-1, 1)) | |
| center_line_yx = centralize(skeleton_yx, cos, -sin, radius, | |
| instance_center_mask) | |
| y, x = center_line_yx[:, 0], center_line_yx[:, 1] | |
| radius = (pred_radius[y, x] * self.radius_shrink_ratio).reshape( | |
| (-1, 1)) | |
| score = pred_center_score[y, x].reshape((-1, 1)) | |
| instance_disks = np.hstack( | |
| [np.fliplr(center_line_yx), radius, score]) | |
| instance_disks = merge_disks(instance_disks, self.disk_overlap_thr) | |
| instance_mask = np.zeros(mask_sz, dtype=np.uint8) | |
| for x, y, radius, score in instance_disks: | |
| if radius > 1: | |
| cv2.circle(instance_mask, (int(x), int(y)), int(radius), 1, | |
| -1) | |
| contours, _ = cv2.findContours(instance_mask, cv2.RETR_TREE, | |
| cv2.CHAIN_APPROX_SIMPLE) | |
| score = np.sum(instance_mask * pred_text_score) / ( | |
| np.sum(instance_mask) + 1e-8) | |
| if (len(contours) > 0 and cv2.contourArea(contours[0]) > 0 | |
| and contours[0].size > 8): | |
| boundary = contours[0].flatten().tolist() | |
| boundaries.append(boundary + [score]) | |
| return boundaries | |