myj / model.py
sonygod's picture
imporve ι‡ζž„ι‘΅ι’
2205e98
import time
from doctr.file_utils import is_tf_available
import numpy as np
import cv2
if is_tf_available():
import tensorflow as tf
from backend.tensorflow import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
else:
import torch
from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
class OCRModel:
def __init__(self):
self.predictor = None
self.device = None
self._init_backend()
def _init_backend(self):
# First check if TF is available
if is_tf_available():
import tensorflow as tf
from backend.tensorflow import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
self.DET_ARCHS = DET_ARCHS
self.RECO_ARCHS = RECO_ARCHS
if any(tf.config.experimental.list_physical_devices("gpu")):
self.device = tf.device("/gpu:0")
else:
self.device = tf.device("/cpu:0")
else:
# Only import torch if TF is not available
import torch
from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
self.DET_ARCHS = DET_ARCHS
self.RECO_ARCHS = RECO_ARCHS
if torch.cuda.is_available():
self.device = torch.device("cuda:0")
else:
self.device = torch.device("cpu")
def load_model(self, det_arch, reco_arch, **kwargs):
model_params = {
"det_arch": det_arch,
"reco_arch": reco_arch,
"assume_straight_pages": kwargs.get("assume_straight_pages", True),
"straighten_pages": kwargs.get("straighten_pages", False),
"export_as_straight_boxes": kwargs.get("export_as_straight_boxes", False),
"disable_page_orientation": kwargs.get("disable_page_orientation", False),
"disable_crop_orientation": kwargs.get("disable_crop_orientation", False),
"bin_thresh": kwargs.get("bin_thresh", 0.3),
"box_thresh": kwargs.get("box_thresh", 0.1)
}
self.predictor = load_predictor(
**model_params,
device=self.device
)
def process_page(self, page):
seg_map = forward_image(self.predictor, page, self.device)
seg_map = np.squeeze(seg_map)
seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)
out = self.predictor([page])
return seg_map, out
def get_reconstructed_page(self, out, page):
"""Get reconstructed page from OCR output"""
page_export = out.pages[0].export()
img = out.pages[0].synthesize()
# Validate if blocks exist
if not page_export["blocks"]:
return img # Return full image if no blocks detected
# Initialize with image dimensions instead of infinity
x_min, y_min = img.shape[1], img.shape[0]
x_max, y_max = 0, 0
valid_coords = False
for block in page_export["blocks"]:
coords = np.array(block["geometry"])
if coords.size > 0: # Check if coordinates exist
valid_coords = True
x_min = min(x_min, max(0, np.nanmin(coords[:, 0]) * img.shape[1]))
y_min = min(y_min, max(0, np.nanmin(coords[:, 1]) * img.shape[0]))
x_max = max(x_max, min(img.shape[1], np.nanmax(coords[:, 0]) * img.shape[1]))
y_max = max(y_max, min(img.shape[0], np.nanmax(coords[:, 1]) * img.shape[0]))
# Return full image if no valid coordinates found
if not valid_coords or x_min >= x_max or y_min >= y_max:
return img
# Add margins and ensure bounds
margin = 10
x_min = max(0, int(x_min - margin))
y_min = max(0, int(y_min - margin))
x_max = min(img.shape[1], int(x_max + margin))
y_max = min(img.shape[0], int(y_max + margin))
return img[y_min:y_max, x_min:x_max]