manga1 / modules /ocr /base.py
sayed555's picture
Upload 342 files
7689b07 verified
from typing import Tuple, List, Dict, Union, Callable
import numpy as np
from collections import OrderedDict
from utils.textblock import TextBlock
from utils.registry import Registry
OCR = Registry('OCR')
register_OCR = OCR.register_module
from ..base import BaseModule, DEFAULT_DEVICE, DEVICE_SELECTOR, LOGGER
class OCRBase(BaseModule):
_postprocess_hooks = OrderedDict()
_preprocess_hooks = OrderedDict()
_line_only: bool = False
def __init__(self, **params) -> None:
super().__init__(**params)
self.name = ''
for key in OCR.module_dict:
if OCR.module_dict[key] == self.__class__:
self.name = key
break
def run_ocr(self, img: np.ndarray, blk_list: List[TextBlock] = None, *args, **kwargs) -> Union[List[TextBlock], str]:
if not self.all_model_loaded():
self.load_model()
if blk_list is None:
text = self.ocr_img(img)
return text
elif isinstance(blk_list, TextBlock):
blk_list = [blk_list]
for blk in blk_list:
if self.name != 'none_ocr':
blk.text = []
self._ocr_blk_list(img, blk_list, *args, **kwargs)
for callback_name, callback in self._postprocess_hooks.items():
callback(textblocks=blk_list, img=img, ocr_module=self)
return blk_list
def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock], *args, **kwargs) -> None:
raise NotImplementedError
def ocr_img(self, img: np.ndarray) -> str:
raise NotImplementedError