| import os |
| import shutil |
| import tempfile |
| from time import perf_counter |
| from typing import Any, List, Union |
|
|
| from doctr import models as models |
| from doctr.io import DocumentFile |
| from doctr.models import ocr_predictor |
| from PIL import Image |
|
|
| from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest |
| from inference.core.entities.requests.inference import InferenceRequest |
| from inference.core.entities.responses.doctr import DoctrOCRInferenceResponse |
| from inference.core.entities.responses.inference import InferenceResponse |
| from inference.core.env import MODEL_CACHE_DIR |
| from inference.core.models.roboflow import RoboflowCoreModel |
| from inference.core.utils.image_utils import load_image |
|
|
|
|
| class DocTR(RoboflowCoreModel): |
| def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn", **kwargs): |
| """Initializes the DocTR model. |
| |
| Args: |
| *args: Variable length argument list. |
| **kwargs: Arbitrary keyword arguments. |
| """ |
| self.api_key = kwargs.get("api_key") |
| self.dataset_id = "doctr" |
| self.version_id = "default" |
| self.endpoint = model_id |
| model_id = model_id.lower() |
|
|
| os.environ["DOCTR_CACHE_DIR"] = os.path.join(MODEL_CACHE_DIR, "doctr_rec") |
|
|
| self.det_model = DocTRDet(api_key=kwargs.get("api_key")) |
| self.rec_model = DocTRRec(api_key=kwargs.get("api_key")) |
|
|
| os.makedirs(f"{MODEL_CACHE_DIR}/doctr_rec/models/", exist_ok=True) |
| os.makedirs(f"{MODEL_CACHE_DIR}/doctr_det/models/", exist_ok=True) |
|
|
| shutil.copyfile( |
| f"{MODEL_CACHE_DIR}/doctr_det/db_resnet50/model.pt", |
| f"{MODEL_CACHE_DIR}/doctr_det/models/db_resnet50-ac60cadc.pt", |
| ) |
| shutil.copyfile( |
| f"{MODEL_CACHE_DIR}/doctr_rec/crnn_vgg16_bn/model.pt", |
| f"{MODEL_CACHE_DIR}/doctr_rec/models/crnn_vgg16_bn-9762b0b0.pt", |
| ) |
|
|
| self.model = ocr_predictor( |
| det_arch=self.det_model.version_id, |
| reco_arch=self.rec_model.version_id, |
| pretrained=True, |
| ) |
| self.task_type = "ocr" |
|
|
| def clear_cache(self) -> None: |
| self.det_model.clear_cache() |
| self.rec_model.clear_cache() |
|
|
| def preprocess_image(self, image: Image.Image) -> Image.Image: |
| """ |
| DocTR pre-processes images as part of its inference pipeline. |
| |
| Thus, no preprocessing is required here. |
| """ |
| pass |
|
|
| def infer_from_request( |
| self, request: DoctrOCRInferenceRequest |
| ) -> DoctrOCRInferenceResponse: |
| t1 = perf_counter() |
| result = self.infer(**request.dict()) |
| return DoctrOCRInferenceResponse( |
| result=result, |
| time=perf_counter() - t1, |
| ) |
|
|
| def infer(self, image: Any, **kwargs): |
| """ |
| Run inference on a provided image. |
| |
| Args: |
| request (DoctrOCRInferenceRequest): The inference request. |
| |
| Returns: |
| DoctrOCRInferenceResponse: The inference response. |
| """ |
|
|
| img = load_image(image) |
|
|
| with tempfile.NamedTemporaryFile(suffix=".jpg") as f: |
| image = Image.fromarray(img[0]) |
|
|
| image.save(f.name) |
|
|
| doc = DocumentFile.from_images([f.name]) |
|
|
| result = self.model(doc).export() |
|
|
| result = result["pages"][0]["blocks"] |
|
|
| result = [ |
| " ".join([word["value"] for word in line["words"]]) |
| for block in result |
| for line in block["lines"] |
| ] |
|
|
| result = " ".join(result) |
|
|
| return result |
|
|
| def get_infer_bucket_file_list(self) -> list: |
| """Get the list of required files for inference. |
| |
| Returns: |
| list: A list of required files for inference, e.g., ["model.pt"]. |
| """ |
| return ["model.pt"] |
|
|
|
|
| class DocTRRec(RoboflowCoreModel): |
| def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn", **kwargs): |
| """Initializes the DocTR model. |
| |
| Args: |
| *args: Variable length argument list. |
| **kwargs: Arbitrary keyword arguments. |
| """ |
| pass |
|
|
| self.get_infer_bucket_file_list() |
|
|
| super().__init__(*args, model_id=model_id, **kwargs) |
|
|
| def get_infer_bucket_file_list(self) -> list: |
| """Get the list of required files for inference. |
| |
| Returns: |
| list: A list of required files for inference, e.g., ["model.pt"]. |
| """ |
| return ["model.pt"] |
|
|
|
|
| class DocTRDet(RoboflowCoreModel): |
| """DocTR class for document Optical Character Recognition (OCR). |
| |
| Attributes: |
| doctr: The DocTR model. |
| ort_session: ONNX runtime inference session. |
| """ |
|
|
| def __init__(self, *args, model_id: str = "doctr_det/db_resnet50", **kwargs): |
| """Initializes the DocTR model. |
| |
| Args: |
| *args: Variable length argument list. |
| **kwargs: Arbitrary keyword arguments. |
| """ |
|
|
| self.get_infer_bucket_file_list() |
|
|
| super().__init__(*args, model_id=model_id, **kwargs) |
|
|
| def get_infer_bucket_file_list(self) -> list: |
| """Get the list of required files for inference. |
| |
| Returns: |
| list: A list of required files for inference, e.g., ["model.pt"]. |
| """ |
| return ["model.pt"] |
|
|