App / ch_ppocr_rec /utils.py
LinhKL2002's picture
Upload folder using huggingface_hub
4dbe5d1 verified
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
class CTCLabelDecode:
def __init__(
self,
character: Optional[List[str]] = None,
character_path: Union[str, Path, None] = None,
):
self.character = self.get_character(character, character_path)
self.dict = {char: i for i, char in enumerate(self.character)}
def __call__(
self, preds: np.ndarray, return_word_box: bool = False, **kwargs
) -> List[Tuple[str, float]]:
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(
preds_idx, preds_prob, return_word_box, is_remove_duplicate=True
)
if return_word_box:
for rec_idx, rec in enumerate(text):
wh_ratio = kwargs["wh_ratio_list"][rec_idx]
max_wh_ratio = kwargs["max_wh_ratio"]
rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
return text
def get_character(
self,
character: Optional[List[str]] = None,
character_path: Union[str, Path, None] = None,
) -> List[str]:
if character is None and character_path is None:
raise ValueError("character must not be None")
character_list = None
if character:
character_list = character
if character_path:
character_list = self.read_character_file(character_path)
if character_list is None:
raise ValueError("character must not be None")
character_list = self.insert_special_char(
character_list, " ", len(character_list)
)
character_list = self.insert_special_char(character_list, "blank", 0)
return character_list
@staticmethod
def read_character_file(character_path: Union[str, Path]) -> List[str]:
character_list = []
with open(character_path, "rb") as f:
lines = f.readlines()
for line in lines:
line = line.decode("utf-8").strip("\n").strip("\r\n")
character_list.append(line)
return character_list
@staticmethod
def insert_special_char(
character_list: List[str], special_char: str, loc: int = -1
) -> List[str]:
character_list.insert(loc, special_char)
return character_list
def decode(
self,
text_index: np.ndarray,
text_prob: Optional[np.ndarray] = None,
return_word_box: bool = False,
is_remove_duplicate: bool = False,
) -> List[Tuple[str, float]]:
"""convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
if text_prob is not None:
conf_list = np.array(text_prob[batch_idx][selection]).tolist()
else:
conf_list = [1] * len(selection)
if len(conf_list) == 0:
conf_list = [0]
char_list = [
self.character[text_id] for text_id in text_index[batch_idx][selection]
]
text = "".join(char_list)
if return_word_box:
word_list, word_col_list, state_list = self.get_word_info(
text, selection
)
result_list.append(
(
text,
np.mean(conf_list).tolist(),
[
len(text_index[batch_idx]),
word_list,
word_col_list,
state_list,
conf_list,
],
)
)
else:
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
@staticmethod
def get_word_info(
text: str, selection: np.ndarray
) -> Tuple[List[List[str]], List[List[int]], List[str]]:
"""
Group the decoded characters and record the corresponding decoded positions.
from https://github.com/PaddlePaddle/PaddleOCR/blob/fbba2178d7093f1dffca65a5b963ec277f1a6125/ppocr/postprocess/rec_postprocess.py#L70
Args:
text: the decoded text
selection: the bool array that identifies which columns of features are decoded as non-separated characters
Returns:
word_list: list of the grouped words
word_col_list: list of decoding positions corresponding to each character in the grouped word
state_list: list of marker to identify the type of grouping words, including two types of grouping words:
- 'cn': continous chinese characters (e.g., 你好啊)
- 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
"""
state = None
word_content = []
word_col_content = []
word_list = []
word_col_list = []
state_list = []
valid_col = np.where(selection)[0]
col_width = np.zeros(valid_col.shape)
if len(valid_col) > 0:
col_width[1:] = valid_col[1:] - valid_col[:-1]
col_width[0] = min(
3 if "\u4e00" <= text[0] <= "\u9fff" else 2, int(valid_col[0])
)
for c_i, char in enumerate(text):
if "\u4e00" <= char <= "\u9fff":
c_state = "cn"
else:
c_state = "en&num"
if state is None:
state = c_state
if state != c_state or col_width[c_i] > 4:
if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)
word_content = []
word_col_content = []
state = c_state
word_content.append(char)
word_col_content.append(int(valid_col[c_i]))
if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)
return word_list, word_col_list, state_list
@staticmethod
def get_ignored_tokens() -> List[int]:
return [0] # for ctc blank