Spaces:
Sleeping
Sleeping
| # Copyright (C) 2021-2024, Mindee. | |
| # This program is licensed under the Apache License 2.0. | |
| # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | |
| from typing import Any, Callable, Dict, List, Optional, Tuple | |
| import numpy as np | |
| from doctr.models.builder import DocumentBuilder | |
| from doctr.utils.geometry import extract_crops, extract_rcrops | |
| from .._utils import rectify_crops, rectify_loc_preds | |
| from ..classification import crop_orientation_predictor | |
| from ..classification.predictor import OrientationPredictor | |
| __all__ = ["_OCRPredictor"] | |
| class _OCRPredictor: | |
| """Implements an object able to localize and identify text elements in a set of documents | |
| Args: | |
| ---- | |
| assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages | |
| without rotated textual elements. | |
| straighten_pages: if True, estimates the page general orientation based on the median line orientation. | |
| Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped | |
| accordingly. Doing so will improve performances for documents with page-uniform rotations. | |
| preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding) | |
| symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically. | |
| **kwargs: keyword args of `DocumentBuilder` | |
| """ | |
| crop_orientation_predictor: Optional[OrientationPredictor] | |
| def __init__( | |
| self, | |
| assume_straight_pages: bool = True, | |
| straighten_pages: bool = False, | |
| preserve_aspect_ratio: bool = True, | |
| symmetric_pad: bool = True, | |
| **kwargs: Any, | |
| ) -> None: | |
| self.assume_straight_pages = assume_straight_pages | |
| self.straighten_pages = straighten_pages | |
| self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor(pretrained=True) | |
| self.doc_builder = DocumentBuilder(**kwargs) | |
| self.preserve_aspect_ratio = preserve_aspect_ratio | |
| self.symmetric_pad = symmetric_pad | |
| self.hooks: List[Callable] = [] | |
| def _generate_crops( | |
| pages: List[np.ndarray], | |
| loc_preds: List[np.ndarray], | |
| channels_last: bool, | |
| assume_straight_pages: bool = False, | |
| ) -> List[List[np.ndarray]]: | |
| extraction_fn = extract_crops if assume_straight_pages else extract_rcrops | |
| crops = [ | |
| extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator] | |
| for page, _boxes in zip(pages, loc_preds) | |
| ] | |
| return crops | |
| def _prepare_crops( | |
| pages: List[np.ndarray], | |
| loc_preds: List[np.ndarray], | |
| channels_last: bool, | |
| assume_straight_pages: bool = False, | |
| ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]: | |
| crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages) | |
| # Avoid sending zero-sized crops | |
| is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops] | |
| crops = [ | |
| [crop for crop, _kept in zip(page_crops, page_kept) if _kept] | |
| for page_crops, page_kept in zip(crops, is_kept) | |
| ] | |
| loc_preds = [_boxes[_kept] for _boxes, _kept in zip(loc_preds, is_kept)] | |
| return crops, loc_preds | |
| def _rectify_crops( | |
| self, | |
| crops: List[List[np.ndarray]], | |
| loc_preds: List[np.ndarray], | |
| ) -> Tuple[List[List[np.ndarray]], List[np.ndarray], List[Tuple[int, float]]]: | |
| # Work at a page level | |
| orientations, classes, probs = zip(*[self.crop_orientation_predictor(page_crops) for page_crops in crops]) # type: ignore[misc] | |
| rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)] | |
| rect_loc_preds = [ | |
| rectify_loc_preds(page_loc_preds, orientation) if len(page_loc_preds) > 0 else page_loc_preds | |
| for page_loc_preds, orientation in zip(loc_preds, orientations) | |
| ] | |
| # Flatten to list of tuples with (value, confidence) | |
| crop_orientations = [ | |
| (orientation, prob) | |
| for page_classes, page_probs in zip(classes, probs) | |
| for orientation, prob in zip(page_classes, page_probs) | |
| ] | |
| return rect_crops, rect_loc_preds, crop_orientations # type: ignore[return-value] | |
| def _remove_padding( | |
| self, | |
| pages: List[np.ndarray], | |
| loc_preds: List[np.ndarray], | |
| ) -> List[np.ndarray]: | |
| if self.preserve_aspect_ratio: | |
| # Rectify loc_preds to remove padding | |
| rectified_preds = [] | |
| for page, loc_pred in zip(pages, loc_preds): | |
| h, w = page.shape[0], page.shape[1] | |
| if h > w: | |
| # y unchanged, dilate x coord | |
| if self.symmetric_pad: | |
| if self.assume_straight_pages: | |
| loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5, 0, 1) | |
| else: | |
| loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - 0.5) * h / w + 0.5, 0, 1) | |
| else: | |
| if self.assume_straight_pages: | |
| loc_pred[:, [0, 2]] *= h / w | |
| else: | |
| loc_pred[:, :, 0] *= h / w | |
| elif w > h: | |
| # x unchanged, dilate y coord | |
| if self.symmetric_pad: | |
| if self.assume_straight_pages: | |
| loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5, 0, 1) | |
| else: | |
| loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - 0.5) * w / h + 0.5, 0, 1) | |
| else: | |
| if self.assume_straight_pages: | |
| loc_pred[:, [1, 3]] *= w / h | |
| else: | |
| loc_pred[:, :, 1] *= w / h | |
| rectified_preds.append(loc_pred) | |
| return rectified_preds | |
| return loc_preds | |
| def _process_predictions( | |
| loc_preds: List[np.ndarray], | |
| word_preds: List[Tuple[str, float]], | |
| crop_orientations: List[Dict[str, Any]], | |
| ) -> Tuple[List[np.ndarray], List[List[Tuple[str, float]]], List[List[Dict[str, Any]]]]: | |
| text_preds = [] | |
| crop_orientation_preds = [] | |
| if len(loc_preds) > 0: | |
| # Text & crop orientation predictions at page level | |
| _idx = 0 | |
| for page_boxes in loc_preds: | |
| text_preds.append(word_preds[_idx : _idx + page_boxes.shape[0]]) | |
| crop_orientation_preds.append(crop_orientations[_idx : _idx + page_boxes.shape[0]]) | |
| _idx += page_boxes.shape[0] | |
| return loc_preds, text_preds, crop_orientation_preds | |
| def add_hook(self, hook: Callable) -> None: | |
| """Add a hook to the predictor | |
| Args: | |
| ---- | |
| hook: a callable that takes as input the `loc_preds` and returns the modified `loc_preds` | |
| """ | |
| self.hooks.append(hook) | |