# Copyright (C) 2021-2025, Mindee. # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. from typing import Any from doctr.models.utils import _CompiledModule from .. import classification from ..preprocessor import PreProcessor from .predictor import OrientationPredictor __all__ = ["crop_orientation_predictor", "page_orientation_predictor"] ARCHS: list[str] = [ "magc_resnet31", "mobilenet_v3_small", "mobilenet_v3_small_r", "mobilenet_v3_large", "mobilenet_v3_large_r", "resnet18", "resnet31", "resnet34", "resnet50", "resnet34_wide", "textnet_tiny", "textnet_small", "textnet_base", "vgg16_bn_r", "vit_s", "vit_b", "vip_tiny", "vip_base", ] ORIENTATION_ARCHS: list[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"] def _orientation_predictor( arch: Any, pretrained: bool, model_type: str, disabled: bool = False, **kwargs: Any ) -> OrientationPredictor: if disabled: # Case where the orientation predictor is disabled return OrientationPredictor(None, None) if isinstance(arch, str): if arch not in ORIENTATION_ARCHS: raise ValueError(f"unknown architecture '{arch}'") # Load directly classifier from backbone _model = classification.__dict__[arch](pretrained=pretrained) else: # Adding the type for torch compiled models to the allowed architectures allowed_archs = [classification.MobileNetV3, _CompiledModule] if not isinstance(arch, tuple(allowed_archs)): raise ValueError(f"unknown architecture: {type(arch)}") _model = arch kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 128 if model_type == "crop" else 4) input_shape = _model.cfg["input_shape"][1:] predictor = OrientationPredictor( PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model ) return predictor def crop_orientation_predictor( arch: Any = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, batch_size: int = 128, **kwargs: Any ) -> OrientationPredictor: """Crop orientation classification architecture. >>> import numpy as np >>> from doctr.models import crop_orientation_predictor >>> model = crop_orientation_predictor(arch='mobilenet_v3_small_crop_orientation', pretrained=True) >>> input_crop = (255 * np.random.rand(256, 256, 3)).astype(np.uint8) >>> out = model([input_crop]) Args: arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation') pretrained: If True, returns a model pre-trained on our recognition crops dataset batch_size: number of samples the model processes in parallel **kwargs: keyword arguments to be passed to the OrientationPredictor Returns: OrientationPredictor """ return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="crop", **kwargs) def page_orientation_predictor( arch: Any = "mobilenet_v3_small_page_orientation", pretrained: bool = False, batch_size: int = 4, **kwargs: Any ) -> OrientationPredictor: """Page orientation classification architecture. >>> import numpy as np >>> from doctr.models import page_orientation_predictor >>> model = page_orientation_predictor(arch='mobilenet_v3_small_page_orientation', pretrained=True) >>> input_page = (255 * np.random.rand(512, 512, 3)).astype(np.uint8) >>> out = model([input_page]) Args: arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation') pretrained: If True, returns a model pre-trained on our recognition crops dataset batch_size: number of samples the model processes in parallel **kwargs: keyword arguments to be passed to the OrientationPredictor Returns: OrientationPredictor """ return _orientation_predictor(arch=arch, pretrained=pretrained, batch_size=batch_size, model_type="page", **kwargs)