|
|
|
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
from skimage.morphology import skeletonize |
|
|
|
|
|
from mmocr.models.builder import POSTPROCESSOR |
|
|
from .base_postprocessor import BasePostprocessor |
|
|
from .utils import centralize, fill_hole, merge_disks |
|
|
|
|
|
|
|
|
@POSTPROCESSOR.register_module() |
|
|
class TextSnakePostprocessor(BasePostprocessor): |
|
|
"""Decoding predictions of TextSnake to instances. This was partially |
|
|
adapted from https://github.com/princewang1994/TextSnake.pytorch. |
|
|
|
|
|
Args: |
|
|
text_repr_type (str): The boundary encoding type 'poly' or 'quad'. |
|
|
min_text_region_confidence (float): The confidence threshold of text |
|
|
region in TextSnake. |
|
|
min_center_region_confidence (float): The confidence threshold of text |
|
|
center region in TextSnake. |
|
|
min_center_area (int): The minimal text center region area. |
|
|
disk_overlap_thr (float): The radius overlap threshold for merging |
|
|
disks. |
|
|
radius_shrink_ratio (float): The shrink ratio of ordered disks radii. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
text_repr_type='poly', |
|
|
min_text_region_confidence=0.6, |
|
|
min_center_region_confidence=0.2, |
|
|
min_center_area=30, |
|
|
disk_overlap_thr=0.03, |
|
|
radius_shrink_ratio=1.03, |
|
|
**kwargs): |
|
|
super().__init__(text_repr_type) |
|
|
assert text_repr_type == 'poly' |
|
|
self.min_text_region_confidence = min_text_region_confidence |
|
|
self.min_center_region_confidence = min_center_region_confidence |
|
|
self.min_center_area = min_center_area |
|
|
self.disk_overlap_thr = disk_overlap_thr |
|
|
self.radius_shrink_ratio = radius_shrink_ratio |
|
|
|
|
|
def __call__(self, preds): |
|
|
""" |
|
|
Args: |
|
|
preds (Tensor): Prediction map with shape :math:`(C, H, W)`. |
|
|
|
|
|
Returns: |
|
|
list[list[float]]: The instance boundary and its confidence. |
|
|
""" |
|
|
assert preds.dim() == 3 |
|
|
|
|
|
preds[:2, :, :] = torch.sigmoid(preds[:2, :, :]) |
|
|
preds = preds.detach().cpu().numpy() |
|
|
|
|
|
pred_text_score = preds[0] |
|
|
pred_text_mask = pred_text_score > self.min_text_region_confidence |
|
|
pred_center_score = preds[1] * pred_text_score |
|
|
pred_center_mask = \ |
|
|
pred_center_score > self.min_center_region_confidence |
|
|
pred_sin = preds[2] |
|
|
pred_cos = preds[3] |
|
|
pred_radius = preds[4] |
|
|
mask_sz = pred_text_mask.shape |
|
|
|
|
|
scale = np.sqrt(1.0 / (pred_sin**2 + pred_cos**2 + 1e-8)) |
|
|
pred_sin = pred_sin * scale |
|
|
pred_cos = pred_cos * scale |
|
|
|
|
|
pred_center_mask = fill_hole(pred_center_mask).astype(np.uint8) |
|
|
center_contours, _ = cv2.findContours(pred_center_mask, cv2.RETR_TREE, |
|
|
cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
boundaries = [] |
|
|
for contour in center_contours: |
|
|
if cv2.contourArea(contour) < self.min_center_area: |
|
|
continue |
|
|
instance_center_mask = np.zeros(mask_sz, dtype=np.uint8) |
|
|
cv2.drawContours(instance_center_mask, [contour], -1, 1, -1) |
|
|
skeleton = skeletonize(instance_center_mask) |
|
|
skeleton_yx = np.argwhere(skeleton > 0) |
|
|
y, x = skeleton_yx[:, 0], skeleton_yx[:, 1] |
|
|
cos = pred_cos[y, x].reshape((-1, 1)) |
|
|
sin = pred_sin[y, x].reshape((-1, 1)) |
|
|
radius = pred_radius[y, x].reshape((-1, 1)) |
|
|
|
|
|
center_line_yx = centralize(skeleton_yx, cos, -sin, radius, |
|
|
instance_center_mask) |
|
|
y, x = center_line_yx[:, 0], center_line_yx[:, 1] |
|
|
radius = (pred_radius[y, x] * self.radius_shrink_ratio).reshape( |
|
|
(-1, 1)) |
|
|
score = pred_center_score[y, x].reshape((-1, 1)) |
|
|
instance_disks = np.hstack( |
|
|
[np.fliplr(center_line_yx), radius, score]) |
|
|
instance_disks = merge_disks(instance_disks, self.disk_overlap_thr) |
|
|
|
|
|
instance_mask = np.zeros(mask_sz, dtype=np.uint8) |
|
|
for x, y, radius, score in instance_disks: |
|
|
if radius > 1: |
|
|
cv2.circle(instance_mask, (int(x), int(y)), int(radius), 1, |
|
|
-1) |
|
|
contours, _ = cv2.findContours(instance_mask, cv2.RETR_TREE, |
|
|
cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
score = np.sum(instance_mask * pred_text_score) / ( |
|
|
np.sum(instance_mask) + 1e-8) |
|
|
if (len(contours) > 0 and cv2.contourArea(contours[0]) > 0 |
|
|
and contours[0].size > 8): |
|
|
boundary = contours[0].flatten().tolist() |
|
|
boundaries.append(boundary + [score]) |
|
|
|
|
|
return boundaries |
|
|
|