|
|
import time |
|
|
from doctr.file_utils import is_tf_available |
|
|
import numpy as np |
|
|
import cv2 |
|
|
|
|
|
if is_tf_available(): |
|
|
import tensorflow as tf |
|
|
from backend.tensorflow import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor |
|
|
else: |
|
|
import torch |
|
|
from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor |
|
|
|
|
|
class OCRModel: |
|
|
def __init__(self): |
|
|
self.predictor = None |
|
|
self.device = None |
|
|
self._init_backend() |
|
|
|
|
|
def _init_backend(self): |
|
|
|
|
|
if is_tf_available(): |
|
|
import tensorflow as tf |
|
|
from backend.tensorflow import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor |
|
|
self.DET_ARCHS = DET_ARCHS |
|
|
self.RECO_ARCHS = RECO_ARCHS |
|
|
if any(tf.config.experimental.list_physical_devices("gpu")): |
|
|
self.device = tf.device("/gpu:0") |
|
|
else: |
|
|
self.device = tf.device("/cpu:0") |
|
|
else: |
|
|
|
|
|
import torch |
|
|
from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor |
|
|
self.DET_ARCHS = DET_ARCHS |
|
|
self.RECO_ARCHS = RECO_ARCHS |
|
|
if torch.cuda.is_available(): |
|
|
self.device = torch.device("cuda:0") |
|
|
else: |
|
|
self.device = torch.device("cpu") |
|
|
|
|
|
def load_model(self, det_arch, reco_arch, **kwargs): |
|
|
model_params = { |
|
|
"det_arch": det_arch, |
|
|
"reco_arch": reco_arch, |
|
|
"assume_straight_pages": kwargs.get("assume_straight_pages", True), |
|
|
"straighten_pages": kwargs.get("straighten_pages", False), |
|
|
"export_as_straight_boxes": kwargs.get("export_as_straight_boxes", False), |
|
|
"disable_page_orientation": kwargs.get("disable_page_orientation", False), |
|
|
"disable_crop_orientation": kwargs.get("disable_crop_orientation", False), |
|
|
"bin_thresh": kwargs.get("bin_thresh", 0.3), |
|
|
"box_thresh": kwargs.get("box_thresh", 0.1) |
|
|
} |
|
|
|
|
|
self.predictor = load_predictor( |
|
|
**model_params, |
|
|
device=self.device |
|
|
) |
|
|
|
|
|
def process_page(self, page): |
|
|
seg_map = forward_image(self.predictor, page, self.device) |
|
|
seg_map = np.squeeze(seg_map) |
|
|
seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR) |
|
|
out = self.predictor([page]) |
|
|
return seg_map, out |
|
|
|
|
|
def get_reconstructed_page(self, out, page): |
|
|
"""Get reconstructed page from OCR output""" |
|
|
page_export = out.pages[0].export() |
|
|
img = out.pages[0].synthesize() |
|
|
|
|
|
|
|
|
if not page_export["blocks"]: |
|
|
return img |
|
|
|
|
|
|
|
|
x_min, y_min = img.shape[1], img.shape[0] |
|
|
x_max, y_max = 0, 0 |
|
|
|
|
|
valid_coords = False |
|
|
for block in page_export["blocks"]: |
|
|
coords = np.array(block["geometry"]) |
|
|
if coords.size > 0: |
|
|
valid_coords = True |
|
|
x_min = min(x_min, max(0, np.nanmin(coords[:, 0]) * img.shape[1])) |
|
|
y_min = min(y_min, max(0, np.nanmin(coords[:, 1]) * img.shape[0])) |
|
|
x_max = max(x_max, min(img.shape[1], np.nanmax(coords[:, 0]) * img.shape[1])) |
|
|
y_max = max(y_max, min(img.shape[0], np.nanmax(coords[:, 1]) * img.shape[0])) |
|
|
|
|
|
|
|
|
if not valid_coords or x_min >= x_max or y_min >= y_max: |
|
|
return img |
|
|
|
|
|
|
|
|
margin = 10 |
|
|
x_min = max(0, int(x_min - margin)) |
|
|
y_min = max(0, int(y_min - margin)) |
|
|
x_max = min(img.shape[1], int(x_max + margin)) |
|
|
y_max = min(img.shape[0], int(y_max + margin)) |
|
|
|
|
|
return img[y_min:y_max, x_min:x_max] |