Spaces:
Running
Running
File size: 1,876 Bytes
f3270e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
# Copyright (C) 2021-2025, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
from collections.abc import Callable
import torch
from doctr.models import kie_predictor, ocr_predictor
from .schemas import DetectionIn, KIEIn, OCRIn, RecognitionIn
def _move_to_device(predictor: Callable) -> Callable:
"""Move the predictor to the desired device
Args:
predictor: the predictor to move
Returns:
Callable: the predictor moved to the desired device
"""
return predictor.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
def init_predictor(request: KIEIn | OCRIn | RecognitionIn | DetectionIn) -> Callable:
"""Initialize the predictor based on the request
Args:
request: input request
Returns:
Callable: the predictor
"""
params = request.model_dump()
bin_thresh = params.pop("bin_thresh", None)
box_thresh = params.pop("box_thresh", None)
if isinstance(request, (OCRIn, RecognitionIn, DetectionIn)):
predictor = ocr_predictor(pretrained=True, **params)
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
if isinstance(request, DetectionIn):
return _move_to_device(predictor.det_predictor)
elif isinstance(request, RecognitionIn):
return _move_to_device(predictor.reco_predictor)
return _move_to_device(predictor)
elif isinstance(request, KIEIn):
predictor = kie_predictor(pretrained=True, **params)
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
return _move_to_device(predictor)
|